TradingAgents/tests/unit/test_industry_deep_dive.py

172 lines
6.3 KiB
Python

"""Tests for the Industry Deep Dive improvements:
1. _extract_top_sectors() parses sector performance reports correctly
2. run_tool_loop nudge triggers when first response is short & no tool calls
"""
import pytest
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from tradingagents.agents.scanners.industry_deep_dive import (
VALID_SECTOR_KEYS,
_DISPLAY_TO_KEY,
_extract_top_sectors,
)
from tradingagents.agents.utils.tool_runner import (
run_tool_loop,
MAX_TOOL_ROUNDS,
MIN_REPORT_LENGTH,
)
# ---------------------------------------------------------------------------
# _extract_top_sectors tests
# ---------------------------------------------------------------------------
SAMPLE_SECTOR_REPORT = """\
# Sector Performance Overview
# Data retrieved on: 2026-03-17 12:00:00
| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |
|--------|---------|----------|-----------|-------|
| Technology | +0.45% | +1.20% | +5.67% | +12.30% |
| Healthcare | -0.12% | -0.50% | -2.10% | +3.40% |
| Financials | +0.30% | +0.80% | +3.25% | +8.10% |
| Energy | +1.10% | +2.50% | +7.80% | +15.20% |
| Consumer Discretionary | -0.20% | -0.10% | -1.50% | +2.00% |
| Consumer Staples | +0.05% | +0.30% | +0.90% | +4.50% |
| Industrials | +0.25% | +0.60% | +2.80% | +6.70% |
| Materials | +0.40% | +1.00% | +4.20% | +9.30% |
| Real Estate | -0.35% | -0.80% | -3.40% | -1.20% |
| Utilities | +0.10% | +0.20% | +1.10% | +5.60% |
| Communication Services | +0.55% | +1.50% | +6.30% | +11.00% |
"""
class TestExtractTopSectors:
"""Verify _extract_top_sectors parses the table correctly."""
def test_returns_top_3_by_absolute_1month(self):
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=3)
assert len(result) == 3
assert result[0] == "energy"
assert result[1] == "communication-services"
assert result[2] == "technology"
def test_returns_top_n_variable(self):
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=5)
assert len(result) == 5
for key in result:
assert key in VALID_SECTOR_KEYS, f"Invalid key: {key}"
def test_empty_report_returns_defaults(self):
result = _extract_top_sectors("", top_n=3)
assert result == VALID_SECTOR_KEYS[:3]
def test_none_report_returns_defaults(self):
result = _extract_top_sectors(None, top_n=3)
assert result == VALID_SECTOR_KEYS[:3]
def test_garbage_report_returns_defaults(self):
result = _extract_top_sectors("not a table at all\njust random text", top_n=3)
assert result == VALID_SECTOR_KEYS[:3]
def test_negative_returns_sorted_by_absolute_value(self):
"""Sectors with large negative moves should rank high (big movers)."""
report = """\
| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |
|--------|---------|----------|-----------|-------|
| Technology | +0.10% | +0.20% | +1.00% | +2.00% |
| Energy | -0.50% | -1.00% | -8.50% | -5.00% |
| Healthcare | +0.05% | +0.10% | +0.50% | +1.00% |
"""
result = _extract_top_sectors(report, top_n=2)
assert result[0] == "energy"
def test_all_returned_keys_are_valid(self):
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=11)
for key in result:
assert key in VALID_SECTOR_KEYS
def test_display_to_key_covers_all_sectors(self):
display_names = [
"technology", "healthcare", "financials", "energy",
"consumer discretionary", "consumer staples", "industrials",
"materials", "real estate", "utilities", "communication services",
]
for name in display_names:
assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'"
assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS
# ---------------------------------------------------------------------------
# run_tool_loop nudge tests
# ---------------------------------------------------------------------------
class TestToolLoopNudge:
"""Verify the nudge mechanism in run_tool_loop."""
def _make_chain(self, responses):
chain = MagicMock()
chain.invoke = MagicMock(side_effect=responses)
return chain
def _make_tool(self, name="my_tool"):
tool = MagicMock()
tool.name = name
tool.invoke = MagicMock(return_value="tool result")
return tool
def test_long_response_no_nudge(self):
long_text = "A" * 2100
response = AIMessage(content=long_text, tool_calls=[])
chain = self._make_chain([response])
tool = self._make_tool()
result = run_tool_loop(chain, [], [tool])
assert result.content == long_text
assert chain.invoke.call_count == 1
def test_short_response_triggers_nudge(self):
short_resp = AIMessage(content="Brief.", tool_calls=[])
long_resp = AIMessage(content="A" * 2100, tool_calls=[])
chain = self._make_chain([short_resp, long_resp])
tool = self._make_tool()
result = run_tool_loop(chain, [], [tool])
assert result.content == long_resp.content
assert chain.invoke.call_count == 2
second_call_messages = chain.invoke.call_args_list[1][0][0]
nudge_msgs = [m for m in second_call_messages if isinstance(m, HumanMessage)]
assert len(nudge_msgs) == 1
assert "MUST call at least one tool" in nudge_msgs[0].content
def test_nudge_only_on_first_round(self):
tool_call_resp = AIMessage(
content="",
tool_calls=[{"name": "my_tool", "args": {}, "id": "tc1"}],
)
short_resp = AIMessage(content="Done.", tool_calls=[])
chain = self._make_chain([tool_call_resp, short_resp])
tool = self._make_tool()
result = run_tool_loop(chain, [], [tool])
assert result.content == "Done."
assert chain.invoke.call_count == 2
def test_tool_calls_execute_normally(self):
tool_call_resp = AIMessage(
content="",
tool_calls=[{"name": "my_tool", "args": {"x": 1}, "id": "tc1"}],
)
final_resp = AIMessage(content="Final report" * 50, tool_calls=[])
chain = self._make_chain([tool_call_resp, final_resp])
tool = self._make_tool()
result = run_tool_loop(chain, [], [tool])
tool.invoke.assert_called_once_with({"x": 1})
assert "Final report" in result.content