Merge b950e3a018 into 13b826a31d
This commit is contained in:
commit
41501a4e1b
17
.env.example
17
.env.example
|
|
@ -1,2 +1,17 @@
|
|||
# Data vendor API keys
|
||||
ALPHA_VANTAGE_API_KEY=alpha_vantage_api_key_placeholder
|
||||
OPENAI_API_KEY=openai_api_key_placeholder
|
||||
|
||||
# LLM Provider API keys (set the ones you want to use)
|
||||
OPENAI_API_KEY=openai_api_key_placeholder
|
||||
ANTHROPIC_API_KEY=anthropic_api_key_placeholder
|
||||
GEMINI_API_KEY=gemini_api_key_placeholder
|
||||
OPENROUTER_API_KEY=openrouter_api_key_placeholder
|
||||
|
||||
# Local LLM provider URLs (optional, defaults shown)
|
||||
# OLLAMA_URL=http://localhost:11434
|
||||
# LM_STUDIO_URL=http://localhost:1234
|
||||
|
||||
# Feature flags
|
||||
# Set to "true" to fetch latest models from APIs and use latest web_search tool
|
||||
# Set to "false" or leave unset for static model lists and web_search_preview (legacy)
|
||||
FETCH_LATEST=true
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
"""API key and endpoint validation for LLM providers."""
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
import httpx
|
||||
|
||||
|
||||
# Map cloud providers to their required environment variables
|
||||
PROVIDER_API_KEYS = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"google": "GEMINI_API_KEY",
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
}
|
||||
|
||||
# Default endpoints for local providers
|
||||
LOCAL_PROVIDER_DEFAULTS = {
|
||||
"ollama": ("OLLAMA_URL", "http://localhost:11434"),
|
||||
"lm studio": ("LM_STUDIO_URL", "http://localhost:1234"),
|
||||
}
|
||||
|
||||
|
||||
def get_api_key(provider: str) -> Optional[str]:
|
||||
"""Get API key for a cloud provider, returns None if not set."""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
# Special case: OpenRouter can use OPENROUTER_API_KEY or OPENAI_API_KEY with sk-or- prefix
|
||||
if provider_lower == "openrouter":
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if openrouter_key:
|
||||
return openrouter_key
|
||||
# Check if OPENAI_API_KEY is actually an OpenRouter key
|
||||
openai_key = os.getenv("OPENAI_API_KEY", "")
|
||||
if openai_key.startswith("sk-or-"):
|
||||
return openai_key
|
||||
return None
|
||||
|
||||
env_var = PROVIDER_API_KEYS.get(provider_lower)
|
||||
if env_var is None:
|
||||
return None
|
||||
return os.getenv(env_var)
|
||||
|
||||
|
||||
def get_local_endpoint(provider: str) -> Optional[str]:
|
||||
"""Get the endpoint URL for a local provider."""
|
||||
provider_lower = provider.lower()
|
||||
if provider_lower not in LOCAL_PROVIDER_DEFAULTS:
|
||||
return None
|
||||
|
||||
env_var, default_url = LOCAL_PROVIDER_DEFAULTS[provider_lower]
|
||||
return os.getenv(env_var, default_url)
|
||||
|
||||
|
||||
def is_local_provider_running(provider: str) -> bool:
|
||||
"""Check if a local provider (Ollama/LM Studio) is running by probing its endpoint."""
|
||||
endpoint = get_local_endpoint(provider)
|
||||
if not endpoint:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Probe the models endpoint with a short timeout
|
||||
response = httpx.get(
|
||||
f"{endpoint}/v1/models",
|
||||
timeout=1.0
|
||||
)
|
||||
return response.status_code == 200
|
||||
except (httpx.RequestError, httpx.TimeoutException):
|
||||
return False
|
||||
|
||||
|
||||
def is_provider_available(provider: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if a provider is available.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_available, reason_if_unavailable)
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
# Local providers: check if endpoint is reachable
|
||||
if provider_lower in LOCAL_PROVIDER_DEFAULTS:
|
||||
if is_local_provider_running(provider):
|
||||
return (True, "")
|
||||
return (False, "Not running")
|
||||
|
||||
# Cloud providers: check for API key
|
||||
if get_api_key(provider) is not None:
|
||||
return (True, "")
|
||||
return (False, "No API key")
|
||||
|
||||
|
||||
def get_all_provider_availability() -> dict:
|
||||
"""
|
||||
Get availability status for all providers.
|
||||
|
||||
Returns:
|
||||
Dict mapping provider name to (is_available, reason) tuple
|
||||
"""
|
||||
all_providers = list(PROVIDER_API_KEYS.keys()) + list(LOCAL_PROVIDER_DEFAULTS.keys())
|
||||
return {provider: is_provider_available(provider) for provider in all_providers}
|
||||
|
|
@ -0,0 +1,501 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compile all trading agent reports into a single consolidated PDF.
|
||||
|
||||
Creates a PDF with:
|
||||
1. Summary table showing all symbols, their decisions, and analysis dates
|
||||
2. Detailed reports for each symbol (in order specified by REPORT_ORDER)
|
||||
|
||||
Usage:
|
||||
python cli/compile_reports.py # Compile all results into single PDF
|
||||
python cli/compile_reports.py --output report.pdf # Custom output filename
|
||||
python cli/compile_reports.py --date 2026-01-18 # Filter to specific date (auto-names output)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import markdown2
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
# Report order (top to bottom for each symbol's section)
|
||||
REPORT_ORDER = [
|
||||
("final_trade_decision.md", "Final Trade Decision"),
|
||||
("trader_investment_plan.md", "Trader Investment Plan"),
|
||||
("investment_plan.md", "Investment Plan"),
|
||||
("fundamentals_report.md", "Fundamentals Analysis"),
|
||||
("news_report.md", "News Analysis"),
|
||||
("sentiment_report.md", "Sentiment Analysis"),
|
||||
("market_report.md", "Market Analysis"),
|
||||
]
|
||||
|
||||
# Clean GitHub-style markdown CSS
|
||||
CSS_STYLES = """
|
||||
@page {
|
||||
size: A4;
|
||||
margin: 0.75in;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Noto Sans', Helvetica, Arial, sans-serif;
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
color: #24292f;
|
||||
max-width: 100%;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 2em;
|
||||
font-weight: 600;
|
||||
border-bottom: 1px solid #d0d7de;
|
||||
padding-bottom: 0.3em;
|
||||
margin-top: 24px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
h2 {
|
||||
font-size: 1.5em;
|
||||
font-weight: 600;
|
||||
border-bottom: 1px solid #d0d7de;
|
||||
padding-bottom: 0.3em;
|
||||
margin-top: 24px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 1.25em;
|
||||
font-weight: 600;
|
||||
margin-top: 24px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
h4 {
|
||||
font-size: 1em;
|
||||
font-weight: 600;
|
||||
margin-top: 24px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
p {
|
||||
margin-top: 0;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
ul, ol {
|
||||
padding-left: 2em;
|
||||
margin-top: 0;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
li {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
li + li {
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
table {
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
margin-top: 0;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
th, td {
|
||||
padding: 6px 13px;
|
||||
border: 1px solid #d0d7de;
|
||||
}
|
||||
|
||||
th {
|
||||
background-color: #f6f8fa;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
tr:nth-child(2n) {
|
||||
background-color: #f6f8fa;
|
||||
}
|
||||
|
||||
hr {
|
||||
border: 0;
|
||||
border-top: 1px solid #d0d7de;
|
||||
margin: 24px 0;
|
||||
}
|
||||
|
||||
code {
|
||||
background-color: rgba(175, 184, 193, 0.2);
|
||||
padding: 0.2em 0.4em;
|
||||
border-radius: 6px;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Menlo, Consolas, monospace;
|
||||
font-size: 85%;
|
||||
}
|
||||
|
||||
pre {
|
||||
background-color: #f6f8fa;
|
||||
padding: 16px;
|
||||
border-radius: 6px;
|
||||
overflow-x: auto;
|
||||
margin-bottom: 16px;
|
||||
font-size: 85%;
|
||||
line-height: 1.45;
|
||||
}
|
||||
|
||||
pre code {
|
||||
padding: 0;
|
||||
background: none;
|
||||
font-size: 100%;
|
||||
}
|
||||
|
||||
blockquote {
|
||||
border-left: 0.25em solid #d0d7de;
|
||||
padding: 0 1em;
|
||||
margin: 0 0 16px 0;
|
||||
color: #57606a;
|
||||
}
|
||||
|
||||
strong {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
/* Decision color styling */
|
||||
.decision-buy { color: #1a7f37; font-weight: 700; }
|
||||
.decision-sell { color: #cf222e; font-weight: 700; }
|
||||
.decision-hold { color: #9a6700; font-weight: 700; }
|
||||
|
||||
/* Symbol section - page break before each new symbol */
|
||||
.symbol-section {
|
||||
page-break-before: always;
|
||||
}
|
||||
|
||||
.symbol-section:first-of-type {
|
||||
page-break-before: avoid;
|
||||
}
|
||||
|
||||
/* Report title styling */
|
||||
.report-title {
|
||||
color: #0969da;
|
||||
font-size: 1.3em;
|
||||
font-weight: 600;
|
||||
margin-top: 32px;
|
||||
margin-bottom: 16px;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 2px solid #0969da;
|
||||
}
|
||||
|
||||
.report-title:first-of-type {
|
||||
margin-top: 16px;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def extract_decision(content: str) -> str:
|
||||
"""Extract BUY/SELL/HOLD decision from final trade decision content."""
|
||||
content_lower = content.lower()
|
||||
|
||||
patterns = [
|
||||
r"recommendation[:\s]*\*{0,2}(buy|sell|hold)\*{0,2}",
|
||||
r"\*{0,2}(buy|sell|hold)\*{0,2}[:\s]*recommendation",
|
||||
r"final.*?decision[:\s]*\*{0,2}(buy|sell|hold)\*{0,2}",
|
||||
r"recommend.*?(buy|sell|hold)",
|
||||
r"action[:\s]*\*{0,2}(buy|sell|hold)\*{0,2}",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, content_lower)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
|
||||
buy_count = len(re.findall(r"\bbuy\b", content_lower))
|
||||
sell_count = len(re.findall(r"\bsell\b", content_lower))
|
||||
hold_count = len(re.findall(r"\bhold\b", content_lower))
|
||||
|
||||
max_count = max(buy_count, sell_count, hold_count)
|
||||
if max_count > 0:
|
||||
if sell_count == max_count:
|
||||
return "SELL"
|
||||
if buy_count == max_count:
|
||||
return "BUY"
|
||||
return "HOLD"
|
||||
|
||||
return "N/A"
|
||||
|
||||
|
||||
def markdown_to_html(md_content: str) -> str:
|
||||
"""Convert markdown to HTML with extras."""
|
||||
return markdown2.markdown(
|
||||
md_content,
|
||||
extras=[
|
||||
"tables",
|
||||
"fenced-code-blocks",
|
||||
"strike",
|
||||
"task_list",
|
||||
"cuddled-lists",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def find_all_reports(results_dir: Path, date_filter: str | None = None) -> list[dict]:
|
||||
"""Find all report directories and extract their data.
|
||||
|
||||
Args:
|
||||
results_dir: Path to the results directory
|
||||
date_filter: Optional date string (YYYY-MM-DD) to filter reports
|
||||
"""
|
||||
all_reports = []
|
||||
|
||||
if not results_dir.exists():
|
||||
return all_reports
|
||||
|
||||
for symbol_dir in sorted(results_dir.iterdir()):
|
||||
if not symbol_dir.is_dir():
|
||||
continue
|
||||
|
||||
symbol = symbol_dir.name
|
||||
if symbol.startswith(".") or " " in symbol:
|
||||
continue
|
||||
|
||||
for date_dir in sorted(symbol_dir.iterdir(), reverse=True):
|
||||
if not date_dir.is_dir():
|
||||
continue
|
||||
|
||||
date = date_dir.name
|
||||
|
||||
# Skip if date doesn't match filter
|
||||
if date_filter and date != date_filter:
|
||||
continue
|
||||
|
||||
reports_dir = date_dir / "reports"
|
||||
|
||||
if not reports_dir.exists():
|
||||
continue
|
||||
|
||||
report_files = []
|
||||
decision = "N/A"
|
||||
|
||||
for filename, title in REPORT_ORDER:
|
||||
file_path = reports_dir / filename
|
||||
if file_path.exists():
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
html_content = markdown_to_html(content)
|
||||
report_files.append((filename, title, html_content))
|
||||
|
||||
if filename == "final_trade_decision.md":
|
||||
decision = extract_decision(content)
|
||||
|
||||
if report_files:
|
||||
all_reports.append({
|
||||
"symbol": symbol,
|
||||
"date": date,
|
||||
"decision": decision,
|
||||
"reports_dir": reports_dir,
|
||||
"reports": report_files,
|
||||
})
|
||||
|
||||
return all_reports
|
||||
|
||||
|
||||
def build_html_document(all_reports: list[dict]) -> str:
|
||||
"""Build complete HTML document with summary table and all reports."""
|
||||
|
||||
# Build summary table rows
|
||||
summary_rows = []
|
||||
for report_data in all_reports:
|
||||
decision = report_data["decision"]
|
||||
decision_class = f"decision-{decision.lower()}" if decision in ["BUY", "SELL", "HOLD"] else ""
|
||||
summary_rows.append(f'''<tr>
|
||||
<td><strong>{report_data["symbol"]}</strong></td>
|
||||
<td>{report_data["date"]}</td>
|
||||
<td class="{decision_class}">{decision}</td>
|
||||
<td>{len(report_data["reports"])} reports</td>
|
||||
</tr>''')
|
||||
|
||||
summary_table = "\n".join(summary_rows)
|
||||
|
||||
# Build symbol sections
|
||||
symbol_sections = []
|
||||
for report_data in all_reports:
|
||||
symbol = report_data["symbol"]
|
||||
date = report_data["date"]
|
||||
decision = report_data["decision"]
|
||||
decision_class = f"decision-{decision.lower()}" if decision in ["BUY", "SELL", "HOLD"] else ""
|
||||
|
||||
# Build report content - simple flowing structure
|
||||
reports_html_parts = []
|
||||
for _, title, html_content in report_data["reports"]:
|
||||
reports_html_parts.append(f'''<div class="report-title">{title}</div>
|
||||
{html_content}''')
|
||||
|
||||
reports_html = "\n".join(reports_html_parts)
|
||||
|
||||
symbol_sections.append(f'''<div class="symbol-section">
|
||||
<h1>{symbol} Trading Analysis Report</h1>
|
||||
<p><strong>Date:</strong> {date} | <strong>Recommendation:</strong> <span class="{decision_class}">{decision}</span></p>
|
||||
<hr>
|
||||
{reports_html}
|
||||
</div>''')
|
||||
|
||||
all_symbols_html = "\n".join(symbol_sections)
|
||||
|
||||
generated_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
html = f'''<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Trading Analysis Report</title>
|
||||
<style>
|
||||
{CSS_STYLES}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Trading Analysis Report</h1>
|
||||
<p><em>Generated: {generated_date}</em></p>
|
||||
|
||||
<h2>Summary of Recommendations</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Symbol</th>
|
||||
<th>Analysis Date</th>
|
||||
<th>Decision</th>
|
||||
<th>Reports</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{summary_table}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<hr>
|
||||
|
||||
{all_symbols_html}
|
||||
|
||||
<hr>
|
||||
<p><em>Report generated by TradingAgents</em></p>
|
||||
</body>
|
||||
</html>'''
|
||||
return html
|
||||
|
||||
|
||||
def compile_to_pdf(html_content: str, output_path: Path) -> bool:
|
||||
"""Generate PDF from HTML using Playwright."""
|
||||
try:
|
||||
with sync_playwright() as p:
|
||||
browser = p.chromium.launch()
|
||||
page = browser.new_page()
|
||||
page.set_content(html_content, wait_until="networkidle")
|
||||
|
||||
page.pdf(
|
||||
path=str(output_path),
|
||||
format="A4",
|
||||
margin={
|
||||
"top": "0.5in",
|
||||
"bottom": "0.5in",
|
||||
"left": "0.5in",
|
||||
"right": "0.5in",
|
||||
},
|
||||
print_background=True,
|
||||
)
|
||||
browser.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error generating PDF: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compile all trading agent reports into a single consolidated PDF",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python cli/compile_reports.py
|
||||
python cli/compile_reports.py --output my_report.pdf
|
||||
python cli/compile_reports.py --date 2026-01-18
|
||||
python cli/compile_reports.py --date 2026-01-18 --output custom.pdf
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
default="./results/trading_analysis_report.pdf",
|
||||
help="Output PDF filename (default: ./results/trading_analysis_report.pdf)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--results-dir", "-r",
|
||||
default="./results",
|
||||
help="Results directory (default: ./results)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--date", "-d",
|
||||
help="Filter reports to a specific date (format: YYYY-MM-DD)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate date format if provided
|
||||
if args.date:
|
||||
import re as re_module
|
||||
if not re_module.match(r"^\d{4}-\d{2}-\d{2}$", args.date):
|
||||
print(f"Error: Invalid date format '{args.date}'. Expected YYYY-MM-DD")
|
||||
sys.exit(1)
|
||||
|
||||
results_dir = Path(args.results_dir)
|
||||
default_output = "./results/trading_analysis_report.pdf"
|
||||
|
||||
if not results_dir.exists():
|
||||
print(f"Error: Results directory not found: {results_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.date:
|
||||
print(f"Scanning {results_dir} for reports on {args.date}...")
|
||||
else:
|
||||
print(f"Scanning {results_dir} for reports...")
|
||||
|
||||
all_reports = find_all_reports(results_dir, date_filter=args.date)
|
||||
|
||||
if not all_reports:
|
||||
if args.date:
|
||||
print(f"No reports found for date {args.date}")
|
||||
else:
|
||||
print("No reports found")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(all_reports)} symbol analysis report(s):\n")
|
||||
|
||||
for report_data in all_reports:
|
||||
decision_indicator = {
|
||||
"BUY": "[BUY]",
|
||||
"SELL": "[SELL]",
|
||||
"HOLD": "[HOLD]",
|
||||
}.get(report_data["decision"], "[N/A]")
|
||||
|
||||
print(f" {report_data['symbol']:6} | {report_data['date']} | {decision_indicator:6} | {len(report_data['reports'])} reports")
|
||||
|
||||
# Determine output path
|
||||
if args.date and args.output == default_output:
|
||||
# Generate dynamic filename from date + symbols (up to 5)
|
||||
symbols = [r["symbol"] for r in all_reports[:5]]
|
||||
symbols_str = "_".join(symbols)
|
||||
output_path = Path(f"./results/trading_report_{args.date}_{symbols_str}.pdf")
|
||||
else:
|
||||
output_path = Path(args.output)
|
||||
|
||||
print("\nGenerating PDF...")
|
||||
|
||||
html_document = build_html_document(all_reports)
|
||||
|
||||
if compile_to_pdf(html_document, output_path):
|
||||
print(f"\n+ PDF created: {output_path}")
|
||||
else:
|
||||
print("\n- Failed to create PDF")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
55
cli/main.py
55
cli/main.py
|
|
@ -429,10 +429,12 @@ def get_user_selections():
|
|||
box_content += f"\n[dim]Default: {default}[/dim]"
|
||||
return Panel(box_content, border_style="blue", padding=(1, 2))
|
||||
|
||||
# Step 1: Ticker symbol
|
||||
# Step 1: Ticker symbol(s)
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
|
||||
"Step 1: Ticker Symbol(s)",
|
||||
"Enter ticker symbol(s) to analyze (comma-separated for multiple)",
|
||||
"SPY",
|
||||
)
|
||||
)
|
||||
selected_ticker = get_ticker()
|
||||
|
|
@ -475,13 +477,20 @@ def get_user_selections():
|
|||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
||||
# Step 6: Thinking agents
|
||||
# Step 6: Quick-Thinking LLM Engine
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 6: Thinking Agents", "Select your thinking agents for analysis"
|
||||
"Step 6: Quick-Thinking LLM Engine", "Select your quick-thinking model for fast operations"
|
||||
)
|
||||
)
|
||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||
|
||||
# Step 7: Deep-Thinking LLM Engine
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Deep-Thinking LLM Engine", "Select your deep-thinking model for complex reasoning"
|
||||
)
|
||||
)
|
||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
||||
|
||||
return {
|
||||
|
|
@ -497,8 +506,11 @@ def get_user_selections():
|
|||
|
||||
|
||||
def get_ticker():
|
||||
"""Get ticker symbol from user input."""
|
||||
return typer.prompt("", default="SPY")
|
||||
"""Get ticker symbol(s) from user input. Supports comma-separated symbols."""
|
||||
raw_input = typer.prompt("", default="SPY")
|
||||
# Split by comma, strip whitespace, convert to uppercase
|
||||
symbols = [s.strip().upper() for s in raw_input.split(",") if s.strip()]
|
||||
return symbols if len(symbols) > 1 else symbols[0]
|
||||
|
||||
|
||||
def get_analysis_date():
|
||||
|
|
@ -736,6 +748,7 @@ def extract_content_string(content):
|
|||
return str(content)
|
||||
|
||||
def run_analysis():
|
||||
"""Run analysis for one or more ticker symbols."""
|
||||
# First get all user selections
|
||||
selections = get_user_selections()
|
||||
|
||||
|
|
@ -748,13 +761,33 @@ def run_analysis():
|
|||
config["backend_url"] = selections["backend_url"]
|
||||
config["llm_provider"] = selections["llm_provider"].lower()
|
||||
|
||||
# Initialize the graph
|
||||
# Normalize ticker(s) to list
|
||||
tickers = selections["ticker"] if isinstance(selections["ticker"], list) else [selections["ticker"]]
|
||||
|
||||
# Initialize the graph once and reuse for all symbols
|
||||
graph = TradingAgentsGraph(
|
||||
[analyst.value for analyst in selections["analysts"]], config=config, debug=True
|
||||
)
|
||||
|
||||
for i, ticker in enumerate(tickers, 1):
|
||||
if len(tickers) > 1:
|
||||
console.print(f"\n[bold cyan]{'═' * 50}[/bold cyan]")
|
||||
console.print(f"[bold cyan] Analyzing {ticker} ({i}/{len(tickers)})[/bold cyan]")
|
||||
console.print(f"[bold cyan]{'═' * 50}[/bold cyan]\n")
|
||||
|
||||
run_single_analysis(ticker, selections, config, graph)
|
||||
|
||||
if i < len(tickers):
|
||||
console.print(f"\n[dim]Moving to next symbol...[/dim]\n")
|
||||
|
||||
if len(tickers) > 1:
|
||||
console.print(f"\n[bold green]Completed analysis for all {len(tickers)} symbols: {', '.join(tickers)}[/bold green]")
|
||||
|
||||
|
||||
def run_single_analysis(ticker: str, selections: dict, config: dict, graph: TradingAgentsGraph):
|
||||
"""Run analysis for a single ticker symbol."""
|
||||
# Create result directory
|
||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||
results_dir = Path(config["results_dir"]) / ticker / selections["analysis_date"]
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
report_dir = results_dir / "reports"
|
||||
report_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -808,7 +841,7 @@ def run_analysis():
|
|||
update_display(layout)
|
||||
|
||||
# Add initial messages
|
||||
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
|
||||
message_buffer.add_message("System", f"Selected ticker: {ticker}")
|
||||
message_buffer.add_message(
|
||||
"System", f"Analysis date: {selections['analysis_date']}"
|
||||
)
|
||||
|
|
@ -835,13 +868,13 @@ def run_analysis():
|
|||
|
||||
# Create spinner text
|
||||
spinner_text = (
|
||||
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
|
||||
f"Analyzing {ticker} on {selections['analysis_date']}..."
|
||||
)
|
||||
update_display(layout, spinner_text)
|
||||
|
||||
# Initialize state and get graph args
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
selections["ticker"], selections["analysis_date"]
|
||||
ticker, selections["analysis_date"]
|
||||
)
|
||||
args = graph.propagator.get_graph_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,225 @@
|
|||
"""Dynamic model fetching from LLM provider APIs with caching."""
|
||||
|
||||
import os
|
||||
from typing import List, Tuple, Optional
|
||||
import httpx
|
||||
|
||||
# Cache for fetched models (provider -> list of models)
|
||||
_model_cache: dict = {}
|
||||
|
||||
# Maximum number of models to display (None = no limit, show all)
|
||||
MAX_MODELS = None
|
||||
|
||||
|
||||
def is_fetch_latest() -> bool:
|
||||
"""Check if FETCH_LATEST is enabled in environment.
|
||||
|
||||
When enabled, fetches models dynamically from provider APIs.
|
||||
When disabled, falls back to static hardcoded model lists.
|
||||
"""
|
||||
return os.getenv("FETCH_LATEST", "false").lower() in ("true", "1", "yes")
|
||||
|
||||
|
||||
def fetch_openai_models() -> Optional[List[Tuple[str, str]]]:
|
||||
"""
|
||||
Fetch available models from OpenAI API, sorted by creation date (newest first).
|
||||
|
||||
Returns:
|
||||
List of (display_name, model_id) tuples, or None on failure
|
||||
"""
|
||||
if "openai" in _model_cache:
|
||||
return _model_cache["openai"]
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key or api_key.startswith("sk-or-"):
|
||||
return None
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
timeout=10.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
models_data = response.json().get("data", [])
|
||||
|
||||
# Filter to chat/reasoning models and keep metadata for sorting
|
||||
chat_models = []
|
||||
for model in models_data:
|
||||
model_id = model.get("id", "")
|
||||
created = model.get("created", 0)
|
||||
|
||||
# Include GPT models and reasoning models (o-series)
|
||||
if (model_id.startswith("gpt-") or
|
||||
model_id.startswith("o1") or
|
||||
model_id.startswith("o3") or
|
||||
model_id.startswith("o4") or
|
||||
model_id.startswith("o5") or
|
||||
model_id.startswith("gpt-5")):
|
||||
# Skip snapshot/dated versions to keep list clean
|
||||
if "-20" not in model_id and "-preview" not in model_id.lower():
|
||||
chat_models.append((model_id, created))
|
||||
|
||||
# Remove duplicates (keep highest created timestamp for each model_id)
|
||||
model_dict = {}
|
||||
for model_id, created in chat_models:
|
||||
if model_id not in model_dict or created > model_dict[model_id]:
|
||||
model_dict[model_id] = created
|
||||
|
||||
# Sort by created timestamp (newest first) and limit
|
||||
sorted_models = sorted(model_dict.items(), key=lambda x: -x[1])[:MAX_MODELS]
|
||||
result = [(model_id, model_id) for model_id, _ in sorted_models]
|
||||
|
||||
_model_cache["openai"] = result
|
||||
return result
|
||||
except (httpx.RequestError, httpx.HTTPStatusError, ValueError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def fetch_anthropic_models() -> Optional[List[Tuple[str, str]]]:
|
||||
"""
|
||||
Fetch available models from Anthropic API, sorted by creation date (newest first).
|
||||
|
||||
Returns:
|
||||
List of (display_name, model_id) tuples, or None on failure
|
||||
"""
|
||||
if "anthropic" in _model_cache:
|
||||
return _model_cache["anthropic"]
|
||||
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
},
|
||||
timeout=10.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
models_data = response.json().get("data", [])
|
||||
|
||||
# Filter to Claude models and keep metadata for sorting
|
||||
claude_models = []
|
||||
for model in models_data:
|
||||
model_id = model.get("id", "")
|
||||
# Anthropic API returns created_at as ISO string (RFC 3339)
|
||||
created_at = model.get("created_at", "")
|
||||
display_name = model.get("display_name", "")
|
||||
|
||||
if model_id.startswith("claude-"):
|
||||
# Skip dated versions (e.g., claude-3-sonnet-20240229)
|
||||
if "-20" not in model_id:
|
||||
# Use display_name if available, otherwise model_id
|
||||
label = display_name if display_name else model_id
|
||||
claude_models.append((model_id, label, created_at))
|
||||
|
||||
# Remove duplicates (keep latest for each model_id)
|
||||
model_dict = {}
|
||||
for model_id, label, created_at in claude_models:
|
||||
if model_id not in model_dict or created_at > model_dict[model_id][1]:
|
||||
model_dict[model_id] = (label, created_at)
|
||||
|
||||
# Sort by created_at (newest first) and limit
|
||||
sorted_models = sorted(model_dict.items(), key=lambda x: x[1][1], reverse=True)[:MAX_MODELS]
|
||||
result = [(label, model_id) for model_id, (label, _) in sorted_models]
|
||||
|
||||
_model_cache["anthropic"] = result
|
||||
return result
|
||||
except (httpx.RequestError, httpx.HTTPStatusError, ValueError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def fetch_google_models() -> Optional[List[Tuple[str, str]]]:
|
||||
"""
|
||||
Fetch available models from Google Generative AI API.
|
||||
Uses displayName for user-friendly labels, sorted as returned by API (typically newest first).
|
||||
|
||||
Returns:
|
||||
List of (display_name, model_id) tuples, or None on failure
|
||||
"""
|
||||
if "google" in _model_cache:
|
||||
return _model_cache["google"]
|
||||
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
f"https://generativelanguage.googleapis.com/v1/models?key={api_key}",
|
||||
timeout=10.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
models_data = response.json().get("models", [])
|
||||
|
||||
# Filter to Gemini models that support generateContent
|
||||
gemini_models = []
|
||||
for model in models_data:
|
||||
model_name = model.get("name", "")
|
||||
display_name = model.get("displayName", "")
|
||||
supported_methods = model.get("supportedGenerationMethods", [])
|
||||
|
||||
# Extract model ID from "models/gemini-..." format
|
||||
if model_name.startswith("models/"):
|
||||
model_id = model_name.replace("models/", "")
|
||||
else:
|
||||
model_id = model_name
|
||||
|
||||
# Only include Gemini models that support content generation
|
||||
if model_id.startswith("gemini") and "generateContent" in supported_methods:
|
||||
# Use displayName if available, otherwise model_id
|
||||
label = display_name if display_name else model_id
|
||||
gemini_models.append((label, model_id))
|
||||
|
||||
# API returns in a reasonable order, just dedupe and limit
|
||||
seen = set()
|
||||
unique_models = []
|
||||
for label, model_id in gemini_models:
|
||||
if model_id not in seen:
|
||||
seen.add(model_id)
|
||||
unique_models.append((label, model_id))
|
||||
|
||||
result = unique_models[:MAX_MODELS]
|
||||
|
||||
_model_cache["google"] = result
|
||||
return result
|
||||
except (httpx.RequestError, httpx.HTTPStatusError, ValueError, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def fetch_models_for_provider(provider: str) -> Optional[List[Tuple[str, str]]]:
|
||||
"""
|
||||
Fetch models for a given provider.
|
||||
|
||||
Only fetches dynamically if FETCH_LATEST is enabled. Otherwise returns None
|
||||
to trigger fallback to static model lists.
|
||||
|
||||
Args:
|
||||
provider: Provider name (openai, anthropic, google)
|
||||
|
||||
Returns:
|
||||
List of (display_name, model_id) tuples, or None if not supported/failed
|
||||
"""
|
||||
# Return None if FETCH_LATEST is not enabled - will use static lists
|
||||
if not is_fetch_latest():
|
||||
return None
|
||||
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower == "openai":
|
||||
return fetch_openai_models()
|
||||
elif provider_lower == "anthropic":
|
||||
return fetch_anthropic_models()
|
||||
elif provider_lower == "google":
|
||||
return fetch_google_models()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def clear_cache():
|
||||
"""Clear the model cache."""
|
||||
_model_cache.clear()
|
||||
106
cli/utils.py
106
cli/utils.py
|
|
@ -1,7 +1,12 @@
|
|||
import questionary
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
from rich.console import Console
|
||||
|
||||
from cli.models import AnalystType
|
||||
from cli.api_keys import is_provider_available
|
||||
from cli.model_fetcher import fetch_models_for_provider
|
||||
|
||||
console = Console()
|
||||
|
||||
ANALYST_ORDER = [
|
||||
("Market Analyst", AnalystType.MARKET),
|
||||
|
|
@ -125,7 +130,7 @@ def select_research_depth() -> int:
|
|||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
|
||||
# Define shallow thinking llm engine options with their corresponding model names
|
||||
# Static fallback options for each provider
|
||||
SHALLOW_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"),
|
||||
|
|
@ -142,24 +147,43 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
"google": [
|
||||
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
|
||||
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
|
||||
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
|
||||
("Gemini 2.5 Flash-Lite - Lightweight and cost efficient", "gemini-2.5-flash-lite"),
|
||||
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
|
||||
("Gemini 3 Flash Preview - Latest generation flash model", "gemini-3-flash-preview"),
|
||||
],
|
||||
"openrouter": [
|
||||
("Xiaomi MiMo V2 Flash - Fast and efficient multimodal model", "xiaomi/mimo-v2-flash:free"),
|
||||
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
|
||||
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
|
||||
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
|
||||
],
|
||||
"ollama": [
|
||||
("llama3.1 local", "llama3.1"),
|
||||
("llama3.2 local", "llama3.2"),
|
||||
("llama3.2:3b local", "llama3.2:3b"),
|
||||
("phi3.5 local", "phi3.5:latest"),
|
||||
],
|
||||
"lm studio": [
|
||||
("Local Model (default)", "local-model"),
|
||||
]
|
||||
}
|
||||
|
||||
provider_lower = provider.lower()
|
||||
|
||||
# Try dynamic fetch for supported providers (OpenAI, Anthropic, Google)
|
||||
model_options = None
|
||||
if provider_lower in ["openai", "anthropic", "google"]:
|
||||
dynamic_models = fetch_models_for_provider(provider_lower)
|
||||
if dynamic_models:
|
||||
model_options = dynamic_models
|
||||
|
||||
# Fall back to static list if dynamic fetch failed or not supported
|
||||
if model_options is None:
|
||||
model_options = SHALLOW_AGENT_OPTIONS.get(provider_lower, [])
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Quick-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
|
||||
for display, value in model_options
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -183,7 +207,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
def select_deep_thinking_agent(provider) -> str:
|
||||
"""Select deep thinking llm engine using an interactive selection."""
|
||||
|
||||
# Define deep thinking llm engine options with their corresponding model names
|
||||
# Static fallback options for each provider
|
||||
DEEP_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
|
||||
|
|
@ -199,29 +223,46 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
|
||||
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
|
||||
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
|
||||
("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"),
|
||||
("Claude Opus 4 - Most powerful Anthropic model", "claude-opus-4-0"),
|
||||
],
|
||||
"google": [
|
||||
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
|
||||
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
|
||||
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
|
||||
("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"),
|
||||
("Gemini 2.5 Flash-Lite - Lightweight and cost efficient", "gemini-2.5-flash-lite"),
|
||||
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
|
||||
("Gemini 3 Flash Preview - Latest generation flash model", "gemini-3-flash-preview"),
|
||||
],
|
||||
"openrouter": [
|
||||
("Xiaomi MiMo V2 Flash - Fast and efficient multimodal model", "xiaomi/mimo-v2-flash:free"),
|
||||
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
|
||||
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
|
||||
],
|
||||
"ollama": [
|
||||
("llama3.1 local", "llama3.1"),
|
||||
("qwen3", "qwen3"),
|
||||
("llama3.2:3b local", "llama3.2:3b"),
|
||||
("phi3.5 local", "phi3.5:latest"),
|
||||
],
|
||||
"lm studio": [
|
||||
("Local Model (default)", "local-model"),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
provider_lower = provider.lower()
|
||||
|
||||
# Try dynamic fetch for supported providers (OpenAI, Anthropic, Google)
|
||||
model_options = None
|
||||
if provider_lower in ["openai", "anthropic", "google"]:
|
||||
dynamic_models = fetch_models_for_provider(provider_lower)
|
||||
if dynamic_models:
|
||||
model_options = dynamic_models
|
||||
|
||||
# Fall back to static list if dynamic fetch failed or not supported
|
||||
if model_options is None:
|
||||
model_options = DEEP_AGENT_OPTIONS.get(provider_lower, [])
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Deep-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
|
||||
for display, value in model_options
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -240,22 +281,35 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
return choice
|
||||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
"""Select the OpenAI api url using interactive selection."""
|
||||
# Define OpenAI api options with their corresponding endpoints
|
||||
"""Select the LLM provider using interactive selection with availability checks."""
|
||||
# Define provider options with their corresponding endpoints
|
||||
BASE_URLS = [
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
("LM Studio", "http://localhost:1234/v1"),
|
||||
]
|
||||
|
||||
|
||||
# Build choices with availability status
|
||||
choices = []
|
||||
for display, url in BASE_URLS:
|
||||
available, reason = is_provider_available(display)
|
||||
if available:
|
||||
choices.append(questionary.Choice(display, value=(display, url)))
|
||||
else:
|
||||
# Show disabled option with reason
|
||||
disabled_label = f"{display} ({reason})"
|
||||
choices.append(questionary.Choice(
|
||||
disabled_label,
|
||||
value=(display, url),
|
||||
disabled=reason
|
||||
))
|
||||
|
||||
choice = questionary.select(
|
||||
"Select your LLM Provider:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=(display, value))
|
||||
for display, value in BASE_URLS
|
||||
],
|
||||
choices=choices,
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
[
|
||||
|
|
@ -265,12 +319,12 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
||||
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||
exit(1)
|
||||
|
||||
|
||||
display_name, url = choice
|
||||
print(f"You selected: {display_name}\tURL: {url}")
|
||||
|
||||
|
||||
return display_name, url
|
||||
|
|
|
|||
|
|
@ -24,3 +24,5 @@ rich
|
|||
questionary
|
||||
langchain_anthropic
|
||||
langchain-google-genai
|
||||
playwright
|
||||
markdown2
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions
|
||||
from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions, normalize_content
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ def create_fundamentals_analyst(llm):
|
|||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
report = normalize_content(result.content)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators
|
||||
from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators, normalize_content
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ Volume-Based Indicators:
|
|||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
report = normalize_content(result.content)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import get_news, get_global_news
|
||||
from tradingagents.agents.utils.agent_utils import get_news, get_global_news, normalize_content
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ def create_news_analyst(llm):
|
|||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
report = normalize_content(result.content)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.agent_utils import get_news
|
||||
from tradingagents.agents.utils.agent_utils import get_news, normalize_content
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ def create_social_media_analyst(llm):
|
|||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
report = normalize_content(result.content)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -1,5 +1,20 @@
|
|||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
|
||||
|
||||
def normalize_content(content):
|
||||
"""Normalize LLM response content to string.
|
||||
|
||||
Gemini returns content as a list of dicts with 'text' keys,
|
||||
while OpenAI/Anthropic return a simple string.
|
||||
"""
|
||||
if isinstance(content, list):
|
||||
return "".join(
|
||||
block.get("text", "") if isinstance(block, dict) else str(block)
|
||||
for block in content
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
# Import tools from separate utility files
|
||||
from tradingagents.agents.utils.core_stock_tools import (
|
||||
get_stock_data
|
||||
|
|
|
|||
|
|
@ -5,6 +5,38 @@ from typing import Dict, Optional
|
|||
_config: Optional[Dict] = None
|
||||
DATA_DIR: Optional[str] = None
|
||||
|
||||
# Local LLM providers that don't support OpenAI's web_search_preview
|
||||
LOCAL_LLM_PROVIDERS = ["ollama", "lm studio"]
|
||||
# Methods that require OpenAI's web_search_preview tool
|
||||
OPENAI_ONLY_METHODS = ["get_news", "get_global_news", "get_fundamentals"]
|
||||
|
||||
|
||||
def validate_config(config: Dict):
|
||||
"""Validate configuration and warn about incompatible settings."""
|
||||
llm_provider = config.get("llm_provider", "").lower()
|
||||
|
||||
if llm_provider in LOCAL_LLM_PROVIDERS:
|
||||
# Check data vendors
|
||||
data_vendors = config.get("data_vendors", {})
|
||||
tool_vendors = config.get("tool_vendors", {})
|
||||
|
||||
warnings = []
|
||||
if data_vendors.get("news_data") == "openai":
|
||||
warnings.append("data_vendors.news_data")
|
||||
if data_vendors.get("fundamental_data") == "openai":
|
||||
warnings.append("data_vendors.fundamental_data")
|
||||
|
||||
for method in OPENAI_ONLY_METHODS:
|
||||
if tool_vendors.get(method) == "openai":
|
||||
warnings.append(f"tool_vendors.{method}")
|
||||
|
||||
if warnings:
|
||||
print(f"WARNING: Using local LLM provider '{llm_provider}' with 'openai' data vendors.")
|
||||
print(f" The following settings use OpenAI's web_search_preview which is not available locally:")
|
||||
for w in warnings:
|
||||
print(f" - {w}")
|
||||
print(f" Recommendation: Change these to 'alpha_vantage', 'google', or 'local'.")
|
||||
|
||||
|
||||
def initialize_config():
|
||||
"""Initialize the configuration with default values."""
|
||||
|
|
@ -21,6 +53,7 @@ def set_config(config: Dict):
|
|||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config.update(config)
|
||||
DATA_DIR = _config["data_dir"]
|
||||
validate_config(_config)
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Annotated
|
||||
import time
|
||||
|
||||
# Import from vendor-specific modules
|
||||
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
|
||||
|
|
@ -16,6 +17,7 @@ from .alpha_vantage import (
|
|||
get_news as get_alpha_vantage_news
|
||||
)
|
||||
from .alpha_vantage_common import AlphaVantageRateLimitError
|
||||
from openai import APIConnectionError, APITimeoutError, RateLimitError
|
||||
|
||||
# Configuration and routing logic
|
||||
from .config import get_config
|
||||
|
|
@ -194,25 +196,58 @@ def route_to_vendor(method: str, *args, **kwargs):
|
|||
else:
|
||||
vendor_methods = [(vendor_impl, vendor)]
|
||||
|
||||
# Run methods for this vendor
|
||||
# Run methods for this vendor with retry logic
|
||||
vendor_results = []
|
||||
for impl_func, vendor_name in vendor_methods:
|
||||
try:
|
||||
print(f"DEBUG: Calling {impl_func.__name__} from vendor '{vendor_name}'...")
|
||||
result = impl_func(*args, **kwargs)
|
||||
vendor_results.append(result)
|
||||
print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully")
|
||||
|
||||
except AlphaVantageRateLimitError as e:
|
||||
if vendor == "alpha_vantage":
|
||||
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor")
|
||||
max_retries = 3
|
||||
base_delay = 1.0
|
||||
last_error = None
|
||||
|
||||
for retry_attempt in range(max_retries):
|
||||
try:
|
||||
if retry_attempt > 0:
|
||||
print(f"RETRY: Attempt {retry_attempt + 1}/{max_retries} "
|
||||
f"for {impl_func.__name__}")
|
||||
else:
|
||||
print(f"DEBUG: Calling {impl_func.__name__} "
|
||||
f"from vendor '{vendor_name}'...")
|
||||
|
||||
result = impl_func(*args, **kwargs)
|
||||
vendor_results.append(result)
|
||||
print(f"SUCCESS: {impl_func.__name__} from vendor "
|
||||
f"'{vendor_name}' completed successfully")
|
||||
last_error = None
|
||||
break # Success, exit retry loop
|
||||
|
||||
except (AlphaVantageRateLimitError, RateLimitError) as e:
|
||||
print(f"RATE_LIMIT: {type(e).__name__} exceeded, falling back to next vendor.")
|
||||
print(f"DEBUG: Rate limit details: {e}")
|
||||
# Continue to next vendor for fallback
|
||||
continue
|
||||
except Exception as e:
|
||||
# Log error but continue with other implementations
|
||||
print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}")
|
||||
continue
|
||||
last_error = e
|
||||
break # Don't retry rate limits, move to next vendor
|
||||
|
||||
except (ConnectionError, TimeoutError, OSError,
|
||||
APIConnectionError, APITimeoutError) as e:
|
||||
# Transient errors - retry with backoff
|
||||
last_error = e
|
||||
if retry_attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** retry_attempt)
|
||||
print(f"TRANSIENT_ERROR: {type(e).__name__} - {e}")
|
||||
print(f"RETRY: Waiting {delay}s before retry...")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
print(f"FAILED: {impl_func.__name__} from vendor "
|
||||
f"'{vendor_name}' failed after {max_retries} "
|
||||
f"attempts: {e}")
|
||||
|
||||
except Exception as e:
|
||||
# Non-transient errors - don't retry
|
||||
last_error = e
|
||||
print(f"FAILED: {impl_func.__name__} from vendor "
|
||||
f"'{vendor_name}' failed: {type(e).__name__}: {e}")
|
||||
break
|
||||
|
||||
if last_error is not None:
|
||||
continue # Move to next implementation
|
||||
|
||||
# Add this vendor's results
|
||||
if vendor_results:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,46 @@
|
|||
import os
|
||||
|
||||
from openai import OpenAI
|
||||
from .config import get_config
|
||||
|
||||
|
||||
def _get_web_search_tool_type() -> str:
|
||||
"""Return the appropriate web search tool type based on FETCH_LATEST setting.
|
||||
|
||||
- FETCH_LATEST=true: Use 'web_search' (GA version, supports GPT-5)
|
||||
- FETCH_LATEST=false/unset: Use 'web_search_preview' (legacy, wider compatibility)
|
||||
"""
|
||||
fetch_latest = os.getenv("FETCH_LATEST", "false").lower() in ("true", "1", "yes")
|
||||
return "web_search" if fetch_latest else "web_search_preview"
|
||||
|
||||
|
||||
def _extract_text_from_response(response):
|
||||
"""Safely extract text content from OpenAI Responses API output.
|
||||
|
||||
The response.output array typically contains:
|
||||
- output[0]: ResponseFunctionWebSearch (the web search call)
|
||||
- output[1]: ResponseOutputMessage (the text response)
|
||||
|
||||
This function handles edge cases where the structure may differ.
|
||||
"""
|
||||
if not response.output:
|
||||
raise RuntimeError("OpenAI response has empty output")
|
||||
|
||||
# Look for a message with text content
|
||||
for item in response.output:
|
||||
if hasattr(item, 'content') and item.content:
|
||||
for content_block in item.content:
|
||||
if hasattr(content_block, 'text') and content_block.text:
|
||||
return content_block.text
|
||||
|
||||
# If we get here, no text was found
|
||||
output_types = [type(item).__name__ for item in response.output]
|
||||
raise RuntimeError(
|
||||
f"No text content found in OpenAI response. "
|
||||
f"Output types: {output_types}"
|
||||
)
|
||||
|
||||
|
||||
def get_stock_news_openai(query, start_date, end_date):
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
|
|
@ -23,7 +62,7 @@ def get_stock_news_openai(query, start_date, end_date):
|
|||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"type": _get_web_search_tool_type(),
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
|
|
@ -34,7 +73,7 @@ def get_stock_news_openai(query, start_date, end_date):
|
|||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
return _extract_text_from_response(response)
|
||||
|
||||
|
||||
def get_global_news_openai(curr_date, look_back_days=7, limit=5):
|
||||
|
|
@ -58,7 +97,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
|
|||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"type": _get_web_search_tool_type(),
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
|
|
@ -69,7 +108,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
|
|||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
return _extract_text_from_response(response)
|
||||
|
||||
|
||||
def get_fundamentals_openai(ticker, curr_date):
|
||||
|
|
@ -93,7 +132,7 @@ def get_fundamentals_openai(ticker, curr_date):
|
|||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"type": _get_web_search_tool_type(),
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
|
|
@ -104,4 +143,4 @@ def get_fundamentals_openai(ticker, curr_date):
|
|||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
return _extract_text_from_response(response)
|
||||
|
|
@ -12,6 +12,8 @@ from langchain_google_genai import ChatGoogleGenerativeAI
|
|||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.llm import requires_responses_api, ChatOpenAIResponses
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
|
@ -72,9 +74,20 @@ class TradingAgentsGraph:
|
|||
)
|
||||
|
||||
# Initialize LLMs
|
||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
if self.config["llm_provider"].lower() in ["openai", "ollama", "openrouter", "lm studio"]:
|
||||
# Select LLM class based on model - newer models require Responses API
|
||||
deep_model = self.config["deep_think_llm"]
|
||||
quick_model = self.config["quick_think_llm"]
|
||||
|
||||
if requires_responses_api(deep_model):
|
||||
self.deep_thinking_llm = ChatOpenAIResponses(model=deep_model, base_url=self.config["backend_url"])
|
||||
else:
|
||||
self.deep_thinking_llm = ChatOpenAI(model=deep_model, base_url=self.config["backend_url"])
|
||||
|
||||
if requires_responses_api(quick_model):
|
||||
self.quick_thinking_llm = ChatOpenAIResponses(model=quick_model, base_url=self.config["backend_url"])
|
||||
else:
|
||||
self.quick_thinking_llm = ChatOpenAI(model=quick_model, base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "anthropic":
|
||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
"""LLM wrapper utilities for TradingAgents.
|
||||
|
||||
This module provides custom LLM wrappers for different API endpoints.
|
||||
"""
|
||||
|
||||
from tradingagents.llm.model_utils import requires_responses_api
|
||||
from tradingagents.llm.openai_responses import ChatOpenAIResponses
|
||||
|
||||
__all__ = ["requires_responses_api", "ChatOpenAIResponses"]
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
"""Utility functions for model detection and selection."""
|
||||
|
||||
# Model prefixes that require the OpenAI Responses API (/v1/responses)
|
||||
# instead of the Chat Completions API (/v1/chat/completions)
|
||||
RESPONSES_API_PREFIXES = [
|
||||
"gpt-5", # All GPT-5 variants (gpt-5, gpt-5.1, gpt-5.1-codex-mini, etc.)
|
||||
"codex", # Codex models that use Responses API
|
||||
]
|
||||
|
||||
|
||||
def requires_responses_api(model_name: str) -> bool:
|
||||
"""Check if a model requires the Responses API instead of Chat Completions.
|
||||
|
||||
Some newer OpenAI models only support the /v1/responses endpoint and will
|
||||
return a 404 error if called via /v1/chat/completions.
|
||||
|
||||
Args:
|
||||
model_name: The model identifier (e.g., "gpt-5.1-codex-mini", "gpt-4o")
|
||||
|
||||
Returns:
|
||||
True if the model requires the Responses API, False otherwise.
|
||||
"""
|
||||
model_lower = model_name.lower()
|
||||
return any(prefix in model_lower for prefix in RESPONSES_API_PREFIXES)
|
||||
|
|
@ -0,0 +1,307 @@
|
|||
"""LangChain-compatible wrapper for OpenAI's Responses API.
|
||||
|
||||
This module provides ChatOpenAIResponses, a drop-in replacement for ChatOpenAI
|
||||
that uses the /v1/responses endpoint instead of /v1/chat/completions.
|
||||
|
||||
This is required for newer models like gpt-5.1-codex-mini that only support
|
||||
the Responses API.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class ChatOpenAIResponses(BaseChatModel):
|
||||
"""LangChain-compatible chat model using OpenAI's Responses API.
|
||||
|
||||
This class provides the same interface as ChatOpenAI but uses the
|
||||
/v1/responses endpoint, which is required for certain newer models.
|
||||
|
||||
Example:
|
||||
>>> llm = ChatOpenAIResponses(model="gpt-5.1-codex-mini")
|
||||
>>> llm_with_tools = llm.bind_tools([my_tool])
|
||||
>>> result = llm_with_tools.invoke([HumanMessage(content="Hello")])
|
||||
"""
|
||||
|
||||
model: str = Field(default="gpt-5.1-codex-mini")
|
||||
base_url: Optional[str] = Field(default=None)
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
temperature: float = Field(default=1.0)
|
||||
max_output_tokens: int = Field(default=4096)
|
||||
top_p: float = Field(default=1.0)
|
||||
|
||||
# Internal state for tool binding
|
||||
_bound_tools: List[Dict[str, Any]] = []
|
||||
_client: Optional[OpenAI] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._bound_tools = []
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openai-responses"
|
||||
|
||||
@property
|
||||
def client(self) -> OpenAI:
|
||||
"""Lazily initialize the OpenAI client."""
|
||||
if self._client is None:
|
||||
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if self.base_url:
|
||||
self._client = OpenAI(api_key=api_key, base_url=self.base_url)
|
||||
else:
|
||||
self._client = OpenAI(api_key=api_key)
|
||||
return self._client
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> "ChatOpenAIResponses":
|
||||
"""Bind tools to this model instance.
|
||||
|
||||
Args:
|
||||
tools: A sequence of tools to bind. Can be LangChain tools or dicts.
|
||||
|
||||
Returns:
|
||||
A new ChatOpenAIResponses instance with the tools bound.
|
||||
"""
|
||||
new_instance = ChatOpenAIResponses(
|
||||
model=self.model,
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
temperature=self.temperature,
|
||||
max_output_tokens=self.max_output_tokens,
|
||||
top_p=self.top_p,
|
||||
)
|
||||
new_instance._bound_tools = self._convert_tools(tools)
|
||||
return new_instance
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: Sequence[Union[Dict[str, Any], BaseTool]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert LangChain tools to OpenAI Responses API function format.
|
||||
|
||||
The Responses API uses a flat structure for function tools:
|
||||
{
|
||||
"type": "function",
|
||||
"name": "function_name",
|
||||
"description": "...",
|
||||
"parameters": {...}
|
||||
}
|
||||
|
||||
This differs from Chat Completions which nests under "function" key.
|
||||
"""
|
||||
converted = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
# Get the JSON schema for parameters
|
||||
if tool.args_schema:
|
||||
params = tool.args_schema.model_json_schema()
|
||||
# Remove title field that OpenAI doesn't expect at schema level
|
||||
params.pop("title", None)
|
||||
else:
|
||||
params = {"type": "object", "properties": {}}
|
||||
|
||||
# Responses API uses flat structure - name at top level, not nested
|
||||
tool_schema = {
|
||||
"type": "function",
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": params,
|
||||
}
|
||||
converted.append(tool_schema)
|
||||
elif isinstance(tool, dict):
|
||||
# Handle dict format - convert from Chat Completions format if needed
|
||||
if "function" in tool:
|
||||
# Chat Completions format - flatten it
|
||||
func = tool["function"]
|
||||
tool_schema = {
|
||||
"type": "function",
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {"type": "object", "properties": {}}),
|
||||
}
|
||||
converted.append(tool_schema)
|
||||
elif "name" in tool:
|
||||
# Already in Responses API format
|
||||
converted.append(tool)
|
||||
else:
|
||||
# Unknown format, try to use as-is
|
||||
converted.append(tool)
|
||||
return converted
|
||||
|
||||
def _convert_messages(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert LangChain messages to OpenAI Responses API format.
|
||||
|
||||
The Responses API uses a different message format than Chat Completions:
|
||||
- System/user messages use 'input_text' content type
|
||||
- Assistant messages use 'output_text' content type (no function_call in content)
|
||||
- Tool calls from assistant are represented as separate 'function_call' items
|
||||
- Tool results use 'function_call_output' content type
|
||||
"""
|
||||
converted = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
converted.append({
|
||||
"role": "system",
|
||||
"content": [{"type": "input_text", "text": content}],
|
||||
})
|
||||
elif isinstance(msg, HumanMessage):
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
converted.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": content}],
|
||||
})
|
||||
elif isinstance(msg, AIMessage):
|
||||
# Handle AI messages (assistant responses)
|
||||
# First add text content if present
|
||||
if msg.content:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
converted.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
})
|
||||
|
||||
# Tool calls need to be added as separate items in the Responses API
|
||||
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
# Convert args to JSON string for the API
|
||||
args = tc.get("args", {})
|
||||
if isinstance(args, dict):
|
||||
args_str = json.dumps(args)
|
||||
else:
|
||||
args_str = str(args)
|
||||
|
||||
# Add tool call as a separate item (not inside assistant content)
|
||||
converted.append({
|
||||
"type": "function_call",
|
||||
"call_id": tc.get("id", str(uuid.uuid4())),
|
||||
"name": tc["name"],
|
||||
"arguments": args_str,
|
||||
})
|
||||
elif not msg.content:
|
||||
# Empty assistant message - add placeholder
|
||||
converted.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": ""}],
|
||||
})
|
||||
elif isinstance(msg, ToolMessage):
|
||||
# Tool results need to be formatted as function call outputs
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
converted.append({
|
||||
"type": "function_call_output",
|
||||
"call_id": msg.tool_call_id,
|
||||
"output": content,
|
||||
})
|
||||
return converted
|
||||
|
||||
def _parse_response(self, response: Any) -> AIMessage:
|
||||
"""Parse OpenAI Responses API response into LangChain AIMessage."""
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
|
||||
if not response.output:
|
||||
return AIMessage(content="")
|
||||
|
||||
for item in response.output:
|
||||
# Handle text output
|
||||
if hasattr(item, 'content') and item.content:
|
||||
for block in item.content:
|
||||
if hasattr(block, 'text') and block.text:
|
||||
text_content += block.text
|
||||
|
||||
# Handle function/tool calls
|
||||
if hasattr(item, 'type') and item.type == 'function_call':
|
||||
args = item.arguments
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {"raw": args}
|
||||
|
||||
tool_calls.append({
|
||||
"id": getattr(item, 'id', None) or getattr(item, 'call_id', None) or str(uuid.uuid4()),
|
||||
"name": item.name,
|
||||
"args": args,
|
||||
})
|
||||
|
||||
if tool_calls:
|
||||
return AIMessage(content=text_content, tool_calls=tool_calls)
|
||||
return AIMessage(content=text_content)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a response using the OpenAI Responses API.
|
||||
|
||||
Args:
|
||||
messages: List of LangChain messages to send.
|
||||
stop: Optional stop sequences (not used by Responses API).
|
||||
run_manager: Optional callback manager.
|
||||
|
||||
Returns:
|
||||
ChatResult containing the model's response.
|
||||
"""
|
||||
# Convert messages to Responses API format
|
||||
converted_messages = self._convert_messages(messages)
|
||||
|
||||
# Build request parameters
|
||||
request_params = {
|
||||
"model": self.model,
|
||||
"input": converted_messages,
|
||||
"temperature": self.temperature,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
"top_p": self.top_p,
|
||||
}
|
||||
|
||||
# Add tools if bound
|
||||
if self._bound_tools:
|
||||
request_params["tools"] = self._bound_tools
|
||||
|
||||
# Make the API call
|
||||
response = self.client.responses.create(**request_params)
|
||||
|
||||
# Parse the response
|
||||
ai_message = self._parse_response(response)
|
||||
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=ai_message)],
|
||||
)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Return identifying parameters for this LLM."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"temperature": self.temperature,
|
||||
"max_output_tokens": self.max_output_tokens,
|
||||
}
|
||||
Loading…
Reference in New Issue