Merge pull request #3 from abhishekbhakat/fix-openai-provider
Fix openai provider
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
MODELS = {
|
MODELS = {
|
||||||
"openai": {
|
"openai": {
|
||||||
"name": "OpenAI",
|
"name": "OpenAI",
|
||||||
"endpoint": "https://api.openai.com/v1/chat/completions",
|
"endpoint": "https://api.openai.com/v1",
|
||||||
"models": [
|
"models": [
|
||||||
{
|
{
|
||||||
"id": "gpt-4o",
|
"id": "gpt-4o",
|
||||||
|
|||||||
@@ -37,6 +37,15 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
base_url: Optional base URL for the API (used for OpenRouter)
|
base_url: Optional base URL for the API (used for OpenRouter)
|
||||||
"""
|
"""
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
||||||
|
# Ensure the base_url doesn't end with /chat/completions to prevent URL duplication
|
||||||
|
if base_url and '/chat/completions' in base_url:
|
||||||
|
# Strip the /chat/completions part and ensure we have a proper base URL
|
||||||
|
base_url = base_url.split('/chat/completions')[0]
|
||||||
|
if not base_url.endswith('/v1'):
|
||||||
|
base_url = f"{base_url}/v1" if not base_url.endswith('/') else f"{base_url}v1"
|
||||||
|
logger.info(f"Modified base_url to prevent endpoint duplication: {base_url}")
|
||||||
|
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
def convert_tools(self, airflow_tools: list) -> list:
|
def convert_tools(self, airflow_tools: list) -> list:
|
||||||
@@ -200,7 +209,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def create_follow_up_completion(
|
def create_follow_up_completion(
|
||||||
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None
|
self, messages: list[dict[str, Any]], model: str, temperature: float = 0.4, max_tokens: int | None = None, tool_results: dict[str, Any] = None, original_response: Any = None, stream: bool = False, tools: list[dict[str, Any]] | None = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create a follow-up completion with tool results.
|
Create a follow-up completion with tool results.
|
||||||
@@ -212,22 +221,51 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
tool_results: Results of tool executions
|
tool_results: Results of tool executions
|
||||||
original_response: Original response with tool calls
|
original_response: Original response with tool calls
|
||||||
|
stream: Whether to stream the response
|
||||||
|
tools: List of tool definitions in OpenAI format
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OpenAI response object
|
OpenAI response object or StreamingResponse if streaming
|
||||||
"""
|
"""
|
||||||
if not original_response or not tool_results:
|
if not original_response or not tool_results:
|
||||||
return original_response
|
return original_response
|
||||||
|
|
||||||
# Get the original message with tool calls
|
# Handle StreamingResponse objects
|
||||||
original_message = original_response.choices[0].message
|
if isinstance(original_response, StreamingResponse):
|
||||||
|
logger.info("Processing StreamingResponse in create_follow_up_completion")
|
||||||
|
# Extract tool calls from StreamingResponse
|
||||||
|
tool_calls = []
|
||||||
|
if original_response.tool_call is not None:
|
||||||
|
logger.info(f"Found tool call in StreamingResponse: {original_response.tool_call}")
|
||||||
|
tool_call = original_response.tool_call
|
||||||
|
# Create a simplified tool call structure for the assistant message
|
||||||
|
tool_calls.append({
|
||||||
|
"id": tool_call.get("id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("name", ""),
|
||||||
|
"arguments": json.dumps(tool_call.get("input", {}))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create a new message with the tool calls
|
||||||
|
assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Handle regular OpenAI response objects
|
||||||
|
logger.info("Processing regular OpenAI response in create_follow_up_completion")
|
||||||
|
# Get the original message with tool calls
|
||||||
|
original_message = original_response.choices[0].message
|
||||||
|
|
||||||
# Create a new message with the tool calls
|
# Create a new message with the tool calls
|
||||||
assistant_message = {
|
assistant_message = {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": None,
|
"content": None,
|
||||||
"tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
|
"tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in original_message.tool_calls],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create tool result messages
|
# Create tool result messages
|
||||||
tool_messages = []
|
tool_messages = []
|
||||||
@@ -238,14 +276,14 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
new_messages = messages + [assistant_message] + tool_messages
|
new_messages = messages + [assistant_message] + tool_messages
|
||||||
|
|
||||||
# Make a second request to get the final response
|
# Make a second request to get the final response
|
||||||
logger.info("Making second request with tool results")
|
logger.info(f"Making second request with tool results (stream={stream})")
|
||||||
return self.create_chat_completion(
|
return self.create_chat_completion(
|
||||||
messages=new_messages,
|
messages=new_messages,
|
||||||
model=model,
|
model=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stream=False,
|
stream=stream,
|
||||||
tools=None, # No tools needed for follow-up
|
tools=tools, # Pass tools parameter for follow-up
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_content(self, response: Any) -> str:
|
def get_content(self, response: Any) -> str:
|
||||||
@@ -283,6 +321,9 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
nonlocal tool_call, tool_use_detected, current_tool_call
|
nonlocal tool_call, tool_use_detected, current_tool_call
|
||||||
|
|
||||||
|
# Flag to track if we've yielded any content
|
||||||
|
has_yielded_content = False
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
# Check for tool call in the delta
|
# Check for tool call in the delta
|
||||||
@@ -315,16 +356,25 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
|
|
||||||
# Update the arguments if they're provided in this chunk
|
# Update the arguments if they're provided in this chunk
|
||||||
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
|
if hasattr(delta_tool_call, "function") and hasattr(delta_tool_call.function, "arguments") and delta_tool_call.function.arguments and current_tool_call:
|
||||||
|
# Instead of trying to parse each chunk as JSON, accumulate the arguments
|
||||||
|
# and only parse the complete JSON at the end
|
||||||
|
if "_raw_arguments" not in current_tool_call:
|
||||||
|
current_tool_call["_raw_arguments"] = ""
|
||||||
|
|
||||||
|
# Accumulate the raw arguments
|
||||||
|
current_tool_call["_raw_arguments"] += delta_tool_call.function.arguments
|
||||||
|
|
||||||
|
# Try to parse the accumulated arguments
|
||||||
try:
|
try:
|
||||||
# Try to parse the arguments JSON
|
arguments = json.loads(current_tool_call["_raw_arguments"])
|
||||||
arguments = json.loads(delta_tool_call.function.arguments)
|
|
||||||
if isinstance(arguments, dict):
|
if isinstance(arguments, dict):
|
||||||
current_tool_call["input"].update(arguments)
|
# Successfully parsed the complete JSON
|
||||||
|
current_tool_call["input"] = arguments # Replace instead of update
|
||||||
# Update the StreamingResponse object's tool_call attribute
|
# Update the StreamingResponse object's tool_call attribute
|
||||||
streaming_response.tool_call = current_tool_call
|
streaming_response.tool_call = current_tool_call
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# If the arguments are not valid JSON, just log a warning
|
# This is expected for partial JSON - we'll try again with the next chunk
|
||||||
logger.warning(f"Failed to parse arguments: {delta_tool_call.function.arguments}")
|
logger.debug(f"Accumulated partial arguments: {current_tool_call['_raw_arguments']}")
|
||||||
|
|
||||||
# Skip yielding content for tool call chunks
|
# Skip yielding content for tool call chunks
|
||||||
continue
|
continue
|
||||||
@@ -332,7 +382,30 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
# For the final chunk, set the tool_call attribute
|
# For the final chunk, set the tool_call attribute
|
||||||
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
|
if chunk.choices and hasattr(chunk.choices[0], "finish_reason") and chunk.choices[0].finish_reason == "tool_calls":
|
||||||
logger.info("Streaming response finished with tool_calls reason")
|
logger.info("Streaming response finished with tool_calls reason")
|
||||||
|
|
||||||
|
# If we haven't yielded any content yet and we're finishing with tool_calls,
|
||||||
|
# yield a placeholder message so the frontend has something to display
|
||||||
|
if not has_yielded_content and tool_use_detected:
|
||||||
|
logger.info("Yielding placeholder content for tool call")
|
||||||
|
yield "I'll help you with that." # Simple placeholder message
|
||||||
|
has_yielded_content = True
|
||||||
if current_tool_call:
|
if current_tool_call:
|
||||||
|
# One final attempt to parse the arguments if we have accumulated raw arguments
|
||||||
|
if "_raw_arguments" in current_tool_call and current_tool_call["_raw_arguments"]:
|
||||||
|
try:
|
||||||
|
arguments = json.loads(current_tool_call["_raw_arguments"])
|
||||||
|
if isinstance(arguments, dict):
|
||||||
|
current_tool_call["input"] = arguments
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Failed to parse final arguments: {current_tool_call['_raw_arguments']}")
|
||||||
|
# If we still can't parse it, use an empty dict as fallback
|
||||||
|
if not current_tool_call["input"]:
|
||||||
|
current_tool_call["input"] = {}
|
||||||
|
|
||||||
|
# Remove the raw arguments from the final tool call
|
||||||
|
if "_raw_arguments" in current_tool_call:
|
||||||
|
del current_tool_call["_raw_arguments"]
|
||||||
|
|
||||||
tool_call = current_tool_call
|
tool_call = current_tool_call
|
||||||
logger.info(f"Final tool call: {json.dumps(tool_call)}")
|
logger.info(f"Final tool call: {json.dumps(tool_call)}")
|
||||||
# Update the StreamingResponse object's tool_call attribute
|
# Update the StreamingResponse object's tool_call attribute
|
||||||
@@ -343,6 +416,7 @@ class OpenAIProvider(BaseLLMProvider):
|
|||||||
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
|
if chunk.choices and hasattr(chunk.choices[0].delta, "content") and chunk.choices[0].delta.content:
|
||||||
content = chunk.choices[0].delta.content
|
content = chunk.choices[0].delta.content
|
||||||
yield content
|
yield content
|
||||||
|
has_yielded_content = True
|
||||||
|
|
||||||
# Create the generator
|
# Create the generator
|
||||||
gen = generate()
|
gen = generate()
|
||||||
|
|||||||
@@ -66,6 +66,15 @@ def convert_to_openai_tools(airflow_tools: list) -> list:
|
|||||||
# Add default value if available
|
# Add default value if available
|
||||||
if "default" in param_info and param_info["default"] is not None:
|
if "default" in param_info and param_info["default"] is not None:
|
||||||
param_def["default"] = param_info["default"]
|
param_def["default"] = param_info["default"]
|
||||||
|
|
||||||
|
# Add items property for array types
|
||||||
|
if param_def.get("type") == "array" and "items" not in param_def:
|
||||||
|
# If items is defined in the original schema, use it
|
||||||
|
if "items" in param_info:
|
||||||
|
param_def["items"] = param_info["items"]
|
||||||
|
else:
|
||||||
|
# Otherwise, default to string items
|
||||||
|
param_def["items"] = {"type": "string"}
|
||||||
|
|
||||||
# Add to properties
|
# Add to properties
|
||||||
openai_tool["function"]["parameters"]["properties"][param_name] = param_def
|
openai_tool["function"]["parameters"]["properties"][param_name] = param_def
|
||||||
@@ -141,3 +150,7 @@ def _handle_schema_construct(param_def: dict[str, Any], param_info: dict[str, An
|
|||||||
# If no type was found, default to string
|
# If no type was found, default to string
|
||||||
if "type" not in param_def:
|
if "type" not in param_def:
|
||||||
param_def["type"] = "string"
|
param_def["type"] = "string"
|
||||||
|
|
||||||
|
# Add items property for array types
|
||||||
|
if param_def.get("type") == "array" and "items" not in param_def:
|
||||||
|
param_def["items"] = {"type": "string"}
|
||||||
|
|||||||
Reference in New Issue
Block a user