"""AST-based extractor for input/output specs from Python source files.
Parses a ``.py`` file using only the standard-library ``ast`` module so
that it works without any third-party dependencies installed (e.g. in CI
environments that lack pydantic, torch, etc.).
The extractor looks for:
1. A function decorated with ``@gyoza_op(...)`` to discover the
``input_model`` and ``output_model`` class names.
2. Class definitions that inherit from ``BaseModel`` whose names match
the discovered model names, extracting their annotated fields.
"""
from __future__ import annotations
import ast
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from gyoza.models.op_definition import (
InputSpec,
InputSpecs,
OutputSpec,
OutputSpecs,
)
_TYPE_MAP: dict[str, str] = {
"str": "string",
"int": "int",
"float": "float",
"bool": "boolean",
"bytes": "bytes",
"list": "array",
"dict": "object",
"List": "array",
"Dict": "object",
"Optional": "optional",
"Any": "any",
}
@dataclass
class _DecoratorInfo:
input_model_name: str | None = None
output_model_name: str | None = None
def _resolve_type(node: ast.expr) -> str:
"""Convert an AST annotation node to a simplified type string.
Parameters
----------
node : ast.expr
The annotation AST node.
Returns
-------
str
Simplified type string (e.g. ``"string"``, ``"int"``).
"""
if isinstance(node, ast.Name):
return _TYPE_MAP.get(node.id, node.id)
if isinstance(node, ast.Constant) and isinstance(node.value, str):
return _TYPE_MAP.get(node.value, node.value)
if isinstance(node, ast.Attribute):
return _TYPE_MAP.get(node.attr, node.attr)
if isinstance(node, ast.Subscript):
base = _resolve_type(node.value)
inner = _resolve_type(node.slice)
return inner if base == "optional" else f"{base}[{inner}]"
return "any"
def _is_gyoza_op(node: ast.expr) -> bool:
if isinstance(node, ast.Call):
return _is_gyoza_op(node.func)
if isinstance(node, ast.Name):
return node.id in {"gyoza_op", "GyozaOp"}
if isinstance(node, ast.Attribute):
return node.attr in {"gyoza_op", "GyozaOp"}
return False
def _find_decorator_info(tree: ast.Module) -> _DecoratorInfo:
"""Walk the AST to find the first ``@gyoza_op(...)`` decorator.
Parameters
----------
tree : ast.Module
Parsed module AST.
Returns
-------
_DecoratorInfo
Extracted ``input_model`` and ``output_model`` class names.
Raises
------
ValueError
If no ``@gyoza_op`` decorator is found.
"""
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
for dec in node.decorator_list:
if not (_is_gyoza_op(dec) and isinstance(dec, ast.Call)):
continue
info = _DecoratorInfo()
for kw in dec.keywords:
if kw.arg == "input_model" and isinstance(kw.value, ast.Name):
info.input_model_name = kw.value.id
elif kw.arg == "output_model" and isinstance(kw.value, ast.Name):
info.output_model_name = kw.value.id
return info
msg = "No @gyoza_op decorator found in the file"
raise ValueError(msg)
def _extract_class_fields(
tree: ast.Module,
class_name: str,
) -> dict[str, tuple[str, bool, Any]]:
"""Extract annotated fields from a class definition.
Parameters
----------
tree : ast.Module
Parsed module AST.
class_name : str
Name of the class to extract fields from.
Returns
-------
dict[str, tuple[str, bool, Any]]
Mapping of field name to ``(type_str, required, default)``.
Raises
------
ValueError
If the class is not found.
"""
for node in ast.walk(tree):
if not isinstance(node, ast.ClassDef) or node.name != class_name:
continue
fields: dict[str, tuple[str, bool, Any]] = {}
for stmt in node.body:
if not isinstance(stmt, ast.AnnAssign) or not isinstance(
stmt.target, ast.Name
):
continue
name = stmt.target.id
type_str = _resolve_type(stmt.annotation)
if stmt.value is not None:
default = (
stmt.value.value if isinstance(stmt.value, ast.Constant) else None
)
fields[name] = (type_str, False, default)
else:
fields[name] = (type_str, True, None)
return fields
msg = f"Class '{class_name}' not found in the file"
raise ValueError(msg)