From 1362781291a12f254d17786afb9e5ea92a9c906f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 20:10:45 +0100 Subject: [PATCH] feat: improve Industry Deep Dive report quality with enriched data, sector routing, and tool-call nudge * Initial plan * Improve Industry Deep Dive quality: enrich tool data, explicit sector keys, tool-call nudge - Enrich get_industry_performance_yfinance with 1-day/1-week/1-month price returns via batched yf.download() for top 10 tickers (Step 1) - Add VALID_SECTOR_KEYS, _DISPLAY_TO_KEY, _extract_top_sectors() to industry_deep_dive.py to pre-extract top sectors from Phase 1 report and inject them into the prompt (Step 2) - Add tool-call nudge to run_tool_loop: if first LLM response has no tool calls and is under 500 chars, re-prompt with explicit instruction to call tools (Step 3) - Update scanner_tools.py get_industry_performance docstring to list all valid sector keys (Step 4) - Add 15 unit tests covering _extract_top_sectors, tool_runner nudge, and enriched output (Step 5) Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> * Address code review: move cols[3] access into try block for IndexError safety Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> * fix: align display row count with download count in get_industry_performance_yfinance The enriched function downloads price data for top 10 tickers but displayed 20 rows, causing rows 11-20 to show N/A in all price columns. This broke test_industry_perf_falls_back_to_yfinance which asserts N/A count < 5. Now both download and display use head(10) for consistency. Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> Co-authored-by: Ahmet Guzererler Co-authored-by: Claude Opus 4.6 --- tests/test_industry_deep_dive.py | 242 ++++++++++++++++++ .../agents/scanners/industry_deep_dive.py | 112 +++++++- tradingagents/agents/utils/scanner_tools.py | 8 +- tradingagents/agents/utils/tool_runner.py | 32 ++- tradingagents/dataflows/yfinance_scanner.py | 49 +++- 5 files changed, 424 insertions(+), 19 deletions(-) create mode 100644 tests/test_industry_deep_dive.py diff --git a/tests/test_industry_deep_dive.py b/tests/test_industry_deep_dive.py new file mode 100644 index 00000000..52c98678 --- /dev/null +++ b/tests/test_industry_deep_dive.py @@ -0,0 +1,242 @@ +"""Tests for the Industry Deep Dive improvements: + +1. _extract_top_sectors() parses sector performance reports correctly +2. Enriched get_industry_performance_yfinance returns price columns +3. run_tool_loop nudge triggers when first response is short & no tool calls +""" + +import pytest +from unittest.mock import MagicMock + +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + +from tradingagents.agents.scanners.industry_deep_dive import ( + VALID_SECTOR_KEYS, + _DISPLAY_TO_KEY, + _extract_top_sectors, +) +from tradingagents.agents.utils.tool_runner import ( + run_tool_loop, + MAX_TOOL_ROUNDS, + MIN_REPORT_LENGTH, +) + + +# --------------------------------------------------------------------------- +# _extract_top_sectors tests +# --------------------------------------------------------------------------- + +SAMPLE_SECTOR_REPORT = """\ +# Sector Performance Overview +# Data retrieved on: 2026-03-17 12:00:00 + +| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % | +|--------|---------|----------|-----------|-------| +| Technology | +0.45% | +1.20% | +5.67% | +12.30% | +| Healthcare | -0.12% | -0.50% | -2.10% | +3.40% | +| Financials | +0.30% | +0.80% | +3.25% | +8.10% | +| Energy | +1.10% | +2.50% | +7.80% | +15.20% | +| Consumer Discretionary | -0.20% | -0.10% | -1.50% | +2.00% | +| Consumer Staples | +0.05% | +0.30% | +0.90% | +4.50% | +| Industrials | +0.25% | +0.60% | +2.80% | +6.70% | +| Materials | +0.40% | +1.00% | +4.20% | +9.30% | +| Real Estate | -0.35% | -0.80% | -3.40% | -1.20% | +| Utilities | +0.10% | +0.20% | +1.10% | +5.60% | +| Communication Services | +0.55% | +1.50% | +6.30% | +11.00% | +""" + + +class TestExtractTopSectors: + """Verify _extract_top_sectors parses the table correctly.""" + + def test_returns_top_3_by_absolute_1month(self): + result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=3) + assert len(result) == 3 + # Energy (+7.80%), Communication Services (+6.30%), Technology (+5.67%) + assert result[0] == "energy" + assert result[1] == "communication-services" + assert result[2] == "technology" + + def test_returns_top_n_variable(self): + result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=5) + assert len(result) == 5 + # All should be valid sector keys + for key in result: + assert key in VALID_SECTOR_KEYS, f"Invalid key: {key}" + + def test_empty_report_returns_defaults(self): + result = _extract_top_sectors("", top_n=3) + assert result == VALID_SECTOR_KEYS[:3] + + def test_none_report_returns_defaults(self): + result = _extract_top_sectors(None, top_n=3) + assert result == VALID_SECTOR_KEYS[:3] + + def test_garbage_report_returns_defaults(self): + result = _extract_top_sectors("not a table at all\njust random text", top_n=3) + assert result == VALID_SECTOR_KEYS[:3] + + def test_negative_returns_sorted_by_absolute_value(self): + """Sectors with large negative moves should rank high (big movers).""" + report = """\ +| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % | +|--------|---------|----------|-----------|-------| +| Technology | +0.10% | +0.20% | +1.00% | +2.00% | +| Energy | -0.50% | -1.00% | -8.50% | -5.00% | +| Healthcare | +0.05% | +0.10% | +0.50% | +1.00% | +""" + result = _extract_top_sectors(report, top_n=2) + assert result[0] == "energy" # |-8.50| > |1.00| + + def test_all_returned_keys_are_valid(self): + result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=11) + for key in result: + assert key in VALID_SECTOR_KEYS + + def test_display_to_key_covers_all_sectors(self): + """Every sector name that appears in the ETF performance table + should map to a valid key.""" + display_names = [ + "technology", "healthcare", "financials", "energy", + "consumer discretionary", "consumer staples", "industrials", + "materials", "real estate", "utilities", "communication services", + ] + for name in display_names: + assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'" + assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS + + +# --------------------------------------------------------------------------- +# run_tool_loop nudge tests +# --------------------------------------------------------------------------- + +class TestToolLoopNudge: + """Verify the nudge mechanism in run_tool_loop.""" + + def _make_chain(self, responses): + """Create a mock chain that returns responses in sequence.""" + chain = MagicMock() + chain.invoke = MagicMock(side_effect=responses) + return chain + + def _make_tool(self, name="my_tool"): + tool = MagicMock() + tool.name = name + tool.invoke = MagicMock(return_value="tool result") + return tool + + def test_long_response_no_nudge(self): + """A long first response (no tool calls) should be returned as-is.""" + long_text = "A" * 600 + response = AIMessage(content=long_text, tool_calls=[]) + chain = self._make_chain([response]) + tool = self._make_tool() + + result = run_tool_loop(chain, [], [tool]) + assert result.content == long_text + assert chain.invoke.call_count == 1 + + def test_short_response_triggers_nudge(self): + """A short first response triggers a nudge, then the LLM is re-invoked.""" + short_resp = AIMessage(content="Brief.", tool_calls=[]) + long_resp = AIMessage(content="A" * 600, tool_calls=[]) + chain = self._make_chain([short_resp, long_resp]) + tool = self._make_tool() + + result = run_tool_loop(chain, [], [tool]) + assert result.content == long_resp.content + assert chain.invoke.call_count == 2 + + # The second invoke should have received a HumanMessage nudge + second_call_messages = chain.invoke.call_args_list[1][0][0] + nudge_msgs = [m for m in second_call_messages if isinstance(m, HumanMessage)] + assert len(nudge_msgs) == 1 + assert "MUST call at least one tool" in nudge_msgs[0].content + + def test_nudge_only_on_first_round(self): + """Nudge should NOT trigger after tools have been used.""" + # Round 1: LLM calls a tool + tool_call_resp = AIMessage( + content="", + tool_calls=[{"name": "my_tool", "args": {}, "id": "tc1"}], + ) + # Round 2: LLM returns a short text — no nudge expected + short_resp = AIMessage(content="Done.", tool_calls=[]) + chain = self._make_chain([tool_call_resp, short_resp]) + tool = self._make_tool() + + result = run_tool_loop(chain, [], [tool]) + assert result.content == "Done." + assert chain.invoke.call_count == 2 + + def test_tool_calls_execute_normally(self): + """Normal tool-calling flow should still work unchanged.""" + tool_call_resp = AIMessage( + content="", + tool_calls=[{"name": "my_tool", "args": {"x": 1}, "id": "tc1"}], + ) + final_resp = AIMessage(content="Final report" * 50, tool_calls=[]) + chain = self._make_chain([tool_call_resp, final_resp]) + tool = self._make_tool() + + result = run_tool_loop(chain, [], [tool]) + tool.invoke.assert_called_once_with({"x": 1}) + assert "Final report" in result.content + + +# --------------------------------------------------------------------------- +# Enriched industry performance tests +# --------------------------------------------------------------------------- + +class TestEnrichedIndustryPerformance: + """Verify that get_industry_performance_yfinance now returns price columns. + + These tests require network access to Yahoo Finance. If the host is not + reachable (e.g. in sandboxed CI), they are automatically skipped. + """ + + @pytest.fixture(autouse=True) + def _require_yahoo(self): + import socket + try: + socket.setdefaulttimeout(3) + socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect( + ("query2.finance.yahoo.com", 443) + ) + except (socket.error, OSError): + pytest.skip("Yahoo Finance not reachable") + + def test_technology_has_price_columns(self): + from tradingagents.dataflows.yfinance_scanner import ( + get_industry_performance_yfinance, + ) + + result = get_industry_performance_yfinance("technology") + assert "# Industry Performance: Technology" in result + # New columns should be present in the header + assert "1-Day %" in result + assert "1-Week %" in result + assert "1-Month %" in result + + def test_table_has_seven_columns(self): + from tradingagents.dataflows.yfinance_scanner import ( + get_industry_performance_yfinance, + ) + + result = get_industry_performance_yfinance("technology") + lines = result.strip().split("\n") + # Find the header separator line + sep_lines = [l for l in lines if l.startswith("|") and "---" in l] + assert len(sep_lines) >= 1 + # Count columns in separator + cols = [c.strip() for c in sep_lines[0].split("|")[1:-1]] + assert len(cols) == 7, f"Expected 7 columns, got {len(cols)}: {cols}" + + def test_healthcare_sector_key(self): + from tradingagents.dataflows.yfinance_scanner import ( + get_industry_performance_yfinance, + ) + + result = get_industry_performance_yfinance("healthcare") + assert "Industry Performance: Healthcare" in result + assert "1-Day %" in result diff --git a/tradingagents/agents/scanners/industry_deep_dive.py b/tradingagents/agents/scanners/industry_deep_dive.py index bfe84b6b..3b15cf4f 100644 --- a/tradingagents/agents/scanners/industry_deep_dive.py +++ b/tradingagents/agents/scanners/industry_deep_dive.py @@ -1,7 +1,85 @@ +from __future__ import annotations + from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news from tradingagents.agents.utils.tool_runner import run_tool_loop +# All valid sector keys accepted by yfinance Sector() and get_industry_performance. +VALID_SECTOR_KEYS = [ + "technology", + "healthcare", + "financial-services", + "energy", + "consumer-cyclical", + "consumer-defensive", + "industrials", + "basic-materials", + "real-estate", + "utilities", + "communication-services", +] + +# Map display names used in the sector performance report to valid keys. +_DISPLAY_TO_KEY = { + "technology": "technology", + "healthcare": "healthcare", + "financials": "financial-services", + "financial services": "financial-services", + "energy": "energy", + "consumer discretionary": "consumer-cyclical", + "consumer staples": "consumer-defensive", + "industrials": "industrials", + "materials": "basic-materials", + "basic materials": "basic-materials", + "real estate": "real-estate", + "utilities": "utilities", + "communication services": "communication-services", +} + + +def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]: + """Parse the sector performance report and return the *top_n* sector keys + ranked by absolute 1-month performance (largest absolute move first). + + The sector performance table looks like: + + | Technology | +0.45% | +1.20% | +5.67% | +12.3% | + + We parse the 1-month column (index 3) and sort by absolute value. + + Returns a list of valid sector keys (e.g. ``["technology", "energy"]``). + Falls back to a sensible default if parsing fails. + """ + if not sector_report: + return VALID_SECTOR_KEYS[:top_n] + + rows: list[tuple[str, float]] = [] + for line in sector_report.split("\n"): + if not line.startswith("|"): + continue + cols = [c.strip() for c in line.split("|")[1:-1]] + if len(cols) < 4: + continue + sector_name = cols[0].lower() + if sector_name in ("sector", "---", "") or "---" in sector_name: + continue + # Try to parse the 1-month column (index 3) + try: + month_str = cols[3].replace("%", "").replace("+", "").strip() + month_val = float(month_str) + except (ValueError, IndexError): + continue + key = _DISPLAY_TO_KEY.get(sector_name) + if key: + rows.append((key, month_val)) + + if not rows: + return VALID_SECTOR_KEYS[:top_n] + + # Sort by absolute 1-month move (biggest mover first) + rows.sort(key=lambda r: abs(r[1]), reverse=True) + return [r[0] for r in rows[:top_n]] + def create_industry_deep_dive(llm): def industry_deep_dive_node(state): @@ -9,6 +87,9 @@ def create_industry_deep_dive(llm): tools = [get_industry_performance, get_topic_news] + sector_report = state.get("sector_performance_report", "") + top_sectors = _extract_top_sectors(sector_report, top_n=3) + # Inject Phase 1 context so the LLM can decide which sectors to drill into phase1_context = f"""## Phase 1 Scanner Reports (for your reference) @@ -19,20 +100,29 @@ def create_industry_deep_dive(llm): {state.get("market_movers_report", "Not available")} ### Sector Performance Report: -{state.get("sector_performance_report", "Not available")} +{sector_report or "Not available"} """ + sector_list_str = ", ".join(f"'{s}'" for s in top_sectors) + all_keys_str = ", ".join(f"'{s}'" for s in VALID_SECTOR_KEYS) + system_message = ( - "You are a senior research analyst performing an industry deep dive. " - "You have received reports from three parallel scanners (geopolitical, market movers, sector performance). " - "Review these reports and identify the 2-3 most promising sectors/industries to investigate further. " - "Use get_industry_performance to drill into those sectors and get_topic_news for sector-specific news. " - "Write a detailed report covering: " - "(1) Why these industries were selected, " - "(2) Top companies within each industry and their recent performance, " - "(3) Industry-specific catalysts and risks, " - "(4) Cross-references between geopolitical events and sector opportunities." - f"\n\n{phase1_context}" + "You are a senior research analyst performing an industry deep dive.\n\n" + "## Your task\n" + "Based on the Phase 1 reports below, drill into the most interesting sectors " + "using the tools provided and write a detailed analysis.\n\n" + "## IMPORTANT — You MUST call tools before writing your report\n" + f"1. Call get_industry_performance for EACH of these top sectors: {sector_list_str}\n" + "2. Call get_topic_news for at least 2 sector-specific topics " + "(e.g., 'semiconductor industry', 'renewable energy stocks').\n" + "3. After receiving tool results, write your detailed report.\n\n" + f"Valid sector_key values for get_industry_performance: {all_keys_str}\n\n" + "## Report structure\n" + "(1) Why these industries were selected (link to Phase 1 findings)\n" + "(2) Top companies within each industry and their recent performance\n" + "(3) Industry-specific catalysts and risks\n" + "(4) Cross-references between geopolitical events and sector opportunities\n\n" + f"{phase1_context}" ) prompt = ChatPromptTemplate.from_messages( diff --git a/tradingagents/agents/utils/scanner_tools.py b/tradingagents/agents/utils/scanner_tools.py index 6898da67..b1869a4b 100644 --- a/tradingagents/agents/utils/scanner_tools.py +++ b/tradingagents/agents/utils/scanner_tools.py @@ -52,11 +52,15 @@ def get_industry_performance( ) -> str: """ Get industry-level drill-down within a specific sector. - Shows top companies and industries in the sector. + Shows top companies with rating, market weight, and recent price performance + (1-day, 1-week, 1-month returns). Uses the configured scanner_data vendor. Args: - sector_key (str): Sector identifier (e.g., 'technology', 'healthcare', 'energy') + sector_key (str): Sector identifier. Must be one of: + 'technology', 'healthcare', 'financial-services', 'energy', + 'consumer-cyclical', 'consumer-defensive', 'industrials', + 'basic-materials', 'real-estate', 'utilities', 'communication-services' Returns: str: Formatted table of top companies/industries in the sector with performance data diff --git a/tradingagents/agents/utils/tool_runner.py b/tradingagents/agents/utils/tool_runner.py index 3c07d5a4..a988f99b 100644 --- a/tradingagents/agents/utils/tool_runner.py +++ b/tradingagents/agents/utils/tool_runner.py @@ -9,42 +9,72 @@ from __future__ import annotations from typing import Any, List -from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # Most LLM tool-calling patterns resolve within 2-3 rounds; # 5 provides headroom for complex scenarios while preventing runaway loops. MAX_TOOL_ROUNDS = 5 +# If the LLM's first response has no tool calls AND is shorter than this, +# a nudge message is appended to encourage tool usage. +MIN_REPORT_LENGTH = 500 + def run_tool_loop( chain, messages: List[Any], tools: List[Any], max_rounds: int = MAX_TOOL_ROUNDS, + min_report_length: int = MIN_REPORT_LENGTH, ) -> AIMessage: """Invoke *chain* in a loop, executing any tool calls until the LLM produces a final text response (i.e. no more tool_calls). + If the very first LLM response contains no tool calls **and** the text + is shorter than *min_report_length*, the loop appends a nudge message + asking the LLM to call tools first, then re-invokes once before + accepting the response. This prevents under-powered models from + skipping tool use when overwhelmed by long context. + Args: chain: A LangChain runnable (prompt | llm.bind_tools(tools)). messages: The initial list of messages to send. tools: List of LangChain tool objects (must match the tools bound to the LLM). max_rounds: Maximum number of tool-calling rounds before forcing a stop. + min_report_length: Minimum acceptable length (chars) of a text-only + first response. Shorter responses trigger a nudge to use tools. Returns: The final AIMessage with a text ``content`` (report). """ tool_map = {t.name: t for t in tools} current_messages = list(messages) + first_round = True for _ in range(max_rounds): result: AIMessage = chain.invoke(current_messages) current_messages.append(result) if not result.tool_calls: + # Nudge: if the LLM skipped tools on its first turn and the + # response is suspiciously short, ask it to try again with tools. + if first_round and len(result.content or "") < min_report_length: + tool_names = ", ".join(tool_map.keys()) + nudge = ( + "Your response was too brief. You MUST call at least one tool " + f"({tool_names}) before writing your final report. " + "Please call the tools now." + ) + current_messages.append( + HumanMessage(content=nudge) + ) + first_round = False + continue return result + first_round = False + # Execute each requested tool call and append ToolMessages for tc in result.tool_calls: tool_name = tc["name"] diff --git a/tradingagents/dataflows/yfinance_scanner.py b/tradingagents/dataflows/yfinance_scanner.py index d4649ab8..21b5b3e5 100644 --- a/tradingagents/dataflows/yfinance_scanner.py +++ b/tradingagents/dataflows/yfinance_scanner.py @@ -249,6 +249,10 @@ def get_industry_performance_yfinance( ) -> str: """ Get industry-level drill-down within a sector. + + Returns top companies with metadata (rating, market weight) **plus** + recent price performance (1-day, 1-week, 1-month returns) obtained + via a single batched ``yf.download()`` call for the top 10 tickers. Args: sector_key: Sector identifier (e.g., 'technology', 'healthcare') @@ -265,17 +269,44 @@ def get_industry_performance_yfinance( if top_companies is None or top_companies.empty: return f"No industry data found for sector '{sector_key}'" - + + # --- Batch-download price history for the top 10 tickers ---------- + tickers = list(top_companies.head(10).index) + price_returns: dict[str, dict[str, float | None]] = {} + try: + hist = yf.download( + tickers, period="1mo", auto_adjust=True, progress=False, threads=True, + ) + for tkr in tickers: + try: + if len(tickers) > 1: + closes = hist["Close"][tkr].dropna() + else: + closes = hist["Close"].dropna() + if closes.empty or len(closes) < 2: + continue + price_returns[tkr] = { + "1d": _safe_pct(closes, 1), + "1w": _safe_pct(closes, 5), + "1m": _safe_pct(closes, len(closes) - 1), + } + except Exception: + continue + except Exception: + pass # Fall through — table will show N/A for returns + # ------------------------------------------------------------------ + header = f"# Industry Performance: {sector_key.replace('-', ' ').title()}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" result_str = header - result_str += "| Company | Symbol | Rating | Market Weight |\n" - result_str += "|---------|--------|--------|---------------|\n" + result_str += "| Company | Symbol | Rating | Market Weight | 1-Day % | 1-Week % | 1-Month % |\n" + result_str += "|---------|--------|--------|---------------|---------|----------|-----------|\n" # top_companies has ticker as the DataFrame index (index.name == 'symbol') # Columns: name, rating, market weight - for symbol, row in top_companies.head(20).iterrows(): + # Display only the tickers we downloaded prices for to avoid N/A gaps + for symbol, row in top_companies.head(10).iterrows(): name = row.get('name', 'N/A') rating = row.get('rating', 'N/A') market_weight = row.get('market weight', None) @@ -283,7 +314,15 @@ def get_industry_performance_yfinance( name_short = name[:30] if isinstance(name, str) else str(name) weight_str = f"{market_weight:.2%}" if isinstance(market_weight, (int, float)) else "N/A" - result_str += f"| {name_short} | {symbol} | {rating} | {weight_str} |\n" + ret = price_returns.get(symbol, {}) + day_str = f"{ret['1d']:+.2f}%" if ret.get('1d') is not None else "N/A" + week_str = f"{ret['1w']:+.2f}%" if ret.get('1w') is not None else "N/A" + month_str = f"{ret['1m']:+.2f}%" if ret.get('1m') is not None else "N/A" + + result_str += ( + f"| {name_short} | {symbol} | {rating} | {weight_str}" + f" | {day_str} | {week_str} | {month_str} |\n" + ) return result_str