Merge pull request #92 from aguzererler/copilot/fix-report-saving-in-runs
Fix report persistence, run status tracking, auto-mode stock sourcing, and portfolio context loading
This commit is contained in:
commit
36a6b17a22
|
|
@ -1,5 +1,5 @@
|
|||
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any, List, AsyncGenerator
|
||||
import logging
|
||||
import uuid
|
||||
import time
|
||||
|
|
@ -13,6 +13,21 @@ router = APIRouter(prefix="/api/run", tags=["runs"])
|
|||
|
||||
engine = LangGraphEngine()
|
||||
|
||||
|
||||
async def _run_and_store(run_id: str, gen: AsyncGenerator[Dict[str, Any], None]) -> None:
|
||||
"""Drive an engine generator, updating run status and caching events."""
|
||||
runs[run_id]["status"] = "running"
|
||||
runs[run_id]["events"] = []
|
||||
try:
|
||||
async for event in gen:
|
||||
runs[run_id]["events"].append(event)
|
||||
runs[run_id]["status"] = "completed"
|
||||
except Exception as exc:
|
||||
runs[run_id]["status"] = "failed"
|
||||
runs[run_id]["error"] = str(exc)
|
||||
logger.exception("Run failed run=%s", run_id)
|
||||
|
||||
|
||||
@router.post("/scan")
|
||||
async def trigger_scan(
|
||||
background_tasks: BackgroundTasks,
|
||||
|
|
@ -29,7 +44,7 @@ async def trigger_scan(
|
|||
"params": params or {}
|
||||
}
|
||||
logger.info("Queued SCAN run=%s user=%s", run_id, user["user_id"])
|
||||
background_tasks.add_task(engine.run_scan, run_id, params or {})
|
||||
background_tasks.add_task(_run_and_store, run_id, engine.run_scan(run_id, params or {}))
|
||||
return {"run_id": run_id, "status": "queued"}
|
||||
|
||||
@router.post("/pipeline")
|
||||
|
|
@ -48,7 +63,7 @@ async def trigger_pipeline(
|
|||
"params": params or {}
|
||||
}
|
||||
logger.info("Queued PIPELINE run=%s user=%s", run_id, user["user_id"])
|
||||
background_tasks.add_task(engine.run_pipeline, run_id, params or {})
|
||||
background_tasks.add_task(_run_and_store, run_id, engine.run_pipeline(run_id, params or {}))
|
||||
return {"run_id": run_id, "status": "queued"}
|
||||
|
||||
@router.post("/portfolio")
|
||||
|
|
@ -67,7 +82,7 @@ async def trigger_portfolio(
|
|||
"params": params or {}
|
||||
}
|
||||
logger.info("Queued PORTFOLIO run=%s user=%s", run_id, user["user_id"])
|
||||
background_tasks.add_task(engine.run_portfolio, run_id, params or {})
|
||||
background_tasks.add_task(_run_and_store, run_id, engine.run_portfolio(run_id, params or {}))
|
||||
return {"run_id": run_id, "status": "queued"}
|
||||
|
||||
@router.post("/auto")
|
||||
|
|
@ -86,7 +101,7 @@ async def trigger_auto(
|
|||
"params": params or {}
|
||||
}
|
||||
logger.info("Queued AUTO run=%s user=%s", run_id, user["user_id"])
|
||||
background_tasks.add_task(engine.run_auto, run_id, params or {})
|
||||
background_tasks.add_task(_run_and_store, run_id, engine.run_auto(run_id, params or {}))
|
||||
return {"run_id": run_id, "status": "queued"}
|
||||
|
||||
@router.get("/")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,9 @@ logger = logging.getLogger("agent_os.websocket")
|
|||
|
||||
router = APIRouter(prefix="/ws", tags=["websocket"])
|
||||
|
||||
# Polling interval when streaming cached events from a background-task-driven run
|
||||
_EVENT_POLL_INTERVAL_SECONDS = 0.05
|
||||
|
||||
engine = LangGraphEngine()
|
||||
|
||||
@router.websocket("/stream/{run_id}")
|
||||
|
|
@ -33,6 +36,35 @@ async def websocket_endpoint(
|
|||
params = run_info.get("params", {})
|
||||
|
||||
try:
|
||||
status = run_info.get("status", "queued")
|
||||
|
||||
if status in ("running", "completed", "failed"):
|
||||
# Background task is already executing (or finished) — stream its cached events
|
||||
# then wait for completion if still running.
|
||||
logger.info(
|
||||
"WebSocket streaming from cache run=%s status=%s", run_id, status
|
||||
)
|
||||
sent = 0
|
||||
while True:
|
||||
cached = run_info.get("events") or []
|
||||
while sent < len(cached):
|
||||
payload = cached[sent]
|
||||
if "timestamp" not in payload:
|
||||
payload["timestamp"] = time.strftime("%H:%M:%S")
|
||||
await websocket.send_json(payload)
|
||||
sent += 1
|
||||
current_status = run_info.get("status")
|
||||
if current_status in ("completed", "failed"):
|
||||
break
|
||||
# Yield to the event loop so the background task can produce more events
|
||||
await asyncio.sleep(_EVENT_POLL_INTERVAL_SECONDS)
|
||||
|
||||
if run_info.get("status") == "failed":
|
||||
await websocket.send_json(
|
||||
{"type": "system", "message": f"Run failed: {run_info.get('error', 'unknown error')}"}
|
||||
)
|
||||
else:
|
||||
# status == "queued" — WebSocket is the executor (background task didn't start yet)
|
||||
stream_gen = None
|
||||
if run_type == "scan":
|
||||
stream_gen = engine.run_scan(run_id, params)
|
||||
|
|
@ -44,8 +76,11 @@ async def websocket_endpoint(
|
|||
stream_gen = engine.run_auto(run_id, params)
|
||||
|
||||
if stream_gen:
|
||||
run_info["status"] = "running"
|
||||
run_info.setdefault("events", [])
|
||||
try:
|
||||
async for payload in stream_gen:
|
||||
# Add timestamp if not present
|
||||
run_info["events"].append(payload)
|
||||
if "timestamp" not in payload:
|
||||
payload["timestamp"] = time.strftime("%H:%M:%S")
|
||||
await websocket.send_json(payload)
|
||||
|
|
@ -55,6 +90,11 @@ async def websocket_endpoint(
|
|||
payload.get("node_id"),
|
||||
run_id,
|
||||
)
|
||||
run_info["status"] = "completed"
|
||||
except Exception as exc:
|
||||
run_info["status"] = "failed"
|
||||
run_info["error"] = str(exc)
|
||||
raise
|
||||
else:
|
||||
msg = f"Run type '{run_type}' streaming not yet implemented."
|
||||
logger.warning(msg)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
import asyncio
|
||||
import datetime as _dt
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, AsyncGenerator
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.graph.scanner_graph import ScannerGraph
|
||||
from tradingagents.graph.portfolio_graph import PortfolioGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.report_paths import get_market_dir, get_ticker_dir
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
from tradingagents.daily_digest import append_to_digest
|
||||
from tradingagents.agents.utils.json_utils import extract_json
|
||||
|
||||
logger = logging.getLogger("agent_os.engine")
|
||||
|
||||
|
|
@ -54,14 +60,75 @@ class LangGraphEngine:
|
|||
}
|
||||
|
||||
self._node_start_times[run_id] = {}
|
||||
final_state: Dict[str, Any] = {}
|
||||
|
||||
async for event in scanner.graph.astream_events(initial_state, version="v2"):
|
||||
# Capture the complete final state from the root graph's terminal event.
|
||||
# LangGraph v2 emits one root-level on_chain_end (parent_ids=[], no
|
||||
# langgraph_node in metadata) whose data.output is the full accumulated state.
|
||||
if self._is_root_chain_end(event):
|
||||
output = (event.get("data") or {}).get("output")
|
||||
if isinstance(output, dict):
|
||||
final_state = output
|
||||
mapped = self._map_langgraph_event(run_id, event)
|
||||
if mapped:
|
||||
yield mapped
|
||||
|
||||
self._node_start_times.pop(run_id, None)
|
||||
self._node_prompts.pop(run_id, None)
|
||||
|
||||
# Save scan reports to disk
|
||||
if final_state:
|
||||
yield self._system_log("Saving scan reports to disk…")
|
||||
try:
|
||||
save_dir = get_market_dir(date)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for key in (
|
||||
"geopolitical_report",
|
||||
"market_movers_report",
|
||||
"sector_performance_report",
|
||||
"industry_deep_dive_report",
|
||||
"macro_scan_summary",
|
||||
):
|
||||
content = final_state.get(key, "")
|
||||
if content:
|
||||
(save_dir / f"{key}.md").write_text(content)
|
||||
|
||||
# Parse and save macro_scan_summary.json via ReportStore for downstream use
|
||||
summary_text = final_state.get("macro_scan_summary", "")
|
||||
if summary_text:
|
||||
try:
|
||||
summary_data = extract_json(summary_text)
|
||||
ReportStore().save_scan(date, summary_data)
|
||||
except (ValueError, KeyError, TypeError):
|
||||
logger.warning(
|
||||
"macro_scan_summary for date=%s is not valid JSON "
|
||||
"(summary already saved as .md — downstream loads may fail)",
|
||||
date,
|
||||
)
|
||||
|
||||
# Append to daily digest
|
||||
scan_parts = []
|
||||
for key, label in (
|
||||
("geopolitical_report", "Geopolitical & Macro"),
|
||||
("market_movers_report", "Market Movers"),
|
||||
("sector_performance_report", "Sector Performance"),
|
||||
("industry_deep_dive_report", "Industry Deep Dive"),
|
||||
("macro_scan_summary", "Macro Scan Summary"),
|
||||
):
|
||||
content = final_state.get(key, "")
|
||||
if content:
|
||||
scan_parts.append(f"### {label}\n{content}")
|
||||
if scan_parts:
|
||||
append_to_digest(date, "scan", "Market Scan", "\n\n".join(scan_parts))
|
||||
|
||||
yield self._system_log(f"Scan reports saved to {save_dir}")
|
||||
logger.info("Saved scan reports run=%s date=%s dir=%s", run_id, date, save_dir)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to save scan reports run=%s", run_id)
|
||||
yield self._system_log(f"Warning: could not save scan reports: {exc}")
|
||||
|
||||
logger.info("Completed SCAN run=%s", run_id)
|
||||
|
||||
async def run_pipeline(
|
||||
|
|
@ -86,18 +153,53 @@ class LangGraphEngine:
|
|||
initial_state = graph_wrapper.propagator.create_initial_state(ticker, date)
|
||||
|
||||
self._node_start_times[run_id] = {}
|
||||
final_state: Dict[str, Any] = {}
|
||||
|
||||
async for event in graph_wrapper.graph.astream_events(
|
||||
initial_state,
|
||||
version="v2",
|
||||
config={"recursion_limit": graph_wrapper.propagator.max_recur_limit},
|
||||
):
|
||||
# Capture the complete final state from the root graph's terminal event.
|
||||
if self._is_root_chain_end(event):
|
||||
output = (event.get("data") or {}).get("output")
|
||||
if isinstance(output, dict):
|
||||
final_state = output
|
||||
mapped = self._map_langgraph_event(run_id, event)
|
||||
if mapped:
|
||||
yield mapped
|
||||
|
||||
self._node_start_times.pop(run_id, None)
|
||||
self._node_prompts.pop(run_id, None)
|
||||
|
||||
# Save pipeline reports to disk
|
||||
if final_state:
|
||||
yield self._system_log(f"Saving analysis report for {ticker}…")
|
||||
try:
|
||||
save_dir = get_ticker_dir(date, ticker)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save JSON via ReportStore (complete_report.json)
|
||||
ReportStore().save_analysis(date, ticker, final_state)
|
||||
|
||||
# Write human-readable complete_report.md
|
||||
self._write_complete_report_md(final_state, ticker, save_dir)
|
||||
|
||||
# Append to daily digest
|
||||
digest_content = (
|
||||
final_state.get("final_trade_decision")
|
||||
or final_state.get("trader_investment_plan")
|
||||
or ""
|
||||
)
|
||||
if digest_content:
|
||||
append_to_digest(date, "analyze", ticker, digest_content)
|
||||
|
||||
yield self._system_log(f"Analysis report for {ticker} saved to {save_dir}")
|
||||
logger.info("Saved pipeline report run=%s ticker=%s dir=%s", run_id, ticker, save_dir)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to save pipeline reports run=%s ticker=%s", run_id, ticker)
|
||||
yield self._system_log(f"Warning: could not save analysis report for {ticker}: {exc}")
|
||||
|
||||
logger.info("Completed PIPELINE run=%s", run_id)
|
||||
|
||||
async def run_portfolio(
|
||||
|
|
@ -117,9 +219,34 @@ class LangGraphEngine:
|
|||
|
||||
portfolio_graph = PortfolioGraph(config=self.config)
|
||||
|
||||
# Load scan summary and per-ticker analyses from the daily report folder
|
||||
store = ReportStore()
|
||||
scan_summary = store.load_scan(date) or {}
|
||||
ticker_analyses: Dict[str, Any] = {}
|
||||
|
||||
from tradingagents.report_paths import get_daily_dir
|
||||
daily_dir = get_daily_dir(date)
|
||||
if daily_dir.exists():
|
||||
for ticker_dir in daily_dir.iterdir():
|
||||
if ticker_dir.is_dir() and ticker_dir.name not in ("market", "portfolio"):
|
||||
analysis = store.load_analysis(date, ticker_dir.name)
|
||||
if analysis:
|
||||
ticker_analyses[ticker_dir.name] = analysis
|
||||
|
||||
if scan_summary:
|
||||
yield self._system_log(f"Loaded macro scan summary for {date}")
|
||||
else:
|
||||
yield self._system_log(f"No scan summary found for {date}, proceeding without it")
|
||||
if ticker_analyses:
|
||||
yield self._system_log(f"Loaded analyses for: {', '.join(sorted(ticker_analyses.keys()))}")
|
||||
else:
|
||||
yield self._system_log("No per-ticker analyses found for this date")
|
||||
|
||||
initial_state = {
|
||||
"portfolio_id": portfolio_id,
|
||||
"scan_date": date,
|
||||
"scan_summary": scan_summary,
|
||||
"ticker_analyses": ticker_analyses,
|
||||
"messages": [],
|
||||
}
|
||||
|
||||
|
|
@ -150,27 +277,131 @@ class LangGraphEngine:
|
|||
async for evt in self.run_scan(f"{run_id}_scan", {"date": date}):
|
||||
yield evt
|
||||
|
||||
# Phase 2: Pipeline analysis (default ticker for now)
|
||||
ticker = params.get("ticker", "AAPL")
|
||||
# Phase 2: Pipeline analysis — get tickers from saved scan report
|
||||
yield self._system_log("Phase 2/3: Loading stocks from scan report…")
|
||||
scan_data = ReportStore().load_scan(date)
|
||||
tickers = self._extract_tickers_from_scan_data(scan_data)
|
||||
|
||||
if not tickers:
|
||||
yield self._system_log(
|
||||
"Warning: no stocks found in scan summary — ensure the scan completed "
|
||||
"successfully and produced a 'stocks_to_investigate' list. "
|
||||
"Skipping pipeline phase."
|
||||
)
|
||||
else:
|
||||
for ticker in tickers:
|
||||
yield self._system_log(f"Phase 2/3: Running analysis pipeline for {ticker}…")
|
||||
async for evt in self.run_pipeline(
|
||||
f"{run_id}_pipeline", {"ticker": ticker, "date": date}
|
||||
f"{run_id}_pipeline_{ticker}", {"ticker": ticker, "date": date}
|
||||
):
|
||||
yield evt
|
||||
|
||||
# Phase 3: Portfolio management
|
||||
yield self._system_log("Phase 3/3: Running portfolio manager…")
|
||||
portfolio_params = {k: v for k, v in params.items() if k != "ticker"}
|
||||
async for evt in self.run_portfolio(
|
||||
f"{run_id}_portfolio", {"date": date, **params}
|
||||
f"{run_id}_portfolio", {"date": date, **portfolio_params}
|
||||
):
|
||||
yield evt
|
||||
|
||||
logger.info("Completed AUTO run=%s", run_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Report helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _write_complete_report_md(
|
||||
final_state: Dict[str, Any], ticker: str, save_dir: Path
|
||||
) -> None:
|
||||
"""Write a human-readable complete_report.md from the pipeline final state."""
|
||||
sections = []
|
||||
header = (
|
||||
f"# Trading Analysis Report: {ticker}\n\n"
|
||||
f"Generated: {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
)
|
||||
|
||||
analyst_parts = []
|
||||
for key, label in (
|
||||
("market_report", "Market Analyst"),
|
||||
("sentiment_report", "Social Analyst"),
|
||||
("news_report", "News Analyst"),
|
||||
("fundamentals_report", "Fundamentals Analyst"),
|
||||
):
|
||||
if final_state.get(key):
|
||||
analyst_parts.append(f"### {label}\n{final_state[key]}")
|
||||
if analyst_parts:
|
||||
sections.append("## I. Analyst Team Reports\n\n" + "\n\n".join(analyst_parts))
|
||||
|
||||
if final_state.get("investment_plan"):
|
||||
sections.append(f"## II. Research Team Decision\n\n{final_state['investment_plan']}")
|
||||
|
||||
if final_state.get("trader_investment_plan"):
|
||||
sections.append(f"## III. Trading Team Plan\n\n{final_state['trader_investment_plan']}")
|
||||
|
||||
if final_state.get("final_trade_decision"):
|
||||
sections.append(f"## IV. Final Decision\n\n{final_state['final_trade_decision']}")
|
||||
|
||||
(save_dir / "complete_report.md").write_text(header + "\n\n".join(sections))
|
||||
|
||||
@staticmethod
|
||||
def _extract_tickers_from_scan_data(scan_data: Dict[str, Any] | None) -> list[str]:
|
||||
"""Extract ticker symbols from a ReportStore scan summary dict.
|
||||
|
||||
Handles two shapes from the macro synthesis LLM output:
|
||||
* List of dicts: ``[{"ticker": "AAPL", ...}, ...]``
|
||||
* List of strings: ``["AAPL", "TSLA", ...]``
|
||||
|
||||
Also checks both ``stocks_to_investigate`` and ``watchlist`` keys.
|
||||
Returns an uppercase, deduplicated list in original order.
|
||||
"""
|
||||
if not scan_data:
|
||||
return []
|
||||
raw_stocks = (
|
||||
scan_data.get("stocks_to_investigate")
|
||||
or scan_data.get("watchlist")
|
||||
or []
|
||||
)
|
||||
seen: set[str] = set()
|
||||
tickers: list[str] = []
|
||||
for item in raw_stocks:
|
||||
if isinstance(item, dict):
|
||||
sym = item.get("ticker") or item.get("symbol") or ""
|
||||
elif isinstance(item, str):
|
||||
sym = item
|
||||
else:
|
||||
continue
|
||||
sym = sym.strip().upper()
|
||||
if sym and sym not in seen:
|
||||
seen.add(sym)
|
||||
tickers.append(sym)
|
||||
return tickers
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event mapping
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _is_root_chain_end(event: Dict[str, Any]) -> bool:
|
||||
"""Return True for the root-graph terminal event in a LangGraph v2 stream.
|
||||
|
||||
LangGraph v2 emits one ``on_chain_end`` event per node AND one for the
|
||||
root graph itself. The root-graph event is distinguished by:
|
||||
|
||||
* ``event["metadata"]`` has no ``langgraph_node`` key (node events always do)
|
||||
* ``event["parent_ids"]`` is empty (root has no parent run)
|
||||
|
||||
Its ``data["output"]`` contains the **complete** final state — the
|
||||
canonical way to read the propagated state without re-running the graph.
|
||||
"""
|
||||
if event.get("event") != "on_chain_end":
|
||||
return False
|
||||
metadata = event.get("metadata") or {}
|
||||
if metadata.get("langgraph_node"):
|
||||
return False # This is a node event, not the root
|
||||
parent_ids = event.get("parent_ids")
|
||||
return parent_ids is not None and len(parent_ids) == 0
|
||||
|
||||
@staticmethod
|
||||
def _extract_node_name(event: Dict[str, Any]) -> str:
|
||||
"""Extract the LangGraph node name from event metadata or tags."""
|
||||
|
|
|
|||
|
|
@ -168,6 +168,94 @@ class TestLangGraphEngineExtraction(unittest.TestCase):
|
|||
result = self.engine._map_langgraph_event("run_123", event)
|
||||
self.assertIsNone(result)
|
||||
|
||||
# ── _is_root_chain_end ──────────────────────────────────────────
|
||||
|
||||
def test_is_root_chain_end_true(self):
|
||||
"""Root graph terminal event: on_chain_end with empty parent_ids and no node."""
|
||||
event = {
|
||||
"event": "on_chain_end",
|
||||
"name": "LangGraph",
|
||||
"parent_ids": [],
|
||||
"metadata": {},
|
||||
"data": {"output": {"x": "final"}},
|
||||
}
|
||||
self.assertTrue(self.engine._is_root_chain_end(event))
|
||||
|
||||
def test_is_root_chain_end_false_for_node_event(self):
|
||||
"""Node on_chain_end should NOT be treated as the root graph end."""
|
||||
event = {
|
||||
"event": "on_chain_end",
|
||||
"name": "geopolitical_scanner",
|
||||
"parent_ids": ["some-parent-run-id"],
|
||||
"metadata": {"langgraph_node": "geopolitical_scanner"},
|
||||
"data": {"output": {"geopolitical_report": "..."}},
|
||||
}
|
||||
self.assertFalse(self.engine._is_root_chain_end(event))
|
||||
|
||||
def test_is_root_chain_end_false_for_non_chain_end(self):
|
||||
"""on_chain_start should never match."""
|
||||
event = {
|
||||
"event": "on_chain_start",
|
||||
"name": "LangGraph",
|
||||
"parent_ids": [],
|
||||
"metadata": {},
|
||||
"data": {},
|
||||
}
|
||||
self.assertFalse(self.engine._is_root_chain_end(event))
|
||||
|
||||
def test_is_root_chain_end_false_when_parent_ids_missing(self):
|
||||
"""If parent_ids is absent (unexpected), should not match."""
|
||||
event = {
|
||||
"event": "on_chain_end",
|
||||
"name": "LangGraph",
|
||||
"metadata": {},
|
||||
"data": {"output": {"x": "v"}},
|
||||
}
|
||||
self.assertFalse(self.engine._is_root_chain_end(event))
|
||||
|
||||
def test_is_root_chain_end_false_when_langgraph_node_present(self):
|
||||
"""Node-level event with empty parent_ids should still not match."""
|
||||
event = {
|
||||
"event": "on_chain_end",
|
||||
"name": "some_node",
|
||||
"parent_ids": [],
|
||||
"metadata": {"langgraph_node": "some_node"},
|
||||
"data": {"output": {}},
|
||||
}
|
||||
self.assertFalse(self.engine._is_root_chain_end(event))
|
||||
|
||||
# ── _extract_tickers_from_scan_data ─────────────────────────────
|
||||
|
||||
def test_extract_tickers_list_of_dicts(self):
|
||||
scan = {"stocks_to_investigate": [
|
||||
{"ticker": "AAPL", "name": "Apple"},
|
||||
{"ticker": "tsla", "sector": "EV"},
|
||||
]}
|
||||
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["AAPL", "TSLA"])
|
||||
|
||||
def test_extract_tickers_list_of_strings(self):
|
||||
scan = {"watchlist": ["msft", "GOOG", "amzn"]}
|
||||
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["MSFT", "GOOG", "AMZN"])
|
||||
|
||||
def test_extract_tickers_prefers_stocks_to_investigate(self):
|
||||
scan = {
|
||||
"stocks_to_investigate": [{"ticker": "NVDA"}],
|
||||
"watchlist": [{"ticker": "AMD"}],
|
||||
}
|
||||
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["NVDA"])
|
||||
|
||||
def test_extract_tickers_deduplicates(self):
|
||||
scan = {"stocks_to_investigate": ["AAPL", "aapl", "AAPL"]}
|
||||
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["AAPL"])
|
||||
|
||||
def test_extract_tickers_empty_scan(self):
|
||||
self.assertEqual(self.engine._extract_tickers_from_scan_data(None), [])
|
||||
self.assertEqual(self.engine._extract_tickers_from_scan_data({}), [])
|
||||
|
||||
def test_extract_tickers_symbol_key_fallback(self):
|
||||
scan = {"stocks_to_investigate": [{"symbol": "META"}]}
|
||||
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["META"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue