feat: add token usage tracking and reporting to run service

- Introduced TokenCallbackHandler to track input and output token usage during LLM operations.
- Updated RunResult model to include token usage data.
- Enhanced RunsStore to support token usage persistence in the database.
- Modified RunService to yield token usage information during event streaming.
- Implemented UI components to display token statistics in the run detail view.
- Added tests for token handling and reporting functionality.

Made-with: Cursor
This commit is contained in:
Ali AL OGAILI 2026-03-24 00:40:38 +01:00
commit 29a338d957
19 changed files with 867 additions and 68 deletions

View File

View File

@ -0,0 +1,41 @@
import threading
from typing import Any
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.outputs import LLMResult
class TokenCallbackHandler(BaseCallbackHandler):
"""Tracks LLM token usage. Call snapshot_and_reset() after each agent step
to get the delta tokens consumed by that step, then zero the counters."""
def __init__(self) -> None:
super().__init__()
self._lock = threading.Lock()
self._tokens_in = 0
self._tokens_out = 0
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
try:
generation = response.generations[0][0]
except (IndexError, TypeError):
return
if not hasattr(generation, "message"):
return
message = generation.message
if not isinstance(message, AIMessage):
return
usage = getattr(message, "usage_metadata", None)
if usage is not None:
with self._lock:
self._tokens_in += usage.get("input_tokens", 0)
self._tokens_out += usage.get("output_tokens", 0)
def snapshot_and_reset(self) -> dict[str, int]:
"""Return {"in": N, "out": M} for the current period and zero counters."""
with self._lock:
result = {"in": self._tokens_in, "out": self._tokens_out}
self._tokens_in = 0
self._tokens_out = 0
return result

View File

@ -32,7 +32,13 @@ class RunSummary(BaseModel):
created_at: str
class TokenUsage(BaseModel):
tokens_in: int = 0
tokens_out: int = 0
class RunResult(RunSummary):
config: Optional[RunConfig] = None
reports: dict[str, str] = {}
reports: dict[str, str] = Field(default_factory=dict)
error: Optional[str] = None
token_usage: dict[str, TokenUsage] = Field(default_factory=dict)

View File

@ -1,12 +1,19 @@
import json
import pathlib
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from api.models.run import RunConfig, RunResult, RunSummary
from api.services.run_service import RunService
from api.store.runs_store import RunsStore
try:
from tradingagents.default_config import DEFAULT_CONFIG
except ImportError:
DEFAULT_CONFIG = {"results_dir": "./results"}
router = APIRouter()
_store = RunsStore()
_db_path = pathlib.Path(DEFAULT_CONFIG["results_dir"]) / "runs.sqlite"
_store = RunsStore(_db_path)
_service = RunService(_store)

View File

@ -3,6 +3,7 @@ from collections import defaultdict
from typing import Generator
from api.store.runs_store import RunsStore
from api.models.run import RunConfig, RunStatus
from api.callbacks.token_handler import TokenCallbackHandler
logger = logging.getLogger(__name__)
@ -24,6 +25,37 @@ class RunService:
yield {"event": "run:error", "data": {"message": "Run not found"}}
return
if run.status == RunStatus.COMPLETE:
token_usage = {k: v.model_dump() for k, v in (run.token_usage or {}).items()}
for key, report in run.reports.items():
if ":" not in key:
logger.warning(
"Skipping malformed report key %r for run %s", key, run_id
)
continue
step_key, turn_str = key.rsplit(":", 1)
if not turn_str.isdigit():
logger.warning(
"Skipping report key with non-numeric turn %r for run %s", key, run_id
)
continue
turn = int(turn_str)
raw = token_usage.get(key, {"tokens_in": 0, "tokens_out": 0})
yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}}
yield {"event": "agent:complete", "data": {
"step": step_key, "turn": turn, "report": report,
"tokens_in": raw.get("tokens_in", 0),
"tokens_out": raw.get("tokens_out", 0),
}}
yield {"event": "run:complete", "data": {"decision": run.decision or "HOLD", "run_id": run_id}}
return
if run.status == RunStatus.RUNNING:
yield {"event": "run:error", "data": {"message": "Run is already in progress"}}
return
self._store.clear_reports(run_id)
self._store.clear_token_usage(run_id)
self._store.update_status(run_id, RunStatus.RUNNING)
config = run.config
@ -35,20 +67,31 @@ class RunService:
ta_config["max_risk_discuss_rounds"] = config.max_risk_discuss_rounds
try:
token_handler = TokenCallbackHandler()
ta = TradingAgentsGraph(
debug=False,
config=ta_config,
selected_analysts=config.enabled_analysts or
["market", "news", "fundamentals", "social"],
callbacks=[token_handler],
)
turn_counts: defaultdict[str, int] = defaultdict(int)
for step_key, report in ta.stream_propagate(config.ticker, config.date):
tokens = token_handler.snapshot_and_reset()
turn = turn_counts[step_key]
yield {"event": "agent:start", "data": {"step": step_key, "turn": turn}}
yield {"event": "agent:complete", "data": {"step": step_key, "turn": turn, "report": report}}
yield {"event": "agent:complete", "data": {
"step": step_key, "turn": turn, "report": report,
"tokens_in": tokens["in"], "tokens_out": tokens["out"],
}}
self._store.add_report(run_id, f"{step_key}:{turn}", report)
# Normalize to TokenUsage field names before persisting
self._store.add_token_usage(
run_id, f"{step_key}:{turn}",
{"tokens_in": tokens["in"], "tokens_out": tokens["out"]},
)
turn_counts[step_key] += 1
decision = ta._last_decision or "HOLD"

