test: add tests and parsing logic for text formats in _extract_top_sectors
Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
parent
5799bb3f00
commit
644ce57b9c
|
|
@ -100,6 +100,41 @@ class TestExtractTopSectors:
|
||||||
assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'"
|
assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'"
|
||||||
assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS
|
assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS
|
||||||
|
|
||||||
|
def test_extracts_from_bullet_points(self):
|
||||||
|
report = """
|
||||||
|
Here are the top sectors:
|
||||||
|
- Technology: The technology sector has been performing well.
|
||||||
|
- Healthcare: Innovations in biotech are driving growth.
|
||||||
|
- Energy - showing strong recovery.
|
||||||
|
- Utilities
|
||||||
|
"""
|
||||||
|
result = _extract_top_sectors(report, top_n=3)
|
||||||
|
assert result == ["technology", "healthcare", "energy"]
|
||||||
|
|
||||||
|
def test_extracts_from_numbered_lists(self):
|
||||||
|
report = """
|
||||||
|
Top performers this month:
|
||||||
|
1. Financial Services: Interest rates are up.
|
||||||
|
2. Consumer Staples - steady growth.
|
||||||
|
3. Real Estate: Rebounding.
|
||||||
|
"""
|
||||||
|
result = _extract_top_sectors(report, top_n=2)
|
||||||
|
assert result == ["financial-services", "consumer-defensive"]
|
||||||
|
|
||||||
|
def test_extracts_from_plain_text(self):
|
||||||
|
report = "We recommend looking into technology, energy, and materials this quarter."
|
||||||
|
result = _extract_top_sectors(report, top_n=3)
|
||||||
|
assert result == ["technology", "energy", "basic-materials"]
|
||||||
|
|
||||||
|
def test_extracts_with_mixed_capitalization_and_whitespace(self):
|
||||||
|
report = """
|
||||||
|
* ComMunication SeRvices : strong user growth.
|
||||||
|
* inDuStrials: Infrastructure spending is up.
|
||||||
|
* basiC MaTerials - high demand.
|
||||||
|
"""
|
||||||
|
result = _extract_top_sectors(report, top_n=3)
|
||||||
|
assert result == ["communication-services", "industrials", "basic-materials"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# run_tool_loop nudge tests
|
# run_tool_loop nudge tests
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news
|
from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news
|
||||||
from tradingagents.agents.utils.tool_runner import run_tool_loop
|
from tradingagents.agents.utils.tool_runner import run_tool_loop
|
||||||
|
|
@ -46,6 +47,8 @@ def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]:
|
||||||
| Technology | +0.45% | +1.20% | +5.67% | +12.3% |
|
| Technology | +0.45% | +1.20% | +5.67% | +12.3% |
|
||||||
|
|
||||||
We parse the 1-month column (index 3) and sort by absolute value.
|
We parse the 1-month column (index 3) and sort by absolute value.
|
||||||
|
If the report is not a table, it attempts to parse list formats
|
||||||
|
(bullet points, numbered lists, or plain text).
|
||||||
|
|
||||||
Returns a list of valid sector keys (e.g. ``["technology", "energy"]``).
|
Returns a list of valid sector keys (e.g. ``["technology", "energy"]``).
|
||||||
Falls back to a sensible default if parsing fails.
|
Falls back to a sensible default if parsing fails.
|
||||||
|
|
@ -73,12 +76,48 @@ def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]:
|
||||||
if key:
|
if key:
|
||||||
rows.append((key, month_val))
|
rows.append((key, month_val))
|
||||||
|
|
||||||
if not rows:
|
if rows:
|
||||||
return VALID_SECTOR_KEYS[:top_n]
|
# Sort by absolute 1-month move (biggest mover first)
|
||||||
|
rows.sort(key=lambda r: abs(r[1]), reverse=True)
|
||||||
|
return [r[0] for r in rows[:top_n]]
|
||||||
|
|
||||||
# Sort by absolute 1-month move (biggest mover first)
|
# Fallback to parsing text formats: bullet points, numbered lists, plain text
|
||||||
rows.sort(key=lambda r: abs(r[1]), reverse=True)
|
sectors = []
|
||||||
return [r[0] for r in rows[:top_n]]
|
lines = sector_report.split("\n")
|
||||||
|
for line in lines:
|
||||||
|
line_clean = line.strip()
|
||||||
|
# Regex to match list formats: e.g., "- Technology:", "1. Energy -", "* Healthcare:"
|
||||||
|
match = re.match(r'^(?:-|\*|\d+\.)?\s*([a-zA-Z\s]+?)\s*[:\-]', line_clean)
|
||||||
|
if match:
|
||||||
|
sector_name = match.group(1).strip().lower()
|
||||||
|
key = _DISPLAY_TO_KEY.get(sector_name)
|
||||||
|
if key and key not in sectors:
|
||||||
|
sectors.append(key)
|
||||||
|
if len(sectors) == top_n:
|
||||||
|
return sectors
|
||||||
|
|
||||||
|
# Final fallback for plain text search
|
||||||
|
if not sectors:
|
||||||
|
report_lower = sector_report.lower()
|
||||||
|
|
||||||
|
found_sectors = []
|
||||||
|
for disp_name, key in _DISPLAY_TO_KEY.items():
|
||||||
|
idx = report_lower.find(disp_name)
|
||||||
|
if idx != -1:
|
||||||
|
found_sectors.append((idx, key))
|
||||||
|
|
||||||
|
# Sort by appearance order
|
||||||
|
found_sectors.sort(key=lambda x: x[0])
|
||||||
|
for _, key in found_sectors:
|
||||||
|
if key not in sectors:
|
||||||
|
sectors.append(key)
|
||||||
|
if len(sectors) == top_n:
|
||||||
|
return sectors
|
||||||
|
|
||||||
|
if sectors:
|
||||||
|
return sectors[:top_n]
|
||||||
|
|
||||||
|
return VALID_SECTOR_KEYS[:top_n]
|
||||||
|
|
||||||
|
|
||||||
def create_industry_deep_dive(llm):
|
def create_industry_deep_dive(llm):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue