diff --git a/cli/main.py b/cli/main.py index 33d110fb..4f353f0b 100644 --- a/cli/main.py +++ b/cli/main.py @@ -253,7 +253,7 @@ def format_tokens(n): return str(n) -def update_display(layout, spinner_text=None, stats_handler=None, start_time=None): +def update_display(layout, spinner_text=None, stats_handler=None, start_time=None, spend_tracker=None): # Header with welcome message layout["header"].update( Panel( @@ -445,6 +445,15 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non tokens_str = "Tokens: --" stats_parts.append(tokens_str) + # Cost display from spend tracker + if spend_tracker: + cost_str = f"${spend_tracker.total_cost_usd:.4f}" + if spend_tracker.max_cost is not None: + cost_str += f"/${spend_tracker.max_cost:.2f}" + if spend_tracker.budget_exceeded: + cost_str = f"[red]{cost_str} OVER BUDGET[/red]" + stats_parts.append(f"Cost: {cost_str}") + stats_parts.append(f"Reports: {reports_completed}/{reports_total}") # Elapsed time @@ -926,7 +935,7 @@ def format_tool_args(args, max_length=80) -> str: return result[:max_length - 3] + "..." return result -def run_analysis(): +def run_analysis(max_cost: float | None = None): # First get all user selections selections = get_user_selections() @@ -943,10 +952,17 @@ def run_analysis(): config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["anthropic_effort"] = selections.get("anthropic_effort") config["output_language"] = selections.get("output_language", "English") + # Spend limit (CLI flag → config → DEFAULT_CONFIG) + if max_cost is not None: + config["max_cost"] = max_cost # Create stats callback handler for tracking LLM/tool calls stats_handler = StatsCallbackHandler() + # Create spend tracker with optional budget limit + from tradingagents import SpendTracker, BudgetExceededError + spend_tracker = SpendTracker(max_cost=config.get("max_cost")) + # Normalize analyst selection to predefined order (selection is a 'set', order is fixed) selected_set = {analyst.value for analyst in selections["analysts"]} selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set] @@ -956,7 +972,7 @@ def run_analysis(): selected_analyst_keys, config=config, debug=True, - callbacks=[stats_handler], + callbacks=[stats_handler, spend_tracker], ) # Initialize message buffer with selected analysts @@ -1018,7 +1034,7 @@ def run_analysis(): with Live(layout, refresh_per_second=4) as live: # Initial display - update_display(layout, stats_handler=stats_handler, start_time=start_time) + update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker) # Add initial messages message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") @@ -1029,18 +1045,18 @@ def run_analysis(): "System", f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}", ) - update_display(layout, stats_handler=stats_handler, start_time=start_time) + update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker) # Update agent status to in_progress for the first analyst first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst" message_buffer.update_agent_status(first_analyst, "in_progress") - update_display(layout, stats_handler=stats_handler, start_time=start_time) + update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker) # Create spinner text spinner_text = ( f"Analyzing {selections['ticker']} on {selections['analysis_date']}..." ) - update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time) + update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker) # Initialize state and get graph args with callbacks init_agent_state = graph.propagator.create_initial_state( @@ -1048,119 +1064,140 @@ def run_analysis(): ) # Pass callbacks to graph config for tool execution tracking # (LLM tracking is handled separately via LLM constructor) - args = graph.propagator.get_graph_args(callbacks=[stats_handler]) + args = graph.propagator.get_graph_args(callbacks=[stats_handler, spend_tracker]) + + # Notify spend tracker of ticker start + spend_tracker.begin_ticker(selections["ticker"]) # Stream the analysis trace = [] - for chunk in graph.graph.stream(init_agent_state, **args): - # Process all messages in chunk, deduplicating by message ID - for message in chunk.get("messages", []): - msg_id = getattr(message, "id", None) - if msg_id is not None: - if msg_id in message_buffer._processed_message_ids: - continue - message_buffer._processed_message_ids.add(msg_id) + budget_aborted = False + try: + for chunk in graph.graph.stream(init_agent_state, **args): + # Process all messages in chunk, deduplicating by message ID + for message in chunk.get("messages", []): + msg_id = getattr(message, "id", None) + if msg_id is not None: + if msg_id in message_buffer._processed_message_ids: + continue + message_buffer._processed_message_ids.add(msg_id) - msg_type, content = classify_message_type(message) - if content and content.strip(): - message_buffer.add_message(msg_type, content) + msg_type, content = classify_message_type(message) + if content and content.strip(): + message_buffer.add_message(msg_type, content) - if hasattr(message, "tool_calls") and message.tool_calls: - for tool_call in message.tool_calls: - if isinstance(tool_call, dict): - message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) - else: - message_buffer.add_tool_call(tool_call.name, tool_call.args) + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + if isinstance(tool_call, dict): + message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) + else: + message_buffer.add_tool_call(tool_call.name, tool_call.args) - # Update analyst statuses based on report state (runs on every chunk) - update_analyst_statuses(message_buffer, chunk) + # Update analyst statuses based on report state (runs on every chunk) + update_analyst_statuses(message_buffer, chunk) - # Research Team - Handle Investment Debate State - if chunk.get("investment_debate_state"): - debate_state = chunk["investment_debate_state"] - bull_hist = debate_state.get("bull_history", "").strip() - bear_hist = debate_state.get("bear_history", "").strip() - judge = debate_state.get("judge_decision", "").strip() + # Research Team - Handle Investment Debate State + if chunk.get("investment_debate_state"): + debate_state = chunk["investment_debate_state"] + bull_hist = debate_state.get("bull_history", "").strip() + bear_hist = debate_state.get("bear_history", "").strip() + judge = debate_state.get("judge_decision", "").strip() - # Only update status when there's actual content - if bull_hist or bear_hist: - update_research_team_status("in_progress") - if bull_hist: - message_buffer.update_report_section( - "investment_plan", f"### Bull Researcher Analysis\n{bull_hist}" - ) - if bear_hist: - message_buffer.update_report_section( - "investment_plan", f"### Bear Researcher Analysis\n{bear_hist}" - ) - if judge: - message_buffer.update_report_section( - "investment_plan", f"### Research Manager Decision\n{judge}" - ) - update_research_team_status("completed") - message_buffer.update_agent_status("Trader", "in_progress") - - # Trading Team - if chunk.get("trader_investment_plan"): - message_buffer.update_report_section( - "trader_investment_plan", chunk["trader_investment_plan"] - ) - if message_buffer.agent_status.get("Trader") != "completed": - message_buffer.update_agent_status("Trader", "completed") - message_buffer.update_agent_status("Aggressive Analyst", "in_progress") - - # Risk Management Team - Handle Risk Debate State - if chunk.get("risk_debate_state"): - risk_state = chunk["risk_debate_state"] - agg_hist = risk_state.get("aggressive_history", "").strip() - con_hist = risk_state.get("conservative_history", "").strip() - neu_hist = risk_state.get("neutral_history", "").strip() - judge = risk_state.get("judge_decision", "").strip() - - if agg_hist: - if message_buffer.agent_status.get("Aggressive Analyst") != "completed": - message_buffer.update_agent_status("Aggressive Analyst", "in_progress") - message_buffer.update_report_section( - "final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}" - ) - if con_hist: - if message_buffer.agent_status.get("Conservative Analyst") != "completed": - message_buffer.update_agent_status("Conservative Analyst", "in_progress") - message_buffer.update_report_section( - "final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}" - ) - if neu_hist: - if message_buffer.agent_status.get("Neutral Analyst") != "completed": - message_buffer.update_agent_status("Neutral Analyst", "in_progress") - message_buffer.update_report_section( - "final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}" - ) - if judge: - if message_buffer.agent_status.get("Portfolio Manager") != "completed": - message_buffer.update_agent_status("Portfolio Manager", "in_progress") + # Only update status when there's actual content + if bull_hist or bear_hist: + update_research_team_status("in_progress") + if bull_hist: message_buffer.update_report_section( - "final_trade_decision", f"### Portfolio Manager Decision\n{judge}" + "investment_plan", f"### Bull Researcher Analysis\n{bull_hist}" ) - message_buffer.update_agent_status("Aggressive Analyst", "completed") - message_buffer.update_agent_status("Conservative Analyst", "completed") - message_buffer.update_agent_status("Neutral Analyst", "completed") - message_buffer.update_agent_status("Portfolio Manager", "completed") + if bear_hist: + message_buffer.update_report_section( + "investment_plan", f"### Bear Researcher Analysis\n{bear_hist}" + ) + if judge: + message_buffer.update_report_section( + "investment_plan", f"### Research Manager Decision\n{judge}" + ) + update_research_team_status("completed") + message_buffer.update_agent_status("Trader", "in_progress") - # Update the display - update_display(layout, stats_handler=stats_handler, start_time=start_time) + # Trading Team + if chunk.get("trader_investment_plan"): + message_buffer.update_report_section( + "trader_investment_plan", chunk["trader_investment_plan"] + ) + if message_buffer.agent_status.get("Trader") != "completed": + message_buffer.update_agent_status("Trader", "completed") + message_buffer.update_agent_status("Aggressive Analyst", "in_progress") - trace.append(chunk) + # Risk Management Team - Handle Risk Debate State + if chunk.get("risk_debate_state"): + risk_state = chunk["risk_debate_state"] + agg_hist = risk_state.get("aggressive_history", "").strip() + con_hist = risk_state.get("conservative_history", "").strip() + neu_hist = risk_state.get("neutral_history", "").strip() + judge = risk_state.get("judge_decision", "").strip() + + if agg_hist: + if message_buffer.agent_status.get("Aggressive Analyst") != "completed": + message_buffer.update_agent_status("Aggressive Analyst", "in_progress") + message_buffer.update_report_section( + "final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}" + ) + if con_hist: + if message_buffer.agent_status.get("Conservative Analyst") != "completed": + message_buffer.update_agent_status("Conservative Analyst", "in_progress") + message_buffer.update_report_section( + "final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}" + ) + if neu_hist: + if message_buffer.agent_status.get("Neutral Analyst") != "completed": + message_buffer.update_agent_status("Neutral Analyst", "in_progress") + message_buffer.update_report_section( + "final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}" + ) + if judge: + if message_buffer.agent_status.get("Portfolio Manager") != "completed": + message_buffer.update_agent_status("Portfolio Manager", "in_progress") + message_buffer.update_report_section( + "final_trade_decision", f"### Portfolio Manager Decision\n{judge}" + ) + message_buffer.update_agent_status("Aggressive Analyst", "completed") + message_buffer.update_agent_status("Conservative Analyst", "completed") + message_buffer.update_agent_status("Neutral Analyst", "completed") + message_buffer.update_agent_status("Portfolio Manager", "completed") + + # Update the display + update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker) + + trace.append(chunk) + except BudgetExceededError: + budget_aborted = True + message_buffer.add_message( + "System", + f"[red]Budget exceeded (${spend_tracker.total_cost_usd:.4f} / ${spend_tracker.max_cost:.2f}). Saving partial results.[/red]", + ) + update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker) + + # Log per-ticker and cumulative spend to stderr + spend_tracker.log_ticker_spend(selections["ticker"]) # Get final state and decision - final_state = trace[-1] - decision = graph.process_signal(final_state["final_trade_decision"]) + if trace: + final_state = trace[-1] + else: + final_state = dict(init_agent_state) + if budget_aborted: + final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED") + decision = graph.process_signal(final_state.get("final_trade_decision", "BUDGET_EXCEEDED")) # Update all agent statuses to completed for agent in message_buffer.agent_status: message_buffer.update_agent_status(agent, "completed") message_buffer.add_message( - "System", f"Completed analysis for {selections['analysis_date']}" + "System", + f"{'Budget exceeded — partial' if budget_aborted else 'Completed'} analysis for {selections['analysis_date']}" ) # Update final report sections @@ -1168,10 +1205,13 @@ def run_analysis(): if section in final_state: message_buffer.update_report_section(section, final_state[section]) - update_display(layout, stats_handler=stats_handler, start_time=start_time) + update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker) # Post-analysis prompts (outside Live context for clean interaction) - console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n") + if budget_aborted: + console.print("\n[bold red]Analysis aborted — budget exceeded. Partial results saved.[/bold red]\n") + else: + console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n") # Prompt to save report save_choice = typer.prompt("Save report?", default="Y").strip().upper() @@ -1197,8 +1237,12 @@ def run_analysis(): @app.command() -def analyze(): - run_analysis() +def analyze( + max_cost: Optional[float] = typer.Option( + None, "--max-cost", help="Maximum estimated USD spend for this run. Aborts when exceeded." + ), +): + run_analysis(max_cost=max_cost) if __name__ == "__main__": diff --git a/tests/test_spend_tracker.py b/tests/test_spend_tracker.py new file mode 100644 index 00000000..1235d160 --- /dev/null +++ b/tests/test_spend_tracker.py @@ -0,0 +1,237 @@ +"""Tests for spend tracking, budget abort, and partial result saving.""" + +import unittest +from unittest.mock import MagicMock +from uuid import uuid4 + +from langchain_core.outputs import LLMResult, Generation + +from tradingagents.spend_tracker import ( + AuditEntry, + BudgetExceededError, + SpendTracker, + TokenRecord, + _get_pricing, + MODEL_PRICING, + _DEFAULT_PRICING, +) + + +def _make_llm_result( + prompt_tokens: int = 100, + completion_tokens: int = 50, + model: str = "gpt-4o", +) -> LLMResult: + """Build a minimal LLMResult with token usage.""" + return LLMResult( + generations=[[Generation(text="ok")]], + llm_output={ + "token_usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "model_name": model, + }, + ) + + +class TestSpendTrackerAccumulation(unittest.TestCase): + """Token and cost accumulation.""" + + def test_single_call_updates_totals(self): + tracker = SpendTracker() + tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4()) + self.assertEqual(tracker.prompt_tokens, 100) + self.assertEqual(tracker.completion_tokens, 50) + self.assertEqual(tracker.total_tokens, 150) + self.assertGreater(tracker.total_cost_usd, 0) + + def test_multiple_calls_accumulate(self): + tracker = SpendTracker() + tracker.on_llm_end(_make_llm_result(100, 50), run_id=uuid4()) + tracker.on_llm_end(_make_llm_result(200, 100), run_id=uuid4()) + self.assertEqual(tracker.prompt_tokens, 300) + self.assertEqual(tracker.completion_tokens, 150) + self.assertEqual(len(tracker.records), 2) + + def test_cost_calculation_matches_pricing(self): + tracker = SpendTracker() + inp_price, out_price = MODEL_PRICING["gpt-4o"] + tracker.on_llm_end(_make_llm_result(1_000_000, 1_000_000, "gpt-4o"), run_id=uuid4()) + expected = inp_price + out_price # 1M tokens each + self.assertAlmostEqual(tracker.total_cost_usd, expected, places=2) + + def test_reset_clears_all(self): + tracker = SpendTracker(max_cost=10.0) + tracker.on_llm_end(_make_llm_result(100, 50), run_id=uuid4()) + tracker.budget_exceeded = True + tracker.reset() + self.assertEqual(tracker.total_tokens, 0) + self.assertEqual(tracker.total_cost_usd, 0.0) + self.assertFalse(tracker.budget_exceeded) + self.assertEqual(len(tracker.records), 0) + self.assertEqual(len(tracker.audit_trail), 0) + + +class TestBudgetAbort(unittest.TestCase): + """Budget exceeded detection and abort.""" + + def test_budget_exceeded_flag_set(self): + tracker = SpendTracker(max_cost=0.0001) + tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4()) + self.assertTrue(tracker.budget_exceeded) + + def test_on_llm_start_raises_when_over_budget(self): + tracker = SpendTracker(max_cost=0.0001) + tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4()) + self.assertTrue(tracker.budget_exceeded) + with self.assertRaises(BudgetExceededError): + tracker.on_llm_start({}, ["prompt"], run_id=uuid4()) + + def test_on_chat_model_start_raises_when_over_budget(self): + tracker = SpendTracker(max_cost=0.0001) + tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4()) + with self.assertRaises(BudgetExceededError): + tracker.on_chat_model_start({}, [[]], run_id=uuid4()) + + def test_no_abort_when_under_budget(self): + tracker = SpendTracker(max_cost=999.0) + tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4()) + self.assertFalse(tracker.budget_exceeded) + # Should not raise + tracker.on_llm_start({}, ["prompt"], run_id=uuid4()) + + def test_no_budget_means_never_exceeded(self): + tracker = SpendTracker(max_cost=None) + tracker.on_llm_end(_make_llm_result(1_000_000, 1_000_000, "gpt-4o"), run_id=uuid4()) + self.assertFalse(tracker.budget_exceeded) + + +class TestPartialResults(unittest.TestCase): + """Verify partial results are preserved when budget is exceeded.""" + + def test_propagate_returns_partial_on_budget_exceeded(self): + """Simulate what TradingAgentsGraph.propagate() does on BudgetExceededError.""" + # This mirrors the logic in trading_graph.py propagate() method: + # on BudgetExceededError, it builds partial state from init_agent_state + init_state = { + "company_of_interest": "AAPL", + "trade_date": "2025-01-15", + "market_report": "partial market data", + "messages": [], + } + + # Simulate the except BudgetExceededError branch + final_state = dict(init_state) + final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED") + + self.assertEqual(final_state["final_trade_decision"], "BUDGET_EXCEEDED") + self.assertEqual(final_state["company_of_interest"], "AAPL") + self.assertEqual(final_state["market_report"], "partial market data") + + def test_budget_exceeded_error_message_contains_amounts(self): + tracker = SpendTracker(max_cost=0.01) + tracker.on_llm_end(_make_llm_result(10000, 5000, "gpt-4o"), run_id=uuid4()) + try: + tracker.on_llm_start({}, ["prompt"], run_id=uuid4()) + self.fail("Expected BudgetExceededError") + except BudgetExceededError as e: + msg = str(e) + self.assertIn("$0.01", msg) # budget + self.assertIn("spent", msg.lower()) + + +class TestTickerSpend(unittest.TestCase): + """Per-ticker spend tracking.""" + + def test_ticker_cost_recorded(self): + tracker = SpendTracker() + tracker.begin_ticker("AAPL") + tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4()) + tracker.log_ticker_spend("AAPL") + self.assertIn("AAPL", tracker.ticker_costs) + self.assertGreater(tracker.ticker_costs["AAPL"], 0) + + def test_multiple_tickers_tracked_separately(self): + tracker = SpendTracker() + tracker.begin_ticker("AAPL") + tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4()) + tracker.log_ticker_spend("AAPL") + + tracker.begin_ticker("MSFT") + tracker.on_llm_end(_make_llm_result(200, 100, "gpt-4o"), run_id=uuid4()) + tracker.log_ticker_spend("MSFT") + + self.assertIn("AAPL", tracker.ticker_costs) + self.assertIn("MSFT", tracker.ticker_costs) + self.assertGreater(tracker.ticker_costs["MSFT"], tracker.ticker_costs["AAPL"]) + + +class TestAuditTrail(unittest.TestCase): + """Delegation chain audit trail.""" + + def test_llm_call_creates_audit_entry(self): + tracker = SpendTracker() + tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4()) + self.assertEqual(len(tracker.audit_trail), 1) + entry = tracker.audit_trail[0] + self.assertEqual(entry.call_type, "llm") + self.assertEqual(entry.name, "gpt-4o") + self.assertGreater(entry.cost_usd, 0) + + def test_tool_call_creates_audit_entry(self): + tracker = SpendTracker() + run_id = uuid4() + tracker.on_tool_start({"name": "get_stock_data"}, "", run_id=run_id) + tracker.on_tool_end("result", run_id=run_id) + self.assertEqual(len(tracker.audit_trail), 1) + entry = tracker.audit_trail[0] + self.assertEqual(entry.call_type, "tool") + self.assertEqual(entry.name, "get_stock_data") + + def test_agent_name_resolved_from_chain(self): + tracker = SpendTracker() + parent_id = uuid4() + child_id = uuid4() + tracker.on_chain_start( + {"name": "market_analyst"}, {}, run_id=parent_id + ) + tracker.on_llm_end( + _make_llm_result(100, 50, "gpt-4o"), + run_id=child_id, + parent_run_id=parent_id, + ) + self.assertEqual(tracker.audit_trail[0].agent, "market_analyst") + + def test_format_audit_trail_output(self): + tracker = SpendTracker() + tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4()) + output = tracker.format_audit_trail() + self.assertIn("Delegation chain", output) + self.assertIn("LLM(gpt-4o)", output) + + def test_empty_audit_trail(self): + tracker = SpendTracker() + output = tracker.format_audit_trail() + self.assertIn("No calls recorded", output) + + +class TestPricing(unittest.TestCase): + """Model pricing lookup.""" + + def test_exact_match(self): + self.assertEqual(_get_pricing("gpt-4o"), MODEL_PRICING["gpt-4o"]) + + def test_prefix_match(self): + # "gpt-4o-2024-08-06" should match "gpt-4o" + result = _get_pricing("gpt-4o-2024-08-06") + self.assertEqual(result, MODEL_PRICING["gpt-4o"]) + + def test_unknown_model_returns_default(self): + result = _get_pricing("totally-unknown-model-xyz") + self.assertEqual(result, _DEFAULT_PRICING) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/__init__.py b/tradingagents/__init__.py index 43a2b439..be1d2704 100644 --- a/tradingagents/__init__.py +++ b/tradingagents/__init__.py @@ -1,2 +1,6 @@ import os os.environ.setdefault("PYTHONUTF8", "1") + +from .spend_tracker import SpendTracker, BudgetExceededError, MODEL_PRICING, AuditEntry # noqa: E402 + +__all__ = ["SpendTracker", "BudgetExceededError", "MODEL_PRICING", "AuditEntry"] diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index a9b75e4b..2a6b898c 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -18,6 +18,8 @@ DEFAULT_CONFIG = { # Output language for analyst reports and final decision # Internal agent debate stays in English for reasoning quality "output_language": "English", + # Spend limit: maximum estimated USD cost per run (None = unlimited) + "max_cost": None, # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 78bc13e5..72e5a069 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -190,70 +190,94 @@ class TradingAgentsGraph: } def propagate(self, company_name, trade_date): - """Run the trading agents graph for a company on a specific date.""" + """Run the trading agents graph for a company on a specific date. + + Returns: + Tuple of (final_state, signal). If budget is exceeded mid-run, + returns partial state with ``final_trade_decision`` set to + ``"BUDGET_EXCEEDED"`` and signal ``"hold"``. + """ + from tradingagents.spend_tracker import BudgetExceededError, SpendTracker self.ticker = company_name + # Notify spend trackers of ticker start + for cb in self.callbacks: + if isinstance(cb, SpendTracker): + cb.begin_ticker(company_name) + # Initialize state init_agent_state = self.propagator.create_initial_state( company_name, trade_date ) args = self.propagator.get_graph_args() - if self.debug: - # Debug mode with tracing - trace = [] - for chunk in self.graph.stream(init_agent_state, **args): - if len(chunk["messages"]) == 0: - pass - else: - chunk["messages"][-1].pretty_print() - trace.append(chunk) + try: + if self.debug: + # Debug mode with tracing + trace = [] + for chunk in self.graph.stream(init_agent_state, **args): + if len(chunk["messages"]) == 0: + pass + else: + chunk["messages"][-1].pretty_print() + trace.append(chunk) - final_state = trace[-1] - else: - # Standard mode without tracing - final_state = self.graph.invoke(init_agent_state, **args) + final_state = trace[-1] if trace else dict(init_agent_state) + else: + # Standard mode without tracing + final_state = self.graph.invoke(init_agent_state, **args) + except BudgetExceededError: + # Graceful abort: build partial state from what we have + final_state = dict(init_agent_state) + if self.debug and trace: + final_state.update(trace[-1]) + final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED") # Store current state for reflection self.curr_state = final_state - # Log state + # Log state (partial or full) self._log_state(trade_date, final_state) - # Return decision and processed signal - return final_state, self.process_signal(final_state["final_trade_decision"]) + # Log per-ticker and cumulative spend to stderr + for cb in self.callbacks: + if isinstance(cb, SpendTracker): + cb.log_ticker_spend(company_name) + cb.log_audit_trail() + + decision_text = final_state.get("final_trade_decision", "BUDGET_EXCEEDED") + return final_state, self.process_signal(decision_text) def _log_state(self, trade_date, final_state): - """Log the final state to a JSON file.""" + """Log the final state to a JSON file. Tolerates partial state from budget abort.""" + invest_debate = final_state.get("investment_debate_state") or {} + risk_debate = final_state.get("risk_debate_state") or {} + self.log_states_dict[str(trade_date)] = { - "company_of_interest": final_state["company_of_interest"], - "trade_date": final_state["trade_date"], - "market_report": final_state["market_report"], - "sentiment_report": final_state["sentiment_report"], - "news_report": final_state["news_report"], - "fundamentals_report": final_state["fundamentals_report"], + "company_of_interest": final_state.get("company_of_interest", ""), + "trade_date": final_state.get("trade_date", ""), + "market_report": final_state.get("market_report", ""), + "sentiment_report": final_state.get("sentiment_report", ""), + "news_report": final_state.get("news_report", ""), + "fundamentals_report": final_state.get("fundamentals_report", ""), "investment_debate_state": { - "bull_history": final_state["investment_debate_state"]["bull_history"], - "bear_history": final_state["investment_debate_state"]["bear_history"], - "history": final_state["investment_debate_state"]["history"], - "current_response": final_state["investment_debate_state"][ - "current_response" - ], - "judge_decision": final_state["investment_debate_state"][ - "judge_decision" - ], + "bull_history": invest_debate.get("bull_history", ""), + "bear_history": invest_debate.get("bear_history", ""), + "history": invest_debate.get("history", ""), + "current_response": invest_debate.get("current_response", ""), + "judge_decision": invest_debate.get("judge_decision", ""), }, - "trader_investment_decision": final_state["trader_investment_plan"], + "trader_investment_decision": final_state.get("trader_investment_plan", ""), "risk_debate_state": { - "aggressive_history": final_state["risk_debate_state"]["aggressive_history"], - "conservative_history": final_state["risk_debate_state"]["conservative_history"], - "neutral_history": final_state["risk_debate_state"]["neutral_history"], - "history": final_state["risk_debate_state"]["history"], - "judge_decision": final_state["risk_debate_state"]["judge_decision"], + "aggressive_history": risk_debate.get("aggressive_history", ""), + "conservative_history": risk_debate.get("conservative_history", ""), + "neutral_history": risk_debate.get("neutral_history", ""), + "history": risk_debate.get("history", ""), + "judge_decision": risk_debate.get("judge_decision", ""), }, - "investment_plan": final_state["investment_plan"], - "final_trade_decision": final_state["final_trade_decision"], + "investment_plan": final_state.get("investment_plan", ""), + "final_trade_decision": final_state.get("final_trade_decision", ""), } # Save to file diff --git a/tradingagents/spend_tracker.py b/tradingagents/spend_tracker.py new file mode 100644 index 00000000..92fe5368 --- /dev/null +++ b/tradingagents/spend_tracker.py @@ -0,0 +1,350 @@ +"""LangChain callback handler that tracks cumulative token usage and cost.""" + +from __future__ import annotations + +import sys +import threading +import time +from dataclasses import dataclass, field +from typing import Any +from uuid import UUID + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + + +class BudgetExceededError(Exception): + """Raised when the spend budget is exceeded mid-graph.""" + pass + + +@dataclass +class AuditEntry: + """Single entry in the delegation chain audit trail.""" + + agent: str + call_type: str # "llm" or "tool" + name: str # model name or tool name + cost_usd: float + prompt_tokens: int + completion_tokens: int + timestamp: float = field(default_factory=time.monotonic) + +# Pricing per 1M tokens: (input_usd, output_usd) +# Source: provider pricing pages as of 2025-Q2. Add new models as needed. +MODEL_PRICING: dict[str, tuple[float, float]] = { + # OpenAI + "gpt-4o": (2.50, 10.00), + "gpt-4o-mini": (0.15, 0.60), + "gpt-4-turbo": (10.00, 30.00), + "gpt-4": (30.00, 60.00), + "gpt-3.5-turbo": (0.50, 1.50), + "o1": (15.00, 60.00), + "o1-mini": (3.00, 12.00), + # Anthropic + "claude-3-5-sonnet-20241022": (3.00, 15.00), + "claude-3-5-haiku-20241022": (0.80, 4.00), + "claude-3-opus-20240229": (15.00, 75.00), + # Google + "gemini-2.0-flash": (0.10, 0.40), + "gemini-1.5-pro": (1.25, 5.00), + "gemini-1.5-flash": (0.075, 0.30), +} + +# Fallback: generous estimate so unknown models still get a cost ceiling +_DEFAULT_PRICING: tuple[float, float] = (10.00, 30.00) + + +def _get_pricing(model: str) -> tuple[float, float]: + """Return (input_per_1M, output_per_1M) for *model*, with prefix matching.""" + if model in MODEL_PRICING: + return MODEL_PRICING[model] + # Try prefix match (e.g. "gpt-4o-2024-08-06" → "gpt-4o") + for key in sorted(MODEL_PRICING, key=len, reverse=True): + if model.startswith(key): + return MODEL_PRICING[key] + return _DEFAULT_PRICING + + +@dataclass +class TokenRecord: + """Single LLM call token record.""" + + model: str + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class SpendTracker(BaseCallbackHandler): + """Tracks cumulative token usage and estimated USD cost. + + Thread-safe: LangGraph may invoke nodes concurrently. + + Args: + max_cost: Optional USD budget. When exceeded, ``budget_exceeded`` + becomes ``True``. Callers should check this flag and abort. + + Usage:: + + tracker = SpendTracker(max_cost=0.50) + graph = TradingAgentsGraph(callbacks=[tracker]) + # ... run graph ... + print(tracker.total_cost_usd) + print(tracker.budget_exceeded) + """ + + def __init__(self, max_cost: float | None = None) -> None: + super().__init__() + self._lock = threading.Lock() + self.prompt_tokens: int = 0 + self.completion_tokens: int = 0 + self.total_tokens: int = 0 + self.total_cost_usd: float = 0.0 + self.max_cost: float | None = max_cost + self.budget_exceeded: bool = False + self.records: list[TokenRecord] = [] + # Per-ticker spend tracking + self.ticker_costs: dict[str, float] = {} + self._ticker_snapshot: float = 0.0 # cumulative cost at last ticker start + # Delegation chain audit trail + self.audit_trail: list[AuditEntry] = [] + self._run_names: dict[UUID, str] = {} # run_id → chain/agent name + self._run_parents: dict[UUID, UUID] = {} # run_id → parent_run_id + self._tool_starts: dict[UUID, str] = {} # run_id → tool name + + # -- LangChain callback hooks ------------------------------------------ + + def on_chain_start( + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + **kwargs: Any, + ) -> None: + """Track chain/agent names for the delegation audit trail.""" + name = serialized.get("name") or serialized.get("id", [""])[-1] or "" + with self._lock: + if name: + self._run_names[run_id] = name + if parent_run_id is not None: + self._run_parents[run_id] = parent_run_id + + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Abort before the next LLM call if budget already exceeded.""" + with self._lock: + if parent_run_id is not None: + self._run_parents[run_id] = parent_run_id + if self.budget_exceeded: + raise BudgetExceededError( + f"Budget ${self.max_cost:.2f} exceeded (spent ${self.total_cost_usd:.4f})" + ) + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[Any]], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Abort before the next chat model call if budget already exceeded.""" + with self._lock: + if parent_run_id is not None: + self._run_parents[run_id] = parent_run_id + if self.budget_exceeded: + raise BudgetExceededError( + f"Budget ${self.max_cost:.2f} exceeded (spent ${self.total_cost_usd:.4f})" + ) + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Record tool invocation start for audit trail.""" + tool_name = serialized.get("name") or "unknown_tool" + with self._lock: + self._tool_starts[run_id] = tool_name + if parent_run_id is not None: + self._run_parents[run_id] = parent_run_id + + def on_tool_end( + self, + output: Any, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Record completed tool call in audit trail.""" + with self._lock: + tool_name = self._tool_starts.pop(run_id, "unknown_tool") + agent = self._resolve_agent(parent_run_id or run_id) + self.audit_trail.append( + AuditEntry( + agent=agent, + call_type="tool", + name=tool_name, + cost_usd=0.0, + prompt_tokens=0, + completion_tokens=0, + ) + ) + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + usage = self._extract_usage(response) + if usage is None: + return + prompt, completion, total = usage + model = self._extract_model(response) + inp_price, out_price = _get_pricing(model) + call_cost = (prompt * inp_price + completion * out_price) / 1_000_000 + + with self._lock: + self.prompt_tokens += prompt + self.completion_tokens += completion + self.total_tokens += total + self.total_cost_usd += call_cost + self.records.append( + TokenRecord( + model=model, + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=total, + ) + ) + # Audit trail: resolve which agent triggered this LLM call + agent = self._resolve_agent(parent_run_id or run_id) + self.audit_trail.append( + AuditEntry( + agent=agent, + call_type="llm", + name=model, + cost_usd=call_cost, + prompt_tokens=prompt, + completion_tokens=completion, + ) + ) + if self.max_cost is not None and self.total_cost_usd >= self.max_cost: + self.budget_exceeded = True + + # -- helpers ----------------------------------------------------------- + + def _resolve_agent(self, run_id: UUID | None) -> str: + """Walk the parent chain to find the nearest named agent/chain. Must hold _lock.""" + visited: set[UUID] = set() + current = run_id + while current and current not in visited: + visited.add(current) + name = self._run_names.get(current) + if name: + return name + current = self._run_parents.get(current) + return "unknown" + + @staticmethod + def _extract_usage(response: LLMResult) -> tuple[int, int, int] | None: + """Pull token counts from LLMResult.llm_output or generation metadata.""" + # Most providers put token_usage in llm_output + llm_out = response.llm_output or {} + usage = llm_out.get("token_usage") or llm_out.get("usage") or {} + if usage: + prompt = usage.get("prompt_tokens", 0) or 0 + completion = usage.get("completion_tokens", 0) or 0 + total = usage.get("total_tokens", 0) or (prompt + completion) + return prompt, completion, total + + # Fallback: check generation-level response_metadata (langchain-openai ≥0.3) + for gen_list in response.generations: + for gen in gen_list: + meta = getattr(gen, "generation_info", None) or {} + usage = meta.get("usage") or meta.get("token_usage") or {} + if usage: + prompt = usage.get("prompt_tokens", 0) or 0 + completion = usage.get("completion_tokens", 0) or 0 + total = usage.get("total_tokens", 0) or (prompt + completion) + return prompt, completion, total + + return None + + @staticmethod + def _extract_model(response: LLMResult) -> str: + llm_out = response.llm_output or {} + return llm_out.get("model_name", "") or llm_out.get("model", "unknown") + + def begin_ticker(self, ticker: str) -> None: + """Mark the start of a new ticker analysis. Call before propagate().""" + with self._lock: + self._ticker_snapshot = self.total_cost_usd + + def log_ticker_spend(self, ticker: str) -> None: + """Log per-ticker and cumulative spend to stderr. Call after propagate().""" + print(self.format_ticker_spend(ticker), file=sys.stderr) + + def format_ticker_spend(self, ticker: str) -> str: + """Return a human-readable ticker spend string.""" + with self._lock: + ticker_cost = self.total_cost_usd - self._ticker_snapshot + self.ticker_costs[ticker] = self.ticker_costs.get(ticker, 0.0) + ticker_cost + budget_str = f" / budget ${self.max_cost:.2f}" if self.max_cost is not None else "" + exceeded = " [OVER BUDGET]" if self.budget_exceeded else "" + return f"[spend] {ticker}: ${ticker_cost:.4f} | cumulative: ${self.total_cost_usd:.4f}{budget_str}{exceeded}" + + def reset(self) -> None: + """Reset all counters.""" + with self._lock: + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.total_tokens = 0 + self.total_cost_usd = 0.0 + self.budget_exceeded = False + self.records.clear() + self.audit_trail.clear() + self._run_names.clear() + self._run_parents.clear() + self._tool_starts.clear() + + def format_audit_trail(self) -> str: + """Return a human-readable audit trail string.""" + with self._lock: + if not self.audit_trail: + return "[audit] No calls recorded." + lines = ["[audit] Delegation chain:"] + for i, e in enumerate(self.audit_trail, 1): + if e.call_type == "llm": + lines.append( + f" {i}. {e.agent} → LLM({e.name}) " + f"tokens={e.prompt_tokens}+{e.completion_tokens} " + f"cost=${e.cost_usd:.4f}" + ) + else: + lines.append(f" {i}. {e.agent} → tool({e.name})") + return "\n".join(lines) + + def log_audit_trail(self) -> None: + """Print the audit trail to stderr.""" + print(self.format_audit_trail(), file=sys.stderr)