451 lines
16 KiB
Python
451 lines
16 KiB
Python
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from typing import Dict, Any, Optional
|
|
import json
|
|
import os
|
|
import time
|
|
from datetime import datetime
|
|
import glob
|
|
import uuid
|
|
from starlette.concurrency import run_in_threadpool
|
|
|
|
# Load environment variables from .env (if present)
|
|
try:
|
|
from dotenv import load_dotenv, find_dotenv
|
|
_dotenv_path = find_dotenv()
|
|
if _dotenv_path:
|
|
load_dotenv(_dotenv_path)
|
|
except Exception:
|
|
# dotenv is optional; ignore if not installed
|
|
pass
|
|
|
|
# Import your TradingAgents components
|
|
import sys
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
|
|
app = FastAPI(title="TradingAgents API", version="1.0.0", debug=True)
|
|
|
|
# Centralized results directory to avoid repetition
|
|
RESULTS_BASE = os.path.join(os.path.dirname(__file__), "..", "..", "output_data")
|
|
|
|
# Simple startup check for OPENAI_API_KEY
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
if not OPENAI_API_KEY:
|
|
print("[WARN] OPENAI_API_KEY is not set. Set it in your shell or in a .env file.")
|
|
else:
|
|
print("[INFO] OPENAI_API_KEY detected from environment.")
|
|
|
|
# Configure CORS
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:3000"], # React dev server
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Pydantic models
|
|
class AnalysisRequest(BaseModel):
|
|
symbol: str
|
|
date: str
|
|
config_overrides: Optional[Dict[str, Any]] = None
|
|
|
|
class AnalysisResponse(BaseModel):
|
|
job_id: str
|
|
status: str
|
|
message: str
|
|
|
|
class JobStatus(BaseModel):
|
|
job_id: str
|
|
start_time: Optional[float] = None
|
|
status: str
|
|
progress: Optional[str] = None
|
|
result: Optional[Dict[str, Any]] = None
|
|
error: Optional[str] = None
|
|
trading_agent: Any = None # Store the trading agent instance here
|
|
|
|
# In-memory job storage (in production, use Redis or database)
|
|
jobs: Dict[str, JobStatus] = {}
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "TradingAgents API is running"}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "healthy", "timestamp": datetime.now().isoformat()}
|
|
|
|
async def run_analysis_task(job_id: str, symbol: str, analysis_date: str, config_overrides: Dict[str, Any] = None):
|
|
"""Background task to run the trading analysis without blocking the event loop"""
|
|
|
|
def execution_time() -> Optional[str]:
|
|
if jobs[job_id].status == "completed" or jobs[job_id].status == "failed":
|
|
return f"{time.time() - jobs[job_id].start_time:.2f} seconds"
|
|
else:
|
|
return None
|
|
|
|
jobs[job_id].start_time = time.time() # Save the start time in the job
|
|
|
|
try:
|
|
jobs[job_id].status = "running"
|
|
jobs[job_id].progress = "Initializing TradingAgents..."
|
|
|
|
# Prepare config
|
|
config = DEFAULT_CONFIG.copy()
|
|
if config_overrides:
|
|
config.update(config_overrides)
|
|
|
|
jobs[job_id].progress = "Setting up trading graph..."
|
|
# Create and store the trading agent instance
|
|
ta = TradingAgentsGraph(debug=True, config=config)
|
|
jobs[job_id].trading_agent = ta # Store the instance
|
|
jobs[job_id].progress = f"Analyzing {symbol} for {analysis_date}..."
|
|
|
|
# Run the propagate method in a threadpool
|
|
_, decision = await run_in_threadpool(ta.propagate, symbol, analysis_date)
|
|
|
|
jobs[job_id].status = "completed"
|
|
jobs[job_id].result = {
|
|
"symbol": symbol,
|
|
"date": analysis_date,
|
|
"decision": decision,
|
|
"completed_at": datetime.now().isoformat(),
|
|
"execution_time": execution_time()
|
|
}
|
|
jobs[job_id].progress = "Analysis completed successfully"
|
|
|
|
except Exception as e:
|
|
jobs[job_id].status = "failed"
|
|
jobs[job_id].error = str(e)
|
|
jobs[job_id].progress = f"Error: {str(e)}"
|
|
jobs[job_id].result = {
|
|
"symbol": symbol,
|
|
"date": analysis_date,
|
|
"decision": decision,
|
|
"completed_at": datetime.now().isoformat(),
|
|
"execution_time": execution_time()
|
|
}
|
|
jobs[job_id].progress = "Analysis completed successfully"
|
|
|
|
except Exception as e:
|
|
jobs[job_id].status = "failed"
|
|
jobs[job_id].error = str(e)
|
|
jobs[job_id].progress = f"Error: {str(e)}"
|
|
|
|
@app.post("/analysis/start", response_model=AnalysisResponse)
|
|
async def start_analysis(request: AnalysisRequest, background_tasks: BackgroundTasks):
|
|
"""Start a new trading analysis"""
|
|
job_id = str(uuid.uuid4())
|
|
|
|
# Normalize inputs
|
|
symbol = request.symbol.upper().strip()
|
|
date = request.date.strip()
|
|
|
|
# Validate date format
|
|
try:
|
|
datetime.strptime(date, "%Y-%m-%d")
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Invalid date format. Use YYYY-MM-DD")
|
|
|
|
# Initialize job
|
|
jobs[job_id] = JobStatus(
|
|
job_id=job_id,
|
|
status="queued",
|
|
progress="Analysis queued"
|
|
)
|
|
|
|
# Start background task
|
|
background_tasks.add_task(
|
|
run_analysis_task,
|
|
job_id,
|
|
symbol,
|
|
date,
|
|
request.config_overrides or {}
|
|
)
|
|
|
|
return AnalysisResponse(
|
|
job_id=job_id,
|
|
status="queued",
|
|
message=f"Analysis started for {symbol} on {date}"
|
|
)
|
|
|
|
@app.get("/analysis/status/{job_id}", response_model=JobStatus)
|
|
async def get_analysis_status(job_id: str):
|
|
"""Get the status of a running analysis"""
|
|
if job_id not in jobs:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
return jobs[job_id]
|
|
|
|
@app.get("/results/companies")
|
|
async def get_companies():
|
|
"""Get list of companies with analysis results"""
|
|
results_dir = RESULTS_BASE
|
|
if not os.path.exists(results_dir):
|
|
return {"companies": []}
|
|
|
|
companies = []
|
|
for company_dir in os.listdir(results_dir):
|
|
company_path = os.path.join(results_dir, company_dir)
|
|
if os.path.isdir(company_path):
|
|
# Check both regular logs and transformed logs
|
|
logs_dir = os.path.join(company_path, "TradingAgentsStrategy_logs")
|
|
transformed_logs_dir = os.path.join(company_path, "TradingAgentsStrategy_transformed_logs")
|
|
|
|
total_analyses = 0
|
|
latest_date = None
|
|
|
|
# Count regular analyses
|
|
if os.path.exists(logs_dir):
|
|
json_files = glob.glob(os.path.join(logs_dir, "*.json"))
|
|
total_analyses += len(json_files)
|
|
if json_files:
|
|
latest_file = max(json_files, key=os.path.getctime)
|
|
latest_date = os.path.basename(latest_file).replace("full_states_log_", "").replace(".json", "")
|
|
|
|
# Count transformed analyses
|
|
transformed_count = 0
|
|
if os.path.exists(transformed_logs_dir):
|
|
transformed_files = glob.glob(os.path.join(transformed_logs_dir, "*_transformed.json"))
|
|
transformed_count = len(transformed_files)
|
|
|
|
if total_analyses > 0 or transformed_count > 0:
|
|
companies.append({
|
|
"symbol": company_dir,
|
|
"latest_analysis": latest_date,
|
|
"total_analyses": total_analyses,
|
|
"transformed_analyses": transformed_count
|
|
})
|
|
|
|
return {"companies": companies}
|
|
|
|
@app.get("/results/{symbol}")
|
|
async def get_company_results(symbol: str):
|
|
"""Get all analysis results for a specific company"""
|
|
results_dir = os.path.join(RESULTS_BASE, symbol.upper(), "TradingAgentsStrategy_logs")
|
|
|
|
if not os.path.exists(results_dir):
|
|
raise HTTPException(status_code=404, detail=f"No results found for {symbol}")
|
|
|
|
results = []
|
|
json_files = glob.glob(os.path.join(results_dir, "*.json"))
|
|
|
|
for file_path in sorted(json_files, key=os.path.getctime, reverse=True):
|
|
filename = os.path.basename(file_path)
|
|
analysis_date = filename.replace("full_states_log_", "").replace(".json", "")
|
|
|
|
try:
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
results.append({
|
|
"date": analysis_date,
|
|
"filename": filename,
|
|
"file_size": os.path.getsize(file_path),
|
|
"modified_at": datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat(),
|
|
"preview": {
|
|
"keys": list(data.keys()) if isinstance(data, dict) else "Not a dict",
|
|
"size": len(str(data))
|
|
}
|
|
})
|
|
except Exception as e:
|
|
results.append({
|
|
"date": analysis_date,
|
|
"filename": filename,
|
|
"error": f"Could not read file: {str(e)}"
|
|
})
|
|
|
|
return {"symbol": symbol.upper(), "results": results}
|
|
|
|
@app.get("/transformed-results/{symbol}")
|
|
async def get_transformed_company_results(symbol: str):
|
|
"""Get all transformed analysis results for a specific company"""
|
|
results_dir = os.path.join(RESULTS_BASE, symbol.upper(), "TradingAgentsStrategy_transformed_logs")
|
|
|
|
if not os.path.exists(results_dir):
|
|
raise HTTPException(status_code=404, detail=f"No transformed results found for {symbol}")
|
|
|
|
results = []
|
|
json_files = glob.glob(os.path.join(results_dir, "*_transformed.json"))
|
|
|
|
for file_path in sorted(json_files, key=os.path.getctime, reverse=True):
|
|
filename = os.path.basename(file_path)
|
|
# Extract date from filename like "full_states_log_2025-07-26_transformed.json"
|
|
analysis_date = filename.replace("full_states_log_", "").replace("_transformed.json", "")
|
|
|
|
try:
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
results.append({
|
|
"date": analysis_date,
|
|
"filename": filename,
|
|
"file_size": os.path.getsize(file_path),
|
|
"modified_at": datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat(),
|
|
"preview": {
|
|
"company_ticker": data.get("metadata", {}).get("company_ticker", "N/A"),
|
|
"final_recommendation": data.get("metadata", {}).get("final_recommendation", "N/A"),
|
|
"confidence_level": data.get("metadata", {}).get("confidence_level", "N/A"),
|
|
"current_price": data.get("financial_data", {}).get("current_price", 0)
|
|
}
|
|
})
|
|
except Exception as e:
|
|
results.append({
|
|
"date": analysis_date,
|
|
"filename": filename,
|
|
"error": f"Could not read file: {str(e)}"
|
|
})
|
|
|
|
return {"symbol": symbol.upper(), "results": results}
|
|
|
|
@app.get("/results/{symbol}/{date}")
|
|
async def get_specific_result(symbol: str, date: str):
|
|
"""Get specific analysis result"""
|
|
file_path = os.path.join(
|
|
RESULTS_BASE,
|
|
symbol.upper(),
|
|
"TradingAgentsStrategy_logs",
|
|
f"full_states_log_{date}.json",
|
|
)
|
|
|
|
if not os.path.exists(file_path):
|
|
raise HTTPException(status_code=404, detail=f"No result found for {symbol} on {date}")
|
|
|
|
try:
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
return {
|
|
"symbol": symbol.upper(),
|
|
"date": date,
|
|
"data": data,
|
|
"metadata": {
|
|
"file_size": os.path.getsize(file_path),
|
|
"modified_at": datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
|
|
}
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error reading result: {str(e)}")
|
|
|
|
@app.get("/transformed-results/{symbol}/{date}")
|
|
async def get_specific_transformed_result(symbol: str, date: str):
|
|
"""Get specific transformed analysis result"""
|
|
file_path = os.path.join(
|
|
RESULTS_BASE,
|
|
symbol.upper(),
|
|
"TradingAgentsStrategy_transformed_logs",
|
|
f"full_states_log_{date}_transformed.json",
|
|
)
|
|
|
|
if not os.path.exists(file_path):
|
|
raise HTTPException(status_code=404, detail=f"Transformed result not found for {symbol} on {date}")
|
|
|
|
try:
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
return {
|
|
"symbol": symbol.upper(),
|
|
"date": date,
|
|
"data": data,
|
|
"file_info": {
|
|
"filename": os.path.basename(file_path),
|
|
"file_size": os.path.getsize(file_path),
|
|
"modified_at": datetime.fromtimestamp(os.path.getmtime(file_path)).isoformat()
|
|
}
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error reading file: {str(e)}")
|
|
|
|
@app.get("/jobs")
|
|
async def get_jobs():
|
|
"""Get all jobs"""
|
|
job_lst = []
|
|
for job_id, job in jobs.items():
|
|
job_dict = {
|
|
"job_id": job_id,
|
|
"status": job.status,
|
|
"progress": job.progress,
|
|
"result": job.result,
|
|
"error": job.error,
|
|
}
|
|
|
|
# Add execution_time if available
|
|
if hasattr(job, 'start_time') and job.start_time is not None:
|
|
if job.status in ["completed", "failed"]:
|
|
job_dict["execution_time"] = f"{time.time() - job.start_time:.2f} seconds"
|
|
|
|
job_lst.append(job_dict)
|
|
|
|
return {"jobs": job_lst}
|
|
|
|
@app.post("/reflect-on-analysis/{symbol}/{date}")
|
|
async def reflect_on_analysis(symbol: str, date: str, request: dict):
|
|
"""Get latest financial situation memory for a specific analysis"""
|
|
returns_losses = request.get("returns_losses")
|
|
if returns_losses is None:
|
|
raise HTTPException(status_code=400, detail="returns_losses is required in request body")
|
|
|
|
# Find the job that matches the symbol and date
|
|
matching_job = None
|
|
for job_id, job in jobs.items():
|
|
if (job.result.get("symbol") == symbol.upper() and
|
|
job.result.get("date") == date and
|
|
hasattr(job, 'trading_agent') and
|
|
job.trading_agent):
|
|
matching_job = job
|
|
break
|
|
|
|
if not matching_job:
|
|
raise HTTPException(status_code=404, detail=f"No active job found for {symbol} on {date}")
|
|
|
|
if not hasattr(matching_job.trading_agent, 'memory') or not matching_job.trading_agent.memory:
|
|
raise HTTPException(status_code=404, detail="No memory found for this analysis")
|
|
|
|
matching_job.trading_agent.reflect_and_remember(returns_losses)
|
|
|
|
try:
|
|
bull_memory = matching_job.trading_agent.bull_memory
|
|
bear_memory = matching_job.trading_agent.bear_memory
|
|
trader_memory = matching_job.trading_agent.trader_memory
|
|
invest_judge_memory = matching_job.trading_agent.invest_judge_memory
|
|
risk_manager_memory = matching_job.trading_agent.risk_manager_memory
|
|
|
|
reflections = {}
|
|
|
|
latest_entry = bull_memory.get_latest_situation()
|
|
reflections["bull_memory"] = latest_entry
|
|
|
|
latest_entry = bear_memory.get_latest_situation()
|
|
reflections["bear_memory"] = latest_entry
|
|
|
|
latest_entry = trader_memory.get_latest_situation()
|
|
reflections["trader_memory"] = latest_entry
|
|
|
|
latest_entry = invest_judge_memory.get_latest_situation()
|
|
reflections["invest_judge_memory"] = latest_entry
|
|
|
|
latest_entry = risk_manager_memory.get_latest_situation()
|
|
reflections["risk_manager_memory"] = latest_entry
|
|
|
|
return {
|
|
"symbol": symbol.upper(),
|
|
"date": date,
|
|
"job_id": matching_job.job_id,
|
|
"reflections": reflections
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error retrieving latest situation: {str(e)}")
|
|
|
|
@app.get("/config")
|
|
async def get_default_config():
|
|
"""Get the default configuration"""
|
|
return {"config": DEFAULT_CONFIG}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
|