Compare commits

...

11 Commits
v0.2.3 ... main

34 changed files with 340 additions and 149 deletions

15
.dockerignore Normal file
View File

@ -0,0 +1,15 @@
.git
.venv
.env
.claude
.idea
.vscode
.DS_Store
__pycache__
*.egg-info
build
dist
results
eval_results
Dockerfile
docker-compose.yml

5
.env.enterprise.example Normal file
View File

@ -0,0 +1,5 @@
# Azure OpenAI
AZURE_OPENAI_API_KEY=
AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
AZURE_OPENAI_DEPLOYMENT_NAME=
# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API

View File

@ -3,4 +3,7 @@ OPENAI_API_KEY=
GOOGLE_API_KEY= GOOGLE_API_KEY=
ANTHROPIC_API_KEY= ANTHROPIC_API_KEY=
XAI_API_KEY= XAI_API_KEY=
DEEPSEEK_API_KEY=
DASHSCOPE_API_KEY=
ZHIPU_API_KEY=
OPENROUTER_API_KEY= OPENROUTER_API_KEY=

27
Dockerfile Normal file
View File

@ -0,0 +1,27 @@
FROM python:3.12-slim AS builder
ENV PYTHONDONTWRITEBYTECODE=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
WORKDIR /build
COPY . .
RUN pip install --no-cache-dir .
FROM python:3.12-slim
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
RUN useradd --create-home appuser
USER appuser
WORKDIR /home/appuser/app
COPY --from=builder --chown=appuser:appuser /build .
ENTRYPOINT ["tradingagents"]

View File

@ -118,6 +118,19 @@ Install the package and its dependencies:
pip install . pip install .
``` ```
### Docker
Alternatively, run with Docker:
```bash
cp .env.example .env # add your API keys
docker compose run --rm tradingagents
```
For local models with Ollama:
```bash
docker compose --profile ollama run --rm tradingagents-ollama
```
### Required APIs ### Required APIs
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider: TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
@ -127,10 +140,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT)
export GOOGLE_API_KEY=... # Google (Gemini) export GOOGLE_API_KEY=... # Google (Gemini)
export ANTHROPIC_API_KEY=... # Anthropic (Claude) export ANTHROPIC_API_KEY=... # Anthropic (Claude)
export XAI_API_KEY=... # xAI (Grok) export XAI_API_KEY=... # xAI (Grok)
export DEEPSEEK_API_KEY=... # DeepSeek
export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope)
export ZHIPU_API_KEY=... # GLM (Zhipu)
export OPENROUTER_API_KEY=... # OpenRouter export OPENROUTER_API_KEY=... # OpenRouter
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
``` ```
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
For local models, configure Ollama with `llm_provider: "ollama"` in your config. For local models, configure Ollama with `llm_provider: "ollama"` in your config.
Alternatively, copy `.env.example` to `.env` and fill in your keys: Alternatively, copy `.env.example` to `.env` and fill in your keys:

View File

