"""Worker router - worker management endpoints."""
from typing import Any
from fastapi import APIRouter, status
from gyoza.server.api.models import (
ClaimOpsResponse,
ErrorResponse,
HeartbeatRequest,
WorkerResponse,
)
from gyoza.server.scheduler import scheduler
from gyoza.server.worker import GPU, Heartbeat, Resources, worker_pool
router = APIRouter()
[docs]
@router.get(
"",
response_model=list[WorkerResponse],
responses={
200: {"description": "List of workers"},
},
)
def list_workers() -> list[dict[str, Any]]:
"""List all workers in the pool.
Returns all registered workers regardless of their active status.
Returns
-------
list[dict[str, Any]]
List of all workers with their resources, tags, and status.
"""
workers = worker_pool.list_all()
return [worker.to_dict() for worker in workers]
[docs]
@router.post(
"/heartbeat",
response_model=WorkerResponse,
status_code=status.HTTP_200_OK,
responses={
200: {"description": "Heartbeat received and worker updated"},
400: {"model": ErrorResponse, "description": "Invalid heartbeat data"},
},
)
def heartbeat(request: HeartbeatRequest) -> dict[str, Any]:
"""Register or update a worker via heartbeat.
Workers send heartbeats to announce their presence and update their
status. If the worker exists, it will be updated with the new resources,
tags, and running ops. If it's a new worker, it will be registered.
Parameters
----------
request : HeartbeatRequest
Heartbeat payload containing worker_id, resources, tags, and
optionally running ops.
Returns
-------
dict[str, Any]
The created or updated worker object.
Examples
--------
>>> POST /workers/heartbeat
>>> {
... "worker_id": "worker-gpu-01",
... "resources": {
... "cpu_cores": 8,
... "ram_mb": 32768,
... "gpus": [{"id": 0, "vram_mb": 8192, "tags": ["cuda"]}]
... },
... "tags": ["gpu", "high-mem"],
... "running_ops": []
... }
"""
# Convert request models to domain objects
gpus = [
GPU(id=gpu.id, vram_mb=gpu.vram_mb, tags=gpu.tags)
for gpu in request.resources.gpus
]
resources = Resources(
cpu_cores=request.resources.cpu_cores,
ram_mb=request.resources.ram_mb,
gpus=gpus,
)
# Convert running_ops if provided
running_ops = None
if request.running_ops is not None:
from gyoza.server.worker import WorkerOpRun
running_ops = [WorkerOpRun.from_dict(task) for task in request.running_ops]
heartbeat_obj = Heartbeat(
worker_id=request.worker_id,
resources=resources,
tags=request.tags,
running_ops=running_ops,
)
worker = worker_pool.heartbeat(heartbeat_obj)
return worker.to_dict()
[docs]
@router.post(
"/{worker_id}/claim",
response_model=ClaimOpsResponse,
status_code=status.HTTP_200_OK,
responses={
200: {"description": "Ops claimed successfully"},
},
)
def claim_ops(worker_id: str) -> dict[str, Any]:
"""Claim ops for a worker.
Workers call this endpoint to request work allocation. The scheduler
will determine which pending OpRuns should be allocated to this worker
based on the configured scheduling strategy.
Parameters
----------
worker_id : str
ID of the worker requesting ops.
Returns
-------
dict[str, Any]
Object containing list of claimed ops (WorkerOpRun objects).
Examples
--------
>>> POST /workers/worker-gpu-01/claim
>>> Response:
>>> {
... "ops": [
... {
... "id": "run_abc123",
... "image": "geoiahub/product:v1",
... "inputs": {"param": "value"},
... "constraints": {"ram_mb": 4096, "vram_mb": 2048}
... }
... ]
... }
"""
ops = scheduler.allocate_for_worker(worker_id)
return {
"ops": [op.to_dict() for op in ops],
}