Merge pull request #91 from aguzererler/fix-analyst-tool-loop-18174774985345323969
fix: use run_tool_loop instead of invoke in analyst agents
This commit is contained in:
commit
41681e0f9e
|
|
@ -0,0 +1,73 @@
|
|||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from tradingagents.agents.analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
||||
from tradingagents.agents.analysts.social_media_analyst import create_social_media_analyst
|
||||
from tradingagents.agents.analysts.news_analyst import create_news_analyst
|
||||
|
||||
class MockRunnable(Runnable):
|
||||
def __init__(self, invoke_responses):
|
||||
self.invoke_responses = invoke_responses
|
||||
self.call_count = 0
|
||||
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
response = self.invoke_responses[self.call_count]
|
||||
self.call_count += 1
|
||||
return response
|
||||
|
||||
class MockLLM(Runnable):
|
||||
def __init__(self, invoke_responses):
|
||||
self.runnable = MockRunnable(invoke_responses)
|
||||
self.tools_bound = None
|
||||
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return self.runnable.invoke(input, config=config, **kwargs)
|
||||
|
||||
def bind_tools(self, tools):
|
||||
self.tools_bound = tools
|
||||
return self.runnable
|
||||
|
||||
@pytest.fixture
|
||||
def mock_state():
|
||||
return {
|
||||
"messages": [HumanMessage(content="Analyze AAPL.")],
|
||||
"trade_date": "2024-05-15",
|
||||
"company_of_interest": "AAPL",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_with_tool_call():
|
||||
# 1. First call: The LLM decides to use a tool
|
||||
tool_call_msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "mock_tool", "args": {"query": "test"}, "id": "call_123"}
|
||||
]
|
||||
)
|
||||
# 2. Second call: The LLM receives the tool output and writes the report
|
||||
final_report_msg = AIMessage(
|
||||
content="This is the final report after running the tool."
|
||||
)
|
||||
return MockLLM([tool_call_msg, final_report_msg])
|
||||
|
||||
def test_fundamentals_analyst_tool_loop(mock_state, mock_llm_with_tool_call):
|
||||
node = create_fundamentals_analyst(mock_llm_with_tool_call)
|
||||
result = node(mock_state)
|
||||
assert "This is the final report after running the tool." in result["fundamentals_report"]
|
||||
|
||||
def test_market_analyst_tool_loop(mock_state, mock_llm_with_tool_call):
|
||||
node = create_market_analyst(mock_llm_with_tool_call)
|
||||
result = node(mock_state)
|
||||
assert "This is the final report after running the tool." in result["market_report"]
|
||||
|
||||
def test_social_media_analyst_tool_loop(mock_state, mock_llm_with_tool_call):
|
||||
node = create_social_media_analyst(mock_llm_with_tool_call)
|
||||
result = node(mock_state)
|
||||
assert "This is the final report after running the tool." in result["sentiment_report"]
|
||||
|
||||
def test_news_analyst_tool_loop(mock_state, mock_llm_with_tool_call):
|
||||
node = create_news_analyst(mock_llm_with_tool_call)
|
||||
result = node(mock_state)
|
||||
assert "This is the final report after running the tool." in result["news_report"]
|
||||
|
|
@ -11,6 +11,8 @@ from tradingagents.agents.utils.fundamental_data_tools import (
|
|||
get_sector_relative,
|
||||
)
|
||||
from tradingagents.agents.utils.news_data_tools import get_insider_transactions
|
||||
from tradingagents.agents.utils.tool_runner import run_tool_loop
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -66,12 +68,9 @@ def create_fundamentals_analyst(llm):
|
|||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
result = run_tool_loop(chain, state["messages"], tools)
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
report = result.content or ""
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import time
|
|||
from tradingagents.agents.utils.core_stock_tools import get_stock_data
|
||||
from tradingagents.agents.utils.technical_indicators_tools import get_indicators
|
||||
from tradingagents.agents.utils.fundamental_data_tools import get_macro_regime
|
||||
from tradingagents.agents.utils.tool_runner import run_tool_loop
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -73,16 +75,14 @@ Volume-Based Indicators:
|
|||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
result = run_tool_loop(chain, state["messages"], tools)
|
||||
|
||||
report = ""
|
||||
report = result.content or ""
|
||||
macro_regime_report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
# Extract macro regime section if present
|
||||
if "Macro Regime Classification" in report or "RISK-ON" in report.upper() or "RISK-OFF" in report.upper() or "TRANSITION" in report.upper():
|
||||
macro_regime_report = report
|
||||
# Extract macro regime section if present
|
||||
if report and ("Macro Regime Classification" in report or "RISK-ON" in report.upper() or "RISK-OFF" in report.upper() or "TRANSITION" in report.upper()):
|
||||
macro_regime_report = report
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import json
|
||||
from tradingagents.agents.utils.news_data_tools import get_news, get_global_news
|
||||
from tradingagents.agents.utils.tool_runner import run_tool_loop
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -42,12 +44,9 @@ def create_news_analyst(llm):
|
|||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
result = run_tool_loop(chain, state["messages"], tools)
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
report = result.content or ""
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.news_data_tools import get_news
|
||||
from tradingagents.agents.utils.tool_runner import run_tool_loop
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -43,12 +45,9 @@ def create_social_media_analyst(llm):
|
|||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
result = run_tool_loop(chain, state["messages"], tools)
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
report = result.content or ""
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
Loading…
Reference in New Issue