feat: improve Industry Deep Dive report quality with enriched data, sector routing, and tool-call nudge
* Initial plan * Improve Industry Deep Dive quality: enrich tool data, explicit sector keys, tool-call nudge - Enrich get_industry_performance_yfinance with 1-day/1-week/1-month price returns via batched yf.download() for top 10 tickers (Step 1) - Add VALID_SECTOR_KEYS, _DISPLAY_TO_KEY, _extract_top_sectors() to industry_deep_dive.py to pre-extract top sectors from Phase 1 report and inject them into the prompt (Step 2) - Add tool-call nudge to run_tool_loop: if first LLM response has no tool calls and is under 500 chars, re-prompt with explicit instruction to call tools (Step 3) - Update scanner_tools.py get_industry_performance docstring to list all valid sector keys (Step 4) - Add 15 unit tests covering _extract_top_sectors, tool_runner nudge, and enriched output (Step 5) Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> * Address code review: move cols[3] access into try block for IndexError safety Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> * fix: align display row count with download count in get_industry_performance_yfinance The enriched function downloads price data for top 10 tickers but displayed 20 rows, causing rows 11-20 to show N/A in all price columns. This broke test_industry_perf_falls_back_to_yfinance which asserts N/A count < 5. Now both download and display use head(10) for consistency. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> Co-authored-by: Ahmet Guzererler <guzererler@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8279295348
commit
1362781291
|
|
@ -0,0 +1,242 @@
|
|||
"""Tests for the Industry Deep Dive improvements:
|
||||
|
||||
1. _extract_top_sectors() parses sector performance reports correctly
|
||||
2. Enriched get_industry_performance_yfinance returns price columns
|
||||
3. run_tool_loop nudge triggers when first response is short & no tool calls
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from tradingagents.agents.scanners.industry_deep_dive import (
|
||||
VALID_SECTOR_KEYS,
|
||||
_DISPLAY_TO_KEY,
|
||||
_extract_top_sectors,
|
||||
)
|
||||
from tradingagents.agents.utils.tool_runner import (
|
||||
run_tool_loop,
|
||||
MAX_TOOL_ROUNDS,
|
||||
MIN_REPORT_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_top_sectors tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_SECTOR_REPORT = """\
|
||||
# Sector Performance Overview
|
||||
# Data retrieved on: 2026-03-17 12:00:00
|
||||
|
||||
| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |
|
||||
|--------|---------|----------|-----------|-------|
|
||||
| Technology | +0.45% | +1.20% | +5.67% | +12.30% |
|
||||
| Healthcare | -0.12% | -0.50% | -2.10% | +3.40% |
|
||||
| Financials | +0.30% | +0.80% | +3.25% | +8.10% |
|
||||
| Energy | +1.10% | +2.50% | +7.80% | +15.20% |
|
||||
| Consumer Discretionary | -0.20% | -0.10% | -1.50% | +2.00% |
|
||||
| Consumer Staples | +0.05% | +0.30% | +0.90% | +4.50% |
|
||||
| Industrials | +0.25% | +0.60% | +2.80% | +6.70% |
|
||||
| Materials | +0.40% | +1.00% | +4.20% | +9.30% |
|
||||
| Real Estate | -0.35% | -0.80% | -3.40% | -1.20% |
|
||||
| Utilities | +0.10% | +0.20% | +1.10% | +5.60% |
|
||||
| Communication Services | +0.55% | +1.50% | +6.30% | +11.00% |
|
||||
"""
|
||||
|
||||
|
||||
class TestExtractTopSectors:
|
||||
"""Verify _extract_top_sectors parses the table correctly."""
|
||||
|
||||
def test_returns_top_3_by_absolute_1month(self):
|
||||
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=3)
|
||||
assert len(result) == 3
|
||||
# Energy (+7.80%), Communication Services (+6.30%), Technology (+5.67%)
|
||||
assert result[0] == "energy"
|
||||
assert result[1] == "communication-services"
|
||||
assert result[2] == "technology"
|
||||
|
||||
def test_returns_top_n_variable(self):
|
||||
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=5)
|
||||
assert len(result) == 5
|
||||
# All should be valid sector keys
|
||||
for key in result:
|
||||
assert key in VALID_SECTOR_KEYS, f"Invalid key: {key}"
|
||||
|
||||
def test_empty_report_returns_defaults(self):
|
||||
result = _extract_top_sectors("", top_n=3)
|
||||
assert result == VALID_SECTOR_KEYS[:3]
|
||||
|
||||
def test_none_report_returns_defaults(self):
|
||||
result = _extract_top_sectors(None, top_n=3)
|
||||
assert result == VALID_SECTOR_KEYS[:3]
|
||||
|
||||
def test_garbage_report_returns_defaults(self):
|
||||
result = _extract_top_sectors("not a table at all\njust random text", top_n=3)
|
||||
assert result == VALID_SECTOR_KEYS[:3]
|
||||
|
||||
def test_negative_returns_sorted_by_absolute_value(self):
|
||||
"""Sectors with large negative moves should rank high (big movers)."""
|
||||
report = """\
|
||||
| Sector | 1-Day % | 1-Week % | 1-Month % | YTD % |
|
||||
|--------|---------|----------|-----------|-------|
|
||||
| Technology | +0.10% | +0.20% | +1.00% | +2.00% |
|
||||
| Energy | -0.50% | -1.00% | -8.50% | -5.00% |
|
||||
| Healthcare | +0.05% | +0.10% | +0.50% | +1.00% |
|
||||
"""
|
||||
result = _extract_top_sectors(report, top_n=2)
|
||||
assert result[0] == "energy" # |-8.50| > |1.00|
|
||||
|
||||
def test_all_returned_keys_are_valid(self):
|
||||
result = _extract_top_sectors(SAMPLE_SECTOR_REPORT, top_n=11)
|
||||
for key in result:
|
||||
assert key in VALID_SECTOR_KEYS
|
||||
|
||||
def test_display_to_key_covers_all_sectors(self):
|
||||
"""Every sector name that appears in the ETF performance table
|
||||
should map to a valid key."""
|
||||
display_names = [
|
||||
"technology", "healthcare", "financials", "energy",
|
||||
"consumer discretionary", "consumer staples", "industrials",
|
||||
"materials", "real estate", "utilities", "communication services",
|
||||
]
|
||||
for name in display_names:
|
||||
assert name in _DISPLAY_TO_KEY, f"Missing mapping for '{name}'"
|
||||
assert _DISPLAY_TO_KEY[name] in VALID_SECTOR_KEYS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_tool_loop nudge tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolLoopNudge:
|
||||
"""Verify the nudge mechanism in run_tool_loop."""
|
||||
|
||||
def _make_chain(self, responses):
|
||||
"""Create a mock chain that returns responses in sequence."""
|
||||
chain = MagicMock()
|
||||
chain.invoke = MagicMock(side_effect=responses)
|
||||
return chain
|
||||
|
||||
def _make_tool(self, name="my_tool"):
|
||||
tool = MagicMock()
|
||||
tool.name = name
|
||||
tool.invoke = MagicMock(return_value="tool result")
|
||||
return tool
|
||||
|
||||
def test_long_response_no_nudge(self):
|
||||
"""A long first response (no tool calls) should be returned as-is."""
|
||||
long_text = "A" * 600
|
||||
response = AIMessage(content=long_text, tool_calls=[])
|
||||
chain = self._make_chain([response])
|
||||
tool = self._make_tool()
|
||||
|
||||
result = run_tool_loop(chain, [], [tool])
|
||||
assert result.content == long_text
|
||||
assert chain.invoke.call_count == 1
|
||||
|
||||
def test_short_response_triggers_nudge(self):
|
||||
"""A short first response triggers a nudge, then the LLM is re-invoked."""
|
||||
short_resp = AIMessage(content="Brief.", tool_calls=[])
|
||||
long_resp = AIMessage(content="A" * 600, tool_calls=[])
|
||||
chain = self._make_chain([short_resp, long_resp])
|
||||
tool = self._make_tool()
|
||||
|
||||
result = run_tool_loop(chain, [], [tool])
|
||||
assert result.content == long_resp.content
|
||||
assert chain.invoke.call_count == 2
|
||||
|
||||
# The second invoke should have received a HumanMessage nudge
|
||||
second_call_messages = chain.invoke.call_args_list[1][0][0]
|
||||
nudge_msgs = [m for m in second_call_messages if isinstance(m, HumanMessage)]
|
||||
assert len(nudge_msgs) == 1
|
||||
assert "MUST call at least one tool" in nudge_msgs[0].content
|
||||
|
||||
def test_nudge_only_on_first_round(self):
|
||||
"""Nudge should NOT trigger after tools have been used."""
|
||||
# Round 1: LLM calls a tool
|
||||
tool_call_resp = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "my_tool", "args": {}, "id": "tc1"}],
|
||||
)
|
||||
# Round 2: LLM returns a short text — no nudge expected
|
||||
short_resp = AIMessage(content="Done.", tool_calls=[])
|
||||
chain = self._make_chain([tool_call_resp, short_resp])
|
||||
tool = self._make_tool()
|
||||
|
||||
result = run_tool_loop(chain, [], [tool])
|
||||
assert result.content == "Done."
|
||||
assert chain.invoke.call_count == 2
|
||||
|
||||
def test_tool_calls_execute_normally(self):
|
||||
"""Normal tool-calling flow should still work unchanged."""
|
||||
tool_call_resp = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "my_tool", "args": {"x": 1}, "id": "tc1"}],
|
||||
)
|
||||
final_resp = AIMessage(content="Final report" * 50, tool_calls=[])
|
||||
chain = self._make_chain([tool_call_resp, final_resp])
|
||||
tool = self._make_tool()
|
||||
|
||||
result = run_tool_loop(chain, [], [tool])
|
||||
tool.invoke.assert_called_once_with({"x": 1})
|
||||
assert "Final report" in result.content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enriched industry performance tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnrichedIndustryPerformance:
|
||||
"""Verify that get_industry_performance_yfinance now returns price columns.
|
||||
|
||||
These tests require network access to Yahoo Finance. If the host is not
|
||||
reachable (e.g. in sandboxed CI), they are automatically skipped.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_yahoo(self):
|
||||
import socket
|
||||
try:
|
||||
socket.setdefaulttimeout(3)
|
||||
socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect(
|
||||
("query2.finance.yahoo.com", 443)
|
||||
)
|
||||
except (socket.error, OSError):
|
||||
pytest.skip("Yahoo Finance not reachable")
|
||||
|
||||
def test_technology_has_price_columns(self):
|
||||
from tradingagents.dataflows.yfinance_scanner import (
|
||||
get_industry_performance_yfinance,
|
||||
)
|
||||
|
||||
result = get_industry_performance_yfinance("technology")
|
||||
assert "# Industry Performance: Technology" in result
|
||||
# New columns should be present in the header
|
||||
assert "1-Day %" in result
|
||||
assert "1-Week %" in result
|
||||
assert "1-Month %" in result
|
||||
|
||||
def test_table_has_seven_columns(self):
|
||||
from tradingagents.dataflows.yfinance_scanner import (
|
||||
get_industry_performance_yfinance,
|
||||
)
|
||||
|
||||
result = get_industry_performance_yfinance("technology")
|
||||
lines = result.strip().split("\n")
|
||||
# Find the header separator line
|
||||
sep_lines = [l for l in lines if l.startswith("|") and "---" in l]
|
||||
assert len(sep_lines) >= 1
|
||||
# Count columns in separator
|
||||
cols = [c.strip() for c in sep_lines[0].split("|")[1:-1]]
|
||||
assert len(cols) == 7, f"Expected 7 columns, got {len(cols)}: {cols}"
|
||||
|
||||
def test_healthcare_sector_key(self):
|
||||
from tradingagents.dataflows.yfinance_scanner import (
|
||||
get_industry_performance_yfinance,
|
||||
)
|
||||
|
||||
result = get_industry_performance_yfinance("healthcare")
|
||||
assert "Industry Performance: Healthcare" in result
|
||||
assert "1-Day %" in result
|
||||
|
|
@ -1,7 +1,85 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import get_industry_performance, get_topic_news
|
||||
from tradingagents.agents.utils.tool_runner import run_tool_loop
|
||||
|
||||
# All valid sector keys accepted by yfinance Sector() and get_industry_performance.
|
||||
VALID_SECTOR_KEYS = [
|
||||
"technology",
|
||||
"healthcare",
|
||||
"financial-services",
|
||||
"energy",
|
||||
"consumer-cyclical",
|
||||
"consumer-defensive",
|
||||
"industrials",
|
||||
"basic-materials",
|
||||
"real-estate",
|
||||
"utilities",
|
||||
"communication-services",
|
||||
]
|
||||
|
||||
# Map display names used in the sector performance report to valid keys.
|
||||
_DISPLAY_TO_KEY = {
|
||||
"technology": "technology",
|
||||
"healthcare": "healthcare",
|
||||
"financials": "financial-services",
|
||||
"financial services": "financial-services",
|
||||
"energy": "energy",
|
||||
"consumer discretionary": "consumer-cyclical",
|
||||
"consumer staples": "consumer-defensive",
|
||||
"industrials": "industrials",
|
||||
"materials": "basic-materials",
|
||||
"basic materials": "basic-materials",
|
||||
"real estate": "real-estate",
|
||||
"utilities": "utilities",
|
||||
"communication services": "communication-services",
|
||||
}
|
||||
|
||||
|
||||
def _extract_top_sectors(sector_report: str, top_n: int = 3) -> list[str]:
|
||||
"""Parse the sector performance report and return the *top_n* sector keys
|
||||
ranked by absolute 1-month performance (largest absolute move first).
|
||||
|
||||
The sector performance table looks like:
|
||||
|
||||
| Technology | +0.45% | +1.20% | +5.67% | +12.3% |
|
||||
|
||||
We parse the 1-month column (index 3) and sort by absolute value.
|
||||
|
||||
Returns a list of valid sector keys (e.g. ``["technology", "energy"]``).
|
||||
Falls back to a sensible default if parsing fails.
|
||||
"""
|
||||
if not sector_report:
|
||||
return VALID_SECTOR_KEYS[:top_n]
|
||||
|
||||
rows: list[tuple[str, float]] = []
|
||||
for line in sector_report.split("\n"):
|
||||
if not line.startswith("|"):
|
||||
continue
|
||||
cols = [c.strip() for c in line.split("|")[1:-1]]
|
||||
if len(cols) < 4:
|
||||
continue
|
||||
sector_name = cols[0].lower()
|
||||
if sector_name in ("sector", "---", "") or "---" in sector_name:
|
||||
continue
|
||||
# Try to parse the 1-month column (index 3)
|
||||
try:
|
||||
month_str = cols[3].replace("%", "").replace("+", "").strip()
|
||||
month_val = float(month_str)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
key = _DISPLAY_TO_KEY.get(sector_name)
|
||||
if key:
|
||||
rows.append((key, month_val))
|
||||
|
||||
if not rows:
|
||||
return VALID_SECTOR_KEYS[:top_n]
|
||||
|
||||
# Sort by absolute 1-month move (biggest mover first)
|
||||
rows.sort(key=lambda r: abs(r[1]), reverse=True)
|
||||
return [r[0] for r in rows[:top_n]]
|
||||
|
||||
|
||||
def create_industry_deep_dive(llm):
|
||||
def industry_deep_dive_node(state):
|
||||
|
|
@ -9,6 +87,9 @@ def create_industry_deep_dive(llm):
|
|||
|
||||
tools = [get_industry_performance, get_topic_news]
|
||||
|
||||
sector_report = state.get("sector_performance_report", "")
|
||||
top_sectors = _extract_top_sectors(sector_report, top_n=3)
|
||||
|
||||
# Inject Phase 1 context so the LLM can decide which sectors to drill into
|
||||
phase1_context = f"""## Phase 1 Scanner Reports (for your reference)
|
||||
|
||||
|
|
@ -19,20 +100,29 @@ def create_industry_deep_dive(llm):
|
|||
{state.get("market_movers_report", "Not available")}
|
||||
|
||||
### Sector Performance Report:
|
||||
{state.get("sector_performance_report", "Not available")}
|
||||
{sector_report or "Not available"}
|
||||
"""
|
||||
|
||||
sector_list_str = ", ".join(f"'{s}'" for s in top_sectors)
|
||||
all_keys_str = ", ".join(f"'{s}'" for s in VALID_SECTOR_KEYS)
|
||||
|
||||
system_message = (
|
||||
"You are a senior research analyst performing an industry deep dive. "
|
||||
"You have received reports from three parallel scanners (geopolitical, market movers, sector performance). "
|
||||
"Review these reports and identify the 2-3 most promising sectors/industries to investigate further. "
|
||||
"Use get_industry_performance to drill into those sectors and get_topic_news for sector-specific news. "
|
||||
"Write a detailed report covering: "
|
||||
"(1) Why these industries were selected, "
|
||||
"(2) Top companies within each industry and their recent performance, "
|
||||
"(3) Industry-specific catalysts and risks, "
|
||||
"(4) Cross-references between geopolitical events and sector opportunities."
|
||||
f"\n\n{phase1_context}"
|
||||
"You are a senior research analyst performing an industry deep dive.\n\n"
|
||||
"## Your task\n"
|
||||
"Based on the Phase 1 reports below, drill into the most interesting sectors "
|
||||
"using the tools provided and write a detailed analysis.\n\n"
|
||||
"## IMPORTANT — You MUST call tools before writing your report\n"
|
||||
f"1. Call get_industry_performance for EACH of these top sectors: {sector_list_str}\n"
|
||||
"2. Call get_topic_news for at least 2 sector-specific topics "
|
||||
"(e.g., 'semiconductor industry', 'renewable energy stocks').\n"
|
||||
"3. After receiving tool results, write your detailed report.\n\n"
|
||||
f"Valid sector_key values for get_industry_performance: {all_keys_str}\n\n"
|
||||
"## Report structure\n"
|
||||
"(1) Why these industries were selected (link to Phase 1 findings)\n"
|
||||
"(2) Top companies within each industry and their recent performance\n"
|
||||
"(3) Industry-specific catalysts and risks\n"
|
||||
"(4) Cross-references between geopolitical events and sector opportunities\n\n"
|
||||
f"{phase1_context}"
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
|
|
|
|||
|
|
@ -52,11 +52,15 @@ def get_industry_performance(
|
|||
) -> str:
|
||||
"""
|
||||
Get industry-level drill-down within a specific sector.
|
||||
Shows top companies and industries in the sector.
|
||||
Shows top companies with rating, market weight, and recent price performance
|
||||
(1-day, 1-week, 1-month returns).
|
||||
Uses the configured scanner_data vendor.
|
||||
|
||||
Args:
|
||||
sector_key (str): Sector identifier (e.g., 'technology', 'healthcare', 'energy')
|
||||
sector_key (str): Sector identifier. Must be one of:
|
||||
'technology', 'healthcare', 'financial-services', 'energy',
|
||||
'consumer-cyclical', 'consumer-defensive', 'industrials',
|
||||
'basic-materials', 'real-estate', 'utilities', 'communication-services'
|
||||
|
||||
Returns:
|
||||
str: Formatted table of top companies/industries in the sector with performance data
|
||||
|
|
|
|||
|
|
@ -9,42 +9,72 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
|
||||
# Most LLM tool-calling patterns resolve within 2-3 rounds;
|
||||
# 5 provides headroom for complex scenarios while preventing runaway loops.
|
||||
MAX_TOOL_ROUNDS = 5
|
||||
|
||||
# If the LLM's first response has no tool calls AND is shorter than this,
|
||||
# a nudge message is appended to encourage tool usage.
|
||||
MIN_REPORT_LENGTH = 500
|
||||
|
||||
|
||||
def run_tool_loop(
|
||||
chain,
|
||||
messages: List[Any],
|
||||
tools: List[Any],
|
||||
max_rounds: int = MAX_TOOL_ROUNDS,
|
||||
min_report_length: int = MIN_REPORT_LENGTH,
|
||||
) -> AIMessage:
|
||||
"""Invoke *chain* in a loop, executing any tool calls until the LLM
|
||||
produces a final text response (i.e. no more tool_calls).
|
||||
|
||||
If the very first LLM response contains no tool calls **and** the text
|
||||
is shorter than *min_report_length*, the loop appends a nudge message
|
||||
asking the LLM to call tools first, then re-invokes once before
|
||||
accepting the response. This prevents under-powered models from
|
||||
skipping tool use when overwhelmed by long context.
|
||||
|
||||
Args:
|
||||
chain: A LangChain runnable (prompt | llm.bind_tools(tools)).
|
||||
messages: The initial list of messages to send.
|
||||
tools: List of LangChain tool objects (must match the tools bound to the LLM).
|
||||
max_rounds: Maximum number of tool-calling rounds before forcing a stop.
|
||||
min_report_length: Minimum acceptable length (chars) of a text-only
|
||||
first response. Shorter responses trigger a nudge to use tools.
|
||||
|
||||
Returns:
|
||||
The final AIMessage with a text ``content`` (report).
|
||||
"""
|
||||
tool_map = {t.name: t for t in tools}
|
||||
current_messages = list(messages)
|
||||
first_round = True
|
||||
|
||||
for _ in range(max_rounds):
|
||||
result: AIMessage = chain.invoke(current_messages)
|
||||
current_messages.append(result)
|
||||
|
||||
if not result.tool_calls:
|
||||
# Nudge: if the LLM skipped tools on its first turn and the
|
||||
# response is suspiciously short, ask it to try again with tools.
|
||||
if first_round and len(result.content or "") < min_report_length:
|
||||
tool_names = ", ".join(tool_map.keys())
|
||||
nudge = (
|
||||
"Your response was too brief. You MUST call at least one tool "
|
||||
f"({tool_names}) before writing your final report. "
|
||||
"Please call the tools now."
|
||||
)
|
||||
current_messages.append(
|
||||
HumanMessage(content=nudge)
|
||||
)
|
||||
first_round = False
|
||||
continue
|
||||
return result
|
||||
|
||||
first_round = False
|
||||
|
||||
# Execute each requested tool call and append ToolMessages
|
||||
for tc in result.tool_calls:
|
||||
tool_name = tc["name"]
|
||||
|
|
|
|||
|
|
@ -249,6 +249,10 @@ def get_industry_performance_yfinance(
|
|||
) -> str:
|
||||
"""
|
||||
Get industry-level drill-down within a sector.
|
||||
|
||||
Returns top companies with metadata (rating, market weight) **plus**
|
||||
recent price performance (1-day, 1-week, 1-month returns) obtained
|
||||
via a single batched ``yf.download()`` call for the top 10 tickers.
|
||||
|
||||
Args:
|
||||
sector_key: Sector identifier (e.g., 'technology', 'healthcare')
|
||||
|
|
@ -265,17 +269,44 @@ def get_industry_performance_yfinance(
|
|||
|
||||
if top_companies is None or top_companies.empty:
|
||||
return f"No industry data found for sector '{sector_key}'"
|
||||
|
||||
|
||||
# --- Batch-download price history for the top 10 tickers ----------
|
||||
tickers = list(top_companies.head(10).index)
|
||||
price_returns: dict[str, dict[str, float | None]] = {}
|
||||
try:
|
||||
hist = yf.download(
|
||||
tickers, period="1mo", auto_adjust=True, progress=False, threads=True,
|
||||
)
|
||||
for tkr in tickers:
|
||||
try:
|
||||
if len(tickers) > 1:
|
||||
closes = hist["Close"][tkr].dropna()
|
||||
else:
|
||||
closes = hist["Close"].dropna()
|
||||
if closes.empty or len(closes) < 2:
|
||||
continue
|
||||
price_returns[tkr] = {
|
||||
"1d": _safe_pct(closes, 1),
|
||||
"1w": _safe_pct(closes, 5),
|
||||
"1m": _safe_pct(closes, len(closes) - 1),
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass # Fall through — table will show N/A for returns
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
header = f"# Industry Performance: {sector_key.replace('-', ' ').title()}\n"
|
||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
|
||||
result_str = header
|
||||
result_str += "| Company | Symbol | Rating | Market Weight |\n"
|
||||
result_str += "|---------|--------|--------|---------------|\n"
|
||||
result_str += "| Company | Symbol | Rating | Market Weight | 1-Day % | 1-Week % | 1-Month % |\n"
|
||||
result_str += "|---------|--------|--------|---------------|---------|----------|-----------|\n"
|
||||
|
||||
# top_companies has ticker as the DataFrame index (index.name == 'symbol')
|
||||
# Columns: name, rating, market weight
|
||||
for symbol, row in top_companies.head(20).iterrows():
|
||||
# Display only the tickers we downloaded prices for to avoid N/A gaps
|
||||
for symbol, row in top_companies.head(10).iterrows():
|
||||
name = row.get('name', 'N/A')
|
||||
rating = row.get('rating', 'N/A')
|
||||
market_weight = row.get('market weight', None)
|
||||
|
|
@ -283,7 +314,15 @@ def get_industry_performance_yfinance(
|
|||
name_short = name[:30] if isinstance(name, str) else str(name)
|
||||
weight_str = f"{market_weight:.2%}" if isinstance(market_weight, (int, float)) else "N/A"
|
||||
|
||||
result_str += f"| {name_short} | {symbol} | {rating} | {weight_str} |\n"
|
||||
ret = price_returns.get(symbol, {})
|
||||
day_str = f"{ret['1d']:+.2f}%" if ret.get('1d') is not None else "N/A"
|
||||
week_str = f"{ret['1w']:+.2f}%" if ret.get('1w') is not None else "N/A"
|
||||
month_str = f"{ret['1m']:+.2f}%" if ret.get('1m') is not None else "N/A"
|
||||
|
||||
result_str += (
|
||||
f"| {name_short} | {symbol} | {rating} | {weight_str}"
|
||||
f" | {day_str} | {week_str} | {month_str} |\n"
|
||||
)
|
||||
|
||||
return result_str
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue