This commit is contained in:
godnight10061 2025-11-14 11:40:36 +08:00 committed by GitHub
commit e0c6ddb8f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 199 additions and 57 deletions

2
tests/__init__.py Normal file
View File

@ -0,0 +1,2 @@
# Tests for TradingAgents

View File

@ -0,0 +1,135 @@
"""
Tests for openai dataflow module to ensure compatibility with different LLM providers.
This test reproduces issue #275 where Gemini and OpenRouter fail with openai vendor.
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
from tradingagents.dataflows.openai import (
get_stock_news_openai,
get_global_news_openai,
get_fundamentals_openai
)
class TestOpenAIDataflowCompatibility:
"""Test that openai dataflow functions work with different LLM providers."""
@pytest.fixture
def mock_config_openai(self):
"""Mock config for OpenAI provider."""
return {
"backend_url": "https://api.openai.com/v1",
"quick_think_llm": "gpt-4o-mini",
"llm_provider": "openai"
}
@pytest.fixture
def mock_config_gemini(self):
"""Mock config for Google Gemini provider."""
return {
"backend_url": "https://generativelanguage.googleapis.com/v1",
"quick_think_llm": "gemini-2.0-flash",
"llm_provider": "google"
}
@pytest.fixture
def mock_config_openrouter(self):
"""Mock config for OpenRouter provider."""
return {
"backend_url": "https://openrouter.ai/api/v1",
"quick_think_llm": "deepseek/deepseek-chat-v3-0324:free",
"llm_provider": "openrouter"
}
@patch('tradingagents.dataflows.openai.get_config')
@patch('tradingagents.dataflows.openai.OpenAI')
def test_get_global_news_with_openai(self, mock_openai_class, mock_get_config, mock_config_openai):
"""Test get_global_news_openai works with OpenAI provider."""
mock_get_config.return_value = mock_config_openai
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
# Mock chat completion response (standard API)
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Test news content"
mock_client.chat.completions.create.return_value = mock_response
# Call the function
result = get_global_news_openai("2024-11-09", 7, 5)
# Verify it was called
assert mock_client.chat.completions.create.called
assert result == "Test news content"
@patch('tradingagents.dataflows.openai.get_config')
@patch('tradingagents.dataflows.openai.OpenAI')
def test_get_global_news_with_gemini(self, mock_openai_class, mock_get_config, mock_config_gemini):
"""Test get_global_news_openai works with Gemini provider (via OpenAI-compatible API)."""
mock_get_config.return_value = mock_config_gemini
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
# Mock chat completion response
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Test Gemini news content"
mock_client.chat.completions.create.return_value = mock_response
# Call the function - should not raise an error
result = get_global_news_openai("2024-11-09", 7, 5)
# Verify it was called with standard chat completion API
assert mock_client.chat.completions.create.called
assert result == "Test Gemini news content"
@patch('tradingagents.dataflows.openai.get_config')
@patch('tradingagents.dataflows.openai.OpenAI')
def test_get_global_news_with_openrouter(self, mock_openai_class, mock_get_config, mock_config_openrouter):
"""Test get_global_news_openai works with OpenRouter provider."""
mock_get_config.return_value = mock_config_openrouter
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
# Mock chat completion response
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Test OpenRouter news content"
mock_client.chat.completions.create.return_value = mock_response
# Call the function - should not raise an error
result = get_global_news_openai("2024-11-09", 7, 5)
# Verify it was called with standard chat completion API
assert mock_client.chat.completions.create.called
assert result == "Test OpenRouter news content"
@patch('tradingagents.dataflows.openai.get_config')
@patch('tradingagents.dataflows.openai.OpenAI')
def test_get_fundamentals_with_different_providers(self, mock_openai_class, mock_get_config, mock_config_gemini):
"""Test get_fundamentals_openai works with different providers."""
mock_get_config.return_value = mock_config_gemini
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
# Mock chat completion response
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Test fundamentals data"
mock_client.chat.completions.create.return_value = mock_response
# Call the function
result = get_fundamentals_openai("AAPL", "2024-11-09")
# Verify it was called
assert mock_client.chat.completions.create.called
assert result == "Test fundamentals data"

View File

@ -3,105 +3,110 @@ from .config import get_config
def get_stock_news_openai(query, start_date, end_date):
"""
Retrieve stock news using LLM provider configured in backend_url.
Compatible with OpenAI, Gemini (via OpenAI-compatible API), and OpenRouter.
Args:
query: Stock ticker or search query
start_date: Start date for news search
end_date: End date for news search
Returns:
str: News content as text
"""
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
# Use standard chat completions API for compatibility with all providers
response = client.chat.completions.create(
model=config["quick_think_llm"],
input=[
messages=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Social Media for {query} from {start_date} to {end_date}? Make sure you only get the data posted during that period.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
"content": "You are a financial news analyst. Search and summarize relevant news from social media and news sources."
},
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
"role": "user",
"content": f"Can you search Social Media for {query} from {start_date} to {end_date}? Make sure you only get the data posted during that period."
}
],
temperature=1,
max_output_tokens=4096,
max_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
return response.choices[0].message.content
def get_global_news_openai(curr_date, look_back_days=7, limit=5):
"""
Retrieve global news using LLM provider configured in backend_url.
Compatible with OpenAI, Gemini (via OpenAI-compatible API), and OpenRouter.
Args:
curr_date: Current date in yyyy-mm-dd format
look_back_days: Number of days to look back (default 7)
limit: Maximum number of articles to return (default 5)
Returns:
str: Global news content as text
"""
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
# Use standard chat completions API for compatibility with all providers
response = client.chat.completions.create(
model=config["quick_think_llm"],
input=[
messages=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search global or macroeconomics news from {look_back_days} days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period. Limit the results to {limit} articles.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
"content": "You are a financial news analyst. Search and summarize relevant global and macroeconomic news for trading purposes."
},
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
"role": "user",
"content": f"Can you search global or macroeconomics news from {look_back_days} days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period. Limit the results to {limit} articles."
}
],
temperature=1,
max_output_tokens=4096,
max_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
return response.choices[0].message.content
def get_fundamentals_openai(ticker, curr_date):
"""
Retrieve fundamental data using LLM provider configured in backend_url.
Compatible with OpenAI, Gemini (via OpenAI-compatible API), and OpenRouter.
Args:
ticker: Stock ticker symbol
curr_date: Current date in yyyy-mm-dd format
Returns:
str: Fundamental data as text (table format with PE/PS/Cash flow etc)
"""
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
# Use standard chat completions API for compatibility with all providers
response = client.chat.completions.create(
model=config["quick_think_llm"],
input=[
messages=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
"content": "You are a financial analyst. Search and provide fundamental data for stocks in a structured table format."
},
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
"role": "user",
"content": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc"
}
],
temperature=1,
max_output_tokens=4096,
max_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
return response.choices[0].message.content