feat: improve Industry Deep Dive report quality with enriched data, sector routing, and tool-call nudge

* Initial plan

* Improve Industry Deep Dive quality: enrich tool data, explicit sector keys, tool-call nudge

- Enrich get_industry_performance_yfinance with 1-day/1-week/1-month price returns
  via batched yf.download() for top 10 tickers (Step 1)
- Add VALID_SECTOR_KEYS, _DISPLAY_TO_KEY, _extract_top_sectors() to industry_deep_dive.py
  to pre-extract top sectors from Phase 1 report and inject them into the prompt (Step 2)
- Add tool-call nudge to run_tool_loop: if first LLM response has no tool calls and is
  under 500 chars, re-prompt with explicit instruction to call tools (Step 3)
- Update scanner_tools.py get_industry_performance docstring to list all valid sector keys (Step 4)
- Add 15 unit tests covering _extract_top_sectors, tool_runner nudge, and enriched output (Step 5)

Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>

* Address code review: move cols[3] access into try block for IndexError safety

Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>

* fix: align display row count with download count in get_industry_performance_yfinance

The enriched function downloads price data for top 10 tickers but displayed
20 rows, causing rows 11-20 to show N/A in all price columns. This broke
test_industry_perf_falls_back_to_yfinance which asserts N/A count < 5.
Now both download and display use head(10) for consistency.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
Co-authored-by: Ahmet Guzererler <guzererler@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Copilot 2026-03-17 20:10:45 +01:00 committed by GitHub
parent 8279295348
commit 1362781291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 424 additions and 19 deletions

View File

