Merge pull request #92 from aguzererler/copilot/fix-report-saving-in-runs

Fix report persistence, run status tracking, auto-mode stock sourcing, and portfolio context loading
This commit is contained in:
ahmet guzererler 2026-03-23 19:55:42 +01:00 committed by GitHub
commit 36a6b17a22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 412 additions and 38 deletions

View File

@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException
from typing import Dict, Any, List
from typing import Dict, Any, List, AsyncGenerator
import logging
import uuid
import time
@ -13,6 +13,21 @@ router = APIRouter(prefix="/api/run", tags=["runs"])
engine = LangGraphEngine()
async def _run_and_store(run_id: str, gen: AsyncGenerator[Dict[str, Any], None]) -> None:
"""Drive an engine generator, updating run status and caching events."""
runs[run_id]["status"] = "running"
runs[run_id]["events"] = []
try:
async for event in gen:
runs[run_id]["events"].append(event)
runs[run_id]["status"] = "completed"
except Exception as exc:
runs[run_id]["status"] = "failed"
runs[run_id]["error"] = str(exc)
logger.exception("Run failed run=%s", run_id)
@router.post("/scan")
async def trigger_scan(
background_tasks: BackgroundTasks,
@ -29,7 +44,7 @@ async def trigger_scan(
"params": params or {}
}
logger.info("Queued SCAN run=%s user=%s", run_id, user["user_id"])
background_tasks.add_task(engine.run_scan, run_id, params or {})
background_tasks.add_task(_run_and_store, run_id, engine.run_scan(run_id, params or {}))
return {"run_id": run_id, "status": "queued"}
@router.post("/pipeline")
@ -48,7 +63,7 @@ async def trigger_pipeline(
"params": params or {}
}
logger.info("Queued PIPELINE run=%s user=%s", run_id, user["user_id"])
background_tasks.add_task(engine.run_pipeline, run_id, params or {})
background_tasks.add_task(_run_and_store, run_id, engine.run_pipeline(run_id, params or {}))
return {"run_id": run_id, "status": "queued"}
@router.post("/portfolio")
@ -67,7 +82,7 @@ async def trigger_portfolio(
"params": params or {}
}
logger.info("Queued PORTFOLIO run=%s user=%s", run_id, user["user_id"])
background_tasks.add_task(engine.run_portfolio, run_id, params or {})
background_tasks.add_task(_run_and_store, run_id, engine.run_portfolio(run_id, params or {}))
return {"run_id": run_id, "status": "queued"}
@router.post("/auto")
@ -86,7 +101,7 @@ async def trigger_auto(
"params": params or {}
}
logger.info("Queued AUTO run=%s user=%s", run_id, user["user_id"])
background_tasks.add_task(engine.run_auto, run_id, params or {})
background_tasks.add_task(_run_and_store, run_id, engine.run_auto(run_id, params or {}))
return {"run_id": run_id, "status": "queued"}
@router.get("/")

View File

@ -12,6 +12,9 @@ logger = logging.getLogger("agent_os.websocket")
router = APIRouter(prefix="/ws", tags=["websocket"])
# Polling interval when streaming cached events from a background-task-driven run
_EVENT_POLL_INTERVAL_SECONDS = 0.05
engine = LangGraphEngine()
@router.websocket("/stream/{run_id}")
@ -33,33 +36,70 @@ async def websocket_endpoint(
params = run_info.get("params", {})
try:
stream_gen = None
if run_type == "scan":
stream_gen = engine.run_scan(run_id, params)
elif run_type == "pipeline":
stream_gen = engine.run_pipeline(run_id, params)
elif run_type == "portfolio":
stream_gen = engine.run_portfolio(run_id, params)
elif run_type == "auto":
stream_gen = engine.run_auto(run_id, params)
if stream_gen:
async for payload in stream_gen:
# Add timestamp if not present
if "timestamp" not in payload:
payload["timestamp"] = time.strftime("%H:%M:%S")
await websocket.send_json(payload)
logger.debug(
"Sent event type=%s node=%s run=%s",
payload.get("type"),
payload.get("node_id"),
run_id,
status = run_info.get("status", "queued")
if status in ("running", "completed", "failed"):
# Background task is already executing (or finished) — stream its cached events
# then wait for completion if still running.
logger.info(
"WebSocket streaming from cache run=%s status=%s", run_id, status
)
sent = 0
while True:
cached = run_info.get("events") or []
while sent < len(cached):
payload = cached[sent]
if "timestamp" not in payload:
payload["timestamp"] = time.strftime("%H:%M:%S")
await websocket.send_json(payload)
sent += 1
current_status = run_info.get("status")
if current_status in ("completed", "failed"):
break
# Yield to the event loop so the background task can produce more events
await asyncio.sleep(_EVENT_POLL_INTERVAL_SECONDS)
if run_info.get("status") == "failed":
await websocket.send_json(
{"type": "system", "message": f"Run failed: {run_info.get('error', 'unknown error')}"}
)
else:
msg = f"Run type '{run_type}' streaming not yet implemented."
logger.warning(msg)
await websocket.send_json({"type": "system", "message": f"Error: {msg}"})
# status == "queued" — WebSocket is the executor (background task didn't start yet)
stream_gen = None
if run_type == "scan":
stream_gen = engine.run_scan(run_id, params)
elif run_type == "pipeline":
stream_gen = engine.run_pipeline(run_id, params)
elif run_type == "portfolio":
stream_gen = engine.run_portfolio(run_id, params)
elif run_type == "auto":
stream_gen = engine.run_auto(run_id, params)
if stream_gen:
run_info["status"] = "running"
run_info.setdefault("events", [])
try:
async for payload in stream_gen:
run_info["events"].append(payload)
if "timestamp" not in payload:
payload["timestamp"] = time.strftime("%H:%M:%S")
await websocket.send_json(payload)
logger.debug(
"Sent event type=%s node=%s run=%s",
payload.get("type"),
payload.get("node_id"),
run_id,
)
run_info["status"] = "completed"
except Exception as exc:
run_info["status"] = "failed"
run_info["error"] = str(exc)
raise
else:
msg = f"Run type '{run_type}' streaming not yet implemented."
logger.warning(msg)
await websocket.send_json({"type": "system", "message": f"Error: {msg}"})
await websocket.send_json({"type": "system", "message": "Run completed."})
logger.info("Run completed run=%s type=%s", run_id, run_type)

View File

@ -1,11 +1,17 @@
import asyncio
import datetime as _dt
import logging
import time
from pathlib import Path
from typing import Dict, Any, AsyncGenerator
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.graph.scanner_graph import ScannerGraph
from tradingagents.graph.portfolio_graph import PortfolioGraph
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.report_paths import get_market_dir, get_ticker_dir
from tradingagents.portfolio.report_store import ReportStore
from tradingagents.daily_digest import append_to_digest
from tradingagents.agents.utils.json_utils import extract_json
logger = logging.getLogger("agent_os.engine")
@ -54,14 +60,75 @@ class LangGraphEngine:
}
self._node_start_times[run_id] = {}
final_state: Dict[str, Any] = {}
async for event in scanner.graph.astream_events(initial_state, version="v2"):
# Capture the complete final state from the root graph's terminal event.
# LangGraph v2 emits one root-level on_chain_end (parent_ids=[], no
# langgraph_node in metadata) whose data.output is the full accumulated state.
if self._is_root_chain_end(event):
output = (event.get("data") or {}).get("output")
if isinstance(output, dict):
final_state = output
mapped = self._map_langgraph_event(run_id, event)
if mapped:
yield mapped
self._node_start_times.pop(run_id, None)
self._node_prompts.pop(run_id, None)
# Save scan reports to disk
if final_state:
yield self._system_log("Saving scan reports to disk…")
try:
save_dir = get_market_dir(date)
save_dir.mkdir(parents=True, exist_ok=True)
for key in (
"geopolitical_report",
"market_movers_report",
"sector_performance_report",
"industry_deep_dive_report",
"macro_scan_summary",
):
content = final_state.get(key, "")
if content:
(save_dir / f"{key}.md").write_text(content)
# Parse and save macro_scan_summary.json via ReportStore for downstream use
summary_text = final_state.get("macro_scan_summary", "")
if summary_text:
try:
summary_data = extract_json(summary_text)
ReportStore().save_scan(date, summary_data)
except (ValueError, KeyError, TypeError):
logger.warning(
"macro_scan_summary for date=%s is not valid JSON "
"(summary already saved as .md — downstream loads may fail)",
date,
)
# Append to daily digest
scan_parts = []
for key, label in (
("geopolitical_report", "Geopolitical & Macro"),
("market_movers_report", "Market Movers"),
("sector_performance_report", "Sector Performance"),
("industry_deep_dive_report", "Industry Deep Dive"),
("macro_scan_summary", "Macro Scan Summary"),
):
content = final_state.get(key, "")
if content:
scan_parts.append(f"### {label}\n{content}")
if scan_parts:
append_to_digest(date, "scan", "Market Scan", "\n\n".join(scan_parts))
yield self._system_log(f"Scan reports saved to {save_dir}")
logger.info("Saved scan reports run=%s date=%s dir=%s", run_id, date, save_dir)
except Exception as exc:
logger.exception("Failed to save scan reports run=%s", run_id)
yield self._system_log(f"Warning: could not save scan reports: {exc}")
logger.info("Completed SCAN run=%s", run_id)
async def run_pipeline(
@ -86,18 +153,53 @@ class LangGraphEngine:
initial_state = graph_wrapper.propagator.create_initial_state(ticker, date)
self._node_start_times[run_id] = {}
final_state: Dict[str, Any] = {}
async for event in graph_wrapper.graph.astream_events(
initial_state,
version="v2",
config={"recursion_limit": graph_wrapper.propagator.max_recur_limit},
):
# Capture the complete final state from the root graph's terminal event.
if self._is_root_chain_end(event):
output = (event.get("data") or {}).get("output")
if isinstance(output, dict):
final_state = output
mapped = self._map_langgraph_event(run_id, event)
if mapped:
yield mapped
self._node_start_times.pop(run_id, None)
self._node_prompts.pop(run_id, None)
# Save pipeline reports to disk
if final_state:
yield self._system_log(f"Saving analysis report for {ticker}")
try:
save_dir = get_ticker_dir(date, ticker)
save_dir.mkdir(parents=True, exist_ok=True)
# Save JSON via ReportStore (complete_report.json)
ReportStore().save_analysis(date, ticker, final_state)
# Write human-readable complete_report.md
self._write_complete_report_md(final_state, ticker, save_dir)
# Append to daily digest
digest_content = (
final_state.get("final_trade_decision")
or final_state.get("trader_investment_plan")
or ""
)
if digest_content:
append_to_digest(date, "analyze", ticker, digest_content)
yield self._system_log(f"Analysis report for {ticker} saved to {save_dir}")
logger.info("Saved pipeline report run=%s ticker=%s dir=%s", run_id, ticker, save_dir)
except Exception as exc:
logger.exception("Failed to save pipeline reports run=%s ticker=%s", run_id, ticker)
yield self._system_log(f"Warning: could not save analysis report for {ticker}: {exc}")
logger.info("Completed PIPELINE run=%s", run_id)
async def run_portfolio(
@ -117,9 +219,34 @@ class LangGraphEngine:
portfolio_graph = PortfolioGraph(config=self.config)
# Load scan summary and per-ticker analyses from the daily report folder
store = ReportStore()
scan_summary = store.load_scan(date) or {}
ticker_analyses: Dict[str, Any] = {}
from tradingagents.report_paths import get_daily_dir
daily_dir = get_daily_dir(date)
if daily_dir.exists():
for ticker_dir in daily_dir.iterdir():
if ticker_dir.is_dir() and ticker_dir.name not in ("market", "portfolio"):
analysis = store.load_analysis(date, ticker_dir.name)
if analysis:
ticker_analyses[ticker_dir.name] = analysis
if scan_summary:
yield self._system_log(f"Loaded macro scan summary for {date}")
else:
yield self._system_log(f"No scan summary found for {date}, proceeding without it")
if ticker_analyses:
yield self._system_log(f"Loaded analyses for: {', '.join(sorted(ticker_analyses.keys()))}")
else:
yield self._system_log("No per-ticker analyses found for this date")
initial_state = {
"portfolio_id": portfolio_id,
"scan_date": date,
"scan_summary": scan_summary,
"ticker_analyses": ticker_analyses,
"messages": [],
}
@ -150,27 +277,131 @@ class LangGraphEngine:
async for evt in self.run_scan(f"{run_id}_scan", {"date": date}):
yield evt
# Phase 2: Pipeline analysis (default ticker for now)
ticker = params.get("ticker", "AAPL")
yield self._system_log(f"Phase 2/3: Running analysis pipeline for {ticker}")
async for evt in self.run_pipeline(
f"{run_id}_pipeline", {"ticker": ticker, "date": date}
):
yield evt
# Phase 2: Pipeline analysis — get tickers from saved scan report
yield self._system_log("Phase 2/3: Loading stocks from scan report…")
scan_data = ReportStore().load_scan(date)
tickers = self._extract_tickers_from_scan_data(scan_data)
if not tickers:
yield self._system_log(
"Warning: no stocks found in scan summary — ensure the scan completed "
"successfully and produced a 'stocks_to_investigate' list. "
"Skipping pipeline phase."
)
else:
for ticker in tickers:
yield self._system_log(f"Phase 2/3: Running analysis pipeline for {ticker}")
async for evt in self.run_pipeline(
f"{run_id}_pipeline_{ticker}", {"ticker": ticker, "date": date}
):
yield evt
# Phase 3: Portfolio management
yield self._system_log("Phase 3/3: Running portfolio manager…")
portfolio_params = {k: v for k, v in params.items() if k != "ticker"}
async for evt in self.run_portfolio(
f"{run_id}_portfolio", {"date": date, **params}
f"{run_id}_portfolio", {"date": date, **portfolio_params}
):
yield evt
logger.info("Completed AUTO run=%s", run_id)
# ------------------------------------------------------------------
# Report helpers
# ------------------------------------------------------------------
@staticmethod
def _write_complete_report_md(
final_state: Dict[str, Any], ticker: str, save_dir: Path
) -> None:
"""Write a human-readable complete_report.md from the pipeline final state."""
sections = []
header = (
f"# Trading Analysis Report: {ticker}\n\n"
f"Generated: {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
)
analyst_parts = []
for key, label in (
("market_report", "Market Analyst"),
("sentiment_report", "Social Analyst"),
("news_report", "News Analyst"),
("fundamentals_report", "Fundamentals Analyst"),
):
if final_state.get(key):
analyst_parts.append(f"### {label}\n{final_state[key]}")
if analyst_parts:
sections.append("## I. Analyst Team Reports\n\n" + "\n\n".join(analyst_parts))
if final_state.get("investment_plan"):
sections.append(f"## II. Research Team Decision\n\n{final_state['investment_plan']}")
if final_state.get("trader_investment_plan"):
sections.append(f"## III. Trading Team Plan\n\n{final_state['trader_investment_plan']}")
if final_state.get("final_trade_decision"):
sections.append(f"## IV. Final Decision\n\n{final_state['final_trade_decision']}")
(save_dir / "complete_report.md").write_text(header + "\n\n".join(sections))
@staticmethod
def _extract_tickers_from_scan_data(scan_data: Dict[str, Any] | None) -> list[str]:
"""Extract ticker symbols from a ReportStore scan summary dict.
Handles two shapes from the macro synthesis LLM output:
* List of dicts: ``[{"ticker": "AAPL", ...}, ...]``
* List of strings: ``["AAPL", "TSLA", ...]``
Also checks both ``stocks_to_investigate`` and ``watchlist`` keys.
Returns an uppercase, deduplicated list in original order.
"""
if not scan_data:
return []
raw_stocks = (
scan_data.get("stocks_to_investigate")
or scan_data.get("watchlist")
or []
)
seen: set[str] = set()
tickers: list[str] = []
for item in raw_stocks:
if isinstance(item, dict):
sym = item.get("ticker") or item.get("symbol") or ""
elif isinstance(item, str):
sym = item
else:
continue
sym = sym.strip().upper()
if sym and sym not in seen:
seen.add(sym)
tickers.append(sym)
return tickers
# ------------------------------------------------------------------
# Event mapping
# ------------------------------------------------------------------
@staticmethod
def _is_root_chain_end(event: Dict[str, Any]) -> bool:
"""Return True for the root-graph terminal event in a LangGraph v2 stream.
LangGraph v2 emits one ``on_chain_end`` event per node AND one for the
root graph itself. The root-graph event is distinguished by:
* ``event["metadata"]`` has no ``langgraph_node`` key (node events always do)
* ``event["parent_ids"]`` is empty (root has no parent run)
Its ``data["output"]`` contains the **complete** final state the
canonical way to read the propagated state without re-running the graph.
"""
if event.get("event") != "on_chain_end":
return False
metadata = event.get("metadata") or {}
if metadata.get("langgraph_node"):
return False # This is a node event, not the root
parent_ids = event.get("parent_ids")
return parent_ids is not None and len(parent_ids) == 0
@staticmethod
def _extract_node_name(event: Dict[str, Any]) -> str:
"""Extract the LangGraph node name from event metadata or tags."""

View File

@ -168,6 +168,94 @@ class TestLangGraphEngineExtraction(unittest.TestCase):
result = self.engine._map_langgraph_event("run_123", event)
self.assertIsNone(result)
# ── _is_root_chain_end ──────────────────────────────────────────
def test_is_root_chain_end_true(self):
"""Root graph terminal event: on_chain_end with empty parent_ids and no node."""
event = {
"event": "on_chain_end",
"name": "LangGraph",
"parent_ids": [],
"metadata": {},
"data": {"output": {"x": "final"}},
}
self.assertTrue(self.engine._is_root_chain_end(event))
def test_is_root_chain_end_false_for_node_event(self):
"""Node on_chain_end should NOT be treated as the root graph end."""
event = {
"event": "on_chain_end",
"name": "geopolitical_scanner",
"parent_ids": ["some-parent-run-id"],
"metadata": {"langgraph_node": "geopolitical_scanner"},
"data": {"output": {"geopolitical_report": "..."}},
}
self.assertFalse(self.engine._is_root_chain_end(event))
def test_is_root_chain_end_false_for_non_chain_end(self):
"""on_chain_start should never match."""
event = {
"event": "on_chain_start",
"name": "LangGraph",
"parent_ids": [],
"metadata": {},
"data": {},
}
self.assertFalse(self.engine._is_root_chain_end(event))
def test_is_root_chain_end_false_when_parent_ids_missing(self):
"""If parent_ids is absent (unexpected), should not match."""
event = {
"event": "on_chain_end",
"name": "LangGraph",
"metadata": {},
"data": {"output": {"x": "v"}},
}
self.assertFalse(self.engine._is_root_chain_end(event))
def test_is_root_chain_end_false_when_langgraph_node_present(self):
"""Node-level event with empty parent_ids should still not match."""
event = {
"event": "on_chain_end",
"name": "some_node",
"parent_ids": [],
"metadata": {"langgraph_node": "some_node"},
"data": {"output": {}},
}
self.assertFalse(self.engine._is_root_chain_end(event))
# ── _extract_tickers_from_scan_data ─────────────────────────────
def test_extract_tickers_list_of_dicts(self):
scan = {"stocks_to_investigate": [
{"ticker": "AAPL", "name": "Apple"},
{"ticker": "tsla", "sector": "EV"},
]}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["AAPL", "TSLA"])
def test_extract_tickers_list_of_strings(self):
scan = {"watchlist": ["msft", "GOOG", "amzn"]}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["MSFT", "GOOG", "AMZN"])
def test_extract_tickers_prefers_stocks_to_investigate(self):
scan = {
"stocks_to_investigate": [{"ticker": "NVDA"}],
"watchlist": [{"ticker": "AMD"}],
}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["NVDA"])
def test_extract_tickers_deduplicates(self):
scan = {"stocks_to_investigate": ["AAPL", "aapl", "AAPL"]}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["AAPL"])
def test_extract_tickers_empty_scan(self):
self.assertEqual(self.engine._extract_tickers_from_scan_data(None), [])
self.assertEqual(self.engine._extract_tickers_from_scan_data({}), [])
def test_extract_tickers_symbol_key_fallback(self):
scan = {"stocks_to_investigate": [{"symbol": "META"}]}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["META"])
if __name__ == "__main__":
unittest.main()