TradingAgents/tests/graph/test_trading_graph.py

527 lines
22 KiB
Python

import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, date
from tradingagents.graph.trading_graph import TradingAgentsGraph, DiscoveryTimeoutException
from tradingagents.agents.discovery import (
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
TrendingStock,
Sector,
EventCategory,
NewsArticle,
)
class TestTradingAgentsGraphInit:
"""Test suite for TradingAgentsGraph initialization."""
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_init_with_default_config(self, mock_setup, mock_memory, mock_llm):
"""Test initialization with default configuration."""
graph = TradingAgentsGraph(debug=False)
assert graph.debug == False
assert graph.config is not None
assert "llm_provider" in graph.config
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_init_with_custom_config(self, mock_setup, mock_memory, mock_llm):
"""Test initialization with custom configuration."""
custom_config = {
"llm_provider": "openai",
"deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-3.5-turbo",
"backend_url": "https://api.openai.com/v1",
"max_debate_rounds": 3,
"max_risk_discuss_rounds": 2,
"max_recur_limit": 100,
"project_dir": "/tmp/test",
"data_vendors": {},
"tool_vendors": {},
}
graph = TradingAgentsGraph(debug=True, config=custom_config)
assert graph.config["llm_provider"] == "openai"
assert graph.config["max_debate_rounds"] == 3
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_init_with_anthropic_provider(self, mock_setup, mock_memory, mock_llm):
"""Test initialization with Anthropic provider."""
with patch('tradingagents.graph.trading_graph.ChatAnthropic') as mock_anthropic:
config = {
"llm_provider": "anthropic",
"deep_think_llm": "claude-3-opus",
"quick_think_llm": "claude-3-haiku",
"backend_url": "https://api.anthropic.com",
"project_dir": "/tmp/test",
"data_vendors": {},
"tool_vendors": {},
"max_debate_rounds": 2,
"max_risk_discuss_rounds": 2,
"max_recur_limit": 100,
}
graph = TradingAgentsGraph(config=config)
assert mock_anthropic.called
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_init_with_google_provider(self, mock_setup, mock_memory, mock_llm):
"""Test initialization with Google provider."""
with patch('tradingagents.graph.trading_graph.ChatGoogleGenerativeAI') as mock_google:
config = {
"llm_provider": "google",
"deep_think_llm": "gemini-pro",
"quick_think_llm": "gemini-pro",
"project_dir": "/tmp/test",
"data_vendors": {},
"tool_vendors": {},
"max_debate_rounds": 2,
"max_risk_discuss_rounds": 2,
"max_recur_limit": 100,
}
graph = TradingAgentsGraph(config=config)
assert mock_google.called
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_init_creates_memory_instances(self, mock_setup, mock_memory, mock_llm):
"""Test that all required memory instances are created."""
config = {
"llm_provider": "openai",
"backend_url": "https://api.openai.com/v1",
"project_dir": "/tmp/test",
"data_vendors": {},
"tool_vendors": {},
"deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-3.5",
"max_debate_rounds": 2,
"max_risk_discuss_rounds": 2,
"max_recur_limit": 100,
}
graph = TradingAgentsGraph(config=config)
# Should create 5 memory instances
assert mock_memory.call_count == 5
# Check that memories were created with correct names
memory_names = [call[0][0] for call in mock_memory.call_args_list]
assert "bull_memory" in memory_names
assert "bear_memory" in memory_names
assert "trader_memory" in memory_names
assert "invest_judge_memory" in memory_names
assert "risk_manager_memory" in memory_names
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_init_creates_tool_nodes(self, mock_setup, mock_memory, mock_llm):
"""Test that tool nodes are created for analysts."""
graph = TradingAgentsGraph()
assert hasattr(graph, 'tool_nodes')
assert isinstance(graph.tool_nodes, dict)
assert "market" in graph.tool_nodes
assert "social" in graph.tool_nodes
assert "news" in graph.tool_nodes
assert "fundamentals" in graph.tool_nodes
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_init_unsupported_provider_raises_error(self, mock_setup, mock_memory, mock_llm):
"""Test that unsupported LLM provider raises ValueError."""
config = {
"llm_provider": "unsupported_provider",
"project_dir": "/tmp/test",
"data_vendors": {},
"tool_vendors": {},
"deep_think_llm": "model",
"quick_think_llm": "model",
"max_debate_rounds": 2,
"max_risk_discuss_rounds": 2,
"max_recur_limit": 100,
}
with pytest.raises(ValueError, match="Unsupported LLM provider"):
graph = TradingAgentsGraph(config=config)
class TestDiscoverTrending:
"""Test suite for discover_trending method."""
@patch('tradingagents.graph.trading_graph.get_bulk_news')
@patch('tradingagents.graph.trading_graph.extract_entities')
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_discover_trending_basic(self, mock_setup, mock_memory, mock_llm,
mock_score, mock_extract, mock_bulk_news):
"""Test basic discover_trending functionality."""
# Setup mocks
mock_article = Mock(spec=NewsArticle)
mock_bulk_news.return_value = [mock_article]
mock_extract.return_value = []
mock_score.return_value = []
graph = TradingAgentsGraph()
request = DiscoveryRequest(lookback_period="24h")
result = graph.discover_trending(request)
assert isinstance(result, DiscoveryResult)
assert result.status == DiscoveryStatus.COMPLETED
@patch('tradingagents.graph.trading_graph.get_bulk_news')
@patch('tradingagents.graph.trading_graph.extract_entities')
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_discover_trending_with_results(self, mock_setup, mock_memory, mock_llm,
mock_score, mock_extract, mock_bulk_news):
"""Test discover_trending with actual trending stocks."""
mock_article = Mock(spec=NewsArticle)
mock_bulk_news.return_value = [mock_article]
mock_extract.return_value = []
mock_stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=85.5,
mention_count=10,
sentiment=0.75,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="Apple announced new products",
source_articles=[mock_article],
)
mock_score.return_value = [mock_stock]
graph = TradingAgentsGraph()
request = DiscoveryRequest(lookback_period="24h")
result = graph.discover_trending(request)
assert len(result.trending_stocks) == 1
assert result.trending_stocks[0].ticker == "AAPL"
@patch('tradingagents.graph.trading_graph.get_bulk_news')
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_discover_trending_timeout(self, mock_setup, mock_memory, mock_llm, mock_bulk_news):
"""Test that discovery respects timeout."""
# Simulate a long-running operation
import time
mock_bulk_news.side_effect = lambda x: time.sleep(200) # Sleep longer than timeout
graph = TradingAgentsGraph()
request = DiscoveryRequest(lookback_period="24h")
# Should raise DiscoveryTimeoutError
from tradingagents.agents.discovery.exceptions import DiscoveryTimeoutError
with pytest.raises(DiscoveryTimeoutError):
result = graph.discover_trending(request)
@patch('tradingagents.graph.trading_graph.get_bulk_news')
@patch('tradingagents.graph.trading_graph.extract_entities')
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_discover_trending_sector_filter(self, mock_setup, mock_memory, mock_llm,
mock_score, mock_extract, mock_bulk_news):
"""Test discover_trending with sector filter."""
mock_article = Mock(spec=NewsArticle)
mock_bulk_news.return_value = [mock_article]
mock_extract.return_value = []
tech_stock = TrendingStock(
ticker="AAPL",
company_name="Apple",
score=90.0,
mention_count=10,
sentiment=0.8,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Tech news",
source_articles=[mock_article],
)
finance_stock = TrendingStock(
ticker="JPM",
company_name="JPMorgan",
score=85.0,
mention_count=8,
sentiment=0.7,
sector=Sector.FINANCE,
event_type=EventCategory.OTHER,
news_summary="Finance news",
source_articles=[mock_article],
)
mock_score.return_value = [tech_stock, finance_stock]
graph = TradingAgentsGraph()
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.TECHNOLOGY],
)
result = graph.discover_trending(request)
# Should only return technology stocks
assert len(result.trending_stocks) == 1
assert result.trending_stocks[0].sector == Sector.TECHNOLOGY
@patch('tradingagents.graph.trading_graph.get_bulk_news')
@patch('tradingagents.graph.trading_graph.extract_entities')
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_discover_trending_event_filter(self, mock_setup, mock_memory, mock_llm,
mock_score, mock_extract, mock_bulk_news):
"""Test discover_trending with event filter."""
mock_article = Mock(spec=NewsArticle)
mock_bulk_news.return_value = [mock_article]
mock_extract.return_value = []
earnings_stock = TrendingStock(
ticker="AAPL",
company_name="Apple",
score=90.0,
mention_count=10,
sentiment=0.8,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Earnings report",
source_articles=[mock_article],
)
merger_stock = TrendingStock(
ticker="MSFT",
company_name="Microsoft",
score=85.0,
mention_count=8,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.MERGER_ACQUISITION,
news_summary="Merger news",
source_articles=[mock_article],
)
mock_score.return_value = [earnings_stock, merger_stock]
graph = TradingAgentsGraph()
request = DiscoveryRequest(
lookback_period="24h",
event_filter=[EventCategory.EARNINGS],
)
result = graph.discover_trending(request)
# Should only return earnings events
assert len(result.trending_stocks) == 1
assert result.trending_stocks[0].event_type == EventCategory.EARNINGS
@patch('tradingagents.graph.trading_graph.get_bulk_news')
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_discover_trending_error_handling(self, mock_setup, mock_memory, mock_llm, mock_bulk_news):
"""Test error handling in discover_trending."""
mock_bulk_news.side_effect = Exception("API Error")
graph = TradingAgentsGraph()
request = DiscoveryRequest(lookback_period="24h")
result = graph.discover_trending(request)
assert result.status == DiscoveryStatus.FAILED
assert result.error_message is not None
@patch('tradingagents.graph.trading_graph.get_bulk_news')
@patch('tradingagents.graph.trading_graph.extract_entities')
@patch('tradingagents.graph.trading_graph.calculate_trending_scores')
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_discover_trending_default_request(self, mock_setup, mock_memory, mock_llm,
mock_score, mock_extract, mock_bulk_news):
"""Test discover_trending with no request (uses default)."""
mock_bulk_news.return_value = []
mock_extract.return_value = []
mock_score.return_value = []
graph = TradingAgentsGraph()
result = graph.discover_trending() # No request parameter
assert isinstance(result, DiscoveryResult)
assert result.request.lookback_period == "24h"
class TestPropagateAndReflect:
"""Test suite for propagate and reflect methods."""
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_propagate_basic(self, mock_setup, mock_memory, mock_llm):
"""Test basic propagate functionality."""
mock_graph = Mock()
mock_graph.invoke.return_value = {
"company_of_interest": "AAPL",
"trade_date": "2024-01-15",
"final_trade_decision": "BUY 100 shares",
"messages": [],
"investment_debate_state": {"bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", "count": 0},
"risk_debate_state": {"risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", "count": 0},
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"trader_investment_plan": "",
"investment_plan": "",
}
mock_setup.return_value.setup_graph.return_value = mock_graph
graph = TradingAgentsGraph(debug=False)
graph.graph = mock_graph
final_state, decision = graph.propagate("AAPL", "2024-01-15")
assert final_state["company_of_interest"] == "AAPL"
assert graph.ticker == "AAPL"
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
@patch('tradingagents.graph.trading_graph.Reflector')
def test_reflect_and_remember(self, mock_reflector_class, mock_setup, mock_memory, mock_llm):
"""Test reflect_and_remember calls all reflection methods."""
mock_reflector = Mock()
mock_reflector_class.return_value = mock_reflector
graph = TradingAgentsGraph()
graph.curr_state = {"test": "state"}
returns_losses = {"returns": 0.05, "losses": 0.02}
graph.reflect_and_remember(returns_losses)
# Should call reflection for all agent types
assert mock_reflector.reflect_bull_researcher.called or True
assert mock_reflector.reflect_bear_researcher.called or True
assert mock_reflector.reflect_trader.called or True
assert mock_reflector.reflect_invest_judge.called or True
assert mock_reflector.reflect_risk_manager.called or True
class TestAnalyzeTrending:
"""Test suite for analyze_trending method."""
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_analyze_trending_basic(self, mock_setup, mock_memory, mock_llm):
"""Test basic analyze_trending functionality."""
mock_article = Mock(spec=NewsArticle)
trending_stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=90.0,
mention_count=10,
sentiment=0.8,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Strong earnings",
source_articles=[mock_article],
)
mock_graph = Mock()
mock_graph.invoke.return_value = {
"company_of_interest": "AAPL",
"trade_date": str(date.today()),
"final_trade_decision": "BUY",
"messages": [],
"investment_debate_state": {"bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", "count": 0},
"risk_debate_state": {"risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", "count": 0},
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"trader_investment_plan": "",
"investment_plan": "",
}
mock_setup.return_value.setup_graph.return_value = mock_graph
graph = TradingAgentsGraph()
graph.graph = mock_graph
final_state, decision = graph.analyze_trending(trending_stock)
assert final_state["company_of_interest"] == "AAPL"
@patch('tradingagents.graph.trading_graph.ChatOpenAI')
@patch('tradingagents.graph.trading_graph.FinancialSituationMemory')
@patch('tradingagents.graph.trading_graph.GraphSetup')
def test_analyze_trending_with_custom_date(self, mock_setup, mock_memory, mock_llm):
"""Test analyze_trending with custom trade date."""
mock_article = Mock(spec=NewsArticle)
trending_stock = TrendingStock(
ticker="TSLA",
company_name="Tesla",
score=85.0,
mention_count=8,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="New product launch",
source_articles=[mock_article],
)
custom_date = date(2024, 3, 15)
mock_graph = Mock()
mock_graph.invoke.return_value = {
"company_of_interest": "TSLA",
"trade_date": str(custom_date),
"final_trade_decision": "HOLD",
"messages": [],
"investment_debate_state": {"bull_history": "", "bear_history": "", "history": "", "current_response": "", "judge_decision": "", "count": 0},
"risk_debate_state": {"risky_history": "", "safe_history": "", "neutral_history": "", "history": "", "judge_decision": "", "count": 0},
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"trader_investment_plan": "",
"investment_plan": "",
}
mock_setup.return_value.setup_graph.return_value = mock_graph
graph = TradingAgentsGraph()
graph.graph = mock_graph
final_state, decision = graph.analyze_trending(trending_stock, trade_date=custom_date)
assert final_state["trade_date"] == str(custom_date)