133 lines
4.5 KiB
Python
133 lines
4.5 KiB
Python
from types import SimpleNamespace
|
|
from unittest.mock import patch
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
from langchain_core.runnables import Runnable
|
|
|
|
from tradingagents.agents.scanners.drift_scanner import create_drift_scanner
|
|
from tradingagents.agents.scanners.factor_alignment_scanner import (
|
|
create_factor_alignment_scanner,
|
|
)
|
|
|
|
|
|
class MockRunnable(Runnable):
|
|
def __init__(self, invoke_responses):
|
|
self.invoke_responses = invoke_responses
|
|
self.call_count = 0
|
|
|
|
def invoke(self, input, config=None, **kwargs):
|
|
response = self.invoke_responses[self.call_count]
|
|
self.call_count += 1
|
|
return response
|
|
|
|
|
|
class MockLLM(Runnable):
|
|
def __init__(self, invoke_responses):
|
|
self.runnable = MockRunnable(invoke_responses)
|
|
self.tools_bound = None
|
|
|
|
def invoke(self, input, config=None, **kwargs):
|
|
return self.runnable.invoke(input, config=config, **kwargs)
|
|
|
|
def bind_tools(self, tools):
|
|
self.tools_bound = tools
|
|
return self.runnable
|
|
|
|
|
|
def _base_state():
|
|
return {
|
|
"messages": [HumanMessage(content="Run the market scan.")],
|
|
"scan_date": "2026-03-27",
|
|
"sector_performance_report": "| Sector | 1-Month % |\n| Technology | +5.0% |",
|
|
"market_movers_report": "| Symbol | Change % |\n| NVDA | +4.0% |",
|
|
}
|
|
|
|
|
|
def test_factor_alignment_scanner_end_to_end():
|
|
llm = MockLLM(
|
|
[
|
|
AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{"name": "get_topic_news", "args": {"topic": "analyst upgrades downgrades", "limit": 3}, "id": "tc1"},
|
|
{"name": "get_topic_news", "args": {"topic": "earnings estimate revisions", "limit": 3}, "id": "tc2"},
|
|
{"name": "get_earnings_calendar", "args": {"from_date": "2026-03-27", "to_date": "2026-04-17"}, "id": "tc3"},
|
|
],
|
|
),
|
|
AIMessage(content="Factor alignment report with globally surfaced tickers."),
|
|
]
|
|
)
|
|
|
|
topic_tool = SimpleNamespace(
|
|
name="get_topic_news",
|
|
invoke=lambda args: "analyst news" if "analyst" in args["topic"] else "revision news",
|
|
)
|
|
earnings_tool = SimpleNamespace(
|
|
name="get_earnings_calendar",
|
|
invoke=lambda args: "earnings calendar",
|
|
)
|
|
|
|
with patch(
|
|
"tradingagents.agents.scanners.factor_alignment_scanner.get_topic_news",
|
|
topic_tool,
|
|
), patch(
|
|
"tradingagents.agents.scanners.factor_alignment_scanner.get_earnings_calendar",
|
|
earnings_tool,
|
|
):
|
|
node = create_factor_alignment_scanner(llm)
|
|
result = node(_base_state())
|
|
|
|
assert "Factor alignment report" in result["factor_alignment_report"]
|
|
assert result["sender"] == "factor_alignment_scanner"
|
|
assert [tool.name for tool in llm.tools_bound] == ["get_topic_news", "get_earnings_calendar"]
|
|
|
|
|
|
def test_drift_scanner_end_to_end():
|
|
llm = MockLLM(
|
|
[
|
|
AIMessage(
|
|
content="",
|
|
tool_calls=[
|
|
{"name": "get_gap_candidates", "args": {}, "id": "tc1"},
|
|
{"name": "get_topic_news", "args": {"topic": "earnings beats raised guidance", "limit": 3}, "id": "tc2"},
|
|
{"name": "get_earnings_calendar", "args": {"from_date": "2026-03-27", "to_date": "2026-04-10"}, "id": "tc3"},
|
|
],
|
|
),
|
|
AIMessage(content="Drift opportunities report with continuation setups."),
|
|
]
|
|
)
|
|
|
|
gap_tool = SimpleNamespace(
|
|
name="get_gap_candidates",
|
|
invoke=lambda args: "gap candidates table",
|
|
)
|
|
topic_tool = SimpleNamespace(
|
|
name="get_topic_news",
|
|
invoke=lambda args: "continuation news",
|
|
)
|
|
earnings_tool = SimpleNamespace(
|
|
name="get_earnings_calendar",
|
|
invoke=lambda args: "earnings calendar",
|
|
)
|
|
|
|
with patch(
|
|
"tradingagents.agents.scanners.drift_scanner.get_gap_candidates",
|
|
gap_tool,
|
|
), patch(
|
|
"tradingagents.agents.scanners.drift_scanner.get_topic_news",
|
|
topic_tool,
|
|
), patch(
|
|
"tradingagents.agents.scanners.drift_scanner.get_earnings_calendar",
|
|
earnings_tool,
|
|
):
|
|
node = create_drift_scanner(llm)
|
|
result = node(_base_state())
|
|
|
|
assert "Drift opportunities report" in result["drift_opportunities_report"]
|
|
assert result["sender"] == "drift_scanner"
|
|
assert [tool.name for tool in llm.tools_bound] == [
|
|
"get_gap_candidates",
|
|
"get_topic_news",
|
|
"get_earnings_calendar",
|
|
]
|