Source code for gyoza.server.op_run.repository

"""Repository for OpRun aggregate persistence.

This is the only place that knows about MongoDB. It handles serialization
and deserialization of the OpRun aggregate (including its attempts).
"""

from datetime import UTC, datetime
from typing import Any

from bson import ObjectId
from pymongo import UpdateOne
from pymongo.collection import Collection

from gyoza.models import (
    Constraints,
    EventDelivery,
    EventEntry,
    ExecutionSummary,
    OpAttempt,
    OpRun,
    OpRunState,
    RetryPolicy,
)


[docs] class OpRunRepository: """Repository for OpRun aggregate persistence. Handles both OpRun and its attempts as a single aggregate. Attempts are stored in a separate collection but loaded together. Parameters ---------- runs_collection : Collection MongoDB collection for OpRuns. attempts_collection : Collection MongoDB collection for OpAttempts. events_collection : Collection | None Optional MongoDB collection for the topic event feed. When provided, events from runs that declared an ``event_delivery.topic`` are mirrored here so external consumers can tail them by topic with a cursor. """ def __init__( self, runs_collection: Collection, attempts_collection: Collection, events_collection: Collection | None = None, ) -> None: self._runs = runs_collection self._attempts = attempts_collection self._events = events_collection
[docs] def get(self, run_id: str) -> OpRun | None: """Retrieve an OpRun aggregate by its ID. Loads the run and all its attempts, reconstructing the full aggregate. Parameters ---------- run_id : str The unique identifier of the OpRun. Returns ------- OpRun | None The complete OpRun aggregate if found, None otherwise. """ run_doc = self._runs.find_one({"_id": run_id}) if not run_doc: return None # Load all attempts for this run attempt_docs = list( self._attempts.find({"op_run_id": run_id}).sort("attempt", 1) ) return self._reconstruct_aggregate(run_doc, attempt_docs)
[docs] def save(self, run: OpRun) -> None: """Persist an OpRun aggregate (insert or update). Saves both the run and all its attempts. Parameters ---------- run : OpRun The OpRun aggregate to persist. """ # Save the run document run_doc = self._op_run_to_doc(run) self._runs.replace_one({"_id": run.id}, run_doc, upsert=True) # Save all attempts for attempt in run._attempts: attempt_doc = self._attempt_to_doc(attempt) self._attempts.replace_one({"_id": attempt.id}, attempt_doc, upsert=True) # Mirror events into the topic feed (best-effort, idempotent) self._mirror_events(run)
def _mirror_events(self, run: OpRun) -> None: """Append a run's events to the topic feed, idempotently. Only runs that declared an ``event_delivery.topic`` are mirrored. Each event has a deterministic ``_id`` (``<attempt_id>:<index>``) so repeated saves never create duplicates, and a server-assigned ``seq`` (an ObjectId, set once on insert) that consumers use as a cursor. This is intentionally best-effort: the durable source of truth for whether a run finished is the run's ``state`` (queried via the runs endpoint for reconciliation), not this feed. Parameters ---------- run : OpRun The run whose events should be mirrored. """ if self._events is None: return topic = run.event_delivery.topic if not topic: return ops: list[UpdateOne] = [] for attempt in run._attempts: for index, event in enumerate(attempt.events): ops.append( UpdateOne( {"_id": f"{attempt.id}:{index}"}, { "$setOnInsert": { "seq": ObjectId(), "topic": topic, "run_id": run.id, "attempt": attempt.attempt, "type": event.type, "msg": event.msg, "t": event.t, "state": run.state.value, "payload": dict(event.payload), } }, upsert=True, ) ) if ops: self._events.bulk_write(ops, ordered=False)
[docs] def list_events_by_topic( self, *, topic: str, after: str | None = None, limit: int = 100, ) -> tuple[list[dict[str, Any]], str | None, bool]: """Tail the event feed for a topic using a cursor. Parameters ---------- topic : str The event delivery topic to read. after : str | None Cursor: only events after this ``seq`` are returned. Pass the ``next_cursor`` from the previous page to resume. limit : int Maximum number of events to return (1–500). Returns ------- tuple[list[dict[str, Any]], str | None, bool] The page of events (oldest first), the next cursor (or None when no events were returned), and whether more events likely exist. Raises ------ ValueError If the feed is not configured or the cursor is malformed. """ if self._events is None: raise ValueError("Event feed is not configured") query: dict[str, Any] = {"topic": topic} if after: try: query["seq"] = {"$gt": ObjectId(after)} except Exception as e: raise ValueError(f"Invalid cursor '{after}'") from e # Fetch one extra to know whether more events exist beyond this page. docs = list(self._events.find(query).sort("seq", 1).limit(limit + 1)) has_more = len(docs) > limit docs = docs[:limit] events = [] for doc in docs: t = doc["t"] events.append( { "cursor": str(doc["seq"]), "run_id": doc["run_id"], "attempt": doc["attempt"], "type": doc["type"], "msg": doc["msg"], "t": t.isoformat() if hasattr(t, "isoformat") else t, "state": doc["state"], "payload": doc.get("payload", {}), } ) next_cursor = events[-1]["cursor"] if events else after return events, next_cursor, has_more
[docs] def delete(self, run_id: str) -> bool: """Delete an OpRun and all its attempts. Parameters ---------- run_id : str The unique identifier of the OpRun to delete. Returns ------- bool True if the run was deleted, False if not found. """ self._attempts.delete_many({"op_run_id": run_id}) result = self._runs.delete_one({"_id": run_id}) return result.deleted_count > 0
[docs] def list_by_state(self, state: OpRunState) -> list[OpRun]: """List all OpRuns with a specific state. Parameters ---------- state : OpRunState The state to filter by. Returns ------- list[OpRun] List of OpRuns matching the state. """ run_docs = list(self._runs.find({"state": state.value})) return [self._load_full_aggregate(doc) for doc in run_docs]
[docs] def list_cursor( self, *, limit: int = 10, starting_after: str | None = None, ending_before: str | None = None, filters: dict[str, Any] | None = None, ) -> list[OpRun]: """List OpRuns using cursor-based pagination. Results are sorted by ``created_at`` descending (newest first). Only one of ``starting_after`` / ``ending_before`` may be supplied. Parameters ---------- limit : int Maximum number of results to return (1–100). starting_after : str | None Cursor ID; returns items created *before* this run (next page). ending_before : str | None Cursor ID; returns items created *after* this run (previous page). filters : dict[str, Any] | None Explicit MongoDB filters to apply with pagination. Returns ------- list[OpRun] Page of OpRun aggregates. Raises ------ ValueError If both cursors are provided or the referenced cursor is not found. """ if starting_after and ending_before: raise ValueError("starting_after and ending_before are mutually exclusive") query: dict[str, Any] = dict(filters or {}) if starting_after: ref = self._runs.find_one({"_id": starting_after}, {"created_at": 1}) if not ref: raise ValueError(f"Cursor '{starting_after}' not found") self._merge_range_filter(query, "created_at", {"$lt": ref["created_at"]}) if ending_before: ref = self._runs.find_one({"_id": ending_before}, {"created_at": 1}) if not ref: raise ValueError(f"Cursor '{ending_before}' not found") self._merge_range_filter(query, "created_at", {"$gt": ref["created_at"]}) if ending_before: # Fetch in ascending order then reverse so the caller always # receives newest-first within the page. run_docs = list(self._runs.find(query).sort("created_at", 1).limit(limit)) run_docs.reverse() else: run_docs = list(self._runs.find(query).sort("created_at", -1).limit(limit)) return [self._load_full_aggregate(doc) for doc in run_docs]
[docs] def has_more( self, *, direction: str, cursor_id: str, filters: dict[str, Any] | None = None, ) -> bool: """Check whether more items exist beyond a cursor. Parameters ---------- direction : str ``"after"`` to check for older items, ``"before"`` for newer. cursor_id : str The ``_id`` of the boundary run. filters : dict[str, Any] | None Explicit MongoDB filters to apply with pagination. Returns ------- bool True if at least one more document exists in the given direction. """ ref = self._runs.find_one({"_id": cursor_id}, {"created_at": 1}) if not ref: return False query: dict[str, Any] = dict(filters or {}) if direction == "after": self._merge_range_filter(query, "created_at", {"$lt": ref["created_at"]}) else: self._merge_range_filter(query, "created_at", {"$gt": ref["created_at"]}) return self._runs.count_documents(query, limit=1) > 0
[docs] def list_pending_ordered(self) -> list[OpRun]: """List pending OpRuns ordered by priority (descending). Returns ------- list[OpRun] Pending OpRuns sorted by priority (highest first). """ run_docs = list( self._runs.find({"state": OpRunState.PENDING.value}).sort("priority", -1) ) return [self._load_full_aggregate(doc) for doc in run_docs]
# ------------------------------------------------------------------------- # Internal helpers # ------------------------------------------------------------------------- @staticmethod def _merge_range_filter( query: dict[str, Any], field: str, range_filter: dict[str, Any] ) -> None: """Merge a range filter into an existing query field.""" existing = query.get(field) if existing is None: query[field] = dict(range_filter) return if not isinstance(existing, dict): raise ValueError(f"Filter for '{field}' must be a range object") query[field] = {**existing, **range_filter} def _load_full_aggregate(self, run_doc: dict[str, Any]) -> OpRun: """Load a full aggregate from a run document.""" attempt_docs = list( self._attempts.find({"op_run_id": run_doc["_id"]}).sort("attempt", 1) ) return self._reconstruct_aggregate(run_doc, attempt_docs) def _reconstruct_aggregate( self, run_doc: dict[str, Any], attempt_docs: list[dict[str, Any]] ) -> OpRun: """Reconstruct OpRun aggregate from documents.""" constraints_doc = run_doc.get("constraints", {}) retry_doc = run_doc.get("retry_policy", {}) event_delivery_doc = run_doc.get("event_delivery", {}) run = OpRun( id=run_doc["_id"], state=OpRunState(run_doc["state"]), priority=run_doc.get("priority", 0), image=run_doc["image"], inputs=run_doc.get("inputs", {}), constraints=Constraints.from_dict(constraints_doc), retry_policy=RetryPolicy.from_dict(retry_doc), event_delivery=EventDelivery.from_dict(event_delivery_doc), op_definition=run_doc.get("op_definition"), created_at=run_doc.get("created_at", datetime.now(UTC)), updated_at=run_doc.get("updated_at", datetime.now(UTC)), _attempts=[], _current_attempt_index=-1, ) # Reconstruct attempts for attempt_doc in attempt_docs: attempt = self._doc_to_attempt(attempt_doc) run._attempts.append(attempt) if not run._attempts: run._create_attempt() else: run._current_attempt_index = len(run._attempts) - 1 return run def _op_run_to_doc(self, run: OpRun) -> dict[str, Any]: """Convert OpRun to MongoDB document.""" doc: dict[str, Any] = { "_id": run.id, "state": run.state.value, "priority": run.priority, "image": run.image, "inputs": run.inputs, "constraints": run.constraints.to_dict(), "retry_policy": run.retry_policy.to_dict(), "event_delivery": run.event_delivery.to_dict(), "op_definition": run.op_definition, "created_at": run.created_at, "updated_at": run.updated_at, } return doc def _attempt_to_doc(self, attempt: OpAttempt) -> dict[str, Any]: """Convert attempt to MongoDB document.""" return { "_id": attempt.id, "op_run_id": attempt.op_run_id, "attempt": attempt.attempt, "state": attempt.state.value, "progress": attempt.progress, "events": [e.to_dict() for e in attempt.events], "inputs": attempt.inputs, "outputs": attempt.outputs, "execution_summary": attempt.execution_summary.to_dict(), "constraints": attempt.constraints.to_dict(), "started_at": attempt.started_at, "finished_at": attempt.finished_at, } def _doc_to_attempt(self, doc: dict[str, Any]) -> OpAttempt: """Convert MongoDB document to attempt.""" events = [EventEntry.from_dict(e) for e in doc.get("events", [])] return OpAttempt( id=doc["_id"], op_run_id=doc["op_run_id"], attempt=doc["attempt"], state=OpRunState(doc["state"]), progress=doc.get("progress", 0), events=events, inputs=doc.get("inputs", {}), outputs=doc.get("outputs", {}), execution_summary=ExecutionSummary.from_dict( doc.get("execution_summary", {}) ), constraints=Constraints.from_dict(doc.get("constraints", {})), started_at=doc.get("started_at"), finished_at=doc.get("finished_at"), )