1069 lines
28 KiB
Markdown
1069 lines
28 KiB
Markdown
# TradingAgents - Potential Improvements & Enhancements
|
|
|
|
**Date:** 2025-11-14
|
|
**Analysis by:** Claude (AI Code Analysis)
|
|
|
|
---
|
|
|
|
## Executive Summary
|
|
|
|
This document outlines potential improvements and enhancements for the TradingAgents framework. These suggestions focus on code quality, performance, maintainability, and feature additions that could benefit the project and its community.
|
|
|
|
---
|
|
|
|
## Category 1: Code Quality & Architecture
|
|
|
|
### 1.1 Add Type Hints Throughout Codebase
|
|
**Priority:** High
|
|
**Effort:** Medium
|
|
**Impact:** High maintainability
|
|
|
|
**Current State:**
|
|
Most files lack comprehensive type hints.
|
|
|
|
**Proposed:**
|
|
```python
|
|
from typing import Dict, List, Optional, Union
|
|
from datetime import datetime
|
|
|
|
def get_stock_data(
|
|
ticker: str,
|
|
start_date: Union[str, datetime],
|
|
end_date: Union[str, datetime],
|
|
config: Optional[Dict] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Fetch stock data for a given ticker and date range.
|
|
|
|
Args:
|
|
ticker: Stock ticker symbol (e.g., 'AAPL')
|
|
start_date: Start date for data fetch
|
|
end_date: End date for data fetch
|
|
config: Optional configuration dictionary
|
|
|
|
Returns:
|
|
Dictionary containing stock data
|
|
|
|
Raises:
|
|
ValueError: If dates are invalid
|
|
APIError: If API call fails
|
|
"""
|
|
pass
|
|
```
|
|
|
|
**Benefits:**
|
|
- Better IDE autocomplete
|
|
- Catch type errors early
|
|
- Improved documentation
|
|
- Easier onboarding for contributors
|
|
|
|
### 1.2 Implement Dependency Injection
|
|
**Priority:** Medium
|
|
**Effort:** High
|
|
**Impact:** Better testability
|
|
|
|
**Current State:**
|
|
Heavy use of global configuration and direct instantiation.
|
|
|
|
**Proposed:**
|
|
```python
|
|
from typing import Protocol
|
|
|
|
class DataVendor(Protocol):
|
|
def get_stock_data(self, ticker: str, date: str) -> dict:
|
|
...
|
|
|
|
class TradingAgentsGraph:
|
|
def __init__(
|
|
self,
|
|
data_vendor: DataVendor,
|
|
llm_provider: LLMProvider,
|
|
config: Config
|
|
):
|
|
self.data_vendor = data_vendor
|
|
self.llm_provider = llm_provider
|
|
self.config = config
|
|
```
|
|
|
|
**Benefits:**
|
|
- Easier testing with mocks
|
|
- More flexible architecture
|
|
- Better separation of concerns
|
|
|
|
### 1.3 Add Comprehensive Logging
|
|
**Priority:** High
|
|
**Effort:** Medium
|
|
**Impact:** Better debugging and monitoring
|
|
|
|
**Proposed:**
|
|
```python
|
|
import logging
|
|
from pythonjsonlogger import jsonlogger
|
|
|
|
# Create loggers for different components
|
|
def setup_logging(config: Dict) -> logging.Logger:
|
|
"""Setup structured logging for TradingAgents."""
|
|
logger = logging.getLogger('tradingagents')
|
|
|
|
handler = logging.StreamHandler()
|
|
formatter = jsonlogger.JsonFormatter(
|
|
'%(timestamp)s %(level)s %(name)s %(message)s'
|
|
)
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
level = config.get('log_level', 'INFO')
|
|
logger.setLevel(getattr(logging, level))
|
|
|
|
return logger
|
|
|
|
# Usage throughout codebase
|
|
logger = logging.getLogger('tradingagents.dataflows')
|
|
logger.info(
|
|
"Fetching stock data",
|
|
extra={
|
|
"ticker": ticker,
|
|
"vendor": vendor_name,
|
|
"date": date
|
|
}
|
|
)
|
|
```
|
|
|
|
---
|
|
|
|
## Category 2: Performance Optimizations
|
|
|
|
### 2.1 Implement Caching Layer
|
|
**Priority:** High
|
|
**Effort:** Medium
|
|
**Impact:** Significant performance improvement
|
|
|
|
**Current State:**
|
|
Some caching exists but it's inconsistent.
|
|
|
|
**Proposed:**
|
|
```python
|
|
from functools import lru_cache
|
|
from typing import Optional
|
|
import hashlib
|
|
import json
|
|
|
|
class CacheManager:
|
|
"""Unified caching for API calls and LLM responses."""
|
|
|
|
def __init__(self, cache_dir: str, ttl: int = 3600):
|
|
self.cache_dir = Path(cache_dir)
|
|
self.cache_dir.mkdir(exist_ok=True)
|
|
self.ttl = ttl
|
|
|
|
def get(self, key: str) -> Optional[Any]:
|
|
"""Get cached value if exists and not expired."""
|
|
cache_file = self.cache_dir / f"{key}.json"
|
|
if not cache_file.exists():
|
|
return None
|
|
|
|
with open(cache_file, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
# Check if expired
|
|
if time.time() - data['timestamp'] > self.ttl:
|
|
cache_file.unlink()
|
|
return None
|
|
|
|
return data['value']
|
|
|
|
def set(self, key: str, value: Any) -> None:
|
|
"""Set cache value."""
|
|
cache_file = self.cache_dir / f"{key}.json"
|
|
with open(cache_file, 'w') as f:
|
|
json.dump({
|
|
'timestamp': time.time(),
|
|
'value': value
|
|
}, f)
|
|
|
|
def cache_key(self, *args, **kwargs) -> str:
|
|
"""Generate cache key from arguments."""
|
|
key_data = json.dumps({'args': args, 'kwargs': kwargs}, sort_keys=True)
|
|
return hashlib.sha256(key_data.encode()).hexdigest()
|
|
|
|
# Usage
|
|
cache = CacheManager('./cache', ttl=3600)
|
|
|
|
def get_stock_data(ticker: str, date: str) -> dict:
|
|
cache_key = cache.cache_key(ticker, date)
|
|
|
|
# Try cache first
|
|
cached = cache.get(cache_key)
|
|
if cached:
|
|
return cached
|
|
|
|
# Fetch fresh data
|
|
data = fetch_from_api(ticker, date)
|
|
|
|
# Cache result
|
|
cache.set(cache_key, data)
|
|
return data
|
|
```
|
|
|
|
### 2.2 Parallelize API Calls
|
|
**Priority:** Medium
|
|
**Effort:** Medium
|
|
**Impact:** Faster execution
|
|
|
|
**Proposed:**
|
|
```python
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import List, Callable
|
|
|
|
class ParallelDataFetcher:
|
|
"""Fetch data from multiple sources in parallel."""
|
|
|
|
def __init__(self, max_workers: int = 5):
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
|
|
def fetch_all(
|
|
self,
|
|
tasks: List[Callable],
|
|
timeout: int = 30
|
|
) -> List[Any]:
|
|
"""Execute all tasks in parallel."""
|
|
futures = [
|
|
self.executor.submit(task)
|
|
for task in tasks
|
|
]
|
|
|
|
results = []
|
|
for future in futures:
|
|
try:
|
|
result = future.result(timeout=timeout)
|
|
results.append(result)
|
|
except Exception as e:
|
|
logger.error(f"Task failed: {e}")
|
|
results.append(None)
|
|
|
|
return results
|
|
|
|
# Usage
|
|
fetcher = ParallelDataFetcher()
|
|
results = fetcher.fetch_all([
|
|
lambda: get_stock_data(ticker, date),
|
|
lambda: get_news_data(ticker, date),
|
|
lambda: get_fundamentals(ticker, date),
|
|
])
|
|
```
|
|
|
|
### 2.3 Optimize LLM Token Usage
|
|
**Priority:** High
|
|
**Effort:** Low
|
|
**Impact:** Cost reduction
|
|
|
|
**Proposed:**
|
|
```python
|
|
class TokenOptimizer:
|
|
"""Optimize prompts to reduce token usage."""
|
|
|
|
@staticmethod
|
|
def truncate_context(
|
|
context: str,
|
|
max_tokens: int,
|
|
encoding: str = "cl100k_base"
|
|
) -> str:
|
|
"""Intelligently truncate context to fit token limit."""
|
|
import tiktoken
|
|
|
|
enc = tiktoken.get_encoding(encoding)
|
|
tokens = enc.encode(context)
|
|
|
|
if len(tokens) <= max_tokens:
|
|
return context
|
|
|
|
# Truncate from middle, keep beginning and end
|
|
keep_start = max_tokens // 2
|
|
keep_end = max_tokens - keep_start
|
|
|
|
truncated = tokens[:keep_start] + tokens[-keep_end:]
|
|
return enc.decode(truncated)
|
|
|
|
@staticmethod
|
|
def summarize_if_needed(
|
|
text: str,
|
|
max_tokens: int,
|
|
llm: ChatOpenAI
|
|
) -> str:
|
|
"""Summarize text if it exceeds token limit."""
|
|
if count_tokens(text) <= max_tokens:
|
|
return text
|
|
|
|
# Use cheaper model for summarization
|
|
summary_prompt = f"Summarize this concisely:\n\n{text}"
|
|
return llm.invoke(summary_prompt).content
|
|
```
|
|
|
|
---
|
|
|
|
## Category 3: Feature Enhancements
|
|
|
|
### 3.1 Add Backtesting Framework
|
|
**Priority:** High
|
|
**Effort:** High
|
|
**Impact:** Critical for validation
|
|
|
|
**Proposed:**
|
|
```python
|
|
from dataclasses import dataclass
|
|
from typing import List, Dict
|
|
import pandas as pd
|
|
|
|
@dataclass
|
|
class BacktestResult:
|
|
"""Results from a backtest run."""
|
|
total_return: float
|
|
sharpe_ratio: float
|
|
max_drawdown: float
|
|
win_rate: float
|
|
trades: List[Dict]
|
|
equity_curve: pd.Series
|
|
|
|
class Backtester:
|
|
"""Backtest trading strategies."""
|
|
|
|
def __init__(
|
|
self,
|
|
initial_capital: float = 100000,
|
|
commission: float = 0.001
|
|
):
|
|
self.initial_capital = initial_capital
|
|
self.commission = commission
|
|
|
|
def run(
|
|
self,
|
|
strategy: TradingAgentsGraph,
|
|
tickers: List[str],
|
|
start_date: str,
|
|
end_date: str
|
|
) -> BacktestResult:
|
|
"""Run backtest over date range."""
|
|
dates = pd.date_range(start_date, end_date, freq='D')
|
|
portfolio = Portfolio(self.initial_capital)
|
|
trades = []
|
|
|
|
for date in dates:
|
|
for ticker in tickers:
|
|
# Get strategy decision
|
|
_, decision = strategy.propagate(ticker, date.strftime('%Y-%m-%d'))
|
|
|
|
# Execute trade
|
|
if decision['action'] == 'BUY':
|
|
trade = portfolio.buy(
|
|
ticker,
|
|
decision['quantity'],
|
|
decision['price'],
|
|
self.commission
|
|
)
|
|
trades.append(trade)
|
|
elif decision['action'] == 'SELL':
|
|
trade = portfolio.sell(
|
|
ticker,
|
|
decision['quantity'],
|
|
decision['price'],
|
|
self.commission
|
|
)
|
|
trades.append(trade)
|
|
|
|
return BacktestResult(
|
|
total_return=portfolio.total_return(),
|
|
sharpe_ratio=portfolio.sharpe_ratio(),
|
|
max_drawdown=portfolio.max_drawdown(),
|
|
win_rate=portfolio.win_rate(),
|
|
trades=trades,
|
|
equity_curve=portfolio.equity_curve()
|
|
)
|
|
```
|
|
|
|
### 3.2 Add Real-time Market Data Stream
|
|
**Priority:** Medium
|
|
**Effort:** High
|
|
**Impact:** Production readiness
|
|
|
|
**Proposed:**
|
|
```python
|
|
import asyncio
|
|
from typing import Callable, List
|
|
|
|
class MarketDataStream:
|
|
"""Stream real-time market data."""
|
|
|
|
def __init__(self, websocket_url: str):
|
|
self.websocket_url = websocket_url
|
|
self.subscribers: List[Callable] = []
|
|
|
|
async def subscribe(self, ticker: str, callback: Callable):
|
|
"""Subscribe to ticker updates."""
|
|
self.subscribers.append(callback)
|
|
|
|
async with websockets.connect(self.websocket_url) as ws:
|
|
await ws.send(json.dumps({
|
|
'action': 'subscribe',
|
|
'ticker': ticker
|
|
}))
|
|
|
|
async for message in ws:
|
|
data = json.loads(message)
|
|
await callback(data)
|
|
|
|
async def start(self):
|
|
"""Start streaming data."""
|
|
tasks = [
|
|
self.subscribe(ticker, callback)
|
|
for ticker, callback in self.subscribers
|
|
]
|
|
await asyncio.gather(*tasks)
|
|
```
|
|
|
|
### 3.3 Add Portfolio Management
|
|
**Priority:** High
|
|
**Effort:** Medium
|
|
**Impact:** Essential for production
|
|
|
|
**Proposed:**
|
|
```python
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List
|
|
|
|
@dataclass
|
|
class Position:
|
|
"""Represents a position in a security."""
|
|
ticker: str
|
|
quantity: float
|
|
avg_cost: float
|
|
current_price: float
|
|
|
|
@property
|
|
def market_value(self) -> float:
|
|
return self.quantity * self.current_price
|
|
|
|
@property
|
|
def unrealized_pnl(self) -> float:
|
|
return (self.current_price - self.avg_cost) * self.quantity
|
|
|
|
class Portfolio:
|
|
"""Manage trading portfolio."""
|
|
|
|
def __init__(self, initial_capital: float):
|
|
self.cash = initial_capital
|
|
self.initial_capital = initial_capital
|
|
self.positions: Dict[str, Position] = {}
|
|
self.trade_history: List[Dict] = []
|
|
|
|
def buy(
|
|
self,
|
|
ticker: str,
|
|
quantity: float,
|
|
price: float,
|
|
commission: float = 0.0
|
|
) -> Dict:
|
|
"""Execute buy order."""
|
|
cost = quantity * price * (1 + commission)
|
|
|
|
if cost > self.cash:
|
|
raise ValueError(f"Insufficient funds: need ${cost}, have ${self.cash}")
|
|
|
|
self.cash -= cost
|
|
|
|
if ticker in self.positions:
|
|
# Update existing position
|
|
pos = self.positions[ticker]
|
|
total_qty = pos.quantity + quantity
|
|
pos.avg_cost = (
|
|
(pos.avg_cost * pos.quantity + price * quantity) / total_qty
|
|
)
|
|
pos.quantity = total_qty
|
|
else:
|
|
# Create new position
|
|
self.positions[ticker] = Position(
|
|
ticker=ticker,
|
|
quantity=quantity,
|
|
avg_cost=price,
|
|
current_price=price
|
|
)
|
|
|
|
trade = {
|
|
'action': 'BUY',
|
|
'ticker': ticker,
|
|
'quantity': quantity,
|
|
'price': price,
|
|
'commission': commission,
|
|
'timestamp': datetime.now()
|
|
}
|
|
self.trade_history.append(trade)
|
|
return trade
|
|
|
|
def sell(
|
|
self,
|
|
ticker: str,
|
|
quantity: float,
|
|
price: float,
|
|
commission: float = 0.0
|
|
) -> Dict:
|
|
"""Execute sell order."""
|
|
if ticker not in self.positions:
|
|
raise ValueError(f"No position in {ticker}")
|
|
|
|
pos = self.positions[ticker]
|
|
if quantity > pos.quantity:
|
|
raise ValueError(
|
|
f"Insufficient shares: have {pos.quantity}, trying to sell {quantity}"
|
|
)
|
|
|
|
proceeds = quantity * price * (1 - commission)
|
|
self.cash += proceeds
|
|
|
|
pos.quantity -= quantity
|
|
if pos.quantity == 0:
|
|
del self.positions[ticker]
|
|
|
|
trade = {
|
|
'action': 'SELL',
|
|
'ticker': ticker,
|
|
'quantity': quantity,
|
|
'price': price,
|
|
'commission': commission,
|
|
'realized_pnl': (price - pos.avg_cost) * quantity,
|
|
'timestamp': datetime.now()
|
|
}
|
|
self.trade_history.append(trade)
|
|
return trade
|
|
|
|
def update_prices(self, prices: Dict[str, float]):
|
|
"""Update current prices for all positions."""
|
|
for ticker, price in prices.items():
|
|
if ticker in self.positions:
|
|
self.positions[ticker].current_price = price
|
|
|
|
def total_value(self) -> float:
|
|
"""Calculate total portfolio value."""
|
|
return self.cash + sum(
|
|
pos.market_value for pos in self.positions.values()
|
|
)
|
|
|
|
def total_return(self) -> float:
|
|
"""Calculate total return percentage."""
|
|
return (self.total_value() - self.initial_capital) / self.initial_capital
|
|
```
|
|
|
|
### 3.4 Add Model Performance Tracking
|
|
**Priority:** Medium
|
|
**Effort:** Medium
|
|
**Impact:** Better decision making
|
|
|
|
**Proposed:**
|
|
```python
|
|
class PerformanceTracker:
|
|
"""Track LLM agent performance."""
|
|
|
|
def __init__(self, db_path: str):
|
|
self.db = sqlite3.connect(db_path)
|
|
self._create_tables()
|
|
|
|
def log_decision(
|
|
self,
|
|
agent_name: str,
|
|
ticker: str,
|
|
date: str,
|
|
decision: Dict,
|
|
reasoning: str
|
|
):
|
|
"""Log agent decision for later analysis."""
|
|
cursor = self.db.cursor()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO decisions
|
|
(agent_name, ticker, date, decision, reasoning, timestamp)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(agent_name, ticker, date, json.dumps(decision), reasoning, datetime.now())
|
|
)
|
|
self.db.commit()
|
|
|
|
def log_outcome(
|
|
self,
|
|
decision_id: int,
|
|
actual_return: float,
|
|
market_return: float
|
|
):
|
|
"""Log actual outcome of decision."""
|
|
cursor = self.db.cursor()
|
|
cursor.execute(
|
|
"""
|
|
UPDATE decisions
|
|
SET actual_return = ?, market_return = ?, alpha = ?
|
|
WHERE id = ?
|
|
""",
|
|
(actual_return, market_return, actual_return - market_return, decision_id)
|
|
)
|
|
self.db.commit()
|
|
|
|
def get_agent_stats(self, agent_name: str) -> Dict:
|
|
"""Get performance statistics for an agent."""
|
|
cursor = self.db.cursor()
|
|
cursor.execute(
|
|
"""
|
|
SELECT
|
|
COUNT(*) as total_decisions,
|
|
AVG(actual_return) as avg_return,
|
|
AVG(alpha) as avg_alpha,
|
|
STDDEV(actual_return) as volatility
|
|
FROM decisions
|
|
WHERE agent_name = ? AND actual_return IS NOT NULL
|
|
""",
|
|
(agent_name,)
|
|
)
|
|
return dict(cursor.fetchone())
|
|
```
|
|
|
|
---
|
|
|
|
## Category 4: Testing & Quality Assurance
|
|
|
|
### 4.1 Comprehensive Test Suite
|
|
**Priority:** Critical
|
|
**Effort:** High
|
|
**Impact:** Code reliability
|
|
|
|
**Proposed Structure:**
|
|
```
|
|
tests/
|
|
├── __init__.py
|
|
├── conftest.py # Pytest fixtures
|
|
├── unit/
|
|
│ ├── test_config.py
|
|
│ ├── test_agents.py
|
|
│ ├── test_dataflows.py
|
|
│ └── test_portfolio.py
|
|
├── integration/
|
|
│ ├── test_trading_graph.py
|
|
│ ├── test_api_vendors.py
|
|
│ └── test_end_to_end.py
|
|
├── security/
|
|
│ ├── test_input_validation.py
|
|
│ ├── test_path_traversal.py
|
|
│ └── test_api_security.py
|
|
└── performance/
|
|
├── test_caching.py
|
|
└── test_parallel_execution.py
|
|
```
|
|
|
|
**Example Test:**
|
|
```python
|
|
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
@pytest.fixture
|
|
def mock_config():
|
|
return {
|
|
'deep_think_llm': 'gpt-4o-mini',
|
|
'quick_think_llm': 'gpt-4o-mini',
|
|
'max_debate_rounds': 1,
|
|
}
|
|
|
|
@pytest.fixture
|
|
def trading_graph(mock_config):
|
|
return TradingAgentsGraph(config=mock_config, debug=False)
|
|
|
|
def test_propagate_valid_ticker(trading_graph):
|
|
"""Test propagation with valid ticker."""
|
|
with patch('tradingagents.dataflows.y_finance.get_stock_data') as mock_data:
|
|
mock_data.return_value = {'price': 100.0}
|
|
|
|
state, decision = trading_graph.propagate('AAPL', '2024-01-01')
|
|
|
|
assert decision is not None
|
|
assert 'action' in decision
|
|
assert decision['action'] in ['BUY', 'SELL', 'HOLD']
|
|
|
|
def test_propagate_invalid_ticker(trading_graph):
|
|
"""Test propagation with invalid ticker."""
|
|
with pytest.raises(ValueError, match="Invalid ticker"):
|
|
trading_graph.propagate('../etc/passwd', '2024-01-01')
|
|
|
|
def test_path_traversal_prevention():
|
|
"""Test that path traversal is prevented."""
|
|
from cli.main import sanitize_path_component
|
|
|
|
dangerous_inputs = [
|
|
'../../../etc/passwd',
|
|
'..\\..\\..\\windows\\system32',
|
|
'ticker/../../../secrets'
|
|
]
|
|
|
|
for dangerous in dangerous_inputs:
|
|
safe = sanitize_path_component(dangerous)
|
|
assert '..' not in safe
|
|
assert '/' not in safe
|
|
assert '\\' not in safe
|
|
```
|
|
|
|
### 4.2 Property-Based Testing
|
|
**Priority:** Medium
|
|
**Effort:** Medium
|
|
**Impact:** Find edge cases
|
|
|
|
**Proposed:**
|
|
```python
|
|
from hypothesis import given, strategies as st
|
|
|
|
@given(
|
|
ticker=st.text(min_size=1, max_size=10, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd'))),
|
|
date=st.dates(min_value=date(2020, 1, 1), max_value=date.today())
|
|
)
|
|
def test_ticker_validation_property(ticker, date):
|
|
"""Property: All valid tickers should be accepted."""
|
|
from tradingagents.utils import validate_ticker
|
|
|
|
# Should not raise for alphanumeric tickers
|
|
validate_ticker(ticker)
|
|
|
|
@given(
|
|
portfolio_value=st.floats(min_value=0.0, max_value=1e9),
|
|
returns=st.lists(st.floats(min_value=-0.5, max_value=0.5), min_size=10, max_size=100)
|
|
)
|
|
def test_sharpe_ratio_properties(portfolio_value, returns):
|
|
"""Property: Sharpe ratio should be consistent."""
|
|
from tradingagents.metrics import calculate_sharpe_ratio
|
|
|
|
sharpe = calculate_sharpe_ratio(returns)
|
|
|
|
# Sharpe ratio should be finite
|
|
assert np.isfinite(sharpe)
|
|
|
|
# Reversing returns should negate Sharpe ratio
|
|
reverse_sharpe = calculate_sharpe_ratio([-r for r in returns])
|
|
assert np.isclose(sharpe, -reverse_sharpe, rtol=0.01)
|
|
```
|
|
|
|
---
|
|
|
|
## Category 5: Documentation & Developer Experience
|
|
|
|
### 5.1 Interactive Documentation
|
|
**Priority:** Medium
|
|
**Effort:** Medium
|
|
**Impact:** Better onboarding
|
|
|
|
**Proposed:**
|
|
- Add Jupyter notebooks with examples
|
|
- Create video tutorials
|
|
- Add interactive API documentation with Swagger/OpenAPI
|
|
|
|
**Example Notebook:**
|
|
```python
|
|
# notebooks/01_getting_started.ipynb
|
|
"""
|
|
# Getting Started with TradingAgents
|
|
|
|
This notebook walks you through basic usage of TradingAgents.
|
|
|
|
## Setup
|
|
"""
|
|
from tradingagents import TradingAgentsGraph, DEFAULT_CONFIG
|
|
|
|
# Configure your agents
|
|
config = DEFAULT_CONFIG.copy()
|
|
config['deep_think_llm'] = 'gpt-4o-mini'
|
|
|
|
"""
|
|
## Basic Usage
|
|
|
|
Let's analyze NVIDIA stock on a specific date:
|
|
"""
|
|
ta = TradingAgentsGraph(config=config)
|
|
state, decision = ta.propagate('NVDA', '2024-05-10')
|
|
|
|
"""
|
|
## Understanding the Decision
|
|
|
|
The decision contains:
|
|
- Action: BUY, SELL, or HOLD
|
|
- Confidence: 0-1 scale
|
|
- Reasoning: Why the decision was made
|
|
"""
|
|
print(f"Action: {decision['action']}")
|
|
print(f"Reasoning: {decision['reasoning']}")
|
|
```
|
|
|
|
### 5.2 Contributing Guide
|
|
**Priority:** Medium
|
|
**Effort:** Low
|
|
**Impact:** Community growth
|
|
|
|
**Proposed CONTRIBUTING.md:**
|
|
```markdown
|
|
# Contributing to TradingAgents
|
|
|
|
## Getting Started
|
|
|
|
1. Fork the repository
|
|
2. Clone your fork
|
|
3. Create a virtual environment
|
|
4. Install dependencies: `pip install -r requirements-dev.txt`
|
|
5. Run tests: `pytest`
|
|
|
|
## Development Workflow
|
|
|
|
1. Create a feature branch
|
|
2. Make your changes
|
|
3. Add tests
|
|
4. Run security checks: `bandit -r tradingagents/`
|
|
5. Format code: `black tradingagents/`
|
|
6. Submit PR
|
|
|
|
## Code Standards
|
|
|
|
- Follow PEP 8
|
|
- Add type hints
|
|
- Write docstrings
|
|
- Add tests for new features
|
|
- Keep security in mind
|
|
```
|
|
|
|
---
|
|
|
|
## Category 6: Monitoring & Observability
|
|
|
|
### 6.1 Metrics Collection
|
|
**Priority:** Medium
|
|
**Effort:** Medium
|
|
**Impact:** Production readiness
|
|
|
|
**Proposed:**
|
|
```python
|
|
from prometheus_client import Counter, Histogram, Gauge
|
|
import time
|
|
|
|
# Define metrics
|
|
api_calls = Counter(
|
|
'trading_agents_api_calls_total',
|
|
'Total API calls',
|
|
['vendor', 'endpoint']
|
|
)
|
|
|
|
api_latency = Histogram(
|
|
'trading_agents_api_latency_seconds',
|
|
'API call latency',
|
|
['vendor', 'endpoint']
|
|
)
|
|
|
|
llm_tokens = Counter(
|
|
'trading_agents_llm_tokens_total',
|
|
'Total LLM tokens used',
|
|
['model', 'operation']
|
|
)
|
|
|
|
portfolio_value = Gauge(
|
|
'trading_agents_portfolio_value_usd',
|
|
'Current portfolio value in USD'
|
|
)
|
|
|
|
class MonitoredAPIClient:
|
|
"""API client with metrics."""
|
|
|
|
def __init__(self, vendor: str):
|
|
self.vendor = vendor
|
|
|
|
def make_request(self, endpoint: str, **kwargs):
|
|
"""Make API request with metrics."""
|
|
api_calls.labels(vendor=self.vendor, endpoint=endpoint).inc()
|
|
|
|
start = time.time()
|
|
try:
|
|
result = self._execute_request(endpoint, **kwargs)
|
|
return result
|
|
finally:
|
|
latency = time.time() - start
|
|
api_latency.labels(
|
|
vendor=self.vendor,
|
|
endpoint=endpoint
|
|
).observe(latency)
|
|
```
|
|
|
|
### 6.2 Health Checks
|
|
**Priority:** Medium
|
|
**Effort:** Low
|
|
**Impact:** Production reliability
|
|
|
|
**Proposed:**
|
|
```python
|
|
from fastapi import FastAPI, status
|
|
from typing import Dict
|
|
|
|
app = FastAPI()
|
|
|
|
@app.get("/health")
|
|
async def health_check() -> Dict[str, str]:
|
|
"""Basic health check."""
|
|
return {"status": "healthy"}
|
|
|
|
@app.get("/health/detailed")
|
|
async def detailed_health_check() -> Dict:
|
|
"""Detailed health check."""
|
|
checks = {
|
|
"api_keys": check_api_keys(),
|
|
"data_vendors": check_data_vendors(),
|
|
"llm_providers": check_llm_providers(),
|
|
"cache": check_cache_availability(),
|
|
}
|
|
|
|
all_healthy = all(check['status'] == 'healthy' for check in checks.values())
|
|
|
|
return {
|
|
"status": "healthy" if all_healthy else "degraded",
|
|
"checks": checks,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
def check_api_keys() -> Dict:
|
|
"""Check if required API keys are set."""
|
|
required_keys = ['OPENAI_API_KEY', 'ALPHA_VANTAGE_API_KEY']
|
|
missing = [key for key in required_keys if not os.getenv(key)]
|
|
|
|
return {
|
|
"status": "healthy" if not missing else "unhealthy",
|
|
"missing_keys": missing
|
|
}
|
|
```
|
|
|
|
---
|
|
|
|
## Category 7: Advanced Features
|
|
|
|
### 7.1 Multi-Asset Support
|
|
**Priority:** Medium
|
|
**Effort:** High
|
|
**Impact:** Broader applicability
|
|
|
|
**Proposed:**
|
|
- Support for options, futures, crypto
|
|
- Cross-asset correlation analysis
|
|
- Asset allocation strategies
|
|
|
|
### 7.2 Custom Agent Development Kit
|
|
**Priority:** Low
|
|
**Effort:** High
|
|
**Impact:** Extensibility
|
|
|
|
**Proposed:**
|
|
```python
|
|
from tradingagents.sdk import BaseAgent, AgentCapability
|
|
|
|
class MyCustomAnalyst(BaseAgent):
|
|
"""Custom analyst agent."""
|
|
|
|
capabilities = [
|
|
AgentCapability.TECHNICAL_ANALYSIS,
|
|
AgentCapability.SENTIMENT_ANALYSIS
|
|
]
|
|
|
|
def analyze(self, ticker: str, date: str) -> Dict:
|
|
"""Implement custom analysis logic."""
|
|
# Your logic here
|
|
return {
|
|
'signal': 'BUY',
|
|
'confidence': 0.85,
|
|
'reasoning': 'Custom analysis reasoning'
|
|
}
|
|
|
|
def validate_input(self, ticker: str, date: str) -> bool:
|
|
"""Validate inputs."""
|
|
return self.is_valid_ticker(ticker) and self.is_valid_date(date)
|
|
```
|
|
|
|
### 7.3 Explainable AI Features
|
|
**Priority:** Medium
|
|
**Effort:** Medium
|
|
**Impact:** Trust and transparency
|
|
|
|
**Proposed:**
|
|
```python
|
|
class ExplainableDecision:
|
|
"""Make LLM decisions more explainable."""
|
|
|
|
def explain_decision(self, decision: Dict) -> Dict:
|
|
"""Generate explanation for a decision."""
|
|
return {
|
|
'decision': decision,
|
|
'contributing_factors': self._extract_factors(decision),
|
|
'confidence_breakdown': self._break_down_confidence(decision),
|
|
'alternative_scenarios': self._generate_alternatives(decision),
|
|
'risk_assessment': self._assess_risks(decision)
|
|
}
|
|
|
|
def visualize_reasoning(self, decision: Dict):
|
|
"""Create visual representation of reasoning process."""
|
|
import networkx as nx
|
|
import matplotlib.pyplot as plt
|
|
|
|
G = nx.DiGraph()
|
|
# Add nodes for each analysis step
|
|
# Add edges showing information flow
|
|
# Generate visualization
|
|
```
|
|
|
|
---
|
|
|
|
## Priority Matrix
|
|
|
|
| Enhancement | Priority | Effort | Impact | Quick Win |
|
|
|------------|----------|--------|--------|-----------|
|
|
| Type Hints | High | Medium | High | Yes |
|
|
| Security Fixes | Critical | Low | Critical | Yes |
|
|
| Caching | High | Medium | High | Yes |
|
|
| Test Suite | Critical | High | Critical | No |
|
|
| Logging | High | Medium | High | Yes |
|
|
| Backtesting | High | High | Critical | No |
|
|
| Portfolio Mgmt | High | Medium | High | No |
|
|
| Documentation | Medium | Medium | Medium | Yes |
|
|
| Monitoring | Medium | Medium | Medium | No |
|
|
|
|
---
|
|
|
|
## Implementation Roadmap
|
|
|
|
### Phase 1: Foundation (Weeks 1-2)
|
|
- Fix critical security issues
|
|
- Add comprehensive logging
|
|
- Implement type hints for core modules
|
|
- Add basic test coverage (>50%)
|
|
|
|
### Phase 2: Performance (Weeks 3-4)
|
|
- Implement caching layer
|
|
- Optimize LLM token usage
|
|
- Add parallel execution for data fetching
|
|
- Performance benchmarking
|
|
|
|
### Phase 3: Features (Weeks 5-8)
|
|
- Portfolio management system
|
|
- Backtesting framework
|
|
- Real-time data streaming
|
|
- Performance tracking
|
|
|
|
### Phase 4: Production Ready (Weeks 9-12)
|
|
- Comprehensive test coverage (>80%)
|
|
- Monitoring and metrics
|
|
- Health checks
|
|
- Documentation improvements
|
|
|
|
---
|
|
|
|
## Conclusion
|
|
|
|
These improvements would significantly enhance the TradingAgents framework in terms of:
|
|
- **Security**: Critical fixes prevent vulnerabilities
|
|
- **Performance**: Caching and parallelization improve speed
|
|
- **Reliability**: Tests and monitoring ensure stability
|
|
- **Usability**: Better docs and error handling
|
|
- **Extensibility**: Clear architecture for custom agents
|
|
|
|
The suggested enhancements align with industry best practices and would make TradingAgents production-ready for serious financial analysis.
|