Merge pull request #92 from aguzererler/copilot/fix-report-saving-in-runs

Fix report persistence, run status tracking, auto-mode stock sourcing, and portfolio context loading
This commit is contained in:
ahmet guzererler 2026-03-23 19:55:42 +01:00 committed by GitHub
commit 36a6b17a22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 412 additions and 38 deletions

View File

@ -1,5 +1,5 @@
from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException from fastapi import APIRouter, Depends, BackgroundTasks, HTTPException
from typing import Dict, Any, List from typing import Dict, Any, List, AsyncGenerator
import logging import logging
import uuid import uuid
import time import time
@ -13,6 +13,21 @@ router = APIRouter(prefix="/api/run", tags=["runs"])
engine = LangGraphEngine() engine = LangGraphEngine()
async def _run_and_store(run_id: str, gen: AsyncGenerator[Dict[str, Any], None]) -> None:
"""Drive an engine generator, updating run status and caching events."""
runs[run_id]["status"] = "running"
runs[run_id]["events"] = []
try:
async for event in gen:
runs[run_id]["events"].append(event)
runs[run_id]["status"] = "completed"
except Exception as exc:
runs[run_id]["status"] = "failed"
runs[run_id]["error"] = str(exc)
logger.exception("Run failed run=%s", run_id)
@router.post("/scan") @router.post("/scan")
async def trigger_scan( async def trigger_scan(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@ -29,7 +44,7 @@ async def trigger_scan(
"params": params or {} "params": params or {}
} }
logger.info("Queued SCAN run=%s user=%s", run_id, user["user_id"]) logger.info("Queued SCAN run=%s user=%s", run_id, user["user_id"])
background_tasks.add_task(engine.run_scan, run_id, params or {}) background_tasks.add_task(_run_and_store, run_id, engine.run_scan(run_id, params or {}))
return {"run_id": run_id, "status": "queued"} return {"run_id": run_id, "status": "queued"}
@router.post("/pipeline") @router.post("/pipeline")
@ -48,7 +63,7 @@ async def trigger_pipeline(
"params": params or {} "params": params or {}
} }
logger.info("Queued PIPELINE run=%s user=%s", run_id, user["user_id"]) logger.info("Queued PIPELINE run=%s user=%s", run_id, user["user_id"])
background_tasks.add_task(engine.run_pipeline, run_id, params or {}) background_tasks.add_task(_run_and_store, run_id, engine.run_pipeline(run_id, params or {}))
return {"run_id": run_id, "status": "queued"} return {"run_id": run_id, "status": "queued"}
@router.post("/portfolio") @router.post("/portfolio")
@ -67,7 +82,7 @@ async def trigger_portfolio(
"params": params or {} "params": params or {}
} }
logger.info("Queued PORTFOLIO run=%s user=%s", run_id, user["user_id"]) logger.info("Queued PORTFOLIO run=%s user=%s", run_id, user["user_id"])
background_tasks.add_task(engine.run_portfolio, run_id, params or {}) background_tasks.add_task(_run_and_store, run_id, engine.run_portfolio(run_id, params or {}))
return {"run_id": run_id, "status": "queued"} return {"run_id": run_id, "status": "queued"}
@router.post("/auto") @router.post("/auto")
@ -86,7 +101,7 @@ async def trigger_auto(
"params": params or {} "params": params or {}
} }
logger.info("Queued AUTO run=%s user=%s", run_id, user["user_id"]) logger.info("Queued AUTO run=%s user=%s", run_id, user["user_id"])
background_tasks.add_task(engine.run_auto, run_id, params or {}) background_tasks.add_task(_run_and_store, run_id, engine.run_auto(run_id, params or {}))
return {"run_id": run_id, "status": "queued"} return {"run_id": run_id, "status": "queued"}
@router.get("/") @router.get("/")

View File

@ -12,6 +12,9 @@ logger = logging.getLogger("agent_os.websocket")
router = APIRouter(prefix="/ws", tags=["websocket"]) router = APIRouter(prefix="/ws", tags=["websocket"])
# Polling interval when streaming cached events from a background-task-driven run
_EVENT_POLL_INTERVAL_SECONDS = 0.05
engine = LangGraphEngine() engine = LangGraphEngine()
@router.websocket("/stream/{run_id}") @router.websocket("/stream/{run_id}")
@ -33,32 +36,69 @@ async def websocket_endpoint(
params = run_info.get("params", {}) params = run_info.get("params", {})
try: try:
stream_gen = None status = run_info.get("status", "queued")
if run_type == "scan":
stream_gen = engine.run_scan(run_id, params)
elif run_type == "pipeline":
stream_gen = engine.run_pipeline(run_id, params)
elif run_type == "portfolio":
stream_gen = engine.run_portfolio(run_id, params)
elif run_type == "auto":
stream_gen = engine.run_auto(run_id, params)
if stream_gen: if status in ("running", "completed", "failed"):
async for payload in stream_gen: # Background task is already executing (or finished) — stream its cached events
# Add timestamp if not present # then wait for completion if still running.
if "timestamp" not in payload: logger.info(
payload["timestamp"] = time.strftime("%H:%M:%S") "WebSocket streaming from cache run=%s status=%s", run_id, status
await websocket.send_json(payload) )
logger.debug( sent = 0
"Sent event type=%s node=%s run=%s", while True:
payload.get("type"), cached = run_info.get("events") or []
payload.get("node_id"), while sent < len(cached):
run_id, payload = cached[sent]
if "timestamp" not in payload:
payload["timestamp"] = time.strftime("%H:%M:%S")
await websocket.send_json(payload)
sent += 1
current_status = run_info.get("status")
if current_status in ("completed", "failed"):
break
# Yield to the event loop so the background task can produce more events
await asyncio.sleep(_EVENT_POLL_INTERVAL_SECONDS)
if run_info.get("status") == "failed":
await websocket.send_json(
{"type": "system", "message": f"Run failed: {run_info.get('error', 'unknown error')}"}
) )
else: else:
msg = f"Run type '{run_type}' streaming not yet implemented." # status == "queued" — WebSocket is the executor (background task didn't start yet)
logger.warning(msg) stream_gen = None
await websocket.send_json({"type": "system", "message": f"Error: {msg}"}) if run_type == "scan":
stream_gen = engine.run_scan(run_id, params)
elif run_type == "pipeline":
stream_gen = engine.run_pipeline(run_id, params)
elif run_type == "portfolio":
stream_gen = engine.run_portfolio(run_id, params)
elif run_type == "auto":
stream_gen = engine.run_auto(run_id, params)
if stream_gen:
run_info["status"] = "running"
run_info.setdefault("events", [])
try:
async for payload in stream_gen:
run_info["events"].append(payload)
if "timestamp" not in payload:
payload["timestamp"] = time.strftime("%H:%M:%S")
await websocket.send_json(payload)
logger.debug(
"Sent event type=%s node=%s run=%s",
payload.get("type"),
payload.get("node_id"),
run_id,
)
run_info["status"] = "completed"
except Exception as exc:
run_info["status"] = "failed"
run_info["error"] = str(exc)
raise
else:
msg = f"Run type '{run_type}' streaming not yet implemented."
logger.warning(msg)
await websocket.send_json({"type": "system", "message": f"Error: {msg}"})
await websocket.send_json({"type": "system", "message": "Run completed."}) await websocket.send_json({"type": "system", "message": "Run completed."})
logger.info("Run completed run=%s type=%s", run_id, run_type) logger.info("Run completed run=%s type=%s", run_id, run_type)

View File

@ -1,11 +1,17 @@
import asyncio import asyncio
import datetime as _dt
import logging import logging
import time import time
from pathlib import Path
from typing import Dict, Any, AsyncGenerator from typing import Dict, Any, AsyncGenerator
from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.graph.scanner_graph import ScannerGraph from tradingagents.graph.scanner_graph import ScannerGraph
from tradingagents.graph.portfolio_graph import PortfolioGraph from tradingagents.graph.portfolio_graph import PortfolioGraph
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.report_paths import get_market_dir, get_ticker_dir
from tradingagents.portfolio.report_store import ReportStore
from tradingagents.daily_digest import append_to_digest
from tradingagents.agents.utils.json_utils import extract_json
logger = logging.getLogger("agent_os.engine") logger = logging.getLogger("agent_os.engine")
@ -54,14 +60,75 @@ class LangGraphEngine:
} }
self._node_start_times[run_id] = {} self._node_start_times[run_id] = {}
final_state: Dict[str, Any] = {}
async for event in scanner.graph.astream_events(initial_state, version="v2"): async for event in scanner.graph.astream_events(initial_state, version="v2"):
# Capture the complete final state from the root graph's terminal event.
# LangGraph v2 emits one root-level on_chain_end (parent_ids=[], no
# langgraph_node in metadata) whose data.output is the full accumulated state.
if self._is_root_chain_end(event):
output = (event.get("data") or {}).get("output")
if isinstance(output, dict):
final_state = output
mapped = self._map_langgraph_event(run_id, event) mapped = self._map_langgraph_event(run_id, event)
if mapped: if mapped:
yield mapped yield mapped
self._node_start_times.pop(run_id, None) self._node_start_times.pop(run_id, None)
self._node_prompts.pop(run_id, None) self._node_prompts.pop(run_id, None)
# Save scan reports to disk
if final_state:
yield self._system_log("Saving scan reports to disk…")
try:
save_dir = get_market_dir(date)
save_dir.mkdir(parents=True, exist_ok=True)
for key in (
"geopolitical_report",
"market_movers_report",
"sector_performance_report",
"industry_deep_dive_report",
"macro_scan_summary",
):
content = final_state.get(key, "")
if content:
(save_dir / f"{key}.md").write_text(content)
# Parse and save macro_scan_summary.json via ReportStore for downstream use
summary_text = final_state.get("macro_scan_summary", "")
if summary_text:
try:
summary_data = extract_json(summary_text)
ReportStore().save_scan(date, summary_data)
except (ValueError, KeyError, TypeError):
logger.warning(
"macro_scan_summary for date=%s is not valid JSON "
"(summary already saved as .md — downstream loads may fail)",
date,
)
# Append to daily digest
scan_parts = []
for key, label in (
("geopolitical_report", "Geopolitical & Macro"),
("market_movers_report", "Market Movers"),
("sector_performance_report", "Sector Performance"),
("industry_deep_dive_report", "Industry Deep Dive"),
("macro_scan_summary", "Macro Scan Summary"),
):
content = final_state.get(key, "")
if content:
scan_parts.append(f"### {label}\n{content}")
if scan_parts:
append_to_digest(date, "scan", "Market Scan", "\n\n".join(scan_parts))
yield self._system_log(f"Scan reports saved to {save_dir}")
logger.info("Saved scan reports run=%s date=%s dir=%s", run_id, date, save_dir)
except Exception as exc:
logger.exception("Failed to save scan reports run=%s", run_id)
yield self._system_log(f"Warning: could not save scan reports: {exc}")
logger.info("Completed SCAN run=%s", run_id) logger.info("Completed SCAN run=%s", run_id)
async def run_pipeline( async def run_pipeline(
@ -86,18 +153,53 @@ class LangGraphEngine:
initial_state = graph_wrapper.propagator.create_initial_state(ticker, date) initial_state = graph_wrapper.propagator.create_initial_state(ticker, date)
self._node_start_times[run_id] = {} self._node_start_times[run_id] = {}
final_state: Dict[str, Any] = {}
async for event in graph_wrapper.graph.astream_events( async for event in graph_wrapper.graph.astream_events(
initial_state, initial_state,
version="v2", version="v2",
config={"recursion_limit": graph_wrapper.propagator.max_recur_limit}, config={"recursion_limit": graph_wrapper.propagator.max_recur_limit},
): ):
# Capture the complete final state from the root graph's terminal event.
if self._is_root_chain_end(event):
output = (event.get("data") or {}).get("output")
if isinstance(output, dict):
final_state = output
mapped = self._map_langgraph_event(run_id, event) mapped = self._map_langgraph_event(run_id, event)
if mapped: if mapped:
yield mapped yield mapped
self._node_start_times.pop(run_id, None) self._node_start_times.pop(run_id, None)
self._node_prompts.pop(run_id, None) self._node_prompts.pop(run_id, None)
# Save pipeline reports to disk
if final_state:
yield self._system_log(f"Saving analysis report for {ticker}")
try:
save_dir = get_ticker_dir(date, ticker)
save_dir.mkdir(parents=True, exist_ok=True)
# Save JSON via ReportStore (complete_report.json)
ReportStore().save_analysis(date, ticker, final_state)
# Write human-readable complete_report.md
self._write_complete_report_md(final_state, ticker, save_dir)
# Append to daily digest
digest_content = (
final_state.get("final_trade_decision")
or final_state.get("trader_investment_plan")
or ""
)
if digest_content:
append_to_digest(date, "analyze", ticker, digest_content)
yield self._system_log(f"Analysis report for {ticker} saved to {save_dir}")
logger.info("Saved pipeline report run=%s ticker=%s dir=%s", run_id, ticker, save_dir)
except Exception as exc:
logger.exception("Failed to save pipeline reports run=%s ticker=%s", run_id, ticker)
yield self._system_log(f"Warning: could not save analysis report for {ticker}: {exc}")
logger.info("Completed PIPELINE run=%s", run_id) logger.info("Completed PIPELINE run=%s", run_id)
async def run_portfolio( async def run_portfolio(
@ -117,9 +219,34 @@ class LangGraphEngine:
portfolio_graph = PortfolioGraph(config=self.config) portfolio_graph = PortfolioGraph(config=self.config)
# Load scan summary and per-ticker analyses from the daily report folder
store = ReportStore()
scan_summary = store.load_scan(date) or {}
ticker_analyses: Dict[str, Any] = {}
from tradingagents.report_paths import get_daily_dir
daily_dir = get_daily_dir(date)
if daily_dir.exists():
for ticker_dir in daily_dir.iterdir():
if ticker_dir.is_dir() and ticker_dir.name not in ("market", "portfolio"):
analysis = store.load_analysis(date, ticker_dir.name)
if analysis:
ticker_analyses[ticker_dir.name] = analysis
if scan_summary:
yield self._system_log(f"Loaded macro scan summary for {date}")
else:
yield self._system_log(f"No scan summary found for {date}, proceeding without it")
if ticker_analyses:
yield self._system_log(f"Loaded analyses for: {', '.join(sorted(ticker_analyses.keys()))}")
else:
yield self._system_log("No per-ticker analyses found for this date")
initial_state = { initial_state = {
"portfolio_id": portfolio_id, "portfolio_id": portfolio_id,
"scan_date": date, "scan_date": date,
"scan_summary": scan_summary,
"ticker_analyses": ticker_analyses,
"messages": [], "messages": [],
} }
@ -150,27 +277,131 @@ class LangGraphEngine:
async for evt in self.run_scan(f"{run_id}_scan", {"date": date}): async for evt in self.run_scan(f"{run_id}_scan", {"date": date}):
yield evt yield evt
# Phase 2: Pipeline analysis (default ticker for now) # Phase 2: Pipeline analysis — get tickers from saved scan report
ticker = params.get("ticker", "AAPL") yield self._system_log("Phase 2/3: Loading stocks from scan report…")
yield self._system_log(f"Phase 2/3: Running analysis pipeline for {ticker}") scan_data = ReportStore().load_scan(date)
async for evt in self.run_pipeline( tickers = self._extract_tickers_from_scan_data(scan_data)
f"{run_id}_pipeline", {"ticker": ticker, "date": date}
): if not tickers:
yield evt yield self._system_log(
"Warning: no stocks found in scan summary — ensure the scan completed "
"successfully and produced a 'stocks_to_investigate' list. "
"Skipping pipeline phase."
)
else:
for ticker in tickers:
yield self._system_log(f"Phase 2/3: Running analysis pipeline for {ticker}")
async for evt in self.run_pipeline(
f"{run_id}_pipeline_{ticker}", {"ticker": ticker, "date": date}
):
yield evt
# Phase 3: Portfolio management # Phase 3: Portfolio management
yield self._system_log("Phase 3/3: Running portfolio manager…") yield self._system_log("Phase 3/3: Running portfolio manager…")
portfolio_params = {k: v for k, v in params.items() if k != "ticker"}
async for evt in self.run_portfolio( async for evt in self.run_portfolio(
f"{run_id}_portfolio", {"date": date, **params} f"{run_id}_portfolio", {"date": date, **portfolio_params}
): ):
yield evt yield evt
logger.info("Completed AUTO run=%s", run_id) logger.info("Completed AUTO run=%s", run_id)
# ------------------------------------------------------------------
# Report helpers
# ------------------------------------------------------------------
@staticmethod
def _write_complete_report_md(
final_state: Dict[str, Any], ticker: str, save_dir: Path
) -> None:
"""Write a human-readable complete_report.md from the pipeline final state."""
sections = []
header = (
f"# Trading Analysis Report: {ticker}\n\n"
f"Generated: {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
)
analyst_parts = []
for key, label in (
("market_report", "Market Analyst"),
("sentiment_report", "Social Analyst"),
("news_report", "News Analyst"),
("fundamentals_report", "Fundamentals Analyst"),
):
if final_state.get(key):
analyst_parts.append(f"### {label}\n{final_state[key]}")
if analyst_parts:
sections.append("## I. Analyst Team Reports\n\n" + "\n\n".join(analyst_parts))
if final_state.get("investment_plan"):
sections.append(f"## II. Research Team Decision\n\n{final_state['investment_plan']}")
if final_state.get("trader_investment_plan"):
sections.append(f"## III. Trading Team Plan\n\n{final_state['trader_investment_plan']}")
if final_state.get("final_trade_decision"):
sections.append(f"## IV. Final Decision\n\n{final_state['final_trade_decision']}")
(save_dir / "complete_report.md").write_text(header + "\n\n".join(sections))
@staticmethod
def _extract_tickers_from_scan_data(scan_data: Dict[str, Any] | None) -> list[str]:
"""Extract ticker symbols from a ReportStore scan summary dict.
Handles two shapes from the macro synthesis LLM output:
* List of dicts: ``[{"ticker": "AAPL", ...}, ...]``
* List of strings: ``["AAPL", "TSLA", ...]``
Also checks both ``stocks_to_investigate`` and ``watchlist`` keys.
Returns an uppercase, deduplicated list in original order.
"""
if not scan_data:
return []
raw_stocks = (
scan_data.get("stocks_to_investigate")
or scan_data.get("watchlist")
or []
)
seen: set[str] = set()
tickers: list[str] = []
for item in raw_stocks:
if isinstance(item, dict):
sym = item.get("ticker") or item.get("symbol") or ""
elif isinstance(item, str):
sym = item
else:
continue
sym = sym.strip().upper()
if sym and sym not in seen:
seen.add(sym)
tickers.append(sym)
return tickers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Event mapping # Event mapping
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@staticmethod
def _is_root_chain_end(event: Dict[str, Any]) -> bool:
"""Return True for the root-graph terminal event in a LangGraph v2 stream.
LangGraph v2 emits one ``on_chain_end`` event per node AND one for the
root graph itself. The root-graph event is distinguished by:
* ``event["metadata"]`` has no ``langgraph_node`` key (node events always do)
* ``event["parent_ids"]`` is empty (root has no parent run)
Its ``data["output"]`` contains the **complete** final state the
canonical way to read the propagated state without re-running the graph.
"""
if event.get("event") != "on_chain_end":
return False
metadata = event.get("metadata") or {}
if metadata.get("langgraph_node"):
return False # This is a node event, not the root
parent_ids = event.get("parent_ids")
return parent_ids is not None and len(parent_ids) == 0
@staticmethod @staticmethod
def _extract_node_name(event: Dict[str, Any]) -> str: def _extract_node_name(event: Dict[str, Any]) -> str:
"""Extract the LangGraph node name from event metadata or tags.""" """Extract the LangGraph node name from event metadata or tags."""

View File

@ -168,6 +168,94 @@ class TestLangGraphEngineExtraction(unittest.TestCase):
result = self.engine._map_langgraph_event("run_123", event) result = self.engine._map_langgraph_event("run_123", event)
self.assertIsNone(result) self.assertIsNone(result)
# ── _is_root_chain_end ──────────────────────────────────────────
def test_is_root_chain_end_true(self):
"""Root graph terminal event: on_chain_end with empty parent_ids and no node."""
event = {
"event": "on_chain_end",
"name": "LangGraph",
"parent_ids": [],
"metadata": {},
"data": {"output": {"x": "final"}},
}
self.assertTrue(self.engine._is_root_chain_end(event))
def test_is_root_chain_end_false_for_node_event(self):
"""Node on_chain_end should NOT be treated as the root graph end."""
event = {
"event": "on_chain_end",
"name": "geopolitical_scanner",
"parent_ids": ["some-parent-run-id"],
"metadata": {"langgraph_node": "geopolitical_scanner"},
"data": {"output": {"geopolitical_report": "..."}},
}
self.assertFalse(self.engine._is_root_chain_end(event))
def test_is_root_chain_end_false_for_non_chain_end(self):
"""on_chain_start should never match."""
event = {
"event": "on_chain_start",
"name": "LangGraph",
"parent_ids": [],
"metadata": {},
"data": {},
}
self.assertFalse(self.engine._is_root_chain_end(event))
def test_is_root_chain_end_false_when_parent_ids_missing(self):
"""If parent_ids is absent (unexpected), should not match."""
event = {
"event": "on_chain_end",
"name": "LangGraph",
"metadata": {},
"data": {"output": {"x": "v"}},
}
self.assertFalse(self.engine._is_root_chain_end(event))
def test_is_root_chain_end_false_when_langgraph_node_present(self):
"""Node-level event with empty parent_ids should still not match."""
event = {
"event": "on_chain_end",
"name": "some_node",
"parent_ids": [],
"metadata": {"langgraph_node": "some_node"},
"data": {"output": {}},
}
self.assertFalse(self.engine._is_root_chain_end(event))
# ── _extract_tickers_from_scan_data ─────────────────────────────
def test_extract_tickers_list_of_dicts(self):
scan = {"stocks_to_investigate": [
{"ticker": "AAPL", "name": "Apple"},
{"ticker": "tsla", "sector": "EV"},
]}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["AAPL", "TSLA"])
def test_extract_tickers_list_of_strings(self):
scan = {"watchlist": ["msft", "GOOG", "amzn"]}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["MSFT", "GOOG", "AMZN"])
def test_extract_tickers_prefers_stocks_to_investigate(self):
scan = {
"stocks_to_investigate": [{"ticker": "NVDA"}],
"watchlist": [{"ticker": "AMD"}],
}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["NVDA"])
def test_extract_tickers_deduplicates(self):
scan = {"stocks_to_investigate": ["AAPL", "aapl", "AAPL"]}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["AAPL"])
def test_extract_tickers_empty_scan(self):
self.assertEqual(self.engine._extract_tickers_from_scan_data(None), [])
self.assertEqual(self.engine._extract_tickers_from_scan_data({}), [])
def test_extract_tickers_symbol_key_fallback(self):
scan = {"stocks_to_investigate": [{"symbol": "META"}]}
self.assertEqual(self.engine._extract_tickers_from_scan_data(scan), ["META"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()