feat: Implement async utilities for MCP server management and JSON-RPC communication

- Added `process.py` for managing MCP server subprocesses with async capabilities.
- Introduced `protocol.py` for handling JSON-RPC communication over streams.
- Created `llm_client.py` to support chat completion requests to various LLM providers, integrating with MCP tools.
- Defined model configurations in `llm_models.py` for different LLM providers.
- Removed the synchronous `mcp_manager.py` in favor of a more modular approach.
- Established a provider framework in `providers` directory with a base class and specific implementations.
- Implemented `OpenAIProvider` for interacting with OpenAI's API, including streaming support and tool call handling.
This commit is contained in:
2025-03-26 11:00:20 +00:00
parent a7d5a4cb33
commit 80ba05338f
14 changed files with 1749 additions and 273 deletions

71
src/providers/__init__.py Normal file
View File

@@ -0,0 +1,71 @@
# src/providers/__init__.py
import logging
from providers.base import BaseProvider
# Import specific provider implementations here as they are created
from providers.openai_provider import OpenAIProvider
# from .anthropic_provider import AnthropicProvider
# from .google_provider import GoogleProvider
# from .openrouter_provider import OpenRouterProvider
logger = logging.getLogger(__name__)
# Map provider names (lowercase) to their corresponding class implementations
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
"openai": OpenAIProvider,
# "anthropic": AnthropicProvider,
# "google": GoogleProvider,
# "openrouter": OpenRouterProvider,
}
def register_provider(name: str, provider_class: type[BaseProvider]):
"""Registers a provider class."""
if name.lower() in PROVIDER_MAP:
logger.warning(f"Provider '{name}' is already registered. Overwriting.")
PROVIDER_MAP[name.lower()] = provider_class
logger.info(f"Registered provider: {name}")
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider:
"""
Factory function to create an instance of a specific LLM provider.
Args:
provider_name: The name of the provider (e.g., 'openai', 'anthropic').
api_key: The API key for the provider.
base_url: Optional base URL for the provider's API.
Returns:
An instance of the requested BaseProvider subclass.
Raises:
ValueError: If the requested provider_name is not registered.
"""
provider_class = PROVIDER_MAP.get(provider_name.lower())
if provider_class is None:
available = ", ".join(PROVIDER_MAP.keys()) or "None"
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
logger.info(f"Creating LLM provider instance for: {provider_name}")
try:
return provider_class(api_key=api_key, base_url=base_url)
except Exception as e:
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
def get_available_providers() -> list[str]:
"""Returns a list of registered provider names."""
return list(PROVIDER_MAP.keys())
# Example of how specific providers would register themselves if structured as plugins,
# but for now, we'll explicitly import and map them above.
# def load_providers():
# # Potentially load providers dynamically if designed as plugins
# pass
# load_providers()

140
src/providers/base.py Normal file
View File

@@ -0,0 +1,140 @@
# src/providers/base.py
import abc
from collections.abc import Generator
from typing import Any
class BaseProvider(abc.ABC):
"""
Abstract base class for LLM providers.
Defines the common interface for interacting with different LLM APIs,
including handling chat completions and tool usage.
"""
def __init__(self, api_key: str, base_url: str | None = None):
"""
Initialize the provider.
Args:
api_key: The API key for the provider.
base_url: Optional base URL for the provider's API.
"""
self.api_key = api_key
self.base_url = base_url
@abc.abstractmethod
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Any:
"""
Send a chat completion request to the LLM provider.
Args:
messages: List of message dictionaries with 'role' and 'content'.
model: Model identifier.
temperature: Sampling temperature (0-1).
max_tokens: Maximum tokens to generate.
stream: Whether to stream the response.
tools: Optional list of tools in the provider-specific format.
Returns:
Provider-specific response object (e.g., API response, stream object).
"""
pass
@abc.abstractmethod
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
"""
Extracts and yields content chunks from a streaming response object.
Args:
response: The streaming response object returned by create_chat_completion.
Yields:
String chunks of the response content.
"""
pass
@abc.abstractmethod
def get_content(self, response: Any) -> str:
"""
Extracts the complete content from a non-streaming response object.
Args:
response: The non-streaming response object.
Returns:
The complete response content as a string.
"""
pass
@abc.abstractmethod
def has_tool_calls(self, response: Any) -> bool:
"""
Checks if the response object contains tool calls.
Args:
response: The response object (streaming or non-streaming).
Returns:
True if tool calls are present, False otherwise.
"""
pass
@abc.abstractmethod
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
"""
Parses tool calls from the response object.
Args:
response: The response object containing tool calls.
Returns:
A list of dictionaries, each representing a tool call with details
like 'id', 'function_name', 'arguments'. The exact structure might
vary slightly based on provider needs but should contain enough
info for execution.
"""
pass
@abc.abstractmethod
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""
Formats the result of a tool execution into the structure expected
by the provider for follow-up requests.
Args:
tool_call_id: The unique ID of the tool call (from parse_tool_calls).
result: The data returned by the tool execution.
Returns:
A dictionary representing the tool result in the provider's format.
"""
pass
@abc.abstractmethod
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Converts a list of tools from the standard internal format to the
provider-specific format required for the API call.
Args:
tools: List of tool definitions in the standard internal format.
Each dict contains 'server_name', 'name', 'description', 'input_schema'.
Returns:
List of tool definitions in the provider-specific format.
"""
pass
# Optional: Add a method for follow-up completions if the provider API
# requires a specific structure different from just appending messages.
# def create_follow_up_completion(...) -> Any:
# pass

View File

