From 678f39564904f314fd2dde423ad750e79602f784 Mon Sep 17 00:00:00 2001 From: abhishekbhakat Date: Wed, 26 Mar 2025 19:59:01 +0000 Subject: [PATCH] feat: implement OpenAIProvider with client initialization, message handling, and utility functions --- src/providers/openai_provider.py | 390 -------------------- src/providers/openai_provider/__init__.py | 66 ++++ src/providers/openai_provider/client.py | 23 ++ src/providers/openai_provider/completion.py | 80 ++++ src/providers/openai_provider/response.py | 69 ++++ src/providers/openai_provider/tools.py | 170 +++++++++ src/providers/openai_provider/utils.py | 114 ++++++ src/tools/conversion.py | 53 --- 8 files changed, 522 insertions(+), 443 deletions(-) delete mode 100644 src/providers/openai_provider.py create mode 100644 src/providers/openai_provider/__init__.py create mode 100644 src/providers/openai_provider/client.py create mode 100644 src/providers/openai_provider/completion.py create mode 100644 src/providers/openai_provider/response.py create mode 100644 src/providers/openai_provider/tools.py create mode 100644 src/providers/openai_provider/utils.py diff --git a/src/providers/openai_provider.py b/src/providers/openai_provider.py deleted file mode 100644 index 7b16eb4..0000000 --- a/src/providers/openai_provider.py +++ /dev/null @@ -1,390 +0,0 @@ -# src/providers/openai_provider.py -import json -import logging -import math -from collections.abc import Generator -from typing import Any - -from openai import OpenAI, Stream -from openai.types.chat import ChatCompletion, ChatCompletionChunk -from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall - -from src.llm_models import MODELS -from src.providers.base import BaseProvider - -logger = logging.getLogger(__name__) - - -class OpenAIProvider(BaseProvider): - """Provider implementation for OpenAI and compatible APIs.""" - - def __init__(self, api_key: str, base_url: str | None = None): - # Use default OpenAI endpoint if base_url is not provided - effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint") - super().__init__(api_key, effective_base_url) - logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}") - try: - # TODO: Add default headers like in original client? - self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) - except Exception as e: - logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True) - raise - - def _get_context_window(self, model: str) -> int: - """Retrieves the context window size for a given model.""" - # Default to a safe fallback if model or provider info is missing - default_window = 8000 - try: - # Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts - provider_models = MODELS.get("openai", {}).get("models", []) - for m in provider_models: - if m.get("id") == model: - return m.get("context_window", default_window) - # Fallback if specific model ID not found in our list - logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}") - return default_window - except Exception as e: - logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) - return default_window - - def _estimate_openai_token_count(self, messages: list[dict[str, str]]) -> int: - """ - Estimates the token count for OpenAI messages using char count / 4 approximation. - Note: This is less accurate than using tiktoken. - """ - total_chars = 0 - for message in messages: - total_chars += len(message.get("role", "")) - content = message.get("content") - if isinstance(content, str): - total_chars += len(content) - # Rough approximation for function/tool call overhead if needed later - # Using math.ceil to round up, ensuring we don't underestimate too much. - estimated_tokens = math.ceil(total_chars / 4.0) - logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages") - return estimated_tokens - - def _truncate_messages(self, messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]: - """ - Truncates messages from the beginning if estimated token count exceeds the limit. - Preserves the first message if it's a system prompt. - - Returns: - - The potentially truncated list of messages. - - The initial estimated token count. - - The final estimated token count after truncation (if any). - """ - context_limit = self._get_context_window(model) - # Add a buffer to be safer with approximation - buffer = 200 # Reduce buffer slightly as we round up now - effective_limit = context_limit - buffer - - initial_estimated_count = self._estimate_openai_token_count(messages) - final_estimated_count = initial_estimated_count - - truncated_messages = list(messages) # Make a copy - - # Identify if the first message is a system prompt - has_system_prompt = False - if truncated_messages and truncated_messages[0].get("role") == "system": - has_system_prompt = True - # If only system prompt exists, don't truncate further - if len(truncated_messages) == 1 and final_estimated_count > effective_limit: - logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.") - # Return original messages to avoid removing the only message - return messages, initial_estimated_count, final_estimated_count - - while final_estimated_count > effective_limit: - if has_system_prompt and len(truncated_messages) <= 1: - # Should not happen if check above works, but safety break - logger.warning("Truncation stopped: Only system prompt remains.") - break - if not has_system_prompt and len(truncated_messages) <= 0: - logger.warning("Truncation stopped: No messages left.") - break # No messages left - - # Determine index to remove: 1 if system prompt exists and list is long enough, else 0 - remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0 - - if remove_index >= len(truncated_messages): - logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.") - break # Avoid index error - - removed_message = truncated_messages.pop(remove_index) - logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.") - - # Recalculate estimated count - final_estimated_count = self._estimate_openai_token_count(truncated_messages) - logger.debug(f"Recalculated estimated tokens: {final_estimated_count}") - - # Safety break if list becomes unexpectedly empty - if not truncated_messages: - logger.warning("Truncation resulted in empty message list.") - break - - if initial_estimated_count != final_estimated_count: - logger.info( - f"Truncated messages for model {model}. " - f"Initial estimated tokens: {initial_estimated_count}, " - f"Final estimated tokens: {final_estimated_count}, " - f"Limit: {context_limit} (Effective: {effective_limit})" - ) - else: - logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})") - - return truncated_messages, initial_estimated_count, final_estimated_count - - def create_chat_completion( - self, - messages: list[dict[str, str]], - model: str, - temperature: float = 0.4, - max_tokens: int | None = None, - stream: bool = True, - tools: list[dict[str, Any]] | None = None, - # Add usage dict to return type hint? Needs careful thought for streaming vs non-streaming - ) -> Stream[ChatCompletionChunk] | ChatCompletion: # How to return usage info cleanly? - """Creates a chat completion using the OpenAI API, handling context window truncation.""" - logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") - - # --- Truncation Step --- - truncated_messages, initial_est_tokens, final_est_tokens = self._truncate_messages(messages, model) - # ----------------------- - - try: - completion_params = { - "model": model, - "messages": truncated_messages, # Use truncated messages - "temperature": temperature, - "max_tokens": max_tokens, - "stream": stream, - } - if tools: - completion_params["tools"] = tools - completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools - - # Remove None values like max_tokens if not provided - completion_params = {k: v for k, v in completion_params.items() if v is not None} - - # --- Added Debug Logging --- - log_params = completion_params.copy() - # Avoid logging full messages if they are too long - if "messages" in log_params: - log_params["messages"] = [ - {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()} - for msg in log_params["messages"][-2:] # Log last 2 messages summary - ] - # Specifically log tools structure if present - tools_log = log_params.get("tools", "Not Present") - logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}") - logger.debug(f"Full API Params (messages summarized): {log_params}") - # --- End Added Debug Logging --- - - response = self.client.chat.completions.create(**completion_params) - logger.debug("OpenAI API call successful.") - - # --- Capture Actual Usage (for UI display later) --- - # This part is tricky. Usage info is easily available on the *non-streaming* response. - # For streaming, it's often not available until the stream is fully consumed, - # or sometimes via response headers or a final event (provider-dependent). - # For now, let's focus on getting it from the non-streaming case. - # We need a way to pass this back alongside the content/stream. - # Option 1: Modify return type (complex for stream/non-stream union) - # Option 2: Store it in the provider instance (stateful, maybe bad) - # Option 3: Have LLMClient handle extraction (requires LLMClient to know response structure) - - # Let's try returning it alongside for non-streaming, and figure out streaming later. - # This requires changing the BaseProvider interface and LLMClient handling. - # For now, just log it here. - actual_usage = None - if isinstance(response, ChatCompletion) and response.usage: - actual_usage = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - } - logger.info(f"Actual OpenAI API usage: {actual_usage}") - # TODO: How to handle usage for streaming responses? Needs investigation. - - # Return the raw response for now. LLMClient will process it. - return response - # ---------------------------------------------------- - - except Exception as e: - logger.error(f"OpenAI API error: {e}", exc_info=True) - # Re-raise for the LLMClient to handle - raise - - def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]: - """Yields content chunks from an OpenAI streaming response.""" - logger.debug("Processing OpenAI stream...") - full_delta = "" - try: - for chunk in response: - delta = chunk.choices[0].delta.content - if delta: - full_delta += delta - yield delta - logger.debug(f"Stream finished. Total delta length: {len(full_delta)}") - except Exception as e: - logger.error(f"Error processing OpenAI stream: {e}", exc_info=True) - # Yield an error message? Or let the generator stop? - yield json.dumps({"error": f"Stream processing error: {str(e)}"}) - - def get_content(self, response: ChatCompletion) -> str: - """Extracts content from a non-streaming OpenAI response.""" - try: - content = response.choices[0].message.content - logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.") - return content or "" # Return empty string if content is None - except Exception as e: - logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True) - return f"[Error extracting content: {str(e)}]" - - def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool: - """Checks if the OpenAI response contains tool calls.""" - try: - if isinstance(response, ChatCompletion): # Non-streaming - return bool(response.choices[0].message.tool_calls) - elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper - # This is tricky for streams. We'd need to peek at the first chunk(s) - # or buffer the response. For simplicity, this check might be unreliable - # for streams *before* they are consumed. LLMClient needs robust handling. - logger.warning("has_tool_calls check on a stream is unreliable before consumption.") - # A more robust check would involve consuming the start of the stream - # or relying on the structure after consumption. - return False # Assume no for unconsumed stream for now - else: - # If it's already consumed stream or unexpected type - logger.warning(f"has_tool_calls received unexpected type: {type(response)}") - return False - except Exception as e: - logger.error(f"Error checking for tool calls: {e}", exc_info=True) - return False - - def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]: - """Parses tool calls from a non-streaming OpenAI response.""" - # This implementation assumes a non-streaming response or a fully buffered stream - parsed_calls = [] - try: - if not isinstance(response, ChatCompletion): - logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}") - # Attempt to handle buffered stream if possible? Complex. - return [] - - tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls - if not tool_calls: - return [] - - logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.") - for call in tool_calls: - if call.type == "function": - # Attempt to parse server_name from function name if prefixed - # e.g., "server-name__actual-tool-name" - parts = call.function.name.split("__", 1) - if len(parts) == 2: - server_name, func_name = parts - else: - # If no prefix, how do we know the server? Needs refinement. - # Defaulting to None or a default server? Log warning. - logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.") - server_name = None # Or raise error, or use a default? - func_name = call.function.name - - parsed_calls.append({ - "id": call.id, - "server_name": server_name, # May be None if not prefixed - "function_name": func_name, - "arguments": call.function.arguments, # Arguments are already a string here - }) - else: - logger.warning(f"Unsupported tool call type: {call.type}") - - return parsed_calls - except Exception as e: - logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True) - return [] # Return empty list on error - - def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]: - """Formats a tool result for an OpenAI follow-up request.""" - # Result might be a dict (including potential errors) or simple string/number - # OpenAI expects the content to be a string, often JSON. - try: - if isinstance(result, dict): - content = json.dumps(result) - else: - content = str(result) # Ensure it's a string - except Exception as e: - logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}") - content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))}) - - logger.debug(f"Formatting tool result for call ID {tool_call_id}") - return { - "role": "tool", - "tool_call_id": tool_call_id, - "content": content, - } - - def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Converts internal tool format to OpenAI's format.""" - openai_tools = [] - logger.debug(f"Converting {len(tools)} tools to OpenAI format.") - for tool in tools: - server_name = tool.get("server_name") - tool_name = tool.get("name") - description = tool.get("description") - input_schema = tool.get("inputSchema") - - if not server_name or not tool_name or not description or not input_schema: - logger.warning(f"Skipping invalid tool definition during conversion: {tool}") - continue - - # Prefix tool name with server name to avoid clashes and allow routing - prefixed_tool_name = f"{server_name}__{tool_name}" - - openai_tool_format = { - "type": "function", - "function": { - "name": prefixed_tool_name, - "description": description, - "parameters": input_schema, # OpenAI uses JSON Schema directly - }, - } - openai_tools.append(openai_tool_format) - logger.debug(f"Converted tool: {prefixed_tool_name}") - - return openai_tools - - # Helper needed by LLMClient's current tool handling logic - def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]: - """Extracts the assistant's message containing tool calls.""" - try: - if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls: - message = response.choices[0].message - # Convert Pydantic model to dict for message history - return message.model_dump(exclude_unset=True) - else: - logger.warning("Could not extract original message with tool calls from response.") - # Return a placeholder or raise error? - return {"role": "assistant", "content": "[Could not extract tool calls message]"} - except Exception as e: - logger.error(f"Error extracting original message with calls: {e}", exc_info=True) - return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} - - def get_usage(self, response: Any) -> dict[str, int] | None: - """Extracts token usage from a non-streaming OpenAI response.""" - try: - if isinstance(response, ChatCompletion) and response.usage: - usage = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - # "total_tokens": response.usage.total_tokens, # Optional - } - logger.debug(f"Extracted usage from OpenAI response: {usage}") - return usage - else: - logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}") - return None - except Exception as e: - logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True) - return None diff --git a/src/providers/openai_provider/__init__.py b/src/providers/openai_provider/__init__.py new file mode 100644 index 0000000..96995ff --- /dev/null +++ b/src/providers/openai_provider/__init__.py @@ -0,0 +1,66 @@ +# src/providers/openai_provider/__init__.py +from typing import Any + +from openai import Stream +from openai.types.chat import ChatCompletion, ChatCompletionChunk + +from providers.openai_provider.client import initialize_client +from providers.openai_provider.completion import create_chat_completion +from providers.openai_provider.response import get_content, get_streaming_content, get_usage +from providers.openai_provider.tools import ( + convert_tools, + format_tool_results, + get_original_message_with_calls, + has_tool_calls, + parse_tool_calls, +) +from src.providers.base import BaseProvider + + +class OpenAIProvider(BaseProvider): + """Provider implementation for OpenAI and compatible APIs.""" + + def __init__(self, api_key: str, base_url: str | None = None): + # BaseProvider __init__ might not be needed if client init handles base_url logic + # super().__init__(api_key, base_url) # Let's see if we need this + self.client = initialize_client(api_key, base_url) + # Store api_key and base_url if needed by BaseProvider or other methods + self.api_key = api_key + self.base_url = self.client.base_url # Get effective base_url from client + + def create_chat_completion( + self, + messages: list[dict[str, str]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + stream: bool = True, + tools: list[dict[str, Any]] | None = None, + ) -> Stream[ChatCompletionChunk] | ChatCompletion: + # Pass self (provider instance) to the helper function + return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools) + + def get_streaming_content(self, response: Stream[ChatCompletionChunk]): + return get_streaming_content(response) + + def get_content(self, response: ChatCompletion) -> str: + return get_content(response) + + def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool: + # This method might need the full response after streaming, handled by LLMClient + return has_tool_calls(response) + + def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]: + return parse_tool_calls(response) + + def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]: + return format_tool_results(tool_call_id, result) + + def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + return convert_tools(tools) + + def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]: + return get_original_message_with_calls(response) + + def get_usage(self, response: Any) -> dict[str, int] | None: + return get_usage(response) diff --git a/src/providers/openai_provider/client.py b/src/providers/openai_provider/client.py new file mode 100644 index 0000000..3d58e6c --- /dev/null +++ b/src/providers/openai_provider/client.py @@ -0,0 +1,23 @@ +# src/providers/openai_provider/client.py +import logging + +from openai import OpenAI + +from src.llm_models import MODELS + +logger = logging.getLogger(__name__) + + +def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI: + """Initializes and returns an OpenAI client instance.""" + # Use default OpenAI endpoint if base_url is not provided explicitly + effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint") + logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}") + try: + # TODO: Add default headers if needed, similar to the original openai_client.py? + # default_headers={"HTTP-Referer": "...", "X-Title": "..."} + client = OpenAI(api_key=api_key, base_url=effective_base_url) + return client + except Exception as e: + logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True) + raise diff --git a/src/providers/openai_provider/completion.py b/src/providers/openai_provider/completion.py new file mode 100644 index 0000000..78652c6 --- /dev/null +++ b/src/providers/openai_provider/completion.py @@ -0,0 +1,80 @@ +# src/providers/openai_provider/completion.py +import logging +from typing import Any + +from openai import Stream +from openai.types.chat import ChatCompletion, ChatCompletionChunk + +from providers.openai_provider.utils import truncate_messages + +logger = logging.getLogger(__name__) + + +def create_chat_completion( + provider, # The OpenAIProvider instance + messages: list[dict[str, str]], + model: str, + temperature: float = 0.4, + max_tokens: int | None = None, + stream: bool = True, + tools: list[dict[str, Any]] | None = None, +) -> Stream[ChatCompletionChunk] | ChatCompletion: + """Creates a chat completion using the OpenAI API, handling context window truncation.""" + logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}") + + # --- Truncation Step --- + truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model) + # ----------------------- + + try: + completion_params = { + "model": model, + "messages": truncated_messages, # Use truncated messages + "temperature": temperature, + "max_tokens": max_tokens, + "stream": stream, + } + if tools: + completion_params["tools"] = tools + completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools + + # Remove None values like max_tokens if not provided + completion_params = {k: v for k, v in completion_params.items() if v is not None} + + # --- Added Debug Logging --- + log_params = completion_params.copy() + # Avoid logging full messages if they are too long + if "messages" in log_params: + log_params["messages"] = [ + {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()} + for msg in log_params["messages"][-2:] # Log last 2 messages summary + ] + # Specifically log tools structure if present + tools_log = log_params.get("tools", "Not Present") + logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}") + logger.debug(f"Full API Params (messages summarized): {log_params}") + # --- End Added Debug Logging --- + + response = provider.client.chat.completions.create(**completion_params) + logger.debug("OpenAI API call successful.") + + # --- Capture Actual Usage (for UI display later) --- + # Log usage if available (primarily non-streaming) + actual_usage = None + if isinstance(response, ChatCompletion) and response.usage: + actual_usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + logger.info(f"Actual OpenAI API usage: {actual_usage}") + # TODO: How to handle usage for streaming responses? Needs investigation. + + # Return the raw response for now. LLMClient will process it. + return response + # ---------------------------------------------------- + + except Exception as e: + logger.error(f"OpenAI API error: {e}", exc_info=True) + # Re-raise for the LLMClient to handle + raise diff --git a/src/providers/openai_provider/response.py b/src/providers/openai_provider/response.py new file mode 100644 index 0000000..c18062f --- /dev/null +++ b/src/providers/openai_provider/response.py @@ -0,0 +1,69 @@ +# src/providers/openai_provider/response.py +import json +import logging +from collections.abc import Generator +from typing import Any + +from openai import Stream +from openai.types.chat import ChatCompletion, ChatCompletionChunk + +logger = logging.getLogger(__name__) + + +def get_streaming_content(response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]: + """Yields content chunks from an OpenAI streaming response.""" + logger.debug("Processing OpenAI stream...") + full_delta = "" + try: + for chunk in response: + # Check if choices exist and are not empty + if chunk.choices: + delta = chunk.choices[0].delta.content + if delta: + full_delta += delta + yield delta + # Handle potential finish reasons or other stream elements if needed + # else: + # logger.debug(f"Stream chunk without choices: {chunk}") # Or handle finish reason etc. + logger.debug(f"Stream finished. Total delta length: {len(full_delta)}") + except Exception as e: + logger.error(f"Error processing OpenAI stream: {e}", exc_info=True) + # Yield an error message? Or let the generator stop? + yield json.dumps({"error": f"Stream processing error: {str(e)}"}) + + +def get_content(response: ChatCompletion) -> str: + """Extracts content from a non-streaming OpenAI response.""" + try: + # Check if choices exist and are not empty + if response.choices: + content = response.choices[0].message.content + logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.") + return content or "" # Return empty string if content is None + else: + logger.warning("No choices found in OpenAI non-streaming response.") + return "[No content received]" + except Exception as e: + logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True) + return f"[Error extracting content: {str(e)}]" + + +def get_usage(response: Any) -> dict[str, int] | None: + """Extracts token usage from a non-streaming OpenAI response.""" + try: + if isinstance(response, ChatCompletion) and response.usage: + usage = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + # "total_tokens": response.usage.total_tokens, # Optional + } + logger.debug(f"Extracted usage from OpenAI response: {usage}") + return usage + else: + # Don't log warning for streams, as usage isn't expected here + if not isinstance(response, Stream): + logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}") + return None + except Exception as e: + logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True) + return None diff --git a/src/providers/openai_provider/tools.py b/src/providers/openai_provider/tools.py new file mode 100644 index 0000000..84d47c2 --- /dev/null +++ b/src/providers/openai_provider/tools.py @@ -0,0 +1,170 @@ +# src/providers/openai_provider/tools.py +import json +import logging +from typing import Any + +from openai import Stream +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall + +logger = logging.getLogger(__name__) + + +def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool: + """Checks if the OpenAI response contains tool calls.""" + try: + if isinstance(response, ChatCompletion): # Non-streaming + # Check if choices exist and are not empty + if response.choices: + return bool(response.choices[0].message.tool_calls) + else: + logger.warning("No choices found in OpenAI non-streaming response for tool check.") + return False + elif isinstance(response, Stream): + # This check remains unreliable for unconsumed streams. + # LLMClient needs robust handling after consumption. + logger.warning("has_tool_calls check on a stream is unreliable before consumption.") + return False # Assume no for unconsumed stream for now + else: + # If it's already consumed stream or unexpected type + logger.warning(f"has_tool_calls received unexpected type: {type(response)}") + return False + except Exception as e: + logger.error(f"Error checking for tool calls: {e}", exc_info=True) + return False + + +def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]: + """Parses tool calls from a non-streaming OpenAI response.""" + # This implementation assumes a non-streaming response or a fully buffered stream + parsed_calls = [] + try: + if not isinstance(response, ChatCompletion): + logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}") + return [] + + # Check if choices exist and are not empty + if not response.choices: + logger.warning("No choices found in OpenAI non-streaming response for tool parsing.") + return [] + + tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls + if not tool_calls: + return [] + + logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.") + for call in tool_calls: + if call.type == "function": + # Attempt to parse server_name from function name if prefixed + # e.g., "server-name__actual-tool-name" + parts = call.function.name.split("__", 1) + if len(parts) == 2: + server_name, func_name = parts + else: + # If no prefix, how do we know the server? Needs refinement. + # Defaulting to None or a default server? Log warning. + logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.") + server_name = None # Or raise error, or use a default? + func_name = call.function.name + + # Arguments might be a string needing JSON parsing, or already parsed dict + arguments_obj = None + try: + if isinstance(call.function.arguments, str): + arguments_obj = json.loads(call.function.arguments) + else: + # Assuming it might already be a dict if not a string (less common) + arguments_obj = call.function.arguments + except json.JSONDecodeError as json_err: + logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}") + logger.error(f"Raw arguments string: {call.function.arguments}") + # Decide how to handle: skip tool, pass raw string, pass error? + # Passing raw string for now, but this might break consumers. + arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments} + + parsed_calls.append({ + "id": call.id, + "server_name": server_name, # May be None if not prefixed + "function_name": func_name, + "arguments": arguments_obj, # Pass parsed arguments (or error dict) + }) + else: + logger.warning(f"Unsupported tool call type: {call.type}") + + return parsed_calls + except Exception as e: + logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True) + return [] # Return empty list on error + + +def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]: + """Formats a tool result for an OpenAI follow-up request.""" + # Result might be a dict (including potential errors) or simple string/number + # OpenAI expects the content to be a string, often JSON. + try: + if isinstance(result, dict): + content = json.dumps(result) + elif isinstance(result, str): + content = result # Allow plain strings if result is already string + else: + content = str(result) # Ensure it's a string otherwise + except Exception as e: + logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}") + content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))}) + + logger.debug(f"Formatting tool result for call ID {tool_call_id}") + return { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content, + } + + +def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Converts internal tool format to OpenAI's format.""" + # This function seems identical to the one in src/tools/conversion.py + # We can potentially remove it from here and import from the central location. + # For now, keep it duplicated to maintain modularity until a decision is made. + openai_tools = [] + logger.debug(f"Converting {len(tools)} tools to OpenAI format.") + for tool in tools: + server_name = tool.get("server_name") + tool_name = tool.get("name") + description = tool.get("description") + input_schema = tool.get("inputSchema") + + if not server_name or not tool_name or not description or not input_schema: + logger.warning(f"Skipping invalid tool definition during conversion: {tool}") + continue + + # Prefix tool name with server name to avoid clashes and allow routing + prefixed_tool_name = f"{server_name}__{tool_name}" + + openai_tool_format = { + "type": "function", + "function": { + "name": prefixed_tool_name, + "description": description, + "parameters": input_schema, # OpenAI uses JSON Schema directly + }, + } + openai_tools.append(openai_tool_format) + logger.debug(f"Converted tool: {prefixed_tool_name}") + + return openai_tools + + +def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]: + """Extracts the assistant's message containing tool calls.""" + try: + if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls: + message = response.choices[0].message + # Convert Pydantic model to dict for message history + return message.model_dump(exclude_unset=True) + else: + logger.warning("Could not extract original message with tool calls from response.") + # Return a placeholder or raise error? + return {"role": "assistant", "content": "[Could not extract tool calls message]"} + except Exception as e: + logger.error(f"Error extracting original message with calls: {e}", exc_info=True) + return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"} diff --git a/src/providers/openai_provider/utils.py b/src/providers/openai_provider/utils.py new file mode 100644 index 0000000..0d47f72 --- /dev/null +++ b/src/providers/openai_provider/utils.py @@ -0,0 +1,114 @@ +# src/providers/openai_provider/utils.py +import logging +import math + +from src.llm_models import MODELS + +logger = logging.getLogger(__name__) + + +def get_context_window(model: str) -> int: + """Retrieves the context window size for a given model.""" + # Default to a safe fallback if model or provider info is missing + default_window = 8000 + try: + # Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts + provider_models = MODELS.get("openai", {}).get("models", []) + for m in provider_models: + if m.get("id") == model: + return m.get("context_window", default_window) + # Fallback if specific model ID not found in our list + logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}") + return default_window + except Exception as e: + logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True) + return default_window + + +def estimate_openai_token_count(messages: list[dict[str, str]]) -> int: + """ + Estimates the token count for OpenAI messages using char count / 4 approximation. + Note: This is less accurate than using tiktoken. + """ + total_chars = 0 + for message in messages: + total_chars += len(message.get("role", "")) + content = message.get("content") + if isinstance(content, str): + total_chars += len(content) + # Rough approximation for function/tool call overhead if needed later + # Using math.ceil to round up, ensuring we don't underestimate too much. + estimated_tokens = math.ceil(total_chars / 4.0) + logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages") + return estimated_tokens + + +def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]: + """ + Truncates messages from the beginning if estimated token count exceeds the limit. + Preserves the first message if it's a system prompt. + + Returns: + - The potentially truncated list of messages. + - The initial estimated token count. + - The final estimated token count after truncation (if any). + """ + context_limit = get_context_window(model) + # Add a buffer to be safer with approximation + buffer = 200 # Reduce buffer slightly as we round up now + effective_limit = context_limit - buffer + + initial_estimated_count = estimate_openai_token_count(messages) + final_estimated_count = initial_estimated_count + + truncated_messages = list(messages) # Make a copy + + # Identify if the first message is a system prompt + has_system_prompt = False + if truncated_messages and truncated_messages[0].get("role") == "system": + has_system_prompt = True + # If only system prompt exists, don't truncate further + if len(truncated_messages) == 1 and final_estimated_count > effective_limit: + logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.") + # Return original messages to avoid removing the only message + return messages, initial_estimated_count, final_estimated_count + + while final_estimated_count > effective_limit: + if has_system_prompt and len(truncated_messages) <= 1: + # Should not happen if check above works, but safety break + logger.warning("Truncation stopped: Only system prompt remains.") + break + if not has_system_prompt and len(truncated_messages) <= 0: + logger.warning("Truncation stopped: No messages left.") + break # No messages left + + # Determine index to remove: 1 if system prompt exists and list is long enough, else 0 + remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0 + + if remove_index >= len(truncated_messages): + logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.") + break # Avoid index error + + removed_message = truncated_messages.pop(remove_index) + logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.") + + # Recalculate estimated count + final_estimated_count = estimate_openai_token_count(truncated_messages) + logger.debug(f"Recalculated estimated tokens: {final_estimated_count}") + + # Safety break if list becomes unexpectedly empty + if not truncated_messages: + logger.warning("Truncation resulted in empty message list.") + break + + if initial_estimated_count != final_estimated_count: + logger.info( + f"Truncated messages for model {model}. " + f"Initial estimated tokens: {initial_estimated_count}, " + f"Final estimated tokens: {final_estimated_count}, " + f"Limit: {context_limit} (Effective: {effective_limit})" + ) + else: + logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})") + + return truncated_messages, initial_estimated_count, final_estimated_count diff --git a/src/tools/conversion.py b/src/tools/conversion.py index fa5ae02..8ec6b21 100644 --- a/src/tools/conversion.py +++ b/src/tools/conversion.py @@ -11,59 +11,6 @@ from typing import Any logger = logging.getLogger(__name__) -def convert_to_openai_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Convert MCP tools to OpenAI tool definitions. - - Args: - mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema). - - Returns: - List of OpenAI tool definitions. - """ - openai_tools = [] - logger.debug(f"Converting {len(mcp_tools)} MCP tools to OpenAI format.") - - for tool in mcp_tools: - server_name = tool.get("server_name") - tool_name = tool.get("name") - description = tool.get("description") - input_schema = tool.get("inputSchema") - - if not server_name or not tool_name or not description or not input_schema: - logger.warning(f"Skipping invalid MCP tool definition during OpenAI conversion: {tool}") - continue - - # Prefix tool name with server name for routing - prefixed_tool_name = f"{server_name}__{tool_name}" - - # Initialize the OpenAI tool structure - openai_tool = { - "type": "function", - "function": { - "name": prefixed_tool_name, - "description": description, - "parameters": input_schema, # OpenAI uses JSON Schema directly - }, - } - # Basic validation/cleaning of schema if needed could go here - if not isinstance(input_schema, dict) or input_schema.get("type") != "object": - logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. OpenAI might reject this.") - # Ensure basic structure if missing - if not isinstance(input_schema, dict): - input_schema = {} - if "type" not in input_schema: - input_schema["type"] = "object" - if "properties" not in input_schema: - input_schema["properties"] = {} - openai_tool["function"]["parameters"] = input_schema - - openai_tools.append(openai_tool) - logger.debug(f"Converted MCP tool to OpenAI: {prefixed_tool_name}") - - return openai_tools - - def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Convert MCP tools to Google Gemini format (dictionary structure).