172 lines
6.3 KiB
Python
172 lines
6.3 KiB
Python
"""Tests for the Industry Deep Dive improvements:
|
|
|
|
1. _extract_top_sectors() parses sector performance reports correctly
|
|
2. run_tool_loop nudge triggers when first response is short & no tool calls
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
|
|
|
from tradingagents.agents.scanners.industry_deep_dive import (
|
|
VALID_SECTOR_KEYS,
|
|
_DISPLAY_TO_KEY,
|
|
_extract_top_sectors,
|
|
)
|
|
from tradingagents.agents.utils.tool_runner import (
|
|
run_tool_loop,
|
|
MAX_TOOL_ROUNDS,
|
|
MIN_REPORT_LENGTH,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _extract_top_sectors tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
SAMPLE_SECTOR_REPORT = """\
|
|
# Sector Performance Overview
|
|
# Data retrieved on: 2026-03-17 12:00:00
|
|
|
|
| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |
|
|
|--------|---------|----------|-----------|-------|
|
|
| Technology | +0.45% | +1.20% | +5.67% | +12.30% |
|
|
| Healthcare | -0.12% | -0.50% | -2.10% | +3.40% |
|
|
| Financials | +0.30% | +0.80% | +3.25% | +8.10% |
|
|
| Energy | +1.10% | +2.50% | +7.80% | +15.20% |
|
|
| Consumer Discretionary | -0.20% | -0.10% | -1.50% | +2.00% |
|
|
| Consumer Staples | +0.05% | +0.30% | +0.90% | +4.50% |
|
|
| Industrials | +0.25% | +0.60% | +2.80% | +6.70% |
|
|
| Materials | +0.40% | +1.00% | +4.20% | +9.30% |
|
|
| Real Estate | -0.35% | -0.80% | -3.40% | -1.20% |
|
|
| Utilities | +0.10% | +0.20% | +1.10% | +5.60% |
|
|
| Communication Services | +0.55% | +1.50% | +6.30% | +11.00% |
|
|
"""
|
|
|
|
|
|
class TestExtractTopSectors:
|
|
"""Verify _extract_top_sectors parses the table correctly."""
|
|
|
|
def test_returns_top_3_by_absolute_1month(self):
|
|
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=3)
|
|
assert len(result) == 3
|
|
assert result[0] == "energy"
|
|
assert result[1] == "communication-services"
|
|
assert result[2] == "technology"
|
|
|
|
def test_returns_top_n_variable(self):
|
|
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=5)
|
|
assert len(result) == 5
|
|
for key in result:
|
|
assert key in VALID_SECTOR_KEYS, f"Invalid key: {key}"
|
|
|
|
def test_empty_report_returns_defaults(self):
|
|
result = _extract_top_sectors("", top_n=3)
|
|
assert result == VALID_SECTOR_KEYS[:3]
|
|
|
|
def test_none_report_returns_defaults(self):
|
|
result = _extract_top_sectors(None, top_n=3)
|
|
assert result == VALID_SECTOR_KEYS[:3]
|
|
|
|
def test_garbage_report_returns_defaults(self):
|
|
result = _extract_top_sectors("not a table at all\njust random text", top_n=3)
|
|
assert result == VALID_SECTOR_KEYS[:3]
|
|
|
|
def test_negative_returns_sorted_by_absolute_value(self):
|
|
"""Sectors with large negative moves should rank high (big movers)."""
|
|
report = """\
|
|
| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |
|
|
|--------|---------|----------|-----------|-------|
|
|
| Technology | +0.10% | +0.20% | +1.00% | +2.00% |
|
|
| Energy | -0.50% | -1.00% | -8.50% | -5.00% |
|
|
| Healthcare | +0.05% | +0.10% | +0.50% | +1.00% |
|
|
"""
|
|
result = _extract_top_sectors(report, top_n=2)
|
|
assert result[0] == "energy"
|
|
|
|
def test_all_returned_keys_are_valid(self):
|
|
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=11)
|
|
for key in result:
|
|
assert key in VALID_SECTOR_KEYS
|
|
|
|
def test_display_to_key_covers_all_sectors(self):
|
|
display_names = [
|
|
"technology", "healthcare", "financials", "energy",
|
|
"consumer discretionary", "consumer staples", "industrials",
|
|
"materials", "real estate", "utilities", "communication services",
|
|
]
|
|
for name in display_names:
|
|
assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'"
|
|
assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# run_tool_loop nudge tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestToolLoopNudge:
|
|
"""Verify the nudge mechanism in run_tool_loop."""
|
|
|
|
def _make_chain(self, responses):
|
|
chain = MagicMock()
|
|
chain.invoke = MagicMock(side_effect=responses)
|
|
return chain
|
|
|
|
def _make_tool(self, name="my_tool"):
|
|
tool = MagicMock()
|
|
tool.name = name
|
|
tool.invoke = MagicMock(return_value="tool result")
|
|
return tool
|
|
|
|
def test_long_response_no_nudge(self):
|
|
long_text = "A" * 2100
|
|
response = AIMessage(content=long_text, tool_calls=[])
|
|
chain = self._make_chain([response])
|
|
tool = self._make_tool()
|
|
|
|
result = run_tool_loop(chain, [], [tool])
|
|
assert result.content == long_text
|
|
assert chain.invoke.call_count == 1
|
|
|
|
def test_short_response_triggers_nudge(self):
|
|
short_resp = AIMessage(content="Brief.", tool_calls=[])
|
|
long_resp = AIMessage(content="A" * 2100, tool_calls=[])
|
|
chain = self._make_chain([short_resp, long_resp])
|
|
tool = self._make_tool()
|
|
|
|
result = run_tool_loop(chain, [], [tool])
|
|
assert result.content == long_resp.content
|
|
assert chain.invoke.call_count == 2
|
|
|
|
second_call_messages = chain.invoke.call_args_list[1][0][0]
|
|
nudge_msgs = [m for m in second_call_messages if isinstance(m, HumanMessage)]
|
|
assert len(nudge_msgs) == 1
|
|
assert "MUST call at least one tool" in nudge_msgs[0].content
|
|
|
|
def test_nudge_only_on_first_round(self):
|
|
tool_call_resp = AIMessage(
|
|
content="",
|
|
tool_calls=[{"name": "my_tool", "args": {}, "id": "tc1"}],
|
|
)
|
|
short_resp = AIMessage(content="Done.", tool_calls=[])
|
|
chain = self._make_chain([tool_call_resp, short_resp])
|
|
tool = self._make_tool()
|
|
|
|
result = run_tool_loop(chain, [], [tool])
|
|
assert result.content == "Done."
|
|
assert chain.invoke.call_count == 2
|
|
|
|
def test_tool_calls_execute_normally(self):
|
|
tool_call_resp = AIMessage(
|
|
content="",
|
|
tool_calls=[{"name": "my_tool", "args": {"x": 1}, "id": "tc1"}],
|
|
)
|
|
final_resp = AIMessage(content="Final report" * 50, tool_calls=[])
|
|
chain = self._make_chain([tool_call_resp, final_resp])
|
|
tool = self._make_tool()
|
|
|
|
result = run_tool_loop(chain, [], [tool])
|
|
tool.invoke.assert_called_once_with({"x": 1})
|
|
assert "Final report" in result.content
|