@@ -0,0 +1,239 @@
# src/providers/openai_provider.py
import json
import logging
from collections.abc import Generator
from typing import Any
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
from providers.base import BaseProvider
from src.llm_models import MODELS # Use absolute import
logger = logging.getLogger(__name__)
class OpenAIProvider(BaseProvider):
"""Provider implementation for OpenAI and compatible APIs."""
def __init__(self, api_key: str, base_url: str | None = None):
# Use default OpenAI endpoint if base_url is not provided
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
super().__init__(api_key, effective_base_url)
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
try:
# TODO: Add default headers like in original client?
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
raise
def create_chat_completion(
self,
messages: list[dict[str, str]],
model: str,
temperature: float = 0.4,
max_tokens: int | None = None,
stream: bool = True,
tools: list[dict[str, Any]] | None = None,
) -> Stream[ChatCompletionChunk] | ChatCompletion:
"""Creates a chat completion using the OpenAI API."""
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
try:
completion_params = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
if tools:
completion_params["tools"] = tools
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
# Remove None values like max_tokens if not provided
completion_params = {k: v for k, v in completion_params.items() if v is not None}
# --- Added Debug Logging ---
log_params = completion_params.copy()
# Avoid logging full messages if they are too long
if "messages" in log_params:
log_params["messages"] = [
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
for msg in log_params["messages"][-2:] # Log last 2 messages summary
]
# Specifically log tools structure if present
tools_log = log_params.get("tools", "Not Present")
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
logger.debug(f"Full API Params (messages summarized): {log_params}")
# --- End Added Debug Logging ---
response = self.client.chat.completions.create(**completion_params)
logger.debug("OpenAI API call successful.")
return response
except Exception as e:
logger.error(f"OpenAI API error: {e}", exc_info=True)
# Re-raise for the LLMClient to handle
raise
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
"""Yields content chunks from an OpenAI streaming response."""
logger.debug("Processing OpenAI stream...")
full_delta = ""
try:
for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
full_delta += delta
yield delta
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
except Exception as e:
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
# Yield an error message? Or let the generator stop?
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
def get_content(self, response: ChatCompletion) -> str:
"""Extracts content from a non-streaming OpenAI response."""
try:
content = response.choices[0].message.content
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
return content or "" # Return empty string if content is None
except Exception as e:
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
return f"[Error extracting content: {str(e)}]"
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
"""Checks if the OpenAI response contains tool calls."""
try:
if isinstance(response, ChatCompletion): # Non-streaming
return bool(response.choices[0].message.tool_calls)
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
# This is tricky for streams. We'd need to peek at the first chunk(s)
# or buffer the response. For simplicity, this check might be unreliable
# for streams *before* they are consumed. LLMClient needs robust handling.
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
# A more robust check would involve consuming the start of the stream
# or relying on the structure after consumption.
return False # Assume no for unconsumed stream for now
else:
# If it's already consumed stream or unexpected type
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
return False
except Exception as e:
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
return False
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
"""Parses tool calls from a non-streaming OpenAI response."""
# This implementation assumes a non-streaming response or a fully buffered stream
parsed_calls = []
try:
if not isinstance(response, ChatCompletion):
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
# Attempt to handle buffered stream if possible? Complex.
return []
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
if not tool_calls:
return []
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
for call in tool_calls:
if call.type == "function":
# Attempt to parse server_name from function name if prefixed
# e.g., "server-name__actual-tool-name"
parts = call.function.name.split("__", 1)
if len(parts) == 2:
server_name, func_name = parts
else:
# If no prefix, how do we know the server? Needs refinement.
# Defaulting to None or a default server? Log warning.
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
server_name = None # Or raise error, or use a default?
func_name = call.function.name
parsed_calls.append({
"id": call.id,
"server_name": server_name, # May be None if not prefixed
"function_name": func_name,
"arguments": call.function.arguments, # Arguments are already a string here
})
else:
logger.warning(f"Unsupported tool call type: {call.type}")
return parsed_calls
except Exception as e:
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
return [] # Return empty list on error
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
"""Formats a tool result for an OpenAI follow-up request."""
# Result might be a dict (including potential errors) or simple string/number
# OpenAI expects the content to be a string, often JSON.
try:
if isinstance(result, dict):
content = json.dumps(result)
else:
content = str(result) # Ensure it's a string
except Exception as e:
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": content,
}
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Converts internal tool format to OpenAI's format."""
openai_tools = []
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
for tool in tools:
server_name = tool.get("server_name")
tool_name = tool.get("name")
description = tool.get("description")
input_schema = tool.get("inputSchema")
if not server_name or not tool_name or not description or not input_schema:
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
continue
# Prefix tool name with server name to avoid clashes and allow routing
prefixed_tool_name = f"{server_name}__{tool_name}"
openai_tool_format = {
"type": "function",
"function": {
"name": prefixed_tool_name,
"description": description,
"parameters": input_schema, # OpenAI uses JSON Schema directly
},
}
openai_tools.append(openai_tool_format)
logger.debug(f"Converted tool: {prefixed_tool_name}")
return openai_tools
# Helper needed by LLMClient's current tool handling logic
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
"""Extracts the assistant's message containing tool calls."""
try:
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
message = response.choices[0].message
# Convert Pydantic model to dict for message history
return message.model_dump(exclude_unset=True)
else:
logger.warning("Could not extract original message with tool calls from response.")
# Return a placeholder or raise error?
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
except Exception as e:
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
# Register this provider (if using the registration mechanism)
# from . import register_provider
# register_provider("openai", OpenAIProvider)