feat: implement OpenAIProvider with client initialization, message handling, and utility functions
This commit is contained in:
@@ -1,390 +0,0 @@
|
||||
# src/providers/openai_provider.py
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
from src.llm_models import MODELS
|
||||
from src.providers.base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Provider implementation for OpenAI and compatible APIs."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# Use default OpenAI endpoint if base_url is not provided
|
||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||
super().__init__(api_key, effective_base_url)
|
||||
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
|
||||
try:
|
||||
# TODO: Add default headers like in original client?
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_context_window(self, model: str) -> int:
|
||||
"""Retrieves the context window size for a given model."""
|
||||
# Default to a safe fallback if model or provider info is missing
|
||||
default_window = 8000
|
||||
try:
|
||||
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
|
||||
provider_models = MODELS.get("openai", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
# Fallback if specific model ID not found in our list
|
||||
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
def _estimate_openai_token_count(self, messages: list[dict[str, str]]) -> int:
|
||||
"""
|
||||
Estimates the token count for OpenAI messages using char count / 4 approximation.
|
||||
Note: This is less accurate than using tiktoken.
|
||||
"""
|
||||
total_chars = 0
|
||||
for message in messages:
|
||||
total_chars += len(message.get("role", ""))
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
total_chars += len(content)
|
||||
# Rough approximation for function/tool call overhead if needed later
|
||||
# Using math.ceil to round up, ensuring we don't underestimate too much.
|
||||
estimated_tokens = math.ceil(total_chars / 4.0)
|
||||
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
||||
return estimated_tokens
|
||||
|
||||
def _truncate_messages(self, messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
|
||||
"""
|
||||
Truncates messages from the beginning if estimated token count exceeds the limit.
|
||||
Preserves the first message if it's a system prompt.
|
||||
|
||||
Returns:
|
||||
- The potentially truncated list of messages.
|
||||
- The initial estimated token count.
|
||||
- The final estimated token count after truncation (if any).
|
||||
"""
|
||||
context_limit = self._get_context_window(model)
|
||||
# Add a buffer to be safer with approximation
|
||||
buffer = 200 # Reduce buffer slightly as we round up now
|
||||
effective_limit = context_limit - buffer
|
||||
|
||||
initial_estimated_count = self._estimate_openai_token_count(messages)
|
||||
final_estimated_count = initial_estimated_count
|
||||
|
||||
truncated_messages = list(messages) # Make a copy
|
||||
|
||||
# Identify if the first message is a system prompt
|
||||
has_system_prompt = False
|
||||
if truncated_messages and truncated_messages[0].get("role") == "system":
|
||||
has_system_prompt = True
|
||||
# If only system prompt exists, don't truncate further
|
||||
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
||||
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
||||
# Return original messages to avoid removing the only message
|
||||
return messages, initial_estimated_count, final_estimated_count
|
||||
|
||||
while final_estimated_count > effective_limit:
|
||||
if has_system_prompt and len(truncated_messages) <= 1:
|
||||
# Should not happen if check above works, but safety break
|
||||
logger.warning("Truncation stopped: Only system prompt remains.")
|
||||
break
|
||||
if not has_system_prompt and len(truncated_messages) <= 0:
|
||||
logger.warning("Truncation stopped: No messages left.")
|
||||
break # No messages left
|
||||
|
||||
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
|
||||
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
||||
|
||||
if remove_index >= len(truncated_messages):
|
||||
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
||||
break # Avoid index error
|
||||
|
||||
removed_message = truncated_messages.pop(remove_index)
|
||||
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
||||
|
||||
# Recalculate estimated count
|
||||
final_estimated_count = self._estimate_openai_token_count(truncated_messages)
|
||||
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
||||
|
||||
# Safety break if list becomes unexpectedly empty
|
||||
if not truncated_messages:
|
||||
logger.warning("Truncation resulted in empty message list.")
|
||||
break
|
||||
|
||||
if initial_estimated_count != final_estimated_count:
|
||||
logger.info(
|
||||
f"Truncated messages for model {model}. "
|
||||
f"Initial estimated tokens: {initial_estimated_count}, "
|
||||
f"Final estimated tokens: {final_estimated_count}, "
|
||||
f"Limit: {context_limit} (Effective: {effective_limit})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||
|
||||
return truncated_messages, initial_estimated_count, final_estimated_count
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
# Add usage dict to return type hint? Needs careful thought for streaming vs non-streaming
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion: # How to return usage info cleanly?
|
||||
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
|
||||
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# --- Truncation Step ---
|
||||
truncated_messages, initial_est_tokens, final_est_tokens = self._truncate_messages(messages, model)
|
||||
# -----------------------
|
||||
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": truncated_messages, # Use truncated messages
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||
|
||||
# Remove None values like max_tokens if not provided
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
# --- Added Debug Logging ---
|
||||
log_params = completion_params.copy()
|
||||
# Avoid logging full messages if they are too long
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [
|
||||
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||
]
|
||||
# Specifically log tools structure if present
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
# --- End Added Debug Logging ---
|
||||
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
logger.debug("OpenAI API call successful.")
|
||||
|
||||
# --- Capture Actual Usage (for UI display later) ---
|
||||
# This part is tricky. Usage info is easily available on the *non-streaming* response.
|
||||
# For streaming, it's often not available until the stream is fully consumed,
|
||||
# or sometimes via response headers or a final event (provider-dependent).
|
||||
# For now, let's focus on getting it from the non-streaming case.
|
||||
# We need a way to pass this back alongside the content/stream.
|
||||
# Option 1: Modify return type (complex for stream/non-stream union)
|
||||
# Option 2: Store it in the provider instance (stateful, maybe bad)
|
||||
# Option 3: Have LLMClient handle extraction (requires LLMClient to know response structure)
|
||||
|
||||
# Let's try returning it alongside for non-streaming, and figure out streaming later.
|
||||
# This requires changing the BaseProvider interface and LLMClient handling.
|
||||
# For now, just log it here.
|
||||
actual_usage = None
|
||||
if isinstance(response, ChatCompletion) and response.usage:
|
||||
actual_usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
logger.info(f"Actual OpenAI API usage: {actual_usage}")
|
||||
# TODO: How to handle usage for streaming responses? Needs investigation.
|
||||
|
||||
# Return the raw response for now. LLMClient will process it.
|
||||
return response
|
||||
# ----------------------------------------------------
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
# Re-raise for the LLMClient to handle
|
||||
raise
|
||||
|
||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an OpenAI streaming response."""
|
||||
logger.debug("Processing OpenAI stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for chunk in response:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||
# Yield an error message? Or let the generator stop?
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
def get_content(self, response: ChatCompletion) -> str:
|
||||
"""Extracts content from a non-streaming OpenAI response."""
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||
return content or "" # Return empty string if content is None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
|
||||
# This is tricky for streams. We'd need to peek at the first chunk(s)
|
||||
# or buffer the response. For simplicity, this check might be unreliable
|
||||
# for streams *before* they are consumed. LLMClient needs robust handling.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
# A more robust check would involve consuming the start of the stream
|
||||
# or relying on the structure after consumption.
|
||||
return False # Assume no for unconsumed stream for now
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
# Attempt to handle buffered stream if possible? Complex.
|
||||
return []
|
||||
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
func_name = call.function.name
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"function_name": func_name,
|
||||
"arguments": call.function.arguments, # Arguments are already a string here
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
else:
|
||||
content = str(result) # Ensure it's a string
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
# Helper needed by LLMClient's current tool handling logic
|
||||
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
|
||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage from a non-streaming OpenAI response."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
# "total_tokens": response.usage.total_tokens, # Optional
|
||||
}
|
||||
logger.debug(f"Extracted usage from OpenAI response: {usage}")
|
||||
return usage
|
||||
else:
|
||||
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
|
||||
return None
|
||||
66
src/providers/openai_provider/__init__.py
Normal file
66
src/providers/openai_provider/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# src/providers/openai_provider/__init__.py
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from providers.openai_provider.client import initialize_client
|
||||
from providers.openai_provider.completion import create_chat_completion
|
||||
from providers.openai_provider.response import get_content, get_streaming_content, get_usage
|
||||
from providers.openai_provider.tools import (
|
||||
convert_tools,
|
||||
format_tool_results,
|
||||
get_original_message_with_calls,
|
||||
has_tool_calls,
|
||||
parse_tool_calls,
|
||||
)
|
||||
from src.providers.base import BaseProvider
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""Provider implementation for OpenAI and compatible APIs."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
# BaseProvider __init__ might not be needed if client init handles base_url logic
|
||||
# super().__init__(api_key, base_url) # Let's see if we need this
|
||||
self.client = initialize_client(api_key, base_url)
|
||||
# Store api_key and base_url if needed by BaseProvider or other methods
|
||||
self.api_key = api_key
|
||||
self.base_url = self.client.base_url # Get effective base_url from client
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
# Pass self (provider instance) to the helper function
|
||||
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||
|
||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]):
|
||||
return get_streaming_content(response)
|
||||
|
||||
def get_content(self, response: ChatCompletion) -> str:
|
||||
return get_content(response)
|
||||
|
||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
# This method might need the full response after streaming, handled by LLMClient
|
||||
return has_tool_calls(response)
|
||||
|
||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
return parse_tool_calls(response)
|
||||
|
||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
return format_tool_results(tool_call_id, result)
|
||||
|
||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return convert_tools(tools)
|
||||
|
||||
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
return get_original_message_with_calls(response)
|
||||
|
||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||
return get_usage(response)
|
||||
23
src/providers/openai_provider/client.py
Normal file
23
src/providers/openai_provider/client.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# src/providers/openai_provider/client.py
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from src.llm_models import MODELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI:
|
||||
"""Initializes and returns an OpenAI client instance."""
|
||||
# Use default OpenAI endpoint if base_url is not provided explicitly
|
||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||
logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}")
|
||||
try:
|
||||
# TODO: Add default headers if needed, similar to the original openai_client.py?
|
||||
# default_headers={"HTTP-Referer": "...", "X-Title": "..."}
|
||||
client = OpenAI(api_key=api_key, base_url=effective_base_url)
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||
raise
|
||||
80
src/providers/openai_provider/completion.py
Normal file
80
src/providers/openai_provider/completion.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# src/providers/openai_provider/completion.py
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from providers.openai_provider.utils import truncate_messages
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
provider, # The OpenAIProvider instance
|
||||
messages: list[dict[str, str]],
|
||||
model: str,
|
||||
temperature: float = 0.4,
|
||||
max_tokens: int | None = None,
|
||||
stream: bool = True,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
|
||||
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||
|
||||
# --- Truncation Step ---
|
||||
truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model)
|
||||
# -----------------------
|
||||
|
||||
try:
|
||||
completion_params = {
|
||||
"model": model,
|
||||
"messages": truncated_messages, # Use truncated messages
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": stream,
|
||||
}
|
||||
if tools:
|
||||
completion_params["tools"] = tools
|
||||
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||
|
||||
# Remove None values like max_tokens if not provided
|
||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||
|
||||
# --- Added Debug Logging ---
|
||||
log_params = completion_params.copy()
|
||||
# Avoid logging full messages if they are too long
|
||||
if "messages" in log_params:
|
||||
log_params["messages"] = [
|
||||
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||
]
|
||||
# Specifically log tools structure if present
|
||||
tools_log = log_params.get("tools", "Not Present")
|
||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||
# --- End Added Debug Logging ---
|
||||
|
||||
response = provider.client.chat.completions.create(**completion_params)
|
||||
logger.debug("OpenAI API call successful.")
|
||||
|
||||
# --- Capture Actual Usage (for UI display later) ---
|
||||
# Log usage if available (primarily non-streaming)
|
||||
actual_usage = None
|
||||
if isinstance(response, ChatCompletion) and response.usage:
|
||||
actual_usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
logger.info(f"Actual OpenAI API usage: {actual_usage}")
|
||||
# TODO: How to handle usage for streaming responses? Needs investigation.
|
||||
|
||||
# Return the raw response for now. LLMClient will process it.
|
||||
return response
|
||||
# ----------------------------------------------------
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
# Re-raise for the LLMClient to handle
|
||||
raise
|
||||
69
src/providers/openai_provider/response.py
Normal file
69
src/providers/openai_provider/response.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# src/providers/openai_provider/response.py
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_streaming_content(response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||
"""Yields content chunks from an OpenAI streaming response."""
|
||||
logger.debug("Processing OpenAI stream...")
|
||||
full_delta = ""
|
||||
try:
|
||||
for chunk in response:
|
||||
# Check if choices exist and are not empty
|
||||
if chunk.choices:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta:
|
||||
full_delta += delta
|
||||
yield delta
|
||||
# Handle potential finish reasons or other stream elements if needed
|
||||
# else:
|
||||
# logger.debug(f"Stream chunk without choices: {chunk}") # Or handle finish reason etc.
|
||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||
# Yield an error message? Or let the generator stop?
|
||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||
|
||||
|
||||
def get_content(response: ChatCompletion) -> str:
|
||||
"""Extracts content from a non-streaming OpenAI response."""
|
||||
try:
|
||||
# Check if choices exist and are not empty
|
||||
if response.choices:
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||
return content or "" # Return empty string if content is None
|
||||
else:
|
||||
logger.warning("No choices found in OpenAI non-streaming response.")
|
||||
return "[No content received]"
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||
return f"[Error extracting content: {str(e)}]"
|
||||
|
||||
|
||||
def get_usage(response: Any) -> dict[str, int] | None:
|
||||
"""Extracts token usage from a non-streaming OpenAI response."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
# "total_tokens": response.usage.total_tokens, # Optional
|
||||
}
|
||||
logger.debug(f"Extracted usage from OpenAI response: {usage}")
|
||||
return usage
|
||||
else:
|
||||
# Don't log warning for streams, as usage isn't expected here
|
||||
if not isinstance(response, Stream):
|
||||
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
|
||||
return None
|
||||
170
src/providers/openai_provider/tools.py
Normal file
170
src/providers/openai_provider/tools.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# src/providers/openai_provider/tools.py
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai import Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||
"""Checks if the OpenAI response contains tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion): # Non-streaming
|
||||
# Check if choices exist and are not empty
|
||||
if response.choices:
|
||||
return bool(response.choices[0].message.tool_calls)
|
||||
else:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
|
||||
return False
|
||||
elif isinstance(response, Stream):
|
||||
# This check remains unreliable for unconsumed streams.
|
||||
# LLMClient needs robust handling after consumption.
|
||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||
return False # Assume no for unconsumed stream for now
|
||||
else:
|
||||
# If it's already consumed stream or unexpected type
|
||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||
parsed_calls = []
|
||||
try:
|
||||
if not isinstance(response, ChatCompletion):
|
||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||
return []
|
||||
|
||||
# Check if choices exist and are not empty
|
||||
if not response.choices:
|
||||
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
|
||||
return []
|
||||
|
||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||
for call in tool_calls:
|
||||
if call.type == "function":
|
||||
# Attempt to parse server_name from function name if prefixed
|
||||
# e.g., "server-name__actual-tool-name"
|
||||
parts = call.function.name.split("__", 1)
|
||||
if len(parts) == 2:
|
||||
server_name, func_name = parts
|
||||
else:
|
||||
# If no prefix, how do we know the server? Needs refinement.
|
||||
# Defaulting to None or a default server? Log warning.
|
||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||
server_name = None # Or raise error, or use a default?
|
||||
func_name = call.function.name
|
||||
|
||||
# Arguments might be a string needing JSON parsing, or already parsed dict
|
||||
arguments_obj = None
|
||||
try:
|
||||
if isinstance(call.function.arguments, str):
|
||||
arguments_obj = json.loads(call.function.arguments)
|
||||
else:
|
||||
# Assuming it might already be a dict if not a string (less common)
|
||||
arguments_obj = call.function.arguments
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
|
||||
logger.error(f"Raw arguments string: {call.function.arguments}")
|
||||
# Decide how to handle: skip tool, pass raw string, pass error?
|
||||
# Passing raw string for now, but this might break consumers.
|
||||
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
|
||||
|
||||
parsed_calls.append({
|
||||
"id": call.id,
|
||||
"server_name": server_name, # May be None if not prefixed
|
||||
"function_name": func_name,
|
||||
"arguments": arguments_obj, # Pass parsed arguments (or error dict)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||
|
||||
return parsed_calls
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||
return [] # Return empty list on error
|
||||
|
||||
|
||||
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||
"""Formats a tool result for an OpenAI follow-up request."""
|
||||
# Result might be a dict (including potential errors) or simple string/number
|
||||
# OpenAI expects the content to be a string, often JSON.
|
||||
try:
|
||||
if isinstance(result, dict):
|
||||
content = json.dumps(result)
|
||||
elif isinstance(result, str):
|
||||
content = result # Allow plain strings if result is already string
|
||||
else:
|
||||
content = str(result) # Ensure it's a string otherwise
|
||||
except Exception as e:
|
||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||
|
||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Converts internal tool format to OpenAI's format."""
|
||||
# This function seems identical to the one in src/tools/conversion.py
|
||||
# We can potentially remove it from here and import from the central location.
|
||||
# For now, keep it duplicated to maintain modularity until a decision is made.
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||
for tool in tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name to avoid clashes and allow routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
openai_tool_format = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
openai_tools.append(openai_tool_format)
|
||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extracts the assistant's message containing tool calls."""
|
||||
try:
|
||||
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
|
||||
message = response.choices[0].message
|
||||
# Convert Pydantic model to dict for message history
|
||||
return message.model_dump(exclude_unset=True)
|
||||
else:
|
||||
logger.warning("Could not extract original message with tool calls from response.")
|
||||
# Return a placeholder or raise error?
|
||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||
114
src/providers/openai_provider/utils.py
Normal file
114
src/providers/openai_provider/utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# src/providers/openai_provider/utils.py
|
||||
import logging
|
||||
import math
|
||||
|
||||
from src.llm_models import MODELS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_context_window(model: str) -> int:
|
||||
"""Retrieves the context window size for a given model."""
|
||||
# Default to a safe fallback if model or provider info is missing
|
||||
default_window = 8000
|
||||
try:
|
||||
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
|
||||
provider_models = MODELS.get("openai", {}).get("models", [])
|
||||
for m in provider_models:
|
||||
if m.get("id") == model:
|
||||
return m.get("context_window", default_window)
|
||||
# Fallback if specific model ID not found in our list
|
||||
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||
return default_window
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||
return default_window
|
||||
|
||||
|
||||
def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
|
||||
"""
|
||||
Estimates the token count for OpenAI messages using char count / 4 approximation.
|
||||
Note: This is less accurate than using tiktoken.
|
||||
"""
|
||||
total_chars = 0
|
||||
for message in messages:
|
||||
total_chars += len(message.get("role", ""))
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
total_chars += len(content)
|
||||
# Rough approximation for function/tool call overhead if needed later
|
||||
# Using math.ceil to round up, ensuring we don't underestimate too much.
|
||||
estimated_tokens = math.ceil(total_chars / 4.0)
|
||||
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
||||
return estimated_tokens
|
||||
|
||||
|
||||
def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
|
||||
"""
|
||||
Truncates messages from the beginning if estimated token count exceeds the limit.
|
||||
Preserves the first message if it's a system prompt.
|
||||
|
||||
Returns:
|
||||
- The potentially truncated list of messages.
|
||||
- The initial estimated token count.
|
||||
- The final estimated token count after truncation (if any).
|
||||
"""
|
||||
context_limit = get_context_window(model)
|
||||
# Add a buffer to be safer with approximation
|
||||
buffer = 200 # Reduce buffer slightly as we round up now
|
||||
effective_limit = context_limit - buffer
|
||||
|
||||
initial_estimated_count = estimate_openai_token_count(messages)
|
||||
final_estimated_count = initial_estimated_count
|
||||
|
||||
truncated_messages = list(messages) # Make a copy
|
||||
|
||||
# Identify if the first message is a system prompt
|
||||
has_system_prompt = False
|
||||
if truncated_messages and truncated_messages[0].get("role") == "system":
|
||||
has_system_prompt = True
|
||||
# If only system prompt exists, don't truncate further
|
||||
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
||||
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
||||
# Return original messages to avoid removing the only message
|
||||
return messages, initial_estimated_count, final_estimated_count
|
||||
|
||||
while final_estimated_count > effective_limit:
|
||||
if has_system_prompt and len(truncated_messages) <= 1:
|
||||
# Should not happen if check above works, but safety break
|
||||
logger.warning("Truncation stopped: Only system prompt remains.")
|
||||
break
|
||||
if not has_system_prompt and len(truncated_messages) <= 0:
|
||||
logger.warning("Truncation stopped: No messages left.")
|
||||
break # No messages left
|
||||
|
||||
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
|
||||
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
||||
|
||||
if remove_index >= len(truncated_messages):
|
||||
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
||||
break # Avoid index error
|
||||
|
||||
removed_message = truncated_messages.pop(remove_index)
|
||||
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
||||
|
||||
# Recalculate estimated count
|
||||
final_estimated_count = estimate_openai_token_count(truncated_messages)
|
||||
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
||||
|
||||
# Safety break if list becomes unexpectedly empty
|
||||
if not truncated_messages:
|
||||
logger.warning("Truncation resulted in empty message list.")
|
||||
break
|
||||
|
||||
if initial_estimated_count != final_estimated_count:
|
||||
logger.info(
|
||||
f"Truncated messages for model {model}. "
|
||||
f"Initial estimated tokens: {initial_estimated_count}, "
|
||||
f"Final estimated tokens: {final_estimated_count}, "
|
||||
f"Limit: {context_limit} (Effective: {effective_limit})"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||
|
||||
return truncated_messages, initial_estimated_count, final_estimated_count
|
||||
@@ -11,59 +11,6 @@ from typing import Any
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_to_openai_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to OpenAI tool definitions.
|
||||
|
||||
Args:
|
||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
||||
|
||||
Returns:
|
||||
List of OpenAI tool definitions.
|
||||
"""
|
||||
openai_tools = []
|
||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to OpenAI format.")
|
||||
|
||||
for tool in mcp_tools:
|
||||
server_name = tool.get("server_name")
|
||||
tool_name = tool.get("name")
|
||||
description = tool.get("description")
|
||||
input_schema = tool.get("inputSchema")
|
||||
|
||||
if not server_name or not tool_name or not description or not input_schema:
|
||||
logger.warning(f"Skipping invalid MCP tool definition during OpenAI conversion: {tool}")
|
||||
continue
|
||||
|
||||
# Prefix tool name with server name for routing
|
||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||
|
||||
# Initialize the OpenAI tool structure
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": prefixed_tool_name,
|
||||
"description": description,
|
||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||
},
|
||||
}
|
||||
# Basic validation/cleaning of schema if needed could go here
|
||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. OpenAI might reject this.")
|
||||
# Ensure basic structure if missing
|
||||
if not isinstance(input_schema, dict):
|
||||
input_schema = {}
|
||||
if "type" not in input_schema:
|
||||
input_schema["type"] = "object"
|
||||
if "properties" not in input_schema:
|
||||
input_schema["properties"] = {}
|
||||
openai_tool["function"]["parameters"] = input_schema
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
logger.debug(f"Converted MCP tool to OpenAI: {prefixed_tool_name}")
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Convert MCP tools to Google Gemini format (dictionary structure).
|
||||
|
||||
Reference in New Issue
Block a user