This commit is contained in:
MarkLo 2025-12-16 19:03:58 +08:00
parent 488eeac64c
commit 803885305a
6 changed files with 390 additions and 143 deletions

View File

@ -30,10 +30,13 @@ router = APIRouter(prefix="/api", tags=["TradingAgentsX"])
@router.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
from backend.app.services.redis_client import is_redis_available
return HealthResponse(
status="healthy",
version=settings.app_version,
timestamp=datetime.now().isoformat(),
redis_connected=is_redis_available(),
)

View File

@ -105,6 +105,7 @@ class HealthResponse(BaseModel):
status: str = Field(..., description="API health status")
version: str = Field(..., description="API version")
timestamp: str = Field(..., description="Current server timestamp")
redis_connected: bool = Field(False, description="Whether Redis is connected")
class ErrorResponse(BaseModel):

View File

@ -0,0 +1,275 @@
"""
Redis client for production caching
This module provides an optional Redis connection for:
- Task status persistence (survives server restarts)
- Rate limiting across multiple instances
- API response caching
If REDIS_URL is not set, all operations will be no-ops and
the system will fall back to in-memory storage.
"""
import os
import json
import logging
from typing import Optional, Any
logger = logging.getLogger(__name__)
# Redis URL from environment (Railway provides this automatically)
REDIS_URL = os.getenv("REDIS_URL", "")
# Redis client instance (lazy initialization)
_redis_client = None
def get_redis_client():
"""
Get Redis client instance (lazy initialization).
Returns None if Redis is not configured.
"""
global _redis_client
if _redis_client is not None:
return _redis_client
if not REDIS_URL:
logger.info("Redis not configured (REDIS_URL not set) - using in-memory storage")
return None
try:
import redis
_redis_client = redis.from_url(
REDIS_URL,
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5,
)
# Test connection
_redis_client.ping()
logger.info("✅ Redis connected successfully")
return _redis_client
except Exception as e:
logger.warning(f"⚠️ Redis connection failed: {e} - using in-memory storage")
return None
def is_redis_available() -> bool:
"""Check if Redis is available and connected."""
client = get_redis_client()
if client is None:
return False
try:
client.ping()
return True
except:
return False
# ============== Task Storage ==============
def save_task_to_redis(task_id: str, data: dict, expire_seconds: int = 86400) -> bool:
"""
Save task data to Redis.
Args:
task_id: Unique task identifier
data: Task data dictionary
expire_seconds: TTL in seconds (default 24 hours)
Returns:
True if saved successfully, False otherwise
"""
client = get_redis_client()
if client is None:
return False
try:
key = f"task:{task_id}"
client.setex(key, expire_seconds, json.dumps(data, default=str))
return True
except Exception as e:
logger.error(f"Failed to save task to Redis: {e}")
return False
def get_task_from_redis(task_id: str) -> Optional[dict]:
"""
Get task data from Redis.
Args:
task_id: Unique task identifier
Returns:
Task data dictionary or None if not found
"""
client = get_redis_client()
if client is None:
return None
try:
key = f"task:{task_id}"
data = client.get(key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"Failed to get task from Redis: {e}")
return None
def delete_task_from_redis(task_id: str) -> bool:
"""
Delete task data from Redis.
Args:
task_id: Unique task identifier
Returns:
True if deleted successfully, False otherwise
"""
client = get_redis_client()
if client is None:
return False
try:
key = f"task:{task_id}"
client.delete(key)
return True
except Exception as e:
logger.error(f"Failed to delete task from Redis: {e}")
return False
def update_task_in_redis(task_id: str, updates: dict) -> bool:
"""
Update specific fields in task data.
Args:
task_id: Unique task identifier
updates: Dictionary of fields to update
Returns:
True if updated successfully, False otherwise
"""
existing = get_task_from_redis(task_id)
if existing is None:
return False
existing.update(updates)
return save_task_to_redis(task_id, existing)
# ============== Rate Limiting ==============
def check_rate_limit(key: str, max_requests: int, window_seconds: int) -> tuple[bool, int]:
"""
Check rate limit for a given key.
Args:
key: Unique identifier (e.g., IP address)
max_requests: Maximum allowed requests
window_seconds: Time window in seconds
Returns:
Tuple of (is_allowed, remaining_requests)
"""
client = get_redis_client()
if client is None:
# If Redis not available, allow all (fall back to in-memory rate limiting)
return True, max_requests
try:
rate_key = f"ratelimit:{key}"
current = client.get(rate_key)
if current is None:
# First request in window
client.setex(rate_key, window_seconds, 1)
return True, max_requests - 1
count = int(current)
if count >= max_requests:
return False, 0
client.incr(rate_key)
return True, max_requests - count - 1
except Exception as e:
logger.error(f"Rate limit check failed: {e}")
return True, max_requests # Allow on error
# ============== Caching ==============
def cache_set(key: str, value: Any, expire_seconds: int = 3600) -> bool:
"""
Set a cache value.
Args:
key: Cache key
value: Value to cache (will be JSON serialized)
expire_seconds: TTL in seconds (default 1 hour)
Returns:
True if cached successfully, False otherwise
"""
client = get_redis_client()
if client is None:
return False
try:
cache_key = f"cache:{key}"
client.setex(cache_key, expire_seconds, json.dumps(value, default=str))
return True
except Exception as e:
logger.error(f"Failed to set cache: {e}")
return False
def cache_get(key: str) -> Optional[Any]:
"""
Get a cached value.
Args:
key: Cache key
Returns:
Cached value or None if not found
"""
client = get_redis_client()
if client is None:
return None
try:
cache_key = f"cache:{key}"
data = client.get(cache_key)
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"Failed to get cache: {e}")
return None
def cache_delete(key: str) -> bool:
"""
Delete a cached value.
Args:
key: Cache key
Returns:
True if deleted successfully, False otherwise
"""
client = get_redis_client()
if client is None:
return False
try:
cache_key = f"cache:{key}"
client.delete(cache_key)
return True
except Exception as e:
logger.error(f"Failed to delete cache: {e}")
return False

