This commit is contained in:
MarkLo 2025-12-16 19:03:58 +08:00
parent 488eeac64c
commit 803885305a
6 changed files with 390 additions and 143 deletions

View File

@ -30,10 +30,13 @@ router = APIRouter(prefix="/api", tags=["TradingAgentsX"])
@router.get("/health", response_model=HealthResponse) @router.get("/health", response_model=HealthResponse)
async def health_check(): async def health_check():
"""Health check endpoint""" """Health check endpoint"""
from backend.app.services.redis_client import is_redis_available
return HealthResponse( return HealthResponse(
status="healthy", status="healthy",
version=settings.app_version, version=settings.app_version,
timestamp=datetime.now().isoformat(), timestamp=datetime.now().isoformat(),
redis_connected=is_redis_available(),
) )

View File

@ -105,6 +105,7 @@ class HealthResponse(BaseModel):
status: str = Field(..., description="API health status") status: str = Field(..., description="API health status")
version: str = Field(..., description="API version") version: str = Field(..., description="API version")
timestamp: str = Field(..., description="Current server timestamp") timestamp: str = Field(..., description="Current server timestamp")
redis_connected: bool = Field(False, description="Whether Redis is connected")
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):

View File

@ -0,0 +1,275 @@
"""
Redis client for production caching
This module provides an optional Redis connection for:
- Task status persistence (survives server restarts)
- Rate limiting across multiple instances
- API response caching
If REDIS_URL is not set, all operations will be no-ops and
the system will fall back to in-memory storage.
"""
import os
import json
import logging
from typing import Optional, Any
logger = logging.getLogger(__name__)
# Redis URL from environment (Railway provides this automatically)
REDIS_URL = os.getenv("REDIS_URL", "")
# Redis client instance (lazy initialization)
_redis_client = None
def get_redis_client():
"""
Get Redis client instance (lazy initialization).
Returns None if Redis is not configured.
"""
global _redis_client
if _redis_client is not None:
return _redis_client
if not REDIS_URL:
logger.info("Redis not configured (REDIS_URL not set) - using in-memory storage")
return None
try:
import redis
_redis_client = redis.from_url(
REDIS_URL,
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5,
)
# Test connection
_redis_client.ping()
logger.info("✅ Redis connected successfully")
return _redis_client
except Exception as e:
logger.warning(f"⚠️ Redis connection failed: {e} - using in-memory storage")
return None
def is_redis_available() -> bool:
"""Check if Redis is available and connected."""
client = get_redis_client()
if client is None:
return False
try:
client.ping()
return True
except:
return False
# ============== Task Storage ==============
def save_task_to_redis(task_id: str, data: dict, expire_seconds: int = 86400) -> bool:
"""
Save task data to Redis.
Args:
task_id: Unique task identifier
data: Task data dictionary
expire_seconds: TTL in seconds (default 24 hours)
Returns:
True if saved successfully, False otherwise
"""
client = get_redis_client()
if client is None:
return False
try:
key = f"task:{task_id}"
client.setex(key, expire_seconds, json.dumps(data, default=str))
return True
except Exception as e:
logger.error(f"Failed to save task to Redis: {e}")
return False
def get_task_from_redis(task_id: str) -> Optional[dict]:
"""
Get task data from Redis.
Args:
task_id: Unique task identifier
Returns:
Task data dictionary or None if not found
"""
client = get_redis_client()
if client is None:
return None
try:
key = f"task:{task_id}"
data = client.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"Failed to get task from Redis: {e}")
return None
def delete_task_from_redis(task_id: str) -> bool:
"""
Delete task data from Redis.
Args:
task_id: Unique task identifier
Returns:
True if deleted successfully, False otherwise
"""
client = get_redis_client()
if client is None:
return False
try:
key = f"task:{task_id}"
client.delete(key)
return True
except Exception as e:
logger.error(f"Failed to delete task from Redis: {e}")
return False
def update_task_in_redis(task_id: str, updates: dict) -> bool:
"""
Update specific fields in task data.
Args:
task_id: Unique task identifier
updates: Dictionary of fields to update
Returns:
True if updated successfully, False otherwise
"""
existing = get_task_from_redis(task_id)
if existing is None:
return False
existing.update(updates)
return save_task_to_redis(task_id, existing)
# ============== Rate Limiting ==============
def check_rate_limit(key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
"""
Check rate limit for a given key.
Args:
key: Unique identifier (e.g., IP address)
max_requests: Maximum allowed requests
window_seconds: Time window in seconds
Returns:
Tuple of (is_allowed, remaining_requests)
"""
client = get_redis_client()
if client is None:
# If Redis not available, allow all (fall back to in-memory rate limiting)
return True, max_requests
try:
rate_key = f"ratelimit:{key}"
current = client.get(rate_key)
if current is None:
# First request in window
client.setex(rate_key, window_seconds, 1)
return True, max_requests - 1
count = int(current)
if count >= max_requests:
return False, 0
client.incr(rate_key)
return True, max_requests - count - 1
except Exception as e:
logger.error(f"Rate limit check failed: {e}")
return True, max_requests # Allow on error
# ============== Caching ==============
def cache_set(key: str, value: Any, expire_seconds: int = 3600) -> bool:
"""
Set a cache value.
Args:
key: Cache key
value: Value to cache (will be JSON serialized)
expire_seconds: TTL in seconds (default 1 hour)
Returns:
True if cached successfully, False otherwise
"""
client = get_redis_client()
if client is None:
return False
try:
cache_key = f"cache:{key}"
client.setex(cache_key, expire_seconds, json.dumps(value, default=str))
return True
except Exception as e:
logger.error(f"Failed to set cache: {e}")
return False
def cache_get(key: str) -> Optional[Any]:
"""
Get a cached value.
Args:
key: Cache key
Returns:
Cached value or None if not found
"""
client = get_redis_client()
if client is None:
return None
try:
cache_key = f"cache:{key}"
data = client.get(cache_key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"Failed to get cache: {e}")
return None
def cache_delete(key: str) -> bool:
"""
Delete a cached value.
Args:
key: Cache key
Returns:
True if deleted successfully, False otherwise
"""
client = get_redis_client()
if client is None:
return False
try:
cache_key = f"cache:{key}"
client.delete(cache_key)
return True
except Exception as e:
logger.error(f"Failed to delete cache: {e}")
return False

View File

@ -1,28 +1,52 @@
""" """
In-Memory Task Manager for managing async analysis tasks Hybrid Task Manager - Redis + In-Memory
Uses Redis when available (production on Railway),
falls back to in-memory storage (local development).
""" """
import uuid import uuid
import json import json
import threading import threading
import logging
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from datetime import datetime, timedelta from datetime import datetime, timedelta
from backend.app.services.redis_client import (
save_task_to_redis,
get_task_from_redis,
delete_task_from_redis,
is_redis_available,
)
class InMemoryTaskManager: logger = logging.getLogger(__name__)
class HybridTaskManager:
""" """
Manages async tasks using in-memory storage with thread safety. Manages async tasks using Redis when available,
with in-memory fallback for local development.
Note: Tasks will be lost if the server restarts. Features:
Consider using Redis for production if persistence is needed. - Thread-safe in-memory storage
- Redis persistence when REDIS_URL is configured
- Automatic cleanup of expired tasks
- Seamless fallback between storage backends
""" """
def __init__(self): def __init__(self):
"""Initialize in-memory task storage""" """Initialize hybrid task storage"""
# In-memory storage (always available as fallback)
self._tasks: Dict[str, Dict[str, Any]] = {} self._tasks: Dict[str, Dict[str, Any]] = {}
self._lock = threading.RLock() # Reentrant lock for thread safety self._lock = threading.RLock()
self._cleanup_interval = 3600 # 1 hour self._cleanup_interval = 3600 # 1 hour
self._task_expiry = 86400 # 24 hours self._task_expiry = 86400 # 24 hours
# Check Redis availability on startup
if is_redis_available():
logger.info("📦 Task Manager: Using Redis for task storage")
else:
logger.info("📦 Task Manager: Using in-memory storage (Redis not available)")
# Start background cleanup thread # Start background cleanup thread
self._start_cleanup_thread() self._start_cleanup_thread()
@ -37,19 +61,52 @@ class InMemoryTaskManager:
cleanup_thread.start() cleanup_thread.start()
def _cleanup_expired_tasks(self): def _cleanup_expired_tasks(self):
"""Remove tasks older than expiry time""" """Remove tasks older than expiry time (in-memory only, Redis has TTL)"""
with self._lock: with self._lock:
current_time = datetime.now() current_time = datetime.now()
expired_keys = [] expired_keys = []
for task_id, task_data in self._tasks.items(): for task_id, task_data in self._tasks.items():
created_at = datetime.fromisoformat(task_data.get("created_at", "")) created_at_str = task_data.get("created_at", "")
if current_time - created_at > timedelta(seconds=self._task_expiry): if created_at_str:
expired_keys.append(task_id) try:
created_at = datetime.fromisoformat(created_at_str)
if current_time - created_at > timedelta(seconds=self._task_expiry):
expired_keys.append(task_id)
except:
pass
for key in expired_keys: for key in expired_keys:
del self._tasks[key] del self._tasks[key]
def _save_to_storage(self, task_id: str, task_data: dict):
"""Save task to both Redis (if available) and in-memory"""
# Always save to in-memory (fast access)
with self._lock:
self._tasks[task_id] = task_data
# Also save to Redis if available (persistence)
if is_redis_available():
save_task_to_redis(task_id, task_data, self._task_expiry)
def _get_from_storage(self, task_id: str) -> Optional[dict]:
"""Get task from in-memory first, then Redis"""
# Check in-memory first (fastest)
with self._lock:
if task_id in self._tasks:
return self._tasks[task_id]
# Try Redis if not in memory (e.g., after server restart)
if is_redis_available():
redis_data = get_task_from_redis(task_id)
if redis_data:
# Cache in memory for future access
with self._lock:
self._tasks[task_id] = redis_data
return redis_data
return None
def create_task(self, initial_data: Dict[str, Any]) -> str: def create_task(self, initial_data: Dict[str, Any]) -> str:
""" """
Create a new task with initial data Create a new task with initial data
@ -72,9 +129,7 @@ class InMemoryTaskManager:
**initial_data **initial_data
} }
with self._lock: self._save_to_storage(task_id, task_data)
self._tasks[task_id] = task_data
return task_id return task_id
def update_task_status(self, task_id: str, status: str, progress: Optional[str] = None): def update_task_status(self, task_id: str, status: str, progress: Optional[str] = None):
@ -86,12 +141,13 @@ class InMemoryTaskManager:
status: New status (pending, running, completed, failed) status: New status (pending, running, completed, failed)
progress: Optional progress message progress: Optional progress message
""" """
with self._lock: task_data = self._get_from_storage(task_id)
if task_id in self._tasks: if task_data:
self._tasks[task_id]["status"] = status task_data["status"] = status
if progress: if progress:
self._tasks[task_id]["progress"] = progress task_data["progress"] = progress
self._tasks[task_id]["updated_at"] = datetime.now().isoformat() task_data["updated_at"] = datetime.now().isoformat()
self._save_to_storage(task_id, task_data)
def update_task_progress(self, task_id: str, progress: str): def update_task_progress(self, task_id: str, progress: str):
""" """
@ -101,10 +157,11 @@ class InMemoryTaskManager:
task_id: Task ID task_id: Task ID
progress: Progress message progress: Progress message
""" """
with self._lock: task_data = self._get_from_storage(task_id)
if task_id in self._tasks: if task_data:
self._tasks[task_id]["progress"] = progress task_data["progress"] = progress
self._tasks[task_id]["updated_at"] = datetime.now().isoformat() task_data["updated_at"] = datetime.now().isoformat()
self._save_to_storage(task_id, task_data)
def set_task_result(self, task_id: str, result: Any): def set_task_result(self, task_id: str, result: Any):
""" """
@ -112,14 +169,15 @@ class InMemoryTaskManager:
Args: Args:
task_id: Task ID task_id: Task ID
result: Task result (will be JSON serialized) result: Task result
""" """
with self._lock: task_data = self._get_from_storage(task_id)
if task_id in self._tasks: if task_data:
self._tasks[task_id]["status"] = "completed" task_data["status"] = "completed"
self._tasks[task_id]["result"] = result task_data["result"] = result
self._tasks[task_id]["progress"] = "Analysis completed" task_data["progress"] = "Analysis completed"
self._tasks[task_id]["completed_at"] = datetime.now().isoformat() task_data["completed_at"] = datetime.now().isoformat()
self._save_to_storage(task_id, task_data)
def set_task_error(self, task_id: str, error: str): def set_task_error(self, task_id: str, error: str):
""" """
@ -129,12 +187,13 @@ class InMemoryTaskManager:
task_id: Task ID task_id: Task ID
error: Error message error: Error message
""" """
with self._lock: task_data = self._get_from_storage(task_id)
if task_id in self._tasks: if task_data:
self._tasks[task_id]["status"] = "failed" task_data["status"] = "failed"
self._tasks[task_id]["error"] = error task_data["error"] = error
self._tasks[task_id]["progress"] = "Analysis failed" task_data["progress"] = "Analysis failed"
self._tasks[task_id]["failed_at"] = datetime.now().isoformat() task_data["failed_at"] = datetime.now().isoformat()
self._save_to_storage(task_id, task_data)
def get_task(self, task_id: str) -> Optional[Dict[str, Any]]: def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
""" """
@ -146,8 +205,7 @@ class InMemoryTaskManager:
Returns: Returns:
Task data or None if not found Task data or None if not found
""" """
with self._lock: return self._get_from_storage(task_id)
return self._tasks.get(task_id)
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
""" """
@ -157,7 +215,7 @@ class InMemoryTaskManager:
task_id: Task ID task_id: Task ID
Returns: Returns:
Dictionary with task status information including all required fields Dictionary with task status information
""" """
task = self.get_task(task_id) task = self.get_task(task_id)
if not task: if not task:
@ -167,7 +225,7 @@ class InMemoryTaskManager:
"task_id": task["task_id"], "task_id": task["task_id"],
"status": task["status"], "status": task["status"],
"created_at": task.get("created_at"), "created_at": task.get("created_at"),
"updated_at": task.get("updated_at", task.get("created_at")), # Fallback to created_at if updated_at not set "updated_at": task.get("updated_at", task.get("created_at")),
"progress": task.get("progress"), "progress": task.get("progress"),
"result": task.get("result"), "result": task.get("result"),
"error": task.get("error"), "error": task.get("error"),
@ -184,17 +242,20 @@ class InMemoryTaskManager:
with self._lock: with self._lock:
if task_id in self._tasks: if task_id in self._tasks:
del self._tasks[task_id] del self._tasks[task_id]
if is_redis_available():
delete_task_from_redis(task_id)
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]: def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
""" """
Get all tasks (for debugging) Get all tasks (for debugging)
Returns: Returns:
Dictionary of all tasks Dictionary of all tasks (in-memory only)
""" """
with self._lock: with self._lock:
return self._tasks.copy() return self._tasks.copy()
# Global task manager instance # Global task manager instance
task_manager = InMemoryTaskManager() task_manager = HybridTaskManager()

