Combined pydantic model with parameters and request body
This commit is contained in:
@@ -18,7 +18,7 @@ class OperationDetails:
|
|||||||
path: str
|
path: str
|
||||||
method: str
|
method: str
|
||||||
parameters: dict[str, Any]
|
parameters: dict[str, Any]
|
||||||
request_body: type[BaseModel] | None = None
|
input_model: type[BaseModel]
|
||||||
response_model: type[BaseModel] | None = None
|
response_model: type[BaseModel] | None = None
|
||||||
|
|
||||||
|
|
||||||
@@ -35,7 +35,6 @@ class OperationParser:
|
|||||||
ValueError: If spec_path is invalid or spec cannot be loaded
|
ValueError: If spec_path is invalid or spec cannot be loaded
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Load and parse OpenAPI spec
|
|
||||||
if isinstance(spec_path, bytes):
|
if isinstance(spec_path, bytes):
|
||||||
self.raw_spec = yaml.safe_load(spec_path)
|
self.raw_spec = yaml.safe_load(spec_path)
|
||||||
elif isinstance(spec_path, dict):
|
elif isinstance(spec_path, dict):
|
||||||
@@ -48,7 +47,6 @@ class OperationParser:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid spec_path type: {type(spec_path)}. Expected Path, str, dict, bytes or file-like object")
|
raise ValueError(f"Invalid spec_path type: {type(spec_path)}. Expected Path, str, dict, bytes or file-like object")
|
||||||
|
|
||||||
# Initialize OpenAPI spec
|
|
||||||
spec = OpenAPI.from_dict(self.raw_spec)
|
spec = OpenAPI.from_dict(self.raw_spec)
|
||||||
self.spec = spec
|
self.spec = spec
|
||||||
self._paths = self.raw_spec["paths"]
|
self._paths = self.raw_spec["paths"]
|
||||||
@@ -72,7 +70,6 @@ class OperationParser:
|
|||||||
ValueError: If operation not found or invalid
|
ValueError: If operation not found or invalid
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Find operation in spec
|
|
||||||
for path, path_item in self._paths.items():
|
for path, path_item in self._paths.items():
|
||||||
for method, operation in path_item.items():
|
for method, operation in path_item.items():
|
||||||
if method.startswith("x-") or method == "parameters":
|
if method.startswith("x-") or method == "parameters":
|
||||||
@@ -81,21 +78,30 @@ class OperationParser:
|
|||||||
if operation.get("operationId") == operation_id:
|
if operation.get("operationId") == operation_id:
|
||||||
logger.debug("Found operation %s at %s %s", operation_id, method, path)
|
logger.debug("Found operation %s at %s %s", operation_id, method, path)
|
||||||
|
|
||||||
# Add path to operation for parameter context
|
|
||||||
operation["path"] = path
|
operation["path"] = path
|
||||||
operation["path_item"] = path_item
|
operation["path_item"] = path_item
|
||||||
|
|
||||||
# Extract operation details
|
|
||||||
parameters = self.extract_parameters(operation)
|
parameters = self.extract_parameters(operation)
|
||||||
request_body = self._parse_request_body(operation)
|
|
||||||
response_model = self._parse_response_model(operation)
|
response_model = self._parse_response_model(operation)
|
||||||
|
|
||||||
|
# Get request body schema if present
|
||||||
|
body_schema = None
|
||||||
|
if "requestBody" in operation:
|
||||||
|
content = operation["requestBody"].get("content", {})
|
||||||
|
if "application/json" in content:
|
||||||
|
body_schema = content["application/json"].get("schema", {})
|
||||||
|
if "$ref" in body_schema:
|
||||||
|
body_schema = self._resolve_ref(body_schema["$ref"])
|
||||||
|
|
||||||
|
# Create unified input model
|
||||||
|
input_model = self._create_input_model(operation_id, parameters, body_schema)
|
||||||
|
|
||||||
return OperationDetails(
|
return OperationDetails(
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
path=str(path),
|
path=str(path),
|
||||||
method=method,
|
method=method,
|
||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
request_body=request_body,
|
input_model=input_model,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,6 +111,37 @@ class OperationParser:
|
|||||||
logger.error("Error parsing operation %s: %s", operation_id, e)
|
logger.error("Error parsing operation %s: %s", operation_id, e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _create_input_model(
|
||||||
|
self,
|
||||||
|
operation_id: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
body_schema: dict[str, Any] | None = None,
|
||||||
|
) -> type[BaseModel]:
|
||||||
|
"""Create unified input model for all parameters."""
|
||||||
|
fields: dict[str, tuple[type, Any]] = {}
|
||||||
|
|
||||||
|
# Add path parameters
|
||||||
|
for name, schema in parameters.get("path", {}).items():
|
||||||
|
field_type = schema["type"]
|
||||||
|
required = schema.get("required", True) # Path parameters are required by default
|
||||||
|
fields[f"path_{name}"] = (field_type, ... if required else None)
|
||||||
|
|
||||||
|
# Add query parameters
|
||||||
|
for name, schema in parameters.get("query", {}).items():
|
||||||
|
field_type = schema["type"]
|
||||||
|
required = schema.get("required", False) # Query parameters are optional by default
|
||||||
|
fields[f"query_{name}"] = (field_type, ... if required else None)
|
||||||
|
|
||||||
|
# Add body fields if present
|
||||||
|
if body_schema and body_schema.get("type") == "object":
|
||||||
|
for prop_name, prop_schema in body_schema.get("properties", {}).items():
|
||||||
|
field_type = self._map_type(prop_schema.get("type", "string"))
|
||||||
|
required = prop_name in body_schema.get("required", [])
|
||||||
|
fields[f"body_{prop_name}"] = (field_type, ... if required else None)
|
||||||
|
|
||||||
|
logger.debug("Creating input model for %s with fields: %s", operation_id, fields)
|
||||||
|
return create_model(f"{operation_id}_input", **fields)
|
||||||
|
|
||||||
def extract_parameters(self, operation: dict[str, Any]) -> dict[str, Any]:
|
def extract_parameters(self, operation: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Extract and categorize operation parameters.
|
"""Extract and categorize operation parameters.
|
||||||
|
|
||||||
@@ -120,12 +157,10 @@ class OperationParser:
|
|||||||
"header": {},
|
"header": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle path-level parameters
|
|
||||||
path_item = operation.get("path_item", {})
|
path_item = operation.get("path_item", {})
|
||||||
if path_item and "parameters" in path_item:
|
if path_item and "parameters" in path_item:
|
||||||
self._process_parameters(path_item["parameters"], parameters)
|
self._process_parameters(path_item["parameters"], parameters)
|
||||||
|
|
||||||
# Handle operation-level parameters
|
|
||||||
self._process_parameters(operation.get("parameters", []), parameters)
|
self._process_parameters(operation.get("parameters", []), parameters)
|
||||||
|
|
||||||
return parameters
|
return parameters
|
||||||
@@ -138,11 +173,9 @@ class OperationParser:
|
|||||||
target: Target dictionary to store processed parameters
|
target: Target dictionary to store processed parameters
|
||||||
"""
|
"""
|
||||||
for param in params:
|
for param in params:
|
||||||
# Resolve parameter reference if needed
|
|
||||||
if "$ref" in param:
|
if "$ref" in param:
|
||||||
param = self._resolve_ref(param["$ref"])
|
param = self._resolve_ref(param["$ref"])
|
||||||
|
|
||||||
# Validate parameter structure
|
|
||||||
if not isinstance(param, dict) or "in" not in param:
|
if not isinstance(param, dict) or "in" not in param:
|
||||||
logger.warning("Invalid parameter format: %s", param)
|
logger.warning("Invalid parameter format: %s", param)
|
||||||
continue
|
continue
|
||||||
@@ -165,7 +198,7 @@ class OperationParser:
|
|||||||
|
|
||||||
parts = ref.split("/")
|
parts = ref.split("/")
|
||||||
current = self.raw_spec
|
current = self.raw_spec
|
||||||
for part in parts[1:]: # Skip first '#'
|
for part in parts[1:]:
|
||||||
current = current[part]
|
current = current[part]
|
||||||
|
|
||||||
self._schema_cache[ref] = current
|
self._schema_cache[ref] = current
|
||||||
@@ -192,14 +225,7 @@ class OperationParser:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _map_type(self, openapi_type: str) -> type:
|
def _map_type(self, openapi_type: str) -> type:
|
||||||
"""Map OpenAPI type to Python type.
|
"""Map OpenAPI type to Python type."""
|
||||||
|
|
||||||
Args:
|
|
||||||
openapi_type: OpenAPI type string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Corresponding Python type
|
|
||||||
"""
|
|
||||||
type_map = {
|
type_map = {
|
||||||
"string": str,
|
"string": str,
|
||||||
"integer": int,
|
"integer": int,
|
||||||
@@ -210,28 +236,6 @@ class OperationParser:
|
|||||||
}
|
}
|
||||||
return type_map.get(openapi_type, Any)
|
return type_map.get(openapi_type, Any)
|
||||||
|
|
||||||
def _parse_request_body(self, operation: dict[str, Any]) -> type[BaseModel] | None:
|
|
||||||
"""Parse request body schema into Pydantic model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation: Operation object from OpenAPI spec
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Pydantic model for request body or None
|
|
||||||
"""
|
|
||||||
if "requestBody" not in operation:
|
|
||||||
return None
|
|
||||||
|
|
||||||
content = operation["requestBody"].get("content", {})
|
|
||||||
if "application/json" not in content:
|
|
||||||
return None
|
|
||||||
|
|
||||||
schema = content["application/json"].get("schema", {})
|
|
||||||
if "$ref" in schema:
|
|
||||||
schema = self._resolve_ref(schema["$ref"])
|
|
||||||
|
|
||||||
return self._create_model("RequestBody", schema)
|
|
||||||
|
|
||||||
def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None:
|
def _parse_response_model(self, operation: dict[str, Any]) -> type[BaseModel] | None:
|
||||||
"""Parse response schema into Pydantic model.
|
"""Parse response schema into Pydantic model.
|
||||||
|
|
||||||
@@ -259,23 +263,6 @@ class OperationParser:
|
|||||||
|
|
||||||
return self._create_model("Response", schema)
|
return self._create_model("Response", schema)
|
||||||
|
|
||||||
def get_operations(self) -> list[str]:
|
|
||||||
"""Get list of all operation IDs from spec.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of operation IDs
|
|
||||||
"""
|
|
||||||
operations = []
|
|
||||||
|
|
||||||
for path in self._paths.values():
|
|
||||||
for method, operation in path.items():
|
|
||||||
if method.startswith("x-") or method == "parameters":
|
|
||||||
continue
|
|
||||||
if "operationId" in operation:
|
|
||||||
operations.append(operation["operationId"])
|
|
||||||
|
|
||||||
return operations
|
|
||||||
|
|
||||||
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
|
def _create_model(self, name: str, schema: dict[str, Any]) -> type[BaseModel]:
|
||||||
"""Create Pydantic model from schema.
|
"""Create Pydantic model from schema.
|
||||||
|
|
||||||
@@ -292,21 +279,18 @@ class OperationParser:
|
|||||||
if "$ref" in schema:
|
if "$ref" in schema:
|
||||||
schema = self._resolve_ref(schema["$ref"])
|
schema = self._resolve_ref(schema["$ref"])
|
||||||
|
|
||||||
if schema.get("type") != "object":
|
if schema.get("type", "object") != "object":
|
||||||
raise ValueError("Schema must be an object type")
|
raise ValueError("Schema must be an object type")
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||||
# Resolve property schema reference if needed
|
|
||||||
if "$ref" in prop_schema:
|
if "$ref" in prop_schema:
|
||||||
prop_schema = self._resolve_ref(prop_schema["$ref"])
|
prop_schema = self._resolve_ref(prop_schema["$ref"])
|
||||||
|
|
||||||
if prop_schema.get("type") == "object":
|
if prop_schema.get("type") == "object":
|
||||||
# Create nested model
|
|
||||||
nested_model = self._create_model(f"{name}_{prop_name}", prop_schema)
|
nested_model = self._create_model(f"{name}_{prop_name}", prop_schema)
|
||||||
field_type = nested_model
|
field_type = nested_model
|
||||||
elif prop_schema.get("type") == "array":
|
elif prop_schema.get("type") == "array":
|
||||||
# Handle array types
|
|
||||||
items = prop_schema.get("items", {})
|
items = prop_schema.get("items", {})
|
||||||
if "$ref" in items:
|
if "$ref" in items:
|
||||||
items = self._resolve_ref(items["$ref"])
|
items = self._resolve_ref(items["$ref"])
|
||||||
@@ -328,3 +312,16 @@ class OperationParser:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error creating model %s: %s", name, e)
|
logger.error("Error creating model %s: %s", name, e)
|
||||||
raise ValueError(f"Failed to create model {name}: {e}")
|
raise ValueError(f"Failed to create model {name}: {e}")
|
||||||
|
|
||||||
|
def get_operations(self) -> list[str]:
|
||||||
|
"""Get list of all operation IDs from spec."""
|
||||||
|
operations = []
|
||||||
|
|
||||||
|
for path in self._paths.values():
|
||||||
|
for method, operation in path.items():
|
||||||
|
if method.startswith("x-") or method == "parameters":
|
||||||
|
continue
|
||||||
|
if "operationId" in operation:
|
||||||
|
operations.append(operation["operationId"])
|
||||||
|
|
||||||
|
return operations
|
||||||
|
|||||||
@@ -35,4 +35,4 @@ async def serve() -> None:
|
|||||||
|
|
||||||
options = server.create_initialization_options()
|
options = server.create_initialization_options()
|
||||||
async with stdio_server() as (read_stream, write_stream):
|
async with stdio_server() as (read_stream, write_stream):
|
||||||
server.run(read_stream, write_stream, options, raise_exceptions=True)
|
await server.run(read_stream, write_stream, options, raise_exceptions=True)
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from airflow_mcp_server.client.airflow_client import AirflowClient
|
from airflow_mcp_server.client.airflow_client import AirflowClient
|
||||||
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
from airflow_mcp_server.parser.operation_parser import OperationDetails
|
||||||
from airflow_mcp_server.tools.base_tools import BaseTools
|
from airflow_mcp_server.tools.base_tools import BaseTools
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -48,108 +49,82 @@ class AirflowTool(BaseTools):
|
|||||||
self.operation = operation_details
|
self.operation = operation_details
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
def _validate_parameters(
|
def _validate_input(
|
||||||
self,
|
self,
|
||||||
path_params: dict[str, Any] | None = None,
|
path_params: dict[str, Any] | None = None,
|
||||||
query_params: dict[str, Any] | None = None,
|
query_params: dict[str, Any] | None = None,
|
||||||
body: dict[str, Any] | None = None,
|
body: dict[str, Any] | None = None,
|
||||||
) -> tuple[dict[str, Any] | None, dict[str, Any] | None, dict[str, Any] | None]:
|
) -> dict[str, Any]:
|
||||||
"""Validate input parameters against operation schemas.
|
"""Validate input parameters using unified input model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path_params: URL path parameters
|
path_params: Path parameters
|
||||||
query_params: URL query parameters
|
query_params: Query parameters
|
||||||
body: Request body data
|
body: Body parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of validated (path_params, query_params, body)
|
dict[str, Any]: Validated input parameters
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValidationError: If parameters fail validation
|
|
||||||
"""
|
"""
|
||||||
validated_params: dict[str, dict[str, Any] | None] = {
|
|
||||||
"path": None,
|
|
||||||
"query": None,
|
|
||||||
"body": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate path parameters
|
input_data = {}
|
||||||
if path_params and "path" in self.operation.parameters:
|
|
||||||
path_schema = self.operation.parameters["path"]
|
|
||||||
for name, value in path_params.items():
|
|
||||||
if name in path_schema:
|
|
||||||
param_type = path_schema[name]["type"]
|
|
||||||
if not isinstance(value, param_type):
|
|
||||||
raise create_validation_error(
|
|
||||||
field=name,
|
|
||||||
message=f"Path parameter {name} must be of type {param_type.__name__}",
|
|
||||||
)
|
|
||||||
validated_params["path"] = path_params
|
|
||||||
|
|
||||||
# Validate query parameters
|
if path_params:
|
||||||
if query_params and "query" in self.operation.parameters:
|
input_data.update({f"path_{k}": v for k, v in path_params.items()})
|
||||||
query_schema = self.operation.parameters["query"]
|
|
||||||
for name, value in query_params.items():
|
if query_params:
|
||||||
if name in query_schema:
|
input_data.update({f"query_{k}": v for k, v in query_params.items()})
|
||||||
param_type = query_schema[name]["type"]
|
|
||||||
if not isinstance(value, param_type):
|
if body:
|
||||||
raise create_validation_error(
|
input_data.update({f"body_{k}": v for k, v in body.items()})
|
||||||
field=name,
|
|
||||||
message=f"Query parameter {name} must be of type {param_type.__name__}",
|
validated = self.operation.input_model(**input_data)
|
||||||
)
|
return validated.model_dump()
|
||||||
validated_params["query"] = query_params
|
|
||||||
|
|
||||||
# Validate request body
|
|
||||||
if body and self.operation.request_body:
|
|
||||||
try:
|
|
||||||
model: type[BaseModel] = self.operation.request_body
|
|
||||||
validated_body = model(**body)
|
|
||||||
validated_params["body"] = validated_body.model_dump()
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
# Re-raise Pydantic validation errors directly
|
logger.error("Input validation failed: %s", e)
|
||||||
raise e
|
|
||||||
|
|
||||||
return (
|
|
||||||
validated_params["path"],
|
|
||||||
validated_params["query"],
|
|
||||||
validated_params["body"],
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Parameter validation failed: %s", e)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _extract_parameters(self, validated_input: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||||
|
"""Extract validated parameters by type."""
|
||||||
|
path_params = {}
|
||||||
|
query_params = {}
|
||||||
|
body = {}
|
||||||
|
|
||||||
|
# Extract parameters based on operation definition
|
||||||
|
for key, value in validated_input.items():
|
||||||
|
# Remove prefix from key if present
|
||||||
|
param_key = key
|
||||||
|
if key.startswith(("path_", "query_", "body_")):
|
||||||
|
param_key = key.split("_", 1)[1]
|
||||||
|
|
||||||
|
if key.startswith("path_"):
|
||||||
|
path_params[param_key] = value
|
||||||
|
elif key.startswith("query_"):
|
||||||
|
query_params[param_key] = value
|
||||||
|
elif key.startswith("body_"):
|
||||||
|
body[param_key] = value
|
||||||
|
else:
|
||||||
|
body[key] = value
|
||||||
|
|
||||||
|
return path_params, query_params, body
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
path_params: dict[str, Any] | None = None,
|
path_params: dict[str, Any] | None = None,
|
||||||
query_params: dict[str, Any] | None = None,
|
query_params: dict[str, Any] | None = None,
|
||||||
body: dict[str, Any] | None = None,
|
body: dict[str, Any] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Execute the operation with provided parameters.
|
"""Execute the operation with provided parameters."""
|
||||||
|
|
||||||
Args:
|
|
||||||
path_params: URL path parameters
|
|
||||||
query_params: URL query parameters
|
|
||||||
body: Request body data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API response data
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValidationError: If parameters fail validation
|
|
||||||
RuntimeError: If client execution fails
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Validate parameters
|
validated_input = self._validate_input(path_params, query_params, body)
|
||||||
validated_path_params, validated_query_params, validated_body = self._validate_parameters(path_params, query_params, body)
|
path_params, query_params, body = self._extract_parameters(validated_input)
|
||||||
|
|
||||||
# Execute operation
|
# Execute operation
|
||||||
response = await self.client.execute(
|
response = await self.client.execute(
|
||||||
operation_id=self.operation.operation_id,
|
operation_id=self.operation.operation_id,
|
||||||
path_params=validated_path_params,
|
path_params=path_params,
|
||||||
query_params=validated_query_params,
|
query_params=query_params,
|
||||||
body=validated_body,
|
body=body,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate response if model exists
|
# Validate response if model exists
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def get_airflow_tools() -> list[Tool]:
|
|||||||
Tool(
|
Tool(
|
||||||
name=operation_id,
|
name=operation_id,
|
||||||
description=tool.operation.operation_id,
|
description=tool.operation.operation_id,
|
||||||
inputSchema=tool.operation.request_body.model_json_schema() if tool.operation.request_body else None,
|
inputSchema=tool.operation.input_model.model_json_schema(),
|
||||||
)
|
)
|
||||||
for operation_id, tool in _tools_cache.items()
|
for operation_id, tool in _tools_cache.items()
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -39,33 +39,32 @@ def test_parse_operation_with_path_params(parser: OperationParser) -> None:
|
|||||||
operation = parser.parse_operation("get_dag")
|
operation = parser.parse_operation("get_dag")
|
||||||
|
|
||||||
assert operation.path == "/dags/{dag_id}"
|
assert operation.path == "/dags/{dag_id}"
|
||||||
assert "dag_id" in operation.parameters["path"]
|
assert isinstance(operation.input_model, type(BaseModel))
|
||||||
param = operation.parameters["path"]["dag_id"]
|
|
||||||
assert isinstance(param["type"], type(str))
|
# Verify path parameter field exists
|
||||||
assert param["required"] is True
|
fields = operation.input_model.__annotations__
|
||||||
|
assert "path_dag_id" in fields
|
||||||
|
assert isinstance(fields["path_dag_id"], type(str))
|
||||||
|
|
||||||
|
|
||||||
def test_parse_operation_with_query_params(parser: OperationParser) -> None:
|
def test_parse_operation_with_query_params(parser: OperationParser) -> None:
|
||||||
"""Test parsing operation with query parameters."""
|
"""Test parsing operation with query parameters."""
|
||||||
operation = parser.parse_operation("get_dags")
|
operation = parser.parse_operation("get_dags")
|
||||||
|
|
||||||
assert "limit" in operation.parameters["query"]
|
# Verify query parameter field exists
|
||||||
param = operation.parameters["query"]["limit"]
|
fields = operation.input_model.__annotations__
|
||||||
assert isinstance(param["type"], type(int))
|
assert "query_limit" in fields
|
||||||
assert param["required"] is False
|
assert isinstance(fields["query_limit"], type(int))
|
||||||
|
|
||||||
|
|
||||||
def test_parse_operation_with_request_body(parser: OperationParser) -> None:
|
def test_parse_operation_with_body_params(parser: OperationParser) -> None:
|
||||||
"""Test parsing operation with request body."""
|
"""Test parsing operation with request body."""
|
||||||
operation = parser.parse_operation("post_dag_run")
|
operation = parser.parse_operation("post_dag_run")
|
||||||
|
|
||||||
assert operation.request_body is not None
|
# Verify body fields exist
|
||||||
assert issubclass(operation.request_body, BaseModel)
|
fields = operation.input_model.__annotations__
|
||||||
|
assert "body_dag_run_id" in fields
|
||||||
# Test model fields
|
assert isinstance(fields["body_dag_run_id"], type(str))
|
||||||
fields = operation.request_body.__annotations__
|
|
||||||
assert "dag_run_id" in fields
|
|
||||||
assert isinstance(fields["dag_run_id"], type(str))
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_operation_with_response_model(parser: OperationParser) -> None:
|
def test_parse_operation_with_response_model(parser: OperationParser) -> None:
|
||||||
@@ -140,9 +139,7 @@ def test_create_model_nested_objects(parser: OperationParser) -> None:
|
|||||||
assert issubclass(model, BaseModel)
|
assert issubclass(model, BaseModel)
|
||||||
fields = model.__annotations__
|
fields = model.__annotations__
|
||||||
assert "nested" in fields
|
assert "nested" in fields
|
||||||
# Check that nested field is a Pydantic model
|
|
||||||
assert issubclass(fields["nested"], BaseModel)
|
assert issubclass(fields["nested"], BaseModel)
|
||||||
# Verify nested model structure
|
|
||||||
nested_fields = fields["nested"].__annotations__
|
nested_fields = fields["nested"].__annotations__
|
||||||
assert "field" in nested_fields
|
assert "field" in nested_fields
|
||||||
assert isinstance(nested_fields["field"], type(str))
|
assert isinstance(nested_fields["field"], type(str))
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def operation_details():
|
|||||||
"filter": {"type": str, "required": False},
|
"filter": {"type": str, "required": False},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
request_body=TestRequestModel,
|
input_model=TestRequestModel,
|
||||||
response_model=TestResponseModel,
|
response_model=TestResponseModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ from pydantic import BaseModel
|
|||||||
class TestRequestModel(BaseModel):
|
class TestRequestModel(BaseModel):
|
||||||
"""Test request model."""
|
"""Test request model."""
|
||||||
|
|
||||||
name: str
|
path_id: int
|
||||||
value: int
|
query_filter: str | None = None
|
||||||
|
body_name: str
|
||||||
|
body_value: int
|
||||||
|
|
||||||
|
|
||||||
class TestResponseModel(BaseModel):
|
class TestResponseModel(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user