View File

@ -1,28 +1,52 @@
"""
In-Memory Task Manager for managing async analysis tasks
Hybrid Task Manager - Redis + In-Memory
Uses Redis when available (production on Railway),
falls back to in-memory storage (local development).
"""
import uuid
import json
import threading
import logging
from typing import Dict, Any, Optional
from datetime import datetime, timedelta
from backend.app.services.redis_client import (
save_task_to_redis,
get_task_from_redis,
delete_task_from_redis,
is_redis_available,
)
class InMemoryTaskManager:
logger = logging.getLogger(__name__)
class HybridTaskManager:
"""
Manages async tasks using in-memory storage with thread safety.
Manages async tasks using Redis when available,
with in-memory fallback for local development.
Note: Tasks will be lost if the server restarts.
Consider using Redis for production if persistence is needed.
Features:
- Thread-safe in-memory storage
- Redis persistence when REDIS_URL is configured
- Automatic cleanup of expired tasks
- Seamless fallback between storage backends
"""
def __init__(self):
"""Initialize in-memory task storage"""
"""Initialize hybrid task storage"""
# In-memory storage (always available as fallback)
self._tasks: Dict[str, Dict[str, Any]] = {}
self._lock = threading.RLock() # Reentrant lock for thread safety
self._lock = threading.RLock()
self._cleanup_interval = 3600 # 1 hour
self._task_expiry = 86400 # 24 hours
# Check Redis availability on startup
if is_redis_available():
logger.info("📦 Task Manager: Using Redis for task storage")
else:
logger.info("📦 Task Manager: Using in-memory storage (Redis not available)")
# Start background cleanup thread
self._start_cleanup_thread()
@ -37,19 +61,52 @@ class InMemoryTaskManager:
cleanup_thread.start()
def _cleanup_expired_tasks(self):
"""Remove tasks older than expiry time"""
"""Remove tasks older than expiry time (in-memory only, Redis has TTL)"""
with self._lock:
current_time = datetime.now()
expired_keys = []
for task_id, task_data in self._tasks.items():
created_at = datetime.fromisoformat(task_data.get("created_at", ""))
if current_time - created_at > timedelta(seconds=self._task_expiry):
expired_keys.append(task_id)
created_at_str = task_data.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
if current_time - created_at > timedelta(seconds=self._task_expiry):
expired_keys.append(task_id)
except:
pass
for key in expired_keys:
del self._tasks[key]
def _save_to_storage(self, task_id: str, task_data: dict):
"""Save task to both Redis (if available) and in-memory"""
# Always save to in-memory (fast access)
with self._lock:
self._tasks[task_id] = task_data
# Also save to Redis if available (persistence)
if is_redis_available():
save_task_to_redis(task_id, task_data, self._task_expiry)
def _get_from_storage(self, task_id: str) -> Optional[dict]:
"""Get task from in-memory first, then Redis"""
# Check in-memory first (fastest)
with self._lock:
if task_id in self._tasks:
return self._tasks[task_id]
# Try Redis if not in memory (e.g., after server restart)
if is_redis_available():
redis_data = get_task_from_redis(task_id)
if redis_data:
# Cache in memory for future access
with self._lock:
self._tasks[task_id] = redis_data
return redis_data
return None
def create_task(self, initial_data: Dict[str, Any]) -> str:
"""
Create a new task with initial data
@ -72,9 +129,7 @@ class InMemoryTaskManager:
**initial_data
}
with self._lock:
self._tasks[task_id] = task_data
self._save_to_storage(task_id, task_data)
return task_id
def update_task_status(self, task_id: str, status: str, progress: Optional[str] = None):
@ -86,12 +141,13 @@ class InMemoryTaskManager:
status: New status (pending, running, completed, failed)
progress: Optional progress message
"""
with self._lock:
if task_id in self._tasks:
self._tasks[task_id]["status"] = status
if progress:
self._tasks[task_id]["progress"] = progress
self._tasks[task_id]["updated_at"] = datetime.now().isoformat()
task_data = self._get_from_storage(task_id)
if task_data:
task_data["status"] = status
if progress:
task_data["progress"] = progress
task_data["updated_at"] = datetime.now().isoformat()
self._save_to_storage(task_id, task_data)
def update_task_progress(self, task_id: str, progress: str):
"""
@ -101,10 +157,11 @@ class InMemoryTaskManager:
task_id: Task ID
progress: Progress message
"""
with self._lock:
if task_id in self._tasks:
self._tasks[task_id]["progress"] = progress
self._tasks[task_id]["updated_at"] = datetime.now().isoformat()
task_data = self._get_from_storage(task_id)
if task_data:
task_data["progress"] = progress
task_data["updated_at"] = datetime.now().isoformat()
self._save_to_storage(task_id, task_data)
def set_task_result(self, task_id: str, result: Any):
"""
@ -112,14 +169,15 @@ class InMemoryTaskManager:
Args:
task_id: Task ID
result: Task result (will be JSON serialized)
result: Task result
"""
with self._lock:
if task_id in self._tasks:
self._tasks[task_id]["status"] = "completed"
self._tasks[task_id]["result"] = result
self._tasks[task_id]["progress"] = "Analysis completed"
self._tasks[task_id]["completed_at"] = datetime.now().isoformat()
task_data = self._get_from_storage(task_id)
if task_data:
task_data["status"] = "completed"
task_data["result"] = result
task_data["progress"] = "Analysis completed"
task_data["completed_at"] = datetime.now().isoformat()
self._save_to_storage(task_id, task_data)
def set_task_error(self, task_id: str, error: str):
"""
@ -129,12 +187,13 @@ class InMemoryTaskManager:
task_id: Task ID
error: Error message
"""
with self._lock:
if task_id in self._tasks:
self._tasks[task_id]["status"] = "failed"
self._tasks[task_id]["error"] = error
self._tasks[task_id]["progress"] = "Analysis failed"
self._tasks[task_id]["failed_at"] = datetime.now().isoformat()
task_data = self._get_from_storage(task_id)
if task_data:
task_data["status"] = "failed"
task_data["error"] = error
task_data["progress"] = "Analysis failed"
task_data["failed_at"] = datetime.now().isoformat()
self._save_to_storage(task_id, task_data)
def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
@ -146,8 +205,7 @@ class InMemoryTaskManager:
Returns:
Task data or None if not found
"""
with self._lock:
return self._tasks.get(task_id)
return self._get_from_storage(task_id)
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
@ -157,7 +215,7 @@ class InMemoryTaskManager:
task_id: Task ID
Returns:
Dictionary with task status information including all required fields
Dictionary with task status information
"""
task = self.get_task(task_id)
if not task:
@ -167,7 +225,7 @@ class InMemoryTaskManager:
"task_id": task["task_id"],
"status": task["status"],
"created_at": task.get("created_at"),
"updated_at": task.get("updated_at", task.get("created_at")), # Fallback to created_at if updated_at not set
"updated_at": task.get("updated_at", task.get("created_at")),
"progress": task.get("progress"),
"result": task.get("result"),
"error": task.get("error"),
@ -184,17 +242,20 @@ class InMemoryTaskManager:
with self._lock:
if task_id in self._tasks:
del self._tasks[task_id]
if is_redis_available():
delete_task_from_redis(task_id)
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
"""
Get all tasks (for debugging)
Returns:
Dictionary of all tasks
Dictionary of all tasks (in-memory only)
"""
with self._lock:
return self._tasks.copy()
# Global task manager instance
task_manager = InMemoryTaskManager()
task_manager = HybridTaskManager()

