From 644ce57b9ce58059cf6169d92841a95f96d14ea4 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 08:28:35 +0000 Subject: [PATCH] test: add tests and parsing logic for text formats in _extract_top_sectors Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> --- tests/unit/test_industry_deep_dive.py | 35 +++++++++++++ .../agents/scanners/industry_deep_dive.py | 49 +++++++++++++++++-- 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_industry_deep_dive.py b/tests/unit/test_industry_deep_dive.py index ffed3339..87d93b8d 100644 --- a/tests/unit/test_industry_deep_dive.py +++ b/tests/unit/test_industry_deep_dive.py @@ -100,6 +100,41 @@ class TestExtractTopSectors: assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'" assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS + def test_extracts_from_bullet_points(self): + report = """ + Here are the top sectors: + - Technology: The technology sector has been performing well. + - Healthcare: Innovations in biotech are driving growth. + - Energy - showing strong recovery. + - Utilities + """ + result = _extract_top_sectors(report, top_n=3) + assert result == ["technology", "healthcare", "energy"] + + def test_extracts_from_numbered_lists(self): + report = """ + Top performers this month: + 1. Financial Services: Interest rates are up. + 2. Consumer Staples - steady growth. + 3. Real Estate: Rebounding. + """ + result = _extract_top_sectors(report, top_n=2) + assert result == ["financial-services", "consumer-defensive"] + + def test_extracts_from_plain_text(self): + report = "We recommend looking into technology, energy, and materials this quarter." + result = _extract_top_sectors(report, top_n=3) + assert result == ["technology", "energy", "basic-materials"] + + def test_extracts_with_mixed_capitalization_and_whitespace(self): + report = """ + * ComMunication SeRvices : strong user growth. + * inDuStrials: Infrastructure spending is up. + * basiC MaTerials - high demand. + """ + result = _extract_top_sectors(report, top_n=3) + assert result == ["communication-services", "industrials", "basic-materials"] + # --------------------------------------------------------------------------- # run_tool_loop nudge tests diff --git a/tradingagents/agents/scanners/industry_deep_dive.py b/tradingagents/agents/scanners/industry_deep_dive.py index 3b15cf4f..8821e840 100644 --- a/tradingagents/agents/scanners/industry_deep_dive.py +++ b/tradingagents/agents/scanners/industry_deep_dive.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news from tradingagents.agents.utils.tool_runner import run_tool_loop @@ -46,6 +47,8 @@ def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]: | Technology | +0.45% | +1.20% | +5.67% | +12.3% | We parse the 1-month column (index 3) and sort by absolute value. + If the report is not a table, it attempts to parse list formats + (bullet points, numbered lists, or plain text). Returns a list of valid sector keys (e.g. ``["technology", "energy"]``). Falls back to a sensible default if parsing fails. @@ -73,12 +76,48 @@ def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]: if key: rows.append((key, month_val)) - if not rows: - return VALID_SECTOR_KEYS[:top_n] + if rows: + # Sort by absolute 1-month move (biggest mover first) + rows.sort(key=lambda r: abs(r[1]), reverse=True) + return [r[0] for r in rows[:top_n]] - # Sort by absolute 1-month move (biggest mover first) - rows.sort(key=lambda r: abs(r[1]), reverse=True) - return [r[0] for r in rows[:top_n]] + # Fallback to parsing text formats: bullet points, numbered lists, plain text + sectors = [] + lines = sector_report.split("\n") + for line in lines: + line_clean = line.strip() + # Regex to match list formats: e.g., "- Technology:", "1. Energy -", "* Healthcare:" + match = re.match(r'^(?:-|\*|\d+\.)?\s*([a-zA-Z\s]+?)\s*[:\-]', line_clean) + if match: + sector_name = match.group(1).strip().lower() + key = _DISPLAY_TO_KEY.get(sector_name) + if key and key not in sectors: + sectors.append(key) + if len(sectors) == top_n: + return sectors + + # Final fallback for plain text search + if not sectors: + report_lower = sector_report.lower() + + found_sectors = [] + for disp_name, key in _DISPLAY_TO_KEY.items(): + idx = report_lower.find(disp_name) + if idx != -1: + found_sectors.append((idx, key)) + + # Sort by appearance order + found_sectors.sort(key=lambda x: x[0]) + for _, key in found_sectors: + if key not in sectors: + sectors.append(key) + if len(sectors) == top_n: + return sectors + + if sectors: + return sectors[:top_n] + + return VALID_SECTOR_KEYS[:top_n] def create_industry_deep_dive(llm):