Compare commits
5 Commits
23e41671e0
...
a9c4844105
| Author | SHA1 | Date |
|---|---|---|
|
|
a9c4844105 | |
|
|
fa4d01c23a | |
|
|
b0f6058299 | |
|
|
59d6b2152d | |
|
|
15b9f90ae2 |
|
|
@ -0,0 +1,5 @@
|
|||
# Azure OpenAI
|
||||
AZURE_OPENAI_API_KEY=
|
||||
AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME=
|
||||
# OPENAI_API_VERSION=2024-10-21 # optional, required for non-v1 API
|
||||
|
|
@ -3,4 +3,7 @@ OPENAI_API_KEY=
|
|||
GOOGLE_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
XAI_API_KEY=
|
||||
DEEPSEEK_API_KEY=
|
||||
DASHSCOPE_API_KEY=
|
||||
ZHIPU_API_KEY=
|
||||
OPENROUTER_API_KEY=
|
||||
|
|
|
|||
|
|
@ -140,10 +140,15 @@ export OPENAI_API_KEY=... # OpenAI (GPT)
|
|||
export GOOGLE_API_KEY=... # Google (Gemini)
|
||||
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
||||
export XAI_API_KEY=... # xAI (Grok)
|
||||
export DEEPSEEK_API_KEY=... # DeepSeek
|
||||
export DASHSCOPE_API_KEY=... # Qwen (Alibaba DashScope)
|
||||
export ZHIPU_API_KEY=... # GLM (Zhipu)
|
||||
export OPENROUTER_API_KEY=... # OpenRouter
|
||||
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||
```
|
||||
|
||||
For enterprise providers (e.g. Azure OpenAI, AWS Bedrock), copy `.env.enterprise.example` to `.env.enterprise` and fill in your credentials.
|
||||
|
||||
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
||||
|
||||
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,276 @@
|
|||
import backtrader as bt
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime, timedelta
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
class TradingAgentsStrategy(bt.Strategy):
|
||||
"""Strategy that uses TradingAgents for decision making"""
|
||||
|
||||
def __init__(self, trading_agent, ticker, backtest_config):
|
||||
self.trading_agent = trading_agent
|
||||
self.ticker = ticker
|
||||
self.backtest_config = backtest_config
|
||||
self.decisions = {}
|
||||
self.trade_count = 0
|
||||
self.data_checks = {}
|
||||
|
||||
def next(self):
|
||||
# Get current date
|
||||
current_date = self.datas[0].datetime.date(0)
|
||||
date_str = current_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Verify data range to avoid look-ahead bias
|
||||
self.verify_data_range(date_str)
|
||||
|
||||
# Get decision from TradingAgents
|
||||
if date_str not in self.decisions:
|
||||
print(f"Processing date: {date_str}")
|
||||
try:
|
||||
_, decision = self.trading_agent.propagate(self.ticker, date_str)
|
||||
self.decisions[date_str] = decision
|
||||
print(f"Decision: {decision}")
|
||||
except Exception as e:
|
||||
print(f"Error getting decision for {date_str}: {e}")
|
||||
self.decisions[date_str] = "HOLD"
|
||||
|
||||
decision = self.decisions[date_str]
|
||||
|
||||
# Execute trade based on decision
|
||||
if decision == "BUY" and not self.position:
|
||||
# Buy with 100% of available cash
|
||||
size = int(self.broker.getcash() / self.data.close[0])
|
||||
if size > 0:
|
||||
self.buy(size=size)
|
||||
self.trade_count += 1
|
||||
print(f"BUY {self.ticker} on {date_str} at ${self.data.close[0]:.2f}")
|
||||
|
||||
elif decision == "SELL" and self.position:
|
||||
# Sell all positions
|
||||
self.sell(size=self.position.size)
|
||||
self.trade_count += 1
|
||||
print(f"SELL {self.ticker} on {date_str} at ${self.data.close[0]:.2f}")
|
||||
|
||||
def verify_data_range(self, date_str):
|
||||
"""Verify that data range is correct to avoid look-ahead bias"""
|
||||
current_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
# Check if we already verified this date
|
||||
if date_str in self.data_checks:
|
||||
return
|
||||
|
||||
# Verify data feed doesn't contain future data
|
||||
data_end_date = self.datas[0].datetime.date(-1)
|
||||
if data_end_date > current_date:
|
||||
print(f"⚠️ Warning: Data feed contains future data beyond {date_str}")
|
||||
|
||||
self.data_checks[date_str] = True
|
||||
|
||||
def clean_cache():
|
||||
"""Clean cache to avoid look-ahead bias"""
|
||||
print("\n=== Cleaning cache to avoid look-ahead bias ===")
|
||||
|
||||
# Clean yfinance cache
|
||||
yfinance_cache = "yfinance_cache"
|
||||
if os.path.exists(yfinance_cache):
|
||||
shutil.rmtree(yfinance_cache)
|
||||
print(f"✓ Cleaned yfinance cache: {yfinance_cache}")
|
||||
|
||||
# Clean dataflows cache
|
||||
dataflows_cache = "dataflows/data_cache"
|
||||
if os.path.exists(dataflows_cache):
|
||||
shutil.rmtree(dataflows_cache)
|
||||
print(f"✓ Cleaned dataflows cache: {dataflows_cache}")
|
||||
|
||||
# Clean backtest results (optional)
|
||||
# backtest_results = "backtest_results"
|
||||
# if os.path.exists(backtest_results):
|
||||
# shutil.rmtree(backtest_results)
|
||||
# print(f"✓ Cleaned backtest results: {backtest_results}")
|
||||
|
||||
def run_backtest(ticker, start_date, end_date, initial_cash=100000, clean_cache_flag=True):
|
||||
"""Run backtest for a given ticker and date range"""
|
||||
|
||||
# Clean cache to avoid look-ahead bias
|
||||
if clean_cache_flag:
|
||||
clean_cache()
|
||||
|
||||
# Create Cerebro engine
|
||||
cerebro = bt.Cerebro()
|
||||
|
||||
# Set initial cash
|
||||
cerebro.broker.setcash(initial_cash)
|
||||
|
||||
# Add strategy
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["llm_provider"] = "openrouter"
|
||||
config["deep_think_llm"] = "deepseek/deepseek-chat"
|
||||
config["quick_think_llm"] = "openai/gpt-4o-mini"
|
||||
config["max_debate_rounds"] = 2
|
||||
|
||||
# Verify date range
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
if start_dt >= end_dt:
|
||||
raise ValueError("Start date must be before end date")
|
||||
|
||||
if end_dt > datetime.now():
|
||||
raise ValueError("End date cannot be in the future")
|
||||
|
||||
trading_agent = TradingAgentsGraph(debug=False, config=config)
|
||||
|
||||
backtest_config = {
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"initial_cash": initial_cash,
|
||||
"clean_cache": clean_cache_flag
|
||||
}
|
||||
|
||||
cerebro.addstrategy(TradingAgentsStrategy, trading_agent=trading_agent, ticker=ticker, backtest_config=backtest_config)
|
||||
|
||||
# Get historical data from yfinance
|
||||
print("\n=== Fetching historical data ===")
|
||||
data = yf.download(ticker, start=start_date, end=end_date)
|
||||
|
||||
# Verify data quality
|
||||
if data.empty:
|
||||
raise ValueError(f"No data found for {ticker} between {start_date} and {end_date}")
|
||||
|
||||
print(f"✓ Data fetched: {len(data)} trading days")
|
||||
print(f"✓ Date range: {data.index.min().date()} to {data.index.max().date()}")
|
||||
|
||||
# Convert to backtrader data feed
|
||||
data_feed = bt.feeds.PandasData(
|
||||
dataname=data,
|
||||
datetime=0,
|
||||
high=1,
|
||||
low=2,
|
||||
open=3,
|
||||
close=4,
|
||||
volume=5,
|
||||
openinterest=-1
|
||||
)
|
||||
|
||||
# Add data feed to cerebro
|
||||
cerebro.adddata(data_feed, name=ticker)
|
||||
|
||||
# Add analyzers
|
||||
cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe')
|
||||
cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
|
||||
cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades')
|
||||
cerebro.addanalyzer(bt.analyzers.AnnualReturn, _name='annual')
|
||||
cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
|
||||
cerebro.addanalyzer(bt.analyzers.PositionsValue, _name='positions')
|
||||
|
||||
# Run backtest
|
||||
print(f"\n=== Starting Backtest for {ticker} ===")
|
||||
print(f"Date range: {start_date} to {end_date}")
|
||||
print(f"Initial cash: ${initial_cash:.2f}")
|
||||
print(f"LLM Provider: {config['llm_provider']}")
|
||||
print(f"Models: Deep={config['deep_think_llm']}, Quick={config['quick_think_llm']}")
|
||||
|
||||
results = cerebro.run()
|
||||
|
||||
# Get results
|
||||
strategy = results[0]
|
||||
final_value = cerebro.broker.getvalue()
|
||||
total_return = ((final_value - initial_cash) / initial_cash) * 100
|
||||
|
||||
# Get analyzer results
|
||||
sharpe = strategy.analyzers.sharpe.get_analysis()
|
||||
drawdown = strategy.analyzers.drawdown.get_analysis()
|
||||
trades = strategy.analyzers.trades.get_analysis()
|
||||
annual = strategy.analyzers.annual.get_analysis()
|
||||
returns = strategy.analyzers.returns.get_analysis()
|
||||
|
||||
# Calculate additional metrics
|
||||
total_trades = trades.get('total', {}).get('total', 0)
|
||||
won_trades = trades.get('won', {}).get('total', 0)
|
||||
win_rate = won_trades / max(total_trades, 1) * 100
|
||||
|
||||
# Print results
|
||||
print(f"\n=== Backtest Results ===")
|
||||
print(f"Final portfolio value: ${final_value:.2f}")
|
||||
print(f"Total return: {total_return:.2f}%")
|
||||
print(f"Daily return: {returns.get('rnorm', 0) * 100:.4f}%")
|
||||
print(f"Sharpe Ratio: {sharpe.get('sharperatio', 'N/A'):.2f}")
|
||||
print(f"Max Drawdown: {drawdown.get('max', {}).get('drawdown', 'N/A'):.2f}%")
|
||||
print(f"Total trades: {total_trades}")
|
||||
print(f"Win rate: {win_rate:.2f}%")
|
||||
print(f"Average trade duration: {trades.get('len', {}).get('average', 'N/A'):.1f} days")
|
||||
|
||||
# Save results
|
||||
save_results(ticker, start_date, end_date, {
|
||||
"initial_cash": initial_cash,
|
||||
"final_value": final_value,
|
||||
"total_return": total_return,
|
||||
"daily_return": returns.get('rnorm', 0),
|
||||
"sharpe_ratio": sharpe.get('sharperatio', None),
|
||||
"max_drawdown": drawdown.get('max', {}).get('drawdown', None),
|
||||
"total_trades": total_trades,
|
||||
"won_trades": won_trades,
|
||||
"win_rate": win_rate,
|
||||
"average_trade_duration": trades.get('len', {}).get('average', None),
|
||||
"decisions": strategy.decisions,
|
||||
"config": backtest_config
|
||||
})
|
||||
|
||||
# Plot results
|
||||
print("\n=== Generating backtest chart ===")
|
||||
cerebro.plot(style='candlestick')
|
||||
|
||||
def save_results(ticker, start_date, end_date, results):
|
||||
"""Save backtest results to file"""
|
||||
results_dir = f"backtest_results/{ticker}/"
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
filename = f"backtest_{start_date}_{end_date}.json"
|
||||
filepath = os.path.join(results_dir, filename)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=4, default=str)
|
||||
|
||||
print(f"Results saved to: {filepath}")
|
||||
|
||||
def run_multiple_backtests(ticker_list, start_date, end_date, initial_cash=100000):
|
||||
"""Run backtests for multiple tickers"""
|
||||
all_results = {}
|
||||
|
||||
for ticker in ticker_list:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Running backtest for {ticker}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
# Run backtest without cleaning cache for subsequent tickers
|
||||
clean_cache_flag = (ticker == ticker_list[0])
|
||||
run_backtest(ticker, start_date, end_date, initial_cash, clean_cache_flag)
|
||||
except Exception as e:
|
||||
print(f"Error running backtest for {ticker}: {e}")
|
||||
all_results[ticker] = {"error": str(e)}
|
||||
|
||||
return all_results
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define parameters
|
||||
ticker = "NVDA"
|
||||
start_date = "2024-01-01"
|
||||
end_date = "2024-03-29"
|
||||
initial_cash = 100000
|
||||
|
||||
# Run backtest
|
||||
run_backtest(ticker, start_date, end_date, initial_cash)
|
||||
|
||||
# Example: Run multiple backtests
|
||||
# tickers = ["NVDA", "AAPL", "MSFT"]
|
||||
# run_multiple_backtests(tickers, start_date, end_date, initial_cash)
|
||||
43
cli/main.py
43
cli/main.py
|
|
@ -6,8 +6,9 @@ from functools import wraps
|
|||
from rich.console import Console
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
load_dotenv(".env.enterprise", override=False)
|
||||
from rich.panel import Panel
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
|
|
@ -79,7 +80,7 @@ class MessageBuffer:
|
|||
self.current_agent = None
|
||||
self.report_sections = {}
|
||||
self.selected_analysts = []
|
||||
self._last_message_id = None
|
||||
self._processed_message_ids = set()
|
||||
|
||||
def init_for_analysis(self, selected_analysts):
|
||||
"""Initialize agent status and report sections based on selected analysts.
|
||||
|
|
@ -114,7 +115,7 @@ class MessageBuffer:
|
|||
self.current_agent = None
|
||||
self.messages.clear()
|
||||
self.tool_calls.clear()
|
||||
self._last_message_id = None
|
||||
self._processed_message_ids.clear()
|
||||
|
||||
def get_completed_reports_count(self):
|
||||
"""Count reports that are finalized (their finalizing agent is completed).
|
||||
|
|
@ -1052,28 +1053,24 @@ def run_analysis():
|
|||
# Stream the analysis
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
# Process messages if present (skip duplicates via message ID)
|
||||
if len(chunk["messages"]) > 0:
|
||||
last_message = chunk["messages"][-1]
|
||||
msg_id = getattr(last_message, "id", None)
|
||||
# Process all messages in chunk, deduplicating by message ID
|
||||
for message in chunk.get("messages", []):
|
||||
msg_id = getattr(message, "id", None)
|
||||
if msg_id is not None:
|
||||
if msg_id in message_buffer._processed_message_ids:
|
||||
continue
|
||||
message_buffer._processed_message_ids.add(msg_id)
|
||||
|
||||
if msg_id != message_buffer._last_message_id:
|
||||
message_buffer._last_message_id = msg_id
|
||||
msg_type, content = classify_message_type(message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# Add message to buffer
|
||||
msg_type, content = classify_message_type(last_message)
|
||||
if content and content.strip():
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# Handle tool calls
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(
|
||||
tool_call["name"], tool_call["args"]
|
||||
)
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
||||
# Update analyst statuses based on report state (runs on every chunk)
|
||||
update_analyst_statuses(message_buffer, chunk)
|
||||
|
|
|
|||
92
cli/utils.py
92
cli/utils.py
|
|
@ -174,17 +174,30 @@ def select_openrouter_model() -> str:
|
|||
return choice
|
||||
|
||||
|
||||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
def _prompt_custom_model_id() -> str:
|
||||
"""Prompt user to type a custom model ID."""
|
||||
return questionary.text(
|
||||
"Enter model ID:",
|
||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a model ID.",
|
||||
).ask().strip()
|
||||
|
||||
|
||||
def _select_model(provider: str, mode: str) -> str:
|
||||
"""Select a model for the given provider and mode (quick/deep)."""
|
||||
if provider.lower() == "openrouter":
|
||||
return select_openrouter_model()
|
||||
|
||||
if provider.lower() == "azure":
|
||||
return questionary.text(
|
||||
f"Enter Azure deployment name ({mode}-thinking):",
|
||||
validate=lambda x: len(x.strip()) > 0 or "Please enter a deployment name.",
|
||||
).ask().strip()
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Quick-Thinking LLM Engine]:",
|
||||
f"Select Your [{mode.title()}-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in get_model_options(provider, "quick")
|
||||
for display, value in get_model_options(provider, mode)
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -197,58 +210,45 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print(
|
||||
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
|
||||
)
|
||||
console.print(f"\n[red]No {mode} thinking llm engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
if choice == "custom":
|
||||
return _prompt_custom_model_id()
|
||||
|
||||
return choice
|
||||
|
||||
|
||||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
return _select_model(provider, "quick")
|
||||
|
||||
|
||||
def select_deep_thinking_agent(provider) -> str:
|
||||
"""Select deep thinking llm engine using an interactive selection."""
|
||||
|
||||
if provider.lower() == "openrouter":
|
||||
return select_openrouter_model()
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Deep-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in get_model_options(provider, "deep")
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:magenta noinherit"),
|
||||
("highlighted", "fg:magenta noinherit"),
|
||||
("pointer", "fg:magenta noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
return choice
|
||||
return _select_model(provider, "deep")
|
||||
|
||||
def select_llm_provider() -> tuple[str, str | None]:
|
||||
"""Select the LLM provider and its API endpoint."""
|
||||
BASE_URLS = [
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Google", None), # google-genai SDK manages its own endpoint
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "https://api.x.ai/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
# (display_name, provider_key, base_url)
|
||||
PROVIDERS = [
|
||||
("OpenAI", "openai", "https://api.openai.com/v1"),
|
||||
("Google", "google", None),
|
||||
("Anthropic", "anthropic", "https://api.anthropic.com/"),
|
||||
("xAI", "xai", "https://api.x.ai/v1"),
|
||||
("DeepSeek", "deepseek", "https://api.deepseek.com"),
|
||||
("Qwen", "qwen", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
("GLM", "glm", "https://open.bigmodel.cn/api/paas/v4/"),
|
||||
("OpenRouter", "openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Azure OpenAI", "azure", None),
|
||||
("Ollama", "ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
||||
|
||||
choice = questionary.select(
|
||||
"Select your LLM Provider:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=(display, value))
|
||||
for display, value in BASE_URLS
|
||||
questionary.Choice(display, value=(provider_key, url))
|
||||
for display, provider_key, url in PROVIDERS
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -261,13 +261,11 @@ def select_llm_provider() -> tuple[str, str | None]:
|
|||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
display_name, url = choice
|
||||
print(f"You selected: {display_name}\tURL: {url}")
|
||||
|
||||
return display_name, url
|
||||
provider, url = choice
|
||||
return provider, url
|
||||
|
||||
|
||||
def ask_openai_reasoning_effort() -> str:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ services:
|
|||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./results:/home/appuser/app/results
|
||||
- tradingagents_data:/home/appuser/.tradingagents
|
||||
tty: true
|
||||
stdin_open: true
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ services:
|
|||
environment:
|
||||
- LLM_PROVIDER=ollama
|
||||
volumes:
|
||||
- ./results:/home/appuser/app/results
|
||||
- tradingagents_data:/home/appuser/.tradingagents
|
||||
depends_on:
|
||||
- ollama
|
||||
tty: true
|
||||
|
|
@ -31,4 +31,5 @@ services:
|
|||
- ollama
|
||||
|
||||
volumes:
|
||||
tradingagents_data:
|
||||
ollama_data:
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class FinancialSituationMemory:
|
|||
|
||||
# Build results
|
||||
results = []
|
||||
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
|
||||
max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0
|
||||
|
||||
for idx in top_indices:
|
||||
# Normalize score to 0-1 range for consistency
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
|
||||
_TRADINGAGENTS_HOME = os.path.join(os.path.expanduser("~"), ".tradingagents")
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
"data_cache_dir": os.path.join(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
||||
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "gpt-5.4",
|
||||
|
|
|
|||
|
|
@ -66,10 +66,8 @@ class TradingAgentsGraph:
|
|||
set_config(self.config)
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs(
|
||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
||||
exist_ok=True,
|
||||
)
|
||||
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
|
||||
os.makedirs(self.config["results_dir"], exist_ok=True)
|
||||
|
||||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from .base_client import BaseLLMClient, normalize_content
|
||||
from .validators import validate_model
|
||||
|
||||
_PASSTHROUGH_KWARGS = (
|
||||
"timeout", "max_retries", "api_key", "reasoning_effort",
|
||||
"callbacks", "http_client", "http_async_client",
|
||||
)
|
||||
|
||||
|
||||
class NormalizedAzureChatOpenAI(AzureChatOpenAI):
|
||||
"""AzureChatOpenAI with normalized content output."""
|
||||
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
|
||||
class AzureOpenAIClient(BaseLLMClient):
|
||||
"""Client for Azure OpenAI deployments.
|
||||
|
||||
Requires environment variables:
|
||||
AZURE_OPENAI_API_KEY: API key
|
||||
AZURE_OPENAI_ENDPOINT: Endpoint URL (e.g. https://<resource>.openai.azure.com/)
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME: Deployment name
|
||||
OPENAI_API_VERSION: API version (e.g. 2025-03-01-preview)
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured AzureChatOpenAI instance."""
|
||||
self.warn_if_unknown_model()
|
||||
|
||||
llm_kwargs = {
|
||||
"model": self.model,
|
||||
"azure_deployment": os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME", self.model),
|
||||
}
|
||||
|
||||
for key in _PASSTHROUGH_KWARGS:
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
return NormalizedAzureChatOpenAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Azure accepts any deployed model name."""
|
||||
return True
|
||||
|
|
@ -4,6 +4,12 @@ from .base_client import BaseLLMClient
|
|||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .azure_client import AzureOpenAIClient
|
||||
|
||||
# Providers that use the OpenAI-compatible chat completions API
|
||||
_OPENAI_COMPATIBLE = (
|
||||
"openai", "xai", "deepseek", "qwen", "glm", "ollama", "openrouter",
|
||||
)
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
|
|
@ -15,16 +21,10 @@ def create_llm_client(
|
|||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||
provider: LLM provider name
|
||||
model: Model name/identifier
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
- http_client: Custom httpx.Client for SSL proxy or certificate customization
|
||||
- http_async_client: Custom httpx.AsyncClient for async operations
|
||||
- timeout: Request timeout in seconds
|
||||
- max_retries: Maximum retry attempts
|
||||
- api_key: API key for the provider
|
||||
- callbacks: LangChain callbacks
|
||||
|
||||
Returns:
|
||||
Configured BaseLLMClient instance
|
||||
|
|
@ -34,16 +34,16 @@ def create_llm_client(
|
|||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||
if provider_lower in _OPENAI_COMPATIBLE:
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
|
||||
if provider_lower == "xai":
|
||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
||||
|
||||
if provider_lower == "anthropic":
|
||||
return AnthropicClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "azure":
|
||||
return AzureOpenAIClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -63,8 +63,43 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
|||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
],
|
||||
},
|
||||
# OpenRouter models are fetched dynamically at CLI runtime.
|
||||
# No static entries needed; any model ID is accepted by the validator.
|
||||
"deepseek": {
|
||||
"quick": [
|
||||
("DeepSeek V3.2", "deepseek-chat"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("DeepSeek V3.2 (thinking)", "deepseek-reasoner"),
|
||||
("DeepSeek V3.2", "deepseek-chat"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
"qwen": {
|
||||
"quick": [
|
||||
("Qwen 3.5 Flash", "qwen3.5-flash"),
|
||||
("Qwen Plus", "qwen-plus"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("Qwen 3.6 Plus", "qwen3.6-plus"),
|
||||
("Qwen 3.5 Plus", "qwen3.5-plus"),
|
||||
("Qwen 3 Max", "qwen3-max"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
"glm": {
|
||||
"quick": [
|
||||
("GLM-4.7", "glm-4.7"),
|
||||
("GLM-5", "glm-5"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
"deep": [
|
||||
("GLM-5.1", "glm-5.1"),
|
||||
("GLM-5", "glm-5"),
|
||||
("Custom model ID", "custom"),
|
||||
],
|
||||
},
|
||||
# OpenRouter: fetched dynamically. Azure: any deployed model name.
|
||||
"ollama": {
|
||||
"quick": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ _PASSTHROUGH_KWARGS = (
|
|||
# Provider base URLs and API key env vars
|
||||
_PROVIDER_CONFIG = {
|
||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
||||
"deepseek": ("https://api.deepseek.com", "DEEPSEEK_API_KEY"),
|
||||
"qwen": ("https://dashscope-intl.aliyuncs.com/compatible-mode/v1", "DASHSCOPE_API_KEY"),
|
||||
"glm": ("https://api.z.ai/api/paas/v4/", "ZHIPU_API_KEY"),
|
||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
||||
"ollama": ("http://localhost:11434/v1", None),
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue