diff --git a/api/callbacks/__init__.py b/api/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/callbacks/token_handler.py b/api/callbacks/token_handler.py new file mode 100644 index 00000000..1c1b41ab --- /dev/null +++ b/api/callbacks/token_handler.py @@ -0,0 +1,41 @@ +import threading +from typing import Any + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AIMessage +from langchain_core.outputs import LLMResult + + +class TokenCallbackHandler(BaseCallbackHandler): + """Tracks LLM token usage. Call snapshot_and_reset() after each agent step + to get the delta tokens consumed by that step, then zero the counters.""" + + def __init__(self) -> None: + super().__init__() + self._lock = threading.Lock() + self._tokens_in = 0 + self._tokens_out = 0 + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + try: + generation = response.generations[0][0] + except (IndexError, TypeError): + return + if not hasattr(generation, "message"): + return + message = generation.message + if not isinstance(message, AIMessage): + return + usage = getattr(message, "usage_metadata", None) + if usage is not None: + with self._lock: + self._tokens_in += usage.get("input_tokens", 0) + self._tokens_out += usage.get("output_tokens", 0) + + def snapshot_and_reset(self) -> dict[str, int]: + """Return {"in": N, "out": M} for the current period and zero counters.""" + with self._lock: + result = {"in": self._tokens_in, "out": self._tokens_out} + self._tokens_in = 0 + self._tokens_out = 0 + return result diff --git a/api/models/run.py b/api/models/run.py index 6ce837b6..079e4527 100644 --- a/api/models/run.py +++ b/api/models/run.py @@ -32,7 +32,13 @@ class RunSummary(BaseModel): created_at: str +class TokenUsage(BaseModel): + tokens_in: int = 0 + tokens_out: int = 0 + + class RunResult(RunSummary): config: Optional[RunConfig] = None - reports: dict[str, str] = {} + reports: dict[str, str] = Field(default_factory=dict) error: Optional[str] = None + token_usage: dict[str, TokenUsage] = Field(default_factory=dict) diff --git a/api/routers/runs.py b/api/routers/runs.py index 7d799ef3..01f47b43 100644 --- a/api/routers/runs.py +++ b/api/routers/runs.py @@ -1,12 +1,19 @@ import json +import pathlib from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from api.models.run import RunConfig, RunResult, RunSummary from api.services.run_service import RunService from api.store.runs_store import RunsStore +try: + from tradingagents.default_config import DEFAULT_CONFIG +except ImportError: + DEFAULT_CONFIG = {"results_dir": "./results"} + router = APIRouter() -_store = RunsStore() +_db_path = pathlib.Path(DEFAULT_CONFIG["results_dir"]) / "runs.sqlite" +_store = RunsStore(_db_path) _service = RunService(_store) diff --git a/api/services/run_service.py b/api/services/run_service.py index 6c94375d..d31e4600 100644 --- a/api/services/run_service.py +++ b/api/services/run_service.py @@ -3,6 +3,7 @@ from collections import defaultdict from typing import Generator from api.store.runs_store import RunsStore from api.models.run import RunConfig, RunStatus +from api.callbacks.token_handler import TokenCallbackHandler logger = logging.getLogger(__name__) @@ -24,6 +25,37 @@ class RunService: yield {"event": "run:error", "data": {"message": "Run not found"}} return + if run.status == RunStatus.COMPLETE: + token_usage = {k: v.model_dump() for k, v in (run.token_usage or {}).items()} + for key, report in run.reports.items(): + if ":" not in key: + logger.warning( + "Skipping malformed report key %r for run %s", key, run_id + ) + continue + step_key, turn_str = key.rsplit(":", 1) + if not turn_str.isdigit(): + logger.warning( + "Skipping report key with non-numeric turn %r for run %s", key, run_id + ) + continue + turn = int(turn_str) + raw = token_usage.get(key, {"tokens_in": 0, "tokens_out": 0}) + yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}} + yield {"event": "agent:complete", "data": { + "step": step_key, "turn": turn, "report": report, + "tokens_in": raw.get("tokens_in", 0), + "tokens_out": raw.get("tokens_out", 0), + }} + yield {"event": "run:complete", "data": {"decision": run.decision or "HOLD", "run_id": run_id}} + return + + if run.status == RunStatus.RUNNING: + yield {"event": "run:error", "data": {"message": "Run is already in progress"}} + return + + self._store.clear_reports(run_id) + self._store.clear_token_usage(run_id) self._store.update_status(run_id, RunStatus.RUNNING) config = run.config @@ -35,20 +67,31 @@ class RunService: ta_config["max_risk_discuss_rounds"] = config.max_risk_discuss_rounds try: + token_handler = TokenCallbackHandler() ta = TradingAgentsGraph( debug=False, config=ta_config, selected_analysts=config.enabled_analysts or ["market", "news", "fundamentals", "social"], + callbacks=[token_handler], ) turn_counts: defaultdict[str, int] = defaultdict(int) for step_key, report in ta.stream_propagate(config.ticker, config.date): + tokens = token_handler.snapshot_and_reset() turn = turn_counts[step_key] yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}} - yield {"event": "agent:complete", "data": {"step": step_key, "turn": turn, "report": report}} + yield {"event": "agent:complete", "data": { + "step": step_key, "turn": turn, "report": report, + "tokens_in": tokens["in"], "tokens_out": tokens["out"], + }} self._store.add_report(run_id, f"{step_key}:{turn}", report) + # Normalize to TokenUsage field names before persisting + self._store.add_token_usage( + run_id, f"{step_key}:{turn}", + {"tokens_in": tokens["in"], "tokens_out": tokens["out"]}, + ) turn_counts[step_key] += 1 decision = ta._last_decision or "HOLD" diff --git a/api/store/runs_store.py b/api/store/runs_store.py index d6509c53..4fedb964 100644 --- a/api/store/runs_store.py +++ b/api/store/runs_store.py @@ -1,63 +1,165 @@ +import json +import sqlite3 import uuid from datetime import datetime, timezone +from pathlib import Path from threading import Lock -from api.models.run import RunConfig, RunResult, RunStatus -from typing import Optional, Literal +from typing import Literal, Optional + +from api.models.run import RunConfig, RunResult, RunStatus, TokenUsage + +_CREATE_TABLE_SQL = """ +CREATE TABLE IF NOT EXISTS runs ( + id TEXT PRIMARY KEY, + ticker TEXT NOT NULL, + date TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'queued', + decision TEXT, + created_at TEXT NOT NULL, + config TEXT, + reports TEXT NOT NULL DEFAULT '{}', + error TEXT, + token_usage TEXT NOT NULL DEFAULT '{}' +) +""" class RunsStore: - def __init__(self): - self._runs: dict[str, RunResult] = {} + def __init__(self, db_path: Path) -> None: + db_path.parent.mkdir(parents=True, exist_ok=True) + self._conn = sqlite3.connect(str(db_path), check_same_thread=False) + self._conn.row_factory = sqlite3.Row + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute(_CREATE_TABLE_SQL) + # Migration: add token_usage column if not present (handles existing DBs) + existing_cols = { + row["name"] + for row in self._conn.execute("PRAGMA table_info(runs)") + } + if "token_usage" not in existing_cols: + self._conn.execute( + "ALTER TABLE runs ADD COLUMN token_usage TEXT NOT NULL DEFAULT '{}'" + ) + self._conn.commit() self._lock = Lock() + def _row_to_run(self, row: sqlite3.Row) -> RunResult: + return RunResult( + id=row["id"], + ticker=row["ticker"], + date=row["date"], + status=RunStatus(row["status"]), + decision=row["decision"], + created_at=row["created_at"], + config=RunConfig(**json.loads(row["config"])) if row["config"] else None, + reports=json.loads(row["reports"]), + error=row["error"], + token_usage={ + k: TokenUsage(**v) + for k, v in json.loads(row["token_usage"] or "{}").items() + }, + ) + def create(self, config: RunConfig) -> RunResult: run_id = str(uuid.uuid4())[:8] - run = RunResult( + now = datetime.now(timezone.utc).isoformat() + with self._lock: + self._conn.execute( + "INSERT INTO runs (id, ticker, date, status, created_at, config)" + " VALUES (?, ?, ?, ?, ?, ?)", + (run_id, config.ticker, config.date, RunStatus.QUEUED.value, + now, config.model_dump_json()), + ) + self._conn.commit() + return RunResult( id=run_id, ticker=config.ticker, date=config.date, status=RunStatus.QUEUED, - created_at=datetime.now(timezone.utc).isoformat(), + created_at=now, config=config, ) - with self._lock: - self._runs[run_id] = run - return run def get(self, run_id: str) -> Optional[RunResult]: - return self._runs.get(run_id) + with self._lock: + row = self._conn.execute( + "SELECT * FROM runs WHERE id = ?", (run_id,) + ).fetchone() + return self._row_to_run(row) if row else None def list_all(self) -> list[RunResult]: - return list(self._runs.values()) + with self._lock: + rows = self._conn.execute( + "SELECT * FROM runs ORDER BY created_at DESC" + ).fetchall() + return [self._row_to_run(row) for row in rows] def update_status(self, run_id: str, status: RunStatus) -> None: with self._lock: - if run_id in self._runs: - self._runs[run_id] = self._runs[run_id].model_copy( - update={"status": status} - ) + self._conn.execute( + "UPDATE runs SET status = ? WHERE id = ?", + (status.value, run_id), + ) + self._conn.commit() def update_decision( self, run_id: str, decision: Literal["BUY", "SELL", "HOLD"] ) -> None: with self._lock: - if run_id in self._runs: - self._runs[run_id] = self._runs[run_id].model_copy( - update={"decision": decision} - ) + self._conn.execute( + "UPDATE runs SET decision = ? WHERE id = ?", + (decision, run_id), + ) + self._conn.commit() def add_report(self, run_id: str, step: str, report: str) -> None: with self._lock: - if run_id in self._runs: - reports = dict(self._runs[run_id].reports) + row = self._conn.execute( + "SELECT reports FROM runs WHERE id = ?", (run_id,) + ).fetchone() + if row: + reports = json.loads(row[0]) reports[step] = report - self._runs[run_id] = self._runs[run_id].model_copy( - update={"reports": reports} + self._conn.execute( + "UPDATE runs SET reports = ? WHERE id = ?", + (json.dumps(reports), run_id), ) + self._conn.commit() def set_error(self, run_id: str, error: str) -> None: with self._lock: - if run_id in self._runs: - self._runs[run_id] = self._runs[run_id].model_copy( - update={"status": RunStatus.ERROR, "error": error} + self._conn.execute( + "UPDATE runs SET status = ?, error = ? WHERE id = ?", + (RunStatus.ERROR.value, error, run_id), + ) + self._conn.commit() + + def clear_reports(self, run_id: str) -> None: + with self._lock: + self._conn.execute( + "UPDATE runs SET reports = '{}' WHERE id = ?", + (run_id,), + ) + self._conn.commit() + + def add_token_usage(self, run_id: str, key: str, tokens: dict) -> None: + with self._lock: + row = self._conn.execute( + "SELECT token_usage FROM runs WHERE id = ?", (run_id,) + ).fetchone() + if row: + usage = json.loads(row[0] or "{}") + usage[key] = tokens + self._conn.execute( + "UPDATE runs SET token_usage = ? WHERE id = ?", + (json.dumps(usage), run_id), ) + self._conn.commit() + + def clear_token_usage(self, run_id: str) -> None: + with self._lock: + self._conn.execute( + "UPDATE runs SET token_usage = '{}' WHERE id = ?", + (run_id,), + ) + self._conn.commit() diff --git a/tests/api/test_run_service.py b/tests/api/test_run_service.py index 81297717..9fa99c2b 100644 --- a/tests/api/test_run_service.py +++ b/tests/api/test_run_service.py @@ -7,8 +7,8 @@ from api.models.run import RunConfig, RunStatus @pytest.fixture -def store(): - return RunsStore() +def store(tmp_path): + return RunsStore(tmp_path / "test.sqlite") @pytest.fixture @@ -116,3 +116,124 @@ def test_selected_analysts_passed_to_graph(service, store): call_kwargs = MockGraph.call_args.kwargs assert call_kwargs.get("selected_analysts") == ["market", "news"] + + +def test_completed_run_replays_without_re_running_graph(service, store): + # Pre-populate a completed run in the store + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.update_status(run.id, RunStatus.RUNNING) + store.add_report(run.id, "market_analyst:0", "bullish") + store.update_status(run.id, RunStatus.COMPLETE) + store.update_decision(run.id, "BUY") + + with patch("api.services.run_service.TradingAgentsGraph") as MockGraph: + events = list(service.stream_events(run.id)) + + assert MockGraph.call_count == 0 # no agent execution on replay + agent_completes = [e for e in events if e["event"] == "agent:complete"] + assert len(agent_completes) == 1 + assert agent_completes[0]["data"]["report"] == "bullish" + run_complete = next(e for e in events if e["event"] == "run:complete") + assert run_complete["data"]["decision"] == "BUY" + + +def test_running_run_returns_in_progress_error(service, store): + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.update_status(run.id, RunStatus.RUNNING) + + with patch("api.services.run_service.TradingAgentsGraph") as MockGraph: + events = list(service.stream_events(run.id)) + + assert MockGraph.call_count == 0 # no agent execution + assert len(events) == 1 + assert events[0]["event"] == "run:error" + assert "already in progress" in events[0]["data"]["message"] + + +def test_error_run_retries_and_clears_reports(service, store): + # Simulate a run that failed partway through with a stale report + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.update_status(run.id, RunStatus.RUNNING) + store.add_report(run.id, "market_analyst:0", "stale data") + store.set_error(run.id, "timeout") + + assert store.get(run.id).status == RunStatus.ERROR + assert store.get(run.id).reports == {"market_analyst:0": "stale data"} + + with patch("api.services.run_service.TradingAgentsGraph") as MockGraph: + MockGraph.return_value = _mock_graph([("news_analyst", "fresh")], decision="HOLD") + events = list(service.stream_events(run.id)) + + assert MockGraph.call_count == 1 # graph was executed on retry + final_reports = store.get(run.id).reports + assert "news_analyst:0" in final_reports + assert "market_analyst:0" not in final_reports # stale data cleared on retry + + +def test_live_run_agent_complete_includes_token_fields(service, store): + """agent:complete events must include tokens_in and tokens_out.""" + config = RunConfig(ticker="NVDA", date="2026-03-23") + run = store.create(config) + with patch("api.services.run_service.TradingAgentsGraph") as MockGraph: + MockGraph.return_value = _mock_graph([("market_analyst", "bullish")]) + events = list(service.stream_events(run.id)) + + complete_events = [e for e in events if e["event"] == "agent:complete"] + assert len(complete_events) == 1 + data = complete_events[0]["data"] + assert "tokens_in" in data + assert "tokens_out" in data + assert isinstance(data["tokens_in"], int) + assert isinstance(data["tokens_out"], int) + + +def test_retry_clears_stale_token_usage(service, store): + """After an errored run is retried, stale token keys are absent.""" + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.update_status(run.id, RunStatus.RUNNING) + store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 500, "tokens_out": 200}) + store.set_error(run.id, "timeout") + + with patch("api.services.run_service.TradingAgentsGraph") as MockGraph: + MockGraph.return_value = _mock_graph([("news_analyst", "fresh")]) + list(service.stream_events(run.id)) + + final = store.get(run.id) + assert "market_analyst:0" not in final.token_usage # stale key gone + assert "news_analyst:0" in final.token_usage # new key present + + +def test_replay_attaches_token_data_to_agent_complete(service, store): + """Replaying a completed run emits agent:complete with token data.""" + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.update_status(run.id, RunStatus.RUNNING) + store.add_report(run.id, "market_analyst:0", "bullish") + store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400}) + store.update_status(run.id, RunStatus.COMPLETE) + store.update_decision(run.id, "BUY") + + with patch("api.services.run_service.TradingAgentsGraph") as MockGraph: + events = list(service.stream_events(run.id)) + + assert MockGraph.call_count == 0 + complete_events = [e for e in events if e["event"] == "agent:complete"] + assert len(complete_events) == 1 + assert complete_events[0]["data"]["tokens_in"] == 1200 + assert complete_events[0]["data"]["tokens_out"] == 400 + + +def test_replay_defaults_to_zero_tokens_when_missing(service, store): + """Replay of a run with no token_usage emits tokens_in=0, tokens_out=0.""" + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.update_status(run.id, RunStatus.RUNNING) + store.add_report(run.id, "market_analyst:0", "bearish") + store.update_status(run.id, RunStatus.COMPLETE) + store.update_decision(run.id, "SELL") + # No add_token_usage call — simulates old run with no token data + + with patch("api.services.run_service.TradingAgentsGraph"): + events = list(service.stream_events(run.id)) + + complete = next(e for e in events if e["event"] == "agent:complete") + assert complete["data"]["tokens_in"] == 0 + assert complete["data"]["tokens_out"] == 0 diff --git a/tests/api/test_store.py b/tests/api/test_store.py index a8607d57..7eb616af 100644 --- a/tests/api/test_store.py +++ b/tests/api/test_store.py @@ -1,10 +1,11 @@ import pytest +import sqlite3 as _sqlite3 from api.store.runs_store import RunsStore from api.models.run import RunConfig, RunStatus -def test_create_and_get_run(): - store = RunsStore() +def test_create_and_get_run(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") config = RunConfig(ticker="NVDA", date="2024-05-10") run = store.create(config) assert run.id is not None @@ -13,16 +14,108 @@ def test_create_and_get_run(): assert fetched.ticker == "NVDA" -def test_list_runs(): - store = RunsStore() +def test_list_runs(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") store.create(RunConfig(ticker="NVDA", date="2024-05-10")) store.create(RunConfig(ticker="AAPL", date="2024-05-09")) runs = store.list_all() assert len(runs) == 2 -def test_update_run_status(): - store = RunsStore() +def test_update_run_status(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") run = store.create(RunConfig(ticker="NVDA", date="2024-05-10")) store.update_status(run.id, RunStatus.RUNNING) assert store.get(run.id).status == RunStatus.RUNNING + + +def test_add_report(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") + run = store.create(RunConfig(ticker="NVDA", date="2024-05-10")) + store.add_report(run.id, "market_analyst:0", "bullish") + assert store.get(run.id).reports == {"market_analyst:0": "bullish"} + + +def test_set_error(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") + run = store.create(RunConfig(ticker="NVDA", date="2024-05-10")) + store.set_error(run.id, "timeout") + fetched = store.get(run.id) + assert fetched.status == RunStatus.ERROR + assert fetched.error == "timeout" + + +def test_clear_reports(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") + run = store.create(RunConfig(ticker="NVDA", date="2024-05-10")) + store.add_report(run.id, "market_analyst:0", "bullish") + store.clear_reports(run.id) + assert store.get(run.id).reports == {} + + +def test_update_decision(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") + run = store.create(RunConfig(ticker="NVDA", date="2024-05-10")) + store.update_decision(run.id, "BUY") + assert store.get(run.id).decision == "BUY" + + +def test_migration_adds_token_usage_column_to_existing_db(tmp_path): + """DB created without token_usage gets the column added on RunsStore.__init__.""" + db_path = tmp_path / "old.sqlite" + # Create a DB without the token_usage column (simulate pre-migration state) + conn = _sqlite3.connect(str(db_path)) + conn.execute(""" + CREATE TABLE runs ( + id TEXT PRIMARY KEY, ticker TEXT NOT NULL, date TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'queued', decision TEXT, + created_at TEXT NOT NULL, config TEXT, + reports TEXT NOT NULL DEFAULT '{}', error TEXT + ) + """) + conn.commit() + conn.close() + + # Initialising the store should migrate the column + store = RunsStore(db_path) + cols = { + row["name"] + for row in store._conn.execute("PRAGMA table_info(runs)") + } + assert "token_usage" in cols + + +def test_migration_is_idempotent(tmp_path): + """Re-initialising the store after migration does not crash.""" + db_path = tmp_path / "test.sqlite" + RunsStore(db_path) # first init — creates table + column + RunsStore(db_path) # second init — column already present, should not crash + + +def test_add_token_usage(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400}) + result = store.get(run.id) + assert "market_analyst:0" in result.token_usage + assert result.token_usage["market_analyst:0"].tokens_in == 1200 + assert result.token_usage["market_analyst:0"].tokens_out == 400 + + +def test_clear_token_usage(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400}) + store.clear_token_usage(run.id) + assert store.get(run.id).token_usage == {} + + +def test_clear_token_usage_does_not_affect_reports(tmp_path): + store = RunsStore(tmp_path / "test.sqlite") + run = store.create(RunConfig(ticker="NVDA", date="2026-03-23")) + store.add_report(run.id, "market_analyst:0", "bullish") + store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 100, "tokens_out": 50}) + store.clear_token_usage(run.id) + fetched = store.get(run.id) + assert fetched.reports == {"market_analyst:0": "bullish"} + assert fetched.token_usage == {} diff --git a/tests/api/test_token_handler.py b/tests/api/test_token_handler.py new file mode 100644 index 00000000..62cf497d --- /dev/null +++ b/tests/api/test_token_handler.py @@ -0,0 +1,74 @@ +import threading +from unittest.mock import MagicMock +from langchain_core.messages import AIMessage +from langchain_core.outputs import LLMResult, ChatGeneration +from api.callbacks.token_handler import TokenCallbackHandler + + +def _make_response(input_tokens: int, output_tokens: int) -> LLMResult: + """Build a minimal LLMResult that on_llm_end will parse.""" + msg = AIMessage(content="ok") + msg.usage_metadata = {"input_tokens": input_tokens, "output_tokens": output_tokens} + gen = ChatGeneration(message=msg) + return LLMResult(generations=[[gen]]) + + +def test_snapshot_and_reset_returns_delta(): + handler = TokenCallbackHandler() + handler.on_llm_end(_make_response(100, 40)) + result = handler.snapshot_and_reset() + assert result == {"in": 100, "out": 40} + + +def test_snapshot_and_reset_zeroes_counters(): + handler = TokenCallbackHandler() + handler.on_llm_end(_make_response(100, 40)) + handler.snapshot_and_reset() + second = handler.snapshot_and_reset() + assert second == {"in": 0, "out": 0} + + +def test_multiple_llm_calls_accumulate(): + handler = TokenCallbackHandler() + handler.on_llm_end(_make_response(100, 40)) + handler.on_llm_end(_make_response(200, 60)) + result = handler.snapshot_and_reset() + assert result == {"in": 300, "out": 100} + + +def test_concurrent_on_llm_end_does_not_corrupt(): + handler = TokenCallbackHandler() + threads = [ + threading.Thread(target=handler.on_llm_end, args=(_make_response(10, 5),)) + for _ in range(20) + ] + for t in threads: + t.start() + for t in threads: + t.join() + result = handler.snapshot_and_reset() + assert result == {"in": 200, "out": 100} + + +def test_missing_usage_metadata_does_not_crash(): + handler = TokenCallbackHandler() + msg = AIMessage(content="no metadata") + # No usage_metadata attribute + gen = ChatGeneration(message=msg) + response = LLMResult(generations=[[gen]]) + handler.on_llm_end(response) # should not raise + assert handler.snapshot_and_reset() == {"in": 0, "out": 0} + + +def test_empty_outer_generations_does_not_crash(): + handler = TokenCallbackHandler() + response = LLMResult(generations=[]) + handler.on_llm_end(response) # IndexError guard + assert handler.snapshot_and_reset() == {"in": 0, "out": 0} + + +def test_empty_inner_generations_does_not_crash(): + handler = TokenCallbackHandler() + response = LLMResult(generations=[[]]) + handler.on_llm_end(response) # IndexError guard + assert handler.snapshot_and_reset() == {"in": 0, "out": 0} diff --git a/ui/__tests__/features/run-detail/TokenStatsBar.test.tsx b/ui/__tests__/features/run-detail/TokenStatsBar.test.tsx new file mode 100644 index 00000000..1605afb2 --- /dev/null +++ b/ui/__tests__/features/run-detail/TokenStatsBar.test.tsx @@ -0,0 +1,30 @@ +import { render, screen } from '@testing-library/react' +import TokenStatsBar from '@/features/run-detail/components/TokenStatsBar' + +test('renders nothing when both token counts are zero', () => { + const { container } = render( + + ) + expect(container.firstChild).toBeNull() +}) + +test('renders total, input, output when tokens are non-zero', () => { + render( + + ) + // Total = 3100 + 1100 = 4200 → "4.2k" + expect(screen.getByText('4.2k')).toBeInTheDocument() + // Input → "3.1k" + expect(screen.getByText('3.1k')).toBeInTheDocument() + // Output → "1.1k" + expect(screen.getByText('1.1k')).toBeInTheDocument() +}) + +test('formats values below 1000 without k suffix', () => { + render( + + ) + expect(screen.getByText('950')).toBeInTheDocument() // total + expect(screen.getByText('800')).toBeInTheDocument() // input + expect(screen.getByText('150')).toBeInTheDocument() // output +}) diff --git a/ui/__tests__/features/run-detail/useRunStream.test.ts b/ui/__tests__/features/run-detail/useRunStream.test.ts index 3edc2898..b7b62334 100644 --- a/ui/__tests__/features/run-detail/useRunStream.test.ts +++ b/ui/__tests__/features/run-detail/useRunStream.test.ts @@ -16,8 +16,21 @@ jest.mock('@/lib/sse', () => ({ }), })) +// getRun defaults to 'queued' so existing SSE-path tests still exercise the SSE branch. +// Tests that need a different status use mockResolvedValueOnce to override. jest.mock('@/lib/api-client', () => ({ getRunStreamUrl: (id: string) => `/api/runs/${id}/stream`, + getRun: jest.fn().mockResolvedValue({ + id: 'abc', + ticker: 'NVDA', + date: '2026-03-23', + status: 'queued', + decision: null, + created_at: '2026-03-23T00:00:00Z', + config: null, + reports: {}, + error: null, + }), })) test('appends multiple turns for same step', async () => { @@ -46,3 +59,118 @@ test('initial reports are empty arrays', () => { const { result } = renderHook(() => useRunStream('abc')) expect(result.current.reports['market_analyst']).toEqual([]) }) + +test('hydrates from reports when run is complete, skipping SSE', async () => { + const { getRun } = jest.requireMock('@/lib/api-client') + const { createSSEConnection } = jest.requireMock('@/lib/sse') + + // Reset call history so we can assert createSSEConnection was NOT called for this run + jest.clearAllMocks() + + getRun.mockResolvedValueOnce({ + id: 'xyz', + ticker: 'AAPL', + date: '2026-03-23', + status: 'complete', + decision: 'SELL', + created_at: '2026-03-23T00:00:00Z', + config: null, + reports: { 'market_analyst:0': 'bearish signal' }, + error: null, + }) + + const { result } = renderHook(() => useRunStream('xyz')) + await act(async () => { await new Promise((r) => setTimeout(r, 10)) }) + + expect(result.current.status).toBe('complete') + expect(result.current.verdict).toBe('SELL') + expect(result.current.reports['market_analyst']).toEqual(['bearish signal']) + expect(createSSEConnection).not.toHaveBeenCalled() // SSE skipped for completed run +}) + +test('AGENT_COMPLETE accumulates tokensByStep and tokensTotal', async () => { + const { getRun } = jest.requireMock('@/lib/api-client') + const { createSSEConnection } = jest.requireMock('@/lib/sse') + jest.clearAllMocks() + + // getRun returns queued so SSE path runs + getRun.mockResolvedValueOnce({ + id: 'abc', ticker: 'NVDA', date: '2026-03-23', status: 'queued', + decision: null, created_at: '2026-03-23T00:00:00Z', + config: null, reports: {}, error: null, token_usage: null, + }) + + // SSE mock emits one agent:complete with token data + createSSEConnection.mockImplementationOnce( + (_url: string, handlers: Record void>) => { + setTimeout(() => { + handlers.onAgentStart?.({ step: 'market_analyst', turn: 0 }) + handlers.onAgentComplete?.({ + step: 'market_analyst', turn: 0, report: 'bullish', + tokens_in: 1200, tokens_out: 400, + }) + handlers.onRunComplete?.({ decision: 'BUY', run_id: 'abc' }) + }, 0) + return jest.fn() + } + ) + + const { result } = renderHook(() => useRunStream('abc')) + await act(async () => { await new Promise((r) => setTimeout(r, 10)) }) + + expect(result.current.tokensByStep['market_analyst']).toEqual({ in: 1200, out: 400 }) + expect(result.current.tokensTotal).toEqual({ in: 1200, out: 400 }) +}) + +test('missing tokens_in/out in AGENT_COMPLETE defaults to 0', async () => { + const { getRun } = jest.requireMock('@/lib/api-client') + const { createSSEConnection } = jest.requireMock('@/lib/sse') + jest.clearAllMocks() + + getRun.mockResolvedValueOnce({ + id: 'abc', ticker: 'NVDA', date: '2026-03-23', status: 'queued', + decision: null, created_at: '2026-03-23T00:00:00Z', + config: null, reports: {}, error: null, token_usage: null, + }) + + createSSEConnection.mockImplementationOnce( + (_url: string, handlers: Record void>) => { + setTimeout(() => { + handlers.onAgentStart?.({ step: 'news_analyst', turn: 0 }) + // No tokens_in/tokens_out in payload + handlers.onAgentComplete?.({ step: 'news_analyst', turn: 0, report: 'ok' }) + handlers.onRunComplete?.({ decision: 'HOLD', run_id: 'abc' }) + }, 0) + return jest.fn() + } + ) + + const { result } = renderHook(() => useRunStream('abc')) + await act(async () => { await new Promise((r) => setTimeout(r, 10)) }) + + expect(result.current.tokensByStep['news_analyst']).toEqual({ in: 0, out: 0 }) + expect(result.current.tokensTotal).toEqual({ in: 0, out: 0 }) +}) + +test('completed-run hydration populates tokens from getRun().token_usage without SSE', async () => { + const { getRun } = jest.requireMock('@/lib/api-client') + const { createSSEConnection } = jest.requireMock('@/lib/sse') + jest.clearAllMocks() + + getRun.mockResolvedValueOnce({ + id: 'tok', ticker: 'AAPL', date: '2026-03-23', status: 'complete', + decision: 'BUY', created_at: '2026-03-23T00:00:00Z', + config: null, + reports: { 'market_analyst:0': 'bullish' }, + error: null, + token_usage: { 'market_analyst:0': { tokens_in: 1200, tokens_out: 400 } }, + }) + + const { result } = renderHook(() => useRunStream('tok')) + await act(async () => { await new Promise((r) => setTimeout(r, 10)) }) + + expect(createSSEConnection).not.toHaveBeenCalled() + expect(result.current.status).toBe('complete') + expect(result.current.tokensByStep['market_analyst']).toEqual({ in: 1200, out: 400 }) + expect(result.current.tokensTotal).toEqual({ in: 1200, out: 400 }) +}) diff --git a/ui/app/(dashboard)/runs/[id]/page.tsx b/ui/app/(dashboard)/runs/[id]/page.tsx index 4ea8b36e..35722012 100644 --- a/ui/app/(dashboard)/runs/[id]/page.tsx +++ b/ui/app/(dashboard)/runs/[id]/page.tsx @@ -6,6 +6,7 @@ import VerdictBanner from '@/features/run-detail/components/VerdictBanner' import PhaseTabs from '@/features/run-detail/components/PhaseTabs' import { getRun } from '@/lib/api-client' import type { RunSummary } from '@/lib/types/run' +import TokenStatsBar from '@/features/run-detail/components/TokenStatsBar' const STATUS_CONFIG: Record }) { const { id } = use(params) - const { steps, reports, verdict, status, error } = useRunStream(id) + const { steps, reports, verdict, status, error, tokensTotal, tokensByStep } = useRunStream(id) const [run, setRun] = useState(null) useEffect(() => { @@ -88,6 +89,9 @@ export default function RunDetailPage({ params }: { params: Promise<{ id: string + {/* Token stats bar */} + + {/* Pipeline */} @@ -111,7 +115,7 @@ export default function RunDetailPage({ params }: { params: Promise<{ id: string )} {/* Phase tabs + reports */} - + ) } diff --git a/ui/features/run-detail/components/AnalystReports.tsx b/ui/features/run-detail/components/AnalystReports.tsx index 8a9575ef..adff6dc4 100644 --- a/ui/features/run-detail/components/AnalystReports.tsx +++ b/ui/features/run-detail/components/AnalystReports.tsx @@ -2,6 +2,10 @@ import { AGENT_STEPS, AGENT_STEP_LABELS, STEP_PHASE } from '@/lib/types/run' import type { AgentStep } from '@/lib/types/run' import type { StepStatus } from '@/lib/types/agents' +function formatTokens(n: number): string { + return n >= 1000 ? `${(n / 1000).toFixed(1)}k` : String(n) +} + type Phase = 'analysts' | 'researchers' | 'trader' | 'risk' const MULTI_TURN_STEPS = new Set([ @@ -43,9 +47,10 @@ type Props = { phase: Phase steps: Record reports: Record + tokensByStep: Record } -export default function AnalystReports({ phase, steps, reports }: Props) { +export default function AnalystReports({ phase, steps, reports, tokensByStep }: Props) { const phaseSteps = AGENT_STEPS.filter((s) => STEP_PHASE[s] === phase) return ( @@ -112,6 +117,22 @@ export default function AnalystReports({ phase, steps, reports }: Props) { T{i + 1} )} + {(() => { + const tok = tokensByStep[step] + return tok && (tok.in > 0 || tok.out > 0) ? ( + + {formatTokens(tok.in)}↑{' '} + {formatTokens(tok.out)}↓ + + ) : null + })()}
- reports: Record + steps: Record + reports: Record + tokensByStep: Record } function getPhaseCompletion(phase: Phase, steps: Record): number { @@ -32,7 +33,7 @@ function getPhaseStatus(phase: Phase, steps: Record): 'do return 'pending' } -export default function PhaseTabs({ steps, reports }: Props) { +export default function PhaseTabs({ steps, reports, tokensByStep }: Props) { const [active, setActive] = useState('analysts') return ( @@ -118,7 +119,7 @@ export default function PhaseTabs({ steps, reports }: Props) {
{/* Reports */} - + ) } diff --git a/ui/features/run-detail/components/TokenStatsBar.tsx b/ui/features/run-detail/components/TokenStatsBar.tsx new file mode 100644 index 00000000..a7736eee --- /dev/null +++ b/ui/features/run-detail/components/TokenStatsBar.tsx @@ -0,0 +1,73 @@ +import type { TokenCount } from '../types' + +function formatTokens(n: number): string { + return n >= 1000 ? `${(n / 1000).toFixed(1)}k` : String(n) +} + +type Props = { + tokensTotal: TokenCount + status: string +} + +export default function TokenStatsBar({ tokensTotal }: Props) { + if (tokensTotal.in === 0 && tokensTotal.out === 0) return null + + const total = tokensTotal.in + tokensTotal.out + + return ( +
+
+ + Total + + + {formatTokens(total)} + +
+ +
+ +
+ + Input ↑ + + + {formatTokens(tokensTotal.in)} + +
+ +
+ + Output ↓ + + + {formatTokens(tokensTotal.out)} + +
+
+ ) +} diff --git a/ui/features/run-detail/hooks/useRunStream.ts b/ui/features/run-detail/hooks/useRunStream.ts index 999311f2..d4394163 100644 --- a/ui/features/run-detail/hooks/useRunStream.ts +++ b/ui/features/run-detail/hooks/useRunStream.ts @@ -1,24 +1,28 @@ 'use client' import { useEffect, useReducer } from 'react' import { createSSEConnection } from '@/lib/sse' -import { getRunStreamUrl } from '@/lib/api-client' +import { getRun, getRunStreamUrl } from '@/lib/api-client' import { AGENT_STEPS } from '@/lib/types/run' import type { AgentStep } from '@/lib/types/run' -import type { RunStreamState } from '../types' +import type { RunStreamState, TokenCount } from '../types' + +const zeroTokens = (): TokenCount => ({ in: 0, out: 0 }) const initialState: RunStreamState = { status: 'connecting', - steps: Object.fromEntries(AGENT_STEPS.map((s) => [s, 'pending'])) as RunStreamState['steps'], - reports: Object.fromEntries(AGENT_STEPS.map((s) => [s, []])) as RunStreamState['reports'], + steps: Object.fromEntries(AGENT_STEPS.map((s) => [s, 'pending'])) as RunStreamState['steps'], + reports: Object.fromEntries(AGENT_STEPS.map((s) => [s, []])) as RunStreamState['reports'], + tokensByStep: Object.fromEntries(AGENT_STEPS.map((s) => [s, zeroTokens()])) as RunStreamState['tokensByStep'], + tokensTotal: zeroTokens(), verdict: null, error: null, } type Action = - | { type: 'AGENT_START'; step: AgentStep; turn: number } - | { type: 'AGENT_COMPLETE'; step: AgentStep; turn: number; report: string } - | { type: 'RUN_COMPLETE'; decision: string } - | { type: 'RUN_ERROR'; message: string } + | { type: 'AGENT_START'; step: AgentStep; turn: number } + | { type: 'AGENT_COMPLETE'; step: AgentStep; turn: number; report: string; tokens_in?: number; tokens_out?: number } + | { type: 'RUN_COMPLETE'; decision: string } + | { type: 'RUN_ERROR'; message: string } | { type: 'CONNECTED' } function reducer(state: RunStreamState, action: Action): RunStreamState { @@ -27,19 +31,30 @@ function reducer(state: RunStreamState, action: Action): RunStreamState { return { ...state, status: 'running' } case 'AGENT_START': - // Only transition to 'running' on first turn (don't regress from 'done') if (state.steps[action.step] !== 'pending') return state return { ...state, steps: { ...state.steps, [action.step]: 'running' } } - case 'AGENT_COMPLETE': + case 'AGENT_COMPLETE': { + const dIn = action.tokens_in ?? 0 + const dOut = action.tokens_out ?? 0 + const prev = state.tokensByStep[action.step] ?? zeroTokens() return { ...state, - steps: { ...state.steps, [action.step]: 'done' }, + steps: { ...state.steps, [action.step]: 'done' }, reports: { ...state.reports, [action.step]: [...(state.reports[action.step] ?? []), action.report], }, + tokensByStep: { + ...state.tokensByStep, + [action.step]: { in: prev.in + dIn, out: prev.out + dOut }, + }, + tokensTotal: { + in: state.tokensTotal.in + dIn, + out: state.tokensTotal.out + dOut, + }, } + } case 'RUN_COMPLETE': return { ...state, status: 'complete', verdict: action.decision as RunStreamState['verdict'] } @@ -56,17 +71,43 @@ export function useRunStream(runId: string): RunStreamState { const [state, dispatch] = useReducer(reducer, initialState) useEffect(() => { - const url = getRunStreamUrl(runId) - const close = createSSEConnection(url, { - onOpen: () => dispatch({ type: 'CONNECTED' }), - onAgentStart: ({ step, turn }) => - dispatch({ type: 'AGENT_START', step: step as AgentStep, turn }), - onAgentComplete: ({ step, turn, report }) => - dispatch({ type: 'AGENT_COMPLETE', step: step as AgentStep, turn, report }), - onRunComplete: ({ decision }) => dispatch({ type: 'RUN_COMPLETE', decision }), - onRunError: ({ message }) => dispatch({ type: 'RUN_ERROR', message }), + let close: (() => void) | undefined + let aborted = false + + getRun(runId).then((run) => { + if (aborted) return + + if (run.status === 'complete' && run.reports) { + dispatch({ type: 'CONNECTED' }) + for (const [key, report] of Object.entries(run.reports)) { + const lastColon = key.lastIndexOf(':') + const step = key.slice(0, lastColon) as AgentStep + const turn = parseInt(key.slice(lastColon + 1), 10) + const tok = run.token_usage?.[key] ?? { tokens_in: 0, tokens_out: 0 } + dispatch({ type: 'AGENT_START', step, turn }) + dispatch({ type: 'AGENT_COMPLETE', step, turn, report, + tokens_in: tok.tokens_in, tokens_out: tok.tokens_out }) + } + dispatch({ type: 'RUN_COMPLETE', decision: run.decision ?? 'HOLD' }) + return + } + + const url = getRunStreamUrl(runId) + close = createSSEConnection(url, { + onOpen: () => dispatch({ type: 'CONNECTED' }), + onAgentStart: ({ step, turn }) => + dispatch({ type: 'AGENT_START', step: step as AgentStep, turn }), + onAgentComplete: ({ step, turn, report, tokens_in, tokens_out }) => + dispatch({ type: 'AGENT_COMPLETE', step: step as AgentStep, turn, report, + tokens_in, tokens_out }), + onRunComplete: ({ decision }) => dispatch({ type: 'RUN_COMPLETE', decision }), + onRunError: ({ message }) => dispatch({ type: 'RUN_ERROR', message }), + }) + }).catch(() => { + if (!aborted) dispatch({ type: 'RUN_ERROR', message: 'Failed to load run' }) }) - return close + + return () => { aborted = true; close?.() } }, [runId]) return state diff --git a/ui/features/run-detail/types.ts b/ui/features/run-detail/types.ts index cd099543..75dff0b5 100644 --- a/ui/features/run-detail/types.ts +++ b/ui/features/run-detail/types.ts @@ -1,10 +1,14 @@ import type { AgentStep, RunStatus } from '@/lib/types/run' import type { Decision, StepStatus } from '@/lib/types/agents' +export type TokenCount = { in: number; out: number } + export type RunStreamState = { status: RunStatus | 'connecting' steps: Record reports: Record verdict: Decision | null error: string | null + tokensByStep: Record + tokensTotal: TokenCount } diff --git a/ui/lib/api-client.ts b/ui/lib/api-client.ts index b8074e64..5eab65e0 100644 --- a/ui/lib/api-client.ts +++ b/ui/lib/api-client.ts @@ -14,14 +14,21 @@ async function apiFetch(path: string, init?: RequestInit): Promise { return res.json() as Promise } +export type RunResult = RunSummary & { + config: RunConfig | null + reports: Record + error: string | null + token_usage: Record | null +} + export const createRun = (config: RunConfig): Promise => apiFetch('/api/runs', { method: 'POST', body: JSON.stringify(config) }) export const listRuns = (): Promise => apiFetch('/api/runs') -export const getRun = (id: string): Promise => - apiFetch(`/api/runs/${id}`) +export const getRun = (id: string): Promise => + apiFetch(`/api/runs/${id}`) export const getSettings = (): Promise => apiFetch('/api/settings') diff --git a/ui/lib/sse.ts b/ui/lib/sse.ts index dc700d13..4516e072 100644 --- a/ui/lib/sse.ts +++ b/ui/lib/sse.ts @@ -1,6 +1,9 @@ export type SSEHandlers = { onAgentStart?: (data: { step: string; turn: number }) => void - onAgentComplete?: (data: { step: string; turn: number; report: string }) => void + onAgentComplete?: (data: { + step: string; turn: number; report: string; + tokens_in?: number; tokens_out?: number + }) => void onRunComplete?: (data: { decision: string; run_id: string }) => void onRunError?: (data: { message: string }) => void onOpen?: () => void