feat: add footer statistics tracking with LangChain callbacks
- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens - Integrate callbacks into TradingAgentsGraph and all LLM clients - Dynamic agent/report counts based on selected analysts - Fix report completion counting (tied to agent completion)
This commit is contained in:
parent
b06936f420
commit
54cdb146d0
255
cli/main.py
255
cli/main.py
|
|
@ -15,7 +15,6 @@ from rich.columns import Columns
|
|||
from rich.markdown import Markdown
|
||||
from rich.layout import Layout
|
||||
from rich.text import Text
|
||||
from rich.live import Live
|
||||
from rich.table import Table
|
||||
from collections import deque
|
||||
import time
|
||||
|
|
@ -29,6 +28,7 @@ from tradingagents.default_config import DEFAULT_CONFIG
|
|||
from cli.models import AnalystType
|
||||
from cli.utils import *
|
||||
from cli.announcements import fetch_announcements, display_announcements
|
||||
from cli.stats_handler import StatsCallbackHandler
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -41,40 +41,99 @@ app = typer.Typer(
|
|||
|
||||
# Create a deque to store recent messages with a maximum length
|
||||
class MessageBuffer:
|
||||
# Fixed teams that always run (not user-selectable)
|
||||
FIXED_AGENTS = {
|
||||
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
|
||||
"Trading Team": ["Trader"],
|
||||
"Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"],
|
||||
"Portfolio Management": ["Portfolio Manager"],
|
||||
}
|
||||
|
||||
# Analyst name mapping
|
||||
ANALYST_MAPPING = {
|
||||
"market": "Market Analyst",
|
||||
"social": "Social Analyst",
|
||||
"news": "News Analyst",
|
||||
"fundamentals": "Fundamentals Analyst",
|
||||
}
|
||||
|
||||
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
|
||||
# analyst_key: which analyst selection controls this section (None = always included)
|
||||
# finalizing_agent: which agent must be "completed" for this report to count as done
|
||||
REPORT_SECTIONS = {
|
||||
"market_report": ("market", "Market Analyst"),
|
||||
"sentiment_report": ("social", "Social Analyst"),
|
||||
"news_report": ("news", "News Analyst"),
|
||||
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
|
||||
"investment_plan": (None, "Research Manager"),
|
||||
"trader_investment_plan": (None, "Trader"),
|
||||
"final_trade_decision": (None, "Portfolio Manager"),
|
||||
}
|
||||
|
||||
def __init__(self, max_length=100):
|
||||
self.messages = deque(maxlen=max_length)
|
||||
self.tool_calls = deque(maxlen=max_length)
|
||||
self.current_report = None
|
||||
self.final_report = None # Store the complete final report
|
||||
self.agent_status = {
|
||||
# Analyst Team
|
||||
"Market Analyst": "pending",
|
||||
"Social Analyst": "pending",
|
||||
"News Analyst": "pending",
|
||||
"Fundamentals Analyst": "pending",
|
||||
# Research Team
|
||||
"Bull Researcher": "pending",
|
||||
"Bear Researcher": "pending",
|
||||
"Research Manager": "pending",
|
||||
# Trading Team
|
||||
"Trader": "pending",
|
||||
# Risk Management Team
|
||||
"Aggressive Analyst": "pending",
|
||||
"Neutral Analyst": "pending",
|
||||
"Conservative Analyst": "pending",
|
||||
# Portfolio Management Team
|
||||
"Portfolio Manager": "pending",
|
||||
}
|
||||
self.agent_status = {}
|
||||
self.current_agent = None
|
||||
self.report_sections = {
|
||||
"market_report": None,
|
||||
"sentiment_report": None,
|
||||
"news_report": None,
|
||||
"fundamentals_report": None,
|
||||
"investment_plan": None,
|
||||
"trader_investment_plan": None,
|
||||
"final_trade_decision": None,
|
||||
}
|
||||
self.report_sections = {}
|
||||
self.selected_analysts = []
|
||||
|
||||
def init_for_analysis(self, selected_analysts):
|
||||
"""Initialize agent status and report sections based on selected analysts.
|
||||
|
||||
Args:
|
||||
selected_analysts: List of analyst type strings (e.g., ["market", "news"])
|
||||
"""
|
||||
self.selected_analysts = [a.lower() for a in selected_analysts]
|
||||
|
||||
# Build agent_status dynamically
|
||||
self.agent_status = {}
|
||||
|
||||
# Add selected analysts
|
||||
for analyst_key in self.selected_analysts:
|
||||
if analyst_key in self.ANALYST_MAPPING:
|
||||
self.agent_status[self.ANALYST_MAPPING[analyst_key]] = "pending"
|
||||
|
||||
# Add fixed teams
|
||||
for team_agents in self.FIXED_AGENTS.values():
|
||||
for agent in team_agents:
|
||||
self.agent_status[agent] = "pending"
|
||||
|
||||
# Build report_sections dynamically
|
||||
self.report_sections = {}
|
||||
for section, (analyst_key, _) in self.REPORT_SECTIONS.items():
|
||||
if analyst_key is None or analyst_key in self.selected_analysts:
|
||||
self.report_sections[section] = None
|
||||
|
||||
# Reset other state
|
||||
self.current_report = None
|
||||
self.final_report = None
|
||||
self.current_agent = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
|
||||
def get_completed_reports_count(self):
|
||||
"""Count reports that are finalized (their finalizing agent is completed).
|
||||
|
||||
A report is considered complete when:
|
||||
1. The report section has content (not None), AND
|
||||
2. The agent responsible for finalizing that report has status "completed"
|
||||
|
||||
This prevents interim updates (like debate rounds) from counting as completed.
|
||||
"""
|
||||
count = 0
|
||||
for section in self.report_sections:
|
||||
if section not in self.REPORT_SECTIONS:
|
||||
continue
|
||||
_, finalizing_agent = self.REPORT_SECTIONS[section]
|
||||
# Report is complete if it has content AND its finalizing agent is done
|
||||
has_content = self.report_sections.get(section) is not None
|
||||
agent_done = self.agent_status.get(finalizing_agent) == "completed"
|
||||
if has_content and agent_done:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def add_message(self, message_type, content):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
|
|
@ -126,46 +185,39 @@ class MessageBuffer:
|
|||
def _update_final_report(self):
|
||||
report_parts = []
|
||||
|
||||
# Analyst Team Reports
|
||||
if any(
|
||||
self.report_sections[section]
|
||||
for section in [
|
||||
"market_report",
|
||||
"sentiment_report",
|
||||
"news_report",
|
||||
"fundamentals_report",
|
||||
]
|
||||
):
|
||||
# Analyst Team Reports - use .get() to handle missing sections
|
||||
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
|
||||
if any(self.report_sections.get(section) for section in analyst_sections):
|
||||
report_parts.append("## Analyst Team Reports")
|
||||
if self.report_sections["market_report"]:
|
||||
if self.report_sections.get("market_report"):
|
||||
report_parts.append(
|
||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
||||
)
|
||||
if self.report_sections["sentiment_report"]:
|
||||
if self.report_sections.get("sentiment_report"):
|
||||
report_parts.append(
|
||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
||||
)
|
||||
if self.report_sections["news_report"]:
|
||||
if self.report_sections.get("news_report"):
|
||||
report_parts.append(
|
||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
||||
)
|
||||
if self.report_sections["fundamentals_report"]:
|
||||
if self.report_sections.get("fundamentals_report"):
|
||||
report_parts.append(
|
||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
||||
)
|
||||
|
||||
# Research Team Reports
|
||||
if self.report_sections["investment_plan"]:
|
||||
if self.report_sections.get("investment_plan"):
|
||||
report_parts.append("## Research Team Decision")
|
||||
report_parts.append(f"{self.report_sections['investment_plan']}")
|
||||
|
||||
# Trading Team Reports
|
||||
if self.report_sections["trader_investment_plan"]:
|
||||
if self.report_sections.get("trader_investment_plan"):
|
||||
report_parts.append("## Trading Team Plan")
|
||||
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
||||
|
||||
# Portfolio Management Decision
|
||||
if self.report_sections["final_trade_decision"]:
|
||||
if self.report_sections.get("final_trade_decision"):
|
||||
report_parts.append("## Portfolio Management Decision")
|
||||
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
||||
|
||||
|
|
@ -191,7 +243,14 @@ def create_layout():
|
|||
return layout
|
||||
|
||||
|
||||
def update_display(layout, spinner_text=None):
|
||||
def format_tokens(n):
|
||||
"""Format token count for display."""
|
||||
if n >= 1000:
|
||||
return f"{n/1000:.1f}k"
|
||||
return str(n)
|
||||
|
||||
|
||||
def update_display(layout, spinner_text=None, stats_handler=None, start_time=None):
|
||||
# Header with welcome message
|
||||
layout["header"].update(
|
||||
Panel(
|
||||
|
|
@ -218,8 +277,8 @@ def update_display(layout, spinner_text=None):
|
|||
progress_table.add_column("Agent", style="green", justify="center", width=20)
|
||||
progress_table.add_column("Status", style="yellow", justify="center", width=20)
|
||||
|
||||
# Group agents by team
|
||||
teams = {
|
||||
# Group agents by team - filter to only include agents in agent_status
|
||||
all_teams = {
|
||||
"Analyst Team": [
|
||||
"Market Analyst",
|
||||
"Social Analyst",
|
||||
|
|
@ -232,10 +291,17 @@ def update_display(layout, spinner_text=None):
|
|||
"Portfolio Management": ["Portfolio Manager"],
|
||||
}
|
||||
|
||||
# Filter teams to only include agents that are in agent_status
|
||||
teams = {}
|
||||
for team, agents in all_teams.items():
|
||||
active_agents = [a for a in agents if a in message_buffer.agent_status]
|
||||
if active_agents:
|
||||
teams[team] = active_agents
|
||||
|
||||
for team, agents in teams.items():
|
||||
# Add first agent with team name
|
||||
first_agent = agents[0]
|
||||
status = message_buffer.agent_status[first_agent]
|
||||
status = message_buffer.agent_status.get(first_agent, "pending")
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||
|
|
@ -252,7 +318,7 @@ def update_display(layout, spinner_text=None):
|
|||
|
||||
# Add remaining agents in team
|
||||
for agent in agents[1:]:
|
||||
status = message_buffer.agent_status[agent]
|
||||
status = message_buffer.agent_status.get(agent, "pending")
|
||||
if status == "in_progress":
|
||||
spinner = Spinner(
|
||||
"dots", text="[blue]in_progress[/blue]", style="bold cyan"
|
||||
|
|
@ -379,19 +445,43 @@ def update_display(layout, spinner_text=None):
|
|||
)
|
||||
|
||||
# Footer with statistics
|
||||
tool_calls_count = len(message_buffer.tool_calls)
|
||||
llm_calls_count = sum(
|
||||
1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
|
||||
)
|
||||
reports_count = sum(
|
||||
1 for content in message_buffer.report_sections.values() if content is not None
|
||||
# Agent progress - derived from agent_status dict
|
||||
agents_completed = sum(
|
||||
1 for status in message_buffer.agent_status.values() if status == "completed"
|
||||
)
|
||||
agents_total = len(message_buffer.agent_status)
|
||||
|
||||
# Report progress - based on agent completion (not just content existence)
|
||||
reports_completed = message_buffer.get_completed_reports_count()
|
||||
reports_total = len(message_buffer.report_sections)
|
||||
|
||||
# Build stats parts
|
||||
stats_parts = [f"Agents: {agents_completed}/{agents_total}"]
|
||||
|
||||
# LLM and tool stats from callback handler
|
||||
if stats_handler:
|
||||
stats = stats_handler.get_stats()
|
||||
stats_parts.append(f"LLM: {stats['llm_calls']}")
|
||||
stats_parts.append(f"Tools: {stats['tool_calls']}")
|
||||
|
||||
# Token display with graceful fallback
|
||||
if stats["tokens_in"] > 0 or stats["tokens_out"] > 0:
|
||||
tokens_str = f"Tokens: {format_tokens(stats['tokens_in'])}\u2191 {format_tokens(stats['tokens_out'])}\u2193"
|
||||
else:
|
||||
tokens_str = "Tokens: --"
|
||||
stats_parts.append(tokens_str)
|
||||
|
||||
stats_parts.append(f"Reports: {reports_completed}/{reports_total}")
|
||||
|
||||
# Elapsed time
|
||||
if start_time:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed_str = f"\u23f1 {int(elapsed // 60):02d}:{int(elapsed % 60):02d}"
|
||||
stats_parts.append(elapsed_str)
|
||||
|
||||
stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
|
||||
stats_table.add_column("Stats", justify="center")
|
||||
stats_table.add_row(
|
||||
f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
|
||||
)
|
||||
stats_table.add_row(" | ".join(stats_parts))
|
||||
|
||||
layout["footer"].update(Panel(stats_table, border_style="grey50"))
|
||||
|
||||
|
|
@ -803,11 +893,24 @@ def run_analysis():
|
|||
config["google_thinking_level"] = selections.get("google_thinking_level")
|
||||
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
|
||||
|
||||
# Initialize the graph
|
||||
# Create stats callback handler for tracking LLM/tool calls
|
||||
stats_handler = StatsCallbackHandler()
|
||||
|
||||
# Initialize the graph with callbacks bound to LLMs
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
||||
[analyst.value for analyst in selections["analysts"]],
|
||||
config=config,
|
||||
debug=True,
|
||||
callbacks=[stats_handler],
|
||||
)
|
||||
|
||||
# Initialize message buffer with selected analysts
|
||||
selected_analyst_keys = [analyst.value for analyst in selections["analysts"]]
|
||||
message_buffer.init_for_analysis(selected_analyst_keys)
|
||||
|
||||
# Track start time for elapsed display
|
||||
start_time = time.time()
|
||||
|
||||
# Create result directory
|
||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -860,7 +963,7 @@ def run_analysis():
|
|||
|
||||
with Live(layout, refresh_per_second=4) as live:
|
||||
# Initial display
|
||||
update_display(layout)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Add initial messages
|
||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||
|
|
@ -871,34 +974,26 @@ def run_analysis():
|
|||
"System",
|
||||
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
||||
)
|
||||
update_display(layout)
|
||||
|
||||
# Reset agent statuses
|
||||
for agent in message_buffer.agent_status:
|
||||
message_buffer.update_agent_status(agent, "pending")
|
||||
|
||||
# Reset report sections
|
||||
for section in message_buffer.report_sections:
|
||||
message_buffer.report_sections[section] = None
|
||||
message_buffer.current_report = None
|
||||
message_buffer.final_report = None
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Update agent status to in_progress for the first analyst
|
||||
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
|
||||
message_buffer.update_agent_status(first_analyst, "in_progress")
|
||||
update_display(layout)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Create spinner text
|
||||
spinner_text = (
|
||||
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
|
||||
)
|
||||
update_display(layout, spinner_text)
|
||||
update_display(layout, spinner_text, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
# Initialize state and get graph args
|
||||
# Initialize state and get graph args with callbacks
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
selections["ticker"], selections["analysis_date"]
|
||||
)
|
||||
args = graph.propagator.get_graph_args()
|
||||
# Pass callbacks to graph config for tool execution tracking
|
||||
# (LLM tracking is handled separately via LLM constructor)
|
||||
args = graph.propagator.get_graph_args(callbacks=[stats_handler])
|
||||
|
||||
# Stream the analysis
|
||||
trace = []
|
||||
|
|
@ -1112,7 +1207,7 @@ def run_analysis():
|
|||
)
|
||||
|
||||
# Update the display
|
||||
update_display(layout)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
trace.append(chunk)
|
||||
|
||||
|
|
@ -1136,7 +1231,7 @@ def run_analysis():
|
|||
# Display the complete final report
|
||||
display_complete_report(final_state)
|
||||
|
||||
update_display(layout)
|
||||
update_display(layout, stats_handler=stats_handler, start_time=start_time)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,76 @@
|
|||
import threading
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
|
||||
class StatsCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler that tracks LLM calls, tool calls, and token usage."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
self.llm_calls = 0
|
||||
self.tool_calls = 0
|
||||
self.tokens_in = 0
|
||||
self.tokens_out = 0
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Increment LLM call counter when an LLM starts."""
|
||||
with self._lock:
|
||||
self.llm_calls += 1
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[Any]],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Increment LLM call counter when a chat model starts."""
|
||||
with self._lock:
|
||||
self.llm_calls += 1
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Extract token usage from LLM response."""
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
except (IndexError, TypeError):
|
||||
return
|
||||
|
||||
usage_metadata = None
|
||||
if hasattr(generation, "message"):
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and hasattr(message, "usage_metadata"):
|
||||
usage_metadata = message.usage_metadata
|
||||
|
||||
if usage_metadata:
|
||||
with self._lock:
|
||||
self.tokens_in += usage_metadata.get("input_tokens", 0)
|
||||
self.tokens_out += usage_metadata.get("output_tokens", 0)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Increment tool call counter when a tool starts."""
|
||||
with self._lock:
|
||||
self.tool_calls += 1
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Return current statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
"llm_calls": self.llm_calls,
|
||||
"tool_calls": self.tool_calls,
|
||||
"tokens_in": self.tokens_in,
|
||||
"tokens_out": self.tokens_out,
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
# TradingAgents/graph/propagation.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
|
|
@ -41,9 +41,17 @@ class Propagator:
|
|||
"news_report": "",
|
||||
}
|
||||
|
||||
def get_graph_args(self) -> Dict[str, Any]:
|
||||
"""Get arguments for the graph invocation."""
|
||||
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
|
||||
"""Get arguments for the graph invocation.
|
||||
|
||||
Args:
|
||||
callbacks: Optional list of callback handlers for tool execution tracking.
|
||||
Note: LLM callbacks are handled separately via LLM constructor.
|
||||
"""
|
||||
config = {"recursion_limit": self.max_recur_limit}
|
||||
if callbacks:
|
||||
config["callbacks"] = callbacks
|
||||
return {
|
||||
"stream_mode": "values",
|
||||
"config": {"recursion_limit": self.max_recur_limit},
|
||||
"config": config,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ class TradingAgentsGraph:
|
|||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
callbacks: Optional[List] = None,
|
||||
):
|
||||
"""Initialize the trading agents graph and components.
|
||||
|
||||
|
|
@ -55,9 +56,11 @@ class TradingAgentsGraph:
|
|||
selected_analysts: List of analyst types to include
|
||||
debug: Whether to run in debug mode
|
||||
config: Configuration dictionary. If None, uses default config
|
||||
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
|
|
@ -71,6 +74,10 @@ class TradingAgentsGraph:
|
|||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
|
||||
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
||||
deep_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["deep_think_llm"],
|
||||
|
|
@ -83,6 +90,7 @@ class TradingAgentsGraph:
|
|||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
self.quick_thinking_llm = quick_client.get_llm()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class AnthropicClient(BaseLLMClient):
|
|||
"""Return configured ChatAnthropic instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
|
||||
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from .base_client import BaseLLMClient
|
|||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .vllm_client import VLLMClient
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
|
|
@ -16,7 +15,7 @@ def create_llm_client(
|
|||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||
model: Model name/identifier
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
|
|
@ -41,7 +40,4 @@ def create_llm_client(
|
|||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "vllm":
|
||||
return VLLMClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class GoogleClient(BaseLLMClient):
|
|||
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "google_api_key"):
|
||||
for key in ("timeout", "max_retries", "google_api_key", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class OpenAIClient(BaseLLMClient):
|
|||
elif self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
|
||||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
|
|
|||
|
|
@ -69,11 +69,11 @@ VALID_MODELS = {
|
|||
def validate_model(provider: str, model: str) -> bool:
|
||||
"""Check if model name is valid for the given provider.
|
||||
|
||||
For ollama, openrouter, vllm - any model is accepted.
|
||||
For ollama, openrouter - any model is accepted.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("ollama", "openrouter", "vllm"):
|
||||
if provider_lower in ("ollama", "openrouter"):
|
||||
return True
|
||||
|
||||
if provider_lower not in VALID_MODELS:
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
|
||||
class VLLMClient(BaseLLMClient):
|
||||
"""Client for vLLM (placeholder for future implementation)."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured vLLM instance."""
|
||||
raise NotImplementedError("vLLM client not yet implemented")
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for vLLM."""
|
||||
return True
|
||||
Loading…
Reference in New Issue