@ -6,8 +6,9 @@ from functools import wraps
from rich.console import Console from rich.console import Console
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from .env file # Load environment variables
load_dotenv() load_dotenv()
load_dotenv(".env.enterprise", override=False)
from rich.panel import Panel from rich.panel import Panel
from rich.spinner import Spinner from rich.spinner import Spinner
from rich.live import Live from rich.live import Live
@ -79,7 +80,7 @@ class MessageBuffer:
self.current_agent = None self.current_agent = None
self.report_sections = {} self.report_sections = {}
self.selected_analysts = [] self.selected_analysts = []
self._last_message_id = None self._processed_message_ids = set()
def init_for_analysis(self, selected_analysts): def init_for_analysis(self, selected_analysts):
"""Initialize agent status and report sections based on selected analysts. """Initialize agent status and report sections based on selected analysts.
@ -114,7 +115,7 @@ class MessageBuffer:
self.current_agent = None self.current_agent = None
self.messages.clear() self.messages.clear()
self.tool_calls.clear() self.tool_calls.clear()
self._last_message_id = None self._processed_message_ids.clear()
def get_completed_reports_count(self): def get_completed_reports_count(self):
"""Count reports that are finalized (their finalizing agent is completed). """Count reports that are finalized (their finalizing agent is completed).
@ -1052,26 +1053,22 @@ def run_analysis():
# Stream the analysis # Stream the analysis
trace = [] trace = []
for chunk in graph.graph.stream(init_agent_state, **args): for chunk in graph.graph.stream(init_agent_state, **args):
# Process messages if present (skip duplicates via message ID) # Process all messages in chunk, deduplicating by message ID
if len(chunk["messages"]) > 0: for message in chunk.get("messages", []):
last_message = chunk["messages"][-1] msg_id = getattr(message, "id", None)
msg_id = getattr(last_message, "id", None) if msg_id is not None:
if msg_id in message_buffer._processed_message_ids:
continue
message_buffer._processed_message_ids.add(msg_id)
if msg_id != message_buffer._last_message_id: msg_type, content = classify_message_type(message)
message_buffer._last_message_id = msg_id
# Add message to buffer
msg_type, content = classify_message_type(last_message)
if content and content.strip(): if content and content.strip():
message_buffer.add_message(msg_type, content) message_buffer.add_message(msg_type, content)
# Handle tool calls if hasattr(message, "tool_calls") and message.tool_calls:
if hasattr(last_message, "tool_calls") and last_message.tool_calls: for tool_call in message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict): if isinstance(tool_call, dict):
message_buffer.add_tool_call( message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
tool_call["name"], tool_call["args"]
)
else: else:
message_buffer.add_tool_call(tool_call.name, tool_call.args) message_buffer.add_tool_call(tool_call.name, tool_call.args)

View File

