feat: Implement async utilities for MCP server management and JSON-RPC communication
- Added `process.py` for managing MCP server subprocesses with async capabilities. - Introduced `protocol.py` for handling JSON-RPC communication over streams. - Created `llm_client.py` to support chat completion requests to various LLM providers, integrating with MCP tools. - Defined model configurations in `llm_models.py` for different LLM providers. - Removed the synchronous `mcp_manager.py` in favor of a more modular approach. - Established a provider framework in `providers` directory with a base class and specific implementations. - Implemented `OpenAIProvider` for interacting with OpenAI's API, including streaming support and tool call handling.
This commit is contained in:
71
src/providers/__init__.py
Normal file
71
src/providers/__init__.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# src/providers/__init__.py
|
||||
import logging
|
||||
|
||||
from providers.base import BaseProvider
|
||||
|
||||
# Import specific provider implementations here as they are created
|
||||
from providers.openai_provider import OpenAIProvider
|
||||
|
||||
# from .anthropic_provider import AnthropicProvider
|
||||
# from .google_provider import GoogleProvider
|
||||
# from .openrouter_provider import OpenRouterProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Map provider names (lowercase) to their corresponding class implementations
|
||||
PROVIDER_MAP: dict[str, type[BaseProvider]] = {
|
||||
"openai": OpenAIProvider,
|
||||
# "anthropic": AnthropicProvider,
|
||||
# "google": GoogleProvider,
|
||||
# "openrouter": OpenRouterProvider,
|
||||
}
|
||||
|
||||
|
||||
def register_provider(name: str, provider_class: type[BaseProvider]):
|
||||
"""Registers a provider class."""
|
||||
if name.lower() in PROVIDER_MAP:
|
||||
logger.warning(f"Provider '{name}' is already registered. Overwriting.")
|
||||
PROVIDER_MAP[name.lower()] = provider_class
|
||||
logger.info(f"Registered provider: {name}")
|
||||
|
||||
|
||||
def create_llm_provider(provider_name: str, api_key: str, base_url: str | None = None) -> BaseProvider:
|
||||
"""
|
||||
Factory function to create an instance of a specific LLM provider.
|
||||
|
||||
Args:
|
||||
provider_name: The name of the provider (e.g., 'openai', 'anthropic').
|
||||
api_key: The API key for the provider.
|
||||
base_url: Optional base URL for the provider's API.
|
||||
|
||||
Returns:
|
||||
An instance of the requested BaseProvider subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If the requested provider_name is not registered.
|
||||
"""
|
||||
provider_class = PROVIDER_MAP.get(provider_name.lower())
|
||||
|
||||
if provider_class is None:
|
||||
available = ", ".join(PROVIDER_MAP.keys()) or "None"
|
||||
raise ValueError(f"Unsupported LLM provider: '{provider_name}'. Available providers: {available}")
|
||||
|
||||
logger.info(f"Creating LLM provider instance for: {provider_name}")
|
||||
try:
|
||||
return provider_class(api_key=api_key, base_url=base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to instantiate provider '{provider_name}': {e}", exc_info=True)
|
||||
raise RuntimeError(f"Could not create provider '{provider_name}'.") from e
|
||||
|
||||
|
||||
def get_available_providers() -> list[str]:
|
||||
"""Returns a list of registered provider names."""
|
||||
return list(PROVIDER_MAP.keys())
|
||||
|
||||
|
||||
# Example of how specific providers would register themselves if structured as plugins,
|
||||
# but for now, we'll explicitly import and map them above.
|
||||
# def load_providers():
|
||||
# # Potentially load providers dynamically if designed as plugins
|
||||
# pass
|
||||
# load_providers()
|
||||
140
src/providers/base.py
Normal file
140
src/providers/base.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# src/providers/base.py
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseProvider(abc.ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Defines the common interface for interacting with different LLM APIs,
|
||||
including handling chat completions and tool usage.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
Initialize the provider.
|
||||
|
||||
Args:
|
||||
api_key: The API key for the provider.
|
||||
base_url: Optional base URL for the provider's API.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Send a chat completion request to the LLM provider.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'.
|
||||
model: Model identifier.
|
||||
temperature: Sampling temperature (0-1).
|
||||
max_tokens: Maximum tokens to generate.
|
||||
stream: Whether to stream the response.
|
||||
tools: Optional list of tools in the provider-specific format.
|
||||
|
||||
Returns:
|
||||
Provider-specific response object (e.g., API response, stream object).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_streaming_content(self, response: Any) -> Generator[str, None, None]:
|
||||
"""
|
||||
Extracts and yields content chunks from a streaming response object.
|
||||
|
||||
Args:
|
||||
response: The streaming response object returned by create_chat_completion.
|
||||
|
||||
Yields:
|
||||
String chunks of the response content.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_content(self, response: Any) -> str:
|
||||
"""
|
||||
Extracts the complete content from a non-streaming response object.
|
||||
|
||||
Args:
|
||||
response: The non-streaming response object.
|
||||
|
||||
Returns:
|
||||
The complete response content as a string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def has_tool_calls(self, response: Any) -> bool:
|
||||
"""
|
||||
Checks if the response object contains tool calls.
|
||||
|
||||
Args:
|
||||
response: The response object (streaming or non-streaming).
|
||||
|
||||
Returns:
|
||||
True if tool calls are present, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def parse_tool_calls(self, response: Any) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Parses tool calls from the response object.
|
||||
|
||||
Args:
|
||||
response: The response object containing tool calls.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each representing a tool call with details
|
||||
like 'id', 'function_name', 'arguments'. The exact structure might
|
||||
vary slightly based on provider needs but should contain enough
|
||||
info for execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Formats the result of a tool execution into the structure expected
|
||||
by the provider for follow-up requests.
|
||||
|
||||
Args:
|
||||
tool_call_id: The unique ID of the tool call (from parse_tool_calls).
|
||||
result: The data returned by the tool execution.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the tool result in the provider's format.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Converts a list of tools from the standard internal format to the
|
||||
provider-specific format required for the API call.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions in the standard internal format.
|
||||
Each dict contains 'server_name', 'name', 'description', 'input_schema'.
|
||||
|
||||
Returns:
|
||||
List of tool definitions in the provider-specific format.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Optional: Add a method for follow-up completions if the provider API
|
||||
# requires a specific structure different from just appending messages.
|
||||
# def create_follow_up_completion(...) -> Any:
|
||||
# pass
|
||||
239
src/providers/openai_provider.py
Normal file
239
src/providers/openai_provider.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# src/providers/openai_provider.py
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
from providers.base import BaseProvider
|
||||
from src.llm_models import MODELS # Use absolute import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Provider implementation for OpenAI and compatible APIs."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Use default OpenAI endpoint if base_url is not provided
|
||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url)
|
||||
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
|
||||
try:
|
||||
# TODO: Add default headers like in original client?
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
"""Creates a chat completion using the OpenAI API."""
|
||||
logger.debug(f"OpenAI create_chat_completion called. Stream: {stream}, Tools: {bool(tools)}")
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||
|
||||
# Remove None values like max_tokens if not provided
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
# --- Added Debug Logging ---
|
||||
log_params = completion_params.copy()
|
||||
# Avoid logging full messages if they are too long
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [
|
||||
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||
]
|
||||
# Specifically log tools structure if present
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
# --- End Added Debug Logging ---
|
||||
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
logger.debug("OpenAI API call successful.")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
# Re-raise for the LLMClient to handle
|
||||
raise
|
||||
|
||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an OpenAI streaming response."""
|
||||
logger.debug("Processing OpenAI stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for chunk in response:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||
# Yield an error message? Or let the generator stop?
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: ChatCompletion) -> str:
|
||||
"""Extracts content from a non-streaming OpenAI response."""
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||
return content or "" # Return empty string if content is None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
|
||||
# This is tricky for streams. We'd need to peek at the first chunk(s)
|
||||
# or buffer the response. For simplicity, this check might be unreliable
|
||||
# for streams *before* they are consumed. LLMClient needs robust handling.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
# A more robust check would involve consuming the start of the stream
|
||||
# or relying on the structure after consumption.
|
||||
return False # Assume no for unconsumed stream for now
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
# Attempt to handle buffered stream if possible? Complex.
|
||||
return []
|
||||
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
func_name = call.function.name
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"function_name": func_name,
|
||||
"arguments": call.function.arguments, # Arguments are already a string here
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
else:
|
||||
content = str(result) # Ensure it's a string
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
# Helper needed by LLMClient's current tool handling logic
|
||||
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
|
||||
|
||||
# Register this provider (if using the registration mechanism)
|
||||
# from . import register_provider
|
||||
# register_provider("openai", OpenAIProvider)
|
||||
Reference in New Issue
Block a user