diff --git a/tests/unit/agents/test_analyst_agents.py b/tests/unit/agents/test_analyst_agents.py new file mode 100644 index 00000000..b576a97b --- /dev/null +++ b/tests/unit/agents/test_analyst_agents.py @@ -0,0 +1,73 @@ +from unittest.mock import MagicMock +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.runnables import Runnable +from tradingagents.agents.analysts.fundamentals_analyst import create_fundamentals_analyst +from tradingagents.agents.analysts.market_analyst import create_market_analyst +from tradingagents.agents.analysts.social_media_analyst import create_social_media_analyst +from tradingagents.agents.analysts.news_analyst import create_news_analyst + +class MockRunnable(Runnable): + def __init__(self, invoke_responses): + self.invoke_responses = invoke_responses + self.call_count = 0 + + def invoke(self, input, config=None, **kwargs): + response = self.invoke_responses[self.call_count] + self.call_count += 1 + return response + +class MockLLM(Runnable): + def __init__(self, invoke_responses): + self.runnable = MockRunnable(invoke_responses) + self.tools_bound = None + + def invoke(self, input, config=None, **kwargs): + return self.runnable.invoke(input, config=config, **kwargs) + + def bind_tools(self, tools): + self.tools_bound = tools + return self.runnable + +@pytest.fixture +def mock_state(): + return { + "messages": [HumanMessage(content="Analyze AAPL.")], + "trade_date": "2024-05-15", + "company_of_interest": "AAPL", + } + +@pytest.fixture +def mock_llm_with_tool_call(): + # 1. First call: The LLM decides to use a tool + tool_call_msg = AIMessage( + content="", + tool_calls=[ + {"name": "mock_tool", "args": {"query": "test"}, "id": "call_123"} + ] + ) + # 2. Second call: The LLM receives the tool output and writes the report + final_report_msg = AIMessage( + content="This is the final report after running the tool." + ) + return MockLLM([tool_call_msg, final_report_msg]) + +def test_fundamentals_analyst_tool_loop(mock_state, mock_llm_with_tool_call): + node = create_fundamentals_analyst(mock_llm_with_tool_call) + result = node(mock_state) + assert "This is the final report after running the tool." in result["fundamentals_report"] + +def test_market_analyst_tool_loop(mock_state, mock_llm_with_tool_call): + node = create_market_analyst(mock_llm_with_tool_call) + result = node(mock_state) + assert "This is the final report after running the tool." in result["market_report"] + +def test_social_media_analyst_tool_loop(mock_state, mock_llm_with_tool_call): + node = create_social_media_analyst(mock_llm_with_tool_call) + result = node(mock_state) + assert "This is the final report after running the tool." in result["sentiment_report"] + +def test_news_analyst_tool_loop(mock_state, mock_llm_with_tool_call): + node = create_news_analyst(mock_llm_with_tool_call) + result = node(mock_state) + assert "This is the final report after running the tool." in result["news_report"] diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 6b63b1b4..1a5f8aef 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -11,6 +11,8 @@ from tradingagents.agents.utils.fundamental_data_tools import ( get_sector_relative, ) from tradingagents.agents.utils.news_data_tools import get_insider_transactions +from tradingagents.agents.utils.tool_runner import run_tool_loop +from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.dataflows.config import get_config @@ -66,12 +68,9 @@ def create_fundamentals_analyst(llm): chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) + result = run_tool_loop(chain, state["messages"], tools) - report = "" - - if len(result.tool_calls) == 0: - report = result.content + report = result.content or "" return { "messages": [result], diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index e5a9982d..78df7f7d 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -3,6 +3,8 @@ import time from tradingagents.agents.utils.core_stock_tools import get_stock_data from tradingagents.agents.utils.technical_indicators_tools import get_indicators from tradingagents.agents.utils.fundamental_data_tools import get_macro_regime +from tradingagents.agents.utils.tool_runner import run_tool_loop +from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.dataflows.config import get_config @@ -73,16 +75,14 @@ Volume-Based Indicators: chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) + result = run_tool_loop(chain, state["messages"], tools) - report = "" + report = result.content or "" macro_regime_report = "" - if len(result.tool_calls) == 0: - report = result.content - # Extract macro regime section if present - if "Macro Regime Classification" in report or "RISK-ON" in report.upper() or "RISK-OFF" in report.upper() or "TRANSITION" in report.upper(): - macro_regime_report = report + # Extract macro regime section if present + if report and ("Macro Regime Classification" in report or "RISK-ON" in report.upper() or "RISK-OFF" in report.upper() or "TRANSITION" in report.upper()): + macro_regime_report = report return { "messages": [result], diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index 7c29b7b4..ec50880c 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,6 +1,8 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import json from tradingagents.agents.utils.news_data_tools import get_news, get_global_news +from tradingagents.agents.utils.tool_runner import run_tool_loop +from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.dataflows.config import get_config @@ -42,12 +44,9 @@ def create_news_analyst(llm): prompt = prompt.partial(instrument_context=instrument_context) chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) + result = run_tool_loop(chain, state["messages"], tools) - report = "" - - if len(result.tool_calls) == 0: - report = result.content + report = result.content or "" return { "messages": [result], diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 9c34a5f1..ac7c0afd 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -2,6 +2,8 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json from tradingagents.agents.utils.news_data_tools import get_news +from tradingagents.agents.utils.tool_runner import run_tool_loop +from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.dataflows.config import get_config @@ -43,12 +45,9 @@ def create_social_media_analyst(llm): chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) + result = run_tool_loop(chain, state["messages"], tools) - report = "" - - if len(result.tool_calls) == 0: - report = result.content + report = result.content or "" return { "messages": [result],