This commit is contained in:
parent
2b9d8f1880
commit
d3ab03ccd9
|
|
@ -4,6 +4,7 @@ API route definitions for TradingAgents Backend
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from backend.app.models.schemas import (
|
||||
AnalysisRequest,
|
||||
|
|
@ -11,11 +12,17 @@ from backend.app.models.schemas import (
|
|||
ConfigResponse,
|
||||
HealthResponse,
|
||||
Ticker,
|
||||
TaskCreatedResponse,
|
||||
TaskStatusResponse,
|
||||
)
|
||||
from backend.app.services.trading_service import TradingService
|
||||
from backend.app.services.task_manager import RedisTaskManager, TaskStatus
|
||||
from backend.app.api.dependencies import get_trading_service
|
||||
from backend.app.core.config import settings
|
||||
|
||||
# Initialize task manager
|
||||
task_manager = RedisTaskManager(settings.redis_url)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create API router
|
||||
|
|
@ -42,49 +49,102 @@ async def get_config(service: TradingService = Depends(get_trading_service)):
|
|||
)
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=AnalysisResponse)
|
||||
@router.post("/analyze", response_model=TaskCreatedResponse)
|
||||
async def run_analysis(
|
||||
request: AnalysisRequest,
|
||||
service: TradingService = Depends(get_trading_service),
|
||||
):
|
||||
"""
|
||||
Run a comprehensive trading analysis for a given ticker and date.
|
||||
Start an async trading analysis task.
|
||||
|
||||
Requires OpenAI API key to be provided in the request.
|
||||
This endpoint creates an async task and returns immediately with a task ID.
|
||||
Use the /api/task/{task_id} endpoint to check the status and get results.
|
||||
|
||||
Args:
|
||||
request: Analysis request configuration
|
||||
service: Trading service instance (injected)
|
||||
|
||||
Returns:
|
||||
TaskCreatedResponse: Task ID and initial status
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Received analysis request for {request.ticker} on {request.analysis_date}")
|
||||
|
||||
# Run analysis with all provided parameters including API keys
|
||||
result = await service.run_analysis(
|
||||
ticker=request.ticker,
|
||||
analysis_date=request.analysis_date,
|
||||
openai_api_key=request.openai_api_key,
|
||||
openai_base_url=request.openai_base_url,
|
||||
alpha_vantage_api_key=request.alpha_vantage_api_key,
|
||||
analysts=request.analysts,
|
||||
research_depth=request.research_depth,
|
||||
deep_think_llm=request.deep_think_llm,
|
||||
quick_think_llm=request.quick_think_llm,
|
||||
)
|
||||
|
||||
# Check if result contains error
|
||||
if result.get("status") == "error":
|
||||
logger.error(f"Analysis failed: {result.get('error')}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Analysis failed: {result.get('error', 'Unknown error')}"
|
||||
logger.info(f"Creating analysis task for {request.ticker} on {request.analysis_date}")
|
||||
|
||||
# Create task in Redis
|
||||
task_id = task_manager.create_task({
|
||||
"ticker": request.ticker,
|
||||
"analysis_date": request.analysis_date,
|
||||
})
|
||||
|
||||
# Start background analysis
|
||||
def run_background_analysis():
|
||||
try:
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
TaskStatus.RUNNING,
|
||||
progress="Starting analysis..."
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during analysis: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Analysis failed: {str(e)}"
|
||||
)
|
||||
|
||||
result = service.run_analysis(
|
||||
ticker=request.ticker,
|
||||
analysis_date=request.analysis_date,
|
||||
analysts=request.analysts,
|
||||
research_depth=request.research_depth,
|
||||
deep_think_llm=request.deep_think_llm,
|
||||
quick_think_llm=request.quick_think_llm,
|
||||
openai_api_key=request.openai_api_key,
|
||||
openai_base_url=request.openai_base_url,
|
||||
alpha_vantage_api_key=request.alpha_vantage_api_key,
|
||||
)
|
||||
|
||||
# Check for errors in result
|
||||
if "status" in result and result["status"] == "error":
|
||||
task_manager.set_task_result(
|
||||
task_id,
|
||||
result={},
|
||||
error=result.get("message", "Analysis failed")
|
||||
)
|
||||
else:
|
||||
task_manager.set_task_result(task_id, result=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Analysis task {task_id} failed: {str(e)}", exc_info=True)
|
||||
task_manager.set_task_result(
|
||||
task_id,
|
||||
result={},
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
# Start background thread
|
||||
thread = threading.Thread(target=run_background_analysis, daemon=True)
|
||||
thread.start()
|
||||
|
||||
return TaskCreatedResponse(
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
message="Analysis task created successfully"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/task/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status(task_id: str):
|
||||
"""
|
||||
Get the status of an analysis task.
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
|
||||
Returns:
|
||||
TaskStatusResponse: Current task status and results if completed
|
||||
|
||||
Raises:
|
||||
HTTPException: If task not found
|
||||
"""
|
||||
task = task_manager.get_task(task_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
|
||||
return TaskStatusResponse(**task)
|
||||
|
||||
|
||||
@router.get("/tickers")
|
||||
|
|
|
|||
|
|
@ -4,15 +4,20 @@ Configuration management for TradingAgents Backend API
|
|||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
import os
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables"""
|
||||
|
||||
# API Configuration
|
||||
# Application settings
|
||||
app_name: str = "TradingAgents API"
|
||||
app_version: str = "1.0.0"
|
||||
debug: bool = True
|
||||
debug: bool = Field(default=False)
|
||||
results_dir: str = Field(default="./results")
|
||||
|
||||
# Redis configuration for task queue
|
||||
redis_url: str = Field(default="redis://localhost:6379")
|
||||
|
||||
# API Keys
|
||||
openai_api_key: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
Pydantic models for request/response schemas
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from typing import List, Optional, Dict, Any, Literal
|
||||
from datetime import date
|
||||
|
||||
|
||||
|
|
@ -87,3 +87,24 @@ class Ticker(BaseModel):
|
|||
"""Ticker information model"""
|
||||
symbol: str = Field(..., description="Stock ticker symbol")
|
||||
name: str = Field(..., description="Company name")
|
||||
|
||||
|
||||
# Task Management Schemas
|
||||
|
||||
class TaskCreatedResponse(BaseModel):
|
||||
"""Response when a task is created"""
|
||||
task_id: str = Field(..., description="Unique task identifier")
|
||||
status: Literal["pending"] = Field(default="pending", description="Initial task status")
|
||||
message: str = Field(default="Analysis task created successfully", description="Success message")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""Response for task status query"""
|
||||
task_id: str = Field(..., description="Task identifier")
|
||||
status: Literal["pending", "running", "completed", "failed"] = Field(..., description="Current task status")
|
||||
created_at: str = Field(..., description="Task creation timestamp")
|
||||
updated_at: str = Field(..., description="Last update timestamp")
|
||||
progress: Optional[str] = Field(None, description="Progress message")
|
||||
result: Optional[AnalysisResponse] = Field(None, description="Analysis result (only when completed)")
|
||||
error: Optional[str] = Field(None, description="Error message (only when failed)")
|
||||
completed_at: Optional[str] = Field(None, description="Completion timestamp")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,173 @@
|
|||
"""
|
||||
Redis Task Manager for async analysis processing
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Optional
|
||||
from enum import Enum
|
||||
import redis
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task status enum"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class RedisTaskManager:
|
||||
"""Manages async tasks using Redis as storage"""
|
||||
|
||||
def __init__(self, redis_url: str):
|
||||
"""
|
||||
Initialize Redis task manager
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
"""
|
||||
self.redis_client = redis.from_url(
|
||||
redis_url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=5
|
||||
)
|
||||
self.task_expiry = timedelta(hours=24)
|
||||
|
||||
def _task_key(self, task_id: str) -> str:
|
||||
"""Generate Redis key for task"""
|
||||
return f"task:{task_id}"
|
||||
|
||||
def create_task(self, task_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create a new task
|
||||
|
||||
Args:
|
||||
task_data: Initial task data
|
||||
|
||||
Returns:
|
||||
task_id: Unique task identifier
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
task = {
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.PENDING,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
**task_data
|
||||
}
|
||||
|
||||
key = self._task_key(task_id)
|
||||
self.redis_client.setex(
|
||||
key,
|
||||
self.task_expiry,
|
||||
json.dumps(task)
|
||||
)
|
||||
|
||||
logger.info(f"Created task {task_id}")
|
||||
return task_id
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get task by ID
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
|
||||
Returns:
|
||||
Task data or None if not found
|
||||
"""
|
||||
key = self._task_key(task_id)
|
||||
data = self.redis_client.get(key)
|
||||
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
def update_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
progress: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Update task status and progress
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
status: New task status
|
||||
progress: Progress message
|
||||
"""
|
||||
task = self.get_task(task_id)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found for status update")
|
||||
return
|
||||
|
||||
task["status"] = status
|
||||
task["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
if progress:
|
||||
task["progress"] = progress
|
||||
|
||||
key = self._task_key(task_id)
|
||||
self.redis_client.setex(
|
||||
key,
|
||||
self.task_expiry,
|
||||
json.dumps(task)
|
||||
)
|
||||
|
||||
logger.info(f"Updated task {task_id} status to {status}")
|
||||
|
||||
def set_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
result: Dict[str, Any],
|
||||
error: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Set task result (success or failure)
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
result: Task result data
|
||||
error: Error message if failed
|
||||
"""
|
||||
task = self.get_task(task_id)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found for result update")
|
||||
return
|
||||
|
||||
if error:
|
||||
task["status"] = TaskStatus.FAILED
|
||||
task["error"] = error
|
||||
else:
|
||||
task["status"] = TaskStatus.COMPLETED
|
||||
task["result"] = result
|
||||
|
||||
task["updated_at"] = datetime.now().isoformat()
|
||||
task["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
key = self._task_key(task_id)
|
||||
self.redis_client.setex(
|
||||
key,
|
||||
self.task_expiry,
|
||||
json.dumps(task)
|
||||
)
|
||||
|
||||
status_msg = "completed" if not error else f"failed: {error}"
|
||||
logger.info(f"Task {task_id} {status_msg}")
|
||||
|
||||
def delete_task(self, task_id: str):
|
||||
"""
|
||||
Delete a task
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
"""
|
||||
key = self._task_key(task_id)
|
||||
self.redis_client.delete(key)
|
||||
logger.info(f"Deleted task {task_id}")
|
||||
|
|
@ -38,3 +38,4 @@ langchain_anthropic
|
|||
langchain-google-genai
|
||||
beautifulsoup4>=4.12.0
|
||||
tenacity>=8.2.0
|
||||
redis>=5.0.0
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
/**
|
||||
* Custom hook for trading analysis
|
||||
* Custom hook for trading analysis with async task support
|
||||
*/
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { api } from "@/lib/api";
|
||||
import type { AnalysisRequest, AnalysisResponse } from "@/lib/types";
|
||||
|
||||
|
|
@ -11,30 +11,120 @@ export function useAnalysis() {
|
|||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [result, setResult] = useState<AnalysisResponse | null>(null);
|
||||
const [taskId, setTaskId] = useState<string | null>(null);
|
||||
const [progress, setProgress] = useState<string | null>(null);
|
||||
const pollingIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
// Poll for task status
|
||||
const pollTaskStatus = async (id: string) => {
|
||||
try {
|
||||
const status = await api.getTaskStatus(id);
|
||||
|
||||
// Update progress
|
||||
if (status.progress) {
|
||||
setProgress(status.progress);
|
||||
}
|
||||
|
||||
// Check if completed
|
||||
if (status.status === "completed") {
|
||||
if (status.result) {
|
||||
setResult(status.result);
|
||||
}
|
||||
setLoading(false);
|
||||
setProgress(null);
|
||||
// Stop polling
|
||||
if (pollingIntervalRef.current) {
|
||||
clearInterval(pollingIntervalRef.current);
|
||||
pollingIntervalRef.current = null;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if failed
|
||||
if (status.status === "failed") {
|
||||
setError(status.error || "Analysis failed");
|
||||
setLoading(false);
|
||||
setProgress(null);
|
||||
// Stop polling
|
||||
if (pollingIntervalRef.current) {
|
||||
clearInterval(pollingIntervalRef.current);
|
||||
pollingIntervalRef.current = null;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false; // Still running
|
||||
} catch (err: any) {
|
||||
console.error("Error polling task status:", err);
|
||||
// Don't stop polling on temporary errors
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
// Start polling
|
||||
const startPolling = (id: string) => {
|
||||
// Clear any existing interval
|
||||
if (pollingIntervalRef.current) {
|
||||
clearInterval(pollingIntervalRef.current);
|
||||
}
|
||||
|
||||
// Poll every 3 seconds
|
||||
pollingIntervalRef.current = setInterval(async () => {
|
||||
await pollTaskStatus(id);
|
||||
}, 3000);
|
||||
|
||||
// Also poll immediately
|
||||
pollTaskStatus(id);
|
||||
};
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (pollingIntervalRef.current) {
|
||||
clearInterval(pollingIntervalRef.current);
|
||||
pollingIntervalRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
const runAnalysis = async (request: AnalysisRequest) => {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
setResult(null);
|
||||
setProgress("Submitting analysis request...");
|
||||
|
||||
try {
|
||||
const response = await api.runAnalysis(request);
|
||||
setResult(response);
|
||||
return response;
|
||||
// Start analysis task
|
||||
const taskResponse = await api.runAnalysis(request);
|
||||
setTaskId(taskResponse.task_id);
|
||||
setProgress("Analysis started, waiting for results...");
|
||||
|
||||
// Start polling for status
|
||||
startPolling(taskResponse.task_id);
|
||||
|
||||
return taskResponse;
|
||||
} catch (err: any) {
|
||||
const errorMessage =
|
||||
err.response?.data?.detail || err.message || "Analysis failed";
|
||||
err.response?.data?.detail || err.message || "Failed to start analysis";
|
||||
setError(errorMessage);
|
||||
throw err;
|
||||
} finally {
|
||||
setLoading(false);
|
||||
setProgress(null);
|
||||
throw err;
|
||||
}
|
||||
};
|
||||
|
||||
const reset = () => {
|
||||
// Stop polling
|
||||
if (pollingIntervalRef.current) {
|
||||
clearInterval(pollingIntervalRef.current);
|
||||
pollingIntervalRef.current = null;
|
||||
}
|
||||
|
||||
setLoading(false);
|
||||
setError(null);
|
||||
setResult(null);
|
||||
setTaskId(null);
|
||||
setProgress(null);
|
||||
};
|
||||
|
||||
return {
|
||||
|
|
@ -42,6 +132,8 @@ export function useAnalysis() {
|
|||
loading,
|
||||
error,
|
||||
result,
|
||||
taskId,
|
||||
progress,
|
||||
reset,
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import type {
|
|||
ConfigResponse,
|
||||
HealthResponse,
|
||||
Ticker,
|
||||
TaskCreatedResponse,
|
||||
TaskStatusResponse,
|
||||
} from "./types";
|
||||
|
||||
const apiClient = axios.create({
|
||||
|
|
@ -34,16 +36,24 @@ export const api = {
|
|||
},
|
||||
|
||||
/**
|
||||
* Run trading analysis
|
||||
* Start analysis (returns task ID)
|
||||
*/
|
||||
async runAnalysis(request: AnalysisRequest): Promise<AnalysisResponse> {
|
||||
const response = await apiClient.post<AnalysisResponse>(
|
||||
async runAnalysis(request: AnalysisRequest): Promise<TaskCreatedResponse> {
|
||||
const response = await apiClient.post<TaskCreatedResponse>(
|
||||
"/api/analyze",
|
||||
request
|
||||
);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
/**
|
||||
* Get task status
|
||||
*/
|
||||
async getTaskStatus(taskId: string): Promise<TaskStatusResponse> {
|
||||
const response = await apiClient.get<TaskStatusResponse>(`/api/task/${taskId}`);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
/**
|
||||
* Get list of popular tickers
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -93,3 +93,23 @@ export interface Ticker {
|
|||
symbol: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
// Task Management Types
|
||||
|
||||
export interface TaskCreatedResponse {
|
||||
task_id: string;
|
||||
status: "pending";
|
||||
message: string;
|
||||
}
|
||||
|
||||
export interface TaskStatusResponse {
|
||||
task_id: string;
|
||||
status: "pending" | "running" | "completed" | "failed";
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
progress?: string;
|
||||
result?: AnalysisResponse;
|
||||
error?: string;
|
||||
completed_at?: string;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue