This commit is contained in:
parent
b554ae71d5
commit
4c18ea3833
|
|
@ -16,13 +16,10 @@ from backend.app.models.schemas import (
|
|||
TaskStatusResponse,
|
||||
)
|
||||
from backend.app.services.trading_service import TradingService
|
||||
from backend.app.services.task_manager import RedisTaskManager, TaskStatus
|
||||
from backend.app.services.task_manager import task_manager
|
||||
from backend.app.api.dependencies import get_trading_service
|
||||
from backend.app.core.config import settings
|
||||
|
||||
# Initialize task manager
|
||||
task_manager = RedisTaskManager(settings.redis_url)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create API router
|
||||
|
|
|
|||
|
|
@ -16,9 +16,6 @@ class Settings(BaseSettings):
|
|||
debug: bool = Field(default=False)
|
||||
results_dir: str = Field(default="./results")
|
||||
|
||||
# Redis configuration for task queue
|
||||
redis_url: str = Field(default="redis://localhost:6379", validation_alias="REDIS_URL")
|
||||
|
||||
# API Keys
|
||||
openai_api_key: Optional[str] = None
|
||||
alpha_vantage_api_key: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -27,21 +27,6 @@ app = FastAPI(
|
|||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Log configuration on startup"""
|
||||
# Debug: Log available environment variables (keys only)
|
||||
import os
|
||||
logger.warning(f"Available environment variables: {list(os.environ.keys())}")
|
||||
|
||||
redis_url = settings.redis_url
|
||||
# Mask password if present
|
||||
if "@" in redis_url:
|
||||
masked_url = redis_url.split("@")[1]
|
||||
logger.warning(f"Redis configured with host: {masked_url}")
|
||||
else:
|
||||
logger.warning(f"Redis configured with URL: {redis_url}")
|
||||
|
||||
# Setup CORS
|
||||
setup_cors(app)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,173 +1,197 @@
|
|||
"""
|
||||
Redis Task Manager for async analysis processing
|
||||
In-Memory Task Manager for managing async analysis tasks
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
from enum import Enum
|
||||
import redis
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task status enum"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class RedisTaskManager:
|
||||
"""Manages async tasks using Redis as storage"""
|
||||
class InMemoryTaskManager:
|
||||
"""
|
||||
Manages async tasks using in-memory storage with thread safety.
|
||||
|
||||
def __init__(self, redis_url: str):
|
||||
Note: Tasks will be lost if the server restarts.
|
||||
Consider using Redis for production if persistence is needed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize in-memory task storage"""
|
||||
self._tasks: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
self._cleanup_interval = 3600 # 1 hour
|
||||
self._task_expiry = 86400 # 24 hours
|
||||
|
||||
# Start background cleanup thread
|
||||
self._start_cleanup_thread()
|
||||
|
||||
def _start_cleanup_thread(self):
|
||||
"""Start a background thread to clean up expired tasks"""
|
||||
def cleanup_worker():
|
||||
while True:
|
||||
threading.Event().wait(self._cleanup_interval)
|
||||
self._cleanup_expired_tasks()
|
||||
|
||||
cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
def _cleanup_expired_tasks(self):
|
||||
"""Remove tasks older than expiry time"""
|
||||
with self._lock:
|
||||
current_time = datetime.now()
|
||||
expired_keys = []
|
||||
|
||||
for task_id, task_data in self._tasks.items():
|
||||
created_at = datetime.fromisoformat(task_data.get("created_at", ""))
|
||||
if current_time - created_at > timedelta(seconds=self._task_expiry):
|
||||
expired_keys.append(task_id)
|
||||
|
||||
for key in expired_keys:
|
||||
del self._tasks[key]
|
||||
|
||||
def create_task(self, initial_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Initialize Redis task manager
|
||||
Create a new task with initial data
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
"""
|
||||
self.redis_client = redis.from_url(
|
||||
redis_url,
|
||||
decode_responses=True,
|
||||
socket_connect_timeout=5
|
||||
)
|
||||
self.task_expiry = timedelta(hours=24)
|
||||
|
||||
def _task_key(self, task_id: str) -> str:
|
||||
"""Generate Redis key for task"""
|
||||
return f"task:{task_id}"
|
||||
|
||||
def create_task(self, task_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Create a new task
|
||||
|
||||
Args:
|
||||
task_data: Initial task data
|
||||
initial_data: Initial task data
|
||||
|
||||
Returns:
|
||||
task_id: Unique task identifier
|
||||
Task ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
task = {
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.PENDING,
|
||||
"status": "pending",
|
||||
"progress": "Task created",
|
||||
"result": None,
|
||||
"error": None,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
**task_data
|
||||
**initial_data
|
||||
}
|
||||
|
||||
key = self._task_key(task_id)
|
||||
self.redis_client.setex(
|
||||
key,
|
||||
self.task_expiry,
|
||||
json.dumps(task)
|
||||
)
|
||||
with self._lock:
|
||||
self._tasks[task_id] = task_data
|
||||
|
||||
logger.info(f"Created task {task_id}")
|
||||
return task_id
|
||||
|
||||
def update_task_status(self, task_id: str, status: str, progress: Optional[str] = None):
|
||||
"""
|
||||
Update task status and optional progress message
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
status: New status (pending, running, completed, failed)
|
||||
progress: Optional progress message
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self._tasks:
|
||||
self._tasks[task_id]["status"] = status
|
||||
if progress:
|
||||
self._tasks[task_id]["progress"] = progress
|
||||
self._tasks[task_id]["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
def update_task_progress(self, task_id: str, progress: str):
|
||||
"""
|
||||
Update task progress message
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
progress: Progress message
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self._tasks:
|
||||
self._tasks[task_id]["progress"] = progress
|
||||
self._tasks[task_id]["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
def set_task_result(self, task_id: str, result: Any):
|
||||
"""
|
||||
Set task result and mark as completed
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
result: Task result (will be JSON serialized)
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self._tasks:
|
||||
self._tasks[task_id]["status"] = "completed"
|
||||
self._tasks[task_id]["result"] = result
|
||||
self._tasks[task_id]["progress"] = "Analysis completed"
|
||||
self._tasks[task_id]["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
def set_task_error(self, task_id: str, error: str):
|
||||
"""
|
||||
Set task error and mark as failed
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
error: Error message
|
||||
"""
|
||||
with self._lock:
|
||||
if task_id in self._tasks:
|
||||
self._tasks[task_id]["status"] = "failed"
|
||||
self._tasks[task_id]["error"] = error
|
||||
self._tasks[task_id]["progress"] = "Analysis failed"
|
||||
self._tasks[task_id]["failed_at"] = datetime.now().isoformat()
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get task by ID
|
||||
Get task data by ID
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Task data or None if not found
|
||||
"""
|
||||
key = self._task_key(task_id)
|
||||
data = self.redis_client.get(key)
|
||||
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
with self._lock:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def update_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus,
|
||||
progress: Optional[str] = None
|
||||
):
|
||||
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Update task status and progress
|
||||
Get task status information
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
status: New task status
|
||||
progress: Progress message
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Dictionary with task status information
|
||||
"""
|
||||
task = self.get_task(task_id)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found for status update")
|
||||
return
|
||||
return None
|
||||
|
||||
task["status"] = status
|
||||
task["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
if progress:
|
||||
task["progress"] = progress
|
||||
|
||||
key = self._task_key(task_id)
|
||||
self.redis_client.setex(
|
||||
key,
|
||||
self.task_expiry,
|
||||
json.dumps(task)
|
||||
)
|
||||
|
||||
logger.info(f"Updated task {task_id} status to {status}")
|
||||
|
||||
def set_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
result: Dict[str, Any],
|
||||
error: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Set task result (success or failure)
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
result: Task result data
|
||||
error: Error message if failed
|
||||
"""
|
||||
task = self.get_task(task_id)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found for result update")
|
||||
return
|
||||
|
||||
if error:
|
||||
task["status"] = TaskStatus.FAILED
|
||||
task["error"] = error
|
||||
else:
|
||||
task["status"] = TaskStatus.COMPLETED
|
||||
task["result"] = result
|
||||
|
||||
task["updated_at"] = datetime.now().isoformat()
|
||||
task["completed_at"] = datetime.now().isoformat()
|
||||
|
||||
key = self._task_key(task_id)
|
||||
self.redis_client.setex(
|
||||
key,
|
||||
self.task_expiry,
|
||||
json.dumps(task)
|
||||
)
|
||||
|
||||
status_msg = "completed" if not error else f"failed: {error}"
|
||||
logger.info(f"Task {task_id} {status_msg}")
|
||||
return {
|
||||
"task_id": task["task_id"],
|
||||
"status": task["status"],
|
||||
"progress": task.get("progress"),
|
||||
"result": task.get("result"),
|
||||
"error": task.get("error"),
|
||||
}
|
||||
|
||||
def delete_task(self, task_id: str):
|
||||
"""
|
||||
Delete a task
|
||||
|
||||
Args:
|
||||
task_id: Task identifier
|
||||
task_id: Task ID
|
||||
"""
|
||||
key = self._task_key(task_id)
|
||||
self.redis_client.delete(key)
|
||||
logger.info(f"Deleted task {task_id}")
|
||||
with self._lock:
|
||||
if task_id in self._tasks:
|
||||
del self._tasks[task_id]
|
||||
|
||||
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get all tasks (for debugging)
|
||||
|
||||
Returns:
|
||||
Dictionary of all tasks
|
||||
"""
|
||||
with self._lock:
|
||||
return self._tasks.copy()
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
task_manager = InMemoryTaskManager()
|
||||
|
|
|
|||
|
|
@ -33,9 +33,7 @@ parsel
|
|||
requests
|
||||
tqdm
|
||||
pytz
|
||||
redis
|
||||
langchain_anthropic
|
||||
langchain-google-genai
|
||||
beautifulsoup4>=4.12.0
|
||||
tenacity>=8.2.0
|
||||
redis>=5.0.0
|
||||
|
|
|
|||
Loading…
Reference in New Issue