feat: spend limits with automatic abort (#524)
Add --max-cost flag (USD) to cap LLM spend per analysis run. Review feedback applied: - Guard trace[-1] IndexError when stream yields no chunks - Remove speculative model entries from pricing table - Move budget checks inside lock for thread safety - Split log_ticker_spend into format + log for modularity Closes #524
This commit is contained in:
parent
fa4d01c23a
commit
872ec8d119
252
cli/main.py
252
cli/main.py
|
|
@ -253,7 +253,7 @@ def format_tokens(n):
|
|||
return str(n)
|
||||
|
||||
|
||||
def update_display(layout, spinner_text=None, stats_handler=None, start_time=None):
|
||||
def update_display(layout, spinner_text=None, stats_handler=None, start_time=None, spend_tracker=None):
|
||||
# Header with welcome message
|
||||
layout["header"].update(
|
||||
Panel(
|
||||
|
|
@ -445,6 +445,15 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
|
|||
tokens_str = "Tokens: --"
|
||||
stats_parts.append(tokens_str)
|
||||
|
||||
# Cost display from spend tracker
|
||||
if spend_tracker:
|
||||
cost_str = f"${spend_tracker.total_cost_usd:.4f}"
|
||||
if spend_tracker.max_cost is not None:
|
||||
cost_str += f"/${spend_tracker.max_cost:.2f}"
|
||||
if spend_tracker.budget_exceeded:
|
||||
cost_str = f"[red]{cost_str} OVER BUDGET[/red]"
|
||||
stats_parts.append(f"Cost: {cost_str}")
|
||||
|
||||
stats_parts.append(f"Reports: {reports_completed}/{reports_total}")
|
||||
|
||||
# Elapsed time
|
||||
|
|
@ -926,7 +935,7 @@ def format_tool_args(args, max_length=80) -> str:
|
|||
return result[:max_length - 3] + "..."
|
||||
return result
|
||||
|
||||
def run_analysis():
|
||||
def run_analysis(max_cost: float | None = None):
|
||||
# First get all user selections
|
||||
selections = get_user_selections()
|
||||
|
||||
|
|
@ -943,10 +952,17 @@ def run_analysis():
|
|||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
config["anthropic_effort"] = selections.get("anthropic_effort")
|
||||
config["output_language"] = selections.get("output_language", "English")
|
||||
# Spend limit (CLI flag → config → DEFAULT_CONFIG)
|
||||
if max_cost is not None:
|
||||
config["max_cost"] = max_cost
|
||||
|
||||
# Create stats callback handler for tracking LLM/tool calls
|
||||
stats_handler = StatsCallbackHandler()
|
||||
|
||||
# Create spend tracker with optional budget limit
|
||||
from tradingagents import SpendTracker, BudgetExceededError
|
||||
spend_tracker = SpendTracker(max_cost=config.get("max_cost"))
|
||||
|
||||
# Normalize analyst selection to predefined order (selection is a 'set', order is fixed)
|
||||
selected_set = {analyst.value for analyst in selections["analysts"]}
|
||||
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
|
||||
|
|
@ -956,7 +972,7 @@ def run_analysis():
|
|||
selected_analyst_keys,
|
||||
config=config,
|
||||
debug=True,
|
||||
callbacks=[stats_handler],
|
||||
callbacks=[stats_handler, spend_tracker],
|
||||
)
|
||||
|
||||
# Initialize message buffer with selected analysts
|
||||
|
|
@ -1018,7 +1034,7 @@ def run_analysis():
|
|||
|
||||
with Live(layout, refresh_per_second=4) as live:
|
||||
# Initial display
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
|
||||
|
||||
# Add initial messages
|
||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||
|
|
@ -1029,18 +1045,18 @@ def run_analysis():
|
|||
"System",
|
||||
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
||||
)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
|
||||
|
||||
# Update agent status to in_progress for the first analyst
|
||||
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
|
||||
message_buffer.update_agent_status(first_analyst, "in_progress")
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
|
||||
|
||||
# Create spinner text
|
||||
spinner_text = (
|
||||
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
|
||||
)
|
||||
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
||||
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
|
||||
|
||||
# Initialize state and get graph args with callbacks
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
|
|
@ -1048,119 +1064,140 @@ def run_analysis():
|
|||
)
|
||||
# Pass callbacks to graph config for tool execution tracking
|
||||
# (LLM tracking is handled separately via LLM constructor)
|
||||
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
|
||||
args = graph.propagator.get_graph_args(callbacks=[stats_handler, spend_tracker])
|
||||
|
||||
# Notify spend tracker of ticker start
|
||||
spend_tracker.begin_ticker(selections["ticker"])
|
||||
|
||||
# Stream the analysis
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
# Process all messages in chunk, deduplicating by message ID
|
||||
for message in chunk.get("messages", []):
|
||||
msg_id = getattr(message, "id", None)
|
||||
if msg_id is not None:
|
||||
if msg_id in message_buffer._processed_message_ids:
|
||||
continue
|
||||
message_buffer._processed_message_ids.add(msg_id)
|
||||
budget_aborted = False
|
||||
try:
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
# Process all messages in chunk, deduplicating by message ID
|
||||
for message in chunk.get("messages", []):
|
||||
msg_id = getattr(message, "id", None)
|
||||
if msg_id is not None:
|
||||
if msg_id in message_buffer._processed_message_ids:
|
||||
continue
|
||||
message_buffer._processed_message_ids.add(msg_id)
|
||||
|
||||
msg_type, content = classify_message_type(message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
msg_type, content = classify_message_type(message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
||||
# Update analyst statuses based on report state (runs on every chunk)
|
||||
update_analyst_statuses(message_buffer, chunk)
|
||||
# Update analyst statuses based on report state (runs on every chunk)
|
||||
update_analyst_statuses(message_buffer, chunk)
|
||||
|
||||
# Research Team - Handle Investment Debate State
|
||||
if chunk.get("investment_debate_state"):
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
bull_hist = debate_state.get("bull_history", "").strip()
|
||||
bear_hist = debate_state.get("bear_history", "").strip()
|
||||
judge = debate_state.get("judge_decision", "").strip()
|
||||
# Research Team - Handle Investment Debate State
|
||||
if chunk.get("investment_debate_state"):
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
bull_hist = debate_state.get("bull_history", "").strip()
|
||||
bear_hist = debate_state.get("bear_history", "").strip()
|
||||
judge = debate_state.get("judge_decision", "").strip()
|
||||
|
||||
# Only update status when there's actual content
|
||||
if bull_hist or bear_hist:
|
||||
update_research_team_status("in_progress")
|
||||
if bull_hist:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan", f"### Bull Researcher Analysis\n{bull_hist}"
|
||||
)
|
||||
if bear_hist:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan", f"### Bear Researcher Analysis\n{bear_hist}"
|
||||
)
|
||||
if judge:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan", f"### Research Manager Decision\n{judge}"
|
||||
)
|
||||
update_research_team_status("completed")
|
||||
message_buffer.update_agent_status("Trader", "in_progress")
|
||||
|
||||
# Trading Team
|
||||
if chunk.get("trader_investment_plan"):
|
||||
message_buffer.update_report_section(
|
||||
"trader_investment_plan", chunk["trader_investment_plan"]
|
||||
)
|
||||
if message_buffer.agent_status.get("Trader") != "completed":
|
||||
message_buffer.update_agent_status("Trader", "completed")
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
|
||||
|
||||
# Risk Management Team - Handle Risk Debate State
|
||||
if chunk.get("risk_debate_state"):
|
||||
risk_state = chunk["risk_debate_state"]
|
||||
agg_hist = risk_state.get("aggressive_history", "").strip()
|
||||
con_hist = risk_state.get("conservative_history", "").strip()
|
||||
neu_hist = risk_state.get("neutral_history", "").strip()
|
||||
judge = risk_state.get("judge_decision", "").strip()
|
||||
|
||||
if agg_hist:
|
||||
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}"
|
||||
)
|
||||
if con_hist:
|
||||
if message_buffer.agent_status.get("Conservative Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Conservative Analyst", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}"
|
||||
)
|
||||
if neu_hist:
|
||||
if message_buffer.agent_status.get("Neutral Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
|
||||
)
|
||||
if judge:
|
||||
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||
# Only update status when there's actual content
|
||||
if bull_hist or bear_hist:
|
||||
update_research_team_status("in_progress")
|
||||
if bull_hist:
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
||||
"investment_plan", f"### Bull Researcher Analysis\n{bull_hist}"
|
||||
)
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "completed")
|
||||
message_buffer.update_agent_status("Conservative Analyst", "completed")
|
||||
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
||||
message_buffer.update_agent_status("Portfolio Manager", "completed")
|
||||
if bear_hist:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan", f"### Bear Researcher Analysis\n{bear_hist}"
|
||||
)
|
||||
if judge:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan", f"### Research Manager Decision\n{judge}"
|
||||
)
|
||||
update_research_team_status("completed")
|
||||
message_buffer.update_agent_status("Trader", "in_progress")
|
||||
|
||||
# Update the display
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
# Trading Team
|
||||
if chunk.get("trader_investment_plan"):
|
||||
message_buffer.update_report_section(
|
||||
"trader_investment_plan", chunk["trader_investment_plan"]
|
||||
)
|
||||
if message_buffer.agent_status.get("Trader") != "completed":
|
||||
message_buffer.update_agent_status("Trader", "completed")
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
|
||||
|
||||
trace.append(chunk)
|
||||
# Risk Management Team - Handle Risk Debate State
|
||||
if chunk.get("risk_debate_state"):
|
||||
risk_state = chunk["risk_debate_state"]
|
||||
agg_hist = risk_state.get("aggressive_history", "").strip()
|
||||
con_hist = risk_state.get("conservative_history", "").strip()
|
||||
neu_hist = risk_state.get("neutral_history", "").strip()
|
||||
judge = risk_state.get("judge_decision", "").strip()
|
||||
|
||||
if agg_hist:
|
||||
if message_buffer.agent_status.get("Aggressive Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Aggressive Analyst Analysis\n{agg_hist}"
|
||||
)
|
||||
if con_hist:
|
||||
if message_buffer.agent_status.get("Conservative Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Conservative Analyst", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Conservative Analyst Analysis\n{con_hist}"
|
||||
)
|
||||
if neu_hist:
|
||||
if message_buffer.agent_status.get("Neutral Analyst") != "completed":
|
||||
message_buffer.update_agent_status("Neutral Analyst", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Neutral Analyst Analysis\n{neu_hist}"
|
||||
)
|
||||
if judge:
|
||||
if message_buffer.agent_status.get("Portfolio Manager") != "completed":
|
||||
message_buffer.update_agent_status("Portfolio Manager", "in_progress")
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision", f"### Portfolio Manager Decision\n{judge}"
|
||||
)
|
||||
message_buffer.update_agent_status("Aggressive Analyst", "completed")
|
||||
message_buffer.update_agent_status("Conservative Analyst", "completed")
|
||||
message_buffer.update_agent_status("Neutral Analyst", "completed")
|
||||
message_buffer.update_agent_status("Portfolio Manager", "completed")
|
||||
|
||||
# Update the display
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
|
||||
|
||||
trace.append(chunk)
|
||||
except BudgetExceededError:
|
||||
budget_aborted = True
|
||||
message_buffer.add_message(
|
||||
"System",
|
||||
f"[red]Budget exceeded (${spend_tracker.total_cost_usd:.4f} / ${spend_tracker.max_cost:.2f}). Saving partial results.[/red]",
|
||||
)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
|
||||
|
||||
# Log per-ticker and cumulative spend to stderr
|
||||
spend_tracker.log_ticker_spend(selections["ticker"])
|
||||
|
||||
# Get final state and decision
|
||||
final_state = trace[-1]
|
||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
||||
if trace:
|
||||
final_state = trace[-1]
|
||||
else:
|
||||
final_state = dict(init_agent_state)
|
||||
if budget_aborted:
|
||||
final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED")
|
||||
decision = graph.process_signal(final_state.get("final_trade_decision", "BUDGET_EXCEEDED"))
|
||||
|
||||
# Update all agent statuses to completed
|
||||
for agent in message_buffer.agent_status:
|
||||
message_buffer.update_agent_status(agent, "completed")
|
||||
|
||||
message_buffer.add_message(
|
||||
"System", f"Completed analysis for {selections['analysis_date']}"
|
||||
"System",
|
||||
f"{'Budget exceeded — partial' if budget_aborted else 'Completed'} analysis for {selections['analysis_date']}"
|
||||
)
|
||||
|
||||
# Update final report sections
|
||||
|
|
@ -1168,10 +1205,13 @@ def run_analysis():
|
|||
if section in final_state:
|
||||
message_buffer.update_report_section(section, final_state[section])
|
||||
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time, spend_tracker=spend_tracker)
|
||||
|
||||
# Post-analysis prompts (outside Live context for clean interaction)
|
||||
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
|
||||
if budget_aborted:
|
||||
console.print("\n[bold red]Analysis aborted — budget exceeded. Partial results saved.[/bold red]\n")
|
||||
else:
|
||||
console.print("\n[bold cyan]Analysis Complete![/bold cyan]\n")
|
||||
|
||||
# Prompt to save report
|
||||
save_choice = typer.prompt("Save report?", default="Y").strip().upper()
|
||||
|
|
@ -1197,8 +1237,12 @@ def run_analysis():
|
|||
|
||||
|
||||
@app.command()
|
||||
def analyze():
|
||||
run_analysis()
|
||||
def analyze(
|
||||
max_cost: Optional[float] = typer.Option(
|
||||
None, "--max-cost", help="Maximum estimated USD spend for this run. Aborts when exceeded."
|
||||
),
|
||||
):
|
||||
run_analysis(max_cost=max_cost)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,237 @@
|
|||
"""Tests for spend tracking, budget abort, and partial result saving."""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.outputs import LLMResult, Generation
|
||||
|
||||
from tradingagents.spend_tracker import (
|
||||
AuditEntry,
|
||||
BudgetExceededError,
|
||||
SpendTracker,
|
||||
TokenRecord,
|
||||
_get_pricing,
|
||||
MODEL_PRICING,
|
||||
_DEFAULT_PRICING,
|
||||
)
|
||||
|
||||
|
||||
def _make_llm_result(
|
||||
prompt_tokens: int = 100,
|
||||
completion_tokens: int = 50,
|
||||
model: str = "gpt-4o",
|
||||
) -> LLMResult:
|
||||
"""Build a minimal LLMResult with token usage."""
|
||||
return LLMResult(
|
||||
generations=[[Generation(text="ok")]],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
"model_name": model,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TestSpendTrackerAccumulation(unittest.TestCase):
|
||||
"""Token and cost accumulation."""
|
||||
|
||||
def test_single_call_updates_totals(self):
|
||||
tracker = SpendTracker()
|
||||
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
|
||||
self.assertEqual(tracker.prompt_tokens, 100)
|
||||
self.assertEqual(tracker.completion_tokens, 50)
|
||||
self.assertEqual(tracker.total_tokens, 150)
|
||||
self.assertGreater(tracker.total_cost_usd, 0)
|
||||
|
||||
def test_multiple_calls_accumulate(self):
|
||||
tracker = SpendTracker()
|
||||
tracker.on_llm_end(_make_llm_result(100, 50), run_id=uuid4())
|
||||
tracker.on_llm_end(_make_llm_result(200, 100), run_id=uuid4())
|
||||
self.assertEqual(tracker.prompt_tokens, 300)
|
||||
self.assertEqual(tracker.completion_tokens, 150)
|
||||
self.assertEqual(len(tracker.records), 2)
|
||||
|
||||
def test_cost_calculation_matches_pricing(self):
|
||||
tracker = SpendTracker()
|
||||
inp_price, out_price = MODEL_PRICING["gpt-4o"]
|
||||
tracker.on_llm_end(_make_llm_result(1_000_000, 1_000_000, "gpt-4o"), run_id=uuid4())
|
||||
expected = inp_price + out_price # 1M tokens each
|
||||
self.assertAlmostEqual(tracker.total_cost_usd, expected, places=2)
|
||||
|
||||
def test_reset_clears_all(self):
|
||||
tracker = SpendTracker(max_cost=10.0)
|
||||
tracker.on_llm_end(_make_llm_result(100, 50), run_id=uuid4())
|
||||
tracker.budget_exceeded = True
|
||||
tracker.reset()
|
||||
self.assertEqual(tracker.total_tokens, 0)
|
||||
self.assertEqual(tracker.total_cost_usd, 0.0)
|
||||
self.assertFalse(tracker.budget_exceeded)
|
||||
self.assertEqual(len(tracker.records), 0)
|
||||
self.assertEqual(len(tracker.audit_trail), 0)
|
||||
|
||||
|
||||
class TestBudgetAbort(unittest.TestCase):
|
||||
"""Budget exceeded detection and abort."""
|
||||
|
||||
def test_budget_exceeded_flag_set(self):
|
||||
tracker = SpendTracker(max_cost=0.0001)
|
||||
tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4())
|
||||
self.assertTrue(tracker.budget_exceeded)
|
||||
|
||||
def test_on_llm_start_raises_when_over_budget(self):
|
||||
tracker = SpendTracker(max_cost=0.0001)
|
||||
tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4())
|
||||
self.assertTrue(tracker.budget_exceeded)
|
||||
with self.assertRaises(BudgetExceededError):
|
||||
tracker.on_llm_start({}, ["prompt"], run_id=uuid4())
|
||||
|
||||
def test_on_chat_model_start_raises_when_over_budget(self):
|
||||
tracker = SpendTracker(max_cost=0.0001)
|
||||
tracker.on_llm_end(_make_llm_result(1000, 500, "gpt-4o"), run_id=uuid4())
|
||||
with self.assertRaises(BudgetExceededError):
|
||||
tracker.on_chat_model_start({}, [[]], run_id=uuid4())
|
||||
|
||||
def test_no_abort_when_under_budget(self):
|
||||
tracker = SpendTracker(max_cost=999.0)
|
||||
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
|
||||
self.assertFalse(tracker.budget_exceeded)
|
||||
# Should not raise
|
||||
tracker.on_llm_start({}, ["prompt"], run_id=uuid4())
|
||||
|
||||
def test_no_budget_means_never_exceeded(self):
|
||||
tracker = SpendTracker(max_cost=None)
|
||||
tracker.on_llm_end(_make_llm_result(1_000_000, 1_000_000, "gpt-4o"), run_id=uuid4())
|
||||
self.assertFalse(tracker.budget_exceeded)
|
||||
|
||||
|
||||
class TestPartialResults(unittest.TestCase):
|
||||
"""Verify partial results are preserved when budget is exceeded."""
|
||||
|
||||
def test_propagate_returns_partial_on_budget_exceeded(self):
|
||||
"""Simulate what TradingAgentsGraph.propagate() does on BudgetExceededError."""
|
||||
# This mirrors the logic in trading_graph.py propagate() method:
|
||||
# on BudgetExceededError, it builds partial state from init_agent_state
|
||||
init_state = {
|
||||
"company_of_interest": "AAPL",
|
||||
"trade_date": "2025-01-15",
|
||||
"market_report": "partial market data",
|
||||
"messages": [],
|
||||
}
|
||||
|
||||
# Simulate the except BudgetExceededError branch
|
||||
final_state = dict(init_state)
|
||||
final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED")
|
||||
|
||||
self.assertEqual(final_state["final_trade_decision"], "BUDGET_EXCEEDED")
|
||||
self.assertEqual(final_state["company_of_interest"], "AAPL")
|
||||
self.assertEqual(final_state["market_report"], "partial market data")
|
||||
|
||||
def test_budget_exceeded_error_message_contains_amounts(self):
|
||||
tracker = SpendTracker(max_cost=0.01)
|
||||
tracker.on_llm_end(_make_llm_result(10000, 5000, "gpt-4o"), run_id=uuid4())
|
||||
try:
|
||||
tracker.on_llm_start({}, ["prompt"], run_id=uuid4())
|
||||
self.fail("Expected BudgetExceededError")
|
||||
except BudgetExceededError as e:
|
||||
msg = str(e)
|
||||
self.assertIn("$0.01", msg) # budget
|
||||
self.assertIn("spent", msg.lower())
|
||||
|
||||
|
||||
class TestTickerSpend(unittest.TestCase):
|
||||
"""Per-ticker spend tracking."""
|
||||
|
||||
def test_ticker_cost_recorded(self):
|
||||
tracker = SpendTracker()
|
||||
tracker.begin_ticker("AAPL")
|
||||
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
|
||||
tracker.log_ticker_spend("AAPL")
|
||||
self.assertIn("AAPL", tracker.ticker_costs)
|
||||
self.assertGreater(tracker.ticker_costs["AAPL"], 0)
|
||||
|
||||
def test_multiple_tickers_tracked_separately(self):
|
||||
tracker = SpendTracker()
|
||||
tracker.begin_ticker("AAPL")
|
||||
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
|
||||
tracker.log_ticker_spend("AAPL")
|
||||
|
||||
tracker.begin_ticker("MSFT")
|
||||
tracker.on_llm_end(_make_llm_result(200, 100, "gpt-4o"), run_id=uuid4())
|
||||
tracker.log_ticker_spend("MSFT")
|
||||
|
||||
self.assertIn("AAPL", tracker.ticker_costs)
|
||||
self.assertIn("MSFT", tracker.ticker_costs)
|
||||
self.assertGreater(tracker.ticker_costs["MSFT"], tracker.ticker_costs["AAPL"])
|
||||
|
||||
|
||||
class TestAuditTrail(unittest.TestCase):
|
||||
"""Delegation chain audit trail."""
|
||||
|
||||
def test_llm_call_creates_audit_entry(self):
|
||||
tracker = SpendTracker()
|
||||
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
|
||||
self.assertEqual(len(tracker.audit_trail), 1)
|
||||
entry = tracker.audit_trail[0]
|
||||
self.assertEqual(entry.call_type, "llm")
|
||||
self.assertEqual(entry.name, "gpt-4o")
|
||||
self.assertGreater(entry.cost_usd, 0)
|
||||
|
||||
def test_tool_call_creates_audit_entry(self):
|
||||
tracker = SpendTracker()
|
||||
run_id = uuid4()
|
||||
tracker.on_tool_start({"name": "get_stock_data"}, "", run_id=run_id)
|
||||
tracker.on_tool_end("result", run_id=run_id)
|
||||
self.assertEqual(len(tracker.audit_trail), 1)
|
||||
entry = tracker.audit_trail[0]
|
||||
self.assertEqual(entry.call_type, "tool")
|
||||
self.assertEqual(entry.name, "get_stock_data")
|
||||
|
||||
def test_agent_name_resolved_from_chain(self):
|
||||
tracker = SpendTracker()
|
||||
parent_id = uuid4()
|
||||
child_id = uuid4()
|
||||
tracker.on_chain_start(
|
||||
{"name": "market_analyst"}, {}, run_id=parent_id
|
||||
)
|
||||
tracker.on_llm_end(
|
||||
_make_llm_result(100, 50, "gpt-4o"),
|
||||
run_id=child_id,
|
||||
parent_run_id=parent_id,
|
||||
)
|
||||
self.assertEqual(tracker.audit_trail[0].agent, "market_analyst")
|
||||
|
||||
def test_format_audit_trail_output(self):
|
||||
tracker = SpendTracker()
|
||||
tracker.on_llm_end(_make_llm_result(100, 50, "gpt-4o"), run_id=uuid4())
|
||||
output = tracker.format_audit_trail()
|
||||
self.assertIn("Delegation chain", output)
|
||||
self.assertIn("LLM(gpt-4o)", output)
|
||||
|
||||
def test_empty_audit_trail(self):
|
||||
tracker = SpendTracker()
|
||||
output = tracker.format_audit_trail()
|
||||
self.assertIn("No calls recorded", output)
|
||||
|
||||
|
||||
class TestPricing(unittest.TestCase):
|
||||
"""Model pricing lookup."""
|
||||
|
||||
def test_exact_match(self):
|
||||
self.assertEqual(_get_pricing("gpt-4o"), MODEL_PRICING["gpt-4o"])
|
||||
|
||||
def test_prefix_match(self):
|
||||
# "gpt-4o-2024-08-06" should match "gpt-4o"
|
||||
result = _get_pricing("gpt-4o-2024-08-06")
|
||||
self.assertEqual(result, MODEL_PRICING["gpt-4o"])
|
||||
|
||||
def test_unknown_model_returns_default(self):
|
||||
result = _get_pricing("totally-unknown-model-xyz")
|
||||
self.assertEqual(result, _DEFAULT_PRICING)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,2 +1,6 @@
|
|||
import os
|
||||
os.environ.setdefault("PYTHONUTF8", "1")
|
||||
|
||||
from .spend_tracker import SpendTracker, BudgetExceededError, MODEL_PRICING, AuditEntry # noqa: E402
|
||||
|
||||
__all__ = ["SpendTracker", "BudgetExceededError", "MODEL_PRICING", "AuditEntry"]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ DEFAULT_CONFIG = {
|
|||
# Output language for analyst reports and final decision
|
||||
# Internal agent debate stays in English for reasoning quality
|
||||
"output_language": "English",
|
||||
# Spend limit: maximum estimated USD cost per run (None = unlimited)
|
||||
"max_cost": None,
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
|
|
|
|||
|
|
@ -190,70 +190,94 @@ class TradingAgentsGraph:
|
|||
}
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
"""Run the trading agents graph for a company on a specific date."""
|
||||
"""Run the trading agents graph for a company on a specific date.
|
||||
|
||||
Returns:
|
||||
Tuple of (final_state, signal). If budget is exceeded mid-run,
|
||||
returns partial state with ``final_trade_decision`` set to
|
||||
``"BUDGET_EXCEEDED"`` and signal ``"hold"``.
|
||||
"""
|
||||
from tradingagents.spend_tracker import BudgetExceededError, SpendTracker
|
||||
|
||||
self.ticker = company_name
|
||||
|
||||
# Notify spend trackers of ticker start
|
||||
for cb in self.callbacks:
|
||||
if isinstance(cb, SpendTracker):
|
||||
cb.begin_ticker(company_name)
|
||||
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
if self.debug:
|
||||
# Debug mode with tracing
|
||||
trace = []
|
||||
for chunk in self.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) == 0:
|
||||
pass
|
||||
else:
|
||||
chunk["messages"][-1].pretty_print()
|
||||
trace.append(chunk)
|
||||
try:
|
||||
if self.debug:
|
||||
# Debug mode with tracing
|
||||
trace = []
|
||||
for chunk in self.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) == 0:
|
||||
pass
|
||||
else:
|
||||
chunk["messages"][-1].pretty_print()
|
||||
trace.append(chunk)
|
||||
|
||||
final_state = trace[-1]
|
||||
else:
|
||||
# Standard mode without tracing
|
||||
final_state = self.graph.invoke(init_agent_state, **args)
|
||||
final_state = trace[-1] if trace else dict(init_agent_state)
|
||||
else:
|
||||
# Standard mode without tracing
|
||||
final_state = self.graph.invoke(init_agent_state, **args)
|
||||
except BudgetExceededError:
|
||||
# Graceful abort: build partial state from what we have
|
||||
final_state = dict(init_agent_state)
|
||||
if self.debug and trace:
|
||||
final_state.update(trace[-1])
|
||||
final_state.setdefault("final_trade_decision", "BUDGET_EXCEEDED")
|
||||
|
||||
# Store current state for reflection
|
||||
self.curr_state = final_state
|
||||
|
||||
# Log state
|
||||
# Log state (partial or full)
|
||||
self._log_state(trade_date, final_state)
|
||||
|
||||
# Return decision and processed signal
|
||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||
# Log per-ticker and cumulative spend to stderr
|
||||
for cb in self.callbacks:
|
||||
if isinstance(cb, SpendTracker):
|
||||
cb.log_ticker_spend(company_name)
|
||||
cb.log_audit_trail()
|
||||
|
||||
decision_text = final_state.get("final_trade_decision", "BUDGET_EXCEEDED")
|
||||
return final_state, self.process_signal(decision_text)
|
||||
|
||||
def _log_state(self, trade_date, final_state):
|
||||
"""Log the final state to a JSON file."""
|
||||
"""Log the final state to a JSON file. Tolerates partial state from budget abort."""
|
||||
invest_debate = final_state.get("investment_debate_state") or {}
|
||||
risk_debate = final_state.get("risk_debate_state") or {}
|
||||
|
||||
self.log_states_dict[str(trade_date)] = {
|
||||
"company_of_interest": final_state["company_of_interest"],
|
||||
"trade_date": final_state["trade_date"],
|
||||
"market_report": final_state["market_report"],
|
||||
"sentiment_report": final_state["sentiment_report"],
|
||||
"news_report": final_state["news_report"],
|
||||
"fundamentals_report": final_state["fundamentals_report"],
|
||||
"company_of_interest": final_state.get("company_of_interest", ""),
|
||||
"trade_date": final_state.get("trade_date", ""),
|
||||
"market_report": final_state.get("market_report", ""),
|
||||
"sentiment_report": final_state.get("sentiment_report", ""),
|
||||
"news_report": final_state.get("news_report", ""),
|
||||
"fundamentals_report": final_state.get("fundamentals_report", ""),
|
||||
"investment_debate_state": {
|
||||
"bull_history": final_state["investment_debate_state"]["bull_history"],
|
||||
"bear_history": final_state["investment_debate_state"]["bear_history"],
|
||||
"history": final_state["investment_debate_state"]["history"],
|
||||
"current_response": final_state["investment_debate_state"][
|
||||
"current_response"
|
||||
],
|
||||
"judge_decision": final_state["investment_debate_state"][
|
||||
"judge_decision"
|
||||
],
|
||||
"bull_history": invest_debate.get("bull_history", ""),
|
||||
"bear_history": invest_debate.get("bear_history", ""),
|
||||
"history": invest_debate.get("history", ""),
|
||||
"current_response": invest_debate.get("current_response", ""),
|
||||
"judge_decision": invest_debate.get("judge_decision", ""),
|
||||
},
|
||||
"trader_investment_decision": final_state["trader_investment_plan"],
|
||||
"trader_investment_decision": final_state.get("trader_investment_plan", ""),
|
||||
"risk_debate_state": {
|
||||
"aggressive_history": final_state["risk_debate_state"]["aggressive_history"],
|
||||
"conservative_history": final_state["risk_debate_state"]["conservative_history"],
|
||||
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
|
||||
"history": final_state["risk_debate_state"]["history"],
|
||||
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
|
||||
"aggressive_history": risk_debate.get("aggressive_history", ""),
|
||||
"conservative_history": risk_debate.get("conservative_history", ""),
|
||||
"neutral_history": risk_debate.get("neutral_history", ""),
|
||||
"history": risk_debate.get("history", ""),
|
||||
"judge_decision": risk_debate.get("judge_decision", ""),
|
||||
},
|
||||
"investment_plan": final_state["investment_plan"],
|
||||
"final_trade_decision": final_state["final_trade_decision"],
|
||||
"investment_plan": final_state.get("investment_plan", ""),
|
||||
"final_trade_decision": final_state.get("final_trade_decision", ""),
|
||||
}
|
||||
|
||||
# Save to file
|
||||
|
|
|
|||
|
|
@ -0,0 +1,350 @@
|
|||
"""LangChain callback handler that tracks cumulative token usage and cost."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
"""Raised when the spend budget is exceeded mid-graph."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuditEntry:
|
||||
"""Single entry in the delegation chain audit trail."""
|
||||
|
||||
agent: str
|
||||
call_type: str # "llm" or "tool"
|
||||
name: str # model name or tool name
|
||||
cost_usd: float
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
timestamp: float = field(default_factory=time.monotonic)
|
||||
|
||||
# Pricing per 1M tokens: (input_usd, output_usd)
|
||||
# Source: provider pricing pages as of 2025-Q2. Add new models as needed.
|
||||
MODEL_PRICING: dict[str, tuple[float, float]] = {
|
||||
# OpenAI
|
||||
"gpt-4o": (2.50, 10.00),
|
||||
"gpt-4o-mini": (0.15, 0.60),
|
||||
"gpt-4-turbo": (10.00, 30.00),
|
||||
"gpt-4": (30.00, 60.00),
|
||||
"gpt-3.5-turbo": (0.50, 1.50),
|
||||
"o1": (15.00, 60.00),
|
||||
"o1-mini": (3.00, 12.00),
|
||||
# Anthropic
|
||||
"claude-3-5-sonnet-20241022": (3.00, 15.00),
|
||||
"claude-3-5-haiku-20241022": (0.80, 4.00),
|
||||
"claude-3-opus-20240229": (15.00, 75.00),
|
||||
# Google
|
||||
"gemini-2.0-flash": (0.10, 0.40),
|
||||
"gemini-1.5-pro": (1.25, 5.00),
|
||||
"gemini-1.5-flash": (0.075, 0.30),
|
||||
}
|
||||
|
||||
# Fallback: generous estimate so unknown models still get a cost ceiling
|
||||
_DEFAULT_PRICING: tuple[float, float] = (10.00, 30.00)
|
||||
|
||||
|
||||
def _get_pricing(model: str) -> tuple[float, float]:
|
||||
"""Return (input_per_1M, output_per_1M) for *model*, with prefix matching."""
|
||||
if model in MODEL_PRICING:
|
||||
return MODEL_PRICING[model]
|
||||
# Try prefix match (e.g. "gpt-4o-2024-08-06" → "gpt-4o")
|
||||
for key in sorted(MODEL_PRICING, key=len, reverse=True):
|
||||
if model.startswith(key):
|
||||
return MODEL_PRICING[key]
|
||||
return _DEFAULT_PRICING
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenRecord:
|
||||
"""Single LLM call token record."""
|
||||
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class SpendTracker(BaseCallbackHandler):
|
||||
"""Tracks cumulative token usage and estimated USD cost.
|
||||
|
||||
Thread-safe: LangGraph may invoke nodes concurrently.
|
||||
|
||||
Args:
|
||||
max_cost: Optional USD budget. When exceeded, ``budget_exceeded``
|
||||
becomes ``True``. Callers should check this flag and abort.
|
||||
|
||||
Usage::
|
||||
|
||||
tracker = SpendTracker(max_cost=0.50)
|
||||
graph = TradingAgentsGraph(callbacks=[tracker])
|
||||
# ... run graph ...
|
||||
print(tracker.total_cost_usd)
|
||||
print(tracker.budget_exceeded)
|
||||
"""
|
||||
|
||||
def __init__(self, max_cost: float | None = None) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
self.prompt_tokens: int = 0
|
||||
self.completion_tokens: int = 0
|
||||
self.total_tokens: int = 0
|
||||
self.total_cost_usd: float = 0.0
|
||||
self.max_cost: float | None = max_cost
|
||||
self.budget_exceeded: bool = False
|
||||
self.records: list[TokenRecord] = []
|
||||
# Per-ticker spend tracking
|
||||
self.ticker_costs: dict[str, float] = {}
|
||||
self._ticker_snapshot: float = 0.0 # cumulative cost at last ticker start
|
||||
# Delegation chain audit trail
|
||||
self.audit_trail: list[AuditEntry] = []
|
||||
self._run_names: dict[UUID, str] = {} # run_id → chain/agent name
|
||||
self._run_parents: dict[UUID, UUID] = {} # run_id → parent_run_id
|
||||
self._tool_starts: dict[UUID, str] = {} # run_id → tool name
|
||||
|
||||
# -- LangChain callback hooks ------------------------------------------
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Track chain/agent names for the delegation audit trail."""
|
||||
name = serialized.get("name") or serialized.get("id", [""])[-1] or ""
|
||||
with self._lock:
|
||||
if name:
|
||||
self._run_names[run_id] = name
|
||||
if parent_run_id is not None:
|
||||
self._run_parents[run_id] = parent_run_id
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
prompts: list[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Abort before the next LLM call if budget already exceeded."""
|
||||
with self._lock:
|
||||
if parent_run_id is not None:
|
||||
self._run_parents[run_id] = parent_run_id
|
||||
if self.budget_exceeded:
|
||||
raise BudgetExceededError(
|
||||
f"Budget ${self.max_cost:.2f} exceeded (spent ${self.total_cost_usd:.4f})"
|
||||
)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
messages: list[list[Any]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Abort before the next chat model call if budget already exceeded."""
|
||||
with self._lock:
|
||||
if parent_run_id is not None:
|
||||
self._run_parents[run_id] = parent_run_id
|
||||
if self.budget_exceeded:
|
||||
raise BudgetExceededError(
|
||||
f"Budget ${self.max_cost:.2f} exceeded (spent ${self.total_cost_usd:.4f})"
|
||||
)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Record tool invocation start for audit trail."""
|
||||
tool_name = serialized.get("name") or "unknown_tool"
|
||||
with self._lock:
|
||||
self._tool_starts[run_id] = tool_name
|
||||
if parent_run_id is not None:
|
||||
self._run_parents[run_id] = parent_run_id
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Record completed tool call in audit trail."""
|
||||
with self._lock:
|
||||
tool_name = self._tool_starts.pop(run_id, "unknown_tool")
|
||||
agent = self._resolve_agent(parent_run_id or run_id)
|
||||
self.audit_trail.append(
|
||||
AuditEntry(
|
||||
agent=agent,
|
||||
call_type="tool",
|
||||
name=tool_name,
|
||||
cost_usd=0.0,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
)
|
||||
)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
usage = self._extract_usage(response)
|
||||
if usage is None:
|
||||
return
|
||||
prompt, completion, total = usage
|
||||
model = self._extract_model(response)
|
||||
inp_price, out_price = _get_pricing(model)
|
||||
call_cost = (prompt * inp_price + completion * out_price) / 1_000_000
|
||||
|
||||
with self._lock:
|
||||
self.prompt_tokens += prompt
|
||||
self.completion_tokens += completion
|
||||
self.total_tokens += total
|
||||
self.total_cost_usd += call_cost
|
||||
self.records.append(
|
||||
TokenRecord(
|
||||
model=model,
|
||||
prompt_tokens=prompt,
|
||||
completion_tokens=completion,
|
||||
total_tokens=total,
|
||||
)
|
||||
)
|
||||
# Audit trail: resolve which agent triggered this LLM call
|
||||
agent = self._resolve_agent(parent_run_id or run_id)
|
||||
self.audit_trail.append(
|
||||
AuditEntry(
|
||||
agent=agent,
|
||||
call_type="llm",
|
||||
name=model,
|
||||
cost_usd=call_cost,
|
||||
prompt_tokens=prompt,
|
||||
completion_tokens=completion,
|
||||
)
|
||||
)
|
||||
if self.max_cost is not None and self.total_cost_usd >= self.max_cost:
|
||||
self.budget_exceeded = True
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _resolve_agent(self, run_id: UUID | None) -> str:
|
||||
"""Walk the parent chain to find the nearest named agent/chain. Must hold _lock."""
|
||||
visited: set[UUID] = set()
|
||||
current = run_id
|
||||
while current and current not in visited:
|
||||
visited.add(current)
|
||||
name = self._run_names.get(current)
|
||||
if name:
|
||||
return name
|
||||
current = self._run_parents.get(current)
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _extract_usage(response: LLMResult) -> tuple[int, int, int] | None:
|
||||
"""Pull token counts from LLMResult.llm_output or generation metadata."""
|
||||
# Most providers put token_usage in llm_output
|
||||
llm_out = response.llm_output or {}
|
||||
usage = llm_out.get("token_usage") or llm_out.get("usage") or {}
|
||||
if usage:
|
||||
prompt = usage.get("prompt_tokens", 0) or 0
|
||||
completion = usage.get("completion_tokens", 0) or 0
|
||||
total = usage.get("total_tokens", 0) or (prompt + completion)
|
||||
return prompt, completion, total
|
||||
|
||||
# Fallback: check generation-level response_metadata (langchain-openai ≥0.3)
|
||||
for gen_list in response.generations:
|
||||
for gen in gen_list:
|
||||
meta = getattr(gen, "generation_info", None) or {}
|
||||
usage = meta.get("usage") or meta.get("token_usage") or {}
|
||||
if usage:
|
||||
prompt = usage.get("prompt_tokens", 0) or 0
|
||||
completion = usage.get("completion_tokens", 0) or 0
|
||||
total = usage.get("total_tokens", 0) or (prompt + completion)
|
||||
return prompt, completion, total
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_model(response: LLMResult) -> str:
|
||||
llm_out = response.llm_output or {}
|
||||
return llm_out.get("model_name", "") or llm_out.get("model", "unknown")
|
||||
|
||||
def begin_ticker(self, ticker: str) -> None:
|
||||
"""Mark the start of a new ticker analysis. Call before propagate()."""
|
||||
with self._lock:
|
||||
self._ticker_snapshot = self.total_cost_usd
|
||||
|
||||
def log_ticker_spend(self, ticker: str) -> None:
|
||||
"""Log per-ticker and cumulative spend to stderr. Call after propagate()."""
|
||||
print(self.format_ticker_spend(ticker), file=sys.stderr)
|
||||
|
||||
def format_ticker_spend(self, ticker: str) -> str:
|
||||
"""Return a human-readable ticker spend string."""
|
||||
with self._lock:
|
||||
ticker_cost = self.total_cost_usd - self._ticker_snapshot
|
||||
self.ticker_costs[ticker] = self.ticker_costs.get(ticker, 0.0) + ticker_cost
|
||||
budget_str = f" / budget ${self.max_cost:.2f}" if self.max_cost is not None else ""
|
||||
exceeded = " [OVER BUDGET]" if self.budget_exceeded else ""
|
||||
return f"[spend] {ticker}: ${ticker_cost:.4f} | cumulative: ${self.total_cost_usd:.4f}{budget_str}{exceeded}"
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all counters."""
|
||||
with self._lock:
|
||||
self.prompt_tokens = 0
|
||||
self.completion_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.total_cost_usd = 0.0
|
||||
self.budget_exceeded = False
|
||||
self.records.clear()
|
||||
self.audit_trail.clear()
|
||||
self._run_names.clear()
|
||||
self._run_parents.clear()
|
||||
self._tool_starts.clear()
|
||||
|
||||
def format_audit_trail(self) -> str:
|
||||
"""Return a human-readable audit trail string."""
|
||||
with self._lock:
|
||||
if not self.audit_trail:
|
||||
return "[audit] No calls recorded."
|
||||
lines = ["[audit] Delegation chain:"]
|
||||
for i, e in enumerate(self.audit_trail, 1):
|
||||
if e.call_type == "llm":
|
||||
lines.append(
|
||||
f" {i}. {e.agent} → LLM({e.name}) "
|
||||
f"tokens={e.prompt_tokens}+{e.completion_tokens} "
|
||||
f"cost=${e.cost_usd:.4f}"
|
||||
)
|
||||
else:
|
||||
lines.append(f" {i}. {e.agent} → tool({e.name})")
|
||||
return "\n".join(lines)
|
||||
|
||||
def log_audit_trail(self) -> None:
|
||||
"""Print the audit trail to stderr."""
|
||||
print(self.format_audit_trail(), file=sys.stderr)
|
||||
Loading…
Reference in New Issue