From 95d351a123f56444af3191e6ad880b440f865884 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Sun, 23 Feb 2025 06:13:33 +0000 Subject: [PATCH] Complex schema compositions with `allOf`, `oneOf`, or `anyOf` --- .../parser/operation_parser.py | 43 ++++++++++++++++--- .../tests/parser/test_operation_parser.py | 31 +++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py index fdddd5b..482048d 100644 --- a/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py +++ b/airflow-mcp-server/src/airflow_mcp_server/parser/operation_parser.py @@ -57,6 +57,30 @@ class OperationParser: logger.error("Error initializing OperationParser: %s", e) raise ValueError(f"Failed to initialize parser: {e}") from e + def _merge_allof_schema(self, schema: dict[str, Any]) -> dict[str, Any]: + """Merge an allOf schema into a single effective schema. + + Args: + schema: The schema potentially containing allOf + + Returns: + A merged schema with unified properties and required fields + """ + if "allOf" not in schema: + return schema + merged = {"type": "object", "properties": {}, "required": []} + for subschema in schema["allOf"]: + resolved = subschema + if "$ref" in subschema: + resolved = self._resolve_ref(subschema["$ref"]) + if "properties" in resolved: + merged["properties"].update(resolved["properties"]) + if "required" in resolved: + merged["required"].extend(resolved.get("required", [])) + merged["required"] = list(set(merged["required"])) # Remove duplicates + logger.debug("Merged allOf schema: %s", merged) + return merged + def parse_operation(self, operation_id: str) -> OperationDetails: """Parse operation details from OpenAPI spec. @@ -125,12 +149,19 @@ class OperationParser: fields[name] = (field_type | None, None) # Make all optional parameter_mapping["query"].append(name) - # Add body fields if present - if body_schema and body_schema.get("type") == "object": - for prop_name, prop_schema in body_schema.get("properties", {}).items(): - field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema) - fields[prop_name] = (field_type | None, None) # Make all optional - parameter_mapping["body"].append(prop_name) + # Handle body schema with allOf support + if body_schema: + effective_schema = self._merge_allof_schema(body_schema) + if "properties" in effective_schema or effective_schema.get("type") == "object": + for prop_name, prop_schema in effective_schema.get("properties", {}).items(): + field_type = self._map_type(prop_schema.get("type", "string"), prop_schema.get("format"), prop_schema) + default = None if prop_name in effective_schema.get("required", []) else None + if prop_name == "schema": # Avoid shadowing BaseModel.schema + fields["connection_schema"] = (field_type | None, Field(default, alias="schema")) + parameter_mapping["body"].append("connection_schema") + else: + fields[prop_name] = (field_type | None, default) + parameter_mapping["body"].append(prop_name) logger.debug("Creating input model for %s with fields: %s", operation_id, fields) model = create_model(f"{operation_id}_input", **fields) diff --git a/airflow-mcp-server/tests/parser/test_operation_parser.py b/airflow-mcp-server/tests/parser/test_operation_parser.py index a726a89..ab02873 100644 --- a/airflow-mcp-server/tests/parser/test_operation_parser.py +++ b/airflow-mcp-server/tests/parser/test_operation_parser.py @@ -141,3 +141,34 @@ def test_create_model_nested_objects(parser: OperationParser) -> None: nested_fields = fields["nested"].__annotations__ assert "field" in nested_fields assert isinstance(nested_fields["field"], type(str)) + + +def test_parse_operation_with_allof_body(parser: OperationParser) -> None: + """Test parsing operation with allOf schema in request body.""" + operation = parser.parse_operation("test_connection") + + assert isinstance(operation, OperationDetails) + assert operation.operation_id == "test_connection" + assert operation.path == "/connections/test" + assert operation.method == "post" + + # Verify input model includes fields from allOf schema + fields = operation.input_model.__annotations__ + assert "connection_id" in fields, "Missing connection_id from ConnectionCollectionItem" + assert str in fields["connection_id"].__args__, "connection_id should be a string" + assert "password" in fields, "Missing password from Connection" + assert str in fields["password"].__args__, "password should be a string" + assert "connection_schema" in fields, "Missing schema field (aliased as connection_schema)" + assert str in fields["connection_schema"].__args__, "connection_schema should be a string" + + # Verify parameter mapping + mapping = operation.input_model.model_config["parameter_mapping"] + assert "body" in mapping + assert "connection_id" in mapping["body"] + assert "password" in mapping["body"] + assert "connection_schema" in mapping["body"] + + # Verify alias configuration + model_fields = operation.input_model.model_fields + assert "connection_schema" in model_fields + assert model_fields["connection_schema"].alias == "schema", "connection_schema should alias to schema"