TradingAgents/tests/unit/test_global_search_scanners.py

133 lines
4.5 KiB
Python

from types import SimpleNamespace
from unittest.mock import patch
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable
from tradingagents.agents.scanners.drift_scanner import create_drift_scanner
from tradingagents.agents.scanners.factor_alignment_scanner import (
create_factor_alignment_scanner,
)
class MockRunnable(Runnable):
def __init__(self, invoke_responses):
self.invoke_responses = invoke_responses
self.call_count = 0
def invoke(self, input, config=None, **kwargs):
response = self.invoke_responses[self.call_count]
self.call_count += 1
return response
class MockLLM(Runnable):
def __init__(self, invoke_responses):
self.runnable = MockRunnable(invoke_responses)
self.tools_bound = None
def invoke(self, input, config=None, **kwargs):
return self.runnable.invoke(input, config=config, **kwargs)
def bind_tools(self, tools):
self.tools_bound = tools
return self.runnable
def _base_state():
return {
"messages": [HumanMessage(content="Run the market scan.")],
"scan_date": "2026-03-27",
"sector_performance_report": "| Sector | 1-Month % |\n| Technology | +5.0% |",
"market_movers_report": "| Symbol | Change % |\n| NVDA | +4.0% |",
}
def test_factor_alignment_scanner_end_to_end():
llm = MockLLM(
[
AIMessage(
content="",
tool_calls=[
{"name": "get_topic_news", "args": {"topic": "analyst upgrades downgrades", "limit": 3}, "id": "tc1"},
{"name": "get_topic_news", "args": {"topic": "earnings estimate revisions", "limit": 3}, "id": "tc2"},
{"name": "get_earnings_calendar", "args": {"from_date": "2026-03-27", "to_date": "2026-04-17"}, "id": "tc3"},
],
),
AIMessage(content="Factor alignment report with globally surfaced tickers."),
]
)
topic_tool = SimpleNamespace(
name="get_topic_news",
invoke=lambda args: "analyst news" if "analyst" in args["topic"] else "revision news",
)
earnings_tool = SimpleNamespace(
name="get_earnings_calendar",
invoke=lambda args: "earnings calendar",
)
with patch(
"tradingagents.agents.scanners.factor_alignment_scanner.get_topic_news",
topic_tool,
), patch(
"tradingagents.agents.scanners.factor_alignment_scanner.get_earnings_calendar",
earnings_tool,
):
node = create_factor_alignment_scanner(llm)
result = node(_base_state())
assert "Factor alignment report" in result["factor_alignment_report"]
assert result["sender"] == "factor_alignment_scanner"
assert [tool.name for tool in llm.tools_bound] == ["get_topic_news", "get_earnings_calendar"]
def test_drift_scanner_end_to_end():
llm = MockLLM(
[
AIMessage(
content="",
tool_calls=[
{"name": "get_gap_candidates", "args": {}, "id": "tc1"},
{"name": "get_topic_news", "args": {"topic": "earnings beats raised guidance", "limit": 3}, "id": "tc2"},
{"name": "get_earnings_calendar", "args": {"from_date": "2026-03-27", "to_date": "2026-04-10"}, "id": "tc3"},
],
),
AIMessage(content="Drift opportunities report with continuation setups."),
]
)
gap_tool = SimpleNamespace(
name="get_gap_candidates",
invoke=lambda args: "gap candidates table",
)
topic_tool = SimpleNamespace(
name="get_topic_news",
invoke=lambda args: "continuation news",
)
earnings_tool = SimpleNamespace(
name="get_earnings_calendar",
invoke=lambda args: "earnings calendar",
)
with patch(
"tradingagents.agents.scanners.drift_scanner.get_gap_candidates",
gap_tool,
), patch(
"tradingagents.agents.scanners.drift_scanner.get_topic_news",
topic_tool,
), patch(
"tradingagents.agents.scanners.drift_scanner.get_earnings_calendar",
earnings_tool,
):
node = create_drift_scanner(llm)
result = node(_base_state())
assert "Drift opportunities report" in result["drift_opportunities_report"]
assert result["sender"] == "drift_scanner"
assert [tool.name for tool in llm.tools_bound] == [
"get_gap_candidates",
"get_topic_news",
"get_earnings_calendar",
]