"""Repository for OpRun aggregate persistence.
This is the only place that knows about MongoDB. It handles serialization
and deserialization of the OpRun aggregate (including its attempts).
"""
from datetime import UTC, datetime
from typing import Any
from bson import ObjectId
from pymongo import UpdateOne
from pymongo.collection import Collection
from gyoza.models import (
Constraints,
EventDelivery,
EventEntry,
ExecutionSummary,
OpAttempt,
OpRun,
OpRunState,
RetryPolicy,
)
[docs]
class OpRunRepository:
"""Repository for OpRun aggregate persistence.
Handles both OpRun and its attempts as a single aggregate.
Attempts are stored in a separate collection but loaded together.
Parameters
----------
runs_collection : Collection
MongoDB collection for OpRuns.
attempts_collection : Collection
MongoDB collection for OpAttempts.
events_collection : Collection | None
Optional MongoDB collection for the topic event feed. When provided,
events from runs that declared an ``event_delivery.topic`` are mirrored
here so external consumers can tail them by topic with a cursor.
"""
def __init__(
self,
runs_collection: Collection,
attempts_collection: Collection,
events_collection: Collection | None = None,
) -> None:
self._runs = runs_collection
self._attempts = attempts_collection
self._events = events_collection
[docs]
def get(self, run_id: str) -> OpRun | None:
"""Retrieve an OpRun aggregate by its ID.
Loads the run and all its attempts, reconstructing the full aggregate.
Parameters
----------
run_id : str
The unique identifier of the OpRun.
Returns
-------
OpRun | None
The complete OpRun aggregate if found, None otherwise.
"""
run_doc = self._runs.find_one({"_id": run_id})
if not run_doc:
return None
# Load all attempts for this run
attempt_docs = list(
self._attempts.find({"op_run_id": run_id}).sort("attempt", 1)
)
return self._reconstruct_aggregate(run_doc, attempt_docs)
[docs]
def save(self, run: OpRun) -> None:
"""Persist an OpRun aggregate (insert or update).
Saves both the run and all its attempts.
Parameters
----------
run : OpRun
The OpRun aggregate to persist.
"""
# Save the run document
run_doc = self._op_run_to_doc(run)
self._runs.replace_one({"_id": run.id}, run_doc, upsert=True)
# Save all attempts
for attempt in run._attempts:
attempt_doc = self._attempt_to_doc(attempt)
self._attempts.replace_one({"_id": attempt.id}, attempt_doc, upsert=True)
# Mirror events into the topic feed (best-effort, idempotent)
self._mirror_events(run)
def _mirror_events(self, run: OpRun) -> None:
"""Append a run's events to the topic feed, idempotently.
Only runs that declared an ``event_delivery.topic`` are mirrored.
Each event has a deterministic ``_id`` (``<attempt_id>:<index>``) so
repeated saves never create duplicates, and a server-assigned ``seq``
(an ObjectId, set once on insert) that consumers use as a cursor.
This is intentionally best-effort: the durable source of truth for
whether a run finished is the run's ``state`` (queried via the runs
endpoint for reconciliation), not this feed.
Parameters
----------
run : OpRun
The run whose events should be mirrored.
"""
if self._events is None:
return
topic = run.event_delivery.topic
if not topic:
return
ops: list[UpdateOne] = []
for attempt in run._attempts:
for index, event in enumerate(attempt.events):
ops.append(
UpdateOne(
{"_id": f"{attempt.id}:{index}"},
{
"$setOnInsert": {
"seq": ObjectId(),
"topic": topic,
"run_id": run.id,
"attempt": attempt.attempt,
"type": event.type,
"msg": event.msg,
"t": event.t,
"state": run.state.value,
"payload": dict(event.payload),
}
},
upsert=True,
)
)
if ops:
self._events.bulk_write(ops, ordered=False)
[docs]
def list_events_by_topic(
self,
*,
topic: str,
after: str | None = None,
limit: int = 100,
) -> tuple[list[dict[str, Any]], str | None, bool]:
"""Tail the event feed for a topic using a cursor.
Parameters
----------
topic : str
The event delivery topic to read.
after : str | None
Cursor: only events after this ``seq`` are returned. Pass the
``next_cursor`` from the previous page to resume.
limit : int
Maximum number of events to return (1–500).
Returns
-------
tuple[list[dict[str, Any]], str | None, bool]
The page of events (oldest first), the next cursor (or None when
no events were returned), and whether more events likely exist.
Raises
------
ValueError
If the feed is not configured or the cursor is malformed.
"""
if self._events is None:
raise ValueError("Event feed is not configured")
query: dict[str, Any] = {"topic": topic}
if after:
try:
query["seq"] = {"$gt": ObjectId(after)}
except Exception as e:
raise ValueError(f"Invalid cursor '{after}'") from e
# Fetch one extra to know whether more events exist beyond this page.
docs = list(self._events.find(query).sort("seq", 1).limit(limit + 1))
has_more = len(docs) > limit
docs = docs[:limit]
events = []
for doc in docs:
t = doc["t"]
events.append(
{
"cursor": str(doc["seq"]),
"run_id": doc["run_id"],
"attempt": doc["attempt"],
"type": doc["type"],
"msg": doc["msg"],
"t": t.isoformat() if hasattr(t, "isoformat") else t,
"state": doc["state"],
"payload": doc.get("payload", {}),
}
)
next_cursor = events[-1]["cursor"] if events else after
return events, next_cursor, has_more
[docs]
def delete(self, run_id: str) -> bool:
"""Delete an OpRun and all its attempts.
Parameters
----------
run_id : str
The unique identifier of the OpRun to delete.
Returns
-------
bool
True if the run was deleted, False if not found.
"""
self._attempts.delete_many({"op_run_id": run_id})
result = self._runs.delete_one({"_id": run_id})
return result.deleted_count > 0
[docs]
def list_by_state(self, state: OpRunState) -> list[OpRun]:
"""List all OpRuns with a specific state.
Parameters
----------
state : OpRunState
The state to filter by.
Returns
-------
list[OpRun]
List of OpRuns matching the state.
"""
run_docs = list(self._runs.find({"state": state.value}))
return [self._load_full_aggregate(doc) for doc in run_docs]
[docs]
def list_cursor(
self,
*,
limit: int = 10,
starting_after: str | None = None,
ending_before: str | None = None,
filters: dict[str, Any] | None = None,
) -> list[OpRun]:
"""List OpRuns using cursor-based pagination.
Results are sorted by ``created_at`` descending (newest first).
Only one of ``starting_after`` / ``ending_before`` may be supplied.
Parameters
----------
limit : int
Maximum number of results to return (1–100).
starting_after : str | None
Cursor ID; returns items created *before* this run (next page).
ending_before : str | None
Cursor ID; returns items created *after* this run (previous page).
filters : dict[str, Any] | None
Explicit MongoDB filters to apply with pagination.
Returns
-------
list[OpRun]
Page of OpRun aggregates.
Raises
------
ValueError
If both cursors are provided or the referenced cursor is not found.
"""
if starting_after and ending_before:
raise ValueError("starting_after and ending_before are mutually exclusive")
query: dict[str, Any] = dict(filters or {})
if starting_after:
ref = self._runs.find_one({"_id": starting_after}, {"created_at": 1})
if not ref:
raise ValueError(f"Cursor '{starting_after}' not found")
self._merge_range_filter(query, "created_at", {"$lt": ref["created_at"]})
if ending_before:
ref = self._runs.find_one({"_id": ending_before}, {"created_at": 1})
if not ref:
raise ValueError(f"Cursor '{ending_before}' not found")
self._merge_range_filter(query, "created_at", {"$gt": ref["created_at"]})
if ending_before:
# Fetch in ascending order then reverse so the caller always
# receives newest-first within the page.
run_docs = list(self._runs.find(query).sort("created_at", 1).limit(limit))
run_docs.reverse()
else:
run_docs = list(self._runs.find(query).sort("created_at", -1).limit(limit))
return [self._load_full_aggregate(doc) for doc in run_docs]
[docs]
def has_more(
self,
*,
direction: str,
cursor_id: str,
filters: dict[str, Any] | None = None,
) -> bool:
"""Check whether more items exist beyond a cursor.
Parameters
----------
direction : str
``"after"`` to check for older items, ``"before"`` for newer.
cursor_id : str
The ``_id`` of the boundary run.
filters : dict[str, Any] | None
Explicit MongoDB filters to apply with pagination.
Returns
-------
bool
True if at least one more document exists in the given direction.
"""
ref = self._runs.find_one({"_id": cursor_id}, {"created_at": 1})
if not ref:
return False
query: dict[str, Any] = dict(filters or {})
if direction == "after":
self._merge_range_filter(query, "created_at", {"$lt": ref["created_at"]})
else:
self._merge_range_filter(query, "created_at", {"$gt": ref["created_at"]})
return self._runs.count_documents(query, limit=1) > 0
[docs]
def list_pending_ordered(self) -> list[OpRun]:
"""List pending OpRuns ordered by priority (descending).
Returns
-------
list[OpRun]
Pending OpRuns sorted by priority (highest first).
"""
run_docs = list(
self._runs.find({"state": OpRunState.PENDING.value}).sort("priority", -1)
)
return [self._load_full_aggregate(doc) for doc in run_docs]
# -------------------------------------------------------------------------
# Internal helpers
# -------------------------------------------------------------------------
@staticmethod
def _merge_range_filter(
query: dict[str, Any], field: str, range_filter: dict[str, Any]
) -> None:
"""Merge a range filter into an existing query field."""
existing = query.get(field)
if existing is None:
query[field] = dict(range_filter)
return
if not isinstance(existing, dict):
raise ValueError(f"Filter for '{field}' must be a range object")
query[field] = {**existing, **range_filter}
def _load_full_aggregate(self, run_doc: dict[str, Any]) -> OpRun:
"""Load a full aggregate from a run document."""
attempt_docs = list(
self._attempts.find({"op_run_id": run_doc["_id"]}).sort("attempt", 1)
)
return self._reconstruct_aggregate(run_doc, attempt_docs)
def _reconstruct_aggregate(
self, run_doc: dict[str, Any], attempt_docs: list[dict[str, Any]]
) -> OpRun:
"""Reconstruct OpRun aggregate from documents."""
constraints_doc = run_doc.get("constraints", {})
retry_doc = run_doc.get("retry_policy", {})
event_delivery_doc = run_doc.get("event_delivery", {})
run = OpRun(
id=run_doc["_id"],
state=OpRunState(run_doc["state"]),
priority=run_doc.get("priority", 0),
image=run_doc["image"],
inputs=run_doc.get("inputs", {}),
constraints=Constraints.from_dict(constraints_doc),
retry_policy=RetryPolicy.from_dict(retry_doc),
event_delivery=EventDelivery.from_dict(event_delivery_doc),
op_definition=run_doc.get("op_definition"),
created_at=run_doc.get("created_at", datetime.now(UTC)),
updated_at=run_doc.get("updated_at", datetime.now(UTC)),
_attempts=[],
_current_attempt_index=-1,
)
# Reconstruct attempts
for attempt_doc in attempt_docs:
attempt = self._doc_to_attempt(attempt_doc)
run._attempts.append(attempt)
if not run._attempts:
run._create_attempt()
else:
run._current_attempt_index = len(run._attempts) - 1
return run
def _op_run_to_doc(self, run: OpRun) -> dict[str, Any]:
"""Convert OpRun to MongoDB document."""
doc: dict[str, Any] = {
"_id": run.id,
"state": run.state.value,
"priority": run.priority,
"image": run.image,
"inputs": run.inputs,
"constraints": run.constraints.to_dict(),
"retry_policy": run.retry_policy.to_dict(),
"event_delivery": run.event_delivery.to_dict(),
"op_definition": run.op_definition,
"created_at": run.created_at,
"updated_at": run.updated_at,
}
return doc
def _attempt_to_doc(self, attempt: OpAttempt) -> dict[str, Any]:
"""Convert attempt to MongoDB document."""
return {
"_id": attempt.id,
"op_run_id": attempt.op_run_id,
"attempt": attempt.attempt,
"state": attempt.state.value,
"progress": attempt.progress,
"events": [e.to_dict() for e in attempt.events],
"inputs": attempt.inputs,
"outputs": attempt.outputs,
"execution_summary": attempt.execution_summary.to_dict(),
"constraints": attempt.constraints.to_dict(),
"started_at": attempt.started_at,
"finished_at": attempt.finished_at,
}
def _doc_to_attempt(self, doc: dict[str, Any]) -> OpAttempt:
"""Convert MongoDB document to attempt."""
events = [EventEntry.from_dict(e) for e in doc.get("events", [])]
return OpAttempt(
id=doc["_id"],
op_run_id=doc["op_run_id"],
attempt=doc["attempt"],
state=OpRunState(doc["state"]),
progress=doc.get("progress", 0),
events=events,
inputs=doc.get("inputs", {}),
outputs=doc.get("outputs", {}),
execution_summary=ExecutionSummary.from_dict(
doc.get("execution_summary", {})
),
constraints=Constraints.from_dict(doc.get("constraints", {})),
started_at=doc.get("started_at"),
finished_at=doc.get("finished_at"),
)