test: add tests and parsing logic for text formats in _extract_top_sectors
Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
parent
5799bb3f00
commit
644ce57b9c
|
|
@ -100,6 +100,41 @@ class TestExtractTopSectors:
|
|||
assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'"
|
||||
assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS
|
||||
|
||||
def test_extracts_from_bullet_points(self):
|
||||
report = """
|
||||
Here are the top sectors:
|
||||
- Technology: The technology sector has been performing well.
|
||||
- Healthcare: Innovations in biotech are driving growth.
|
||||
- Energy - showing strong recovery.
|
||||
- Utilities
|
||||
"""
|
||||
result = _extract_top_sectors(report, top_n=3)
|
||||
assert result == ["technology", "healthcare", "energy"]
|
||||
|
||||
def test_extracts_from_numbered_lists(self):
|
||||
report = """
|
||||
Top performers this month:
|
||||
1. Financial Services: Interest rates are up.
|
||||
2. Consumer Staples - steady growth.
|
||||
3. Real Estate: Rebounding.
|
||||
"""
|
||||
result = _extract_top_sectors(report, top_n=2)
|
||||
assert result == ["financial-services", "consumer-defensive"]
|
||||
|
||||
def test_extracts_from_plain_text(self):
|
||||
report = "We recommend looking into technology, energy, and materials this quarter."
|
||||
result = _extract_top_sectors(report, top_n=3)
|
||||
assert result == ["technology", "energy", "basic-materials"]
|
||||
|
||||
def test_extracts_with_mixed_capitalization_and_whitespace(self):
|
||||
report = """
|
||||
* ComMunication SeRvices : strong user growth.
|
||||
* inDuStrials: Infrastructure spending is up.
|
||||
* basiC MaTerials - high demand.
|
||||
"""
|
||||
result = _extract_top_sectors(report, top_n=3)
|
||||
assert result == ["communication-services", "industrials", "basic-materials"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_tool_loop nudge tests
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news
|
||||
from tradingagents.agents.utils.tool_runner import run_tool_loop
|
||||
|
|
@ -46,6 +47,8 @@ def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]:
|
|||
| Technology | +0.45% | +1.20% | +5.67% | +12.3% |
|
||||
|
||||
We parse the 1-month column (index 3) and sort by absolute value.
|
||||
If the report is not a table, it attempts to parse list formats
|
||||
(bullet points, numbered lists, or plain text).
|
||||
|
||||
Returns a list of valid sector keys (e.g. ``["technology", "energy"]``).
|
||||
Falls back to a sensible default if parsing fails.
|
||||
|
|
@ -73,13 +76,49 @@ def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]:
|
|||
if key:
|
||||
rows.append((key, month_val))
|
||||
|
||||
if not rows:
|
||||
return VALID_SECTOR_KEYS[:top_n]
|
||||
|
||||
if rows:
|
||||
# Sort by absolute 1-month move (biggest mover first)
|
||||
rows.sort(key=lambda r: abs(r[1]), reverse=True)
|
||||
return [r[0] for r in rows[:top_n]]
|
||||
|
||||
# Fallback to parsing text formats: bullet points, numbered lists, plain text
|
||||
sectors = []
|
||||
lines = sector_report.split("\n")
|
||||
for line in lines:
|
||||
line_clean = line.strip()
|
||||
# Regex to match list formats: e.g., "- Technology:", "1. Energy -", "* Healthcare:"
|
||||
match = re.match(r'^(?:-|\*|\d+\.)?\s*([a-zA-Z\s]+?)\s*[:\-]', line_clean)
|
||||
if match:
|
||||
sector_name = match.group(1).strip().lower()
|
||||
key = _DISPLAY_TO_KEY.get(sector_name)
|
||||
if key and key not in sectors:
|
||||
sectors.append(key)
|
||||
if len(sectors) == top_n:
|
||||
return sectors
|
||||
|
||||
# Final fallback for plain text search
|
||||
if not sectors:
|
||||
report_lower = sector_report.lower()
|
||||
|
||||
found_sectors = []
|
||||
for disp_name, key in _DISPLAY_TO_KEY.items():
|
||||
idx = report_lower.find(disp_name)
|
||||
if idx != -1:
|
||||
found_sectors.append((idx, key))
|
||||
|
||||
# Sort by appearance order
|
||||
found_sectors.sort(key=lambda x: x[0])
|
||||
for _, key in found_sectors:
|
||||
if key not in sectors:
|
||||
sectors.append(key)
|
||||
if len(sectors) == top_n:
|
||||
return sectors
|
||||
|
||||
if sectors:
|
||||
return sectors[:top_n]
|
||||
|
||||
return VALID_SECTOR_KEYS[:top_n]
|
||||
|
||||
|
||||
def create_industry_deep_dive(llm):
|
||||
def industry_deep_dive_node(state):
|
||||
|
|
|
|||
Loading…
Reference in New Issue