test: add tests and parsing logic for text formats in _extract_top_sectors

Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
google-labs-jules[bot] 2026-03-21 08:28:35 +00:00
parent 5799bb3f00
commit 644ce57b9c
2 changed files with 79 additions and 5 deletions

View File

@ -100,6 +100,41 @@ class TestExtractTopSectors:
assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'" assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'"
assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS
def test_extracts_from_bullet_points(self):
report = """
Here are the top sectors:
- Technology: The technology sector has been performing well.
- Healthcare: Innovations in biotech are driving growth.
- Energy - showing strong recovery.
- Utilities
"""
result = _extract_top_sectors(report, top_n=3)
assert result == ["technology", "healthcare", "energy"]
def test_extracts_from_numbered_lists(self):
report = """
Top performers this month:
1. Financial Services: Interest rates are up.
2. Consumer Staples - steady growth.
3. Real Estate: Rebounding.
"""
result = _extract_top_sectors(report, top_n=2)
assert result == ["financial-services", "consumer-defensive"]
def test_extracts_from_plain_text(self):
report = "We recommend looking into technology, energy, and materials this quarter."
result = _extract_top_sectors(report, top_n=3)
assert result == ["technology", "energy", "basic-materials"]
def test_extracts_with_mixed_capitalization_and_whitespace(self):
report = """
* ComMunication SeRvices : strong user growth.
* inDuStrials: Infrastructure spending is up.
* basiC MaTerials - high demand.
"""
result = _extract_top_sectors(report, top_n=3)
assert result == ["communication-services", "industrials", "basic-materials"]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# run_tool_loop nudge tests # run_tool_loop nudge tests

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import re
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news
from tradingagents.agents.utils.tool_runner import run_tool_loop from tradingagents.agents.utils.tool_runner import run_tool_loop
@ -46,6 +47,8 @@ def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]:
| Technology | +0.45% | +1.20% | +5.67% | +12.3% | | Technology | +0.45% | +1.20% | +5.67% | +12.3% |
We parse the 1-month column (index 3) and sort by absolute value. We parse the 1-month column (index 3) and sort by absolute value.
If the report is not a table, it attempts to parse list formats
(bullet points, numbered lists, or plain text).
Returns a list of valid sector keys (e.g. ``["technology", "energy"]``). Returns a list of valid sector keys (e.g. ``["technology", "energy"]``).
Falls back to a sensible default if parsing fails. Falls back to a sensible default if parsing fails.
@ -73,13 +76,49 @@ def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]:
if key: if key:
rows.append((key, month_val)) rows.append((key, month_val))
if not rows: if rows:
return VALID_SECTOR_KEYS[:top_n]
# Sort by absolute 1-month move (biggest mover first) # Sort by absolute 1-month move (biggest mover first)
rows.sort(key=lambda r: abs(r[1]), reverse=True) rows.sort(key=lambda r: abs(r[1]), reverse=True)
return [r[0] for r in rows[:top_n]] return [r[0] for r in rows[:top_n]]
# Fallback to parsing text formats: bullet points, numbered lists, plain text
sectors = []
lines = sector_report.split("\n")
for line in lines:
line_clean = line.strip()
# Regex to match list formats: e.g., "- Technology:", "1. Energy -", "* Healthcare:"
match = re.match(r'^(?:-|\*|\d+\.)?\s*([a-zA-Z\s]+?)\s*[:\-]', line_clean)
if match:
sector_name = match.group(1).strip().lower()
key = _DISPLAY_TO_KEY.get(sector_name)
if key and key not in sectors:
sectors.append(key)
if len(sectors) == top_n:
return sectors
# Final fallback for plain text search
if not sectors:
report_lower = sector_report.lower()
found_sectors = []
for disp_name, key in _DISPLAY_TO_KEY.items():
idx = report_lower.find(disp_name)
if idx != -1:
found_sectors.append((idx, key))
# Sort by appearance order
found_sectors.sort(key=lambda x: x[0])
for _, key in found_sectors:
if key not in sectors:
sectors.append(key)
if len(sectors) == top_n:
return sectors
if sectors:
return sectors[:top_n]
return VALID_SECTOR_KEYS[:top_n]
def create_industry_deep_dive(llm): def create_industry_deep_dive(llm):
def industry_deep_dive_node(state): def industry_deep_dive_node(state):