add frontend

This commit is contained in:
Marvin Gabler 2025-10-21 23:55:54 +02:00
parent b6275e731d
commit 5887986f73
10 changed files with 280 additions and 84 deletions

View File

@ -1,65 +0,0 @@
# Bug Fix: ChromaDB Collection Name Collision
## Issue Description
When running multiple analyses through the API (either concurrently or sequentially), the system would fail with:
```
chromadb.errors.InternalError: Collection [bull_memory] already exists
```
### Root Cause
The `TradingAgentsGraph` class was creating ChromaDB memory collections with **hardcoded names**:
- `bull_memory`
- `bear_memory`
- `trader_memory`
- `invest_judge_memory`
- `risk_manager_memory`
When multiple analyses ran (even for different tickers), they all tried to create collections with the same names, causing ChromaDB to reject duplicate collection creation.
**Location of the bug:**
- `tradingagents/graph/trading_graph.py` lines 90-94
- `tradingagents/agents/utils/memory.py` line 14
## Solution Implemented
### Changes Made
1. **Modified `TradingAgentsGraph.__init__`** (`tradingagents/graph/trading_graph.py`):
- Added optional `analysis_id` parameter
- Collection names now include the analysis ID as a suffix: `bull_memory_{analysis_id}`
- When `analysis_id` is None, collections use original names (backward compatibility)
2. **Modified `state_manager.py`** (`api/state_manager.py`):
- Pass the unique `analysis_id` when creating `TradingAgentsGraph`
- Added cleanup in `finally` block to delete collections after analysis completes
3. **Added cleanup method** (`tradingagents/graph/trading_graph.py`):
- New `cleanup_memories()` method to delete ChromaDB collections
- Called after each analysis (success or failure) to prevent memory leaks
- Prevents accumulation of old collections in the database
### Backward Compatibility
The fix is **fully backward compatible**:
- CLI usage (`cli/main.py`) - continues to work without `analysis_id`
- Standalone usage (`main.py`) - continues to work without `analysis_id`
- API usage - now provides unique `analysis_id` for isolation
## Testing Recommendations
1. **Test concurrent analyses**: Run multiple analyses simultaneously for the same or different tickers
2. **Test sequential analyses**: Run multiple analyses one after another for the same ticker
3. **Test failure scenarios**: Ensure collections are cleaned up even when analysis fails
4. **Test CLI**: Verify CLI still works without regression
## Benefits
✅ Multiple analyses can now run concurrently without conflicts
✅ Same ticker can be analyzed multiple times without errors
✅ Memory collections are properly cleaned up after each analysis
✅ No breaking changes to existing code
✅ Prevents ChromaDB from accumulating stale collections

View File

@ -20,6 +20,7 @@ from api.models import (
ReportResponse,
)
from api.state_manager import get_executor
from api.utils import extract_trading_decision
from tradingagents.default_config import DEFAULT_CONFIG
router = APIRouter(prefix="/api/v1/analyses", tags=["analyses"])
@ -140,19 +141,30 @@ async def list_analyses(
# Apply pagination
analyses = query.offset(offset).limit(limit).all()
return [
AnalysisSummary(
id=a.id,
ticker=a.ticker,
analysis_date=a.analysis_date,
status=a.status,
created_at=a.created_at,
completed_at=a.completed_at,
error_message=a.error_message,
results = []
for a in analyses:
# Get trading decision for completed analyses
trading_decision = None
if a.status == "completed":
reports = db.query(AnalysisReport).filter(AnalysisReport.analysis_id == a.id).all()
if reports:
trading_decision = extract_trading_decision(reports)
results.append(
AnalysisSummary(
id=a.id,
ticker=a.ticker,
analysis_date=a.analysis_date,
status=a.status,
created_at=a.created_at,
completed_at=a.completed_at,
error_message=a.error_message,
trading_decision=trading_decision,
)
)
for a in analyses
]
return results
@router.get("/{analysis_id}", response_model=AnalysisResponse)
@ -177,6 +189,11 @@ async def get_analysis(
.all()
)
# Extract trading decision if analysis is completed
trading_decision = None
if analysis.status == "completed" and reports:
trading_decision = extract_trading_decision(reports)
return AnalysisResponse(
id=analysis.id,
ticker=analysis.ticker,
@ -197,6 +214,7 @@ async def get_analysis(
updated_at=analysis.updated_at,
completed_at=analysis.completed_at,
error_message=analysis.error_message,
trading_decision=trading_decision,
)

View File

@ -3,11 +3,13 @@
import csv
import glob
import os
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from cli.asset_detection import detect_asset_class
from tradingagents.dataflows.interface import route_to_vendor
from api.auth import APIKey, get_current_api_key
from api.models.responses import CachedDataResponse, CachedTickerInfo
@ -32,6 +34,104 @@ def _parse_date_range(filename: str) -> Optional[Dict[str, str]]:
return None
def _normalize_ohlcv_rows_from_csv(csv_text: str) -> List[Dict[str, str]]:
"""Normalize various vendor CSV formats to standard OHLCV schema.
Output fields: Date, Close, High, Low, Open, Volume
"""
import io
rows: List[Dict[str, str]] = []
if not csv_text:
return rows
f = io.StringIO(csv_text)
reader = csv.DictReader(f)
# Map common header variants to our standard fields
def get_field(d: Dict[str, str], *candidates: str) -> Optional[str]:
for c in candidates:
if c in d and d[c] not in (None, ""):
return d[c]
# case-insensitive
for k in d.keys():
if k.lower() == c.lower() and d[k] not in (None, ""):
return d[k]
return None
for r in reader:
date_val = get_field(r, "Date", "date", "time", "timestamp")
open_val = get_field(r, "Open", "open")
high_val = get_field(r, "High", "high")
low_val = get_field(r, "Low", "low")
close_val = get_field(r, "Close", "close")
volume_val = get_field(r, "Volume", "volume")
if not date_val:
# Skip rows without date
continue
rows.append({
"Date": str(date_val)[:10], # ensure YYYY-MM-DD
"Close": close_val if close_val is not None else "",
"High": high_val if high_val is not None else "",
"Low": low_val if low_val is not None else "",
"Open": open_val if open_val is not None else "",
"Volume": volume_val if volume_val is not None else "",
})
return rows
def _write_cache_csv(ticker: str, start_date: str, end_date: str, rows: List[Dict[str, str]]) -> Path:
"""Write normalized OHLCV rows to cache using standard filename pattern."""
DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
out_path = DATA_CACHE_DIR / f"{ticker.upper()}-YFin-data-{start_date}-{end_date}.csv"
with open(out_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["Date", "Close", "High", "Low", "Open", "Volume"])
writer.writeheader()
for r in rows:
writer.writerow(r)
return out_path
def _ensure_cached_data(ticker: str, start_date: Optional[str], end_date: Optional[str]) -> Optional[Path]:
"""Ensure OHLCV cache exists for ticker. If missing, fetch via vendor and write cache.
Returns the cache file path if created, else None.
"""
# Determine date window if not provided: last ~15 years
today = datetime.utcnow().date()
default_start = (today - timedelta(days=365 * 15)).strftime("%Y-%m-%d")
default_end = today.strftime("%Y-%m-%d")
start = (start_date or default_start)
end = (end_date or default_end)
pattern = f"{ticker.upper()}-YFin-data-*.csv"
existing = list(DATA_CACHE_DIR.glob(pattern))
if existing:
return None # already present
# Detect asset class and fetch
asset_class = detect_asset_class(ticker)
try:
if asset_class == "crypto":
csv_text = route_to_vendor("get_crypto_data", ticker.upper(), start, end, "USD")
elif asset_class == "commodity":
csv_text = route_to_vendor("get_commodity_data", ticker.upper(), start, end, "daily")
else:
csv_text = route_to_vendor("get_stock_data", ticker.upper(), start, end)
except Exception as e:
# If vendor fetch fails, don't block
return None
rows = _normalize_ohlcv_rows_from_csv(csv_text)
if not rows:
return None
# Sort by date to be safe
rows.sort(key=lambda r: r.get("Date", ""))
return _write_cache_csv(ticker, start, end, rows)
@router.get("/cache", response_model=List[CachedTickerInfo])
async def list_cached_tickers(
api_key: APIKey = Depends(get_current_api_key),
@ -73,6 +173,9 @@ async def get_cached_data(
api_key: APIKey = Depends(get_current_api_key),
):
"""Get cached market data for a ticker."""
# Ensure cache exists (auto-fetches if missing for crypto/commodities/stocks)
_ensure_cached_data(ticker, start_date, end_date)
# Find matching file
pattern = f"{ticker.upper()}-YFin-data-*.csv"
matching_files = list(DATA_CACHE_DIR.glob(pattern))

View File

@ -9,6 +9,7 @@ from api.models.responses import (
LogEntry,
ReportResponse,
TickerInfo,
TradingDecision,
CachedDataResponse,
)
@ -22,6 +23,7 @@ __all__ = [
"LogEntry",
"ReportResponse",
"TickerInfo",
"TradingDecision",
"CachedDataResponse",
]

View File

@ -39,6 +39,14 @@ class AnalysisStatusResponse(BaseModel):
updated_at: datetime = Field(..., description="Last update timestamp")
class TradingDecision(BaseModel):
"""Trading decision extracted from analysis reports."""
decision: str = Field(..., description="Trading decision (BUY/SELL/HOLD)")
confidence: Optional[int] = Field(None, description="Confidence percentage")
rationale: Optional[str] = Field(None, description="Brief rationale for the decision")
class AnalysisSummary(BaseModel):
"""Summary view of an analysis."""
@ -49,6 +57,7 @@ class AnalysisSummary(BaseModel):
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: Optional[datetime] = Field(None, description="Completion timestamp")
error_message: Optional[str] = Field(None, description="Error message if failed")
trading_decision: Optional[TradingDecision] = Field(None, description="Extracted trading decision")
class AnalysisResponse(BaseModel):
@ -68,6 +77,7 @@ class AnalysisResponse(BaseModel):
updated_at: datetime = Field(..., description="Last update timestamp")
completed_at: Optional[datetime] = Field(None, description="Completion timestamp")
error_message: Optional[str] = Field(None, description="Error message if failed")
trading_decision: Optional[TradingDecision] = Field(None, description="Extracted trading decision")
class TickerInfo(BaseModel):

87
api/utils.py Normal file
View File

@ -0,0 +1,87 @@
"""Utility functions for the API."""
import re
from typing import List, Optional
from api.database import AnalysisReport
from api.models.responses import TradingDecision
def extract_trading_decision(reports: List[AnalysisReport]) -> Optional[TradingDecision]:
"""Extract trading decision from analysis reports."""
# Look for final_trade_decision first, then investment_plan as fallback
trade_report = None
# Priority order: final_trade_decision > investment_plan
for report_type in ['final_trade_decision', 'investment_plan']:
for report in reports:
if report.report_type == report_type:
trade_report = report
break
if trade_report:
break
if not trade_report:
return None
content = trade_report.content
# Initialize defaults
decision = 'HOLD'
confidence = None
rationale = None
# Look for patterns like "Final Verdict: Sell (partial reduction)"
# Updated to capture everything after the colon, not just the word
decision_patterns = [
(r'final\s+verdict:\s*([^.\n]+)', 'extract'),
(r'final\s+decision:\s*([^.\n]+)', 'extract'),
(r'trade\s+decision:\s*([^.\n]+)', 'extract'),
(r'recommendation:\s*([^.\n]+)', 'extract'),
(r'decision:\s*([^.\n]+)', 'extract'),
(r'verdict:\s*([^.\n]+)', 'extract'),
(r'action:\s*([^.\n]+)', 'extract'),
(r'suggest\s+(buying|selling|holding)', 'verb'),
(r'recommend\s+(buying|selling|holding)', 'verb')
]
# Check each pattern
for pattern, pattern_type in decision_patterns:
match = re.search(pattern, content, re.IGNORECASE | re.MULTILINE)
if match:
found_text = match.group(1).lower()
# Extract the actual decision from the text
if 'buy' in found_text and 'not' not in found_text and "don't" not in found_text:
decision = 'BUY'
elif 'sell' in found_text and 'not' not in found_text and "don't" not in found_text:
decision = 'SELL'
elif 'hold' in found_text:
decision = 'HOLD'
# For patterns that captured the full text after colon, use it as rationale
if pattern_type == 'extract' and match:
# Get the original text with proper capitalization
rationale = match.group(0).strip()
break
# Extract confidence if mentioned
confidence_match = re.search(r'(\d+)%?\s*(confidence|confident|certainty)', content, re.IGNORECASE)
if confidence_match:
confidence = int(confidence_match.group(1))
# If no rationale was extracted from the patterns, look for the decision line
if not rationale:
lines = content.split('\n')
for line in lines:
line_lower = line.lower()
if any(keyword in line_lower for keyword in ['verdict', 'decision', 'recommendation']):
# This line likely contains our decision
rationale = line.strip()
break
return TradingDecision(
decision=decision,
confidence=confidence,
rationale=rationale
)

View File

@ -13,6 +13,7 @@
"autoprefixer": "^10.4.21",
"axios": "^1.12.2",
"framer-motion": "^12.23.24",
"lightweight-charts": "^4.2.3",
"lucide-react": "^0.546.0",
"postcss": "^8.5.6",
"react": "^19.1.1",
@ -2991,6 +2992,12 @@
"integrity": "sha512-GWkBvjiSZK87ELrYOSESUYeVIc9mvLLf/nXalMOS5dYrgZq9o5OVkbZAVM06CVxYsCwH9BDZFPlQTlPA1j4ahA==",
"license": "MIT"
},
"node_modules/fancy-canvas": {
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/fancy-canvas/-/fancy-canvas-2.1.0.tgz",
"integrity": "sha512-nifxXJ95JNLFR2NgRV4/MxVP45G9909wJTEKz5fg/TZS20JJZA6hfgRVh/bC9bwl2zBtBNcYPjiBE4njQHVBwQ==",
"license": "MIT"
},
"node_modules/fast-deep-equal": {
"version": "3.1.3",
"resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz",
@ -3817,6 +3824,15 @@
"url": "https://opencollective.com/parcel"
}
},
"node_modules/lightweight-charts": {
"version": "4.2.3",
"resolved": "https://registry.npmjs.org/lightweight-charts/-/lightweight-charts-4.2.3.tgz",
"integrity": "sha512-5kS/2hY3wNYNzhnS8Gb+GAS07DX8GPF2YVDnd2NMC85gJVQ6RLU6YrXNgNJ6eg0AnWPwCnvaGtYmGky3HiLQEw==",
"license": "Apache-2.0",
"dependencies": {
"fancy-canvas": "2.1.0"
}
},
"node_modules/locate-path": {
"version": "6.0.0",
"resolved": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz",

View File

@ -15,6 +15,7 @@
"autoprefixer": "^10.4.21",
"axios": "^1.12.2",
"framer-motion": "^12.23.24",
"lightweight-charts": "^4.2.3",
"lucide-react": "^0.546.0",
"postcss": "^8.5.6",
"react": "^19.1.1",

View File

@ -4,9 +4,8 @@ from langchain_core.messages import HumanMessage, RemoveMessage
from tradingagents.agents.utils.core_stock_tools import (
get_stock_data
)
from tradingagents.agents.utils.technical_indicators_tools import (
get_indicators
)
# Note: get_indicators is imported from unified_market_tools (see below)
# The old technical_indicators_tools version is deprecated
from tradingagents.agents.utils.fundamental_data_tools import (
get_fundamentals,
get_balance_sheet,
@ -24,10 +23,11 @@ from tradingagents.agents.utils.news_data_tools import (
from tradingagents.agents.utils.crypto_data_tools import (
get_crypto_data
)
# Unified tools provide a consistent interface across asset classes
from tradingagents.agents.utils.unified_market_tools import (
get_market_data,
get_asset_news,
get_indicators,
get_indicators, # Uses new unified signature with adapter for old vendor implementations
get_global_news as get_global_news_unified,
)

View File

@ -67,8 +67,32 @@ def get_indicators(
Returns:
CSV-formatted data with requested technical indicators
"""
# Indicators are equity-specific for now
return route_to_vendor("get_indicators", symbol, start_date, end_date, indicators)
# Adapter: Translate new unified signature to old vendor signature
# Old signature: (symbol, indicator, curr_date, look_back_days)
# New signature: (symbol, start_date, end_date, indicators)
from datetime import datetime
# Calculate look_back_days from date range
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
look_back_days = (end_dt - start_dt).days
# Parse comma-separated indicators
indicator_list = [ind.strip() for ind in indicators.split(',')]
# Call vendor for each indicator and combine results
results = []
for indicator in indicator_list:
if indicator: # Skip empty strings
result = route_to_vendor("get_indicators", symbol, indicator, end_date, look_back_days)
results.append(result)
# Combine results with separators
if len(results) == 1:
return results[0]
else:
return "\n\n" + "="*80 + "\n\n".join(results)
@tool