TradingAgents/frontend/backend/server.py

601 lines
19 KiB
Python

"""FastAPI server for Nifty50 AI recommendations."""
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import database as db
import sys
import os
from pathlib import Path
from datetime import datetime
import threading
# Add parent directories to path for importing trading agents
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
# Track running analyses
running_analyses = {} # {symbol: {"status": "running", "started_at": datetime, "progress": str}}
app = FastAPI(
title="Nifty50 AI API",
description="API for Nifty 50 stock recommendations",
version="1.0.0"
)
# Enable CORS for frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class StockAnalysis(BaseModel):
symbol: str
company_name: str
decision: Optional[str] = None
confidence: Optional[str] = None
risk: Optional[str] = None
raw_analysis: Optional[str] = None
class TopPick(BaseModel):
rank: int
symbol: str
company_name: str
decision: str
reason: str
risk_level: str
class StockToAvoid(BaseModel):
symbol: str
company_name: str
reason: str
class Summary(BaseModel):
total: int
buy: int
sell: int
hold: int
class DailyRecommendation(BaseModel):
date: str
analysis: dict[str, StockAnalysis]
summary: Summary
top_picks: list[TopPick]
stocks_to_avoid: list[StockToAvoid]
class SaveRecommendationRequest(BaseModel):
date: str
analysis: dict
summary: dict
top_picks: list
stocks_to_avoid: list
# ============== Pipeline Data Models ==============
class AgentReport(BaseModel):
agent_type: str
report_content: str
data_sources_used: Optional[list] = []
created_at: Optional[str] = None
class DebateHistory(BaseModel):
debate_type: str
bull_arguments: Optional[str] = None
bear_arguments: Optional[str] = None
risky_arguments: Optional[str] = None
safe_arguments: Optional[str] = None
neutral_arguments: Optional[str] = None
judge_decision: Optional[str] = None
full_history: Optional[str] = None
class PipelineStep(BaseModel):
step_number: int
step_name: str
status: str
started_at: Optional[str] = None
completed_at: Optional[str] = None
duration_ms: Optional[int] = None
output_summary: Optional[str] = None
class DataSourceLog(BaseModel):
source_type: str
source_name: str
data_fetched: Optional[dict] = None
fetch_timestamp: Optional[str] = None
success: bool = True
error_message: Optional[str] = None
class SavePipelineDataRequest(BaseModel):
date: str
symbol: str
agent_reports: Optional[dict] = None
investment_debate: Optional[dict] = None
risk_debate: Optional[dict] = None
pipeline_steps: Optional[list] = None
data_sources: Optional[list] = None
class AnalysisConfig(BaseModel):
deep_think_model: Optional[str] = "opus"
quick_think_model: Optional[str] = "sonnet"
provider: Optional[str] = "claude_subscription" # claude_subscription or anthropic_api
api_key: Optional[str] = None
max_debate_rounds: Optional[int] = 1
class RunAnalysisRequest(BaseModel):
symbol: str
date: Optional[str] = None # Defaults to today if not provided
config: Optional[AnalysisConfig] = None
def run_analysis_task(symbol: str, date: str, analysis_config: dict = None):
"""Background task to run trading analysis for a stock."""
global running_analyses
# Default config values
if analysis_config is None:
analysis_config = {}
deep_think_model = analysis_config.get("deep_think_model", "opus")
quick_think_model = analysis_config.get("quick_think_model", "sonnet")
provider = analysis_config.get("provider", "claude_subscription")
api_key = analysis_config.get("api_key")
max_debate_rounds = analysis_config.get("max_debate_rounds", 1)
try:
running_analyses[symbol] = {
"status": "initializing",
"started_at": datetime.now().isoformat(),
"progress": "Loading trading agents..."
}
# Import trading agents
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
running_analyses[symbol]["progress"] = "Initializing analysis pipeline..."
# Create config from user settings
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "anthropic" # Use Claude for all LLM
config["deep_think_llm"] = deep_think_model
config["quick_think_llm"] = quick_think_model
config["max_debate_rounds"] = max_debate_rounds
# If using API provider and key is provided, set it in environment
if provider == "anthropic_api" and api_key:
os.environ["ANTHROPIC_API_KEY"] = api_key
running_analyses[symbol]["status"] = "running"
running_analyses[symbol]["progress"] = f"Running market analysis (model: {deep_think_model})..."
# Initialize and run
ta = TradingAgentsGraph(debug=False, config=config)
running_analyses[symbol]["progress"] = f"Analyzing {symbol}..."
final_state, decision = ta.propagate(symbol, date)
running_analyses[symbol] = {
"status": "completed",
"completed_at": datetime.now().isoformat(),
"progress": f"Analysis complete: {decision}",
"decision": decision
}
except Exception as e:
error_msg = str(e) if str(e) else f"{type(e).__name__}: No details provided"
running_analyses[symbol] = {
"status": "error",
"error": error_msg,
"progress": f"Error: {error_msg[:100]}"
}
import traceback
print(f"Analysis error for {symbol}: {type(e).__name__}: {error_msg}")
traceback.print_exc()
@app.get("/")
async def root():
"""API root endpoint."""
return {
"name": "Nifty50 AI API",
"version": "2.0.0",
"endpoints": {
"GET /recommendations": "Get all recommendations",
"GET /recommendations/latest": "Get latest recommendation",
"GET /recommendations/{date}": "Get recommendation by date",
"GET /recommendations/{date}/{symbol}/pipeline": "Get full pipeline data for a stock",
"GET /recommendations/{date}/{symbol}/agents": "Get agent reports for a stock",
"GET /recommendations/{date}/{symbol}/debates": "Get debate history for a stock",
"GET /recommendations/{date}/{symbol}/data-sources": "Get data source logs for a stock",
"GET /recommendations/{date}/pipeline-summary": "Get pipeline summary for all stocks on a date",
"GET /stocks/{symbol}/history": "Get stock history",
"GET /dates": "Get all available dates",
"POST /recommendations": "Save a new recommendation",
"POST /pipeline": "Save pipeline data for a stock"
}
}
@app.get("/recommendations")
async def get_all_recommendations():
"""Get all daily recommendations."""
recommendations = db.get_all_recommendations()
return {"recommendations": recommendations, "count": len(recommendations)}
@app.get("/recommendations/latest")
async def get_latest_recommendation():
"""Get the most recent recommendation."""
recommendation = db.get_latest_recommendation()
if not recommendation:
raise HTTPException(status_code=404, detail="No recommendations found")
return recommendation
@app.get("/recommendations/{date}")
async def get_recommendation_by_date(date: str):
"""Get recommendation for a specific date (format: YYYY-MM-DD)."""
recommendation = db.get_recommendation_by_date(date)
if not recommendation:
raise HTTPException(status_code=404, detail=f"No recommendation found for {date}")
return recommendation
@app.get("/stocks/{symbol}/history")
async def get_stock_history(symbol: str):
"""Get historical recommendations for a specific stock."""
history = db.get_stock_history(symbol.upper())
return {"symbol": symbol.upper(), "history": history, "count": len(history)}
@app.get("/dates")
async def get_available_dates():
"""Get all dates with recommendations."""
dates = db.get_all_dates()
return {"dates": dates, "count": len(dates)}
@app.post("/recommendations")
async def save_recommendation(request: SaveRecommendationRequest):
"""Save a new daily recommendation."""
try:
db.save_recommendation(
date=request.date,
analysis_data=request.analysis,
summary=request.summary,
top_picks=request.top_picks,
stocks_to_avoid=request.stocks_to_avoid
)
return {"message": f"Recommendation for {request.date} saved successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "database": "connected"}
# ============== Pipeline Data Endpoints ==============
@app.get("/recommendations/{date}/{symbol}/pipeline")
async def get_pipeline_data(date: str, symbol: str):
"""Get full pipeline data for a stock on a specific date."""
pipeline_data = db.get_full_pipeline_data(date, symbol.upper())
# Check if we have any data
has_data = (
pipeline_data.get('agent_reports') or
pipeline_data.get('debates') or
pipeline_data.get('pipeline_steps') or
pipeline_data.get('data_sources')
)
if not has_data:
# Return empty structure with mock pipeline steps if no data
return {
"date": date,
"symbol": symbol.upper(),
"agent_reports": {},
"debates": {},
"pipeline_steps": [],
"data_sources": [],
"status": "no_data"
}
return {**pipeline_data, "status": "complete"}
@app.get("/recommendations/{date}/{symbol}/agents")
async def get_agent_reports(date: str, symbol: str):
"""Get agent reports for a stock on a specific date."""
reports = db.get_agent_reports(date, symbol.upper())
return {
"date": date,
"symbol": symbol.upper(),
"reports": reports,
"count": len(reports)
}
@app.get("/recommendations/{date}/{symbol}/debates")
async def get_debate_history(date: str, symbol: str):
"""Get debate history for a stock on a specific date."""
debates = db.get_debate_history(date, symbol.upper())
return {
"date": date,
"symbol": symbol.upper(),
"debates": debates
}
@app.get("/recommendations/{date}/{symbol}/data-sources")
async def get_data_sources(date: str, symbol: str):
"""Get data source logs for a stock on a specific date."""
logs = db.get_data_source_logs(date, symbol.upper())
return {
"date": date,
"symbol": symbol.upper(),
"data_sources": logs,
"count": len(logs)
}
@app.get("/recommendations/{date}/pipeline-summary")
async def get_pipeline_summary(date: str):
"""Get pipeline summary for all stocks on a specific date."""
summary = db.get_pipeline_summary_for_date(date)
return {
"date": date,
"stocks": summary,
"count": len(summary)
}
@app.post("/pipeline")
async def save_pipeline_data(request: SavePipelineDataRequest):
"""Save pipeline data for a stock."""
try:
db.save_full_pipeline_data(
date=request.date,
symbol=request.symbol.upper(),
pipeline_data={
'agent_reports': request.agent_reports,
'investment_debate': request.investment_debate,
'risk_debate': request.risk_debate,
'pipeline_steps': request.pipeline_steps,
'data_sources': request.data_sources
}
)
return {"message": f"Pipeline data for {request.symbol} on {request.date} saved successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ============== Analysis Endpoints ==============
# Track bulk analysis state
bulk_analysis_state = {
"status": "idle", # idle, running, completed, error
"total": 0,
"completed": 0,
"failed": 0,
"current_symbol": None,
"started_at": None,
"completed_at": None,
"results": {}
}
# List of Nifty 50 stocks
NIFTY_50_SYMBOLS = [
"RELIANCE", "TCS", "HDFCBANK", "INFY", "ICICIBANK", "HINDUNILVR", "ITC", "SBIN",
"BHARTIARTL", "KOTAKBANK", "LT", "AXISBANK", "ASIANPAINT", "MARUTI", "HCLTECH",
"SUNPHARMA", "TITAN", "BAJFINANCE", "WIPRO", "ULTRACEMCO", "NESTLEIND", "NTPC",
"POWERGRID", "M&M", "TATAMOTORS", "ONGC", "JSWSTEEL", "TATASTEEL", "ADANIENT",
"ADANIPORTS", "COALINDIA", "BAJAJFINSV", "TECHM", "HDFCLIFE", "SBILIFE", "GRASIM",
"DIVISLAB", "DRREDDY", "CIPLA", "BRITANNIA", "EICHERMOT", "APOLLOHOSP", "INDUSINDBK",
"HEROMOTOCO", "TATACONSUM", "BPCL", "UPL", "HINDALCO", "BAJAJ-AUTO", "LTIM"
]
class BulkAnalysisRequest(BaseModel):
deep_think_model: Optional[str] = "opus"
quick_think_model: Optional[str] = "sonnet"
provider: Optional[str] = "claude_subscription"
api_key: Optional[str] = None
max_debate_rounds: Optional[int] = 1
@app.post("/analyze/all")
async def run_bulk_analysis(request: Optional[BulkAnalysisRequest] = None, date: Optional[str] = None):
"""Trigger analysis for all Nifty 50 stocks. Runs in background."""
global bulk_analysis_state
# Check if bulk analysis is already running
if bulk_analysis_state.get("status") == "running":
return {
"message": "Bulk analysis already running",
"status": bulk_analysis_state
}
# Use today's date if not provided
if not date:
date = datetime.now().strftime("%Y-%m-%d")
# Build analysis config from request
analysis_config = {}
if request:
analysis_config = {
"deep_think_model": request.deep_think_model,
"quick_think_model": request.quick_think_model,
"provider": request.provider,
"api_key": request.api_key,
"max_debate_rounds": request.max_debate_rounds
}
# Start bulk analysis in background thread
def run_bulk():
global bulk_analysis_state
bulk_analysis_state = {
"status": "running",
"total": len(NIFTY_50_SYMBOLS),
"completed": 0,
"failed": 0,
"current_symbol": None,
"started_at": datetime.now().isoformat(),
"completed_at": None,
"results": {}
}
for symbol in NIFTY_50_SYMBOLS:
try:
bulk_analysis_state["current_symbol"] = symbol
run_analysis_task(symbol, date, analysis_config)
# Wait for completion
import time
while symbol in running_analyses and running_analyses[symbol].get("status") == "running":
time.sleep(2)
if symbol in running_analyses:
status = running_analyses[symbol].get("status", "unknown")
bulk_analysis_state["results"][symbol] = status
if status == "completed":
bulk_analysis_state["completed"] += 1
else:
bulk_analysis_state["failed"] += 1
else:
bulk_analysis_state["results"][symbol] = "unknown"
bulk_analysis_state["failed"] += 1
except Exception as e:
bulk_analysis_state["results"][symbol] = f"error: {str(e)}"
bulk_analysis_state["failed"] += 1
bulk_analysis_state["status"] = "completed"
bulk_analysis_state["current_symbol"] = None
bulk_analysis_state["completed_at"] = datetime.now().isoformat()
thread = threading.Thread(target=run_bulk)
thread.start()
return {
"message": "Bulk analysis started for all Nifty 50 stocks",
"date": date,
"total_stocks": len(NIFTY_50_SYMBOLS),
"status": "started"
}
@app.get("/analyze/all/status")
async def get_bulk_analysis_status():
"""Get the status of bulk analysis."""
return bulk_analysis_state
@app.get("/analyze/running")
async def get_running_analyses():
"""Get all currently running analyses."""
running = {k: v for k, v in running_analyses.items() if v.get("status") == "running"}
return {
"running": running,
"count": len(running)
}
class SingleAnalysisRequest(BaseModel):
deep_think_model: Optional[str] = "opus"
quick_think_model: Optional[str] = "sonnet"
provider: Optional[str] = "claude_subscription"
api_key: Optional[str] = None
max_debate_rounds: Optional[int] = 1
@app.post("/analyze/{symbol}")
async def run_analysis(symbol: str, background_tasks: BackgroundTasks, request: Optional[SingleAnalysisRequest] = None, date: Optional[str] = None):
"""Trigger analysis for a stock. Runs in background."""
symbol = symbol.upper()
# Check if analysis is already running
if symbol in running_analyses and running_analyses[symbol].get("status") == "running":
return {
"message": f"Analysis already running for {symbol}",
"status": running_analyses[symbol]
}
# Use today's date if not provided
if not date:
date = datetime.now().strftime("%Y-%m-%d")
# Build analysis config from request
analysis_config = {}
if request:
analysis_config = {
"deep_think_model": request.deep_think_model,
"quick_think_model": request.quick_think_model,
"provider": request.provider,
"api_key": request.api_key,
"max_debate_rounds": request.max_debate_rounds
}
# Start analysis in background thread
thread = threading.Thread(target=run_analysis_task, args=(symbol, date, analysis_config))
thread.start()
return {
"message": f"Analysis started for {symbol}",
"symbol": symbol,
"date": date,
"status": "started"
}
@app.get("/analyze/{symbol}/status")
async def get_analysis_status(symbol: str):
"""Get the status of a running or completed analysis."""
symbol = symbol.upper()
if symbol not in running_analyses:
return {
"symbol": symbol,
"status": "not_started",
"message": "No analysis has been run for this stock"
}
return {
"symbol": symbol,
**running_analyses[symbol]
}
@app.get("/analyze/running")
async def get_running_analyses():
"""Get all currently running analyses."""
running = {k: v for k, v in running_analyses.items() if v.get("status") == "running"}
return {
"running": running,
"count": len(running)
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)