This commit is contained in:
claytonbrown 2026-04-20 05:41:16 -07:00 committed by GitHub
commit 7f1d607bfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 806 additions and 145 deletions

View File

@ -253,7 +253,7 @@ def format_tokens(n):
return str(n)
def update_display(layout, spinner_text=None, stats_handler=None, start_time=None):
def update_display(layout, spinner_text=None, stats_handler=None, start_time=None, spend_tracker=None):
# Header with welcome message
layout["header"].update(
Panel(
@ -445,6 +445,15 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
tokens_str = "Tokens: --"
stats_parts.append(tokens_str)
# Cost display from spend tracker
if spend_tracker:
cost_str = f"${spend_tracker.total_cost_usd:.4f}"
if spend_tracker.max_cost is not None:
cost_str += f"/${spend_tracker.max_cost:.2f}"
if spend_tracker.budget_exceeded:
cost_str = f"[red]{cost_str} OVER BUDGET[/red]"
stats_parts.append(f"Cost: {cost_str}")
stats_parts.append(f"Reports: {reports_completed}/{reports_total}")
# Elapsed time
@ -926,7 +935,7 @@ def format_tool_args(args, max_length=80) -> str:
return result[:max_length - 3] + "..."
return result
def run_analysis():
def run_analysis(max_cost: float | None = None):
# First get all user selections
selections = get_user_selections()
@ -943,10 +952,17 @@ def run_analysis():
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
config["anthropic_effort"] = selections.get("anthropic_effort")
config["output_language"] = selections.get("output_language", "English")
# Spend limit (CLI flag → config → DEFAULT_CONFIG)
if max_cost is not None:
config["max_cost"] = max_cost
# Create stats callback handler for tracking LLM/tool calls
stats_handler = StatsCallbackHandler()
# Create spend tracker with optional budget limit
from tradingagents import SpendTracker, BudgetExceededError
spend_tracker = SpendTracker(max_cost=config.get("max_cost"))
# Normalize analyst selection to predefined order (selection is a 'set', order is fixed)
selected_set = {analyst.value for analyst in selections["analysts"]}
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
@ -956,7 +972,7 @@ def run_analysis():
selected_analyst_keys,
config=config,
debug=True,
callbacks=[stats_handler],
callbacks=[stats_handler, spend_tracker],
)
# Initialize message buffer with selected analysts
@ -1018,7 +1034,7 @@ def run_analysis():
with Live(layout, refresh_per_second=4) as live:
# Initial display
update_display(layout, stats_handler=stats_handler, start_time=start_time)
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
# Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
@ -1029,18 +1045,18 @@ def run_analysis():
"System",
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
)
update_display(layout, stats_handler=stats_handler, start_time=start_time)
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
# Update agent status to in_progress for the first analyst
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
message_buffer.update_agent_status(first_analyst, "in_progress")
update_display(layout, stats_handler=stats_handler, start_time=start_time)
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
# Create spinner text
spinner_text = (
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
)
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
# Initialize state and get graph args with callbacks
init_agent_state = graph.propagator.create_initial_state(
@ -1048,119 +1064,140 @@ def run_analysis():
)
# Pass callbacks to graph config for tool execution tracking
# (LLM tracking is handled separately via LLM constructor)
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
args = graph.propagator.get_graph_args(callbacks=[stats_handler, spend_tracker])
# Notify spend tracker of ticker start
spend_tracker.begin_ticker(selections["ticker"])
# Stream the analysis
trace = []
for chunk in graph.graph.stream(init_agent_state, **args):
# Process all messages in chunk, deduplicating by message ID
for message in chunk.get("messages", []):
msg_id = getattr(message, "id", None)
if msg_id is not None:
if msg_id in message_buffer._processed_message_ids:
continue
message_buffer._processed_message_ids.add(msg_id)
budget_aborted = False
try:
for chunk in graph.graph.stream(init_agent_state, **args):
# Process all messages in chunk, deduplicating by message ID
for message in chunk.get("messages", []):
msg_id = getattr(message, "id", None)
if msg_id is not None:
if msg_id in message_buffer._processed_message_ids:
continue
message_buffer._processed_message_ids.add(msg_id)
msg_type, content = classify_message_type(message)
if content and content.strip():
message_buffer.add_message(msg_type, content)
msg_type, content = classify_message_type(message)
if content and content.strip():
message_buffer.add_message(msg_type, content)
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
# Update analyst statuses based on report state (runs on every chunk)
update_analyst_statuses(message_buffer, chunk)
# Update analyst statuses based on report state (runs on every chunk)
update_analyst_statuses(message_buffer, chunk)
# Research Team - Handle Investment Debate State
if chunk.get("investment_debate_state"):
debate_state = chunk["investment_debate_state"]
bull_hist = debate_state.get("bull_history", "").strip()
bear_hist = debate_state.get("bear_history", "").strip()
judge = debate_state.get("judge_decision", "").strip()
# Research Team - Handle Investment Debate State
if chunk.get("investment_debate_state"):
debate_state = chunk["investment_debate_state"]
bull_hist = debate_state.get("bull_history", "").strip()
bear_hist = debate_state.get("bear_history", "").strip()
judge = debate_state.get("judge_decision", "").strip()
# Only update status when there's actual content
if bull_hist or bear_hist:
update_research_team_status("in_progress")
if bull_hist:
message_buffer.update_report_section(
"investment_plan", f"### Bull Researcher Analysis\n{bull_hist}"
)
if bear_hist:
message_buffer.update_report_section(
"investment_plan", f"### Bear Researcher Analysis\n{bear_hist}"
)
if judge:
message_buffer.update_report_section(
"investment_plan", f"### Research Manager Decision\n{judge}"
)
update_research_team_status("completed")
message_buffer.update_agent_status("Trader", "in_progress")
# Trading Team
if chunk.get("trader_investment_plan"):
message_buffer.update_report_section(
"trader_investment_plan", chunk["trader_investment_plan"]
)
if message_buffer.agent_status.get("Trader") != "completed":
message_buffer.update_agent_status("Trader", "completed")
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
# Risk Management Team - Handle Risk Debate State
if chunk.get("risk_debate_state"):
risk_state = chunk["risk_debate_state"]
agg_hist = risk_state.get("aggressive_history", "").strip()
con_hist = risk_state.get("conservative_history", "").strip()
neu_hist = risk_state.get("neutral_history", "").strip()
judge = risk_state.get("judge_decision", "").strip()
if agg_hist:
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}"
)
if con_hist:
if message_buffer.agent_status.get("Conservative Analyst") != "completed":
message_buffer.update_agent_status("Conservative Analyst", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}"
)
if neu_hist:
if message_buffer.agent_status.get("Neutral Analyst") != "completed":
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
)
if judge:
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
# Only update status when there's actual content
if bull_hist or bear_hist:
update_research_team_status("in_progress")
if bull_hist:
message_buffer.update_report_section(
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
"investment_plan", f"### Bull Researcher Analysis\n{bull_hist}"
)
message_buffer.update_agent_status("Aggressive Analyst", "completed")
message_buffer.update_agent_status("Conservative Analyst", "completed")
message_buffer.update_agent_status("Neutral Analyst", "completed")
message_buffer.update_agent_status("Portfolio Manager", "completed")
if bear_hist:
message_buffer.update_report_section(
"investment_plan", f"### Bear Researcher Analysis\n{bear_hist}"
)
if judge:
message_buffer.update_report_section(
"investment_plan", f"### Research Manager Decision\n{judge}"
)
update_research_team_status("completed")
message_buffer.update_agent_status("Trader", "in_progress")
# Update the display
update_display(layout, stats_handler=stats_handler, start_time=start_time)
# Trading Team
if chunk.get("trader_investment_plan"):
message_buffer.update_report_section(
"trader_investment_plan", chunk["trader_investment_plan"]
)
if message_buffer.agent_status.get("Trader") != "completed":
message_buffer.update_agent_status("Trader", "completed")
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
trace.append(chunk)
# Risk Management Team - Handle Risk Debate State
if chunk.get("risk_debate_state"):
risk_state = chunk["risk_debate_state"]
agg_hist = risk_state.get("aggressive_history", "").strip()
con_hist = risk_state.get("conservative_history", "").strip()
neu_hist = risk_state.get("neutral_history", "").strip()
judge = risk_state.get("judge_decision", "").strip()
if agg_hist:
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}"
)
if con_hist:
if message_buffer.agent_status.get("Conservative Analyst") != "completed":
message_buffer.update_agent_status("Conservative Analyst", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}"
)
if neu_hist:
if message_buffer.agent_status.get("Neutral Analyst") != "completed":
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
)
if judge:
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
message_buffer.update_report_section(
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
)
message_buffer.update_agent_status("Aggressive Analyst", "completed")
message_buffer.update_agent_status("Conservative Analyst", "completed")
message_buffer.update_agent_status("Neutral Analyst", "completed")
message_buffer.update_agent_status("Portfolio Manager", "completed")
# Update the display
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
trace.append(chunk)
except BudgetExceededError:
budget_aborted = True
message_buffer.add_message(
"System",
f"[red]Budget exceeded (${spend_tracker.total_cost_usd:.4f} / ${spend_tracker.max_cost:.2f}). Saving partial results.[/red]",
)
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
# Log per-ticker and cumulative spend to stderr
spend_tracker.log_ticker_spend(selections["ticker"])
# Get final state and decision
final_state = trace[-1]
decision = graph.process_signal(final_state["final_trade_decision"])
if trace:
final_state = trace[-1]
else:
final_state = dict(init_agent_state)
if budget_aborted:
final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED")
decision = graph.process_signal(final_state.get("final_trade_decision", "BUDGET_EXCEEDED"))
# Update all agent statuses to completed
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "completed")
message_buffer.add_message(
"System", f"Completed analysis for {selections['analysis_date']}"
"System",
f"{'Budget exceeded — partial' if budget_aborted else 'Completed'} analysis for {selections['analysis_date']}"
)
# Update final report sections
@ -1168,10 +1205,13 @@ def run_analysis():
if section in final_state:
message_buffer.update_report_section(section, final_state[section])
update_display(layout, stats_handler=stats_handler, start_time=start_time)
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
# Post-analysis prompts (outside Live context for clean interaction)
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
if budget_aborted:
console.print("\n[bold red]Analysis aborted — budget exceeded. Partial results saved.[/bold red]\n")
else:
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
# Prompt to save report
save_choice = typer.prompt("Save report?", default="Y").strip().upper()
@ -1197,8 +1237,12 @@ def run_analysis():
@app.command()
def analyze():
run_analysis()
def analyze(
max_cost: Optional[float] = typer.Option(
None, "--max-cost", help="Maximum estimated USD spend for this run. Aborts when exceeded."
),
):
run_analysis(max_cost=max_cost)
if __name__ == "__main__":

