375 lines
12 KiB
Python
375 lines
12 KiB
Python
"""Analysis CRUD endpoints."""
|
|
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy.orm import Session
|
|
|
|
from api.auth import APIKey, get_current_api_key
|
|
from api.database import Analysis, AnalysisLog, AnalysisReport, get_db
|
|
from api.models import (
|
|
AnalysisResponse,
|
|
AnalysisStatusResponse,
|
|
AnalysisSummary,
|
|
CreateAnalysisRequest,
|
|
LogEntry,
|
|
ReportResponse,
|
|
)
|
|
from api.state_manager import get_executor
|
|
from api.utils import extract_trading_decision
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
|
|
router = APIRouter(prefix="/api/v1/analyses", tags=["analyses"])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@router.post("", response_model=AnalysisResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_analysis(
|
|
request: CreateAnalysisRequest,
|
|
db: Session = Depends(get_db),
|
|
api_key: APIKey = Depends(get_current_api_key),
|
|
):
|
|
"""Create and start a new analysis."""
|
|
# Generate analysis ID
|
|
analysis_id = str(uuid.uuid4())
|
|
|
|
# Build configuration
|
|
config = DEFAULT_CONFIG.copy()
|
|
config["max_debate_rounds"] = request.research_depth
|
|
config["max_risk_discuss_rounds"] = request.research_depth
|
|
|
|
if request.llm_provider:
|
|
config["llm_provider"] = request.llm_provider
|
|
if request.backend_url:
|
|
config["backend_url"] = request.backend_url
|
|
if request.quick_think_llm:
|
|
config["quick_think_llm"] = request.quick_think_llm
|
|
if request.deep_think_llm:
|
|
config["deep_think_llm"] = request.deep_think_llm
|
|
|
|
# Auto-detect asset class
|
|
from cli.asset_detection import detect_asset_class
|
|
asset_class = detect_asset_class(request.ticker)
|
|
config["asset_class"] = asset_class
|
|
|
|
# Filter out fundamentals for commodities/crypto
|
|
selected_analysts = request.selected_analysts
|
|
if asset_class in ["commodity", "crypto"] and "fundamentals" in selected_analysts:
|
|
selected_analysts = [a for a in selected_analysts if a != "fundamentals"]
|
|
|
|
# Create database record
|
|
analysis = Analysis(
|
|
id=analysis_id,
|
|
ticker=request.ticker,
|
|
analysis_date=request.analysis_date,
|
|
status="pending",
|
|
config_json=json.dumps(config),
|
|
progress_percentage=0,
|
|
)
|
|
db.add(analysis)
|
|
db.commit()
|
|
db.refresh(analysis)
|
|
|
|
# Start analysis in background
|
|
logger.info(f"Creating analysis {analysis_id} for {request.ticker}")
|
|
executor = get_executor()
|
|
try:
|
|
executor.start_analysis(
|
|
analysis_id=analysis_id,
|
|
ticker=request.ticker,
|
|
analysis_date=request.analysis_date,
|
|
selected_analysts=selected_analysts,
|
|
config=config,
|
|
)
|
|
logger.info(f"Analysis {analysis_id} started successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to start analysis {analysis_id}: {str(e)}")
|
|
# Update status to failed
|
|
analysis.status = "failed"
|
|
analysis.error_message = str(e)
|
|
db.commit()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to start analysis: {str(e)}",
|
|
)
|
|
|
|
return AnalysisResponse(
|
|
id=analysis.id,
|
|
ticker=analysis.ticker,
|
|
analysis_date=analysis.analysis_date,
|
|
status=analysis.status,
|
|
config=config,
|
|
reports=[],
|
|
progress_percentage=analysis.progress_percentage,
|
|
current_agent=analysis.current_agent,
|
|
created_at=analysis.created_at,
|
|
updated_at=analysis.updated_at,
|
|
completed_at=analysis.completed_at,
|
|
error_message=analysis.error_message,
|
|
)
|
|
|
|
|
|
@router.get("", response_model=List[AnalysisSummary])
|
|
async def list_analyses(
|
|
ticker: Optional[str] = Query(None, description="Filter by ticker"),
|
|
status: Optional[str] = Query(None, description="Filter by status"),
|
|
date_from: Optional[str] = Query(None, description="Filter by date (from)"),
|
|
date_to: Optional[str] = Query(None, description="Filter by date (to)"),
|
|
limit: int = Query(100, ge=1, le=1000, description="Max results"),
|
|
offset: int = Query(0, ge=0, description="Offset for pagination"),
|
|
db: Session = Depends(get_db),
|
|
api_key: APIKey = Depends(get_current_api_key),
|
|
):
|
|
"""List all analyses with optional filtering."""
|
|
query = db.query(Analysis)
|
|
|
|
if ticker:
|
|
query = query.filter(Analysis.ticker == ticker.upper())
|
|
if status:
|
|
query = query.filter(Analysis.status == status)
|
|
if date_from:
|
|
query = query.filter(Analysis.analysis_date >= date_from)
|
|
if date_to:
|
|
query = query.filter(Analysis.analysis_date <= date_to)
|
|
|
|
# Order by created_at descending
|
|
query = query.order_by(Analysis.created_at.desc())
|
|
|
|
# Apply pagination
|
|
analyses = query.offset(offset).limit(limit).all()
|
|
|
|
results = []
|
|
for a in analyses:
|
|
# Get trading decision for completed analyses
|
|
trading_decision = None
|
|
if a.status == "completed":
|
|
reports = db.query(AnalysisReport).filter(AnalysisReport.analysis_id == a.id).all()
|
|
if reports:
|
|
trading_decision = extract_trading_decision(reports)
|
|
|
|
results.append(
|
|
AnalysisSummary(
|
|
id=a.id,
|
|
ticker=a.ticker,
|
|
analysis_date=a.analysis_date,
|
|
status=a.status,
|
|
created_at=a.created_at,
|
|
completed_at=a.completed_at,
|
|
error_message=a.error_message,
|
|
trading_decision=trading_decision,
|
|
)
|
|
)
|
|
|
|
return results
|
|
|
|
|
|
@router.get("/{analysis_id}", response_model=AnalysisResponse)
|
|
async def get_analysis(
|
|
analysis_id: str,
|
|
db: Session = Depends(get_db),
|
|
api_key: APIKey = Depends(get_current_api_key),
|
|
):
|
|
"""Get full analysis details."""
|
|
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
|
|
|
if not analysis:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Analysis {analysis_id} not found",
|
|
)
|
|
|
|
# Get reports
|
|
reports = (
|
|
db.query(AnalysisReport)
|
|
.filter(AnalysisReport.analysis_id == analysis_id)
|
|
.all()
|
|
)
|
|
|
|
# Extract trading decision if analysis is completed
|
|
trading_decision = None
|
|
if analysis.status == "completed" and reports:
|
|
trading_decision = extract_trading_decision(reports)
|
|
|
|
return AnalysisResponse(
|
|
id=analysis.id,
|
|
ticker=analysis.ticker,
|
|
analysis_date=analysis.analysis_date,
|
|
status=analysis.status,
|
|
config=json.loads(analysis.config_json),
|
|
reports=[
|
|
ReportResponse(
|
|
report_type=r.report_type,
|
|
content=r.content,
|
|
created_at=r.created_at,
|
|
)
|
|
for r in reports
|
|
],
|
|
progress_percentage=analysis.progress_percentage,
|
|
current_agent=analysis.current_agent,
|
|
created_at=analysis.created_at,
|
|
updated_at=analysis.updated_at,
|
|
completed_at=analysis.completed_at,
|
|
error_message=analysis.error_message,
|
|
trading_decision=trading_decision,
|
|
)
|
|
|
|
|
|
@router.get("/{analysis_id}/status", response_model=AnalysisStatusResponse)
|
|
async def get_analysis_status(
|
|
analysis_id: str,
|
|
db: Session = Depends(get_db),
|
|
api_key: APIKey = Depends(get_current_api_key),
|
|
):
|
|
"""Get current analysis status."""
|
|
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
|
|
|
if not analysis:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Analysis {analysis_id} not found",
|
|
)
|
|
|
|
return AnalysisStatusResponse(
|
|
id=analysis.id,
|
|
status=analysis.status,
|
|
progress_percentage=analysis.progress_percentage,
|
|
current_agent=analysis.current_agent,
|
|
updated_at=analysis.updated_at,
|
|
)
|
|
|
|
|
|
@router.get("/{analysis_id}/reports", response_model=List[ReportResponse])
|
|
async def get_analysis_reports(
|
|
analysis_id: str,
|
|
db: Session = Depends(get_db),
|
|
api_key: APIKey = Depends(get_current_api_key),
|
|
):
|
|
"""Get all reports for an analysis."""
|
|
# Check if analysis exists
|
|
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
|
if not analysis:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Analysis {analysis_id} not found",
|
|
)
|
|
|
|
reports = (
|
|
db.query(AnalysisReport)
|
|
.filter(AnalysisReport.analysis_id == analysis_id)
|
|
.order_by(AnalysisReport.created_at)
|
|
.all()
|
|
)
|
|
|
|
return [
|
|
ReportResponse(
|
|
report_type=r.report_type,
|
|
content=r.content,
|
|
created_at=r.created_at,
|
|
)
|
|
for r in reports
|
|
]
|
|
|
|
|
|
@router.get("/{analysis_id}/reports/{report_type}", response_model=ReportResponse)
|
|
async def get_analysis_report(
|
|
analysis_id: str,
|
|
report_type: str,
|
|
db: Session = Depends(get_db),
|
|
api_key: APIKey = Depends(get_current_api_key),
|
|
):
|
|
"""Get a specific report for an analysis."""
|
|
report = (
|
|
db.query(AnalysisReport)
|
|
.filter(
|
|
AnalysisReport.analysis_id == analysis_id,
|
|
AnalysisReport.report_type == report_type,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not report:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Report {report_type} not found for analysis {analysis_id}",
|
|
)
|
|
|
|
return ReportResponse(
|
|
report_type=report.report_type,
|
|
content=report.content,
|
|
created_at=report.created_at,
|
|
)
|
|
|
|
|
|
@router.get("/{analysis_id}/logs", response_model=List[LogEntry])
|
|
async def get_analysis_logs(
|
|
analysis_id: str,
|
|
log_type: Optional[str] = Query(None, description="Filter by log type"),
|
|
limit: int = Query(100, ge=1, le=1000, description="Max results"),
|
|
offset: int = Query(0, ge=0, description="Offset for pagination"),
|
|
db: Session = Depends(get_db),
|
|
api_key: APIKey = Depends(get_current_api_key),
|
|
):
|
|
"""Get execution logs for an analysis."""
|
|
# Check if analysis exists
|
|
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
|
if not analysis:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Analysis {analysis_id} not found",
|
|
)
|
|
|
|
query = db.query(AnalysisLog).filter(AnalysisLog.analysis_id == analysis_id)
|
|
|
|
if log_type:
|
|
query = query.filter(AnalysisLog.log_type == log_type)
|
|
|
|
logs = query.order_by(AnalysisLog.timestamp).offset(offset).limit(limit).all()
|
|
|
|
return [
|
|
LogEntry(
|
|
timestamp=log.timestamp,
|
|
log_type=log.log_type,
|
|
content=log.content,
|
|
)
|
|
for log in logs
|
|
]
|
|
|
|
|
|
@router.delete("/{analysis_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_analysis(
|
|
analysis_id: str,
|
|
permanent: bool = Query(False, description="Permanently delete from database"),
|
|
db: Session = Depends(get_db),
|
|
api_key: APIKey = Depends(get_current_api_key),
|
|
):
|
|
"""Cancel and/or delete an analysis."""
|
|
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
|
|
|
if not analysis:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Analysis {analysis_id} not found",
|
|
)
|
|
|
|
# Try to cancel if running
|
|
if analysis.status in ["pending", "running"]:
|
|
executor = get_executor()
|
|
executor.cancel_analysis(analysis_id)
|
|
|
|
# Delete from database if requested
|
|
if permanent:
|
|
db.delete(analysis)
|
|
db.commit()
|
|
elif analysis.status not in ["cancelled", "failed", "completed"]:
|
|
# Just mark as cancelled
|
|
analysis.status = "cancelled"
|
|
analysis.updated_at = datetime.utcnow()
|
|
db.commit()
|
|
|
|
return None
|
|
|