Source code for gyoza.deployment.io_extractor

"""AST-based extractor for input/output specs from Python source files.

Parses a ``.py`` file using only the standard-library ``ast`` module so
that it works without any third-party dependencies installed (e.g. in CI
environments that lack pydantic, torch, etc.).

The extractor looks for:

1. A function decorated with ``@gyoza_op(...)`` to discover the
   ``input_model`` and ``output_model`` class names.
2. Class definitions that inherit from ``BaseModel`` whose names match
   the discovered model names, extracting their annotated fields.
"""

from __future__ import annotations

import ast
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from gyoza.models.op_definition import (
    InputSpec,
    InputSpecs,
    OutputSpec,
    OutputSpecs,
)

_TYPE_MAP: dict[str, str] = {
    "str": "string",
    "int": "int",
    "float": "float",
    "bool": "boolean",
    "bytes": "bytes",
    "list": "array",
    "dict": "object",
    "List": "array",
    "Dict": "object",
    "Optional": "optional",
    "Any": "any",
}


@dataclass
class _DecoratorInfo:
    input_model_name: str | None = None
    output_model_name: str | None = None


def _resolve_type(node: ast.expr) -> str:
    """Convert an AST annotation node to a simplified type string.

    Parameters
    ----------
    node : ast.expr
        The annotation AST node.

    Returns
    -------
    str
        Simplified type string (e.g. ``"string"``, ``"int"``).
    """
    if isinstance(node, ast.Name):
        return _TYPE_MAP.get(node.id, node.id)
    if isinstance(node, ast.Constant) and isinstance(node.value, str):
        return _TYPE_MAP.get(node.value, node.value)
    if isinstance(node, ast.Attribute):
        return _TYPE_MAP.get(node.attr, node.attr)
    if isinstance(node, ast.Subscript):
        base = _resolve_type(node.value)
        inner = _resolve_type(node.slice)
        return inner if base == "optional" else f"{base}[{inner}]"
    return "any"


def _is_gyoza_op(node: ast.expr) -> bool:
    if isinstance(node, ast.Call):
        return _is_gyoza_op(node.func)
    if isinstance(node, ast.Name):
        return node.id in {"gyoza_op", "GyozaOp"}
    if isinstance(node, ast.Attribute):
        return node.attr in {"gyoza_op", "GyozaOp"}
    return False


def _find_decorator_info(tree: ast.Module) -> _DecoratorInfo:
    """Walk the AST to find the first ``@gyoza_op(...)`` decorator.

    Parameters
    ----------
    tree : ast.Module
        Parsed module AST.

    Returns
    -------
    _DecoratorInfo
        Extracted ``input_model`` and ``output_model`` class names.

    Raises
    ------
    ValueError
        If no ``@gyoza_op`` decorator is found.
    """
    for node in ast.walk(tree):
        if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            continue
        for dec in node.decorator_list:
            if not (_is_gyoza_op(dec) and isinstance(dec, ast.Call)):
                continue
            info = _DecoratorInfo()
            for kw in dec.keywords:
                if kw.arg == "input_model" and isinstance(kw.value, ast.Name):
                    info.input_model_name = kw.value.id
                elif kw.arg == "output_model" and isinstance(kw.value, ast.Name):
                    info.output_model_name = kw.value.id
            return info
    msg = "No @gyoza_op decorator found in the file"
    raise ValueError(msg)


def _extract_class_fields(
    tree: ast.Module,
    class_name: str,
) -> dict[str, tuple[str, bool, Any]]:
    """Extract annotated fields from a class definition.

    Parameters
    ----------
    tree : ast.Module
        Parsed module AST.
    class_name : str
        Name of the class to extract fields from.

    Returns
    -------
    dict[str, tuple[str, bool, Any]]
        Mapping of field name to ``(type_str, required, default)``.

    Raises
    ------
    ValueError
        If the class is not found.
    """
    for node in ast.walk(tree):
        if not isinstance(node, ast.ClassDef) or node.name != class_name:
            continue
        fields: dict[str, tuple[str, bool, Any]] = {}
        for stmt in node.body:
            if not isinstance(stmt, ast.AnnAssign) or not isinstance(
                stmt.target, ast.Name
            ):
                continue
            name = stmt.target.id
            type_str = _resolve_type(stmt.annotation)
            if stmt.value is not None:
                default = (
                    stmt.value.value if isinstance(stmt.value, ast.Constant) else None
                )
                fields[name] = (type_str, False, default)
            else:
                fields[name] = (type_str, True, None)
        return fields
    msg = f"Class '{class_name}' not found in the file"
    raise ValueError(msg)


[docs] def extract_io_specs(io_file_path: Path | str) -> tuple[InputSpecs, OutputSpecs]: """Extract input and output specs from a Python file via AST analysis. Reads the file as text, parses it with ``ast.parse``, then locates the ``@gyoza_op(input_model=..., output_model=...)`` decorator and the corresponding Pydantic model classes to build the specs. This function requires **no third-party imports** at parse time. Parameters ---------- io_file_path : Path | str Path to the ``.py`` file containing the decorated function and model definitions. Returns ------- tuple[InputSpecs, OutputSpecs] Extracted input and output specifications. Raises ------ FileNotFoundError If the file does not exist. ValueError If the file cannot be parsed or required elements are missing. """ path = Path(io_file_path) if not path.exists(): msg = f"IO file not found: {path}" raise FileNotFoundError(msg) tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) info = _find_decorator_info(tree) input_specs = InputSpecs() if info.input_model_name: fields = _extract_class_fields(tree, info.input_model_name) input_specs = InputSpecs( specs={ name: InputSpec(type=t, required=req, default=default) for name, (t, req, default) in fields.items() } ) output_specs = OutputSpecs() if info.output_model_name: fields = _extract_class_fields(tree, info.output_model_name) output_specs = OutputSpecs( specs={name: OutputSpec(type=t) for name, (t, _, _) in fields.items()} ) return input_specs, output_specs