Compare commits

...

24 Commits
v0.2.2 ... main

Author SHA1 Message Date
Yijia-Xiao fa4d01c23a
fix: process all chunk messages for tool call logging, harden memory score normalization (#534, #531) 2026-04-13 07:21:33 +00:00
Yijia-Xiao b0f6058299
feat: add DeepSeek, Qwen, GLM, and Azure OpenAI provider support 2026-04-13 07:12:07 +00:00
Yijia-Xiao 59d6b2152d
fix: use ~/.tradingagents/ for cache and logs, resolving Docker permission issue (#519) 2026-04-13 05:26:04 +00:00
Yijia-Xiao 10c136f49c
feat: add Docker support for cross-platform deployment 2026-04-04 08:14:01 +00:00
Yijia-Xiao 4f965bf46a
feat: dynamic OpenRouter model selection with search (#482, #337) 2026-04-04 07:56:44 +00:00
Yijia-Xiao bdb9c29d44
refactor: remove stale imports, use configurable results path (#499) 2026-04-04 07:35:35 +00:00
Yijia-Xiao bdc5fc62d3
chore: bump langchain-google-genai minimum to 4.0.0 for thought signature support 2026-04-04 07:28:03 +00:00
Yijia-Xiao 78fb66aed1
fix: normalize indicator names to lowercase (#490) 2026-04-04 07:23:31 +00:00
Yijia-Xiao 7269f877c1
fix: portfolio manager reads trader's proposal and research plan (#503) 2026-04-04 07:22:01 +00:00
Yijia-Xiao 28d5cc661f
fix: add missing pandas import in y_finance.py (#488) 2026-04-04 07:14:10 +00:00
Yijia-Xiao 7004dfe554
fix: remove hardcoded Google endpoint that caused 404 (#493, #496) 2026-04-04 07:07:53 +00:00
Yijia-Xiao 4641c03340
TradingAgents v0.2.3 2026-03-29 19:50:46 +00:00
Yijia-Xiao e75d17bc51
chore: update model lists and defaults to GPT-5.4 family 2026-03-29 19:45:36 +00:00
Yijia-Xiao 6cddd26d6e
feat: multi-language output support for analyst reports and final decision (#472) 2026-03-29 19:19:01 +00:00
Yijia Xiao c61242a28c
Merge pull request #464 from CadeYu/sync-validator-models
sync model validation with cli catalog
2026-03-29 11:07:51 -07:00
Yijia-Xiao 58e99421bd
fix: pass base_url to Google and Anthropic clients for proxy support (#427) 2026-03-29 17:59:52 +00:00
Yijia Xiao 46e1b600b8
Merge pull request #453 from javierdejesusda/fix/standardize-google-api-key
fix(llm_clients): standardize Google API key to unified api_key param
2026-03-29 10:54:28 -07:00
Yijia-Xiao ae8c8aebe8
fix: gracefully handle invalid indicator names in tool calls (#429) 2026-03-29 17:50:30 +00:00
Yijia-Xiao f3f58bdbdc
fix: add yf_retry to yfinance news fetchers (#445) 2026-03-29 17:42:24 +00:00
Yijia-Xiao e1113880a1
fix: prevent look-ahead bias in backtesting data fetchers (#475) 2026-03-29 17:34:35 +00:00
CadeYu bd6a5b75b5 fix model catalog typing and known-model helper 2026-03-25 21:46:56 +08:00
CadeYu 8793336dad sync model validation with cli catalog 2026-03-25 21:23:02 +08:00
javierdejesusda 047b38971c refactor: simplify api_key mapping and consolidate tests
Apply review suggestions: use concise `or` pattern for API key
resolution, consolidate tests into parameterized subTest, move
import to module level per PEP 8.
2026-03-24 14:52:51 +01:00
javierdejesusda f5026009f9 fix(llm_clients): standardize Google API key to unified api_key param
GoogleClient now accepts the unified `api_key` parameter used by
OpenAI and Anthropic clients, mapping it to the provider-specific
`google_api_key` that ChatGoogleGenerativeAI expects. Legacy
`google_api_key` still works for backward compatibility.

Resolves TODO.md item #2 (inconsistent parameter handling).
2026-03-24 14:35:02 +01:00
46 changed files with 785 additions and 457 deletions

15
.dockerignore Normal file
View File

@ -0,0 +1,15 @@
.git
.venv
.env
.claude
.idea
.vscode
.DS_Store
__pycache__
*.egg-info
build
dist
results
eval_results
Dockerfile
docker-compose.yml

5
.env.enterprise.example Normal file
View File

@ -0,0 +1,5 @@
# Azure OpenAI
AZURE_OPENAI_API_KEY=
AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
AZURE_OPENAI_DEPLOYMENT_NAME=
# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API

View File

@ -3,4 +3,7 @@ OPENAI_API_KEY=
GOOGLE_API_KEY= GOOGLE_API_KEY=
ANTHROPIC_API_KEY= ANTHROPIC_API_KEY=
XAI_API_KEY= XAI_API_KEY=
DEEPSEEK_API_KEY=
DASHSCOPE_API_KEY=
ZHIPU_API_KEY=
OPENROUTER_API_KEY= OPENROUTER_API_KEY=

27
Dockerfile Normal file
View File

@ -0,0 +1,27 @@
FROM python:3.12-slim AS builder
ENV PYTHONDONTWRITEBYTECODE=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
WORKDIR /build
COPY . .
RUN pip install --no-cache-dir .
FROM python:3.12-slim
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN useradd --create-home appuser
USER appuser
WORKDIR /home/appuser/app
COPY --from=builder --chown=appuser:appuser /build .
ENTRYPOINT ["tradingagents"]

View File

@ -28,6 +28,7 @@
# TradingAgents: Multi-Agents LLM Financial Trading Framework # TradingAgents: Multi-Agents LLM Financial Trading Framework
## News ## News
- [2026-03] **TradingAgents v0.2.3** released with multi-language support, GPT-5.4 family models, unified model catalog, backtesting date fidelity, and proxy support.
- [2026-03] **TradingAgents v0.2.2** released with GPT-5.4/Gemini 3.1/Claude 4.6 model coverage, five-tier rating scale, OpenAI Responses API, Anthropic effort control, and cross-platform stability. - [2026-03] **TradingAgents v0.2.2** released with GPT-5.4/Gemini 3.1/Claude 4.6 model coverage, five-tier rating scale, OpenAI Responses API, Anthropic effort control, and cross-platform stability.
- [2026-02] **TradingAgents v0.2.0** released with multi-provider LLM support (GPT-5.x, Gemini 3.x, Claude 4.x, Grok 4.x) and improved system architecture. - [2026-02] **TradingAgents v0.2.0** released with multi-provider LLM support (GPT-5.x, Gemini 3.x, Claude 4.x, Grok 4.x) and improved system architecture.
- [2026-01] **Trading-R1** [Technical Report](https://arxiv.org/abs/2509.11420) released, with [Terminal](https://github.com/TauricResearch/Trading-R1) expected to land soon. - [2026-01] **Trading-R1** [Technical Report](https://arxiv.org/abs/2509.11420) released, with [Terminal](https://github.com/TauricResearch/Trading-R1) expected to land soon.
@ -117,6 +118,19 @@ Install the package and its dependencies:
pip install . pip install .
``` ```
### Docker
Alternatively, run with Docker:
```bash
cp .env.example .env # add your API keys
docker compose run --rm tradingagents
```
For local models with Ollama:
```bash
docker compose --profile ollama run --rm tradingagents-ollama
```
### Required APIs ### Required APIs
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider: TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
@ -126,10 +140,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT)
export GOOGLE_API_KEY=... # Google (Gemini) export GOOGLE_API_KEY=... # Google (Gemini)
export ANTHROPIC_API_KEY=... # Anthropic (Claude) export ANTHROPIC_API_KEY=... # Anthropic (Claude)
export XAI_API_KEY=... # xAI (Grok) export XAI_API_KEY=... # xAI (Grok)
export DEEPSEEK_API_KEY=... # DeepSeek
export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope)
export ZHIPU_API_KEY=... # GLM (Zhipu)
export OPENROUTER_API_KEY=... # OpenRouter export OPENROUTER_API_KEY=... # OpenRouter
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
``` ```
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
For local models, configure Ollama with `llm_provider: "ollama"` in your config. For local models, configure Ollama with `llm_provider: "ollama"` in your config.
Alternatively, copy `.env.example` to `.env` and fill in your keys: Alternatively, copy `.env.example` to `.env` and fill in your keys:
@ -189,8 +208,8 @@ from tradingagents.default_config import DEFAULT_CONFIG
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama
config["deep_think_llm"] = "gpt-5.2" # Model for complex reasoning config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
config["quick_think_llm"] = "gpt-5-mini" # Model for quick tasks config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
config["max_debate_rounds"] = 2 config["max_debate_rounds"] = 2
ta = TradingAgentsGraph(debug=True, config=config) ta = TradingAgentsGraph(debug=True, config=config)

View File

@ -6,8 +6,9 @@ from functools import wraps
from rich.console import Console from rich.console import Console
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from .env file # Load environment variables
load_dotenv() load_dotenv()
load_dotenv(".env.enterprise", override=False)
from rich.panel import Panel from rich.panel import Panel
from rich.spinner import Spinner from rich.spinner import Spinner
from rich.live import Live from rich.live import Live
@ -79,7 +80,7 @@ class MessageBuffer:
self.current_agent = None self.current_agent = None
self.report_sections = {} self.report_sections = {}
self.selected_analysts = [] self.selected_analysts = []
self._last_message_id = None self._processed_message_ids = set()
def init_for_analysis(self, selected_analysts): def init_for_analysis(self, selected_analysts):
"""Initialize agent status and report sections based on selected analysts. """Initialize agent status and report sections based on selected analysts.
@ -114,7 +115,7 @@ class MessageBuffer:
self.current_agent = None self.current_agent = None
self.messages.clear() self.messages.clear()
self.tool_calls.clear() self.tool_calls.clear()
self._last_message_id = None self._processed_message_ids.clear()
def get_completed_reports_count(self): def get_completed_reports_count(self):
"""Count reports that are finalized (their finalizing agent is completed). """Count reports that are finalized (their finalizing agent is completed).
@ -519,10 +520,19 @@ def get_user_selections():
) )
analysis_date = get_analysis_date() analysis_date = get_analysis_date()
# Step 3: Select analysts # Step 3: Output language
console.print( console.print(
create_question_box( create_question_box(
"Step 3: Analysts Team", "Select your LLM analyst agents for the analysis" "Step 3: Output Language",
"Select the language for analyst reports and final decision"
)
)
output_language = ask_output_language()
# Step 4: Select analysts
console.print(
create_question_box(
"Step 4: Analysts Team", "Select your LLM analyst agents for the analysis"
) )
) )
selected_analysts = select_analysts() selected_analysts = select_analysts()
@ -530,32 +540,32 @@ def get_user_selections():
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}" f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
) )
# Step 4: Research depth # Step 5: Research depth
console.print( console.print(
create_question_box( create_question_box(
"Step 4: Research Depth", "Select your research depth level" "Step 5: Research Depth", "Select your research depth level"
) )
) )
selected_research_depth = select_research_depth() selected_research_depth = select_research_depth()
# Step 5: OpenAI backend # Step 6: LLM Provider
console.print( console.print(
create_question_box( create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to" "Step 6: LLM Provider", "Select your LLM provider"
) )
) )
selected_llm_provider, backend_url = select_llm_provider() selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents # Step 7: Thinking agents
console.print( console.print(
create_question_box( create_question_box(
"Step 6: Thinking Agents", "Select your thinking agents for analysis" "Step 7: Thinking Agents", "Select your thinking agents for analysis"
) )
) )
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
# Step 7: Provider-specific thinking configuration # Step 8: Provider-specific thinking configuration
thinking_level = None thinking_level = None
reasoning_effort = None reasoning_effort = None
anthropic_effort = None anthropic_effort = None
@ -564,7 +574,7 @@ def get_user_selections():
if provider_lower == "google": if provider_lower == "google":
console.print( console.print(
create_question_box( create_question_box(
"Step 7: Thinking Mode", "Step 8: Thinking Mode",
"Configure Gemini thinking mode" "Configure Gemini thinking mode"
) )
) )
@ -572,7 +582,7 @@ def get_user_selections():
elif provider_lower == "openai": elif provider_lower == "openai":
console.print( console.print(
create_question_box( create_question_box(
"Step 7: Reasoning Effort", "Step 8: Reasoning Effort",
"Configure OpenAI reasoning effort level" "Configure OpenAI reasoning effort level"
) )
) )
@ -580,7 +590,7 @@ def get_user_selections():
elif provider_lower == "anthropic": elif provider_lower == "anthropic":
console.print( console.print(
create_question_box( create_question_box(
"Step 7: Effort Level", "Step 8: Effort Level",
"Configure Claude effort level" "Configure Claude effort level"
) )
) )
@ -598,6 +608,7 @@ def get_user_selections():
"google_thinking_level": thinking_level, "google_thinking_level": thinking_level,
"openai_reasoning_effort": reasoning_effort, "openai_reasoning_effort": reasoning_effort,
"anthropic_effort": anthropic_effort, "anthropic_effort": anthropic_effort,
"output_language": output_language,
} }
@ -931,6 +942,7 @@ def run_analysis():
config["google_thinking_level"] = selections.get("google_thinking_level") config["google_thinking_level"] = selections.get("google_thinking_level")
config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort") config["openai_reasoning_effort"] = selections.get("openai_reasoning_effort")
config["anthropic_effort"] = selections.get("anthropic_effort") config["anthropic_effort"] = selections.get("anthropic_effort")
config["output_language"] = selections.get("output_language", "English")
# Create stats callback handler for tracking LLM/tool calls # Create stats callback handler for tracking LLM/tool calls
stats_handler = StatsCallbackHandler() stats_handler = StatsCallbackHandler()
@ -1041,28 +1053,24 @@ def run_analysis():
# Stream the analysis # Stream the analysis
trace = [] trace = []
for chunk in graph.graph.stream(init_agent_state, **args): for chunk in graph.graph.stream(init_agent_state, **args):
# Process messages if present (skip duplicates via message ID) # Process all messages in chunk, deduplicating by message ID
if len(chunk["messages"]) > 0: for message in chunk.get("messages", []):
last_message = chunk["messages"][-1] msg_id = getattr(message, "id", None)
msg_id = getattr(last_message, "id", None) if msg_id is not None:
if msg_id in message_buffer._processed_message_ids:
continue
message_buffer._processed_message_ids.add(msg_id)
if msg_id != message_buffer._last_message_id: msg_type, content = classify_message_type(message)
message_buffer._last_message_id = msg_id if content and content.strip():
message_buffer.add_message(msg_type, content)
# Add message to buffer if hasattr(message, "tool_calls") and message.tool_calls:
msg_type, content = classify_message_type(last_message) for tool_call in message.tool_calls:
if content and content.strip(): if isinstance(tool_call, dict):
message_buffer.add_message(msg_type, content) message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
else:
# Handle tool calls message_buffer.add_tool_call(tool_call.name, tool_call.args)
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict):
message_buffer.add_tool_call(
tool_call["name"], tool_call["args"]
)
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
# Update analyst statuses based on report state (runs on every chunk) # Update analyst statuses based on report state (runs on every chunk)
update_analyst_statuses(message_buffer, chunk) update_analyst_statuses(message_buffer, chunk)

View File

@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Dict
from rich.console import Console from rich.console import Console
from cli.models import AnalystType from cli.models import AnalystType
from tradingagents.llm_clients.model_catalog import get_model_options
console = Console() console = Console()
@ -133,51 +134,70 @@ def select_research_depth() -> int:
return choice return choice
def select_shallow_thinking_agent(provider) -> str: def _fetch_openrouter_models() -> List[Tuple[str, str]]:
"""Select shallow thinking llm engine using an interactive selection.""" """Fetch available models from the OpenRouter API."""
import requests
try:
resp = requests.get("https://openrouter.ai/api/v1/models", timeout=10)
resp.raise_for_status()
models = resp.json().get("data", [])
return [(m.get("name") or m["id"], m["id"]) for m in models]
except Exception as e:
console.print(f"\n[yellow]Could not fetch OpenRouter models: {e}[/yellow]")
return []
# Define shallow thinking llm engine options with their corresponding model names
# Ordering: medium → light → heavy (balanced first for quick tasks) def select_openrouter_model() -> str:
# Within same tier, newer models first """Select an OpenRouter model from the newest available, or enter a custom ID."""
SHALLOW_AGENT_OPTIONS = { models = _fetch_openrouter_models()
"openai": [
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"), choices = [questionary.Choice(name, value=mid) for name, mid in models[:5]]
("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"), choices.append(questionary.Choice("Custom model ID", value="custom"))
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
],
"anthropic": [
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
],
"google": [
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
],
"xai": [
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
],
"openrouter": [
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
],
"ollama": [
("Qwen3:latest (8B, local)", "qwen3:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
],
}
choice = questionary.select( choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:", "Select OpenRouter Model (latest available):",
choices=choices,
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style([
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]),
).ask()
if choice is None or choice == "custom":
return questionary.text(
"Enter OpenRouter model ID (e.g. google/gemma-4-26b-a4b-it):",
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
).ask().strip()
return choice
def _prompt_custom_model_id() -> str:
"""Prompt user to type a custom model ID."""
return questionary.text(
"Enter model ID:",
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
).ask().strip()
def _select_model(provider: str, mode: str) -> str:
"""Select a model for the given provider and mode (quick/deep)."""
if provider.lower() == "openrouter":
return select_openrouter_model()
if provider.lower() == "azure":
return questionary.text(
f"Enter Azure deployment name ({mode}-thinking):",
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
).ask().strip()
choice = questionary.select(
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
choices=[ choices=[
questionary.Choice(display, value=value) questionary.Choice(display, value=value)
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] for display, value in get_model_options(provider, mode)
], ],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select", instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style( style=questionary.Style(
@ -190,95 +210,45 @@ def select_shallow_thinking_agent(provider) -> str:
).ask() ).ask()
if choice is None: if choice is None:
console.print( console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
)
exit(1) exit(1)
if choice == "custom":
return _prompt_custom_model_id()
return choice return choice
def select_shallow_thinking_agent(provider) -> str:
"""Select shallow thinking llm engine using an interactive selection."""
return _select_model(provider, "quick")
def select_deep_thinking_agent(provider) -> str: def select_deep_thinking_agent(provider) -> str:
"""Select deep thinking llm engine using an interactive selection.""" """Select deep thinking llm engine using an interactive selection."""
return _select_model(provider, "deep")
# Define deep thinking llm engine options with their corresponding model names def select_llm_provider() -> tuple[str, str | None]:
# Ordering: heavy → medium → light (most capable first for deep tasks) """Select the LLM provider and its API endpoint."""
# Within same tier, newer models first # (display_name, provider_key, base_url)
DEEP_AGENT_OPTIONS = { PROVIDERS = [
"openai": [ ("OpenAI", "openai", "https://api.openai.com/v1"),
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"), ("Google", "google", None),
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"), ("Anthropic", "anthropic", "https://api.anthropic.com/"),
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"), ("xAI", "xai", "https://api.x.ai/v1"),
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"), ("DeepSeek", "deepseek", "https://api.deepseek.com"),
], ("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
"anthropic": [ ("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"), ("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"), ("Azure OpenAI", "azure", None),
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"), ("Ollama", "ollama", "http://localhost:11434/v1"),
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
],
"google": [
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
],
"xai": [
("Grok 4 - Flagship model", "grok-4-0709"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
],
"openrouter": [
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
],
"ollama": [
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"),
],
}
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
exit(1)
return choice
def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection."""
# Define OpenAI api options with their corresponding endpoints
BASE_URLS = [
("OpenAI", "https://api.openai.com/v1"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Anthropic", "https://api.anthropic.com/"),
("xAI", "https://api.x.ai/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
] ]
choice = questionary.select( choice = questionary.select(
"Select your LLM Provider:", "Select your LLM Provider:",
choices=[ choices=[
questionary.Choice(display, value=(display, value)) questionary.Choice(display, value=(provider_key, url))
for display, value in BASE_URLS for display, provider_key, url in PROVIDERS
], ],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select", instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style( style=questionary.Style(
@ -291,13 +261,11 @@ def select_llm_provider() -> tuple[str, str]:
).ask() ).ask()
if choice is None: if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") console.print("\n[red]No LLM provider selected. Exiting...[/red]")
exit(1) exit(1)
display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
return display_name, url provider, url = choice
return provider, url
def ask_openai_reasoning_effort() -> str: def ask_openai_reasoning_effort() -> str:
@ -356,3 +324,37 @@ def ask_gemini_thinking_config() -> str | None:
("pointer", "fg:green noinherit"), ("pointer", "fg:green noinherit"),
]), ]),
).ask() ).ask()
def ask_output_language() -> str:
"""Ask for report output language."""
choice = questionary.select(
"Select Output Language:",
choices=[
questionary.Choice("English (default)", "English"),
questionary.Choice("Chinese (中文)", "Chinese"),
questionary.Choice("Japanese (日本語)", "Japanese"),
questionary.Choice("Korean (한국어)", "Korean"),
questionary.Choice("Hindi (हिन्दी)", "Hindi"),
questionary.Choice("Spanish (Español)", "Spanish"),
questionary.Choice("Portuguese (Português)", "Portuguese"),
questionary.Choice("French (Français)", "French"),
questionary.Choice("German (Deutsch)", "German"),
questionary.Choice("Arabic (العربية)", "Arabic"),
questionary.Choice("Russian (Русский)", "Russian"),
questionary.Choice("Custom language", "custom"),
],
style=questionary.Style([
("selected", "fg:yellow noinherit"),
("highlighted", "fg:yellow noinherit"),
("pointer", "fg:yellow noinherit"),
]),
).ask()
if choice == "custom":
return questionary.text(
"Enter language name (e.g. Turkish, Vietnamese, Thai, Indonesian):",
validate=lambda x: len(x.strip()) > 0 or "Please enter a language name.",
).ask().strip()
return choice

35
docker-compose.yml Normal file
View File

@ -0,0 +1,35 @@
services:
tradingagents:
build: .
env_file:
- .env
volumes:
- tradingagents_data:/home/appuser/.tradingagents
tty: true
stdin_open: true
ollama:
image: ollama/ollama:latest
volumes:
- ollama_data:/root/.ollama
profiles:
- ollama
tradingagents-ollama:
build: .
env_file:
- .env
environment:
- LLM_PROVIDER=ollama
volumes:
- tradingagents_data:/home/appuser/.tradingagents
depends_on:
- ollama
tty: true
stdin_open: true
profiles:
- ollama
volumes:
tradingagents_data:
ollama_data:

View File

@ -8,8 +8,8 @@ load_dotenv()
# Create a custom config # Create a custom config
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-5-mini" # Use a different model config["deep_think_llm"] = "gpt-5.4-mini" # Use a different model
config["quick_think_llm"] = "gpt-5-mini" # Use a different model config["quick_think_llm"] = "gpt-5.4-mini" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds config["max_debate_rounds"] = 1 # Increase debate rounds
# Configure data vendors (default uses yfinance, no extra API keys needed) # Configure data vendors (default uses yfinance, no extra API keys needed)

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "tradingagents" name = "tradingagents"
version = "0.2.2" version = "0.2.3"
description = "TradingAgents: Multi-Agents LLM Financial Trading Framework" description = "TradingAgents: Multi-Agents LLM Financial Trading Framework"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
@ -13,7 +13,7 @@ dependencies = [
"backtrader>=1.9.78.123", "backtrader>=1.9.78.123",
"langchain-anthropic>=0.3.15", "langchain-anthropic>=0.3.15",
"langchain-experimental>=0.3.4", "langchain-experimental>=0.3.4",
"langchain-google-genai>=2.1.5", "langchain-google-genai>=4.0.0",
"langchain-openai>=0.3.23", "langchain-openai>=0.3.23",
"langgraph>=0.4.8", "langgraph>=0.4.8",
"pandas>=2.3.0", "pandas>=2.3.0",

View File

@ -0,0 +1,28 @@
import unittest
from unittest.mock import patch
from tradingagents.llm_clients.google_client import GoogleClient
class TestGoogleApiKeyStandardization(unittest.TestCase):
"""Verify GoogleClient accepts unified api_key parameter."""
@patch("tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI")
def test_api_key_handling(self, mock_chat):
test_cases = [
("unified api_key is mapped", {"api_key": "test-key-123"}, "test-key-123"),
("legacy google_api_key still works", {"google_api_key": "legacy-key-456"}, "legacy-key-456"),
("unified api_key takes precedence", {"api_key": "unified", "google_api_key": "legacy"}, "unified"),
]
for msg, kwargs, expected_key in test_cases:
with self.subTest(msg=msg):
mock_chat.reset_mock()
client = GoogleClient("gemini-2.5-flash", **kwargs)
client.get_llm()
call_kwargs = mock_chat.call_args[1]
self.assertEqual(call_kwargs.get("google_api_key"), expected_key)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,52 @@
import unittest
import warnings
from tradingagents.llm_clients.base_client import BaseLLMClient
from tradingagents.llm_clients.model_catalog import get_known_models
from tradingagents.llm_clients.validators import validate_model
class DummyLLMClient(BaseLLMClient):
def __init__(self, provider: str, model: str):
self.provider = provider
super().__init__(model)
def get_llm(self):
self.warn_if_unknown_model()
return object()
def validate_model(self) -> bool:
return validate_model(self.provider, self.model)
class ModelValidationTests(unittest.TestCase):
def test_cli_catalog_models_are_all_validator_approved(self):
for provider, models in get_known_models().items():
if provider in ("ollama", "openrouter"):
continue
for model in models:
with self.subTest(provider=provider, model=model):
self.assertTrue(validate_model(provider, model))
def test_unknown_model_emits_warning_for_strict_provider(self):
client = DummyLLMClient("openai", "not-a-real-openai-model")
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client.get_llm()
self.assertEqual(len(caught), 1)
self.assertIn("not-a-real-openai-model", str(caught[0].message))
self.assertIn("openai", str(caught[0].message))
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
for provider in ("openrouter", "ollama"):
client = DummyLLMClient(provider, "custom-model-name")
with self.subTest(provider=provider):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client.get_llm()
self.assertEqual(caught, [])

View File

@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
build_instrument_context, build_instrument_context,
get_balance_sheet, get_balance_sheet,
@ -8,6 +6,7 @@ from tradingagents.agents.utils.agent_utils import (
get_fundamentals, get_fundamentals,
get_income_statement, get_income_statement,
get_insider_transactions, get_insider_transactions,
get_language_instruction,
) )
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -27,7 +26,8 @@ def create_fundamentals_analyst(llm):
system_message = ( system_message = (
"You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions." "You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read." + " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."
+ " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements.", + " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements."
+ get_language_instruction(),
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(

View File

@ -1,9 +1,8 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
build_instrument_context, build_instrument_context,
get_indicators, get_indicators,
get_language_instruction,
get_stock_data, get_stock_data,
) )
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -47,6 +46,7 @@ Volume-Based Indicators:
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions.""" - Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."""
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
+ get_language_instruction()
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(

View File

@ -1,9 +1,8 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
build_instrument_context, build_instrument_context,
get_global_news, get_global_news,
get_language_instruction,
get_news, get_news,
) )
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -22,6 +21,7 @@ def create_news_analyst(llm):
system_message = ( system_message = (
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions." "You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
+ get_language_instruction()
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(

View File

@ -1,7 +1,5 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_news
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -17,6 +15,7 @@ def create_social_media_analyst(llm):
system_message = ( system_message = (
"You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions." "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Provide specific, actionable insights with supporting evidence to help traders make informed decisions."
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
+ get_language_instruction()
) )
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(

View File

@ -1,4 +1,4 @@
from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
def create_portfolio_manager(llm, memory): def create_portfolio_manager(llm, memory):
@ -12,7 +12,8 @@ def create_portfolio_manager(llm, memory):
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
trader_plan = state["investment_plan"] research_plan = state["investment_plan"]
trader_plan = state["trader_investment_plan"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
@ -35,7 +36,8 @@ def create_portfolio_manager(llm, memory):
- **Sell**: Exit position or avoid entry - **Sell**: Exit position or avoid entry
**Context:** **Context:**
- Trader's proposed plan: **{trader_plan}** - Research Manager's investment plan: **{research_plan}**
- Trader's transaction proposal: **{trader_plan}**
- Lessons from past decisions: **{past_memory_str}** - Lessons from past decisions: **{past_memory_str}**
**Required Output Structure:** **Required Output Structure:**
@ -50,7 +52,7 @@ def create_portfolio_manager(llm, memory):
--- ---
Be decisive and ground every conclusion in specific evidence from the analysts.""" Be decisive and ground every conclusion in specific evidence from the analysts.{get_language_instruction()}"""
response = llm.invoke(prompt) response = llm.invoke(prompt)

View File

@ -1,5 +1,3 @@
import time
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.agents.utils.agent_utils import build_instrument_context

View File

@ -1,6 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_bear_researcher(llm, memory): def create_bear_researcher(llm, memory):

View File

@ -1,6 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_bull_researcher(llm, memory): def create_bull_researcher(llm, memory):

View File

@ -1,5 +1,3 @@
import time
import json
def create_aggressive_debator(llm): def create_aggressive_debator(llm):

View File

@ -1,6 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_conservative_debator(llm): def create_conservative_debator(llm):

View File

@ -1,5 +1,3 @@
import time
import json
def create_neutral_debator(llm): def create_neutral_debator(llm):

View File

@ -1,6 +1,4 @@
import functools import functools
import time
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.agents.utils.agent_utils import build_instrument_context

View File

@ -1,10 +1,6 @@
from typing import Annotated, Sequence from typing import Annotated
from datetime import date, timedelta, datetime from typing_extensions import TypedDict
from typing_extensions import TypedDict, Optional from langgraph.graph import MessagesState
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
# Researcher team state # Researcher team state

View File

@ -20,6 +20,20 @@ from tradingagents.agents.utils.news_data_tools import (
) )
def get_language_instruction() -> str:
"""Return a prompt instruction for the configured output language.
Returns empty string when English (default), so no extra tokens are used.
Only applied to user-facing agents (analysts, portfolio manager).
Internal debate agents stay in English for reasoning quality.
"""
from tradingagents.dataflows.config import get_config
lang = get_config().get("output_language", "English")
if lang.strip().lower() == "english":
return ""
return f" Write your entire response in {lang}."
def build_instrument_context(ticker: str) -> str: def build_instrument_context(ticker: str) -> str:
"""Describe the exact instrument so agents preserve exchange-qualified tickers.""" """Describe the exact instrument so agents preserve exchange-qualified tickers."""
return ( return (

View File

@ -78,7 +78,7 @@ class FinancialSituationMemory:
# Build results # Build results
results = [] results = []
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0
for idx in top_indices: for idx in top_indices:
# Normalize score to 0-1 range for consistency # Normalize score to 0-1 range for consistency

View File

@ -22,10 +22,11 @@ def get_indicators(
""" """
# LLMs sometimes pass multiple indicators as a comma-separated string; # LLMs sometimes pass multiple indicators as a comma-separated string;
# split and process each individually. # split and process each individually.
indicators = [i.strip() for i in indicator.split(",") if i.strip()] indicators = [i.strip().lower() for i in indicator.split(",") if i.strip()]
if len(indicators) > 1: results = []
results = [] for ind in indicators:
for ind in indicators: try:
results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days)) results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days))
return "\n\n".join(results) except ValueError as e:
return route_to_vendor("get_indicators", symbol, indicator.strip(), curr_date, look_back_days) results.append(str(e))
return "\n\n".join(results)

View File

@ -1,6 +1,23 @@
from .alpha_vantage_common import _make_api_request from .alpha_vantage_common import _make_api_request
def _filter_reports_by_date(result, curr_date: str):
"""Filter annualReports/quarterlyReports to exclude entries after curr_date.
Prevents look-ahead bias by removing fiscal periods that end after
the simulation's current date.
"""
if not curr_date or not isinstance(result, dict):
return result
for key in ("annualReports", "quarterlyReports"):
if key in result:
result[key] = [
r for r in result[key]
if r.get("fiscalDateEnding", "") <= curr_date
]
return result
def get_fundamentals(ticker: str, curr_date: str = None) -> str: def get_fundamentals(ticker: str, curr_date: str = None) -> str:
""" """
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage. Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
@ -19,59 +36,20 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str:
return _make_api_request("OVERVIEW", params) return _make_api_request("OVERVIEW", params)
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None):
""" """Retrieve balance sheet data for a given ticker symbol using Alpha Vantage."""
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage. result = _make_api_request("BALANCE_SHEET", {"symbol": ticker})
return _filter_reports_by_date(result, curr_date)
Args:
ticker (str): Ticker symbol of the company
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
Returns:
str: Balance sheet data with normalized fields
"""
params = {
"symbol": ticker,
}
return _make_api_request("BALANCE_SHEET", params)
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None):
""" """Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage."""
Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage. result = _make_api_request("CASH_FLOW", {"symbol": ticker})
return _filter_reports_by_date(result, curr_date)
Args:
ticker (str): Ticker symbol of the company
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
Returns:
str: Cash flow statement data with normalized fields
"""
params = {
"symbol": ticker,
}
return _make_api_request("CASH_FLOW", params)
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None):
""" """Retrieve income statement data for a given ticker symbol using Alpha Vantage."""
Retrieve income statement data for a given ticker symbol using Alpha Vantage. result = _make_api_request("INCOME_STATEMENT", {"symbol": ticker})
return _filter_reports_by_date(result, curr_date)
Args:
ticker (str): Ticker symbol of the company
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
Returns:
str: Income statement data with normalized fields
"""
params = {
"symbol": ticker,
}
return _make_api_request("INCOME_STATEMENT", params)

View File

@ -44,6 +44,64 @@ def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
return data return data
def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
"""Fetch OHLCV data with caching, filtered to prevent look-ahead bias.
Downloads 15 years of data up to today and caches per symbol. On
subsequent calls the cache is reused. Rows after curr_date are
filtered out so backtests never see future prices.
"""
config = get_config()
curr_date_dt = pd.to_datetime(curr_date)
# Cache uses a fixed window (15y to today) so one file per symbol
today_date = pd.Timestamp.today()
start_date = today_date - pd.DateOffset(years=5)
start_str = start_date.strftime("%Y-%m-%d")
end_str = today_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_str}-{end_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file, on_bad_lines="skip")
else:
data = yf_retry(lambda: yf.download(
symbol,
start=start_str,
end=end_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
))
data = data.reset_index()
data.to_csv(data_file, index=False)
data = _clean_dataframe(data)
# Filter to curr_date to prevent look-ahead bias in backtesting
data = data[data["Date"] <= curr_date_dt]
return data
def filter_financials_by_date(data: pd.DataFrame, curr_date: str) -> pd.DataFrame:
"""Drop financial statement columns (fiscal period timestamps) after curr_date.
yfinance financial statements use fiscal period end dates as columns.
Columns after curr_date represent future data and are removed to
prevent look-ahead bias.
"""
if not curr_date or data.empty:
return data
cutoff = pd.Timestamp(curr_date)
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
return data.loc[:, mask]
class StockstatsUtils: class StockstatsUtils:
@staticmethod @staticmethod
def get_stock_stats( def get_stock_stats(
@ -55,42 +113,10 @@ class StockstatsUtils:
str, "curr date for retrieving stock price data, YYYY-mm-dd" str, "curr date for retrieving stock price data, YYYY-mm-dd"
], ],
): ):
config = get_config() data = load_ohlcv(symbol, curr_date)
today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date)
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
# Ensure cache directory exists
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file, on_bad_lines="skip")
else:
data = yf_retry(lambda: yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
))
data = data.reset_index()
data.to_csv(data_file, index=False)
data = _clean_dataframe(data)
df = wrap(data) df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date_str = curr_date_dt.strftime("%Y-%m-%d") curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d")
df[indicator] # trigger stockstats to calculate the indicator df[indicator] # trigger stockstats to calculate the indicator
matching_rows = df[df["Date"].str.startswith(curr_date_str)] matching_rows = df[df["Date"].str.startswith(curr_date_str)]

View File

@ -1,9 +1,10 @@
from typing import Annotated from typing import Annotated
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
import pandas as pd
import yfinance as yf import yfinance as yf
import os import os
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
@ -194,58 +195,9 @@ def _get_stock_stats_bulk(
Fetches data once and calculates indicator for all available dates. Fetches data once and calculates indicator for all available dates.
Returns dict mapping date strings to indicator values. Returns dict mapping date strings to indicator values.
""" """
from .config import get_config
import pandas as pd
from stockstats import wrap from stockstats import wrap
import os
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
if not online:
# Local data path
try:
data = pd.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
),
on_bad_lines="skip",
)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
else:
# Online data fetching with caching
today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date)
end_date = today_date data = load_ohlcv(symbol, curr_date)
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file, on_bad_lines="skip")
else:
data = yf_retry(lambda: yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
))
data = data.reset_index()
data.to_csv(data_file, index=False)
data = _clean_dataframe(data)
df = wrap(data) df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
@ -353,7 +305,7 @@ def get_fundamentals(
def get_balance_sheet( def get_balance_sheet(
ticker: Annotated[str, "ticker symbol of the company"], ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
): ):
"""Get balance sheet data from yfinance.""" """Get balance sheet data from yfinance."""
try: try:
@ -363,7 +315,9 @@ def get_balance_sheet(
data = yf_retry(lambda: ticker_obj.quarterly_balance_sheet) data = yf_retry(lambda: ticker_obj.quarterly_balance_sheet)
else: else:
data = yf_retry(lambda: ticker_obj.balance_sheet) data = yf_retry(lambda: ticker_obj.balance_sheet)
data = filter_financials_by_date(data, curr_date)
if data.empty: if data.empty:
return f"No balance sheet data found for symbol '{ticker}'" return f"No balance sheet data found for symbol '{ticker}'"
@ -383,7 +337,7 @@ def get_balance_sheet(
def get_cashflow( def get_cashflow(
ticker: Annotated[str, "ticker symbol of the company"], ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
): ):
"""Get cash flow data from yfinance.""" """Get cash flow data from yfinance."""
try: try:
@ -393,7 +347,9 @@ def get_cashflow(
data = yf_retry(lambda: ticker_obj.quarterly_cashflow) data = yf_retry(lambda: ticker_obj.quarterly_cashflow)
else: else:
data = yf_retry(lambda: ticker_obj.cashflow) data = yf_retry(lambda: ticker_obj.cashflow)
data = filter_financials_by_date(data, curr_date)
if data.empty: if data.empty:
return f"No cash flow data found for symbol '{ticker}'" return f"No cash flow data found for symbol '{ticker}'"
@ -413,7 +369,7 @@ def get_cashflow(
def get_income_statement( def get_income_statement(
ticker: Annotated[str, "ticker symbol of the company"], ticker: Annotated[str, "ticker symbol of the company"],
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
curr_date: Annotated[str, "current date (not used for yfinance)"] = None curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
): ):
"""Get income statement data from yfinance.""" """Get income statement data from yfinance."""
try: try:
@ -423,7 +379,9 @@ def get_income_statement(
data = yf_retry(lambda: ticker_obj.quarterly_income_stmt) data = yf_retry(lambda: ticker_obj.quarterly_income_stmt)
else: else:
data = yf_retry(lambda: ticker_obj.income_stmt) data = yf_retry(lambda: ticker_obj.income_stmt)
data = filter_financials_by_date(data, curr_date)
if data.empty: if data.empty:
return f"No income statement data found for symbol '{ticker}'" return f"No income statement data found for symbol '{ticker}'"

View File

@ -4,6 +4,8 @@ import yfinance as yf
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from .stockstats_utils import yf_retry
def _extract_article_data(article: dict) -> dict: def _extract_article_data(article: dict) -> dict:
"""Extract article data from yfinance news format (handles nested 'content' structure).""" """Extract article data from yfinance news format (handles nested 'content' structure)."""
@ -64,7 +66,7 @@ def get_news_yfinance(
""" """
try: try:
stock = yf.Ticker(ticker) stock = yf.Ticker(ticker)
news = stock.get_news(count=20) news = yf_retry(lambda: stock.get_news(count=20))
if not news: if not news:
return f"No news found for {ticker}" return f"No news found for {ticker}"
@ -131,11 +133,11 @@ def get_global_news_yfinance(
try: try:
for query in search_queries: for query in search_queries:
search = yf.Search( search = yf_retry(lambda q=query: yf.Search(
query=query, query=q,
news_count=limit, news_count=limit,
enable_fuzzy_query=True, enable_fuzzy_query=True,
) ))
if search.news: if search.news:
for article in search.news: for article in search.news:
@ -167,6 +169,11 @@ def get_global_news_yfinance(
# Handle both flat and nested structures # Handle both flat and nested structures
if "content" in article: if "content" in article:
data = _extract_article_data(article) data = _extract_article_data(article)
# Skip articles published after curr_date (look-ahead guard)
if data.get("pub_date"):
pub_naive = data["pub_date"].replace(tzinfo=None) if hasattr(data["pub_date"], "replace") else data["pub_date"]
if pub_naive > curr_dt + relativedelta(days=1):
continue
title = data["title"] title = data["title"]
publisher = data["publisher"] publisher = data["publisher"]
link = data["link"] link = data["link"]

View File

@ -1,21 +1,23 @@
import os import os
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
"data_cache_dir": os.path.join( "data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# LLM settings # LLM settings
"llm_provider": "openai", "llm_provider": "openai",
"deep_think_llm": "gpt-5.2", "deep_think_llm": "gpt-5.4",
"quick_think_llm": "gpt-5-mini", "quick_think_llm": "gpt-5.4-mini",
"backend_url": "https://api.openai.com/v1", "backend_url": "https://api.openai.com/v1",
# Provider-specific thinking configuration # Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc. "google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low" "openai_reasoning_effort": None, # "medium", "high", "low"
"anthropic_effort": None, # "high", "medium", "low" "anthropic_effort": None, # "high", "medium", "low"
# Output language for analyst reports and final decision
# Internal agent debate stays in English for reasoning quality
"output_language": "English",
# Debate and discussion settings # Debate and discussion settings
"max_debate_rounds": 1, "max_debate_rounds": 1,
"max_risk_discuss_rounds": 1, "max_risk_discuss_rounds": 1,

View File

@ -1,13 +1,12 @@
# TradingAgents/graph/reflection.py # TradingAgents/graph/reflection.py
from typing import Dict, Any from typing import Any, Dict
from langchain_openai import ChatOpenAI
class Reflector: class Reflector:
"""Handles reflection on decisions and updating memory.""" """Handles reflection on decisions and updating memory."""
def __init__(self, quick_thinking_llm: ChatOpenAI): def __init__(self, quick_thinking_llm: Any):
"""Initialize the reflector with an LLM.""" """Initialize the reflector with an LLM."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm
self.reflection_system_prompt = self._get_reflection_prompt() self.reflection_system_prompt = self._get_reflection_prompt()

View File

@ -1,8 +1,7 @@
# TradingAgents/graph/setup.py # TradingAgents/graph/setup.py
from typing import Dict, Any from typing import Any, Dict
from langchain_openai import ChatOpenAI from langgraph.graph import END, START, StateGraph
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from tradingagents.agents import * from tradingagents.agents import *
@ -16,8 +15,8 @@ class GraphSetup:
def __init__( def __init__(
self, self,
quick_thinking_llm: ChatOpenAI, quick_thinking_llm: Any,
deep_thinking_llm: ChatOpenAI, deep_thinking_llm: Any,
tool_nodes: Dict[str, ToolNode], tool_nodes: Dict[str, ToolNode],
bull_memory, bull_memory,
bear_memory, bear_memory,

View File

@ -1,12 +1,12 @@
# TradingAgents/graph/signal_processing.py # TradingAgents/graph/signal_processing.py
from langchain_openai import ChatOpenAI from typing import Any
class SignalProcessor: class SignalProcessor:
"""Processes trading signals to extract actionable decisions.""" """Processes trading signals to extract actionable decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI): def __init__(self, quick_thinking_llm: Any):
"""Initialize with an LLM for processing.""" """Initialize with an LLM for processing."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm

View File

@ -66,10 +66,8 @@ class TradingAgentsGraph:
set_config(self.config) set_config(self.config)
# Create necessary directories # Create necessary directories
os.makedirs( os.makedirs(self.config["data_cache_dir"], exist_ok=True)
os.path.join(self.config["project_dir"], "dataflows/data_cache"), os.makedirs(self.config["results_dir"], exist_ok=True)
exist_ok=True,
)
# Initialize LLMs with provider-specific thinking configuration # Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs() llm_kwargs = self._get_provider_kwargs()
@ -259,15 +257,12 @@ class TradingAgentsGraph:
} }
# Save to file # Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs"
directory.mkdir(parents=True, exist_ok=True) directory.mkdir(parents=True, exist_ok=True)
with open( log_path = directory / f"full_states_log_{trade_date}.json"
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", with open(log_path, "w", encoding="utf-8") as f:
"w", json.dump(self.log_states_dict[str(trade_date)], f, indent=4)
encoding="utf-8",
) as f:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses): def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns.""" """Reflect on decisions and update memory based on returns."""

View File

@ -5,20 +5,11 @@
### 1. `validate_model()` is never called ### 1. `validate_model()` is never called
- Add validation call in `get_llm()` with warning (not error) for unknown models - Add validation call in `get_llm()` with warning (not error) for unknown models
### 2. Inconsistent parameter handling ### 2. ~~Inconsistent parameter handling~~ (Fixed)
| Client | API Key Param | Special Params | - GoogleClient now accepts unified `api_key` and maps it to `google_api_key`
|--------|---------------|----------------|
| OpenAI | `api_key` | `reasoning_effort` |
| Anthropic | `api_key` | `thinking_config``thinking` |
| Google | `google_api_key` | `thinking_budget` |
**Fix:** Standardize with unified `api_key` that maps to provider-specific keys ### 3. ~~`base_url` accepted but ignored~~ (Fixed)
- All clients now pass `base_url` to their respective LLM constructors
### 3. `base_url` accepted but ignored ### 4. ~~Update validators.py with models from CLI~~ (Fixed)
- `AnthropicClient`: accepts `base_url` but never uses it - Synced in v0.2.2
- `GoogleClient`: accepts `base_url` but never uses it (correct - Google doesn't support it)
**Fix:** Remove unused `base_url` from clients that don't support it
### 4. Update validators.py with models from CLI
- Sync `VALID_MODELS` dict with CLI model options after Feature 2 is complete

View File

@ -31,8 +31,12 @@ class AnthropicClient(BaseLLMClient):
def get_llm(self) -> Any: def get_llm(self) -> Any:
"""Return configured ChatAnthropic instance.""" """Return configured ChatAnthropic instance."""
self.warn_if_unknown_model()
llm_kwargs = {"model": self.model} llm_kwargs = {"model": self.model}
if self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in _PASSTHROUGH_KWARGS: for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs: if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key] llm_kwargs[key] = self.kwargs[key]

View File

@ -0,0 +1,52 @@
import os
from typing import Any, Optional
from langchain_openai import AzureChatOpenAI
from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "api_key", "reasoning_effort",
"callbacks", "http_client", "http_async_client",
)
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
"""AzureChatOpenAI with normalized content output."""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
class AzureOpenAIClient(BaseLLMClient):
"""Client for Azure OpenAI deployments.
Requires environment variables:
AZURE_OPENAI_API_KEY: API key
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
"""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured AzureChatOpenAI instance."""
self.warn_if_unknown_model()
llm_kwargs = {
"model": self.model,
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
}
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return NormalizedAzureChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Azure accepts any deployed model name."""
return True

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Optional from typing import Any, Optional
import warnings
def normalize_content(response): def normalize_content(response):
@ -29,6 +30,27 @@ class BaseLLMClient(ABC):
self.base_url = base_url self.base_url = base_url
self.kwargs = kwargs self.kwargs = kwargs
def get_provider_name(self) -> str:
"""Return the provider name used in warning messages."""
provider = getattr(self, "provider", None)
if provider:
return str(provider)
return self.__class__.__name__.removesuffix("Client").lower()
def warn_if_unknown_model(self) -> None:
"""Warn when the model is outside the known list for the provider."""
if self.validate_model():
return
warnings.warn(
(
f"Model '{self.model}' is not in the known model list for "
f"provider '{self.get_provider_name()}'. Continuing anyway."
),
RuntimeWarning,
stacklevel=2,
)
@abstractmethod @abstractmethod
def get_llm(self) -> Any: def get_llm(self) -> Any:
"""Return the configured LLM instance.""" """Return the configured LLM instance."""

View File

@ -4,6 +4,12 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient from .anthropic_client import AnthropicClient
from .google_client import GoogleClient from .google_client import GoogleClient
from .azure_client import AzureOpenAIClient
# Providers that use the OpenAI-compatible chat completions API
_OPENAI_COMPATIBLE = (
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
)
def create_llm_client( def create_llm_client(
@ -15,16 +21,10 @@ def create_llm_client(
"""Create an LLM client for the specified provider. """Create an LLM client for the specified provider.
Args: Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) provider: LLM provider name
model: Model name/identifier model: Model name/identifier
base_url: Optional base URL for API endpoint base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments **kwargs: Additional provider-specific arguments
- http_client: Custom httpx.Client for SSL proxy or certificate customization
- http_async_client: Custom httpx.AsyncClient for async operations
- timeout: Request timeout in seconds
- max_retries: Maximum retry attempts
- api_key: API key for the provider
- callbacks: LangChain callbacks
Returns: Returns:
Configured BaseLLMClient instance Configured BaseLLMClient instance
@ -34,16 +34,16 @@ def create_llm_client(
""" """
provider_lower = provider.lower() provider_lower = provider.lower()
if provider_lower in ("openai", "ollama", "openrouter"): if provider_lower in _OPENAI_COMPATIBLE:
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
if provider_lower == "xai":
return OpenAIClient(model, base_url, provider="xai", **kwargs)
if provider_lower == "anthropic": if provider_lower == "anthropic":
return AnthropicClient(model, base_url, **kwargs) return AnthropicClient(model, base_url, **kwargs)
if provider_lower == "google": if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs) return GoogleClient(model, base_url, **kwargs)
if provider_lower == "azure":
return AzureOpenAIClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -25,12 +25,21 @@ class GoogleClient(BaseLLMClient):
def get_llm(self) -> Any: def get_llm(self) -> Any:
"""Return configured ChatGoogleGenerativeAI instance.""" """Return configured ChatGoogleGenerativeAI instance."""
self.warn_if_unknown_model()
llm_kwargs = {"model": self.model} llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "google_api_key", "callbacks", "http_client", "http_async_client"): if self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in ("timeout", "max_retries", "callbacks", "http_client", "http_async_client"):
if key in self.kwargs: if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key] llm_kwargs[key] = self.kwargs[key]
# Unified api_key maps to provider-specific google_api_key
google_api_key = self.kwargs.get("api_key") or self.kwargs.get("google_api_key")
if google_api_key:
llm_kwargs["google_api_key"] = google_api_key
# Map thinking_level to appropriate API param based on model # Map thinking_level to appropriate API param based on model
# Gemini 3 Pro: low, high # Gemini 3 Pro: low, high
# Gemini 3 Flash: minimal, low, medium, high # Gemini 3 Flash: minimal, low, medium, high

View File

@ -0,0 +1,134 @@
"""Shared model catalog for CLI selections and validation."""
from __future__ import annotations
from typing import Dict, List, Tuple
ModelOption = Tuple[str, str]
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
MODEL_OPTIONS: ProviderModeOptions = {
"openai": {
"quick": [
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
("GPT-5.4 Nano - Cheapest, high-volume tasks", "gpt-5.4-nano"),
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
],
"deep": [
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
("GPT-5.4 Mini - Fast, strong coding and tool use", "gpt-5.4-mini"),
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
],
},
"anthropic": {
"quick": [
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
],
"deep": [
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
],
},
"google": {
"quick": [
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
],
"deep": [
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
],
},
"xai": {
"quick": [
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
],
"deep": [
("Grok 4 - Flagship model", "grok-4-0709"),
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
],
},
"deepseek": {
"quick": [
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
"deep": [
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
],
},
"qwen": {
"quick": [
("Qwen 3.5 Flash", "qwen3.5-flash"),
("Qwen Plus", "qwen-plus"),
("Custom model ID", "custom"),
],
"deep": [
("Qwen 3.6 Plus", "qwen3.6-plus"),
("Qwen 3.5 Plus", "qwen3.5-plus"),
("Qwen 3 Max", "qwen3-max"),
("Custom model ID", "custom"),
],
},
"glm": {
"quick": [
("GLM-4.7", "glm-4.7"),
("GLM-5", "glm-5"),
("Custom model ID", "custom"),
],
"deep": [
("GLM-5.1", "glm-5.1"),
("GLM-5", "glm-5"),
("Custom model ID", "custom"),
],
},
# OpenRouter: fetched dynamically. Azure: any deployed model name.
"ollama": {
"quick": [
("Qwen3:latest (8B, local)", "qwen3:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
],
"deep": [
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"),
],
},
}
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
"""Return shared model options for a provider and selection mode."""
return MODEL_OPTIONS[provider.lower()][mode]
def get_known_models() -> Dict[str, List[str]]:
"""Build known model names from the shared CLI catalog."""
return {
provider: sorted(
{
value
for options in mode_options.values()
for _, value in options
}
)
for provider, mode_options in MODEL_OPTIONS.items()
}

View File

@ -27,6 +27,9 @@ _PASSTHROUGH_KWARGS = (
# Provider base URLs and API key env vars # Provider base URLs and API key env vars
_PROVIDER_CONFIG = { _PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"), "xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None), "ollama": ("http://localhost:11434/v1", None),
} }
@ -53,6 +56,7 @@ class OpenAIClient(BaseLLMClient):
def get_llm(self) -> Any: def get_llm(self) -> Any:
"""Return configured ChatOpenAI instance.""" """Return configured ChatOpenAI instance."""
self.warn_if_unknown_model()
llm_kwargs = {"model": self.model} llm_kwargs = {"model": self.model}
# Provider-specific base URL and auth # Provider-specific base URL and auth

View File

@ -1,53 +1,12 @@
"""Model name validators for each provider. """Model name validators for each provider."""
from .model_catalog import get_known_models
Only validates model names - does NOT enforce limits.
Let LLM providers use their own defaults for unspecified params.
"""
VALID_MODELS = { VALID_MODELS = {
"openai": [ provider: models
# GPT-5 series for provider, models in get_known_models().items()
"gpt-5.4-pro", if provider not in ("ollama", "openrouter")
"gpt-5.4",
"gpt-5.2",
"gpt-5.1",
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
# GPT-4.1 series
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
],
"anthropic": [
# Claude 4.6 series (latest)
"claude-opus-4-6",
"claude-sonnet-4-6",
# Claude 4.5 series
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-haiku-4-5",
],
"google": [
# Gemini 3.1 series (preview)
"gemini-3.1-pro-preview",
"gemini-3.1-flash-lite-preview",
# Gemini 3 series (preview)
"gemini-3-flash-preview",
# Gemini 2.5 series
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
],
"xai": [
# Grok 4.1 series
"grok-4-1-fast-reasoning",
"grok-4-1-fast-non-reasoning",
# Grok 4 series
"grok-4-0709",
"grok-4-fast-reasoning",
"grok-4-fast-non-reasoning",
],
} }