This commit is contained in:
MarkLo127 2026-03-12 09:30:42 +08:00
parent 2a8bc6ade8
commit 69eba15bbb
7 changed files with 244 additions and 72 deletions

View File

@ -1,11 +1,14 @@
"""
Shared dependencies for API routes
"""
import logging
from typing import Optional, Dict, Any
from fastapi import Depends, HTTPException, Header
from backend.app.services.trading_service import TradingService, trading_service
from backend.app.services.auth_utils import verify_access_token
logger = logging.getLogger(__name__)
def get_trading_service() -> TradingService:
"""Dependency to get trading service instance"""
@ -17,24 +20,29 @@ async def get_current_user_optional(
) -> Optional[Dict[str, Any]]:
"""
Get current user from JWT token (optional - returns None if not authenticated)
Use this for endpoints that work both with and without authentication.
All exceptions are caught to prevent 500 errors on malformed tokens.
"""
if not authorization or not authorization.startswith("Bearer "):
return None
token = authorization.replace("Bearer ", "")
payload = verify_access_token(token)
if not payload:
try:
token = authorization.replace("Bearer ", "")
payload = verify_access_token(token)
if not payload:
return None
return {
"id": payload.get("sub"),
"email": payload.get("email"),
"name": payload.get("name"),
"avatar_url": payload.get("avatar_url"),
}
except Exception as e:
logger.warning(f"Token validation error in optional auth: {type(e).__name__}")
return None
return {
"id": payload.get("sub"),
"email": payload.get("email"),
"name": payload.get("name"),
"avatar_url": payload.get("avatar_url"),
}
async def get_current_user_required(

View File

@ -93,14 +93,21 @@ async def run_analysis(
# Log with user info for tracking
user_info = f"user={current_user['email']}" if current_user else "user=anonymous"
logger.info(f"Creating analysis task for {request.ticker} on {request.analysis_date} ({user_info})")
# Create task in Redis with user info
task_id = task_manager.create_task({
"ticker": request.ticker,
"analysis_date": request.analysis_date,
"user_id": current_user["id"] if current_user else None,
"user_email": current_user["email"] if current_user else None,
})
try:
task_id = task_manager.create_task({
"ticker": request.ticker,
"analysis_date": request.analysis_date,
"user_id": current_user["id"] if current_user else None,
"user_email": current_user["email"] if current_user else None,
})
except Exception as e:
logger.error(f"Failed to create analysis task: {type(e).__name__}: {str(e)}", exc_info=True)
raise HTTPException(
status_code=503,
detail=f"Task creation failed ({type(e).__name__}). The task service may be temporarily unavailable. Please try again."
)
# Start background analysis
def run_background_analysis():

View File

@ -7,7 +7,8 @@ from uuid import UUID
from fastapi import APIRouter, HTTPException, Depends, Header
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, func
from sqlalchemy import select, delete, func, text
from sqlalchemy.orm import defer
from datetime import datetime
from backend.app.db import get_db, User, UserSettings, Report
@ -155,7 +156,8 @@ async def get_reports(
market_type: Optional[str] = None,
language: Optional[str] = None,
limit: int = 100,
offset: int = 0
offset: int = 0,
include_result: bool = True,
):
"""Get user's reports with optional filtering and pagination
@ -164,6 +166,8 @@ async def get_reports(
language: Filter by language (en, zh-TW)
limit: Maximum number of reports to return (default 100, max 500)
offset: Number of reports to skip for pagination
include_result: Whether to include result JSONB field (default True for backward compat).
Set to False for lightweight list queries.
"""
# Cap limit at 500 to prevent memory issues
limit = min(limit, 500)
@ -171,31 +175,35 @@ async def get_reports(
# Build query with filters
query = select(Report).where(Report.user_id == user.id)
# Skip loading the large result JSONB column when not needed
if not include_result:
query = query.options(defer(Report.result))
if market_type:
query = query.where(Report.market_type == market_type)
if language:
query = query.where(Report.language == language)
query = query.where(func.coalesce(Report.language, "zh-TW") == language)
# Order by created_at DESC and apply pagination
query = query.order_by(Report.created_at.desc()).offset(offset).limit(limit)
result = await db.execute(query)
reports = result.scalars().all()
# Process reports to strip large payloads for the list view
# Process reports
optimized_reports = []
for r in reports:
# Create a copy of the result to avoid modifying SQLAlchemy objects directly
if r.result and isinstance(r.result, dict):
# Shallow copy the dictionary
if include_result and r.result and isinstance(r.result, dict):
# Shallow copy the dictionary and strip massive reports field
optimized_result = dict(r.result)
# Remove the massive reports field if it exists
if "reports" in optimized_result:
optimized_result["reports"] = None
elif include_result:
optimized_result = r.result or {}
else:
optimized_result = r.result
optimized_result = {}
optimized_reports.append(
ReportResponse(
id=str(r.id),
@ -263,6 +271,40 @@ async def create_report(
}
@router.get("/reports/counts")
async def get_report_counts(
user: User = Depends(get_current_user_required),
db: AsyncSession = Depends(get_db),
language: Optional[str] = None,
):
"""Get report counts by market_type without loading full rows (lightweight query)"""
base_query = (
select(
Report.market_type,
func.count(Report.id).label("count")
)
.where(Report.user_id == user.id)
.group_by(Report.market_type)
)
if language:
base_query = base_query.where(
func.coalesce(Report.language, "zh-TW") == language
)
result = await db.execute(base_query)
rows = result.all()
counts = {"us": 0, "twse": 0, "tpex": 0}
for market_type, count in rows:
if market_type in counts:
counts[market_type] = count
total = sum(counts.values())
return {"counts": counts, "total": total}
@router.get("/reports/{report_id}")
async def get_report(
report_id: str,
@ -299,39 +341,34 @@ async def cleanup_duplicate_reports(
user: User = Depends(get_current_user_required),
db: AsyncSession = Depends(get_db)
):
"""Remove duplicate reports, keeping only the most recent one per (ticker, analysis_date, market_type, language)"""
# Fetch all user reports ordered newest first
result = await db.execute(
select(Report)
.where(Report.user_id == user.id)
.order_by(Report.created_at.desc())
)
all_reports = result.scalars().all()
seen: set = set()
ids_to_delete: list = []
for report in all_reports:
# Normalize language
lang = report.language or "zh-TW"
key = (report.ticker, report.analysis_date, report.market_type, lang)
if key in seen:
ids_to_delete.append(report.id)
else:
seen.add(key)
if ids_to_delete:
await db.execute(
delete(Report)
.where(Report.user_id == user.id)
.where(Report.id.in_(ids_to_delete))
"""Remove duplicate reports using SQL-level deduplication.
Keeps only the most recent one per (ticker, analysis_date, market_type, language).
"""
# Use SQL window function to find duplicates without loading all rows into memory
result = await db.execute(text("""
DELETE FROM reports
WHERE id IN (
SELECT id FROM (
SELECT id,
ROW_NUMBER() OVER (
PARTITION BY user_id, ticker, analysis_date, market_type,
COALESCE(language, 'zh-TW')
ORDER BY created_at DESC
) as rn
FROM reports
WHERE user_id = :user_id
) ranked
WHERE rn > 1
)
await db.commit()
"""), {"user_id": str(user.id)})
deleted_count = result.rowcount
await db.commit()
return {
"success": True,
"deleted": len(ids_to_delete),
"message": f"Cleaned up {len(ids_to_delete)} duplicate reports"
"deleted": deleted_count,
"message": f"Cleaned up {deleted_count} duplicate reports"
}

View File

@ -77,6 +77,8 @@ async def init_db():
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_reports_language ON reports (language);"))
# Add composite index for user + market_type + language queries
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_reports_user_market_lang ON reports (user_id, market_type, language);"))
# Covering index for counts query (GROUP BY market_type with language filter)
await conn.execute(text("CREATE INDEX IF NOT EXISTS ix_reports_user_market_lang_count ON reports (user_id, market_type, COALESCE(language, 'zh-TW'));"))
except Exception as e:
print(f"Skipping manual migration (might be SQLite or syntax not supported): {e}")

View File

@ -171,10 +171,10 @@ async def root():
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""Global exception handler - masks sensitive data in errors"""
async def global_exception_handler(request: Request, exc: Exception):
"""Global exception handler - returns appropriate status codes and masks sensitive data"""
error_msg = str(exc)
# Mask any API keys that might be in the error message
import re
patterns = ["sk-", "sk-ant-", "xai-", "AIza"]
@ -184,14 +184,32 @@ async def global_exception_handler(request, exc):
f'{pattern}**********',
error_msg
)
logger.error(f"Unhandled exception: {error_msg}", exc_info=True)
# Determine appropriate status code based on exception type
error_type = type(exc).__name__
status_code = 500
if isinstance(exc, ValueError):
status_code = 400
elif isinstance(exc, (ConnectionError, ConnectionRefusedError, TimeoutError)):
status_code = 503
elif isinstance(exc, PermissionError):
status_code = 403
elif isinstance(exc, FileNotFoundError):
status_code = 404
logger.error(
f"Unhandled exception on {request.method} {request.url.path}: "
f"[{error_type}] {error_msg}",
exc_info=True
)
return JSONResponse(
status_code=500,
status_code=status_code,
content={
"error": "Internal server error",
"error": "Internal server error" if status_code == 500 else error_type,
"detail": error_msg,
"type": type(exc).__name__,
"type": error_type,
"path": str(request.url.path),
},
)

View File

@ -0,0 +1,81 @@
/**
* Shared utilities for report deduplication, language detection, and date parsing.
* Single source of truth imported by history/page.tsx, analysis/page.tsx, results/page.tsx, etc.
*/
/**
* Normalize report language. All null/undefined values become "zh-TW"
* to match backend behavior (COALESCE(language, 'zh-TW')).
*/
export function normalizeLanguage(
language?: string | null
): "en" | "zh-TW" {
if (language === "en") return "en";
return "zh-TW";
}
/**
* Detect report language from content (for backward compatibility with old reports
* that don't have a language field stored).
* Checks trader_investment_plan for Chinese/English keywords.
*/
export function detectReportLanguage(reports: any): "en" | "zh-TW" {
const traderPlan = reports?.trader_investment_plan;
if (!traderPlan || typeof traderPlan !== "string") {
// If no trader plan, check other reports for Chinese characters
const allText = JSON.stringify(reports || {});
const chineseRegex = /[\u4e00-\u9fa5]/;
return chineseRegex.test(allText) ? "zh-TW" : "en";
}
// Check for Chinese decision keywords
const chineseKeywords = ["買入", "賣出", "持有", "最終交易提案"];
for (const keyword of chineseKeywords) {
if (traderPlan.includes(keyword)) {
return "zh-TW";
}
}
// Check for English decision keywords
const englishKeywords = ["buy", "sell", "hold", "final trading proposal"];
const lowerPlan = traderPlan.toLowerCase();
for (const keyword of englishKeywords) {
if (lowerPlan.includes(keyword)) {
return "en";
}
}
// Fallback: check for Chinese characters in the content
const chineseRegex = /[\u4e00-\u9fa5]/;
return chineseRegex.test(traderPlan) ? "zh-TW" : "en";
}
/**
* Generate a unique signature for report deduplication.
* Uses stable key fields: ticker + date + market_type + language.
* Language is normalized to "zh-TW" when null/undefined to match backend behavior.
*/
export function getReportSignature(report: {
ticker: string;
analysis_date: string;
market_type?: string;
language?: string | null;
}): string {
const lang = normalizeLanguage(report.language);
return `${report.ticker}_${report.analysis_date}_${report.market_type || "us"}_${lang}`;
}
/**
* Parse a date string from the backend as UTC.
* Backend stores created_at in UTC but may not always include timezone info.
* This ensures the date is correctly interpreted as UTC so the browser
* converts it to the user's local timezone for display.
*/
export function parseUTCDate(dateStr: string): Date {
// If the string already has timezone info (Z, +, or - offset), parse directly
if (dateStr.endsWith("Z") || /[+-]\d{2}:\d{2}$/.test(dateStr)) {
return new Date(dateStr);
}
// Otherwise, append 'Z' to treat as UTC
return new Date(dateStr + "Z");
}

View File

@ -5,6 +5,7 @@
import Dexie, { type Table } from "dexie";
import type { AnalysisResponse } from "./types";
import { normalizeLanguage } from "./report-utils";
// Saved report interface
export interface SavedReport {
@ -16,6 +17,8 @@ export interface SavedReport {
task_id?: string; // Original task ID
result: AnalysisResponse; // Full analysis result
language?: "en" | "zh-TW"; // Language of the report (for filtering)
cloud_id?: string; // Corresponding cloud report ID (for sync tracking)
pending_sync?: boolean; // Whether report is waiting to be synced to cloud
}
// Database class extending Dexie
@ -32,6 +35,11 @@ class ReportsDatabase extends Dexie {
this.version(2).stores({
reports: "++id, ticker, market_type, analysis_date, saved_at, language",
});
// Version 3: Added cloud_id and pending_sync for sync tracking
this.version(3).stores({
reports:
"++id, ticker, market_type, analysis_date, saved_at, language, cloud_id, pending_sync",
});
}
}
@ -123,22 +131,32 @@ export async function getReportCountByMarketType(): Promise<{
}
/**
* Check if a report with the same ticker and analysis_date already exists
* Check if a report with the same signature already exists.
* Supports optional market_type and language for precise matching.
*/
export async function checkDuplicateReport(
ticker: string,
analysis_date: string,
market_type?: "us" | "twse" | "tpex",
language?: "en" | "zh-TW",
): Promise<SavedReport | undefined> {
const normalizedLang = normalizeLanguage(language);
return await db.reports
.where("ticker")
.equals(ticker)
.and((report) => report.analysis_date === analysis_date)
.and((report) => {
if (report.analysis_date !== analysis_date) return false;
if (market_type && report.market_type !== market_type) return false;
if (normalizeLanguage(report.language) !== normalizedLang) return false;
return true;
})
.first();
}
/**
* Check if a report exists by ticker, date, market type, and language
* Used for bidirectional sync to prevent duplicates
* Used for bidirectional sync to prevent duplicates.
* Language is normalized so null/undefined matches "zh-TW".
*/
export async function findExistingReport(
ticker: string,
@ -146,6 +164,7 @@ export async function findExistingReport(
market_type: "us" | "twse" | "tpex",
language?: "en" | "zh-TW",
): Promise<SavedReport | undefined> {
const normalizedLang = normalizeLanguage(language);
return await db.reports
.where("ticker")
.equals(ticker)
@ -153,7 +172,7 @@ export async function findExistingReport(
(report) =>
report.analysis_date === analysis_date &&
report.market_type === market_type &&
report.language === language
normalizeLanguage(report.language) === normalizedLang
)
.first();
}