View File

@ -22,6 +22,9 @@ PyJWT>=2.8.0
cryptography>=41.0.0 cryptography>=41.0.0
httpx>=0.25.0 httpx>=0.25.0
# Redis (optional - for production caching)
redis>=5.0.0
# Existing TradingAgentsX dependencies # Existing TradingAgentsX dependencies
typing-extensions typing-extensions
langchain-openai langchain-openai

View File

@ -27,7 +27,7 @@ import {
DialogHeader, DialogHeader,
DialogTitle, DialogTitle,
} from "@/components/ui/dialog"; } from "@/components/ui/dialog";
import { Trash2, Eye, RefreshCw, TrendingUp, Cloud, CloudOff, FileText, Download } from "lucide-react"; import { Trash2, Eye, RefreshCw, TrendingUp, CloudOff, FileText, Download } from "lucide-react";
import { import {
getReportsByMarketType, getReportsByMarketType,
deleteReport, deleteReport,
@ -152,9 +152,7 @@ export default function HistoryPage() {
); );
const [deleting, setDeleting] = useState(false); const [deleting, setDeleting] = useState(false);
// Sync state // Auto-sync tracking ref
const [syncing, setSyncing] = useState(false);
const [syncResult, setSyncResult] = useState<{ success: number; failed: number } | null>(null);
const hasAutoSyncedRef = useRef(false); const hasAutoSyncedRef = useRef(false);
// Load reports when tab changes or auth state changes // Load reports when tab changes or auth state changes
@ -352,83 +350,6 @@ export default function HistoryPage() {
} }
}; };
// Sync local reports to cloud
const handleSyncToCloud = async () => {
if (!isAuthenticated || !isCloudSyncEnabled()) {
alert("請先登入以啟用雲端同步");
return;
}
setSyncing(true);
setSyncResult(null);
try {
// Get all local reports
const [usLocal, twseLocal, tpexLocal] = await Promise.all([
getReportsByMarketType("us"),
getReportsByMarketType("twse"),
getReportsByMarketType("tpex"),
]);
const allLocal = [...usLocal, ...twseLocal, ...tpexLocal];
// Get cloud reports to check for duplicates
const cloudReports = await getCloudReports();
const cloudKeys = new Set(
cloudReports.map(r => `${r.ticker}_${r.analysis_date}`)
);
// Find local-only reports to upload
const toUpload = allLocal.filter(
r => !cloudKeys.has(`${r.ticker}_${r.analysis_date}`)
);
if (toUpload.length === 0) {
setSyncResult({ success: 0, failed: 0 });
alert("所有報告已同步到雲端!");
return;
}
// Upload each report
let success = 0;
let failed = 0;
for (const report of toUpload) {
try {
const cloudId = await saveCloudReport({
ticker: report.ticker,
market_type: report.market_type,
analysis_date: report.analysis_date,
result: report.result,
});
if (cloudId) {
success++;
} else {
failed++;
}
} catch (e) {
failed++;
}
}
setSyncResult({ success, failed });
// Reload data after sync
await loadReports();
await loadCounts();
if (failed === 0) {
alert(`成功同步 ${success} 份報告到雲端!`);
} else {
alert(`同步完成:${success} 成功,${failed} 失敗`);
}
} catch (error) {
console.error("Sync failed:", error);
alert("同步失敗,請稍後再試");
} finally {
setSyncing(false);
}
};
const handleViewReport = (report: SavedReport) => { const handleViewReport = (report: SavedReport) => {
// Set the context with the saved report data // Set the context with the saved report data
setAnalysisResult(report.result); setAnalysisResult(report.result);
@ -592,25 +513,8 @@ export default function HistoryPage() {
(marketType) => ( (marketType) => (
<TabsContent key={marketType} value={marketType} className="mt-6"> <TabsContent key={marketType} value={marketType} className="mt-6">
<div className="space-y-4"> <div className="space-y-4">
{/* Action buttons */} {/* Refresh button */}
<div className="flex justify-end gap-2"> <div className="flex justify-end">
{/* Sync to Cloud button - only show when authenticated */}
{isAuthenticated && (
<Button
variant="outline"
size="sm"
onClick={handleSyncToCloud}
disabled={syncing || loading}
className="gap-2"
>
<Cloud
className={`h-4 w-4 ${syncing ? "animate-pulse" : ""}`}
/>
{syncing ? "同步中..." : "同步到雲端"}
</Button>
)}
{/* Refresh button */}
<Button <Button
variant="outline" variant="outline"
size="sm" size="sm"