feat: implement OpenAIProvider with client initialization, message handling, and utility functions
This commit is contained in:
@@ -1,390 +0,0 @@
|
|||||||
# src/providers/openai_provider.py
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
from collections.abc import Generator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openai import OpenAI, Stream
|
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
|
||||||
|
|
||||||
from src.llm_models import MODELS
|
|
||||||
from src.providers.base import BaseProvider
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProvider(BaseProvider):
|
|
||||||
"""Provider implementation for OpenAI and compatible APIs."""
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, base_url: str | None = None):
|
|
||||||
# Use default OpenAI endpoint if base_url is not provided
|
|
||||||
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
|
||||||
super().__init__(api_key, effective_base_url)
|
|
||||||
logger.info(f"Initializing OpenAIProvider with base URL: {self.base_url}")
|
|
||||||
try:
|
|
||||||
# TODO: Add default headers like in original client?
|
|
||||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _get_context_window(self, model: str) -> int:
|
|
||||||
"""Retrieves the context window size for a given model."""
|
|
||||||
# Default to a safe fallback if model or provider info is missing
|
|
||||||
default_window = 8000
|
|
||||||
try:
|
|
||||||
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
|
|
||||||
provider_models = MODELS.get("openai", {}).get("models", [])
|
|
||||||
for m in provider_models:
|
|
||||||
if m.get("id") == model:
|
|
||||||
return m.get("context_window", default_window)
|
|
||||||
# Fallback if specific model ID not found in our list
|
|
||||||
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
|
||||||
return default_window
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
|
||||||
return default_window
|
|
||||||
|
|
||||||
def _estimate_openai_token_count(self, messages: list[dict[str, str]]) -> int:
|
|
||||||
"""
|
|
||||||
Estimates the token count for OpenAI messages using char count / 4 approximation.
|
|
||||||
Note: This is less accurate than using tiktoken.
|
|
||||||
"""
|
|
||||||
total_chars = 0
|
|
||||||
for message in messages:
|
|
||||||
total_chars += len(message.get("role", ""))
|
|
||||||
content = message.get("content")
|
|
||||||
if isinstance(content, str):
|
|
||||||
total_chars += len(content)
|
|
||||||
# Rough approximation for function/tool call overhead if needed later
|
|
||||||
# Using math.ceil to round up, ensuring we don't underestimate too much.
|
|
||||||
estimated_tokens = math.ceil(total_chars / 4.0)
|
|
||||||
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
|
||||||
return estimated_tokens
|
|
||||||
|
|
||||||
def _truncate_messages(self, messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
|
|
||||||
"""
|
|
||||||
Truncates messages from the beginning if estimated token count exceeds the limit.
|
|
||||||
Preserves the first message if it's a system prompt.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- The potentially truncated list of messages.
|
|
||||||
- The initial estimated token count.
|
|
||||||
- The final estimated token count after truncation (if any).
|
|
||||||
"""
|
|
||||||
context_limit = self._get_context_window(model)
|
|
||||||
# Add a buffer to be safer with approximation
|
|
||||||
buffer = 200 # Reduce buffer slightly as we round up now
|
|
||||||
effective_limit = context_limit - buffer
|
|
||||||
|
|
||||||
initial_estimated_count = self._estimate_openai_token_count(messages)
|
|
||||||
final_estimated_count = initial_estimated_count
|
|
||||||
|
|
||||||
truncated_messages = list(messages) # Make a copy
|
|
||||||
|
|
||||||
# Identify if the first message is a system prompt
|
|
||||||
has_system_prompt = False
|
|
||||||
if truncated_messages and truncated_messages[0].get("role") == "system":
|
|
||||||
has_system_prompt = True
|
|
||||||
# If only system prompt exists, don't truncate further
|
|
||||||
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
|
||||||
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
|
||||||
# Return original messages to avoid removing the only message
|
|
||||||
return messages, initial_estimated_count, final_estimated_count
|
|
||||||
|
|
||||||
while final_estimated_count > effective_limit:
|
|
||||||
if has_system_prompt and len(truncated_messages) <= 1:
|
|
||||||
# Should not happen if check above works, but safety break
|
|
||||||
logger.warning("Truncation stopped: Only system prompt remains.")
|
|
||||||
break
|
|
||||||
if not has_system_prompt and len(truncated_messages) <= 0:
|
|
||||||
logger.warning("Truncation stopped: No messages left.")
|
|
||||||
break # No messages left
|
|
||||||
|
|
||||||
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
|
|
||||||
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
|
||||||
|
|
||||||
if remove_index >= len(truncated_messages):
|
|
||||||
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
|
||||||
break # Avoid index error
|
|
||||||
|
|
||||||
removed_message = truncated_messages.pop(remove_index)
|
|
||||||
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
|
||||||
|
|
||||||
# Recalculate estimated count
|
|
||||||
final_estimated_count = self._estimate_openai_token_count(truncated_messages)
|
|
||||||
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
|
||||||
|
|
||||||
# Safety break if list becomes unexpectedly empty
|
|
||||||
if not truncated_messages:
|
|
||||||
logger.warning("Truncation resulted in empty message list.")
|
|
||||||
break
|
|
||||||
|
|
||||||
if initial_estimated_count != final_estimated_count:
|
|
||||||
logger.info(
|
|
||||||
f"Truncated messages for model {model}. "
|
|
||||||
f"Initial estimated tokens: {initial_estimated_count}, "
|
|
||||||
f"Final estimated tokens: {final_estimated_count}, "
|
|
||||||
f"Limit: {context_limit} (Effective: {effective_limit})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
|
||||||
|
|
||||||
return truncated_messages, initial_estimated_count, final_estimated_count
|
|
||||||
|
|
||||||
def create_chat_completion(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, str]],
|
|
||||||
model: str,
|
|
||||||
temperature: float = 0.4,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
stream: bool = True,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
# Add usage dict to return type hint? Needs careful thought for streaming vs non-streaming
|
|
||||||
) -> Stream[ChatCompletionChunk] | ChatCompletion: # How to return usage info cleanly?
|
|
||||||
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
|
|
||||||
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
|
||||||
|
|
||||||
# --- Truncation Step ---
|
|
||||||
truncated_messages, initial_est_tokens, final_est_tokens = self._truncate_messages(messages, model)
|
|
||||||
# -----------------------
|
|
||||||
|
|
||||||
try:
|
|
||||||
completion_params = {
|
|
||||||
"model": model,
|
|
||||||
"messages": truncated_messages, # Use truncated messages
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
"stream": stream,
|
|
||||||
}
|
|
||||||
if tools:
|
|
||||||
completion_params["tools"] = tools
|
|
||||||
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
|
||||||
|
|
||||||
# Remove None values like max_tokens if not provided
|
|
||||||
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
|
||||||
|
|
||||||
# --- Added Debug Logging ---
|
|
||||||
log_params = completion_params.copy()
|
|
||||||
# Avoid logging full messages if they are too long
|
|
||||||
if "messages" in log_params:
|
|
||||||
log_params["messages"] = [
|
|
||||||
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
|
||||||
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
|
||||||
]
|
|
||||||
# Specifically log tools structure if present
|
|
||||||
tools_log = log_params.get("tools", "Not Present")
|
|
||||||
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
|
||||||
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
|
||||||
# --- End Added Debug Logging ---
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**completion_params)
|
|
||||||
logger.debug("OpenAI API call successful.")
|
|
||||||
|
|
||||||
# --- Capture Actual Usage (for UI display later) ---
|
|
||||||
# This part is tricky. Usage info is easily available on the *non-streaming* response.
|
|
||||||
# For streaming, it's often not available until the stream is fully consumed,
|
|
||||||
# or sometimes via response headers or a final event (provider-dependent).
|
|
||||||
# For now, let's focus on getting it from the non-streaming case.
|
|
||||||
# We need a way to pass this back alongside the content/stream.
|
|
||||||
# Option 1: Modify return type (complex for stream/non-stream union)
|
|
||||||
# Option 2: Store it in the provider instance (stateful, maybe bad)
|
|
||||||
# Option 3: Have LLMClient handle extraction (requires LLMClient to know response structure)
|
|
||||||
|
|
||||||
# Let's try returning it alongside for non-streaming, and figure out streaming later.
|
|
||||||
# This requires changing the BaseProvider interface and LLMClient handling.
|
|
||||||
# For now, just log it here.
|
|
||||||
actual_usage = None
|
|
||||||
if isinstance(response, ChatCompletion) and response.usage:
|
|
||||||
actual_usage = {
|
|
||||||
"prompt_tokens": response.usage.prompt_tokens,
|
|
||||||
"completion_tokens": response.usage.completion_tokens,
|
|
||||||
"total_tokens": response.usage.total_tokens,
|
|
||||||
}
|
|
||||||
logger.info(f"Actual OpenAI API usage: {actual_usage}")
|
|
||||||
# TODO: How to handle usage for streaming responses? Needs investigation.
|
|
||||||
|
|
||||||
# Return the raw response for now. LLMClient will process it.
|
|
||||||
return response
|
|
||||||
# ----------------------------------------------------
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
|
||||||
# Re-raise for the LLMClient to handle
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_streaming_content(self, response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
|
||||||
"""Yields content chunks from an OpenAI streaming response."""
|
|
||||||
logger.debug("Processing OpenAI stream...")
|
|
||||||
full_delta = ""
|
|
||||||
try:
|
|
||||||
for chunk in response:
|
|
||||||
delta = chunk.choices[0].delta.content
|
|
||||||
if delta:
|
|
||||||
full_delta += delta
|
|
||||||
yield delta
|
|
||||||
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
|
||||||
# Yield an error message? Or let the generator stop?
|
|
||||||
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
|
||||||
|
|
||||||
def get_content(self, response: ChatCompletion) -> str:
|
|
||||||
"""Extracts content from a non-streaming OpenAI response."""
|
|
||||||
try:
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
|
||||||
return content or "" # Return empty string if content is None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
|
||||||
return f"[Error extracting content: {str(e)}]"
|
|
||||||
|
|
||||||
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
|
||||||
"""Checks if the OpenAI response contains tool calls."""
|
|
||||||
try:
|
|
||||||
if isinstance(response, ChatCompletion): # Non-streaming
|
|
||||||
return bool(response.choices[0].message.tool_calls)
|
|
||||||
elif hasattr(response, "_iterator"): # Check if it looks like our stream wrapper
|
|
||||||
# This is tricky for streams. We'd need to peek at the first chunk(s)
|
|
||||||
# or buffer the response. For simplicity, this check might be unreliable
|
|
||||||
# for streams *before* they are consumed. LLMClient needs robust handling.
|
|
||||||
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
|
||||||
# A more robust check would involve consuming the start of the stream
|
|
||||||
# or relying on the structure after consumption.
|
|
||||||
return False # Assume no for unconsumed stream for now
|
|
||||||
else:
|
|
||||||
# If it's already consumed stream or unexpected type
|
|
||||||
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
|
||||||
"""Parses tool calls from a non-streaming OpenAI response."""
|
|
||||||
# This implementation assumes a non-streaming response or a fully buffered stream
|
|
||||||
parsed_calls = []
|
|
||||||
try:
|
|
||||||
if not isinstance(response, ChatCompletion):
|
|
||||||
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
|
||||||
# Attempt to handle buffered stream if possible? Complex.
|
|
||||||
return []
|
|
||||||
|
|
||||||
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
|
||||||
if not tool_calls:
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
|
||||||
for call in tool_calls:
|
|
||||||
if call.type == "function":
|
|
||||||
# Attempt to parse server_name from function name if prefixed
|
|
||||||
# e.g., "server-name__actual-tool-name"
|
|
||||||
parts = call.function.name.split("__", 1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
server_name, func_name = parts
|
|
||||||
else:
|
|
||||||
# If no prefix, how do we know the server? Needs refinement.
|
|
||||||
# Defaulting to None or a default server? Log warning.
|
|
||||||
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
|
||||||
server_name = None # Or raise error, or use a default?
|
|
||||||
func_name = call.function.name
|
|
||||||
|
|
||||||
parsed_calls.append({
|
|
||||||
"id": call.id,
|
|
||||||
"server_name": server_name, # May be None if not prefixed
|
|
||||||
"function_name": func_name,
|
|
||||||
"arguments": call.function.arguments, # Arguments are already a string here
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unsupported tool call type: {call.type}")
|
|
||||||
|
|
||||||
return parsed_calls
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
|
||||||
return [] # Return empty list on error
|
|
||||||
|
|
||||||
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
|
||||||
"""Formats a tool result for an OpenAI follow-up request."""
|
|
||||||
# Result might be a dict (including potential errors) or simple string/number
|
|
||||||
# OpenAI expects the content to be a string, often JSON.
|
|
||||||
try:
|
|
||||||
if isinstance(result, dict):
|
|
||||||
content = json.dumps(result)
|
|
||||||
else:
|
|
||||||
content = str(result) # Ensure it's a string
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
|
||||||
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
|
||||||
|
|
||||||
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
|
||||||
return {
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call_id,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
|
|
||||||
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
"""Converts internal tool format to OpenAI's format."""
|
|
||||||
openai_tools = []
|
|
||||||
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
|
||||||
for tool in tools:
|
|
||||||
server_name = tool.get("server_name")
|
|
||||||
tool_name = tool.get("name")
|
|
||||||
description = tool.get("description")
|
|
||||||
input_schema = tool.get("inputSchema")
|
|
||||||
|
|
||||||
if not server_name or not tool_name or not description or not input_schema:
|
|
||||||
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Prefix tool name with server name to avoid clashes and allow routing
|
|
||||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
|
||||||
|
|
||||||
openai_tool_format = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": prefixed_tool_name,
|
|
||||||
"description": description,
|
|
||||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
|
||||||
},
|
|
||||||
}
|
|
||||||
openai_tools.append(openai_tool_format)
|
|
||||||
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
|
||||||
|
|
||||||
return openai_tools
|
|
||||||
|
|
||||||
# Helper needed by LLMClient's current tool handling logic
|
|
||||||
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
|
||||||
"""Extracts the assistant's message containing tool calls."""
|
|
||||||
try:
|
|
||||||
if isinstance(response, ChatCompletion) and response.choices[0].message.tool_calls:
|
|
||||||
message = response.choices[0].message
|
|
||||||
# Convert Pydantic model to dict for message history
|
|
||||||
return message.model_dump(exclude_unset=True)
|
|
||||||
else:
|
|
||||||
logger.warning("Could not extract original message with tool calls from response.")
|
|
||||||
# Return a placeholder or raise error?
|
|
||||||
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
|
||||||
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
|
||||||
|
|
||||||
def get_usage(self, response: Any) -> dict[str, int] | None:
|
|
||||||
"""Extracts token usage from a non-streaming OpenAI response."""
|
|
||||||
try:
|
|
||||||
if isinstance(response, ChatCompletion) and response.usage:
|
|
||||||
usage = {
|
|
||||||
"prompt_tokens": response.usage.prompt_tokens,
|
|
||||||
"completion_tokens": response.usage.completion_tokens,
|
|
||||||
# "total_tokens": response.usage.total_tokens, # Optional
|
|
||||||
}
|
|
||||||
logger.debug(f"Extracted usage from OpenAI response: {usage}")
|
|
||||||
return usage
|
|
||||||
else:
|
|
||||||
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
|
|
||||||
return None
|
|
||||||
66
src/providers/openai_provider/__init__.py
Normal file
66
src/providers/openai_provider/__init__.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# src/providers/openai_provider/__init__.py
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai import Stream
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
|
||||||
|
from providers.openai_provider.client import initialize_client
|
||||||
|
from providers.openai_provider.completion import create_chat_completion
|
||||||
|
from providers.openai_provider.response import get_content, get_streaming_content, get_usage
|
||||||
|
from providers.openai_provider.tools import (
|
||||||
|
convert_tools,
|
||||||
|
format_tool_results,
|
||||||
|
get_original_message_with_calls,
|
||||||
|
has_tool_calls,
|
||||||
|
parse_tool_calls,
|
||||||
|
)
|
||||||
|
from src.providers.base import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseProvider):
|
||||||
|
"""Provider implementation for OpenAI and compatible APIs."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str | None = None):
|
||||||
|
# BaseProvider __init__ might not be needed if client init handles base_url logic
|
||||||
|
# super().__init__(api_key, base_url) # Let's see if we need this
|
||||||
|
self.client = initialize_client(api_key, base_url)
|
||||||
|
# Store api_key and base_url if needed by BaseProvider or other methods
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = self.client.base_url # Get effective base_url from client
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||||
|
# Pass self (provider instance) to the helper function
|
||||||
|
return create_chat_completion(self, messages, model, temperature, max_tokens, stream, tools)
|
||||||
|
|
||||||
|
def get_streaming_content(self, response: Stream[ChatCompletionChunk]):
|
||||||
|
return get_streaming_content(response)
|
||||||
|
|
||||||
|
def get_content(self, response: ChatCompletion) -> str:
|
||||||
|
return get_content(response)
|
||||||
|
|
||||||
|
def has_tool_calls(self, response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||||
|
# This method might need the full response after streaming, handled by LLMClient
|
||||||
|
return has_tool_calls(response)
|
||||||
|
|
||||||
|
def parse_tool_calls(self, response: ChatCompletion) -> list[dict[str, Any]]:
|
||||||
|
return parse_tool_calls(response)
|
||||||
|
|
||||||
|
def format_tool_results(self, tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||||
|
return format_tool_results(tool_call_id, result)
|
||||||
|
|
||||||
|
def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
return convert_tools(tools)
|
||||||
|
|
||||||
|
def get_original_message_with_calls(self, response: ChatCompletion) -> dict[str, Any]:
|
||||||
|
return get_original_message_with_calls(response)
|
||||||
|
|
||||||
|
def get_usage(self, response: Any) -> dict[str, int] | None:
|
||||||
|
return get_usage(response)
|
||||||
23
src/providers/openai_provider/client.py
Normal file
23
src/providers/openai_provider/client.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# src/providers/openai_provider/client.py
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from src.llm_models import MODELS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_client(api_key: str, base_url: str | None = None) -> OpenAI:
|
||||||
|
"""Initializes and returns an OpenAI client instance."""
|
||||||
|
# Use default OpenAI endpoint if base_url is not provided explicitly
|
||||||
|
effective_base_url = base_url or MODELS.get("openai", {}).get("endpoint")
|
||||||
|
logger.info(f"Initializing OpenAI client with base URL: {effective_base_url}")
|
||||||
|
try:
|
||||||
|
# TODO: Add default headers if needed, similar to the original openai_client.py?
|
||||||
|
# default_headers={"HTTP-Referer": "...", "X-Title": "..."}
|
||||||
|
client = OpenAI(api_key=api_key, base_url=effective_base_url)
|
||||||
|
return client
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize OpenAI client: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
80
src/providers/openai_provider/completion.py
Normal file
80
src/providers/openai_provider/completion.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
# src/providers/openai_provider/completion.py
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai import Stream
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
|
||||||
|
from providers.openai_provider.utils import truncate_messages
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_completion(
|
||||||
|
provider, # The OpenAIProvider instance
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
temperature: float = 0.4,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
stream: bool = True,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> Stream[ChatCompletionChunk] | ChatCompletion:
|
||||||
|
"""Creates a chat completion using the OpenAI API, handling context window truncation."""
|
||||||
|
logger.debug(f"OpenAI create_chat_completion called. Model: {model}, Stream: {stream}, Tools: {bool(tools)}")
|
||||||
|
|
||||||
|
# --- Truncation Step ---
|
||||||
|
truncated_messages, initial_est_tokens, final_est_tokens = truncate_messages(messages, model)
|
||||||
|
# -----------------------
|
||||||
|
|
||||||
|
try:
|
||||||
|
completion_params = {
|
||||||
|
"model": model,
|
||||||
|
"messages": truncated_messages, # Use truncated messages
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
if tools:
|
||||||
|
completion_params["tools"] = tools
|
||||||
|
completion_params["tool_choice"] = "auto" # Let OpenAI decide when to use tools
|
||||||
|
|
||||||
|
# Remove None values like max_tokens if not provided
|
||||||
|
completion_params = {k: v for k, v in completion_params.items() if v is not None}
|
||||||
|
|
||||||
|
# --- Added Debug Logging ---
|
||||||
|
log_params = completion_params.copy()
|
||||||
|
# Avoid logging full messages if they are too long
|
||||||
|
if "messages" in log_params:
|
||||||
|
log_params["messages"] = [
|
||||||
|
{k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in msg.items()}
|
||||||
|
for msg in log_params["messages"][-2:] # Log last 2 messages summary
|
||||||
|
]
|
||||||
|
# Specifically log tools structure if present
|
||||||
|
tools_log = log_params.get("tools", "Not Present")
|
||||||
|
logger.debug(f"Calling OpenAI API. Model: {log_params.get('model')}, Stream: {log_params.get('stream')}, Tools: {tools_log}")
|
||||||
|
logger.debug(f"Full API Params (messages summarized): {log_params}")
|
||||||
|
# --- End Added Debug Logging ---
|
||||||
|
|
||||||
|
response = provider.client.chat.completions.create(**completion_params)
|
||||||
|
logger.debug("OpenAI API call successful.")
|
||||||
|
|
||||||
|
# --- Capture Actual Usage (for UI display later) ---
|
||||||
|
# Log usage if available (primarily non-streaming)
|
||||||
|
actual_usage = None
|
||||||
|
if isinstance(response, ChatCompletion) and response.usage:
|
||||||
|
actual_usage = {
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
"total_tokens": response.usage.total_tokens,
|
||||||
|
}
|
||||||
|
logger.info(f"Actual OpenAI API usage: {actual_usage}")
|
||||||
|
# TODO: How to handle usage for streaming responses? Needs investigation.
|
||||||
|
|
||||||
|
# Return the raw response for now. LLMClient will process it.
|
||||||
|
return response
|
||||||
|
# ----------------------------------------------------
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||||
|
# Re-raise for the LLMClient to handle
|
||||||
|
raise
|
||||||
69
src/providers/openai_provider/response.py
Normal file
69
src/providers/openai_provider/response.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# src/providers/openai_provider/response.py
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai import Stream
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_streaming_content(response: Stream[ChatCompletionChunk]) -> Generator[str, None, None]:
|
||||||
|
"""Yields content chunks from an OpenAI streaming response."""
|
||||||
|
logger.debug("Processing OpenAI stream...")
|
||||||
|
full_delta = ""
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
# Check if choices exist and are not empty
|
||||||
|
if chunk.choices:
|
||||||
|
delta = chunk.choices[0].delta.content
|
||||||
|
if delta:
|
||||||
|
full_delta += delta
|
||||||
|
yield delta
|
||||||
|
# Handle potential finish reasons or other stream elements if needed
|
||||||
|
# else:
|
||||||
|
# logger.debug(f"Stream chunk without choices: {chunk}") # Or handle finish reason etc.
|
||||||
|
logger.debug(f"Stream finished. Total delta length: {len(full_delta)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing OpenAI stream: {e}", exc_info=True)
|
||||||
|
# Yield an error message? Or let the generator stop?
|
||||||
|
yield json.dumps({"error": f"Stream processing error: {str(e)}"})
|
||||||
|
|
||||||
|
|
||||||
|
def get_content(response: ChatCompletion) -> str:
|
||||||
|
"""Extracts content from a non-streaming OpenAI response."""
|
||||||
|
try:
|
||||||
|
# Check if choices exist and are not empty
|
||||||
|
if response.choices:
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
logger.debug(f"Extracted content (length {len(content) if content else 0}) from non-streaming response.")
|
||||||
|
return content or "" # Return empty string if content is None
|
||||||
|
else:
|
||||||
|
logger.warning("No choices found in OpenAI non-streaming response.")
|
||||||
|
return "[No content received]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting content from OpenAI response: {e}", exc_info=True)
|
||||||
|
return f"[Error extracting content: {str(e)}]"
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage(response: Any) -> dict[str, int] | None:
|
||||||
|
"""Extracts token usage from a non-streaming OpenAI response."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, ChatCompletion) and response.usage:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
# "total_tokens": response.usage.total_tokens, # Optional
|
||||||
|
}
|
||||||
|
logger.debug(f"Extracted usage from OpenAI response: {usage}")
|
||||||
|
return usage
|
||||||
|
else:
|
||||||
|
# Don't log warning for streams, as usage isn't expected here
|
||||||
|
if not isinstance(response, Stream):
|
||||||
|
logger.warning(f"Could not extract usage from OpenAI response object of type {type(response)}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting usage from OpenAI response: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
170
src/providers/openai_provider/tools.py
Normal file
170
src/providers/openai_provider/tools.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# src/providers/openai_provider/tools.py
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from openai import Stream
|
||||||
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def has_tool_calls(response: Stream[ChatCompletionChunk] | ChatCompletion) -> bool:
|
||||||
|
"""Checks if the OpenAI response contains tool calls."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, ChatCompletion): # Non-streaming
|
||||||
|
# Check if choices exist and are not empty
|
||||||
|
if response.choices:
|
||||||
|
return bool(response.choices[0].message.tool_calls)
|
||||||
|
else:
|
||||||
|
logger.warning("No choices found in OpenAI non-streaming response for tool check.")
|
||||||
|
return False
|
||||||
|
elif isinstance(response, Stream):
|
||||||
|
# This check remains unreliable for unconsumed streams.
|
||||||
|
# LLMClient needs robust handling after consumption.
|
||||||
|
logger.warning("has_tool_calls check on a stream is unreliable before consumption.")
|
||||||
|
return False # Assume no for unconsumed stream for now
|
||||||
|
else:
|
||||||
|
# If it's already consumed stream or unexpected type
|
||||||
|
logger.warning(f"has_tool_calls received unexpected type: {type(response)}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking for tool calls: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_calls(response: ChatCompletion) -> list[dict[str, Any]]:
|
||||||
|
"""Parses tool calls from a non-streaming OpenAI response."""
|
||||||
|
# This implementation assumes a non-streaming response or a fully buffered stream
|
||||||
|
parsed_calls = []
|
||||||
|
try:
|
||||||
|
if not isinstance(response, ChatCompletion):
|
||||||
|
logger.error(f"parse_tool_calls expects ChatCompletion, got {type(response)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Check if choices exist and are not empty
|
||||||
|
if not response.choices:
|
||||||
|
logger.warning("No choices found in OpenAI non-streaming response for tool parsing.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
tool_calls: list[ChatCompletionMessageToolCall] | None = response.choices[0].message.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.debug(f"Parsing {len(tool_calls)} tool calls from OpenAI response.")
|
||||||
|
for call in tool_calls:
|
||||||
|
if call.type == "function":
|
||||||
|
# Attempt to parse server_name from function name if prefixed
|
||||||
|
# e.g., "server-name__actual-tool-name"
|
||||||
|
parts = call.function.name.split("__", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
server_name, func_name = parts
|
||||||
|
else:
|
||||||
|
# If no prefix, how do we know the server? Needs refinement.
|
||||||
|
# Defaulting to None or a default server? Log warning.
|
||||||
|
logger.warning(f"Could not determine server_name from tool name '{call.function.name}'. Assuming default or error needed.")
|
||||||
|
server_name = None # Or raise error, or use a default?
|
||||||
|
func_name = call.function.name
|
||||||
|
|
||||||
|
# Arguments might be a string needing JSON parsing, or already parsed dict
|
||||||
|
arguments_obj = None
|
||||||
|
try:
|
||||||
|
if isinstance(call.function.arguments, str):
|
||||||
|
arguments_obj = json.loads(call.function.arguments)
|
||||||
|
else:
|
||||||
|
# Assuming it might already be a dict if not a string (less common)
|
||||||
|
arguments_obj = call.function.arguments
|
||||||
|
except json.JSONDecodeError as json_err:
|
||||||
|
logger.error(f"Failed to parse JSON arguments for tool {func_name} (ID: {call.id}): {json_err}")
|
||||||
|
logger.error(f"Raw arguments string: {call.function.arguments}")
|
||||||
|
# Decide how to handle: skip tool, pass raw string, pass error?
|
||||||
|
# Passing raw string for now, but this might break consumers.
|
||||||
|
arguments_obj = {"error": "Failed to parse arguments", "raw_arguments": call.function.arguments}
|
||||||
|
|
||||||
|
parsed_calls.append({
|
||||||
|
"id": call.id,
|
||||||
|
"server_name": server_name, # May be None if not prefixed
|
||||||
|
"function_name": func_name,
|
||||||
|
"arguments": arguments_obj, # Pass parsed arguments (or error dict)
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported tool call type: {call.type}")
|
||||||
|
|
||||||
|
return parsed_calls
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing OpenAI tool calls: {e}", exc_info=True)
|
||||||
|
return [] # Return empty list on error
|
||||||
|
|
||||||
|
|
||||||
|
def format_tool_results(tool_call_id: str, result: Any) -> dict[str, Any]:
|
||||||
|
"""Formats a tool result for an OpenAI follow-up request."""
|
||||||
|
# Result might be a dict (including potential errors) or simple string/number
|
||||||
|
# OpenAI expects the content to be a string, often JSON.
|
||||||
|
try:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
content = json.dumps(result)
|
||||||
|
elif isinstance(result, str):
|
||||||
|
content = result # Allow plain strings if result is already string
|
||||||
|
else:
|
||||||
|
content = str(result) # Ensure it's a string otherwise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error JSON-encoding tool result for {tool_call_id}: {e}")
|
||||||
|
content = json.dumps({"error": "Failed to encode tool result", "original_type": str(type(result))})
|
||||||
|
|
||||||
|
logger.debug(f"Formatting tool result for call ID {tool_call_id}")
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Converts internal tool format to OpenAI's format."""
|
||||||
|
# This function seems identical to the one in src/tools/conversion.py
|
||||||
|
# We can potentially remove it from here and import from the central location.
|
||||||
|
# For now, keep it duplicated to maintain modularity until a decision is made.
|
||||||
|
openai_tools = []
|
||||||
|
logger.debug(f"Converting {len(tools)} tools to OpenAI format.")
|
||||||
|
for tool in tools:
|
||||||
|
server_name = tool.get("server_name")
|
||||||
|
tool_name = tool.get("name")
|
||||||
|
description = tool.get("description")
|
||||||
|
input_schema = tool.get("inputSchema")
|
||||||
|
|
||||||
|
if not server_name or not tool_name or not description or not input_schema:
|
||||||
|
logger.warning(f"Skipping invalid tool definition during conversion: {tool}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prefix tool name with server name to avoid clashes and allow routing
|
||||||
|
prefixed_tool_name = f"{server_name}__{tool_name}"
|
||||||
|
|
||||||
|
openai_tool_format = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": prefixed_tool_name,
|
||||||
|
"description": description,
|
||||||
|
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
||||||
|
},
|
||||||
|
}
|
||||||
|
openai_tools.append(openai_tool_format)
|
||||||
|
logger.debug(f"Converted tool: {prefixed_tool_name}")
|
||||||
|
|
||||||
|
return openai_tools
|
||||||
|
|
||||||
|
|
||||||
|
def get_original_message_with_calls(response: ChatCompletion) -> dict[str, Any]:
|
||||||
|
"""Extracts the assistant's message containing tool calls."""
|
||||||
|
try:
|
||||||
|
if isinstance(response, ChatCompletion) and response.choices and response.choices[0].message.tool_calls:
|
||||||
|
message = response.choices[0].message
|
||||||
|
# Convert Pydantic model to dict for message history
|
||||||
|
return message.model_dump(exclude_unset=True)
|
||||||
|
else:
|
||||||
|
logger.warning("Could not extract original message with tool calls from response.")
|
||||||
|
# Return a placeholder or raise error?
|
||||||
|
return {"role": "assistant", "content": "[Could not extract tool calls message]"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting original message with calls: {e}", exc_info=True)
|
||||||
|
return {"role": "assistant", "content": f"[Error extracting tool calls message: {str(e)}]"}
|
||||||
114
src/providers/openai_provider/utils.py
Normal file
114
src/providers/openai_provider/utils.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
# src/providers/openai_provider/utils.py
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
from src.llm_models import MODELS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_window(model: str) -> int:
|
||||||
|
"""Retrieves the context window size for a given model."""
|
||||||
|
# Default to a safe fallback if model or provider info is missing
|
||||||
|
default_window = 8000
|
||||||
|
try:
|
||||||
|
# Assuming MODELS structure: MODELS['openai']['models'] is a list of dicts
|
||||||
|
provider_models = MODELS.get("openai", {}).get("models", [])
|
||||||
|
for m in provider_models:
|
||||||
|
if m.get("id") == model:
|
||||||
|
return m.get("context_window", default_window)
|
||||||
|
# Fallback if specific model ID not found in our list
|
||||||
|
logger.warning(f"Context window for OpenAI model '{model}' not found in MODELS config. Using default: {default_window}")
|
||||||
|
return default_window
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving context window for model {model}: {e}. Using default: {default_window}", exc_info=True)
|
||||||
|
return default_window
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_openai_token_count(messages: list[dict[str, str]]) -> int:
|
||||||
|
"""
|
||||||
|
Estimates the token count for OpenAI messages using char count / 4 approximation.
|
||||||
|
Note: This is less accurate than using tiktoken.
|
||||||
|
"""
|
||||||
|
total_chars = 0
|
||||||
|
for message in messages:
|
||||||
|
total_chars += len(message.get("role", ""))
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
total_chars += len(content)
|
||||||
|
# Rough approximation for function/tool call overhead if needed later
|
||||||
|
# Using math.ceil to round up, ensuring we don't underestimate too much.
|
||||||
|
estimated_tokens = math.ceil(total_chars / 4.0)
|
||||||
|
logger.debug(f"Estimated OpenAI token count (char/4): {estimated_tokens} for {len(messages)} messages")
|
||||||
|
return estimated_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_messages(messages: list[dict[str, str]], model: str) -> tuple[list[dict[str, str]], int, int]:
|
||||||
|
"""
|
||||||
|
Truncates messages from the beginning if estimated token count exceeds the limit.
|
||||||
|
Preserves the first message if it's a system prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The potentially truncated list of messages.
|
||||||
|
- The initial estimated token count.
|
||||||
|
- The final estimated token count after truncation (if any).
|
||||||
|
"""
|
||||||
|
context_limit = get_context_window(model)
|
||||||
|
# Add a buffer to be safer with approximation
|
||||||
|
buffer = 200 # Reduce buffer slightly as we round up now
|
||||||
|
effective_limit = context_limit - buffer
|
||||||
|
|
||||||
|
initial_estimated_count = estimate_openai_token_count(messages)
|
||||||
|
final_estimated_count = initial_estimated_count
|
||||||
|
|
||||||
|
truncated_messages = list(messages) # Make a copy
|
||||||
|
|
||||||
|
# Identify if the first message is a system prompt
|
||||||
|
has_system_prompt = False
|
||||||
|
if truncated_messages and truncated_messages[0].get("role") == "system":
|
||||||
|
has_system_prompt = True
|
||||||
|
# If only system prompt exists, don't truncate further
|
||||||
|
if len(truncated_messages) == 1 and final_estimated_count > effective_limit:
|
||||||
|
logger.warning(f"System prompt alone ({final_estimated_count} tokens) exceeds effective limit ({effective_limit}). Cannot truncate further.")
|
||||||
|
# Return original messages to avoid removing the only message
|
||||||
|
return messages, initial_estimated_count, final_estimated_count
|
||||||
|
|
||||||
|
while final_estimated_count > effective_limit:
|
||||||
|
if has_system_prompt and len(truncated_messages) <= 1:
|
||||||
|
# Should not happen if check above works, but safety break
|
||||||
|
logger.warning("Truncation stopped: Only system prompt remains.")
|
||||||
|
break
|
||||||
|
if not has_system_prompt and len(truncated_messages) <= 0:
|
||||||
|
logger.warning("Truncation stopped: No messages left.")
|
||||||
|
break # No messages left
|
||||||
|
|
||||||
|
# Determine index to remove: 1 if system prompt exists and list is long enough, else 0
|
||||||
|
remove_index = 1 if has_system_prompt and len(truncated_messages) > 1 else 0
|
||||||
|
|
||||||
|
if remove_index >= len(truncated_messages):
|
||||||
|
logger.error(f"Truncation logic error: remove_index {remove_index} out of bounds for {len(truncated_messages)} messages.")
|
||||||
|
break # Avoid index error
|
||||||
|
|
||||||
|
removed_message = truncated_messages.pop(remove_index)
|
||||||
|
logger.debug(f"Truncating message at index {remove_index} (Role: {removed_message.get('role')}) due to context limit.")
|
||||||
|
|
||||||
|
# Recalculate estimated count
|
||||||
|
final_estimated_count = estimate_openai_token_count(truncated_messages)
|
||||||
|
logger.debug(f"Recalculated estimated tokens: {final_estimated_count}")
|
||||||
|
|
||||||
|
# Safety break if list becomes unexpectedly empty
|
||||||
|
if not truncated_messages:
|
||||||
|
logger.warning("Truncation resulted in empty message list.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if initial_estimated_count != final_estimated_count:
|
||||||
|
logger.info(
|
||||||
|
f"Truncated messages for model {model}. "
|
||||||
|
f"Initial estimated tokens: {initial_estimated_count}, "
|
||||||
|
f"Final estimated tokens: {final_estimated_count}, "
|
||||||
|
f"Limit: {context_limit} (Effective: {effective_limit})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"No truncation needed for model {model}. Estimated tokens: {final_estimated_count}, Limit: {context_limit} (Effective: {effective_limit})")
|
||||||
|
|
||||||
|
return truncated_messages, initial_estimated_count, final_estimated_count
|
||||||
@@ -11,59 +11,6 @@ from typing import Any
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_openai_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Convert MCP tools to OpenAI tool definitions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mcp_tools: List of MCP tools (each with server_name, name, description, inputSchema).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of OpenAI tool definitions.
|
|
||||||
"""
|
|
||||||
openai_tools = []
|
|
||||||
logger.debug(f"Converting {len(mcp_tools)} MCP tools to OpenAI format.")
|
|
||||||
|
|
||||||
for tool in mcp_tools:
|
|
||||||
server_name = tool.get("server_name")
|
|
||||||
tool_name = tool.get("name")
|
|
||||||
description = tool.get("description")
|
|
||||||
input_schema = tool.get("inputSchema")
|
|
||||||
|
|
||||||
if not server_name or not tool_name or not description or not input_schema:
|
|
||||||
logger.warning(f"Skipping invalid MCP tool definition during OpenAI conversion: {tool}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Prefix tool name with server name for routing
|
|
||||||
prefixed_tool_name = f"{server_name}__{tool_name}"
|
|
||||||
|
|
||||||
# Initialize the OpenAI tool structure
|
|
||||||
openai_tool = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": prefixed_tool_name,
|
|
||||||
"description": description,
|
|
||||||
"parameters": input_schema, # OpenAI uses JSON Schema directly
|
|
||||||
},
|
|
||||||
}
|
|
||||||
# Basic validation/cleaning of schema if needed could go here
|
|
||||||
if not isinstance(input_schema, dict) or input_schema.get("type") != "object":
|
|
||||||
logger.warning(f"Input schema for tool '{prefixed_tool_name}' is not a valid JSON object schema. OpenAI might reject this.")
|
|
||||||
# Ensure basic structure if missing
|
|
||||||
if not isinstance(input_schema, dict):
|
|
||||||
input_schema = {}
|
|
||||||
if "type" not in input_schema:
|
|
||||||
input_schema["type"] = "object"
|
|
||||||
if "properties" not in input_schema:
|
|
||||||
input_schema["properties"] = {}
|
|
||||||
openai_tool["function"]["parameters"] = input_schema
|
|
||||||
|
|
||||||
openai_tools.append(openai_tool)
|
|
||||||
logger.debug(f"Converted MCP tool to OpenAI: {prefixed_tool_name}")
|
|
||||||
|
|
||||||
return openai_tools
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def convert_to_google_tools(mcp_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Convert MCP tools to Google Gemini format (dictionary structure).
|
Convert MCP tools to Google Gemini format (dictionary structure).
|
||||||
|
|||||||
Reference in New Issue
Block a user