This commit is contained in:
M. Umar Jahangir 2026-01-19 22:23:02 -05:00 committed by GitHub
commit 41501a4e1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1476 additions and 71 deletions

View File

@ -1,2 +1,17 @@
# Data vendor API keys
ALPHA_VANTAGE_API_KEY=alpha_vantage_api_key_placeholder ALPHA_VANTAGE_API_KEY=alpha_vantage_api_key_placeholder
# LLM Provider API keys (set the ones you want to use)
OPENAI_API_KEY=openai_api_key_placeholder OPENAI_API_KEY=openai_api_key_placeholder
ANTHROPIC_API_KEY=anthropic_api_key_placeholder
GEMINI_API_KEY=gemini_api_key_placeholder
OPENROUTER_API_KEY=openrouter_api_key_placeholder
# Local LLM provider URLs (optional, defaults shown)
# OLLAMA_URL=http://localhost:11434
# LM_STUDIO_URL=http://localhost:1234
# Feature flags
# Set to "true" to fetch latest models from APIs and use latest web_search tool
# Set to "false" or leave unset for static model lists and web_search_preview (legacy)
FETCH_LATEST=true

100
cli/api_keys.py Normal file
View File

@ -0,0 +1,100 @@
"""API key and endpoint validation for LLM providers."""
import os
from typing import Optional, Tuple
import httpx
# Map cloud providers to their required environment variables
PROVIDER_API_KEYS = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"google": "GEMINI_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
}
# Default endpoints for local providers
LOCAL_PROVIDER_DEFAULTS = {
"ollama": ("OLLAMA_URL", "http://localhost:11434"),
"lm studio": ("LM_STUDIO_URL", "http://localhost:1234"),
}
def get_api_key(provider: str) -> Optional[str]:
"""Get API key for a cloud provider, returns None if not set."""
provider_lower = provider.lower()
# Special case: OpenRouter can use OPENROUTER_API_KEY or OPENAI_API_KEY with sk-or- prefix
if provider_lower == "openrouter":
openrouter_key = os.getenv("OPENROUTER_API_KEY")
if openrouter_key:
return openrouter_key
# Check if OPENAI_API_KEY is actually an OpenRouter key
openai_key = os.getenv("OPENAI_API_KEY", "")
if openai_key.startswith("sk-or-"):
return openai_key
return None
env_var = PROVIDER_API_KEYS.get(provider_lower)
if env_var is None:
return None
return os.getenv(env_var)
def get_local_endpoint(provider: str) -> Optional[str]:
"""Get the endpoint URL for a local provider."""
provider_lower = provider.lower()
if provider_lower not in LOCAL_PROVIDER_DEFAULTS:
return None
env_var, default_url = LOCAL_PROVIDER_DEFAULTS[provider_lower]
return os.getenv(env_var, default_url)
def is_local_provider_running(provider: str) -> bool:
"""Check if a local provider (Ollama/LM Studio) is running by probing its endpoint."""
endpoint = get_local_endpoint(provider)
if not endpoint:
return False
try:
# Probe the models endpoint with a short timeout
response = httpx.get(
f"{endpoint}/v1/models",
timeout=1.0
)
return response.status_code == 200
except (httpx.RequestError, httpx.TimeoutException):
return False
def is_provider_available(provider: str) -> Tuple[bool, str]:
"""
Check if a provider is available.
Returns:
Tuple of (is_available, reason_if_unavailable)
"""
provider_lower = provider.lower()
# Local providers: check if endpoint is reachable
if provider_lower in LOCAL_PROVIDER_DEFAULTS:
if is_local_provider_running(provider):
return (True, "")
return (False, "Not running")
# Cloud providers: check for API key
if get_api_key(provider) is not None:
return (True, "")
return (False, "No API key")
def get_all_provider_availability() -> dict:
"""
Get availability status for all providers.
Returns:
Dict mapping provider name to (is_available, reason) tuple
"""
all_providers = list(PROVIDER_API_KEYS.keys()) + list(LOCAL_PROVIDER_DEFAULTS.keys())
return {provider: is_provider_available(provider) for provider in all_providers}

501
cli/compile_reports.py Normal file
View File

@ -0,0 +1,501 @@
#!/usr/bin/env python3
"""
Compile all trading agent reports into a single consolidated PDF.
Creates a PDF with:
1. Summary table showing all symbols, their decisions, and analysis dates
2. Detailed reports for each symbol (in order specified by REPORT_ORDER)
Usage:
python cli/compile_reports.py # Compile all results into single PDF
python cli/compile_reports.py --output report.pdf # Custom output filename
python cli/compile_reports.py --date 2026-01-18 # Filter to specific date (auto-names output)
"""
import argparse
import re
import sys
from datetime import datetime
from pathlib import Path
import markdown2
from playwright.sync_api import sync_playwright
# Report order (top to bottom for each symbol's section)
REPORT_ORDER = [
("final_trade_decision.md", "Final Trade Decision"),
("trader_investment_plan.md", "Trader Investment Plan"),
("investment_plan.md", "Investment Plan"),
("fundamentals_report.md", "Fundamentals Analysis"),
("news_report.md", "News Analysis"),
("sentiment_report.md", "Sentiment Analysis"),
("market_report.md", "Market Analysis"),
]
# Clean GitHub-style markdown CSS
CSS_STYLES = """
@page {
size: A4;
margin: 0.75in;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Noto Sans', Helvetica, Arial, sans-serif;
font-size: 14px;
line-height: 1.6;
color: #24292f;
max-width: 100%;
margin: 0;
padding: 0;
}
h1 {
font-size: 2em;
font-weight: 600;
border-bottom: 1px solid #d0d7de;
padding-bottom: 0.3em;
margin-top: 24px;
margin-bottom: 16px;
}
h2 {
font-size: 1.5em;
font-weight: 600;
border-bottom: 1px solid #d0d7de;
padding-bottom: 0.3em;
margin-top: 24px;
margin-bottom: 16px;
}
h3 {
font-size: 1.25em;
font-weight: 600;
margin-top: 24px;
margin-bottom: 16px;
}
h4 {
font-size: 1em;
font-weight: 600;
margin-top: 24px;
margin-bottom: 16px;
}
p {
margin-top: 0;
margin-bottom: 16px;
}
ul, ol {
padding-left: 2em;
margin-top: 0;
margin-bottom: 16px;
}
li {
margin-bottom: 4px;
}
li + li {
margin-top: 4px;
}
table {
border-collapse: collapse;
width: 100%;
margin-top: 0;
margin-bottom: 16px;
}
th, td {
padding: 6px 13px;
border: 1px solid #d0d7de;
}
th {
background-color: #f6f8fa;
font-weight: 600;
}
tr:nth-child(2n) {
background-color: #f6f8fa;
}
hr {
border: 0;
border-top: 1px solid #d0d7de;
margin: 24px 0;
}
code {
background-color: rgba(175, 184, 193, 0.2);
padding: 0.2em 0.4em;
border-radius: 6px;
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
font-size: 85%;
}
pre {
background-color: #f6f8fa;
padding: 16px;
border-radius: 6px;
overflow-x: auto;
margin-bottom: 16px;
font-size: 85%;
line-height: 1.45;
}
pre code {
padding: 0;
background: none;
font-size: 100%;
}
blockquote {
border-left: 0.25em solid #d0d7de;
padding: 0 1em;
margin: 0 0 16px 0;
color: #57606a;
}
strong {
font-weight: 600;
}
/* Decision color styling */
.decision-buy { color: #1a7f37; font-weight: 700; }
.decision-sell { color: #cf222e; font-weight: 700; }
.decision-hold { color: #9a6700; font-weight: 700; }
/* Symbol section - page break before each new symbol */
.symbol-section {
page-break-before: always;
}
.symbol-section:first-of-type {
page-break-before: avoid;
}
/* Report title styling */
.report-title {
color: #0969da;
font-size: 1.3em;
font-weight: 600;
margin-top: 32px;
margin-bottom: 16px;
padding-bottom: 8px;
border-bottom: 2px solid #0969da;
}
.report-title:first-of-type {
margin-top: 16px;
}
"""
def extract_decision(content: str) -> str:
"""Extract BUY/SELL/HOLD decision from final trade decision content."""
content_lower = content.lower()
patterns = [
r"recommendation[:\s]*\*{0,2}(buy|sell|hold)\*{0,2}",
r"\*{0,2}(buy|sell|hold)\*{0,2}[:\s]*recommendation",
r"final.*?decision[:\s]*\*{0,2}(buy|sell|hold)\*{0,2}",
r"recommend.*?(buy|sell|hold)",
r"action[:\s]*\*{0,2}(buy|sell|hold)\*{0,2}",
]
for pattern in patterns:
match = re.search(pattern, content_lower)
if match:
return match.group(1).upper()
buy_count = len(re.findall(r"\bbuy\b", content_lower))
sell_count = len(re.findall(r"\bsell\b", content_lower))
hold_count = len(re.findall(r"\bhold\b", content_lower))
max_count = max(buy_count, sell_count, hold_count)
if max_count > 0:
if sell_count == max_count:
return "SELL"
if buy_count == max_count:
return "BUY"
return "HOLD"
return "N/A"
def markdown_to_html(md_content: str) -> str:
"""Convert markdown to HTML with extras."""
return markdown2.markdown(
md_content,
extras=[
"tables",
"fenced-code-blocks",
"strike",
"task_list",
"cuddled-lists",
],
)
def find_all_reports(results_dir: Path, date_filter: str | None = None) -> list[dict]:
"""Find all report directories and extract their data.
Args:
results_dir: Path to the results directory
date_filter: Optional date string (YYYY-MM-DD) to filter reports
"""
all_reports = []
if not results_dir.exists():
return all_reports
for symbol_dir in sorted(results_dir.iterdir()):
if not symbol_dir.is_dir():
continue
symbol = symbol_dir.name
if symbol.startswith(".") or " " in symbol:
continue
for date_dir in sorted(symbol_dir.iterdir(), reverse=True):
if not date_dir.is_dir():
continue
date = date_dir.name
# Skip if date doesn't match filter
if date_filter and date != date_filter:
continue
reports_dir = date_dir / "reports"
if not reports_dir.exists():
continue
report_files = []
decision = "N/A"
for filename, title in REPORT_ORDER:
file_path = reports_dir / filename
if file_path.exists():
content = file_path.read_text(encoding="utf-8")
html_content = markdown_to_html(content)
report_files.append((filename, title, html_content))
if filename == "final_trade_decision.md":
decision = extract_decision(content)
if report_files:
all_reports.append({
"symbol": symbol,
"date": date,
"decision": decision,
"reports_dir": reports_dir,
"reports": report_files,
})
return all_reports
def build_html_document(all_reports: list[dict]) -> str:
"""Build complete HTML document with summary table and all reports."""
# Build summary table rows
summary_rows = []
for report_data in all_reports:
decision = report_data["decision"]
decision_class = f"decision-{decision.lower()}" if decision in ["BUY", "SELL", "HOLD"] else ""
summary_rows.append(f'''<tr>
<td><strong>{report_data["symbol"]}</strong></td>
<td>{report_data["date"]}</td>
<td class="{decision_class}">{decision}</td>
<td>{len(report_data["reports"])} reports</td>
</tr>''')
summary_table = "\n".join(summary_rows)
# Build symbol sections
symbol_sections = []
for report_data in all_reports:
symbol = report_data["symbol"]
date = report_data["date"]
decision = report_data["decision"]
decision_class = f"decision-{decision.lower()}" if decision in ["BUY", "SELL", "HOLD"] else ""
# Build report content - simple flowing structure
reports_html_parts = []
for _, title, html_content in report_data["reports"]:
reports_html_parts.append(f'''<div class="report-title">{title}</div>
{html_content}''')
reports_html = "\n".join(reports_html_parts)
symbol_sections.append(f'''<div class="symbol-section">
<h1>{symbol} Trading Analysis Report</h1>
<p><strong>Date:</strong> {date} &nbsp;|&nbsp; <strong>Recommendation:</strong> <span class="{decision_class}">{decision}</span></p>
<hr>
{reports_html}
</div>''')
all_symbols_html = "\n".join(symbol_sections)
generated_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
html = f'''<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Trading Analysis Report</title>
<style>
{CSS_STYLES}
</style>
</head>
<body>
<h1>Trading Analysis Report</h1>
<p><em>Generated: {generated_date}</em></p>
<h2>Summary of Recommendations</h2>
<table>
<thead>
<tr>
<th>Symbol</th>
<th>Analysis Date</th>
<th>Decision</th>
<th>Reports</th>
</tr>
</thead>
<tbody>
{summary_table}
</tbody>
</table>
<hr>
{all_symbols_html}
<hr>
<p><em>Report generated by TradingAgents</em></p>
</body>
</html>'''
return html
def compile_to_pdf(html_content: str, output_path: Path) -> bool:
"""Generate PDF from HTML using Playwright."""
try:
with sync_playwright() as p:
browser = p.chromium.launch()
page = browser.new_page()
page.set_content(html_content, wait_until="networkidle")
page.pdf(
path=str(output_path),
format="A4",
margin={
"top": "0.5in",
"bottom": "0.5in",
"left": "0.5in",
"right": "0.5in",
},
print_background=True,
)
browser.close()
return True
except Exception as e:
print(f"Error generating PDF: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Compile all trading agent reports into a single consolidated PDF",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python cli/compile_reports.py
python cli/compile_reports.py --output my_report.pdf
python cli/compile_reports.py --date 2026-01-18
python cli/compile_reports.py --date 2026-01-18 --output custom.pdf
""",
)
parser.add_argument(
"--output", "-o",
default="./results/trading_analysis_report.pdf",
help="Output PDF filename (default: ./results/trading_analysis_report.pdf)",
)
parser.add_argument(
"--results-dir", "-r",
default="./results",
help="Results directory (default: ./results)",
)
parser.add_argument(
"--date", "-d",
help="Filter reports to a specific date (format: YYYY-MM-DD)",
)
args = parser.parse_args()
# Validate date format if provided
if args.date:
import re as re_module
if not re_module.match(r"^\d{4}-\d{2}-\d{2}$", args.date):
print(f"Error: Invalid date format '{args.date}'. Expected YYYY-MM-DD")
sys.exit(1)
results_dir = Path(args.results_dir)
default_output = "./results/trading_analysis_report.pdf"
if not results_dir.exists():
print(f"Error: Results directory not found: {results_dir}")
sys.exit(1)
if args.date:
print(f"Scanning {results_dir} for reports on {args.date}...")
else:
print(f"Scanning {results_dir} for reports...")
all_reports = find_all_reports(results_dir, date_filter=args.date)
if not all_reports:
if args.date:
print(f"No reports found for date {args.date}")
else:
print("No reports found")
sys.exit(1)
print(f"Found {len(all_reports)} symbol analysis report(s):\n")
for report_data in all_reports:
decision_indicator = {
"BUY": "[BUY]",
"SELL": "[SELL]",
"HOLD": "[HOLD]",
}.get(report_data["decision"], "[N/A]")
print(f" {report_data['symbol']:6} | {report_data['date']} | {decision_indicator:6} | {len(report_data['reports'])} reports")
# Determine output path
if args.date and args.output == default_output:
# Generate dynamic filename from date + symbols (up to 5)
symbols = [r["symbol"] for r in all_reports[:5]]
symbols_str = "_".join(symbols)
output_path = Path(f"./results/trading_report_{args.date}_{symbols_str}.pdf")
else:
output_path = Path(args.output)
print("\nGenerating PDF...")
html_document = build_html_document(all_reports)
if compile_to_pdf(html_document, output_path):
print(f"\n+ PDF created: {output_path}")
else:
print("\n- Failed to create PDF")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@ -429,10 +429,12 @@ def get_user_selections():
box_content += f"\n[dim]Default: {default}[/dim]" box_content += f"\n[dim]Default: {default}[/dim]"
return Panel(box_content, border_style="blue", padding=(1, 2)) return Panel(box_content, border_style="blue", padding=(1, 2))
# Step 1: Ticker symbol # Step 1: Ticker symbol(s)
console.print( console.print(
create_question_box( create_question_box(
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY" "Step 1: Ticker Symbol(s)",
"Enter ticker symbol(s) to analyze (comma-separated for multiple)",
"SPY",
) )
) )
selected_ticker = get_ticker() selected_ticker = get_ticker()
@ -475,13 +477,20 @@ def get_user_selections():
) )
selected_llm_provider, backend_url = select_llm_provider() selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents # Step 6: Quick-Thinking LLM Engine
console.print( console.print(
create_question_box( create_question_box(
"Step 6: Thinking Agents", "Select your thinking agents for analysis" "Step 6: Quick-Thinking LLM Engine", "Select your quick-thinking model for fast operations"
) )
) )
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
# Step 7: Deep-Thinking LLM Engine
console.print(
create_question_box(
"Step 7: Deep-Thinking LLM Engine", "Select your deep-thinking model for complex reasoning"
)
)
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
return { return {
@ -497,8 +506,11 @@ def get_user_selections():
def get_ticker(): def get_ticker():
"""Get ticker symbol from user input.""" """Get ticker symbol(s) from user input. Supports comma-separated symbols."""
return typer.prompt("", default="SPY") raw_input = typer.prompt("", default="SPY")
# Split by comma, strip whitespace, convert to uppercase
symbols = [s.strip().upper() for s in raw_input.split(",") if s.strip()]
return symbols if len(symbols) > 1 else symbols[0]
def get_analysis_date(): def get_analysis_date():
@ -736,6 +748,7 @@ def extract_content_string(content):
return str(content) return str(content)
def run_analysis(): def run_analysis():
"""Run analysis for one or more ticker symbols."""
# First get all user selections # First get all user selections
selections = get_user_selections() selections = get_user_selections()
@ -748,13 +761,33 @@ def run_analysis():
config["backend_url"] = selections["backend_url"] config["backend_url"] = selections["backend_url"]
config["llm_provider"] = selections["llm_provider"].lower() config["llm_provider"] = selections["llm_provider"].lower()
# Initialize the graph # Normalize ticker(s) to list
tickers = selections["ticker"] if isinstance(selections["ticker"], list) else [selections["ticker"]]
# Initialize the graph once and reuse for all symbols
graph = TradingAgentsGraph( graph = TradingAgentsGraph(
[analyst.value for analyst in selections["analysts"]], config=config, debug=True [analyst.value for analyst in selections["analysts"]], config=config, debug=True
) )
for i, ticker in enumerate(tickers, 1):
if len(tickers) > 1:
console.print(f"\n[bold cyan]{'' * 50}[/bold cyan]")
console.print(f"[bold cyan] Analyzing {ticker} ({i}/{len(tickers)})[/bold cyan]")
console.print(f"[bold cyan]{'' * 50}[/bold cyan]\n")
run_single_analysis(ticker, selections, config, graph)
if i < len(tickers):
console.print(f"\n[dim]Moving to next symbol...[/dim]\n")
if len(tickers) > 1:
console.print(f"\n[bold green]Completed analysis for all {len(tickers)} symbols: {', '.join(tickers)}[/bold green]")
def run_single_analysis(ticker: str, selections: dict, config: dict, graph: TradingAgentsGraph):
"""Run analysis for a single ticker symbol."""
# Create result directory # Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] results_dir = Path(config["results_dir"]) / ticker / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True) results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports" report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True) report_dir.mkdir(parents=True, exist_ok=True)
@ -808,7 +841,7 @@ def run_analysis():
update_display(layout) update_display(layout)
# Add initial messages # Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") message_buffer.add_message("System", f"Selected ticker: {ticker}")
message_buffer.add_message( message_buffer.add_message(
"System", f"Analysis date: {selections['analysis_date']}" "System", f"Analysis date: {selections['analysis_date']}"
) )
@ -835,13 +868,13 @@ def run_analysis():
# Create spinner text # Create spinner text
spinner_text = ( spinner_text = (
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..." f"Analyzing {ticker} on {selections['analysis_date']}..."
) )
update_display(layout, spinner_text) update_display(layout, spinner_text)
# Initialize state and get graph args # Initialize state and get graph args
init_agent_state = graph.propagator.create_initial_state( init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"] ticker, selections["analysis_date"]
) )
args = graph.propagator.get_graph_args() args = graph.propagator.get_graph_args()

225
cli/model_fetcher.py Normal file
View File

@ -0,0 +1,225 @@
"""Dynamic model fetching from LLM provider APIs with caching."""
import os
from typing import List, Tuple, Optional
import httpx
# Cache for fetched models (provider -> list of models)
_model_cache: dict = {}
# Maximum number of models to display (None = no limit, show all)
MAX_MODELS = None
def is_fetch_latest() -> bool:
"""Check if FETCH_LATEST is enabled in environment.
When enabled, fetches models dynamically from provider APIs.
When disabled, falls back to static hardcoded model lists.
"""
return os.getenv("FETCH_LATEST", "false").lower() in ("true", "1", "yes")
def fetch_openai_models() -> Optional[List[Tuple[str, str]]]:
"""
Fetch available models from OpenAI API, sorted by creation date (newest first).
Returns:
List of (display_name, model_id) tuples, or None on failure
"""
if "openai" in _model_cache:
return _model_cache["openai"]
api_key = os.getenv("OPENAI_API_KEY")
if not api_key or api_key.startswith("sk-or-"):
return None
try:
response = httpx.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
timeout=10.0
)
response.raise_for_status()
models_data = response.json().get("data", [])
# Filter to chat/reasoning models and keep metadata for sorting
chat_models = []
for model in models_data:
model_id = model.get("id", "")
created = model.get("created", 0)
# Include GPT models and reasoning models (o-series)
if (model_id.startswith("gpt-") or
model_id.startswith("o1") or
model_id.startswith("o3") or
model_id.startswith("o4") or
model_id.startswith("o5") or
model_id.startswith("gpt-5")):
# Skip snapshot/dated versions to keep list clean
if "-20" not in model_id and "-preview" not in model_id.lower():
chat_models.append((model_id, created))
# Remove duplicates (keep highest created timestamp for each model_id)
model_dict = {}
for model_id, created in chat_models:
if model_id not in model_dict or created > model_dict[model_id]:
model_dict[model_id] = created
# Sort by created timestamp (newest first) and limit
sorted_models = sorted(model_dict.items(), key=lambda x: -x[1])[:MAX_MODELS]
result = [(model_id, model_id) for model_id, _ in sorted_models]
_model_cache["openai"] = result
return result
except (httpx.RequestError, httpx.HTTPStatusError, ValueError, KeyError):
return None
def fetch_anthropic_models() -> Optional[List[Tuple[str, str]]]:
"""
Fetch available models from Anthropic API, sorted by creation date (newest first).
Returns:
List of (display_name, model_id) tuples, or None on failure
"""
if "anthropic" in _model_cache:
return _model_cache["anthropic"]
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
return None
try:
response = httpx.get(
"https://api.anthropic.com/v1/models",
headers={
"x-api-key": api_key,
"anthropic-version": "2023-06-01"
},
timeout=10.0
)
response.raise_for_status()
models_data = response.json().get("data", [])
# Filter to Claude models and keep metadata for sorting
claude_models = []
for model in models_data:
model_id = model.get("id", "")
# Anthropic API returns created_at as ISO string (RFC 3339)
created_at = model.get("created_at", "")
display_name = model.get("display_name", "")
if model_id.startswith("claude-"):
# Skip dated versions (e.g., claude-3-sonnet-20240229)
if "-20" not in model_id:
# Use display_name if available, otherwise model_id
label = display_name if display_name else model_id
claude_models.append((model_id, label, created_at))
# Remove duplicates (keep latest for each model_id)
model_dict = {}
for model_id, label, created_at in claude_models:
if model_id not in model_dict or created_at > model_dict[model_id][1]:
model_dict[model_id] = (label, created_at)
# Sort by created_at (newest first) and limit
sorted_models = sorted(model_dict.items(), key=lambda x: x[1][1], reverse=True)[:MAX_MODELS]
result = [(label, model_id) for model_id, (label, _) in sorted_models]
_model_cache["anthropic"] = result
return result
except (httpx.RequestError, httpx.HTTPStatusError, ValueError, KeyError):
return None
def fetch_google_models() -> Optional[List[Tuple[str, str]]]:
"""
Fetch available models from Google Generative AI API.
Uses displayName for user-friendly labels, sorted as returned by API (typically newest first).
Returns:
List of (display_name, model_id) tuples, or None on failure
"""
if "google" in _model_cache:
return _model_cache["google"]
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
return None
try:
response = httpx.get(
f"https://generativelanguage.googleapis.com/v1/models?key={api_key}",
timeout=10.0
)
response.raise_for_status()
models_data = response.json().get("models", [])
# Filter to Gemini models that support generateContent
gemini_models = []
for model in models_data:
model_name = model.get("name", "")
display_name = model.get("displayName", "")
supported_methods = model.get("supportedGenerationMethods", [])
# Extract model ID from "models/gemini-..." format
if model_name.startswith("models/"):
model_id = model_name.replace("models/", "")
else:
model_id = model_name
# Only include Gemini models that support content generation
if model_id.startswith("gemini") and "generateContent" in supported_methods:
# Use displayName if available, otherwise model_id
label = display_name if display_name else model_id
gemini_models.append((label, model_id))
# API returns in a reasonable order, just dedupe and limit
seen = set()
unique_models = []
for label, model_id in gemini_models:
if model_id not in seen:
seen.add(model_id)
unique_models.append((label, model_id))
result = unique_models[:MAX_MODELS]
_model_cache["google"] = result
return result
except (httpx.RequestError, httpx.HTTPStatusError, ValueError, KeyError):
return None
def fetch_models_for_provider(provider: str) -> Optional[List[Tuple[str, str]]]:
"""
Fetch models for a given provider.
Only fetches dynamically if FETCH_LATEST is enabled. Otherwise returns None
to trigger fallback to static model lists.
Args:
provider: Provider name (openai, anthropic, google)
Returns:
List of (display_name, model_id) tuples, or None if not supported/failed
"""
# Return None if FETCH_LATEST is not enabled - will use static lists
if not is_fetch_latest():
return None
provider_lower = provider.lower()
if provider_lower == "openai":
return fetch_openai_models()
elif provider_lower == "anthropic":
return fetch_anthropic_models()
elif provider_lower == "google":
return fetch_google_models()
return None
def clear_cache():
"""Clear the model cache."""
_model_cache.clear()

View File

@ -1,7 +1,12 @@
import questionary import questionary
from typing import List, Optional, Tuple, Dict from typing import List, Optional, Tuple, Dict
from rich.console import Console
from cli.models import AnalystType from cli.models import AnalystType
from cli.api_keys import is_provider_available
from cli.model_fetcher import fetch_models_for_provider
console = Console()
ANALYST_ORDER = [ ANALYST_ORDER = [
("Market Analyst", AnalystType.MARKET), ("Market Analyst", AnalystType.MARKET),
@ -125,7 +130,7 @@ def select_research_depth() -> int:
def select_shallow_thinking_agent(provider) -> str: def select_shallow_thinking_agent(provider) -> str:
"""Select shallow thinking llm engine using an interactive selection.""" """Select shallow thinking llm engine using an interactive selection."""
# Define shallow thinking llm engine options with their corresponding model names # Static fallback options for each provider
SHALLOW_AGENT_OPTIONS = { SHALLOW_AGENT_OPTIONS = {
"openai": [ "openai": [
("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"), ("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"),
@ -142,24 +147,43 @@ def select_shallow_thinking_agent(provider) -> str:
"google": [ "google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), ("Gemini 2.5 Flash-Lite - Lightweight and cost efficient", "gemini-2.5-flash-lite"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
("Gemini 3 Flash Preview - Latest generation flash model", "gemini-3-flash-preview"),
], ],
"openrouter": [ "openrouter": [
("Xiaomi MiMo V2 Flash - Fast and efficient multimodal model", "xiaomi/mimo-v2-flash:free"),
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"), ("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"), ("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
], ],
"ollama": [ "ollama": [
("llama3.1 local", "llama3.1"), ("llama3.2:3b local", "llama3.2:3b"),
("llama3.2 local", "llama3.2"), ("phi3.5 local", "phi3.5:latest"),
],
"lm studio": [
("Local Model (default)", "local-model"),
] ]
} }
provider_lower = provider.lower()
# Try dynamic fetch for supported providers (OpenAI, Anthropic, Google)
model_options = None
if provider_lower in ["openai", "anthropic", "google"]:
dynamic_models = fetch_models_for_provider(provider_lower)
if dynamic_models:
model_options = dynamic_models
# Fall back to static list if dynamic fetch failed or not supported
if model_options is None:
model_options = SHALLOW_AGENT_OPTIONS.get(provider_lower, [])
choice = questionary.select( choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:", "Select Your [Quick-Thinking LLM Engine]:",
choices=[ choices=[
questionary.Choice(display, value=value) questionary.Choice(display, value=value)
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] for display, value in model_options
], ],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select", instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style( style=questionary.Style(
@ -183,7 +207,7 @@ def select_shallow_thinking_agent(provider) -> str:
def select_deep_thinking_agent(provider) -> str: def select_deep_thinking_agent(provider) -> str:
"""Select deep thinking llm engine using an interactive selection.""" """Select deep thinking llm engine using an interactive selection."""
# Define deep thinking llm engine options with their corresponding model names # Static fallback options for each provider
DEEP_AGENT_OPTIONS = { DEEP_AGENT_OPTIONS = {
"openai": [ "openai": [
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
@ -199,29 +223,46 @@ def select_deep_thinking_agent(provider) -> str:
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"), ("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"), ("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"), ("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"), ("Claude Opus 4 - Most powerful Anthropic model", "claude-opus-4-0"),
], ],
"google": [ "google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), ("Gemini 2.5 Flash-Lite - Lightweight and cost efficient", "gemini-2.5-flash-lite"),
("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"), ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
("Gemini 3 Flash Preview - Latest generation flash model", "gemini-3-flash-preview"),
], ],
"openrouter": [ "openrouter": [
("Xiaomi MiMo V2 Flash - Fast and efficient multimodal model", "xiaomi/mimo-v2-flash:free"),
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"), ("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
], ],
"ollama": [ "ollama": [
("llama3.1 local", "llama3.1"), ("llama3.2:3b local", "llama3.2:3b"),
("qwen3", "qwen3"), ("phi3.5 local", "phi3.5:latest"),
],
"lm studio": [
("Local Model (default)", "local-model"),
] ]
} }
provider_lower = provider.lower()
# Try dynamic fetch for supported providers (OpenAI, Anthropic, Google)
model_options = None
if provider_lower in ["openai", "anthropic", "google"]:
dynamic_models = fetch_models_for_provider(provider_lower)
if dynamic_models:
model_options = dynamic_models
# Fall back to static list if dynamic fetch failed or not supported
if model_options is None:
model_options = DEEP_AGENT_OPTIONS.get(provider_lower, [])
choice = questionary.select( choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:", "Select Your [Deep-Thinking LLM Engine]:",
choices=[ choices=[
questionary.Choice(display, value=value) questionary.Choice(display, value=value)
for display, value in DEEP_AGENT_OPTIONS[provider.lower()] for display, value in model_options
], ],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select", instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style( style=questionary.Style(
@ -240,22 +281,35 @@ def select_deep_thinking_agent(provider) -> str:
return choice return choice
def select_llm_provider() -> tuple[str, str]: def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection.""" """Select the LLM provider using interactive selection with availability checks."""
# Define OpenAI api options with their corresponding endpoints # Define provider options with their corresponding endpoints
BASE_URLS = [ BASE_URLS = [
("OpenAI", "https://api.openai.com/v1"), ("OpenAI", "https://api.openai.com/v1"),
("Anthropic", "https://api.anthropic.com/"), ("Anthropic", "https://api.anthropic.com/"),
("Google", "https://generativelanguage.googleapis.com/v1"), ("Google", "https://generativelanguage.googleapis.com/v1"),
("Openrouter", "https://openrouter.ai/api/v1"), ("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"), ("Ollama", "http://localhost:11434/v1"),
("LM Studio", "http://localhost:1234/v1"),
] ]
# Build choices with availability status
choices = []
for display, url in BASE_URLS:
available, reason = is_provider_available(display)
if available:
choices.append(questionary.Choice(display, value=(display, url)))
else:
# Show disabled option with reason
disabled_label = f"{display} ({reason})"
choices.append(questionary.Choice(
disabled_label,
value=(display, url),
disabled=reason
))
choice = questionary.select( choice = questionary.select(
"Select your LLM Provider:", "Select your LLM Provider:",
choices=[ choices=choices,
questionary.Choice(display, value=(display, value))
for display, value in BASE_URLS
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select", instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style( style=questionary.Style(
[ [
@ -267,7 +321,7 @@ def select_llm_provider() -> tuple[str, str]:
).ask() ).ask()
if choice is None: if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") console.print("\n[red]No LLM provider selected. Exiting...[/red]")
exit(1) exit(1)
display_name, url = choice display_name, url = choice

View File

@ -24,3 +24,5 @@ rich
questionary questionary
langchain_anthropic langchain_anthropic
langchain-google-genai langchain-google-genai
playwright
markdown2

View File

@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions, normalize_content
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -53,7 +53,7 @@ def create_fundamentals_analyst(llm):
report = "" report = ""
if len(result.tool_calls) == 0: if len(result.tool_calls) == 0:
report = result.content report = normalize_content(result.content)
return { return {
"messages": [result], "messages": [result],

View File

@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators, normalize_content
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -75,7 +75,7 @@ Volume-Based Indicators:
report = "" report = ""
if len(result.tool_calls) == 0: if len(result.tool_calls) == 0:
report = result.content report = normalize_content(result.content)
return { return {
"messages": [result], "messages": [result],

View File

@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
from tradingagents.agents.utils.agent_utils import get_news, get_global_news from tradingagents.agents.utils.agent_utils import get_news, get_global_news, normalize_content
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -48,7 +48,7 @@ def create_news_analyst(llm):
report = "" report = ""
if len(result.tool_calls) == 0: if len(result.tool_calls) == 0:
report = result.content report = normalize_content(result.content)
return { return {
"messages": [result], "messages": [result],

View File

@ -1,7 +1,7 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
from tradingagents.agents.utils.agent_utils import get_news from tradingagents.agents.utils.agent_utils import get_news, normalize_content
from tradingagents.dataflows.config import get_config from tradingagents.dataflows.config import get_config
@ -49,7 +49,7 @@ def create_social_media_analyst(llm):
report = "" report = ""
if len(result.tool_calls) == 0: if len(result.tool_calls) == 0:
report = result.content report = normalize_content(result.content)
return { return {
"messages": [result], "messages": [result],

View File

@ -1,5 +1,20 @@
from langchain_core.messages import HumanMessage, RemoveMessage from langchain_core.messages import HumanMessage, RemoveMessage
def normalize_content(content):
"""Normalize LLM response content to string.
Gemini returns content as a list of dicts with 'text' keys,
while OpenAI/Anthropic return a simple string.
"""
if isinstance(content, list):
return "".join(
block.get("text", "") if isinstance(block, dict) else str(block)
for block in content
)
return content
# Import tools from separate utility files # Import tools from separate utility files
from tradingagents.agents.utils.core_stock_tools import ( from tradingagents.agents.utils.core_stock_tools import (
get_stock_data get_stock_data

View File

@ -5,6 +5,38 @@ from typing import Dict, Optional
_config: Optional[Dict] = None _config: Optional[Dict] = None
DATA_DIR: Optional[str] = None DATA_DIR: Optional[str] = None
# Local LLM providers that don't support OpenAI's web_search_preview
LOCAL_LLM_PROVIDERS = ["ollama", "lm studio"]
# Methods that require OpenAI's web_search_preview tool
OPENAI_ONLY_METHODS = ["get_news", "get_global_news", "get_fundamentals"]
def validate_config(config: Dict):
"""Validate configuration and warn about incompatible settings."""
llm_provider = config.get("llm_provider", "").lower()
if llm_provider in LOCAL_LLM_PROVIDERS:
# Check data vendors
data_vendors = config.get("data_vendors", {})
tool_vendors = config.get("tool_vendors", {})
warnings = []
if data_vendors.get("news_data") == "openai":
warnings.append("data_vendors.news_data")
if data_vendors.get("fundamental_data") == "openai":
warnings.append("data_vendors.fundamental_data")
for method in OPENAI_ONLY_METHODS:
if tool_vendors.get(method) == "openai":
warnings.append(f"tool_vendors.{method}")
if warnings:
print(f"WARNING: Using local LLM provider '{llm_provider}' with 'openai' data vendors.")
print(f" The following settings use OpenAI's web_search_preview which is not available locally:")
for w in warnings:
print(f" - {w}")
print(f" Recommendation: Change these to 'alpha_vantage', 'google', or 'local'.")
def initialize_config(): def initialize_config():
"""Initialize the configuration with default values.""" """Initialize the configuration with default values."""
@ -21,6 +53,7 @@ def set_config(config: Dict):
_config = default_config.DEFAULT_CONFIG.copy() _config = default_config.DEFAULT_CONFIG.copy()
_config.update(config) _config.update(config)
DATA_DIR = _config["data_dir"] DATA_DIR = _config["data_dir"]
validate_config(_config)
def get_config() -> Dict: def get_config() -> Dict:

View File

@ -1,4 +1,5 @@
from typing import Annotated from typing import Annotated
import time
# Import from vendor-specific modules # Import from vendor-specific modules
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
@ -16,6 +17,7 @@ from .alpha_vantage import (
get_news as get_alpha_vantage_news get_news as get_alpha_vantage_news
) )
from .alpha_vantage_common import AlphaVantageRateLimitError from .alpha_vantage_common import AlphaVantageRateLimitError
from openai import APIConnectionError, APITimeoutError, RateLimitError
# Configuration and routing logic # Configuration and routing logic
from .config import get_config from .config import get_config
@ -194,25 +196,58 @@ def route_to_vendor(method: str, *args, **kwargs):
else: else:
vendor_methods = [(vendor_impl, vendor)] vendor_methods = [(vendor_impl, vendor)]
# Run methods for this vendor # Run methods for this vendor with retry logic
vendor_results = [] vendor_results = []
for impl_func, vendor_name in vendor_methods: for impl_func, vendor_name in vendor_methods:
max_retries = 3
base_delay = 1.0
last_error = None
for retry_attempt in range(max_retries):
try: try:
print(f"DEBUG: Calling {impl_func.__name__} from vendor '{vendor_name}'...") if retry_attempt > 0:
print(f"RETRY: Attempt {retry_attempt + 1}/{max_retries} "
f"for {impl_func.__name__}")
else:
print(f"DEBUG: Calling {impl_func.__name__} "
f"from vendor '{vendor_name}'...")
result = impl_func(*args, **kwargs) result = impl_func(*args, **kwargs)
vendor_results.append(result) vendor_results.append(result)
print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully") print(f"SUCCESS: {impl_func.__name__} from vendor "
f"'{vendor_name}' completed successfully")
last_error = None
break # Success, exit retry loop
except AlphaVantageRateLimitError as e: except (AlphaVantageRateLimitError, RateLimitError) as e:
if vendor == "alpha_vantage": print(f"RATE_LIMIT: {type(e).__name__} exceeded, falling back to next vendor.")
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor")
print(f"DEBUG: Rate limit details: {e}") print(f"DEBUG: Rate limit details: {e}")
# Continue to next vendor for fallback last_error = e
continue break # Don't retry rate limits, move to next vendor
except (ConnectionError, TimeoutError, OSError,
APIConnectionError, APITimeoutError) as e:
# Transient errors - retry with backoff
last_error = e
if retry_attempt < max_retries - 1:
delay = base_delay * (2 ** retry_attempt)
print(f"TRANSIENT_ERROR: {type(e).__name__} - {e}")
print(f"RETRY: Waiting {delay}s before retry...")
time.sleep(delay)
else:
print(f"FAILED: {impl_func.__name__} from vendor "
f"'{vendor_name}' failed after {max_retries} "
f"attempts: {e}")
except Exception as e: except Exception as e:
# Log error but continue with other implementations # Non-transient errors - don't retry
print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}") last_error = e
continue print(f"FAILED: {impl_func.__name__} from vendor "
f"'{vendor_name}' failed: {type(e).__name__}: {e}")
break
if last_error is not None:
continue # Move to next implementation
# Add this vendor's results # Add this vendor's results
if vendor_results: if vendor_results:

View File

@ -1,7 +1,46 @@
import os
from openai import OpenAI from openai import OpenAI
from .config import get_config from .config import get_config
def _get_web_search_tool_type() -> str:
"""Return the appropriate web search tool type based on FETCH_LATEST setting.
- FETCH_LATEST=true: Use 'web_search' (GA version, supports GPT-5)
- FETCH_LATEST=false/unset: Use 'web_search_preview' (legacy, wider compatibility)
"""
fetch_latest = os.getenv("FETCH_LATEST", "false").lower() in ("true", "1", "yes")
return "web_search" if fetch_latest else "web_search_preview"
def _extract_text_from_response(response):
"""Safely extract text content from OpenAI Responses API output.
The response.output array typically contains:
- output[0]: ResponseFunctionWebSearch (the web search call)
- output[1]: ResponseOutputMessage (the text response)
This function handles edge cases where the structure may differ.
"""
if not response.output:
raise RuntimeError("OpenAI response has empty output")
# Look for a message with text content
for item in response.output:
if hasattr(item, 'content') and item.content:
for content_block in item.content:
if hasattr(content_block, 'text') and content_block.text:
return content_block.text
# If we get here, no text was found
output_types = [type(item).__name__ for item in response.output]
raise RuntimeError(
f"No text content found in OpenAI response. "
f"Output types: {output_types}"
)
def get_stock_news_openai(query, start_date, end_date): def get_stock_news_openai(query, start_date, end_date):
config = get_config() config = get_config()
client = OpenAI(base_url=config["backend_url"]) client = OpenAI(base_url=config["backend_url"])
@ -23,7 +62,7 @@ def get_stock_news_openai(query, start_date, end_date):
reasoning={}, reasoning={},
tools=[ tools=[
{ {
"type": "web_search_preview", "type": _get_web_search_tool_type(),
"user_location": {"type": "approximate"}, "user_location": {"type": "approximate"},
"search_context_size": "low", "search_context_size": "low",
} }
@ -34,7 +73,7 @@ def get_stock_news_openai(query, start_date, end_date):
store=True, store=True,
) )
return response.output[1].content[0].text return _extract_text_from_response(response)
def get_global_news_openai(curr_date, look_back_days=7, limit=5): def get_global_news_openai(curr_date, look_back_days=7, limit=5):
@ -58,7 +97,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
reasoning={}, reasoning={},
tools=[ tools=[
{ {
"type": "web_search_preview", "type": _get_web_search_tool_type(),
"user_location": {"type": "approximate"}, "user_location": {"type": "approximate"},
"search_context_size": "low", "search_context_size": "low",
} }
@ -69,7 +108,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
store=True, store=True,
) )
return response.output[1].content[0].text return _extract_text_from_response(response)
def get_fundamentals_openai(ticker, curr_date): def get_fundamentals_openai(ticker, curr_date):
@ -93,7 +132,7 @@ def get_fundamentals_openai(ticker, curr_date):
reasoning={}, reasoning={},
tools=[ tools=[
{ {
"type": "web_search_preview", "type": _get_web_search_tool_type(),
"user_location": {"type": "approximate"}, "user_location": {"type": "approximate"},
"search_context_size": "low", "search_context_size": "low",
} }
@ -104,4 +143,4 @@ def get_fundamentals_openai(ticker, curr_date):
store=True, store=True,
) )
return response.output[1].content[0].text return _extract_text_from_response(response)

View File

@ -12,6 +12,8 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from tradingagents.llm import requires_responses_api, ChatOpenAIResponses
from tradingagents.agents import * from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.agents.utils.memory import FinancialSituationMemory
@ -72,9 +74,20 @@ class TradingAgentsGraph:
) )
# Initialize LLMs # Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": if self.config["llm_provider"].lower() in ["openai", "ollama", "openrouter", "lm studio"]:
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) # Select LLM class based on model - newer models require Responses API
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) deep_model = self.config["deep_think_llm"]
quick_model = self.config["quick_think_llm"]
if requires_responses_api(deep_model):
self.deep_thinking_llm = ChatOpenAIResponses(model=deep_model, base_url=self.config["backend_url"])
else:
self.deep_thinking_llm = ChatOpenAI(model=deep_model, base_url=self.config["backend_url"])
if requires_responses_api(quick_model):
self.quick_thinking_llm = ChatOpenAIResponses(model=quick_model, base_url=self.config["backend_url"])
else:
self.quick_thinking_llm = ChatOpenAI(model=quick_model, base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic": elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])

View File

@ -0,0 +1,9 @@
"""LLM wrapper utilities for TradingAgents.
This module provides custom LLM wrappers for different API endpoints.
"""
from tradingagents.llm.model_utils import requires_responses_api
from tradingagents.llm.openai_responses import ChatOpenAIResponses
__all__ = ["requires_responses_api", "ChatOpenAIResponses"]

View File

@ -0,0 +1,24 @@
"""Utility functions for model detection and selection."""
# Model prefixes that require the OpenAI Responses API (/v1/responses)
# instead of the Chat Completions API (/v1/chat/completions)
RESPONSES_API_PREFIXES = [
"gpt-5", # All GPT-5 variants (gpt-5, gpt-5.1, gpt-5.1-codex-mini, etc.)
"codex", # Codex models that use Responses API
]
def requires_responses_api(model_name: str) -> bool:
"""Check if a model requires the Responses API instead of Chat Completions.
Some newer OpenAI models only support the /v1/responses endpoint and will
return a 404 error if called via /v1/chat/completions.
Args:
model_name: The model identifier (e.g., "gpt-5.1-codex-mini", "gpt-4o")
Returns:
True if the model requires the Responses API, False otherwise.
"""
model_lower = model_name.lower()
return any(prefix in model_lower for prefix in RESPONSES_API_PREFIXES)

View File

@ -0,0 +1,307 @@
"""LangChain-compatible wrapper for OpenAI's Responses API.
This module provides ChatOpenAIResponses, a drop-in replacement for ChatOpenAI
that uses the /v1/responses endpoint instead of /v1/chat/completions.
This is required for newer models like gpt-5.1-codex-mini that only support
the Responses API.
"""
import json
import os
import uuid
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import BaseTool
from openai import OpenAI
from pydantic import Field
class ChatOpenAIResponses(BaseChatModel):
"""LangChain-compatible chat model using OpenAI's Responses API.
This class provides the same interface as ChatOpenAI but uses the
/v1/responses endpoint, which is required for certain newer models.
Example:
>>> llm = ChatOpenAIResponses(model="gpt-5.1-codex-mini")
>>> llm_with_tools = llm.bind_tools([my_tool])
>>> result = llm_with_tools.invoke([HumanMessage(content="Hello")])
"""
model: str = Field(default="gpt-5.1-codex-mini")
base_url: Optional[str] = Field(default=None)
api_key: Optional[str] = Field(default=None)
temperature: float = Field(default=1.0)
max_output_tokens: int = Field(default=4096)
top_p: float = Field(default=1.0)
# Internal state for tool binding
_bound_tools: List[Dict[str, Any]] = []
_client: Optional[OpenAI] = None
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._bound_tools = []
self._client = None
@property
def _llm_type(self) -> str:
return "openai-responses"
@property
def client(self) -> OpenAI:
"""Lazily initialize the OpenAI client."""
if self._client is None:
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
if self.base_url:
self._client = OpenAI(api_key=api_key, base_url=self.base_url)
else:
self._client = OpenAI(api_key=api_key)
return self._client
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], BaseTool]],
**kwargs: Any,
) -> "ChatOpenAIResponses":
"""Bind tools to this model instance.
Args:
tools: A sequence of tools to bind. Can be LangChain tools or dicts.
Returns:
A new ChatOpenAIResponses instance with the tools bound.
"""
new_instance = ChatOpenAIResponses(
model=self.model,
base_url=self.base_url,
api_key=self.api_key,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
)
new_instance._bound_tools = self._convert_tools(tools)
return new_instance
def _convert_tools(
self, tools: Sequence[Union[Dict[str, Any], BaseTool]]
) -> List[Dict[str, Any]]:
"""Convert LangChain tools to OpenAI Responses API function format.
The Responses API uses a flat structure for function tools:
{
"type": "function",
"name": "function_name",
"description": "...",
"parameters": {...}
}
This differs from Chat Completions which nests under "function" key.
"""
converted = []
for tool in tools:
if isinstance(tool, BaseTool):
# Get the JSON schema for parameters
if tool.args_schema:
params = tool.args_schema.model_json_schema()
# Remove title field that OpenAI doesn't expect at schema level
params.pop("title", None)
else:
params = {"type": "object", "properties": {}}
# Responses API uses flat structure - name at top level, not nested
tool_schema = {
"type": "function",
"name": tool.name,
"description": tool.description or "",
"parameters": params,
}
converted.append(tool_schema)
elif isinstance(tool, dict):
# Handle dict format - convert from Chat Completions format if needed
if "function" in tool:
# Chat Completions format - flatten it
func = tool["function"]
tool_schema = {
"type": "function",
"name": func.get("name", ""),
"description": func.get("description", ""),
"parameters": func.get("parameters", {"type": "object", "properties": {}}),
}
converted.append(tool_schema)
elif "name" in tool:
# Already in Responses API format
converted.append(tool)
else:
# Unknown format, try to use as-is
converted.append(tool)
return converted
def _convert_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
"""Convert LangChain messages to OpenAI Responses API format.
The Responses API uses a different message format than Chat Completions:
- System/user messages use 'input_text' content type
- Assistant messages use 'output_text' content type (no function_call in content)
- Tool calls from assistant are represented as separate 'function_call' items
- Tool results use 'function_call_output' content type
"""
converted = []
for msg in messages:
if isinstance(msg, SystemMessage):
content = msg.content if isinstance(msg.content, str) else str(msg.content)
converted.append({
"role": "system",
"content": [{"type": "input_text", "text": content}],
})
elif isinstance(msg, HumanMessage):
content = msg.content if isinstance(msg.content, str) else str(msg.content)
converted.append({
"role": "user",
"content": [{"type": "input_text", "text": content}],
})
elif isinstance(msg, AIMessage):
# Handle AI messages (assistant responses)
# First add text content if present
if msg.content:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
converted.append({
"role": "assistant",
"content": [{"type": "output_text", "text": content}],
})
# Tool calls need to be added as separate items in the Responses API
if hasattr(msg, 'tool_calls') and msg.tool_calls:
for tc in msg.tool_calls:
# Convert args to JSON string for the API
args = tc.get("args", {})
if isinstance(args, dict):
args_str = json.dumps(args)
else:
args_str = str(args)
# Add tool call as a separate item (not inside assistant content)
converted.append({
"type": "function_call",
"call_id": tc.get("id", str(uuid.uuid4())),
"name": tc["name"],
"arguments": args_str,
})
elif not msg.content:
# Empty assistant message - add placeholder
converted.append({
"role": "assistant",
"content": [{"type": "output_text", "text": ""}],
})
elif isinstance(msg, ToolMessage):
# Tool results need to be formatted as function call outputs
content = msg.content if isinstance(msg.content, str) else str(msg.content)
converted.append({
"type": "function_call_output",
"call_id": msg.tool_call_id,
"output": content,
})
return converted
def _parse_response(self, response: Any) -> AIMessage:
"""Parse OpenAI Responses API response into LangChain AIMessage."""
text_content = ""
tool_calls = []
if not response.output:
return AIMessage(content="")
for item in response.output:
# Handle text output
if hasattr(item, 'content') and item.content:
for block in item.content:
if hasattr(block, 'text') and block.text:
text_content += block.text
# Handle function/tool calls
if hasattr(item, 'type') and item.type == 'function_call':
args = item.arguments
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {"raw": args}
tool_calls.append({
"id": getattr(item, 'id', None) or getattr(item, 'call_id', None) or str(uuid.uuid4()),
"name": item.name,
"args": args,
})
if tool_calls:
return AIMessage(content=text_content, tool_calls=tool_calls)
return AIMessage(content=text_content)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response using the OpenAI Responses API.
Args:
messages: List of LangChain messages to send.
stop: Optional stop sequences (not used by Responses API).
run_manager: Optional callback manager.
Returns:
ChatResult containing the model's response.
"""
# Convert messages to Responses API format
converted_messages = self._convert_messages(messages)
# Build request parameters
request_params = {
"model": self.model,
"input": converted_messages,
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
"top_p": self.top_p,
}
# Add tools if bound
if self._bound_tools:
request_params["tools"] = self._bound_tools
# Make the API call
response = self.client.responses.create(**request_params)
# Parse the response
ai_message = self._parse_response(response)
return ChatResult(
generations=[ChatGeneration(message=ai_message)],
)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return identifying parameters for this LLM."""
return {
"model": self.model,
"base_url": self.base_url,
"temperature": self.temperature,
"max_output_tokens": self.max_output_tokens,
}