@ -0,0 +1,242 @@
"""Tests for the Industry Deep Dive improvements:
1. _extract_top_sectors() parses sector performance reports correctly
2. Enriched get_industry_performance_yfinance returns price columns
3. run_tool_loop nudge triggers when first response is short & no tool calls
"""
import pytest
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from tradingagents.agents.scanners.industry_deep_dive import (
VALID_SECTOR_KEYS,
_DISPLAY_TO_KEY,
_extract_top_sectors,
)
from tradingagents.agents.utils.tool_runner import (
run_tool_loop,
MAX_TOOL_ROUNDS,
MIN_REPORT_LENGTH,
)
# ---------------------------------------------------------------------------
# _extract_top_sectors tests
# ---------------------------------------------------------------------------
SAMPLE_SECTOR_REPORT = """\
# Sector Performance Overview
# Data retrieved on: 2026-03-17 12:00:00
| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |
|--------|---------|----------|-----------|-------|
| Technology | +0.45% | +1.20% | +5.67% | +12.30% |
| Healthcare | -0.12% | -0.50% | -2.10% | +3.40% |
| Financials | +0.30% | +0.80% | +3.25% | +8.10% |
| Energy | +1.10% | +2.50% | +7.80% | +15.20% |
| Consumer Discretionary | -0.20% | -0.10% | -1.50% | +2.00% |
| Consumer Staples | +0.05% | +0.30% | +0.90% | +4.50% |
| Industrials | +0.25% | +0.60% | +2.80% | +6.70% |
| Materials | +0.40% | +1.00% | +4.20% | +9.30% |
| Real Estate | -0.35% | -0.80% | -3.40% | -1.20% |
| Utilities | +0.10% | +0.20% | +1.10% | +5.60% |
| Communication Services | +0.55% | +1.50% | +6.30% | +11.00% |
"""
class TestExtractTopSectors:
"""Verify _extract_top_sectors parses the table correctly."""
def test_returns_top_3_by_absolute_1month(self):
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=3)
assert len(result) == 3
# Energy (+7.80%), Communication Services (+6.30%), Technology (+5.67%)
assert result[0] == "energy"
assert result[1] == "communication-services"
assert result[2] == "technology"
def test_returns_top_n_variable(self):
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=5)
assert len(result) == 5
# All should be valid sector keys
for key in result:
assert key in VALID_SECTOR_KEYS, f"Invalid key: {key}"
def test_empty_report_returns_defaults(self):
result = _extract_top_sectors("", top_n=3)
assert result == VALID_SECTOR_KEYS[:3]
def test_none_report_returns_defaults(self):
result = _extract_top_sectors(None, top_n=3)
assert result == VALID_SECTOR_KEYS[:3]
def test_garbage_report_returns_defaults(self):
result = _extract_top_sectors("not a table at all\njust random text", top_n=3)
assert result == VALID_SECTOR_KEYS[:3]
def test_negative_returns_sorted_by_absolute_value(self):
"""Sectors with large negative moves should rank high (big movers)."""
report = """\
| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |
|--------|---------|----------|-----------|-------|
| Technology | +0.10% | +0.20% | +1.00% | +2.00% |
| Energy | -0.50% | -1.00% | -8.50% | -5.00% |
| Healthcare | +0.05% | +0.10% | +0.50% | +1.00% |
"""
result = _extract_top_sectors(report, top_n=2)
assert result[0] == "energy" # |-8.50| > |1.00|
def test_all_returned_keys_are_valid(self):
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=11)
for key in result:
assert key in VALID_SECTOR_KEYS
def test_display_to_key_covers_all_sectors(self):
"""Every sector name that appears in the ETF performance table
should map to a valid key."""
display_names = [
"technology", "healthcare", "financials", "energy",
"consumer discretionary", "consumer staples", "industrials",
"materials", "real estate", "utilities", "communication services",
]
for name in display_names:
assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'"
assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS
# ---------------------------------------------------------------------------
# run_tool_loop nudge tests
# ---------------------------------------------------------------------------
class TestToolLoopNudge:
"""Verify the nudge mechanism in run_tool_loop."""
def _make_chain(self, responses):
"""Create a mock chain that returns responses in sequence."""
chain = MagicMock()
chain.invoke = MagicMock(side_effect=responses)
return chain
def _make_tool(self, name="my_tool"):
tool = MagicMock()
tool.name = name
tool.invoke = MagicMock(return_value="tool result")
return tool
def test_long_response_no_nudge(self):
"""A long first response (no tool calls) should be returned as-is."""
long_text = "A" * 600
response = AIMessage(content=long_text, tool_calls=[])
chain = self._make_chain([response])
tool = self._make_tool()
result = run_tool_loop(chain, [], [tool])
assert result.content == long_text
assert chain.invoke.call_count == 1
def test_short_response_triggers_nudge(self):
"""A short first response triggers a nudge, then the LLM is re-invoked."""
short_resp = AIMessage(content="Brief.", tool_calls=[])
long_resp = AIMessage(content="A" * 600, tool_calls=[])
chain = self._make_chain([short_resp, long_resp])
tool = self._make_tool()
result = run_tool_loop(chain, [], [tool])
assert result.content == long_resp.content
assert chain.invoke.call_count == 2
# The second invoke should have received a HumanMessage nudge
second_call_messages = chain.invoke.call_args_list[1][0][0]
nudge_msgs = [m for m in second_call_messages if isinstance(m, HumanMessage)]
assert len(nudge_msgs) == 1
assert "MUST call at least one tool" in nudge_msgs[0].content
def test_nudge_only_on_first_round(self):
"""Nudge should NOT trigger after tools have been used."""
# Round 1: LLM calls a tool
tool_call_resp = AIMessage(
content="",
tool_calls=[{"name": "my_tool", "args": {}, "id": "tc1"}],
)
# Round 2: LLM returns a short text — no nudge expected
short_resp = AIMessage(content="Done.", tool_calls=[])
chain = self._make_chain([tool_call_resp, short_resp])
tool = self._make_tool()
result = run_tool_loop(chain, [], [tool])
assert result.content == "Done."
assert chain.invoke.call_count == 2
def test_tool_calls_execute_normally(self):
"""Normal tool-calling flow should still work unchanged."""
tool_call_resp = AIMessage(
content="",
tool_calls=[{"name": "my_tool", "args": {"x": 1}, "id": "tc1"}],
)
final_resp = AIMessage(content="Final report" * 50, tool_calls=[])
chain = self._make_chain([tool_call_resp, final_resp])
tool = self._make_tool()
result = run_tool_loop(chain, [], [tool])
tool.invoke.assert_called_once_with({"x": 1})
assert "Final report" in result.content
# ---------------------------------------------------------------------------
# Enriched industry performance tests
# ---------------------------------------------------------------------------
class TestEnrichedIndustryPerformance:
"""Verify that get_industry_performance_yfinance now returns price columns.
These tests require network access to Yahoo Finance. If the host is not
reachable (e.g. in sandboxed CI), they are automatically skipped.
"""
@pytest.fixture(autouse=True)
def _require_yahoo(self):
import socket
try:
socket.setdefaulttimeout(3)
socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect(
("query2.finance.yahoo.com", 443)
)
except (socket.error, OSError):
pytest.skip("Yahoo Finance not reachable")
def test_technology_has_price_columns(self):
from tradingagents.dataflows.yfinance_scanner import (
get_industry_performance_yfinance,
)
result = get_industry_performance_yfinance("technology")
assert "# Industry Performance: Technology" in result
# New columns should be present in the header
assert "1-Day %" in result
assert "1-Week %" in result
assert "1-Month %" in result
def test_table_has_seven_columns(self):
from tradingagents.dataflows.yfinance_scanner import (
get_industry_performance_yfinance,
)
result = get_industry_performance_yfinance("technology")
lines = result.strip().split("\n")
# Find the header separator line
sep_lines = [l for l in lines if l.startswith("|") and "---" in l]
assert len(sep_lines) >= 1
# Count columns in separator
cols = [c.strip() for c in sep_lines[0].split("|")[1:-1]]
assert len(cols) == 7, f"Expected 7 columns, got {len(cols)}: {cols}"
def test_healthcare_sector_key(self):
from tradingagents.dataflows.yfinance_scanner import (
get_industry_performance_yfinance,
)
result = get_industry_performance_yfinance("healthcare")
assert "Industry Performance: Healthcare" in result
assert "1-Day %" in result

View File

@ -1,7 +1,85 @@
from __future__ import annotations
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news
from tradingagents.agents.utils.tool_runner import run_tool_loop
# All valid sector keys accepted by yfinance Sector() and get_industry_performance.
VALID_SECTOR_KEYS = [
"technology",
"healthcare",
"financial-services",
"energy",
"consumer-cyclical",
"consumer-defensive",
"industrials",
"basic-materials",
"real-estate",
"utilities",
"communication-services",
]
# Map display names used in the sector performance report to valid keys.
_DISPLAY_TO_KEY = {
"technology": "technology",
"healthcare": "healthcare",
"financials": "financial-services",
"financial services": "financial-services",
"energy": "energy",
"consumer discretionary": "consumer-cyclical",
"consumer staples": "consumer-defensive",
"industrials": "industrials",
"materials": "basic-materials",
"basic materials": "basic-materials",
"real estate": "real-estate",
"utilities": "utilities",
"communication services": "communication-services",
}
def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]:
"""Parse the sector performance report and return the *top_n* sector keys
ranked by absolute 1-month performance (largest absolute move first).
The sector performance table looks like:
| Technology | +0.45% | +1.20% | +5.67% | +12.3% |
We parse the 1-month column (index 3) and sort by absolute value.
Returns a list of valid sector keys (e.g. ``["technology", "energy"]``).
Falls back to a sensible default if parsing fails.
"""
if not sector_report:
return VALID_SECTOR_KEYS[:top_n]
rows: list[tuple[str, float]] = []
for line in sector_report.split("\n"):
if not line.startswith("|"):
continue
cols = [c.strip() for c in line.split("|")[1:-1]]
if len(cols) < 4:
continue
sector_name = cols[0].lower()
if sector_name in ("sector", "---", "") or "---" in sector_name:
continue
# Try to parse the 1-month column (index 3)
try:
month_str = cols[3].replace("%", "").replace("+", "").strip()
month_val = float(month_str)
except (ValueError, IndexError):
continue
key = _DISPLAY_TO_KEY.get(sector_name)
if key:
rows.append((key, month_val))
if not rows:
return VALID_SECTOR_KEYS[:top_n]
# Sort by absolute 1-month move (biggest mover first)
rows.sort(key=lambda r: abs(r[1]), reverse=True)
return [r[0] for r in rows[:top_n]]
def create_industry_deep_dive(llm):
def industry_deep_dive_node(state):
@ -9,6 +87,9 @@ def create_industry_deep_dive(llm):
tools = [get_industry_performance, get_topic_news]
sector_report = state.get("sector_performance_report", "")
top_sectors = _extract_top_sectors(sector_report, top_n=3)
# Inject Phase 1 context so the LLM can decide which sectors to drill into
phase1_context = f"""## Phase 1 Scanner Reports (for your reference)
@ -19,20 +100,29 @@ def create_industry_deep_dive(llm):
{state.get("market_movers_report", "Not available")}
### Sector Performance Report:
{state.get("sector_performance_report", "Not available")}
{sector_report or "Not available"}
"""
sector_list_str = ", ".join(f"'{s}'" for s in top_sectors)
all_keys_str = ", ".join(f"'{s}'" for s in VALID_SECTOR_KEYS)
system_message = (
"You are a senior research analyst performing an industry deep dive. "
"You have received reports from three parallel scanners (geopolitical, market movers, sector performance). "
"Review these reports and identify the 2-3 most promising sectors/industries to investigate further. "
"Use get_industry_performance to drill into those sectors and get_topic_news for sector-specific news. "
"Write a detailed report covering: "
"(1) Why these industries were selected, "
"(2) Top companies within each industry and their recent performance, "
"(3) Industry-specific catalysts and risks, "
"(4) Cross-references between geopolitical events and sector opportunities."
f"\n\n{phase1_context}"
"You are a senior research analyst performing an industry deep dive.\n\n"
"## Your task\n"
"Based on the Phase 1 reports below, drill into the most interesting sectors "
"using the tools provided and write a detailed analysis.\n\n"
"## IMPORTANT — You MUST call tools before writing your report\n"
f"1. Call get_industry_performance for EACH of these top sectors: {sector_list_str}\n"
"2. Call get_topic_news for at least 2 sector-specific topics "
"(e.g., 'semiconductor industry', 'renewable energy stocks').\n"
"3. After receiving tool results, write your detailed report.\n\n"
f"Valid sector_key values for get_industry_performance: {all_keys_str}\n\n"
"## Report structure\n"
"(1) Why these industries were selected (link to Phase 1 findings)\n"
"(2) Top companies within each industry and their recent performance\n"
"(3) Industry-specific catalysts and risks\n"
"(4) Cross-references between geopolitical events and sector opportunities\n\n"
f"{phase1_context}"
)
prompt = ChatPromptTemplate.from_messages(

View File

@ -52,11 +52,15 @@ def get_industry_performance(
) -> str:
"""
Get industry-level drill-down within a specific sector.
Shows top companies and industries in the sector.
Shows top companies with rating, market weight, and recent price performance
(1-day, 1-week, 1-month returns).
Uses the configured scanner_data vendor.
Args:
sector_key (str): Sector identifier (e.g., 'technology', 'healthcare', 'energy')
sector_key (str): Sector identifier. Must be one of:
'technology', 'healthcare', 'financial-services', 'energy',
'consumer-cyclical', 'consumer-defensive', 'industrials',
'basic-materials', 'real-estate', 'utilities', 'communication-services'
Returns:
str: Formatted table of top companies/industries in the sector with performance data

View File

@ -9,42 +9,72 @@ from __future__ import annotations
from typing import Any, List
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
# Most LLM tool-calling patterns resolve within 2-3 rounds;
# 5 provides headroom for complex scenarios while preventing runaway loops.
MAX_TOOL_ROUNDS = 5
# If the LLM's first response has no tool calls AND is shorter than this,
# a nudge message is appended to encourage tool usage.
MIN_REPORT_LENGTH = 500
def run_tool_loop(
chain,
messages: List[Any],
tools: List[Any],
max_rounds: int = MAX_TOOL_ROUNDS,
min_report_length: int = MIN_REPORT_LENGTH,
) -> AIMessage:
"""Invoke *chain* in a loop, executing any tool calls until the LLM
produces a final text response (i.e. no more tool_calls).
If the very first LLM response contains no tool calls **and** the text
is shorter than *min_report_length*, the loop appends a nudge message
asking the LLM to call tools first, then re-invokes once before
accepting the response. This prevents under-powered models from
skipping tool use when overwhelmed by long context.
Args:
chain: A LangChain runnable (prompt | llm.bind_tools(tools)).
messages: The initial list of messages to send.
tools: List of LangChain tool objects (must match the tools bound to the LLM).
max_rounds: Maximum number of tool-calling rounds before forcing a stop.
min_report_length: Minimum acceptable length (chars) of a text-only
first response. Shorter responses trigger a nudge to use tools.
Returns:
The final AIMessage with a text ``content`` (report).
"""
tool_map = {t.name: t for t in tools}
current_messages = list(messages)
first_round = True
for _ in range(max_rounds):
result: AIMessage = chain.invoke(current_messages)
current_messages.append(result)
if not result.tool_calls:
# Nudge: if the LLM skipped tools on its first turn and the
# response is suspiciously short, ask it to try again with tools.
if first_round and len(result.content or "") < min_report_length:
tool_names = ", ".join(tool_map.keys())
nudge = (
"Your response was too brief. You MUST call at least one tool "
f"({tool_names}) before writing your final report. "
"Please call the tools now."
)
current_messages.append(
HumanMessage(content=nudge)
)
first_round = False
continue
return result
first_round = False
# Execute each requested tool call and append ToolMessages
for tc in result.tool_calls:
tool_name = tc["name"]

View File

@ -249,6 +249,10 @@ def get_industry_performance_yfinance(
) -> str:
"""
Get industry-level drill-down within a sector.
Returns top companies with metadata (rating, market weight) **plus**
recent price performance (1-day, 1-week, 1-month returns) obtained
via a single batched ``yf.download()`` call for the top 10 tickers.
Args:
sector_key: Sector identifier (e.g., 'technology', 'healthcare')
@ -265,17 +269,44 @@ def get_industry_performance_yfinance(
if top_companies is None or top_companies.empty:
return f"No industry data found for sector '{sector_key}'"
# --- Batch-download price history for the top 10 tickers ----------
tickers = list(top_companies.head(10).index)
price_returns: dict[str, dict[str, float | None]] = {}
try:
hist = yf.download(
tickers, period="1mo", auto_adjust=True, progress=False, threads=True,
)
for tkr in tickers:
try:
if len(tickers) > 1:
closes = hist["Close"][tkr].dropna()
else:
closes = hist["Close"].dropna()
if closes.empty or len(closes) < 2:
continue
price_returns[tkr] = {
"1d": _safe_pct(closes, 1),
"1w": _safe_pct(closes, 5),
"1m": _safe_pct(closes, len(closes) - 1),
}
except Exception:
continue
except Exception:
pass # Fall through — table will show N/A for returns
# ------------------------------------------------------------------
header = f"# Industry Performance: {sector_key.replace('-', ' ').title()}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
result_str = header
result_str += "| Company | Symbol | Rating | Market Weight |\n"
result_str += "|---------|--------|--------|---------------|\n"
result_str += "| Company | Symbol | Rating | Market Weight | 1-Day % | 1-Week % | 1-Month % |\n"
result_str += "|---------|--------|--------|---------------|---------|----------|-----------|\n"
# top_companies has ticker as the DataFrame index (index.name == 'symbol')
# Columns: name, rating, market weight
for symbol, row in top_companies.head(20).iterrows():
# Display only the tickers we downloaded prices for to avoid N/A gaps
for symbol, row in top_companies.head(10).iterrows():
name = row.get('name', 'N/A')
rating = row.get('rating', 'N/A')
market_weight = row.get('market weight', None)
@ -283,7 +314,15 @@ def get_industry_performance_yfinance(
name_short = name[:30] if isinstance(name, str) else str(name)
weight_str = f"{market_weight:.2%}" if isinstance(market_weight, (int, float)) else "N/A"
result_str += f"| {name_short} | {symbol} | {rating} | {weight_str} |\n"
ret = price_returns.get(symbol, {})
day_str = f"{ret['1d']:+.2f}%" if ret.get('1d') is not None else "N/A"
week_str = f"{ret['1w']:+.2f}%" if ret.get('1w') is not None else "N/A"
month_str = f"{ret['1m']:+.2f}%" if ret.get('1m') is not None else "N/A"
result_str += (
f"| {name_short} | {symbol} | {rating} | {weight_str}"
f" | {day_str} | {week_str} | {month_str} |\n"
)
return result_str