@ -134,14 +134,70 @@ def select_research_depth() -> int:
return choice return choice
def select_shallow_thinking_agent(provider) -> str: def _fetch_openrouter_models() -> List[Tuple[str, str]]:
"""Select shallow thinking llm engine using an interactive selection.""" """Fetch available models from the OpenRouter API."""
import requests
try:
resp = requests.get("https://openrouter.ai/api/v1/models", timeout=10)
resp.raise_for_status()
models = resp.json().get("data", [])
return [(m.get("name") or m["id"], m["id"]) for m in models]
except Exception as e:
console.print(f"\n[yellow]Could not fetch OpenRouter models: {e}[/yellow]")
return []
def select_openrouter_model() -> str:
"""Select an OpenRouter model from the newest available, or enter a custom ID."""
models = _fetch_openrouter_models()
choices = [questionary.Choice(name, value=mid) for name, mid in models[:5]]
choices.append(questionary.Choice("Custom model ID", value="custom"))
choice = questionary.select( choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:", "Select OpenRouter Model (latest available):",
choices=choices,
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style([
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]),
).ask()
if choice is None or choice == "custom":
return questionary.text(
"Enter OpenRouter model ID (e.g. google/gemma-4-26b-a4b-it):",
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
).ask().strip()
return choice
def _prompt_custom_model_id() -> str:
"""Prompt user to type a custom model ID."""
return questionary.text(
"Enter model ID:",
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
).ask().strip()
def _select_model(provider: str, mode: str) -> str:
"""Select a model for the given provider and mode (quick/deep)."""
if provider.lower() == "openrouter":
return select_openrouter_model()
if provider.lower() == "azure":
return questionary.text(
f"Enter Azure deployment name ({mode}-thinking):",
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
).ask().strip()
choice = questionary.select(
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
choices=[ choices=[
questionary.Choice(display, value=value) questionary.Choice(display, value=value)
for display, value in get_model_options(provider, "quick") for display, value in get_model_options(provider, mode)
], ],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select", instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style( style=questionary.Style(
@ -154,56 +210,45 @@ def select_shallow_thinking_agent(provider) -> str:
).ask() ).ask()
if choice is None: if choice is None:
console.print( console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
)
exit(1) exit(1)
if choice == "custom":
return _prompt_custom_model_id()
return choice return choice
def select_shallow_thinking_agent(provider) -> str:
"""Select shallow thinking llm engine using an interactive selection."""
return _select_model(provider, "quick")
def select_deep_thinking_agent(provider) -> str: def select_deep_thinking_agent(provider) -> str:
"""Select deep thinking llm engine using an interactive selection.""" """Select deep thinking llm engine using an interactive selection."""
return _select_model(provider, "deep")
choice = questionary.select( def select_llm_provider() -> tuple[str, str | None]:
"Select Your [Deep-Thinking LLM Engine]:", """Select the LLM provider and its API endpoint."""
choices=[ # (display_name, provider_key, base_url)
questionary.Choice(display, value=value) PROVIDERS = [
for display, value in get_model_options(provider, "deep") ("OpenAI", "openai", "https://api.openai.com/v1"),
], ("Google", "google", None),
instruction="\n- Use arrow keys to navigate\n- Press Enter to select", ("Anthropic", "anthropic", "https://api.anthropic.com/"),
style=questionary.Style( ("xAI", "xai", "https://api.x.ai/v1"),
[ ("DeepSeek", "deepseek", "https://api.deepseek.com"),
("selected", "fg:magenta noinherit"), ("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
("highlighted", "fg:magenta noinherit"), ("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
("pointer", "fg:magenta noinherit"), ("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
] ("Azure OpenAI", "azure", None),
), ("Ollama", "ollama", "http://localhost:11434/v1"),
).ask()
if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
exit(1)
return choice
def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection."""
# Define OpenAI api options with their corresponding endpoints
BASE_URLS = [
("OpenAI", "https://api.openai.com/v1"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Anthropic", "https://api.anthropic.com/"),
("xAI", "https://api.x.ai/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
] ]
choice = questionary.select( choice = questionary.select(
"Select your LLM Provider:", "Select your LLM Provider:",
choices=[ choices=[
questionary.Choice(display, value=(display, value)) questionary.Choice(display, value=(provider_key, url))
for display, value in BASE_URLS for display, provider_key, url in PROVIDERS
], ],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select", instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style( style=questionary.Style(
@ -216,13 +261,11 @@ def select_llm_provider() -> tuple[str, str]:
).ask() ).ask()
if choice is None: if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") console.print("\n[red]No LLM provider selected. Exiting...[/red]")
exit(1) exit(1)
display_name, url = choice provider, url = choice
print(f"You selected: {display_name}\tURL: {url}") return provider, url
return display_name, url
def ask_openai_reasoning_effort() -> str: def ask_openai_reasoning_effort() -> str:

35
docker-compose.yml Normal file
View File

@ -0,0 +1,35 @@
services:
tradingagents:
build: .
env_file:
- .env
volumes:
- tradingagents_data:/home/appuser/.tradingagents
tty: true
stdin_open: true
ollama:
image: ollama/ollama:latest
volumes:
- ollama_data:/root/.ollama
profiles:
- ollama
tradingagents-ollama:
build: .
env_file:
- .env
environment:
- LLM_PROVIDER=ollama
volumes:
- tradingagents_data:/home/appuser/.tradingagents
depends_on:
- ollama
tty: true
stdin_open: true
profiles:
- ollama
volumes:
tradingagents_data:
ollama_data:

View File

@ -13,7 +13,7 @@ dependencies = [
"backtrader>=1.9.78.123", "backtrader>=1.9.78.123",
"langchain-anthropic>=0.3.15", "langchain-anthropic>=0.3.15",
"langchain-experimental>=0.3.4", "langchain-experimental>=0.3.4",
"langchain-google-genai>=2.1.5", "langchain-google-genai>=4.0.0",
"langchain-openai>=0.3.23", "langchain-openai>=0.3.23",
"langgraph>=0.4.8", "langgraph>=0.4.8",
"pandas>=2.3.0", "pandas>=2.3.0",

View File

@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
build_instrument_context, build_instrument_context,
get_balance_sheet, get_balance_sheet,

View File

@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
build_instrument_context, build_instrument_context,
get_indicators, get_indicators,

View File

@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
build_instrument_context, build_instrument_context,
get_global_news, get_global_news,

View File

@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config

View File

@ -12,7 +12,8 @@ def create_portfolio_manager(llm, memory):
news_report = state["news_report"] news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"] fundamentals_report = state["fundamentals_report"]
sentiment_report = state["sentiment_report"] sentiment_report = state["sentiment_report"]
trader_plan = state["investment_plan"] research_plan = state["investment_plan"]
trader_plan = state["trader_investment_plan"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
past_memories = memory.get_memories(curr_situation, n_matches=2) past_memories = memory.get_memories(curr_situation, n_matches=2)
@ -35,7 +36,8 @@ def create_portfolio_manager(llm, memory):
- **Sell**: Exit position or avoid entry - **Sell**: Exit position or avoid entry
**Context:** **Context:**
- Trader's proposed plan: **{trader_plan}** - Research Manager's investment plan: **{research_plan}**
- Trader's transaction proposal: **{trader_plan}**
- Lessons from past decisions: **{past_memory_str}** - Lessons from past decisions: **{past_memory_str}**
**Required Output Structure:** **Required Output Structure:**

View File

@ -1,5 +1,3 @@
import time
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.agents.utils.agent_utils import build_instrument_context

View File

@ -1,6 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_bear_researcher(llm, memory): def create_bear_researcher(llm, memory):

View File

@ -1,6 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_bull_researcher(llm, memory): def create_bull_researcher(llm, memory):

View File

@ -1,5 +1,3 @@
import time
import json
def create_aggressive_debator(llm): def create_aggressive_debator(llm):

View File

@ -1,6 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_conservative_debator(llm): def create_conservative_debator(llm):

View File

@ -1,5 +1,3 @@
import time
import json
def create_neutral_debator(llm): def create_neutral_debator(llm):

View File

@ -1,6 +1,4 @@
import functools import functools
import time
import json
from tradingagents.agents.utils.agent_utils import build_instrument_context from tradingagents.agents.utils.agent_utils import build_instrument_context

View File

@ -1,10 +1,6 @@
from typing import Annotated, Sequence from typing import Annotated
from datetime import date, timedelta, datetime from typing_extensions import TypedDict
from typing_extensions import TypedDict, Optional from langgraph.graph import MessagesState
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
# Researcher team state # Researcher team state

View File

@ -78,7 +78,7 @@ class FinancialSituationMemory:
# Build results # Build results
results = [] results = []
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0
for idx in top_indices: for idx in top_indices:
# Normalize score to 0-1 range for consistency # Normalize score to 0-1 range for consistency

View File

@ -22,7 +22,7 @@ def get_indicators(
""" """
# LLMs sometimes pass multiple indicators as a comma-separated string; # LLMs sometimes pass multiple indicators as a comma-separated string;
# split and process each individually. # split and process each individually.
indicators = [i.strip() for i in indicator.split(",") if i.strip()] indicators = [i.strip().lower() for i in indicator.split(",") if i.strip()]
results = [] results = []
for ind in indicators: for ind in indicators:
try: try:

View File

@ -1,6 +1,7 @@
from typing import Annotated from typing import Annotated
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
import pandas as pd
import yfinance as yf import yfinance as yf
import os import os
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date

View File

@ -1,12 +1,11 @@
import os import os
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
DEFAULT_CONFIG = { DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
"data_cache_dir": os.path.join( "data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# LLM settings # LLM settings
"llm_provider": "openai", "llm_provider": "openai",
"deep_think_llm": "gpt-5.4", "deep_think_llm": "gpt-5.4",

View File

@ -1,13 +1,12 @@
# TradingAgents/graph/reflection.py # TradingAgents/graph/reflection.py
from typing import Dict, Any from typing import Any, Dict
from langchain_openai import ChatOpenAI
class Reflector: class Reflector:
"""Handles reflection on decisions and updating memory.""" """Handles reflection on decisions and updating memory."""
def __init__(self, quick_thinking_llm: ChatOpenAI): def __init__(self, quick_thinking_llm: Any):
"""Initialize the reflector with an LLM.""" """Initialize the reflector with an LLM."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm
self.reflection_system_prompt = self._get_reflection_prompt() self.reflection_system_prompt = self._get_reflection_prompt()

View File

@ -1,8 +1,7 @@
# TradingAgents/graph/setup.py # TradingAgents/graph/setup.py
from typing import Dict, Any from typing import Any, Dict
from langchain_openai import ChatOpenAI from langgraph.graph import END, START, StateGraph
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from tradingagents.agents import * from tradingagents.agents import *
@ -16,8 +15,8 @@ class GraphSetup:
def __init__( def __init__(
self, self,
quick_thinking_llm: ChatOpenAI, quick_thinking_llm: Any,
deep_thinking_llm: ChatOpenAI, deep_thinking_llm: Any,
tool_nodes: Dict[str, ToolNode], tool_nodes: Dict[str, ToolNode],
bull_memory, bull_memory,
bear_memory, bear_memory,

View File

@ -1,12 +1,12 @@
# TradingAgents/graph/signal_processing.py # TradingAgents/graph/signal_processing.py
from langchain_openai import ChatOpenAI from typing import Any
class SignalProcessor: class SignalProcessor:
"""Processes trading signals to extract actionable decisions.""" """Processes trading signals to extract actionable decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI): def __init__(self, quick_thinking_llm: Any):
"""Initialize with an LLM for processing.""" """Initialize with an LLM for processing."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm

View File

@ -66,10 +66,8 @@ class TradingAgentsGraph:
set_config(self.config) set_config(self.config)
# Create necessary directories # Create necessary directories
os.makedirs( os.makedirs(self.config["data_cache_dir"], exist_ok=True)
os.path.join(self.config["project_dir"], "dataflows/data_cache"), os.makedirs(self.config["results_dir"], exist_ok=True)
exist_ok=True,
)
# Initialize LLMs with provider-specific thinking configuration # Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs() llm_kwargs = self._get_provider_kwargs()
@ -259,15 +257,12 @@ class TradingAgentsGraph:
} }
# Save to file # Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs"
directory.mkdir(parents=True, exist_ok=True) directory.mkdir(parents=True, exist_ok=True)
with open( log_path = directory / f"full_states_log_{trade_date}.json"
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", with open(log_path, "w", encoding="utf-8") as f:
"w", json.dump(self.log_states_dict[str(trade_date)], f, indent=4)
encoding="utf-8",
) as f:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses): def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns.""" """Reflect on decisions and update memory based on returns."""

View File

@ -0,0 +1,52 @@
import os
from typing import Any, Optional
from langchain_openai import AzureChatOpenAI
from .base_client import BaseLLMClient, normalize_content
from .validators import validate_model
_PASSTHROUGH_KWARGS = (
"timeout", "max_retries", "api_key", "reasoning_effort",
"callbacks", "http_client", "http_async_client",
)
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
"""AzureChatOpenAI with normalized content output."""
def invoke(self, input, config=None, **kwargs):
return normalize_content(super().invoke(input, config, **kwargs))
class AzureOpenAIClient(BaseLLMClient):
"""Client for Azure OpenAI deployments.
Requires environment variables:
AZURE_OPENAI_API_KEY: API key
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
"""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured AzureChatOpenAI instance."""
self.warn_if_unknown_model()
llm_kwargs = {
"model": self.model,
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
}
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
return NormalizedAzureChatOpenAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Azure accepts any deployed model name."""
return True

View File

@ -4,6 +4,12 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient from .anthropic_client import AnthropicClient
from .google_client import GoogleClient from .google_client import GoogleClient
from .azure_client import AzureOpenAIClient
# Providers that use the OpenAI-compatible chat completions API
_OPENAI_COMPATIBLE = (
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
)
def create_llm_client( def create_llm_client(
@ -15,16 +21,10 @@ def create_llm_client(
"""Create an LLM client for the specified provider. """Create an LLM client for the specified provider.
Args: Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter) provider: LLM provider name
model: Model name/identifier model: Model name/identifier
base_url: Optional base URL for API endpoint base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments **kwargs: Additional provider-specific arguments
- http_client: Custom httpx.Client for SSL proxy or certificate customization
- http_async_client: Custom httpx.AsyncClient for async operations
- timeout: Request timeout in seconds
- max_retries: Maximum retry attempts
- api_key: API key for the provider
- callbacks: LangChain callbacks
Returns: Returns:
Configured BaseLLMClient instance Configured BaseLLMClient instance
@ -34,16 +34,16 @@ def create_llm_client(
""" """
provider_lower = provider.lower() provider_lower = provider.lower()
if provider_lower in ("openai", "ollama", "openrouter"): if provider_lower in _OPENAI_COMPATIBLE:
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
if provider_lower == "xai":
return OpenAIClient(model, base_url, provider="xai", **kwargs)
if provider_lower == "anthropic": if provider_lower == "anthropic":
return AnthropicClient(model, base_url, **kwargs) return AnthropicClient(model, base_url, **kwargs)
if provider_lower == "google": if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs) return GoogleClient(model, base_url, **kwargs)
if provider_lower == "azure":
return AzureOpenAIClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@ -63,16 +63,43 @@ MODEL_OPTIONS: ProviderModeOptions = {
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"), ("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
], ],
}, },
"openrouter": { "deepseek": {
"quick": [ "quick": [
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), ("DeepSeek V3.2", "deepseek-chat"),
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), ("Custom model ID", "custom"),
], ],
"deep": [ "deep": [
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), ("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"), ("DeepSeek V3.2", "deepseek-chat"),
("Custom model ID", "custom"),
], ],
}, },
"qwen": {
"quick": [
("Qwen 3.5 Flash", "qwen3.5-flash"),
("Qwen Plus", "qwen-plus"),
("Custom model ID", "custom"),
],
"deep": [
("Qwen 3.6 Plus", "qwen3.6-plus"),
("Qwen 3.5 Plus", "qwen3.5-plus"),
("Qwen 3 Max", "qwen3-max"),
("Custom model ID", "custom"),
],
},
"glm": {
"quick": [
("GLM-4.7", "glm-4.7"),
("GLM-5", "glm-5"),
("Custom model ID", "custom"),
],
"deep": [
("GLM-5.1", "glm-5.1"),
("GLM-5", "glm-5"),
("Custom model ID", "custom"),
],
},
# OpenRouter: fetched dynamically. Azure: any deployed model name.
"ollama": { "ollama": {
"quick": [ "quick": [
("Qwen3:latest (8B, local)", "qwen3:latest"), ("Qwen3:latest (8B, local)", "qwen3:latest"),

View File

@ -27,6 +27,9 @@ _PASSTHROUGH_KWARGS = (
# Provider base URLs and API key env vars # Provider base URLs and API key env vars
_PROVIDER_CONFIG = { _PROVIDER_CONFIG = {
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"), "xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
"ollama": ("http://localhost:11434/v1", None), "ollama": ("http://localhost:11434/v1", None),
} }