237
tests/test_spend_tracker.py Normal file
View File

@ -0,0 +1,237 @@
"""Tests for spend tracking, budget abort, and partial result saving."""
import unittest
from unittest.mock import MagicMock
from uuid import uuid4
from langchain_core.outputs import LLMResult, Generation
from tradingagents.spend_tracker import (
AuditEntry,
BudgetExceededError,
SpendTracker,
TokenRecord,
_get_pricing,
MODEL_PRICING,
_DEFAULT_PRICING,
)
def _make_llm_result(
prompt_tokens: int = 100,
completion_tokens: int = 50,
model: str = "gpt-4o",
) -> LLMResult:
"""Build a minimal LLMResult with token usage."""
return LLMResult(
generations=[[Generation(text="ok")]],
llm_output={
"token_usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"model_name": model,
},
)
class TestSpendTrackerAccumulation(unittest.TestCase):
"""Token and cost accumulation."""
def test_single_call_updates_totals(self):
tracker = SpendTracker()
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
self.assertEqual(tracker.prompt_tokens, 100)
self.assertEqual(tracker.completion_tokens, 50)
self.assertEqual(tracker.total_tokens, 150)
self.assertGreater(tracker.total_cost_usd, 0)
def test_multiple_calls_accumulate(self):
tracker = SpendTracker()
tracker.on_llm_end(_make_llm_result(100, 50), run_id=uuid4())
tracker.on_llm_end(_make_llm_result(200, 100), run_id=uuid4())
self.assertEqual(tracker.prompt_tokens, 300)
self.assertEqual(tracker.completion_tokens, 150)
self.assertEqual(len(tracker.records), 2)
def test_cost_calculation_matches_pricing(self):
tracker = SpendTracker()
inp_price, out_price = MODEL_PRICING["gpt-4o"]
tracker.on_llm_end(_make_llm_result(1_000_000, 1_000_000, "gpt-4o"), run_id=uuid4())
expected = inp_price + out_price # 1M tokens each
self.assertAlmostEqual(tracker.total_cost_usd, expected, places=2)
def test_reset_clears_all(self):
tracker = SpendTracker(max_cost=10.0)
tracker.on_llm_end(_make_llm_result(100, 50), run_id=uuid4())
tracker.budget_exceeded = True
tracker.reset()
self.assertEqual(tracker.total_tokens, 0)
self.assertEqual(tracker.total_cost_usd, 0.0)
self.assertFalse(tracker.budget_exceeded)
self.assertEqual(len(tracker.records), 0)
self.assertEqual(len(tracker.audit_trail), 0)
class TestBudgetAbort(unittest.TestCase):
"""Budget exceeded detection and abort."""
def test_budget_exceeded_flag_set(self):
tracker = SpendTracker(max_cost=0.0001)
tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4())
self.assertTrue(tracker.budget_exceeded)
def test_on_llm_start_raises_when_over_budget(self):
tracker = SpendTracker(max_cost=0.0001)
tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4())
self.assertTrue(tracker.budget_exceeded)
with self.assertRaises(BudgetExceededError):
tracker.on_llm_start({}, ["prompt"], run_id=uuid4())
def test_on_chat_model_start_raises_when_over_budget(self):
tracker = SpendTracker(max_cost=0.0001)
tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4())
with self.assertRaises(BudgetExceededError):
tracker.on_chat_model_start({}, [[]], run_id=uuid4())
def test_no_abort_when_under_budget(self):
tracker = SpendTracker(max_cost=999.0)
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
self.assertFalse(tracker.budget_exceeded)
# Should not raise
tracker.on_llm_start({}, ["prompt"], run_id=uuid4())
def test_no_budget_means_never_exceeded(self):
tracker = SpendTracker(max_cost=None)
tracker.on_llm_end(_make_llm_result(1_000_000, 1_000_000, "gpt-4o"), run_id=uuid4())
self.assertFalse(tracker.budget_exceeded)
class TestPartialResults(unittest.TestCase):
"""Verify partial results are preserved when budget is exceeded."""
def test_propagate_returns_partial_on_budget_exceeded(self):
"""Simulate what TradingAgentsGraph.propagate() does on BudgetExceededError."""
# This mirrors the logic in trading_graph.py propagate() method:
# on BudgetExceededError, it builds partial state from init_agent_state
init_state = {
"company_of_interest": "AAPL",
"trade_date": "2025-01-15",
"market_report": "partial market data",
"messages": [],
}
# Simulate the except BudgetExceededError branch
final_state = dict(init_state)
final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED")
self.assertEqual(final_state["final_trade_decision"], "BUDGET_EXCEEDED")
self.assertEqual(final_state["company_of_interest"], "AAPL")
self.assertEqual(final_state["market_report"], "partial market data")
def test_budget_exceeded_error_message_contains_amounts(self):
tracker = SpendTracker(max_cost=0.01)
tracker.on_llm_end(_make_llm_result(10000, 5000, "gpt-4o"), run_id=uuid4())
try:
tracker.on_llm_start({}, ["prompt"], run_id=uuid4())
self.fail("Expected BudgetExceededError")
except BudgetExceededError as e:
msg = str(e)
self.assertIn("$0.01", msg) # budget
self.assertIn("spent", msg.lower())
class TestTickerSpend(unittest.TestCase):
"""Per-ticker spend tracking."""
def test_ticker_cost_recorded(self):
tracker = SpendTracker()
tracker.begin_ticker("AAPL")
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
tracker.log_ticker_spend("AAPL")
self.assertIn("AAPL", tracker.ticker_costs)
self.assertGreater(tracker.ticker_costs["AAPL"], 0)
def test_multiple_tickers_tracked_separately(self):
tracker = SpendTracker()
tracker.begin_ticker("AAPL")
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
tracker.log_ticker_spend("AAPL")
tracker.begin_ticker("MSFT")
tracker.on_llm_end(_make_llm_result(200, 100, "gpt-4o"), run_id=uuid4())
tracker.log_ticker_spend("MSFT")
self.assertIn("AAPL", tracker.ticker_costs)
self.assertIn("MSFT", tracker.ticker_costs)
self.assertGreater(tracker.ticker_costs["MSFT"], tracker.ticker_costs["AAPL"])
class TestAuditTrail(unittest.TestCase):
"""Delegation chain audit trail."""
def test_llm_call_creates_audit_entry(self):
tracker = SpendTracker()
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
self.assertEqual(len(tracker.audit_trail), 1)
entry = tracker.audit_trail[0]
self.assertEqual(entry.call_type, "llm")
self.assertEqual(entry.name, "gpt-4o")
self.assertGreater(entry.cost_usd, 0)
def test_tool_call_creates_audit_entry(self):
tracker = SpendTracker()
run_id = uuid4()
tracker.on_tool_start({"name": "get_stock_data"}, "", run_id=run_id)
tracker.on_tool_end("result", run_id=run_id)
self.assertEqual(len(tracker.audit_trail), 1)
entry = tracker.audit_trail[0]
self.assertEqual(entry.call_type, "tool")
self.assertEqual(entry.name, "get_stock_data")
def test_agent_name_resolved_from_chain(self):
tracker = SpendTracker()
parent_id = uuid4()
child_id = uuid4()
tracker.on_chain_start(
{"name": "market_analyst"}, {}, run_id=parent_id
)
tracker.on_llm_end(
_make_llm_result(100, 50, "gpt-4o"),
run_id=child_id,
parent_run_id=parent_id,
)
self.assertEqual(tracker.audit_trail[0].agent, "market_analyst")
def test_format_audit_trail_output(self):
tracker = SpendTracker()
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
output = tracker.format_audit_trail()
self.assertIn("Delegation chain", output)
self.assertIn("LLM(gpt-4o)", output)
def test_empty_audit_trail(self):
tracker = SpendTracker()
output = tracker.format_audit_trail()
self.assertIn("No calls recorded", output)
class TestPricing(unittest.TestCase):
"""Model pricing lookup."""
def test_exact_match(self):
self.assertEqual(_get_pricing("gpt-4o"), MODEL_PRICING["gpt-4o"])
def test_prefix_match(self):
# "gpt-4o-2024-08-06" should match "gpt-4o"
result = _get_pricing("gpt-4o-2024-08-06")
self.assertEqual(result, MODEL_PRICING["gpt-4o"])
def test_unknown_model_returns_default(self):
result = _get_pricing("totally-unknown-model-xyz")
self.assertEqual(result, _DEFAULT_PRICING)
if __name__ == "__main__":
unittest.main()

View File

@ -1,2 +1,6 @@
import os
os.environ.setdefault("PYTHONUTF8", "1")
from .spend_tracker import SpendTracker, BudgetExceededError, MODEL_PRICING, AuditEntry # noqa: E402
__all__ = ["SpendTracker", "BudgetExceededError", "MODEL_PRICING", "AuditEntry"]

View File

@ -18,6 +18,8 @@ DEFAULT_CONFIG = {
# Output language for analyst reports and final decision
# Internal agent debate stays in English for reasoning quality
"output_language": "English",
# Spend limit: maximum estimated USD cost per run (None = unlimited)
"max_cost": None,
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,

View File

@ -190,70 +190,94 @@ class TradingAgentsGraph:
}
def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date."""
"""Run the trading agents graph for a company on a specific date.
Returns:
Tuple of (final_state, signal). If budget is exceeded mid-run,
returns partial state with ``final_trade_decision`` set to
``"BUDGET_EXCEEDED"`` and signal ``"hold"``.
"""
from tradingagents.spend_tracker import BudgetExceededError, SpendTracker
self.ticker = company_name
# Notify spend trackers of ticker start
for cb in self.callbacks:
if isinstance(cb, SpendTracker):
cb.begin_ticker(company_name)
# Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
args = self.propagator.get_graph_args()
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
pass
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
try:
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
pass
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
final_state = trace[-1]
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
final_state = trace[-1] if trace else dict(init_agent_state)
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
except BudgetExceededError:
# Graceful abort: build partial state from what we have
final_state = dict(init_agent_state)
if self.debug and trace:
final_state.update(trace[-1])
final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED")
# Store current state for reflection
self.curr_state = final_state
# Log state
# Log state (partial or full)
self._log_state(trade_date, final_state)
# Return decision and processed signal
return final_state, self.process_signal(final_state["final_trade_decision"])
# Log per-ticker and cumulative spend to stderr
for cb in self.callbacks:
if isinstance(cb, SpendTracker):
cb.log_ticker_spend(company_name)
cb.log_audit_trail()
decision_text = final_state.get("final_trade_decision", "BUDGET_EXCEEDED")
return final_state, self.process_signal(decision_text)
def _log_state(self, trade_date, final_state):
"""Log the final state to a JSON file."""
"""Log the final state to a JSON file. Tolerates partial state from budget abort."""
invest_debate = final_state.get("investment_debate_state") or {}
risk_debate = final_state.get("risk_debate_state") or {}
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
"market_report": final_state["market_report"],
"sentiment_report": final_state["sentiment_report"],
"news_report": final_state["news_report"],
"fundamentals_report": final_state["fundamentals_report"],
"company_of_interest": final_state.get("company_of_interest", ""),
"trade_date": final_state.get("trade_date", ""),
"market_report": final_state.get("market_report", ""),
"sentiment_report": final_state.get("sentiment_report", ""),
"news_report": final_state.get("news_report", ""),
"fundamentals_report": final_state.get("fundamentals_report", ""),
"investment_debate_state": {
"bull_history": final_state["investment_debate_state"]["bull_history"],
"bear_history": final_state["investment_debate_state"]["bear_history"],
"history": final_state["investment_debate_state"]["history"],
"current_response": final_state["investment_debate_state"][
"current_response"
],
"judge_decision": final_state["investment_debate_state"][
"judge_decision"
],
"bull_history": invest_debate.get("bull_history", ""),
"bear_history": invest_debate.get("bear_history", ""),
"history": invest_debate.get("history", ""),
"current_response": invest_debate.get("current_response", ""),
"judge_decision": invest_debate.get("judge_decision", ""),
},
"trader_investment_decision": final_state["trader_investment_plan"],
"trader_investment_decision": final_state.get("trader_investment_plan", ""),
"risk_debate_state": {
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
"history": final_state["risk_debate_state"]["history"],
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
"aggressive_history": risk_debate.get("aggressive_history", ""),
"conservative_history": risk_debate.get("conservative_history", ""),
"neutral_history": risk_debate.get("neutral_history", ""),
"history": risk_debate.get("history", ""),
"judge_decision": risk_debate.get("judge_decision", ""),
},
"investment_plan": final_state["investment_plan"],
"final_trade_decision": final_state["final_trade_decision"],
"investment_plan": final_state.get("investment_plan", ""),
"final_trade_decision": final_state.get("final_trade_decision", ""),
}
# Save to file

View File

@ -0,0 +1,350 @@
"""LangChain callback handler that tracks cumulative token usage and cost."""
from __future__ import annotations
import sys
import threading
import time
from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
class BudgetExceededError(Exception):
"""Raised when the spend budget is exceeded mid-graph."""
pass
@dataclass
class AuditEntry:
"""Single entry in the delegation chain audit trail."""
agent: str
call_type: str # "llm" or "tool"
name: str # model name or tool name
cost_usd: float
prompt_tokens: int
completion_tokens: int
timestamp: float = field(default_factory=time.monotonic)
# Pricing per 1M tokens: (input_usd, output_usd)
# Source: provider pricing pages as of 2025-Q2. Add new models as needed.
MODEL_PRICING: dict[str, tuple[float, float]] = {
# OpenAI
"gpt-4o": (2.50, 10.00),
"gpt-4o-mini": (0.15, 0.60),
"gpt-4-turbo": (10.00, 30.00),
"gpt-4": (30.00, 60.00),
"gpt-3.5-turbo": (0.50, 1.50),
"o1": (15.00, 60.00),
"o1-mini": (3.00, 12.00),
# Anthropic
"claude-3-5-sonnet-20241022": (3.00, 15.00),
"claude-3-5-haiku-20241022": (0.80, 4.00),
"claude-3-opus-20240229": (15.00, 75.00),
# Google
"gemini-2.0-flash": (0.10, 0.40),
"gemini-1.5-pro": (1.25, 5.00),
"gemini-1.5-flash": (0.075, 0.30),
}
# Fallback: generous estimate so unknown models still get a cost ceiling
_DEFAULT_PRICING: tuple[float, float] = (10.00, 30.00)
def _get_pricing(model: str) -> tuple[float, float]:
"""Return (input_per_1M, output_per_1M) for *model*, with prefix matching."""
if model in MODEL_PRICING:
return MODEL_PRICING[model]
# Try prefix match (e.g. "gpt-4o-2024-08-06" → "gpt-4o")
for key in sorted(MODEL_PRICING, key=len, reverse=True):
if model.startswith(key):
return MODEL_PRICING[key]
return _DEFAULT_PRICING
@dataclass
class TokenRecord:
"""Single LLM call token record."""
model: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
class SpendTracker(BaseCallbackHandler):
"""Tracks cumulative token usage and estimated USD cost.
Thread-safe: LangGraph may invoke nodes concurrently.
Args:
max_cost: Optional USD budget. When exceeded, ``budget_exceeded``
becomes ``True``. Callers should check this flag and abort.
Usage::
tracker = SpendTracker(max_cost=0.50)
graph = TradingAgentsGraph(callbacks=[tracker])
# ... run graph ...
print(tracker.total_cost_usd)
print(tracker.budget_exceeded)
"""
def __init__(self, max_cost: float | None = None) -> None:
super().__init__()
self._lock = threading.Lock()
self.prompt_tokens: int = 0
self.completion_tokens: int = 0
self.total_tokens: int = 0
self.total_cost_usd: float = 0.0
self.max_cost: float | None = max_cost
self.budget_exceeded: bool = False
self.records: list[TokenRecord] = []
# Per-ticker spend tracking
self.ticker_costs: dict[str, float] = {}
self._ticker_snapshot: float = 0.0 # cumulative cost at last ticker start
# Delegation chain audit trail
self.audit_trail: list[AuditEntry] = []
self._run_names: dict[UUID, str] = {} # run_id → chain/agent name
self._run_parents: dict[UUID, UUID] = {} # run_id → parent_run_id
self._tool_starts: dict[UUID, str] = {} # run_id → tool name
# -- LangChain callback hooks ------------------------------------------
def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Track chain/agent names for the delegation audit trail."""
name = serialized.get("name") or serialized.get("id", [""])[-1] or ""
with self._lock:
if name:
self._run_names[run_id] = name
if parent_run_id is not None:
self._run_parents[run_id] = parent_run_id
def on_llm_start(
self,
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Abort before the next LLM call if budget already exceeded."""
with self._lock:
if parent_run_id is not None:
self._run_parents[run_id] = parent_run_id
if self.budget_exceeded:
raise BudgetExceededError(
f"Budget ${self.max_cost:.2f} exceeded (spent ${self.total_cost_usd:.4f})"
)
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[Any]],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Abort before the next chat model call if budget already exceeded."""
with self._lock:
if parent_run_id is not None:
self._run_parents[run_id] = parent_run_id
if self.budget_exceeded:
raise BudgetExceededError(
f"Budget ${self.max_cost:.2f} exceeded (spent ${self.total_cost_usd:.4f})"
)
def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Record tool invocation start for audit trail."""
tool_name = serialized.get("name") or "unknown_tool"
with self._lock:
self._tool_starts[run_id] = tool_name
if parent_run_id is not None:
self._run_parents[run_id] = parent_run_id
def on_tool_end(
self,
output: Any,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
"""Record completed tool call in audit trail."""
with self._lock:
tool_name = self._tool_starts.pop(run_id, "unknown_tool")
agent = self._resolve_agent(parent_run_id or run_id)
self.audit_trail.append(
AuditEntry(
agent=agent,
call_type="tool",
name=tool_name,
cost_usd=0.0,
prompt_tokens=0,
completion_tokens=0,
)
)
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
usage = self._extract_usage(response)
if usage is None:
return
prompt, completion, total = usage
model = self._extract_model(response)
inp_price, out_price = _get_pricing(model)
call_cost = (prompt * inp_price + completion * out_price) / 1_000_000
with self._lock:
self.prompt_tokens += prompt
self.completion_tokens += completion
self.total_tokens += total
self.total_cost_usd += call_cost
self.records.append(
TokenRecord(
model=model,
prompt_tokens=prompt,
completion_tokens=completion,
total_tokens=total,
)
)
# Audit trail: resolve which agent triggered this LLM call
agent = self._resolve_agent(parent_run_id or run_id)
self.audit_trail.append(
AuditEntry(
agent=agent,
call_type="llm",
name=model,
cost_usd=call_cost,
prompt_tokens=prompt,
completion_tokens=completion,
)
)
if self.max_cost is not None and self.total_cost_usd >= self.max_cost:
self.budget_exceeded = True
# -- helpers -----------------------------------------------------------
def _resolve_agent(self, run_id: UUID | None) -> str:
"""Walk the parent chain to find the nearest named agent/chain. Must hold _lock."""
visited: set[UUID] = set()
current = run_id
while current and current not in visited:
visited.add(current)
name = self._run_names.get(current)
if name:
return name
current = self._run_parents.get(current)
return "unknown"
@staticmethod
def _extract_usage(response: LLMResult) -> tuple[int, int, int] | None:
"""Pull token counts from LLMResult.llm_output or generation metadata."""
# Most providers put token_usage in llm_output
llm_out = response.llm_output or {}
usage = llm_out.get("token_usage") or llm_out.get("usage") or {}
if usage:
prompt = usage.get("prompt_tokens", 0) or 0
completion = usage.get("completion_tokens", 0) or 0
total = usage.get("total_tokens", 0) or (prompt + completion)
return prompt, completion, total
# Fallback: check generation-level response_metadata (langchain-openai ≥0.3)
for gen_list in response.generations:
for gen in gen_list:
meta = getattr(gen, "generation_info", None) or {}
usage = meta.get("usage") or meta.get("token_usage") or {}
if usage:
prompt = usage.get("prompt_tokens", 0) or 0
completion = usage.get("completion_tokens", 0) or 0
total = usage.get("total_tokens", 0) or (prompt + completion)
return prompt, completion, total
return None
@staticmethod
def _extract_model(response: LLMResult) -> str:
llm_out = response.llm_output or {}
return llm_out.get("model_name", "") or llm_out.get("model", "unknown")
def begin_ticker(self, ticker: str) -> None:
"""Mark the start of a new ticker analysis. Call before propagate()."""
with self._lock:
self._ticker_snapshot = self.total_cost_usd
def log_ticker_spend(self, ticker: str) -> None:
"""Log per-ticker and cumulative spend to stderr. Call after propagate()."""
print(self.format_ticker_spend(ticker), file=sys.stderr)
def format_ticker_spend(self, ticker: str) -> str:
"""Return a human-readable ticker spend string."""
with self._lock:
ticker_cost = self.total_cost_usd - self._ticker_snapshot
self.ticker_costs[ticker] = self.ticker_costs.get(ticker, 0.0) + ticker_cost
budget_str = f" / budget ${self.max_cost:.2f}" if self.max_cost is not None else ""
exceeded = " [OVER BUDGET]" if self.budget_exceeded else ""
return f"[spend] {ticker}: ${ticker_cost:.4f} | cumulative: ${self.total_cost_usd:.4f}{budget_str}{exceeded}"
def reset(self) -> None:
"""Reset all counters."""
with self._lock:
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.total_cost_usd = 0.0
self.budget_exceeded = False
self.records.clear()
self.audit_trail.clear()
self._run_names.clear()
self._run_parents.clear()
self._tool_starts.clear()
def format_audit_trail(self) -> str:
"""Return a human-readable audit trail string."""
with self._lock:
if not self.audit_trail:
return "[audit] No calls recorded."
lines = ["[audit] Delegation chain:"]
for i, e in enumerate(self.audit_trail, 1):
if e.call_type == "llm":
lines.append(
f" {i}. {e.agent} → LLM({e.name}) "
f"tokens={e.prompt_tokens}+{e.completion_tokens} "
f"cost=${e.cost_usd:.4f}"
)
else:
lines.append(f" {i}. {e.agent} → tool({e.name})")
return "\n".join(lines)
def log_audit_trail(self) -> None:
"""Print the audit trail to stderr."""
print(self.format_audit_trail(), file=sys.stderr)