View File

@ -22,6 +22,9 @@ PyJWT>=2.8.0
cryptography>=41.0.0
httpx>=0.25.0
# Redis (optional - for production caching)
redis>=5.0.0
# Existing TradingAgentsX dependencies
typing-extensions
langchain-openai

View File

@ -27,7 +27,7 @@ import {
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Trash2, Eye, RefreshCw, TrendingUp, Cloud, CloudOff, FileText, Download } from "lucide-react";
import { Trash2, Eye, RefreshCw, TrendingUp, CloudOff, FileText, Download } from "lucide-react";
import {
getReportsByMarketType,
deleteReport,
@ -152,9 +152,7 @@ export default function HistoryPage() {
);
const [deleting, setDeleting] = useState(false);
// Sync state
const [syncing, setSyncing] = useState(false);
const [syncResult, setSyncResult] = useState<{ success: number; failed: number } | null>(null);
// Auto-sync tracking ref
const hasAutoSyncedRef = useRef(false);
// Load reports when tab changes or auth state changes
@ -352,83 +350,6 @@ export default function HistoryPage() {
}
};
// Sync local reports to cloud
const handleSyncToCloud = async () => {
if (!isAuthenticated || !isCloudSyncEnabled()) {
alert("請先登入以啟用雲端同步");
return;
}
setSyncing(true);
setSyncResult(null);
try {
// Get all local reports
const [usLocal, twseLocal, tpexLocal] = await Promise.all([
getReportsByMarketType("us"),
getReportsByMarketType("twse"),
getReportsByMarketType("tpex"),
]);
const allLocal = [...usLocal, ...twseLocal, ...tpexLocal];
// Get cloud reports to check for duplicates
const cloudReports = await getCloudReports();
const cloudKeys = new Set(
cloudReports.map(r => `${r.ticker}_${r.analysis_date}`)
);
// Find local-only reports to upload
const toUpload = allLocal.filter(
r => !cloudKeys.has(`${r.ticker}_${r.analysis_date}`)
);
if (toUpload.length === 0) {
setSyncResult({ success: 0, failed: 0 });
alert("所有報告已同步到雲端!");
return;
}
// Upload each report
let success = 0;
let failed = 0;
for (const report of toUpload) {
try {
const cloudId = await saveCloudReport({
ticker: report.ticker,
market_type: report.market_type,
analysis_date: report.analysis_date,
result: report.result,
});
if (cloudId) {
success++;
} else {
failed++;
}
} catch (e) {
failed++;
}
}
setSyncResult({ success, failed });
// Reload data after sync
await loadReports();
await loadCounts();
if (failed === 0) {
alert(`成功同步 ${success} 份報告到雲端!`);
} else {
alert(`同步完成:${success} 成功,${failed} 失敗`);
}
} catch (error) {
console.error("Sync failed:", error);
alert("同步失敗,請稍後再試");
} finally {
setSyncing(false);
}
};
const handleViewReport = (report: SavedReport) => {
// Set the context with the saved report data
setAnalysisResult(report.result);
@ -592,25 +513,8 @@ export default function HistoryPage() {
(marketType) => (
<TabsContent key={marketType} value={marketType} className="mt-6">
<div className="space-y-4">
{/* Action buttons */}
<div className="flex justify-end gap-2">
{/* Sync to Cloud button - only show when authenticated */}
{isAuthenticated && (
<Button
variant="outline"
size="sm"
onClick={handleSyncToCloud}
disabled={syncing || loading}
className="gap-2"
>
<Cloud
className={`h-4 w-4 ${syncing ? "animate-pulse" : ""}`}
/>
{syncing ? "同步中..." : "同步到雲端"}
</Button>
)}
{/* Refresh button */}
{/* Refresh button */}
<div className="flex justify-end">
<Button
variant="outline"
size="sm"