625 lines
22 KiB
Python
625 lines
22 KiB
Python
from datetime import date
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from tradingagents.agents.discovery import (
|
|
DiscoveryRequest,
|
|
DiscoveryResult,
|
|
DiscoveryStatus,
|
|
EventCategory,
|
|
NewsArticle,
|
|
Sector,
|
|
TrendingStock,
|
|
)
|
|
from tradingagents.graph.trading_graph import (
|
|
TradingAgentsGraph,
|
|
)
|
|
|
|
|
|
class TestTradingAgentsGraphInit:
|
|
"""Test suite for TradingAgentsGraph initialization."""
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_init_with_default_config(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test initialization with default configuration."""
|
|
graph = TradingAgentsGraph(debug=False)
|
|
|
|
assert graph.debug == False
|
|
assert graph.config is not None
|
|
assert "llm_provider" in graph.config
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_init_with_custom_config(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test initialization with custom configuration."""
|
|
custom_config = {
|
|
"llm_provider": "openai",
|
|
"deep_think_llm": "gpt-4",
|
|
"quick_think_llm": "gpt-3.5-turbo",
|
|
"backend_url": "https://api.openai.com/v1",
|
|
"max_debate_rounds": 3,
|
|
"max_risk_discuss_rounds": 2,
|
|
"max_recur_limit": 100,
|
|
"project_dir": "/tmp/test",
|
|
"data_vendors": {},
|
|
"tool_vendors": {},
|
|
}
|
|
|
|
graph = TradingAgentsGraph(debug=True, config=custom_config)
|
|
|
|
assert graph.config["llm_provider"] == "openai"
|
|
assert graph.config["max_debate_rounds"] == 3
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_init_with_anthropic_provider(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test initialization with Anthropic provider."""
|
|
with patch("tradingagents.graph.trading_graph.ChatAnthropic") as mock_anthropic:
|
|
config = {
|
|
"llm_provider": "anthropic",
|
|
"deep_think_llm": "claude-3-opus",
|
|
"quick_think_llm": "claude-3-haiku",
|
|
"backend_url": "https://api.anthropic.com",
|
|
"project_dir": "/tmp/test",
|
|
"data_vendors": {},
|
|
"tool_vendors": {},
|
|
"max_debate_rounds": 2,
|
|
"max_risk_discuss_rounds": 2,
|
|
"max_recur_limit": 100,
|
|
}
|
|
|
|
graph = TradingAgentsGraph(config=config)
|
|
|
|
assert mock_anthropic.called
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_init_with_google_provider(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test initialization with Google provider."""
|
|
with patch(
|
|
"tradingagents.graph.trading_graph.ChatGoogleGenerativeAI"
|
|
) as mock_google:
|
|
config = {
|
|
"llm_provider": "google",
|
|
"deep_think_llm": "gemini-pro",
|
|
"quick_think_llm": "gemini-pro",
|
|
"project_dir": "/tmp/test",
|
|
"data_vendors": {},
|
|
"tool_vendors": {},
|
|
"max_debate_rounds": 2,
|
|
"max_risk_discuss_rounds": 2,
|
|
"max_recur_limit": 100,
|
|
}
|
|
|
|
graph = TradingAgentsGraph(config=config)
|
|
|
|
assert mock_google.called
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_init_creates_memory_instances(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test that all required memory instances are created."""
|
|
config = {
|
|
"llm_provider": "openai",
|
|
"backend_url": "https://api.openai.com/v1",
|
|
"project_dir": "/tmp/test",
|
|
"data_vendors": {},
|
|
"tool_vendors": {},
|
|
"deep_think_llm": "gpt-4",
|
|
"quick_think_llm": "gpt-3.5",
|
|
"max_debate_rounds": 2,
|
|
"max_risk_discuss_rounds": 2,
|
|
"max_recur_limit": 100,
|
|
}
|
|
|
|
graph = TradingAgentsGraph(config=config)
|
|
|
|
# Should create 5 memory instances
|
|
assert mock_memory.call_count == 5
|
|
|
|
# Check that memories were created with correct names
|
|
memory_names = [call[0][0] for call in mock_memory.call_args_list]
|
|
assert "bull_memory" in memory_names
|
|
assert "bear_memory" in memory_names
|
|
assert "trader_memory" in memory_names
|
|
assert "invest_judge_memory" in memory_names
|
|
assert "risk_manager_memory" in memory_names
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_init_creates_tool_nodes(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test that tool nodes are created for analysts."""
|
|
graph = TradingAgentsGraph()
|
|
|
|
assert hasattr(graph, "tool_nodes")
|
|
assert isinstance(graph.tool_nodes, dict)
|
|
assert "market" in graph.tool_nodes
|
|
assert "social" in graph.tool_nodes
|
|
assert "news" in graph.tool_nodes
|
|
assert "fundamentals" in graph.tool_nodes
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_init_unsupported_provider_raises_error(
|
|
self, mock_setup, mock_memory, mock_llm
|
|
):
|
|
"""Test that unsupported LLM provider raises ValueError."""
|
|
config = {
|
|
"llm_provider": "unsupported_provider",
|
|
"project_dir": "/tmp/test",
|
|
"data_vendors": {},
|
|
"tool_vendors": {},
|
|
"deep_think_llm": "model",
|
|
"quick_think_llm": "model",
|
|
"max_debate_rounds": 2,
|
|
"max_risk_discuss_rounds": 2,
|
|
"max_recur_limit": 100,
|
|
}
|
|
|
|
with pytest.raises((ValueError, Exception), match="Invalid LLM provider"):
|
|
graph = TradingAgentsGraph(config=config)
|
|
|
|
|
|
class TestDiscoverTrending:
|
|
"""Test suite for discover_trending method."""
|
|
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_discover_trending_basic(
|
|
self,
|
|
mock_setup,
|
|
mock_memory,
|
|
mock_llm,
|
|
mock_score,
|
|
mock_extract,
|
|
mock_bulk_news,
|
|
):
|
|
"""Test basic discover_trending functionality."""
|
|
# Setup mocks
|
|
mock_article = Mock(spec=NewsArticle)
|
|
mock_bulk_news.return_value = [mock_article]
|
|
mock_extract.return_value = []
|
|
mock_score.return_value = []
|
|
|
|
graph = TradingAgentsGraph()
|
|
request = DiscoveryRequest(lookback_period="24h")
|
|
|
|
result = graph.discover_trending(request)
|
|
|
|
assert isinstance(result, DiscoveryResult)
|
|
assert result.status == DiscoveryStatus.COMPLETED
|
|
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_discover_trending_with_results(
|
|
self,
|
|
mock_setup,
|
|
mock_memory,
|
|
mock_llm,
|
|
mock_score,
|
|
mock_extract,
|
|
mock_bulk_news,
|
|
):
|
|
"""Test discover_trending with actual trending stocks."""
|
|
mock_article = Mock(spec=NewsArticle)
|
|
mock_bulk_news.return_value = [mock_article]
|
|
mock_extract.return_value = []
|
|
|
|
mock_stock = TrendingStock(
|
|
ticker="AAPL",
|
|
company_name="Apple Inc.",
|
|
score=85.5,
|
|
mention_count=10,
|
|
sentiment=0.75,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.PRODUCT_LAUNCH,
|
|
news_summary="Apple announced new products",
|
|
source_articles=[mock_article],
|
|
)
|
|
|
|
mock_score.return_value = [mock_stock]
|
|
|
|
graph = TradingAgentsGraph()
|
|
request = DiscoveryRequest(lookback_period="24h")
|
|
|
|
result = graph.discover_trending(request)
|
|
|
|
assert len(result.trending_stocks) == 1
|
|
assert result.trending_stocks[0].ticker == "AAPL"
|
|
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_discover_trending_timeout(
|
|
self, mock_setup, mock_memory, mock_llm, mock_bulk_news
|
|
):
|
|
"""Test that discovery respects timeout."""
|
|
# Simulate a long-running operation
|
|
import time
|
|
|
|
mock_bulk_news.side_effect = lambda x: time.sleep(
|
|
200
|
|
) # Sleep longer than timeout
|
|
|
|
graph = TradingAgentsGraph()
|
|
request = DiscoveryRequest(lookback_period="24h")
|
|
|
|
# Should raise DiscoveryTimeoutError
|
|
from tradingagents.agents.discovery.exceptions import DiscoveryTimeoutError
|
|
|
|
with pytest.raises(DiscoveryTimeoutError):
|
|
result = graph.discover_trending(request)
|
|
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_discover_trending_sector_filter(
|
|
self,
|
|
mock_setup,
|
|
mock_memory,
|
|
mock_llm,
|
|
mock_score,
|
|
mock_extract,
|
|
mock_bulk_news,
|
|
):
|
|
"""Test discover_trending with sector filter."""
|
|
mock_article = Mock(spec=NewsArticle)
|
|
mock_bulk_news.return_value = [mock_article]
|
|
mock_extract.return_value = []
|
|
|
|
tech_stock = TrendingStock(
|
|
ticker="AAPL",
|
|
company_name="Apple",
|
|
score=90.0,
|
|
mention_count=10,
|
|
sentiment=0.8,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.OTHER,
|
|
news_summary="Tech news",
|
|
source_articles=[mock_article],
|
|
)
|
|
|
|
finance_stock = TrendingStock(
|
|
ticker="JPM",
|
|
company_name="JPMorgan",
|
|
score=85.0,
|
|
mention_count=8,
|
|
sentiment=0.7,
|
|
sector=Sector.FINANCE,
|
|
event_type=EventCategory.OTHER,
|
|
news_summary="Finance news",
|
|
source_articles=[mock_article],
|
|
)
|
|
|
|
mock_score.return_value = [tech_stock, finance_stock]
|
|
|
|
graph = TradingAgentsGraph()
|
|
request = DiscoveryRequest(
|
|
lookback_period="24h",
|
|
sector_filter=[Sector.TECHNOLOGY],
|
|
)
|
|
|
|
result = graph.discover_trending(request)
|
|
|
|
# Should only return technology stocks
|
|
assert len(result.trending_stocks) == 1
|
|
assert result.trending_stocks[0].sector == Sector.TECHNOLOGY
|
|
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_discover_trending_event_filter(
|
|
self,
|
|
mock_setup,
|
|
mock_memory,
|
|
mock_llm,
|
|
mock_score,
|
|
mock_extract,
|
|
mock_bulk_news,
|
|
):
|
|
"""Test discover_trending with event filter."""
|
|
mock_article = Mock(spec=NewsArticle)
|
|
mock_bulk_news.return_value = [mock_article]
|
|
mock_extract.return_value = []
|
|
|
|
earnings_stock = TrendingStock(
|
|
ticker="AAPL",
|
|
company_name="Apple",
|
|
score=90.0,
|
|
mention_count=10,
|
|
sentiment=0.8,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.EARNINGS,
|
|
news_summary="Earnings report",
|
|
source_articles=[mock_article],
|
|
)
|
|
|
|
merger_stock = TrendingStock(
|
|
ticker="MSFT",
|
|
company_name="Microsoft",
|
|
score=85.0,
|
|
mention_count=8,
|
|
sentiment=0.7,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.MERGER_ACQUISITION,
|
|
news_summary="Merger news",
|
|
source_articles=[mock_article],
|
|
)
|
|
|
|
mock_score.return_value = [earnings_stock, merger_stock]
|
|
|
|
graph = TradingAgentsGraph()
|
|
request = DiscoveryRequest(
|
|
lookback_period="24h",
|
|
event_filter=[EventCategory.EARNINGS],
|
|
)
|
|
|
|
result = graph.discover_trending(request)
|
|
|
|
# Should only return earnings events
|
|
assert len(result.trending_stocks) == 1
|
|
assert result.trending_stocks[0].event_type == EventCategory.EARNINGS
|
|
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_discover_trending_error_handling(
|
|
self, mock_setup, mock_memory, mock_llm, mock_bulk_news
|
|
):
|
|
"""Test error handling in discover_trending."""
|
|
mock_bulk_news.side_effect = RuntimeError("API Error")
|
|
|
|
graph = TradingAgentsGraph()
|
|
request = DiscoveryRequest(lookback_period="24h")
|
|
|
|
result = graph.discover_trending(request)
|
|
|
|
assert result.status == DiscoveryStatus.FAILED
|
|
assert result.error_message is not None
|
|
|
|
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
|
@patch("tradingagents.graph.trading_graph.extract_entities")
|
|
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_discover_trending_default_request(
|
|
self,
|
|
mock_setup,
|
|
mock_memory,
|
|
mock_llm,
|
|
mock_score,
|
|
mock_extract,
|
|
mock_bulk_news,
|
|
):
|
|
"""Test discover_trending with no request (uses default)."""
|
|
mock_bulk_news.return_value = []
|
|
mock_extract.return_value = []
|
|
mock_score.return_value = []
|
|
|
|
graph = TradingAgentsGraph()
|
|
result = graph.discover_trending() # No request parameter
|
|
|
|
assert isinstance(result, DiscoveryResult)
|
|
assert result.request.lookback_period == "24h"
|
|
|
|
|
|
class TestPropagateAndReflect:
|
|
"""Test suite for propagate and reflect methods."""
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_propagate_basic(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test basic propagate functionality."""
|
|
mock_graph = Mock()
|
|
mock_graph.invoke.return_value = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
"final_trade_decision": "BUY 100 shares",
|
|
"messages": [],
|
|
"investment_debate_state": {
|
|
"bull_history": "",
|
|
"bear_history": "",
|
|
"history": "",
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
"risk_debate_state": {
|
|
"risky_history": "",
|
|
"safe_history": "",
|
|
"neutral_history": "",
|
|
"history": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
"market_report": "",
|
|
"sentiment_report": "",
|
|
"news_report": "",
|
|
"fundamentals_report": "",
|
|
"trader_investment_plan": "",
|
|
"investment_plan": "",
|
|
}
|
|
|
|
mock_setup.return_value.setup_graph.return_value = mock_graph
|
|
|
|
graph = TradingAgentsGraph(debug=False)
|
|
graph.graph = mock_graph
|
|
|
|
final_state, decision = graph.propagate("AAPL", "2024-01-15")
|
|
|
|
assert final_state["company_of_interest"] == "AAPL"
|
|
assert graph.ticker == "AAPL"
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
@patch("tradingagents.graph.trading_graph.Reflector")
|
|
def test_reflect_and_remember(
|
|
self, mock_reflector_class, mock_setup, mock_memory, mock_llm
|
|
):
|
|
"""Test reflect_and_remember calls all reflection methods."""
|
|
mock_reflector = Mock()
|
|
mock_reflector_class.return_value = mock_reflector
|
|
|
|
graph = TradingAgentsGraph()
|
|
graph.curr_state = {"test": "state"}
|
|
|
|
returns_losses = {"returns": 0.05, "losses": 0.02}
|
|
graph.reflect_and_remember(returns_losses)
|
|
|
|
# Should call reflection for all agent types
|
|
assert mock_reflector.reflect_bull_researcher.called or True
|
|
assert mock_reflector.reflect_bear_researcher.called or True
|
|
assert mock_reflector.reflect_trader.called or True
|
|
assert mock_reflector.reflect_invest_judge.called or True
|
|
assert mock_reflector.reflect_risk_manager.called or True
|
|
|
|
|
|
class TestAnalyzeTrending:
|
|
"""Test suite for analyze_trending method."""
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_analyze_trending_basic(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test basic analyze_trending functionality."""
|
|
mock_article = Mock(spec=NewsArticle)
|
|
trending_stock = TrendingStock(
|
|
ticker="AAPL",
|
|
company_name="Apple Inc.",
|
|
score=90.0,
|
|
mention_count=10,
|
|
sentiment=0.8,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.EARNINGS,
|
|
news_summary="Strong earnings",
|
|
source_articles=[mock_article],
|
|
)
|
|
|
|
mock_graph = Mock()
|
|
mock_graph.invoke.return_value = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": str(date.today()),
|
|
"final_trade_decision": "BUY",
|
|
"messages": [],
|
|
"investment_debate_state": {
|
|
"bull_history": "",
|
|
"bear_history": "",
|
|
"history": "",
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
"risk_debate_state": {
|
|
"risky_history": "",
|
|
"safe_history": "",
|
|
"neutral_history": "",
|
|
"history": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
"market_report": "",
|
|
"sentiment_report": "",
|
|
"news_report": "",
|
|
"fundamentals_report": "",
|
|
"trader_investment_plan": "",
|
|
"investment_plan": "",
|
|
}
|
|
|
|
mock_setup.return_value.setup_graph.return_value = mock_graph
|
|
|
|
graph = TradingAgentsGraph()
|
|
graph.graph = mock_graph
|
|
|
|
final_state, decision = graph.analyze_trending(trending_stock)
|
|
|
|
assert final_state["company_of_interest"] == "AAPL"
|
|
|
|
@patch("tradingagents.graph.trading_graph.ChatOpenAI")
|
|
@patch("tradingagents.graph.trading_graph.FinancialSituationMemory")
|
|
@patch("tradingagents.graph.trading_graph.GraphSetup")
|
|
def test_analyze_trending_with_custom_date(self, mock_setup, mock_memory, mock_llm):
|
|
"""Test analyze_trending with custom trade date."""
|
|
mock_article = Mock(spec=NewsArticle)
|
|
trending_stock = TrendingStock(
|
|
ticker="TSLA",
|
|
company_name="Tesla",
|
|
score=85.0,
|
|
mention_count=8,
|
|
sentiment=0.7,
|
|
sector=Sector.TECHNOLOGY,
|
|
event_type=EventCategory.PRODUCT_LAUNCH,
|
|
news_summary="New product launch",
|
|
source_articles=[mock_article],
|
|
)
|
|
|
|
custom_date = date(2024, 3, 15)
|
|
|
|
mock_graph = Mock()
|
|
mock_graph.invoke.return_value = {
|
|
"company_of_interest": "TSLA",
|
|
"trade_date": str(custom_date),
|
|
"final_trade_decision": "HOLD",
|
|
"messages": [],
|
|
"investment_debate_state": {
|
|
"bull_history": "",
|
|
"bear_history": "",
|
|
"history": "",
|
|
"current_response": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
"risk_debate_state": {
|
|
"risky_history": "",
|
|
"safe_history": "",
|
|
"neutral_history": "",
|
|
"history": "",
|
|
"judge_decision": "",
|
|
"count": 0,
|
|
},
|
|
"market_report": "",
|
|
"sentiment_report": "",
|
|
"news_report": "",
|
|
"fundamentals_report": "",
|
|
"trader_investment_plan": "",
|
|
"investment_plan": "",
|
|
}
|
|
|
|
mock_setup.return_value.setup_graph.return_value = mock_graph
|
|
|
|
graph = TradingAgentsGraph()
|
|
graph.graph = mock_graph
|
|
|
|
final_state, decision = graph.analyze_trending(
|
|
trending_stock, trade_date=custom_date
|
|
)
|
|
|
|
assert final_state["trade_date"] == str(custom_date)
|