View File

@ -1,63 +1,165 @@
import json
import sqlite3
import uuid
from datetime import datetime, timezone
from pathlib import Path
from threading import Lock
from api.models.run import RunConfig, RunResult, RunStatus
from typing import Optional, Literal
from typing import Literal, Optional
from api.models.run import RunConfig, RunResult, RunStatus, TokenUsage
_CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS runs (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
date TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
decision TEXT,
created_at TEXT NOT NULL,
config TEXT,
reports TEXT NOT NULL DEFAULT '{}',
error TEXT,
token_usage TEXT NOT NULL DEFAULT '{}'
)
"""
class RunsStore:
def __init__(self):
self._runs: dict[str, RunResult] = {}
def __init__(self, db_path: Path) -> None:
db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(str(db_path), check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute(_CREATE_TABLE_SQL)
# Migration: add token_usage column if not present (handles existing DBs)
existing_cols = {
row["name"]
for row in self._conn.execute("PRAGMA table_info(runs)")
}
if "token_usage" not in existing_cols:
self._conn.execute(
"ALTER TABLE runs ADD COLUMN token_usage TEXT NOT NULL DEFAULT '{}'"
)
self._conn.commit()
self._lock = Lock()
def _row_to_run(self, row: sqlite3.Row) -> RunResult:
return RunResult(
id=row["id"],
ticker=row["ticker"],
date=row["date"],
status=RunStatus(row["status"]),
decision=row["decision"],
created_at=row["created_at"],
config=RunConfig(**json.loads(row["config"])) if row["config"] else None,
reports=json.loads(row["reports"]),
error=row["error"],
token_usage={
k: TokenUsage(**v)
for k, v in json.loads(row["token_usage"] or "{}").items()
},
)
def create(self, config: RunConfig) -> RunResult:
run_id = str(uuid.uuid4())[:8]
run = RunResult(
now = datetime.now(timezone.utc).isoformat()
with self._lock:
self._conn.execute(
"INSERT INTO runs (id, ticker, date, status, created_at, config)"
" VALUES (?, ?, ?, ?, ?, ?)",
(run_id, config.ticker, config.date, RunStatus.QUEUED.value,
now, config.model_dump_json()),
)
self._conn.commit()
return RunResult(
id=run_id,
ticker=config.ticker,
date=config.date,
status=RunStatus.QUEUED,
created_at=datetime.now(timezone.utc).isoformat(),
created_at=now,
config=config,
)
with self._lock:
self._runs[run_id] = run
return run
def get(self, run_id: str) -> Optional[RunResult]:
return self._runs.get(run_id)
with self._lock:
row = self._conn.execute(
"SELECT * FROM runs WHERE id = ?", (run_id,)
).fetchone()
return self._row_to_run(row) if row else None
def list_all(self) -> list[RunResult]:
return list(self._runs.values())
with self._lock:
rows = self._conn.execute(
"SELECT * FROM runs ORDER BY created_at DESC"
).fetchall()
return [self._row_to_run(row) for row in rows]
def update_status(self, run_id: str, status: RunStatus) -> None:
with self._lock:
if run_id in self._runs:
self._runs[run_id] = self._runs[run_id].model_copy(
update={"status": status}
)
self._conn.execute(
"UPDATE runs SET status = ? WHERE id = ?",
(status.value, run_id),
)
self._conn.commit()
def update_decision(
self, run_id: str, decision: Literal["BUY", "SELL", "HOLD"]
) -> None:
with self._lock:
if run_id in self._runs:
self._runs[run_id] = self._runs[run_id].model_copy(
update={"decision": decision}
)
self._conn.execute(
"UPDATE runs SET decision = ? WHERE id = ?",
(decision, run_id),
)
self._conn.commit()
def add_report(self, run_id: str, step: str, report: str) -> None:
with self._lock:
if run_id in self._runs:
reports = dict(self._runs[run_id].reports)
row = self._conn.execute(
"SELECT reports FROM runs WHERE id = ?", (run_id,)
).fetchone()
if row:
reports = json.loads(row[0])
reports[step] = report
self._runs[run_id] = self._runs[run_id].model_copy(
update={"reports": reports}
self._conn.execute(
"UPDATE runs SET reports = ? WHERE id = ?",
(json.dumps(reports), run_id),
)
self._conn.commit()
def set_error(self, run_id: str, error: str) -> None:
with self._lock:
if run_id in self._runs:
self._runs[run_id] = self._runs[run_id].model_copy(
update={"status": RunStatus.ERROR, "error": error}
self._conn.execute(
"UPDATE runs SET status = ?, error = ? WHERE id = ?",
(RunStatus.ERROR.value, error, run_id),
)
self._conn.commit()
def clear_reports(self, run_id: str) -> None:
with self._lock:
self._conn.execute(
"UPDATE runs SET reports = '{}' WHERE id = ?",
(run_id,),
)
self._conn.commit()
def add_token_usage(self, run_id: str, key: str, tokens: dict) -> None:
with self._lock:
row = self._conn.execute(
"SELECT token_usage FROM runs WHERE id = ?", (run_id,)
).fetchone()
if row:
usage = json.loads(row[0] or "{}")
usage[key] = tokens
self._conn.execute(
"UPDATE runs SET token_usage = ? WHERE id = ?",
(json.dumps(usage), run_id),
)
self._conn.commit()
def clear_token_usage(self, run_id: str) -> None:
with self._lock:
self._conn.execute(
"UPDATE runs SET token_usage = '{}' WHERE id = ?",
(run_id,),
)
self._conn.commit()

View File

@ -7,8 +7,8 @@ from api.models.run import RunConfig, RunStatus
@pytest.fixture
def store():
return RunsStore()
def store(tmp_path):
return RunsStore(tmp_path / "test.sqlite")
@pytest.fixture
@ -116,3 +116,124 @@ def test_selected_analysts_passed_to_graph(service, store):
call_kwargs = MockGraph.call_args.kwargs
assert call_kwargs.get("selected_analysts") == ["market", "news"]
def test_completed_run_replays_without_re_running_graph(service, store):
# Pre-populate a completed run in the store
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.update_status(run.id, RunStatus.RUNNING)
store.add_report(run.id, "market_analyst:0", "bullish")
store.update_status(run.id, RunStatus.COMPLETE)
store.update_decision(run.id, "BUY")
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
events = list(service.stream_events(run.id))
assert MockGraph.call_count == 0 # no agent execution on replay
agent_completes = [e for e in events if e["event"] == "agent:complete"]
assert len(agent_completes) == 1
assert agent_completes[0]["data"]["report"] == "bullish"
run_complete = next(e for e in events if e["event"] == "run:complete")
assert run_complete["data"]["decision"] == "BUY"
def test_running_run_returns_in_progress_error(service, store):
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.update_status(run.id, RunStatus.RUNNING)
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
events = list(service.stream_events(run.id))
assert MockGraph.call_count == 0 # no agent execution
assert len(events) == 1
assert events[0]["event"] == "run:error"
assert "already in progress" in events[0]["data"]["message"]
def test_error_run_retries_and_clears_reports(service, store):
# Simulate a run that failed partway through with a stale report
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.update_status(run.id, RunStatus.RUNNING)
store.add_report(run.id, "market_analyst:0", "stale data")
store.set_error(run.id, "timeout")
assert store.get(run.id).status == RunStatus.ERROR
assert store.get(run.id).reports == {"market_analyst:0": "stale data"}
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
MockGraph.return_value = _mock_graph([("news_analyst", "fresh")], decision="HOLD")
events = list(service.stream_events(run.id))
assert MockGraph.call_count == 1 # graph was executed on retry
final_reports = store.get(run.id).reports
assert "news_analyst:0" in final_reports
assert "market_analyst:0" not in final_reports # stale data cleared on retry
def test_live_run_agent_complete_includes_token_fields(service, store):
"""agent:complete events must include tokens_in and tokens_out."""
config = RunConfig(ticker="NVDA", date="2026-03-23")
run = store.create(config)
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
MockGraph.return_value = _mock_graph([("market_analyst", "bullish")])
events = list(service.stream_events(run.id))
complete_events = [e for e in events if e["event"] == "agent:complete"]
assert len(complete_events) == 1
data = complete_events[0]["data"]
assert "tokens_in" in data
assert "tokens_out" in data
assert isinstance(data["tokens_in"], int)
assert isinstance(data["tokens_out"], int)
def test_retry_clears_stale_token_usage(service, store):
"""After an errored run is retried, stale token keys are absent."""
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.update_status(run.id, RunStatus.RUNNING)
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 500, "tokens_out": 200})
store.set_error(run.id, "timeout")
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
MockGraph.return_value = _mock_graph([("news_analyst", "fresh")])
list(service.stream_events(run.id))
final = store.get(run.id)
assert "market_analyst:0" not in final.token_usage # stale key gone
assert "news_analyst:0" in final.token_usage # new key present
def test_replay_attaches_token_data_to_agent_complete(service, store):
"""Replaying a completed run emits agent:complete with token data."""
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.update_status(run.id, RunStatus.RUNNING)
store.add_report(run.id, "market_analyst:0", "bullish")
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400})
store.update_status(run.id, RunStatus.COMPLETE)
store.update_decision(run.id, "BUY")
with patch("api.services.run_service.TradingAgentsGraph") as MockGraph:
events = list(service.stream_events(run.id))
assert MockGraph.call_count == 0
complete_events = [e for e in events if e["event"] == "agent:complete"]
assert len(complete_events) == 1
assert complete_events[0]["data"]["tokens_in"] == 1200
assert complete_events[0]["data"]["tokens_out"] == 400
def test_replay_defaults_to_zero_tokens_when_missing(service, store):
"""Replay of a run with no token_usage emits tokens_in=0, tokens_out=0."""
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.update_status(run.id, RunStatus.RUNNING)
store.add_report(run.id, "market_analyst:0", "bearish")
store.update_status(run.id, RunStatus.COMPLETE)
store.update_decision(run.id, "SELL")
# No add_token_usage call — simulates old run with no token data
with patch("api.services.run_service.TradingAgentsGraph"):
events = list(service.stream_events(run.id))
complete = next(e for e in events if e["event"] == "agent:complete")
assert complete["data"]["tokens_in"] == 0
assert complete["data"]["tokens_out"] == 0

View File

@ -1,10 +1,11 @@
import pytest
import sqlite3 as _sqlite3
from api.store.runs_store import RunsStore
from api.models.run import RunConfig, RunStatus
def test_create_and_get_run():
store = RunsStore()
def test_create_and_get_run(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
config = RunConfig(ticker="NVDA", date="2024-05-10")
run = store.create(config)
assert run.id is not None
@ -13,16 +14,108 @@ def test_create_and_get_run():
assert fetched.ticker == "NVDA"
def test_list_runs():
store = RunsStore()
def test_list_runs(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
store.create(RunConfig(ticker="AAPL", date="2024-05-09"))
runs = store.list_all()
assert len(runs) == 2
def test_update_run_status():
store = RunsStore()
def test_update_run_status(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
store.update_status(run.id, RunStatus.RUNNING)
assert store.get(run.id).status == RunStatus.RUNNING
def test_add_report(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
store.add_report(run.id, "market_analyst:0", "bullish")
assert store.get(run.id).reports == {"market_analyst:0": "bullish"}
def test_set_error(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
store.set_error(run.id, "timeout")
fetched = store.get(run.id)
assert fetched.status == RunStatus.ERROR
assert fetched.error == "timeout"
def test_clear_reports(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
store.add_report(run.id, "market_analyst:0", "bullish")
store.clear_reports(run.id)
assert store.get(run.id).reports == {}
def test_update_decision(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
run = store.create(RunConfig(ticker="NVDA", date="2024-05-10"))
store.update_decision(run.id, "BUY")
assert store.get(run.id).decision == "BUY"
def test_migration_adds_token_usage_column_to_existing_db(tmp_path):
"""DB created without token_usage gets the column added on RunsStore.__init__."""
db_path = tmp_path / "old.sqlite"
# Create a DB without the token_usage column (simulate pre-migration state)
conn = _sqlite3.connect(str(db_path))
conn.execute("""
CREATE TABLE runs (
id TEXT PRIMARY KEY, ticker TEXT NOT NULL, date TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued', decision TEXT,
created_at TEXT NOT NULL, config TEXT,
reports TEXT NOT NULL DEFAULT '{}', error TEXT
)
""")
conn.commit()
conn.close()
# Initialising the store should migrate the column
store = RunsStore(db_path)
cols = {
row["name"]
for row in store._conn.execute("PRAGMA table_info(runs)")
}
assert "token_usage" in cols
def test_migration_is_idempotent(tmp_path):
"""Re-initialising the store after migration does not crash."""
db_path = tmp_path / "test.sqlite"
RunsStore(db_path) # first init — creates table + column
RunsStore(db_path) # second init — column already present, should not crash
def test_add_token_usage(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400})
result = store.get(run.id)
assert "market_analyst:0" in result.token_usage
assert result.token_usage["market_analyst:0"].tokens_in == 1200
assert result.token_usage["market_analyst:0"].tokens_out == 400
def test_clear_token_usage(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 1200, "tokens_out": 400})
store.clear_token_usage(run.id)
assert store.get(run.id).token_usage == {}
def test_clear_token_usage_does_not_affect_reports(tmp_path):
store = RunsStore(tmp_path / "test.sqlite")
run = store.create(RunConfig(ticker="NVDA", date="2026-03-23"))
store.add_report(run.id, "market_analyst:0", "bullish")
store.add_token_usage(run.id, "market_analyst:0", {"tokens_in": 100, "tokens_out": 50})
store.clear_token_usage(run.id)
fetched = store.get(run.id)
assert fetched.reports == {"market_analyst:0": "bullish"}
assert fetched.token_usage == {}

View File

@ -0,0 +1,74 @@
import threading
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from langchain_core.outputs import LLMResult, ChatGeneration
from api.callbacks.token_handler import TokenCallbackHandler
def _make_response(input_tokens: int, output_tokens: int) -> LLMResult:
"""Build a minimal LLMResult that on_llm_end will parse."""
msg = AIMessage(content="ok")
msg.usage_metadata = {"input_tokens": input_tokens, "output_tokens": output_tokens}
gen = ChatGeneration(message=msg)
return LLMResult(generations=[[gen]])
def test_snapshot_and_reset_returns_delta():
handler = TokenCallbackHandler()
handler.on_llm_end(_make_response(100, 40))
result = handler.snapshot_and_reset()
assert result == {"in": 100, "out": 40}
def test_snapshot_and_reset_zeroes_counters():
handler = TokenCallbackHandler()
handler.on_llm_end(_make_response(100, 40))
handler.snapshot_and_reset()
second = handler.snapshot_and_reset()
assert second == {"in": 0, "out": 0}
def test_multiple_llm_calls_accumulate():
handler = TokenCallbackHandler()
handler.on_llm_end(_make_response(100, 40))
handler.on_llm_end(_make_response(200, 60))
result = handler.snapshot_and_reset()
assert result == {"in": 300, "out": 100}
def test_concurrent_on_llm_end_does_not_corrupt():
handler = TokenCallbackHandler()
threads = [
threading.Thread(target=handler.on_llm_end, args=(_make_response(10, 5),))
for _ in range(20)
]
for t in threads:
t.start()
for t in threads:
t.join()
result = handler.snapshot_and_reset()
assert result == {"in": 200, "out": 100}
def test_missing_usage_metadata_does_not_crash():
handler = TokenCallbackHandler()
msg = AIMessage(content="no metadata")
# No usage_metadata attribute
gen = ChatGeneration(message=msg)
response = LLMResult(generations=[[gen]])
handler.on_llm_end(response) # should not raise
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
def test_empty_outer_generations_does_not_crash():
handler = TokenCallbackHandler()
response = LLMResult(generations=[])
handler.on_llm_end(response) # IndexError guard
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}
def test_empty_inner_generations_does_not_crash():
handler = TokenCallbackHandler()
response = LLMResult(generations=[[]])
handler.on_llm_end(response) # IndexError guard
assert handler.snapshot_and_reset() == {"in": 0, "out": 0}

View File

@ -0,0 +1,30 @@
import { render, screen } from '@testing-library/react'
import TokenStatsBar from '@/features/run-detail/components/TokenStatsBar'
test('renders nothing when both token counts are zero', () => {
const { container } = render(
<TokenStatsBar tokensTotal={{ in: 0, out: 0 }} status="running" />
)
expect(container.firstChild).toBeNull()
})
test('renders total, input, output when tokens are non-zero', () => {
render(
<TokenStatsBar tokensTotal={{ in: 3100, out: 1100 }} status="running" />
)
// Total = 3100 + 1100 = 4200 → "4.2k"
expect(screen.getByText('4.2k')).toBeInTheDocument()
// Input → "3.1k"
expect(screen.getByText('3.1k')).toBeInTheDocument()
// Output → "1.1k"
expect(screen.getByText('1.1k')).toBeInTheDocument()
})
test('formats values below 1000 without k suffix', () => {
render(
<TokenStatsBar tokensTotal={{ in: 800, out: 150 }} status="complete" />
)
expect(screen.getByText('950')).toBeInTheDocument() // total
expect(screen.getByText('800')).toBeInTheDocument() // input
expect(screen.getByText('150')).toBeInTheDocument() // output
})

View File

@ -16,8 +16,21 @@ jest.mock('@/lib/sse', () => ({
}),
}))
// getRun defaults to 'queued' so existing SSE-path tests still exercise the SSE branch.
// Tests that need a different status use mockResolvedValueOnce to override.
jest.mock('@/lib/api-client', () => ({
getRunStreamUrl: (id: string) => `/api/runs/${id}/stream`,
getRun: jest.fn().mockResolvedValue({
id: 'abc',
ticker: 'NVDA',
date: '2026-03-23',
status: 'queued',
decision: null,
created_at: '2026-03-23T00:00:00Z',
config: null,
reports: {},
error: null,
}),
}))
test('appends multiple turns for same step', async () => {
@ -46,3 +59,118 @@ test('initial reports are empty arrays', () => {
const { result } = renderHook(() => useRunStream('abc'))
expect(result.current.reports['market_analyst']).toEqual([])
})
test('hydrates from reports when run is complete, skipping SSE', async () => {
const { getRun } = jest.requireMock('@/lib/api-client')
const { createSSEConnection } = jest.requireMock('@/lib/sse')
// Reset call history so we can assert createSSEConnection was NOT called for this run
jest.clearAllMocks()
getRun.mockResolvedValueOnce({
id: 'xyz',
ticker: 'AAPL',
date: '2026-03-23',
status: 'complete',
decision: 'SELL',
created_at: '2026-03-23T00:00:00Z',
config: null,
reports: { 'market_analyst:0': 'bearish signal' },
error: null,
})
const { result } = renderHook(() => useRunStream('xyz'))
await act(async () => { await new Promise((r) => setTimeout(r, 10)) })
expect(result.current.status).toBe('complete')
expect(result.current.verdict).toBe('SELL')
expect(result.current.reports['market_analyst']).toEqual(['bearish signal'])
expect(createSSEConnection).not.toHaveBeenCalled() // SSE skipped for completed run
})
test('AGENT_COMPLETE accumulates tokensByStep and tokensTotal', async () => {
const { getRun } = jest.requireMock('@/lib/api-client')
const { createSSEConnection } = jest.requireMock('@/lib/sse')
jest.clearAllMocks()
// getRun returns queued so SSE path runs
getRun.mockResolvedValueOnce({
id: 'abc', ticker: 'NVDA', date: '2026-03-23', status: 'queued',
decision: null, created_at: '2026-03-23T00:00:00Z',
config: null, reports: {}, error: null, token_usage: null,
})
// SSE mock emits one agent:complete with token data
createSSEConnection.mockImplementationOnce(
(_url: string, handlers: Record<string, (d: unknown) => void>) => {
setTimeout(() => {
handlers.onAgentStart?.({ step: 'market_analyst', turn: 0 })
handlers.onAgentComplete?.({
step: 'market_analyst', turn: 0, report: 'bullish',
tokens_in: 1200, tokens_out: 400,
})
handlers.onRunComplete?.({ decision: 'BUY', run_id: 'abc' })
}, 0)
return jest.fn()
}
)
const { result } = renderHook(() => useRunStream('abc'))
await act(async () => { await new Promise((r) => setTimeout(r, 10)) })
expect(result.current.tokensByStep['market_analyst']).toEqual({ in: 1200, out: 400 })
expect(result.current.tokensTotal).toEqual({ in: 1200, out: 400 })
})
test('missing tokens_in/out in AGENT_COMPLETE defaults to 0', async () => {
const { getRun } = jest.requireMock('@/lib/api-client')
const { createSSEConnection } = jest.requireMock('@/lib/sse')
jest.clearAllMocks()
getRun.mockResolvedValueOnce({
id: 'abc', ticker: 'NVDA', date: '2026-03-23', status: 'queued',
decision: null, created_at: '2026-03-23T00:00:00Z',
config: null, reports: {}, error: null, token_usage: null,
})
createSSEConnection.mockImplementationOnce(
(_url: string, handlers: Record<string, (d: unknown) => void>) => {
setTimeout(() => {
handlers.onAgentStart?.({ step: 'news_analyst', turn: 0 })
// No tokens_in/tokens_out in payload
handlers.onAgentComplete?.({ step: 'news_analyst', turn: 0, report: 'ok' })
handlers.onRunComplete?.({ decision: 'HOLD', run_id: 'abc' })
}, 0)
return jest.fn()
}
)
const { result } = renderHook(() => useRunStream('abc'))
await act(async () => { await new Promise((r) => setTimeout(r, 10)) })
expect(result.current.tokensByStep['news_analyst']).toEqual({ in: 0, out: 0 })
expect(result.current.tokensTotal).toEqual({ in: 0, out: 0 })
})
test('completed-run hydration populates tokens from getRun().token_usage without SSE', async () => {
const { getRun } = jest.requireMock('@/lib/api-client')
const { createSSEConnection } = jest.requireMock('@/lib/sse')
jest.clearAllMocks()
getRun.mockResolvedValueOnce({
id: 'tok', ticker: 'AAPL', date: '2026-03-23', status: 'complete',
decision: 'BUY', created_at: '2026-03-23T00:00:00Z',
config: null,
reports: { 'market_analyst:0': 'bullish' },
error: null,
token_usage: { 'market_analyst:0': { tokens_in: 1200, tokens_out: 400 } },
})
const { result } = renderHook(() => useRunStream('tok'))
await act(async () => { await new Promise((r) => setTimeout(r, 10)) })
expect(createSSEConnection).not.toHaveBeenCalled()
expect(result.current.status).toBe('complete')
expect(result.current.tokensByStep['market_analyst']).toEqual({ in: 1200, out: 400 })
expect(result.current.tokensTotal).toEqual({ in: 1200, out: 400 })
})

View File

@ -6,6 +6,7 @@ import VerdictBanner from '@/features/run-detail/components/VerdictBanner'
import PhaseTabs from '@/features/run-detail/components/PhaseTabs'
import { getRun } from '@/lib/api-client'
import type { RunSummary } from '@/lib/types/run'
import TokenStatsBar from '@/features/run-detail/components/TokenStatsBar'
const STATUS_CONFIG: Record<string, {
bg: string; color: string; dot: string; label: string; pulse: boolean
@ -18,7 +19,7 @@ const STATUS_CONFIG: Record<string, {
export default function RunDetailPage({ params }: { params: Promise<{ id: string }> }) {
const { id } = use(params)
const { steps, reports, verdict, status, error } = useRunStream(id)
const { steps, reports, verdict, status, error, tokensTotal, tokensByStep } = useRunStream(id)
const [run, setRun] = useState<RunSummary | null>(null)
useEffect(() => {
@ -88,6 +89,9 @@ export default function RunDetailPage({ params }: { params: Promise<{ id: string
</div>
</div>
{/* Token stats bar */}
<TokenStatsBar tokensTotal={tokensTotal} status={status} />
{/* Pipeline */}
<PipelineStepper steps={steps} />
@ -111,7 +115,7 @@ export default function RunDetailPage({ params }: { params: Promise<{ id: string
)}
{/* Phase tabs + reports */}
<PhaseTabs steps={steps} reports={reports} />
<PhaseTabs steps={steps} reports={reports} tokensByStep={tokensByStep} />
</div>
)
}

View File

@ -2,6 +2,10 @@ import { AGENT_STEPS, AGENT_STEP_LABELS, STEP_PHASE } from '@/lib/types/run'
import type { AgentStep } from '@/lib/types/run'
import type { StepStatus } from '@/lib/types/agents'
function formatTokens(n: number): string {
return n >= 1000 ? `${(n / 1000).toFixed(1)}k` : String(n)
}
type Phase = 'analysts' | 'researchers' | 'trader' | 'risk'
const MULTI_TURN_STEPS = new Set<AgentStep>([
@ -43,9 +47,10 @@ type Props = {
phase: Phase
steps: Record<AgentStep, StepStatus>
reports: Record<AgentStep, string[]>
tokensByStep: Record<AgentStep, { in: number; out: number }>
}
export default function AnalystReports({ phase, steps, reports }: Props) {
export default function AnalystReports({ phase, steps, reports, tokensByStep }: Props) {
const phaseSteps = AGENT_STEPS.filter((s) => STEP_PHASE[s] === phase)
return (
@ -112,6 +117,22 @@ export default function AnalystReports({ phase, steps, reports }: Props) {
T{i + 1}
</span>
)}
{(() => {
const tok = tokensByStep[step]
return tok && (tok.in > 0 || tok.out > 0) ? (
<span
style={{
fontFamily: 'var(--font-mono)',
fontSize: '9px',
color: 'var(--text-mid)',
letterSpacing: '0.02em',
}}
>
<span style={{ color: 'var(--accent)' }}>{formatTokens(tok.in)}</span>{' '}
{formatTokens(tok.out)}
</span>
) : null
})()}
</div>
<div

View File

@ -15,8 +15,9 @@ const TABS: { label: string; phase: Phase; count: string }[] = [
]
type Props = {
steps: Record<AgentStep, StepStatus>
reports: Record<AgentStep, string[]>
steps: Record<AgentStep, StepStatus>
reports: Record<AgentStep, string[]>
tokensByStep: Record<AgentStep, { in: number; out: number }>
}
function getPhaseCompletion(phase: Phase, steps: Record<AgentStep, StepStatus>): number {
@ -32,7 +33,7 @@ function getPhaseStatus(phase: Phase, steps: Record<AgentStep, StepStatus>): 'do
return 'pending'
}
export default function PhaseTabs({ steps, reports }: Props) {
export default function PhaseTabs({ steps, reports, tokensByStep }: Props) {
const [active, setActive] = useState<Phase>('analysts')
return (
@ -118,7 +119,7 @@ export default function PhaseTabs({ steps, reports }: Props) {
</div>
{/* Reports */}
<AnalystReports phase={active} steps={steps} reports={reports} />
<AnalystReports phase={active} steps={steps} reports={reports} tokensByStep={tokensByStep} />
</div>
)
}

View File

@ -0,0 +1,73 @@
import type { TokenCount } from '../types'
function formatTokens(n: number): string {
return n >= 1000 ? `${(n / 1000).toFixed(1)}k` : String(n)
}
type Props = {
tokensTotal: TokenCount
status: string
}
export default function TokenStatsBar({ tokensTotal }: Props) {
if (tokensTotal.in === 0 && tokensTotal.out === 0) return null
const total = tokensTotal.in + tokensTotal.out
return (
<div
className="flex items-center gap-4 px-4 py-2.5 rounded-xl"
style={{
background: 'var(--bg-elevated)',
border: '1px solid var(--border-raised)',
}}
>
<div className="flex flex-col gap-0.5">
<span
className="text-[8px] uppercase tracking-widest"
style={{ color: 'var(--text-mid)', fontFamily: 'var(--font-mono)' }}
>
Total
</span>
<span
className="text-[11px] font-bold"
style={{ color: 'var(--text-high)', fontFamily: 'var(--font-mono)' }}
>
{formatTokens(total)}
</span>
</div>
<div className="w-px h-6 shrink-0" style={{ background: 'var(--text-low)', opacity: 0.35 }} />
<div className="flex flex-col gap-0.5">
<span
className="text-[8px] uppercase tracking-widest"
style={{ color: 'var(--text-mid)', fontFamily: 'var(--font-mono)' }}
>
Input
</span>
<span
className="text-[11px] font-bold"
style={{ color: 'var(--accent-light)', fontFamily: 'var(--font-mono)' }}
>
{formatTokens(tokensTotal.in)}
</span>
</div>
<div className="flex flex-col gap-0.5">
<span
className="text-[8px] uppercase tracking-widest"
style={{ color: 'var(--text-mid)', fontFamily: 'var(--font-mono)' }}
>
Output
</span>
<span
className="text-[11px] font-bold"
style={{ color: 'var(--text-mid)', fontFamily: 'var(--font-mono)' }}
>
{formatTokens(tokensTotal.out)}
</span>
</div>
</div>
)
}

View File

@ -1,24 +1,28 @@
'use client'
import { useEffect, useReducer } from 'react'
import { createSSEConnection } from '@/lib/sse'
import { getRunStreamUrl } from '@/lib/api-client'
import { getRun, getRunStreamUrl } from '@/lib/api-client'
import { AGENT_STEPS } from '@/lib/types/run'
import type { AgentStep } from '@/lib/types/run'
import type { RunStreamState } from '../types'
import type { RunStreamState, TokenCount } from '../types'
const zeroTokens = (): TokenCount => ({ in: 0, out: 0 })
const initialState: RunStreamState = {
status: 'connecting',
steps: Object.fromEntries(AGENT_STEPS.map((s) => [s, 'pending'])) as RunStreamState['steps'],
reports: Object.fromEntries(AGENT_STEPS.map((s) => [s, []])) as RunStreamState['reports'],
steps: Object.fromEntries(AGENT_STEPS.map((s) => [s, 'pending'])) as RunStreamState['steps'],
reports: Object.fromEntries(AGENT_STEPS.map((s) => [s, []])) as RunStreamState['reports'],
tokensByStep: Object.fromEntries(AGENT_STEPS.map((s) => [s, zeroTokens()])) as RunStreamState['tokensByStep'],
tokensTotal: zeroTokens(),
verdict: null,
error: null,
}
type Action =
| { type: 'AGENT_START'; step: AgentStep; turn: number }
| { type: 'AGENT_COMPLETE'; step: AgentStep; turn: number; report: string }
| { type: 'RUN_COMPLETE'; decision: string }
| { type: 'RUN_ERROR'; message: string }
| { type: 'AGENT_START'; step: AgentStep; turn: number }
| { type: 'AGENT_COMPLETE'; step: AgentStep; turn: number; report: string; tokens_in?: number; tokens_out?: number }
| { type: 'RUN_COMPLETE'; decision: string }
| { type: 'RUN_ERROR'; message: string }
| { type: 'CONNECTED' }
function reducer(state: RunStreamState, action: Action): RunStreamState {
@ -27,19 +31,30 @@ function reducer(state: RunStreamState, action: Action): RunStreamState {
return { ...state, status: 'running' }
case 'AGENT_START':
// Only transition to 'running' on first turn (don't regress from 'done')
if (state.steps[action.step] !== 'pending') return state
return { ...state, steps: { ...state.steps, [action.step]: 'running' } }
case 'AGENT_COMPLETE':
case 'AGENT_COMPLETE': {
const dIn = action.tokens_in ?? 0
const dOut = action.tokens_out ?? 0
const prev = state.tokensByStep[action.step] ?? zeroTokens()
return {
...state,
steps: { ...state.steps, [action.step]: 'done' },
steps: { ...state.steps, [action.step]: 'done' },
reports: {
...state.reports,
[action.step]: [...(state.reports[action.step] ?? []), action.report],
},
tokensByStep: {
...state.tokensByStep,
[action.step]: { in: prev.in + dIn, out: prev.out + dOut },
},
tokensTotal: {
in: state.tokensTotal.in + dIn,
out: state.tokensTotal.out + dOut,
},
}
}
case 'RUN_COMPLETE':
return { ...state, status: 'complete', verdict: action.decision as RunStreamState['verdict'] }
@ -56,17 +71,43 @@ export function useRunStream(runId: string): RunStreamState {
const [state, dispatch] = useReducer(reducer, initialState)
useEffect(() => {
const url = getRunStreamUrl(runId)
const close = createSSEConnection(url, {
onOpen: () => dispatch({ type: 'CONNECTED' }),
onAgentStart: ({ step, turn }) =>
dispatch({ type: 'AGENT_START', step: step as AgentStep, turn }),
onAgentComplete: ({ step, turn, report }) =>
dispatch({ type: 'AGENT_COMPLETE', step: step as AgentStep, turn, report }),
onRunComplete: ({ decision }) => dispatch({ type: 'RUN_COMPLETE', decision }),
onRunError: ({ message }) => dispatch({ type: 'RUN_ERROR', message }),
let close: (() => void) | undefined
let aborted = false
getRun(runId).then((run) => {
if (aborted) return
if (run.status === 'complete' && run.reports) {
dispatch({ type: 'CONNECTED' })
for (const [key, report] of Object.entries(run.reports)) {
const lastColon = key.lastIndexOf(':')
const step = key.slice(0, lastColon) as AgentStep
const turn = parseInt(key.slice(lastColon + 1), 10)
const tok = run.token_usage?.[key] ?? { tokens_in: 0, tokens_out: 0 }
dispatch({ type: 'AGENT_START', step, turn })
dispatch({ type: 'AGENT_COMPLETE', step, turn, report,
tokens_in: tok.tokens_in, tokens_out: tok.tokens_out })
}
dispatch({ type: 'RUN_COMPLETE', decision: run.decision ?? 'HOLD' })
return
}
const url = getRunStreamUrl(runId)
close = createSSEConnection(url, {
onOpen: () => dispatch({ type: 'CONNECTED' }),
onAgentStart: ({ step, turn }) =>
dispatch({ type: 'AGENT_START', step: step as AgentStep, turn }),
onAgentComplete: ({ step, turn, report, tokens_in, tokens_out }) =>
dispatch({ type: 'AGENT_COMPLETE', step: step as AgentStep, turn, report,
tokens_in, tokens_out }),
onRunComplete: ({ decision }) => dispatch({ type: 'RUN_COMPLETE', decision }),
onRunError: ({ message }) => dispatch({ type: 'RUN_ERROR', message }),
})
}).catch(() => {
if (!aborted) dispatch({ type: 'RUN_ERROR', message: 'Failed to load run' })
})
return close
return () => { aborted = true; close?.() }
}, [runId])
return state

View File

@ -1,10 +1,14 @@
import type { AgentStep, RunStatus } from '@/lib/types/run'
import type { Decision, StepStatus } from '@/lib/types/agents'
export type TokenCount = { in: number; out: number }
export type RunStreamState = {
status: RunStatus | 'connecting'
steps: Record<AgentStep, StepStatus>
reports: Record<AgentStep, string[]>
verdict: Decision | null
error: string | null
tokensByStep: Record<AgentStep, TokenCount>
tokensTotal: TokenCount
}

View File

@ -14,14 +14,21 @@ async function apiFetch<T>(path: string, init?: RequestInit): Promise<T> {
return res.json() as Promise<T>
}
export type RunResult = RunSummary & {
config: RunConfig | null
reports: Record<string, string>
error: string | null
token_usage: Record<string, { tokens_in: number; tokens_out: number }> | null
}
export const createRun = (config: RunConfig): Promise<RunSummary> =>
apiFetch('/api/runs', { method: 'POST', body: JSON.stringify(config) })
export const listRuns = (): Promise<RunSummary[]> =>
apiFetch('/api/runs')
export const getRun = (id: string): Promise<RunSummary> =>
apiFetch(`/api/runs/${id}`)
export const getRun = (id: string): Promise<RunResult> =>
apiFetch<RunResult>(`/api/runs/${id}`)
export const getSettings = (): Promise<Settings> =>
apiFetch('/api/settings')

View File

@ -1,6 +1,9 @@
export type SSEHandlers = {
onAgentStart?: (data: { step: string; turn: number }) => void
onAgentComplete?: (data: { step: string; turn: number; report: string }) => void
onAgentComplete?: (data: {
step: string; turn: number; report: string;
tokens_in?: number; tokens_out?: number
}) => void
onRunComplete?: (data: { decision: string; run_id: string }) => void
onRunError?: (data: { message: string }) => void
onOpen?: () => void