Merge pull request #91 from aguzererler/fix-analyst-tool-loop-18174774985345323969

fix: use run_tool_loop instead of invoke in analyst agents
This commit is contained in:
ahmet guzererler 2026-03-23 18:36:33 +01:00 committed by GitHub
commit 41681e0f9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 92 additions and 22 deletions

View File

@ -0,0 +1,73 @@
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import Runnable
from tradingagents.agents.analysts.fundamentals_analyst import create_fundamentals_analyst
from tradingagents.agents.analysts.market_analyst import create_market_analyst
from tradingagents.agents.analysts.social_media_analyst import create_social_media_analyst
from tradingagents.agents.analysts.news_analyst import create_news_analyst
class MockRunnable(Runnable):
def __init__(self, invoke_responses):
self.invoke_responses = invoke_responses
self.call_count = 0
def invoke(self, input, config=None, **kwargs):
response = self.invoke_responses[self.call_count]
self.call_count += 1
return response
class MockLLM(Runnable):
def __init__(self, invoke_responses):
self.runnable = MockRunnable(invoke_responses)
self.tools_bound = None
def invoke(self, input, config=None, **kwargs):
return self.runnable.invoke(input, config=config, **kwargs)
def bind_tools(self, tools):
self.tools_bound = tools
return self.runnable
@pytest.fixture
def mock_state():
return {
"messages": [HumanMessage(content="Analyze AAPL.")],
"trade_date": "2024-05-15",
"company_of_interest": "AAPL",
}
@pytest.fixture
def mock_llm_with_tool_call():
# 1. First call: The LLM decides to use a tool
tool_call_msg = AIMessage(
content="",
tool_calls=[
{"name": "mock_tool", "args": {"query": "test"}, "id": "call_123"}
]
)
# 2. Second call: The LLM receives the tool output and writes the report
final_report_msg = AIMessage(
content="This is the final report after running the tool."
)
return MockLLM([tool_call_msg, final_report_msg])
def test_fundamentals_analyst_tool_loop(mock_state, mock_llm_with_tool_call):
node = create_fundamentals_analyst(mock_llm_with_tool_call)
result = node(mock_state)
assert "This is the final report after running the tool." in result["fundamentals_report"]
def test_market_analyst_tool_loop(mock_state, mock_llm_with_tool_call):
node = create_market_analyst(mock_llm_with_tool_call)
result = node(mock_state)
assert "This is the final report after running the tool." in result["market_report"]
def test_social_media_analyst_tool_loop(mock_state, mock_llm_with_tool_call):
node = create_social_media_analyst(mock_llm_with_tool_call)
result = node(mock_state)
assert "This is the final report after running the tool." in result["sentiment_report"]
def test_news_analyst_tool_loop(mock_state, mock_llm_with_tool_call):
node = create_news_analyst(mock_llm_with_tool_call)
result = node(mock_state)
assert "This is the final report after running the tool." in result["news_report"]

View File

@ -11,6 +11,8 @@ from tradingagents.agents.utils.fundamental_data_tools import (
get_sector_relative, get_sector_relative,
) )
from tradingagents.agents.utils.news_data_tools import get_insider_transactions from tradingagents.agents.utils.news_data_tools import get_insider_transactions
from tradingagents.agents.utils.tool_runner import run_tool_loop
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -66,12 +68,9 @@ def create_fundamentals_analyst(llm):
chain = prompt | llm.bind_tools(tools) chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"]) result = run_tool_loop(chain, state["messages"], tools)
report = "" report = result.content or ""
if len(result.tool_calls) == 0:
report = result.content
return { return {
"messages": [result], "messages": [result],

View File

@ -3,6 +3,8 @@ import time
from tradingagents.agents.utils.core_stock_tools import get_stock_data from tradingagents.agents.utils.core_stock_tools import get_stock_data
from tradingagents.agents.utils.technical_indicators_tools import get_indicators from tradingagents.agents.utils.technical_indicators_tools import get_indicators
from tradingagents.agents.utils.fundamental_data_tools import get_macro_regime from tradingagents.agents.utils.fundamental_data_tools import get_macro_regime
from tradingagents.agents.utils.tool_runner import run_tool_loop
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -73,16 +75,14 @@ Volume-Based Indicators:
chain = prompt | llm.bind_tools(tools) chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"]) result = run_tool_loop(chain, state["messages"], tools)
report = "" report = result.content or ""
macro_regime_report = "" macro_regime_report = ""
if len(result.tool_calls) == 0: # Extract macro regime section if present
report = result.content if report and ("Macro Regime Classification" in report or "RISK-ON" in report.upper() or "RISK-OFF" in report.upper() or "TRANSITION" in report.upper()):
# Extract macro regime section if present macro_regime_report = report
if "Macro Regime Classification" in report or "RISK-ON" in report.upper() or "RISK-OFF" in report.upper() or "TRANSITION" in report.upper():
macro_regime_report = report
return { return {
"messages": [result], "messages": [result],

View File

@ -1,6 +1,8 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import json import json
from tradingagents.agents.utils.news_data_tools import get_news, get_global_news from tradingagents.agents.utils.news_data_tools import get_news, get_global_news
from tradingagents.agents.utils.tool_runner import run_tool_loop
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -42,12 +44,9 @@ def create_news_analyst(llm):
prompt = prompt.partial(instrument_context=instrument_context) prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools) chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"]) result = run_tool_loop(chain, state["messages"], tools)
report = "" report = result.content or ""
if len(result.tool_calls) == 0:
report = result.content
return { return {
"messages": [result], "messages": [result],

View File

@ -2,6 +2,8 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
from tradingagents.agents.utils.news_data_tools import get_news from tradingagents.agents.utils.news_data_tools import get_news
from tradingagents.agents.utils.tool_runner import run_tool_loop
from tradingagents.agents.utils.agent_utils import build_instrument_context
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -43,12 +45,9 @@ def create_social_media_analyst(llm):
chain = prompt | llm.bind_tools(tools) chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"]) result = run_tool_loop(chain, state["messages"], tools)
report = "" report = result.content or ""
if len(result.tool_calls) == 0:
report = result.content
return { return {
"messages": [result], "messages": [result],