add rest and websocket api

This commit is contained in:
Marvin Gabler 2025-10-21 18:58:04 +02:00
parent 3de318602f
commit 4f26352220
24 changed files with 2973 additions and 3 deletions

View File

@ -132,7 +132,7 @@ An interface will appear showing results as they load, letting you track the age
### Implementation Details
We built Litadel with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
We built Litadel with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o1-mini` and `gpt-4o-mini` to save on costs as our framework makes **lots of** API calls.
### Python Usage
@ -157,8 +157,8 @@ from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
config["deep_think_llm"] = "o1-mini" # Use a different model
config["quick_think_llm"] = "gpt-4o-mini" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
# Configure data vendors (default uses yfinance and Alpha Vantage)

View File

@ -0,0 +1,286 @@
# FastAPI Trading Agents API - Implementation Summary
## ✅ Implementation Complete
The FastAPI Trading Agents API has been successfully implemented with all planned features.
## 📁 Project Structure
```
TradingAgents/
├── api/
│ ├── __init__.py ✅ Created
│ ├── main.py ✅ FastAPI application
│ ├── database.py ✅ SQLAlchemy models
│ ├── auth.py ✅ API key authentication
│ ├── state_manager.py ✅ Analysis execution manager
│ ├── cli_admin.py ✅ Admin CLI tool
│ ├── example_client.py ✅ Example Python client
│ ├── README.md ✅ Full documentation
│ ├── models/
│ │ ├── __init__.py ✅ Exports
│ │ ├── requests.py ✅ Pydantic request models
│ │ └── responses.py ✅ Pydantic response models
│ ├── endpoints/
│ │ ├── __init__.py ✅ Created
│ │ ├── analyses.py ✅ Analysis CRUD endpoints
│ │ ├── tickers.py ✅ Ticker history endpoints
│ │ └── data.py ✅ Cached data endpoints
│ └── websockets/
│ ├── __init__.py ✅ Created
│ └── status.py ✅ Real-time status updates
├── requirements-api.txt ✅ API dependencies
├── run_api.sh ✅ Startup script
├── API_QUICKSTART.md ✅ Quick start guide
└── API_IMPLEMENTATION_SUMMARY.md ✅ This file
```
## 🎯 Features Implemented
### Core API Features
- ✅ REST API with FastAPI
- ✅ WebSocket support for real-time updates
- ✅ SQLite database with SQLAlchemy ORM
- ✅ API key authentication with bcrypt hashing
- ✅ Parallel analysis execution (ThreadPoolExecutor)
- ✅ CORS middleware for frontend integration
- ✅ Auto-generated OpenAPI documentation
### Database Schema
- ✅ `analyses` - Analysis metadata and status
- ✅ `analysis_logs` - Tool calls and reasoning logs
- ✅ `analysis_reports` - Generated reports by type
- ✅ `api_keys` - Hashed authentication keys
### REST Endpoints
#### Analyses (`/api/v1/analyses`)
- ✅ `POST /` - Create and start new analysis
- ✅ `GET /` - List all analyses with filtering
- ✅ `GET /{id}` - Get full analysis details
- ✅ `GET /{id}/status` - Get current status
- ✅ `GET /{id}/reports` - Get all reports
- ✅ `GET /{id}/reports/{type}` - Get specific report
- ✅ `GET /{id}/logs` - Get execution logs
- ✅ `DELETE /{id}` - Cancel/delete analysis
#### Tickers (`/api/v1/tickers`)
- ✅ `GET /` - List all tickers with counts
- ✅ `GET /{ticker}/analyses` - Get analyses for ticker
- ✅ `GET /{ticker}/latest` - Get latest analysis
#### Data (`/api/v1/data`)
- ✅ `GET /cache` - List cached tickers
- ✅ `GET /cache/{ticker}` - Get cached market data
#### WebSocket (`/api/v1/ws`)
- ✅ `WS /analyses/{id}` - Real-time status streaming
### State Management
- ✅ `AnalysisExecutor` class for managing parallel execution
- ✅ ThreadPoolExecutor with configurable max workers
- ✅ Status callbacks for WebSocket broadcasting
- ✅ Real-time progress tracking
- ✅ Graceful shutdown and cleanup
### Admin Tools
- ✅ `cli_admin.py` - Command-line admin interface
- ✅ `create-key` - Generate new API key
- ✅ `list-keys` - Show all keys
- ✅ `revoke-key` - Deactivate key
- ✅ `activate-key` - Reactivate key
- ✅ `init-database` - Initialize database
### Documentation
- ✅ `API_QUICKSTART.md` - Quick start guide
- ✅ `api/README.md` - Full API documentation
- ✅ `api/example_client.py` - Working example client
- ✅ Auto-generated Swagger UI at `/docs`
- ✅ Auto-generated ReDoc at `/redoc`
## 🔧 Configuration
### Environment Variables
- `API_DATABASE_URL` - Database connection (default: SQLite)
- `MAX_CONCURRENT_ANALYSES` - Concurrent analysis limit (default: 4)
- Standard TradingAgents env vars (OPENAI_API_KEY, etc.)
### Parallel Execution
- Default: 4 concurrent analyses
- Configurable via environment variable
- Thread-safe database operations
- Independent graph instances per analysis
## 📊 Data Flow
1. **Client** sends POST request to create analysis
2. **API** creates database record, returns analysis_id
3. **Executor** starts analysis in background thread
4. **Graph** streams chunks during execution
5. **State Manager** captures logs, reports, tool calls
6. **Database** stores all data in real-time
7. **WebSocket** broadcasts status updates
8. **Client** polls or streams for results
## 🔐 Security
- ✅ API key authentication (bcrypt hashed)
- ✅ Secure password hashing with passlib
- ✅ CORS middleware (configurable origins)
- ✅ Input validation with Pydantic
- ✅ SQL injection prevention (SQLAlchemy ORM)
## 🧪 Testing
### Manual Testing
```bash
# 1. Initialize
python -m api.cli_admin init-database
python -m api.cli_admin create-key "Test Key"
# 2. Start API
python -m api.main
# 3. Test endpoints
curl http://localhost:8000/health
curl -X POST http://localhost:8000/api/v1/analyses \
-H "X-API-Key: YOUR_KEY" \
-H "Content-Type: application/json" \
-d '{"ticker": "AAPL", "analysis_date": "2025-10-21", "selected_analysts": ["market", "news"], "research_depth": 1}'
# 4. Run example client
python -m api.example_client
```
### Automated Testing
Recommended test suite (not included in this implementation):
- Unit tests for each endpoint
- Integration tests for analysis workflow
- WebSocket connection tests
- Concurrent execution tests
- Authentication tests
## 📈 Performance
- **Concurrency**: 4 parallel analyses by default (configurable)
- **Database**: SQLite for development, PostgreSQL recommended for production
- **Threading**: Thread-safe operations throughout
- **WebSocket**: Efficient real-time updates
- **Caching**: Leverages existing TradingAgents data cache
## 🚀 Deployment
### Development
```bash
python -m api.main
# or
./run_api.sh
```
### Production Checklist
- [ ] Switch to PostgreSQL
- [ ] Configure specific CORS origins
- [ ] Enable HTTPS/WSS (reverse proxy)
- [ ] Add rate limiting
- [ ] Set up monitoring/logging
- [ ] Configure database backups
- [ ] Set environment-specific configs
## 🎓 Usage Examples
### cURL
```bash
# Create analysis
curl -X POST http://localhost:8000/api/v1/analyses \
-H "X-API-Key: YOUR_KEY" \
-H "Content-Type: application/json" \
-d '{"ticker": "AAPL", "analysis_date": "2025-10-21"}'
# Get status
curl http://localhost:8000/api/v1/analyses/{id}/status \
-H "X-API-Key: YOUR_KEY"
```
### Python
```python
from api.example_client import TradingAgentsAPIClient
client = TradingAgentsAPIClient("YOUR_API_KEY")
analysis = await client.create_analysis("AAPL")
await client.monitor_via_websocket(analysis["id"])
```
### JavaScript
```javascript
const response = await fetch('http://localhost:8000/api/v1/analyses', {
method: 'POST',
headers: {
'X-API-Key': 'YOUR_KEY',
'Content-Type': 'application/json',
},
body: JSON.stringify({
ticker: 'AAPL',
analysis_date: '2025-10-21',
}),
});
const analysis = await response.json();
```
## ✨ Key Achievements
1. **Separation of Concerns**: API in separate `api/` directory
2. **Persistent Storage**: SQLite with easy PostgreSQL migration
3. **Real-time Updates**: WebSocket streaming of status
4. **Parallel Processing**: ThreadPoolExecutor with configurable concurrency
5. **Complete Logging**: All tool calls, reasoning, and reports saved
6. **Security**: API key authentication with bcrypt
7. **Documentation**: Comprehensive docs with examples
8. **Admin Tools**: CLI for key management
9. **Easy Setup**: Quick start in < 5 minutes
## 🔄 Integration Points
### With Existing TradingAgents
- ✅ Uses `TradingAgentsGraph` for execution
- ✅ Leverages `DEFAULT_CONFIG` for configuration
- ✅ Integrates with asset detection
- ✅ Uses existing data cache
- ✅ Compatible with all LLM providers
### For Frontend Development
- ✅ RESTful API design
- ✅ WebSocket for real-time updates
- ✅ CORS enabled for cross-origin requests
- ✅ Comprehensive error responses
- ✅ Pagination support
- ✅ Filtering and search
## 📝 Next Steps (Optional Enhancements)
Future improvements could include:
- User authentication and multi-tenancy
- Rate limiting per API key
- Analysis templates
- Scheduled/recurring analyses
- Email/webhook notifications
- Analysis comparison tools
- Performance metrics dashboard
- Analysis export (PDF, CSV)
- Bulk operations API
- GraphQL endpoint
## 🎉 Conclusion
The FastAPI Trading Agents API is fully implemented and ready for use. It provides a robust, scalable foundation for frontend applications to interact with the TradingAgents multi-agent system.
### Getting Started
1. Read `API_QUICKSTART.md` for setup instructions
2. Check `api/README.md` for full documentation
3. Run `api/example_client.py` to see it in action
4. Start building your frontend!
### Support
- API Documentation: http://localhost:8000/docs
- Project README: `api/README.md`
- Quick Start: `API_QUICKSTART.md`

300
api/API_QUICKSTART.md Normal file
View File

@ -0,0 +1,300 @@
# Trading Agents API - Quick Start Guide
This guide will get your FastAPI Trading Agents API up and running in minutes.
## Prerequisites
- Python 3.10+
- TradingAgents installed and configured
- Required environment variables (OPENAI_API_KEY, ALPHA_VANTAGE_API_KEY, etc.)
## Setup (5 minutes)
### 1. Install API Dependencies
```bash
cd TradingAgents
pip install -r requirements-api.txt
```
### 2. Initialize Database
```bash
python -m api.cli_admin init-database
```
### 3. Create Your First API Key
```bash
python -m api.cli_admin create-key "Development Key"
```
**IMPORTANT**: Save the API key that's displayed. You won't be able to see it again!
Example output:
```
✓ API Key created successfully!
Name: Development Key
Created: 2025-10-21 14:30:00
API Key (save this, it won't be shown again):
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
Use this key in the X-API-Key header for all API requests.
```
### 4. Start the API Server
```bash
python -m api.main
```
Or use the startup script:
```bash
./run_api.sh
```
The API will start at: `http://localhost:8000`
## Test Your API (2 minutes)
### View API Documentation
Open your browser to:
- **Swagger UI**: http://localhost:8000/docs
- **ReDoc**: http://localhost:8000/redoc
### Create Your First Analysis
Replace `YOUR_API_KEY` with the key you created:
```bash
curl -X POST "http://localhost:8000/api/v1/analyses" \
-H "X-API-Key: YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"ticker": "AAPL",
"analysis_date": "2025-10-21",
"selected_analysts": ["market", "news"],
"research_depth": 1
}'
```
You'll get a response with an `analysis_id`:
```json
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"ticker": "AAPL",
"status": "pending",
...
}
```
### Check Status
```bash
curl "http://localhost:8000/api/v1/analyses/YOUR_ANALYSIS_ID/status" \
-H "X-API-Key: YOUR_API_KEY"
```
### Get Results
```bash
curl "http://localhost:8000/api/v1/analyses/YOUR_ANALYSIS_ID" \
-H "X-API-Key: YOUR_API_KEY"
```
## Using the Python Client
### Install Additional Dependencies
```bash
pip install httpx websockets
```
### Run Example Client
Edit `api/example_client.py` and replace `YOUR_API_KEY`, then:
```bash
python -m api.example_client
```
This will:
1. Create an analysis for AAPL
2. Monitor it via WebSocket
3. Display the results
## Next Steps
### Multiple Parallel Analyses
Create multiple analyses at once - they'll run in parallel:
```bash
# Start 3 analyses
for ticker in AAPL MSFT GOOGL; do
curl -X POST "http://localhost:8000/api/v1/analyses" \
-H "X-API-Key: YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d "{\"ticker\": \"$ticker\", \"analysis_date\": \"2025-10-21\", \"selected_analysts\": [\"market\", \"news\"], \"research_depth\": 1}"
done
```
### List All Analyses
```bash
curl "http://localhost:8000/api/v1/analyses" \
-H "X-API-Key: YOUR_API_KEY"
```
### Get Analyses for Specific Ticker
```bash
curl "http://localhost:8000/api/v1/tickers/AAPL/analyses" \
-H "X-API-Key: YOUR_API_KEY"
```
### View Cached Market Data
```bash
# List all cached tickers
curl "http://localhost:8000/api/v1/data/cache" \
-H "X-API-Key: YOUR_API_KEY"
# Get data for specific ticker
curl "http://localhost:8000/api/v1/data/cache/AAPL" \
-H "X-API-Key: YOUR_API_KEY"
```
## WebSocket Real-Time Monitoring
### JavaScript Example
```javascript
const ws = new WebSocket('ws://localhost:8000/api/v1/ws/analyses/YOUR_ANALYSIS_ID');
ws.onmessage = (event) => {
const update = JSON.parse(event.data);
console.log(`Status: ${update.status}, Progress: ${update.progress_percentage}%`);
if (update.status === 'completed') {
console.log('Analysis finished!');
ws.close();
}
};
ws.onerror = (error) => console.error('WebSocket error:', error);
```
### Python Example
```python
import asyncio
import websockets
import json
async def monitor_analysis(analysis_id, api_key):
uri = f"ws://localhost:8000/api/v1/ws/analyses/{analysis_id}"
async with websockets.connect(uri) as websocket:
while True:
message = await websocket.recv()
data = json.loads(message)
print(f"Status: {data['status']}, Progress: {data['progress_percentage']}%")
if data['status'] in ['completed', 'failed', 'cancelled']:
break
asyncio.run(monitor_analysis('YOUR_ANALYSIS_ID', 'YOUR_API_KEY'))
```
## API Key Management
### List All Keys
```bash
python -m api.cli_admin list-keys
```
### Create Additional Key
```bash
python -m api.cli_admin create-key "Frontend App"
```
### Revoke a Key
```bash
python -m api.cli_admin revoke-key 1
```
### Activate a Key
```bash
python -m api.cli_admin activate-key 1
```
## Configuration
### Environment Variables
Set these before starting the API:
```bash
# Maximum concurrent analyses (default: 4)
export MAX_CONCURRENT_ANALYSES=8
# Database URL (default: SQLite in current directory)
export API_DATABASE_URL="sqlite:///./api_database.db"
# Or use PostgreSQL for production
export API_DATABASE_URL="postgresql://user:pass@localhost/trading_agents"
# Standard TradingAgents config
export OPENAI_API_KEY="your-key"
export ALPHA_VANTAGE_API_KEY="your-key"
```
## Troubleshooting
### "Invalid or inactive API key"
- Make sure you're using the exact key that was displayed when you created it
- Check that the key hasn't been revoked: `python -m api.cli_admin list-keys`
### "Analysis not found"
- The analysis ID must be exactly as returned from the create endpoint
- Check available analyses: `curl "http://localhost:8000/api/v1/analyses" -H "X-API-Key: YOUR_KEY"`
### Database locked
- SQLite has limited concurrency
- Reduce `MAX_CONCURRENT_ANALYSES` to 2-3
- Or switch to PostgreSQL
### Import errors
- Make sure you're in the TradingAgents directory
- Run with: `python -m api.main` (not `python api/main.py`)
## Full API Reference
See `api/README.md` for complete documentation of all endpoints and features.
## Production Deployment
For production use, see the deployment section in `api/README.md`. Key points:
1. Use PostgreSQL instead of SQLite
2. Configure CORS for your frontend domain
3. Use HTTPS/WSS with reverse proxy
4. Add rate limiting
5. Set up monitoring and logging
## Support
For issues or questions:
- Check the full documentation: `api/README.md`
- Review the main TradingAgents README
- Check the interactive docs: http://localhost:8000/docs

257
api/README.md Normal file
View File

@ -0,0 +1,257 @@
# Trading Agents REST API
FastAPI-based REST API for managing multi-agent trading analyses with real-time WebSocket support.
## Features
- **REST API** for creating, monitoring, and managing trading analyses
- **WebSocket** support for real-time status updates
- **Parallel execution** of multiple analyses (configurable concurrency)
- **SQLite database** for persistent storage
- **API key authentication** for secure access
- **Complete analysis history** with logs and reports
## Installation
1. Install API dependencies:
```bash
cd TradingAgents
pip install -r requirements-api.txt
```
2. Initialize the database and create an API key:
```bash
python -m api.cli_admin init-database
python -m api.cli_admin create-key "My First Key"
```
Save the generated API key - you'll need it for all API requests.
## Running the API
Start the API server:
```bash
python -m api.main
```
Or with uvicorn directly:
```bash
uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
```
The API will be available at `http://localhost:8000`
## API Documentation
Once the server is running, visit:
- **Swagger UI**: http://localhost:8000/docs
- **ReDoc**: http://localhost:8000/redoc
## Quick Start
### 1. Create an Analysis
```bash
curl -X POST "http://localhost:8000/api/v1/analyses" \
-H "X-API-Key: YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"ticker": "AAPL",
"analysis_date": "2025-10-21",
"selected_analysts": ["market", "news", "social", "fundamentals"],
"research_depth": 1
}'
```
Response:
```json
{
"id": "uuid-here",
"ticker": "AAPL",
"analysis_date": "2025-10-21",
"status": "pending",
"progress_percentage": 0,
...
}
```
### 2. Monitor via WebSocket
```javascript
const ws = new WebSocket('ws://localhost:8000/api/v1/ws/analyses/{analysis_id}');
ws.onmessage = (event) => {
const update = JSON.parse(event.data);
console.log(`Status: ${update.status}, Progress: ${update.progress_percentage}%`);
};
```
### 3. Get Analysis Results
```bash
curl "http://localhost:8000/api/v1/analyses/{analysis_id}" \
-H "X-API-Key: YOUR_API_KEY"
```
## API Endpoints
### Analyses
- `POST /api/v1/analyses` - Create and start new analysis
- `GET /api/v1/analyses` - List all analyses (with filtering)
- `GET /api/v1/analyses/{id}` - Get full analysis details
- `GET /api/v1/analyses/{id}/status` - Get current status
- `GET /api/v1/analyses/{id}/reports` - Get all reports
- `GET /api/v1/analyses/{id}/reports/{type}` - Get specific report
- `GET /api/v1/analyses/{id}/logs` - Get execution logs
- `DELETE /api/v1/analyses/{id}` - Cancel/delete analysis
### Tickers
- `GET /api/v1/tickers` - List all tickers with analysis counts
- `GET /api/v1/tickers/{ticker}/analyses` - Get all analyses for ticker
- `GET /api/v1/tickers/{ticker}/latest` - Get latest analysis for ticker
### Data
- `GET /api/v1/data/cache` - List cached ticker data
- `GET /api/v1/data/cache/{ticker}` - Get cached market data
### WebSocket
- `WS /api/v1/ws/analyses/{id}` - Real-time status updates
## API Key Management
### Create API Key
```bash
python -m api.cli_admin create-key "Description"
```
### List API Keys
```bash
python -m api.cli_admin list-keys
```
### Revoke API Key
```bash
python -m api.cli_admin revoke-key <key_id>
```
### Activate API Key
```bash
python -m api.cli_admin activate-key <key_id>
```
## Configuration
### Environment Variables
- `API_DATABASE_URL` - Database connection string (default: `sqlite:///./api_database.db`)
- `MAX_CONCURRENT_ANALYSES` - Maximum parallel analyses (default: `4`)
- Standard TradingAgents config (LLM providers, API keys, etc.)
## Architecture
```
api/
├── main.py # FastAPI application
├── database.py # SQLAlchemy models
├── auth.py # API key authentication
├── state_manager.py # Analysis execution manager
├── models/ # Pydantic request/response models
├── endpoints/ # REST endpoint handlers
└── websockets/ # WebSocket handlers
```
## Database Schema
- **analyses** - Analysis metadata and status
- **analysis_logs** - Execution logs (tool calls, reasoning)
- **analysis_reports** - Generated reports (by type)
- **api_keys** - Authentication keys
## Parallel Execution
The API uses a `ThreadPoolExecutor` to run multiple analyses concurrently:
- Default: 4 concurrent analyses
- Configurable via `MAX_CONCURRENT_ANALYSES` env var
- Each analysis runs in its own thread
- Database writes are thread-safe
- Cached data is read-only (thread-safe)
## Example Frontend Integration
```javascript
class TradingAnalysisClient {
constructor(apiKey, baseURL = 'http://localhost:8000') {
this.apiKey = apiKey;
this.baseURL = baseURL;
}
async createAnalysis(ticker, date, analysts = ['market', 'news']) {
const response = await fetch(`${this.baseURL}/api/v1/analyses`, {
method: 'POST',
headers: {
'X-API-Key': this.apiKey,
'Content-Type': 'application/json',
},
body: JSON.stringify({
ticker,
analysis_date: date,
selected_analysts: analysts,
research_depth: 1,
}),
});
return await response.json();
}
connectWebSocket(analysisId, onUpdate) {
const ws = new WebSocket(`ws://localhost:8000/api/v1/ws/analyses/${analysisId}`);
ws.onmessage = (event) => onUpdate(JSON.parse(event.data));
return ws;
}
async getAnalysis(analysisId) {
const response = await fetch(
`${this.baseURL}/api/v1/analyses/${analysisId}`,
{ headers: { 'X-API-Key': this.apiKey } }
);
return await response.json();
}
}
```
## Troubleshooting
### Database locked error
SQLite has limited concurrent write support. If you get database locked errors:
- Reduce `MAX_CONCURRENT_ANALYSES`
- Or switch to PostgreSQL by changing `API_DATABASE_URL`
### Import errors
Make sure you're running from the TradingAgents root directory:
```bash
cd TradingAgents
python -m api.main
```
### WebSocket connection refused
Check that the server is running and CORS is properly configured for your frontend origin.
## Production Deployment
For production use:
1. **Use PostgreSQL** instead of SQLite
2. **Secure CORS** - Set specific allowed origins in `main.py`
3. **HTTPS/WSS** - Use reverse proxy (nginx) with SSL
4. **Monitoring** - Add logging, metrics, and health checks
5. **Rate limiting** - Add rate limiting middleware
6. **Backup** - Regular database backups
## License
Same as TradingAgents/Litadel main project.

114
api/START_API.md Normal file
View File

@ -0,0 +1,114 @@
# Quick Start - Trading Agents API
## 1. Install Dependencies (if not already done)
```bash
cd TradingAgents
pip install -r requirements.txt
```
## 2. Initialize Database & Create API Key
```bash
# Initialize the database
python -m api.cli_admin init-database
# Create your first API key
python -m api.cli_admin create-key "My Development Key"
```
**IMPORTANT**: Save the API key that's displayed! You'll need it for all requests.
## 3. Start the API
```bash
python -m api.main
```
Or use the startup script:
```bash
./run_api.sh
```
The API will start at: **http://localhost:8001**
## 4. Test It
Open your browser to see the interactive API documentation:
- **Swagger UI**: http://localhost:8001/docs
- **ReDoc**: http://localhost:8001/redoc
## 5. Create Your First Analysis
Using curl (replace `YOUR_API_KEY` with the key from step 2):
```bash
curl -X POST "http://localhost:8001/api/v1/analyses" \
-H "X-API-Key: YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"ticker": "AAPL",
"analysis_date": "2025-10-21",
"selected_analysts": ["market", "news"],
"research_depth": 1
}'
```
You'll get back an `analysis_id`. Use it to check status:
```bash
curl "http://localhost:8001/api/v1/analyses/YOUR_ANALYSIS_ID/status" \
-H "X-API-Key: YOUR_API_KEY"
```
## Configuration (Optional)
Set environment variables before starting:
```bash
# Maximum concurrent analyses (default: 4)
export MAX_CONCURRENT_ANALYSES=8
# Your LLM API keys (if not already set)
export OPENAI_API_KEY="your-key"
export ALPHA_VANTAGE_API_KEY="your-key"
# Then start the API
python -m api.main
```
## Common Commands
```bash
# List all API keys
python -m api.cli_admin list-keys
# Create a new API key
python -m api.cli_admin create-key "Frontend App"
# Revoke an API key (use ID from list-keys)
python -m api.cli_admin revoke-key 1
```
## Full Documentation
- Quick Start: `API_QUICKSTART.md`
- Full API Docs: `api/README.md`
- Implementation Details: `API_IMPLEMENTATION_SUMMARY.md`
## Troubleshooting
**"Invalid or inactive API key"**
- Make sure you're using the exact key from step 2
- Check: `python -m api.cli_admin list-keys`
**Import errors**
- Make sure you're in the TradingAgents directory
- Use: `python -m api.main` (not `python api/main.py`)
**Port 8001 already in use**
- Change port: `API_PORT=8002 python -m api.main`
- Or: `python -m uvicorn api.main:app --port 8002`
That's it! Your API is ready to use. 🚀

2
api/__init__.py Normal file
View File

@ -0,0 +1,2 @@
"""FastAPI Trading Agents API."""

101
api/auth.py Normal file
View File

@ -0,0 +1,101 @@
"""API key authentication."""
import secrets
import warnings
from typing import Optional
from fastapi import Depends, HTTPException, Security, status
from fastapi.security import APIKeyHeader
# Suppress the known passlib/bcrypt compatibility warning
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*trapped.*")
from passlib.context import CryptContext
from sqlalchemy.orm import Session
from api.database import APIKey, get_db
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# API key header scheme
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=True)
def hash_api_key(api_key: str) -> str:
"""Hash an API key."""
return pwd_context.hash(api_key)
def verify_api_key_hash(api_key: str, key_hash: str) -> bool:
"""Verify an API key against its hash."""
return pwd_context.verify(api_key, key_hash)
def generate_api_key() -> str:
"""Generate a new random API key."""
return secrets.token_urlsafe(32)
def create_api_key(db: Session, name: str) -> tuple[str, APIKey]:
"""
Create a new API key in the database.
Returns:
tuple: (plain_key, db_record)
"""
plain_key = generate_api_key()
key_hash = hash_api_key(plain_key)
db_key = APIKey(key_hash=key_hash, name=name, is_active=True)
db.add(db_key)
db.commit()
db.refresh(db_key)
return plain_key, db_key
def get_api_key_by_hash(db: Session, key_hash: str) -> Optional[APIKey]:
"""Get an API key record by its hash."""
return db.query(APIKey).filter(APIKey.key_hash == key_hash).first()
def verify_api_key_from_db(db: Session, api_key: str) -> Optional[APIKey]:
"""
Verify an API key against the database.
Returns:
APIKey record if valid and active, None otherwise
"""
# Get all active keys
active_keys = db.query(APIKey).filter(APIKey.is_active == True).all()
# Check each one
for key_record in active_keys:
if verify_api_key_hash(api_key, key_record.key_hash):
return key_record
return None
async def get_current_api_key(
api_key: str = Security(api_key_header),
db: Session = Depends(get_db),
) -> APIKey:
"""
Dependency to verify API key and return the key record.
Raises:
HTTPException: If API key is invalid or inactive
"""
key_record = verify_api_key_from_db(db, api_key)
if not key_record:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or inactive API key",
)
return key_record

146
api/cli_admin.py Normal file
View File

@ -0,0 +1,146 @@
"""Admin CLI tool for API key management."""
import sys
from datetime import datetime
import typer
from rich.console import Console
from rich.table import Table
from api.auth import create_api_key
from api.database import APIKey, SessionLocal, init_db
app = typer.Typer(
name="api-admin",
help="Trading Agents API Administration Tool",
)
console = Console()
@app.command()
def create_key(name: str = typer.Argument(..., help="Name/description for the API key")):
"""Create a new API key."""
# Initialize database
init_db()
db = SessionLocal()
try:
plain_key, db_key = create_api_key(db, name)
console.print("\n[green]✓ API Key created successfully![/green]\n")
console.print(f"[bold]Name:[/bold] {db_key.name}")
console.print(f"[bold]Created:[/bold] {db_key.created_at}")
console.print(f"\n[bold yellow]API Key (save this, it won't be shown again):[/bold yellow]")
console.print(f"[cyan]{plain_key}[/cyan]\n")
console.print("[dim]Use this key in the X-API-Key header for all API requests.[/dim]\n")
except Exception as e:
console.print(f"[red]Error creating API key: {e}[/red]")
sys.exit(1)
finally:
db.close()
@app.command()
def list_keys():
"""List all API keys."""
init_db()
db = SessionLocal()
try:
keys = db.query(APIKey).order_by(APIKey.created_at.desc()).all()
if not keys:
console.print("[yellow]No API keys found.[/yellow]")
return
table = Table(title="API Keys")
table.add_column("ID", style="cyan")
table.add_column("Name", style="green")
table.add_column("Created", style="blue")
table.add_column("Status", style="magenta")
for key in keys:
status = "Active" if key.is_active else "Revoked"
status_color = "green" if key.is_active else "red"
table.add_row(
str(key.id),
key.name,
key.created_at.strftime("%Y-%m-%d %H:%M:%S"),
f"[{status_color}]{status}[/{status_color}]",
)
console.print(table)
finally:
db.close()
@app.command()
def revoke_key(key_id: int = typer.Argument(..., help="API key ID to revoke")):
"""Revoke an API key."""
init_db()
db = SessionLocal()
try:
key = db.query(APIKey).filter(APIKey.id == key_id).first()
if not key:
console.print(f"[red]API key with ID {key_id} not found.[/red]")
sys.exit(1)
if not key.is_active:
console.print(f"[yellow]API key '{key.name}' is already revoked.[/yellow]")
return
key.is_active = False
db.commit()
console.print(f"[green]✓ API key '{key.name}' has been revoked.[/green]")
except Exception as e:
console.print(f"[red]Error revoking API key: {e}[/red]")
sys.exit(1)
finally:
db.close()
@app.command()
def activate_key(key_id: int = typer.Argument(..., help="API key ID to activate")):
"""Activate a revoked API key."""
init_db()
db = SessionLocal()
try:
key = db.query(APIKey).filter(APIKey.id == key_id).first()
if not key:
console.print(f"[red]API key with ID {key_id} not found.[/red]")
sys.exit(1)
if key.is_active:
console.print(f"[yellow]API key '{key.name}' is already active.[/yellow]")
return
key.is_active = True
db.commit()
console.print(f"[green]✓ API key '{key.name}' has been activated.[/green]")
except Exception as e:
console.print(f"[red]Error activating API key: {e}[/red]")
sys.exit(1)
finally:
db.close()
@app.command()
def init_database():
"""Initialize the database (create tables)."""
try:
init_db()
console.print("[green]✓ Database initialized successfully![/green]")
except Exception as e:
console.print(f"[red]Error initializing database: {e}[/red]")
sys.exit(1)
if __name__ == "__main__":
app()

114
api/database.py Normal file
View File

@ -0,0 +1,114 @@
"""Database models and connection management."""
import os
from datetime import datetime
from typing import Generator
from sqlalchemy import (
Boolean,
Column,
DateTime,
Integer,
String,
Text,
create_engine,
ForeignKey,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker, relationship
# Database file location
DATABASE_URL = os.getenv("API_DATABASE_URL", "sqlite:///./api_database.db")
# Create engine
engine = create_engine(
DATABASE_URL,
connect_args={"check_same_thread": False} if DATABASE_URL.startswith("sqlite") else {},
)
# Session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base class for models
Base = declarative_base()
class Analysis(Base):
"""Analysis metadata and configuration."""
__tablename__ = "analyses"
id = Column(String, primary_key=True, index=True)
ticker = Column(String, index=True, nullable=False)
analysis_date = Column(String, nullable=False)
status = Column(
String, nullable=False, default="pending"
) # pending, running, completed, failed, cancelled
config_json = Column(Text, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
completed_at = Column(DateTime, nullable=True)
error_message = Column(Text, nullable=True)
progress_percentage = Column(Integer, default=0)
current_agent = Column(String, nullable=True)
# Relationships
logs = relationship("AnalysisLog", back_populates="analysis", cascade="all, delete-orphan")
reports = relationship("AnalysisReport", back_populates="analysis", cascade="all, delete-orphan")
class AnalysisLog(Base):
"""Log entries from analysis execution."""
__tablename__ = "analysis_logs"
id = Column(Integer, primary_key=True, index=True)
analysis_id = Column(String, ForeignKey("analyses.id"), nullable=False)
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
log_type = Column(String, nullable=False) # Tool Call, Reasoning, System
content = Column(Text, nullable=False)
# Relationships
analysis = relationship("Analysis", back_populates="logs")
class AnalysisReport(Base):
"""Report sections from analysis."""
__tablename__ = "analysis_reports"
id = Column(Integer, primary_key=True, index=True)
analysis_id = Column(String, ForeignKey("analyses.id"), nullable=False)
report_type = Column(String, nullable=False) # market_report, news_report, etc.
content = Column(Text, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
# Relationships
analysis = relationship("Analysis", back_populates="reports")
class APIKey(Base):
"""API keys for authentication."""
__tablename__ = "api_keys"
id = Column(Integer, primary_key=True, index=True)
key_hash = Column(String, unique=True, nullable=False, index=True)
name = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
is_active = Column(Boolean, default=True, nullable=False)
def init_db():
"""Initialize database and create all tables."""
Base.metadata.create_all(bind=engine)
def get_db() -> Generator[Session, None, None]:
"""Dependency for getting database sessions."""
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@ -0,0 +1,2 @@
"""API endpoint routers."""

356
api/endpoints/analyses.py Normal file
View File

@ -0,0 +1,356 @@
"""Analysis CRUD endpoints."""
import json
import logging
import uuid
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from api.auth import APIKey, get_current_api_key
from api.database import Analysis, AnalysisLog, AnalysisReport, get_db
from api.models import (
AnalysisResponse,
AnalysisStatusResponse,
AnalysisSummary,
CreateAnalysisRequest,
LogEntry,
ReportResponse,
)
from api.state_manager import get_executor
from tradingagents.default_config import DEFAULT_CONFIG
router = APIRouter(prefix="/api/v1/analyses", tags=["analyses"])
logger = logging.getLogger(__name__)
@router.post("", response_model=AnalysisResponse, status_code=status.HTTP_201_CREATED)
async def create_analysis(
request: CreateAnalysisRequest,
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Create and start a new analysis."""
# Generate analysis ID
analysis_id = str(uuid.uuid4())
# Build configuration
config = DEFAULT_CONFIG.copy()
config["max_debate_rounds"] = request.research_depth
config["max_risk_discuss_rounds"] = request.research_depth
if request.llm_provider:
config["llm_provider"] = request.llm_provider
if request.backend_url:
config["backend_url"] = request.backend_url
if request.quick_think_llm:
config["quick_think_llm"] = request.quick_think_llm
if request.deep_think_llm:
config["deep_think_llm"] = request.deep_think_llm
# Auto-detect asset class
from cli.asset_detection import detect_asset_class
asset_class = detect_asset_class(request.ticker)
config["asset_class"] = asset_class
# Filter out fundamentals for commodities/crypto
selected_analysts = request.selected_analysts
if asset_class in ["commodity", "crypto"] and "fundamentals" in selected_analysts:
selected_analysts = [a for a in selected_analysts if a != "fundamentals"]
# Create database record
analysis = Analysis(
id=analysis_id,
ticker=request.ticker,
analysis_date=request.analysis_date,
status="pending",
config_json=json.dumps(config),
progress_percentage=0,
)
db.add(analysis)
db.commit()
db.refresh(analysis)
# Start analysis in background
logger.info(f"Creating analysis {analysis_id} for {request.ticker}")
executor = get_executor()
try:
executor.start_analysis(
analysis_id=analysis_id,
ticker=request.ticker,
analysis_date=request.analysis_date,
selected_analysts=selected_analysts,
config=config,
)
logger.info(f"Analysis {analysis_id} started successfully")
except Exception as e:
logger.error(f"Failed to start analysis {analysis_id}: {str(e)}")
# Update status to failed
analysis.status = "failed"
analysis.error_message = str(e)
db.commit()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to start analysis: {str(e)}",
)
return AnalysisResponse(
id=analysis.id,
ticker=analysis.ticker,
analysis_date=analysis.analysis_date,
status=analysis.status,
config=config,
reports=[],
progress_percentage=analysis.progress_percentage,
current_agent=analysis.current_agent,
created_at=analysis.created_at,
updated_at=analysis.updated_at,
completed_at=analysis.completed_at,
error_message=analysis.error_message,
)
@router.get("", response_model=List[AnalysisSummary])
async def list_analyses(
ticker: Optional[str] = Query(None, description="Filter by ticker"),
status: Optional[str] = Query(None, description="Filter by status"),
date_from: Optional[str] = Query(None, description="Filter by date (from)"),
date_to: Optional[str] = Query(None, description="Filter by date (to)"),
limit: int = Query(100, ge=1, le=1000, description="Max results"),
offset: int = Query(0, ge=0, description="Offset for pagination"),
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""List all analyses with optional filtering."""
query = db.query(Analysis)
if ticker:
query = query.filter(Analysis.ticker == ticker.upper())
if status:
query = query.filter(Analysis.status == status)
if date_from:
query = query.filter(Analysis.analysis_date >= date_from)
if date_to:
query = query.filter(Analysis.analysis_date <= date_to)
# Order by created_at descending
query = query.order_by(Analysis.created_at.desc())
# Apply pagination
analyses = query.offset(offset).limit(limit).all()
return [
AnalysisSummary(
id=a.id,
ticker=a.ticker,
analysis_date=a.analysis_date,
status=a.status,
created_at=a.created_at,
completed_at=a.completed_at,
error_message=a.error_message,
)
for a in analyses
]
@router.get("/{analysis_id}", response_model=AnalysisResponse)
async def get_analysis(
analysis_id: str,
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Get full analysis details."""
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if not analysis:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Analysis {analysis_id} not found",
)
# Get reports
reports = (
db.query(AnalysisReport)
.filter(AnalysisReport.analysis_id == analysis_id)
.all()
)
return AnalysisResponse(
id=analysis.id,
ticker=analysis.ticker,
analysis_date=analysis.analysis_date,
status=analysis.status,
config=json.loads(analysis.config_json),
reports=[
ReportResponse(
report_type=r.report_type,
content=r.content,
created_at=r.created_at,
)
for r in reports
],
progress_percentage=analysis.progress_percentage,
current_agent=analysis.current_agent,
created_at=analysis.created_at,
updated_at=analysis.updated_at,
completed_at=analysis.completed_at,
error_message=analysis.error_message,
)
@router.get("/{analysis_id}/status", response_model=AnalysisStatusResponse)
async def get_analysis_status(
analysis_id: str,
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Get current analysis status."""
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if not analysis:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Analysis {analysis_id} not found",
)
return AnalysisStatusResponse(
id=analysis.id,
status=analysis.status,
progress_percentage=analysis.progress_percentage,
current_agent=analysis.current_agent,
updated_at=analysis.updated_at,
)
@router.get("/{analysis_id}/reports", response_model=List[ReportResponse])
async def get_analysis_reports(
analysis_id: str,
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Get all reports for an analysis."""
# Check if analysis exists
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if not analysis:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Analysis {analysis_id} not found",
)
reports = (
db.query(AnalysisReport)
.filter(AnalysisReport.analysis_id == analysis_id)
.order_by(AnalysisReport.created_at)
.all()
)
return [
ReportResponse(
report_type=r.report_type,
content=r.content,
created_at=r.created_at,
)
for r in reports
]
@router.get("/{analysis_id}/reports/{report_type}", response_model=ReportResponse)
async def get_analysis_report(
analysis_id: str,
report_type: str,
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Get a specific report for an analysis."""
report = (
db.query(AnalysisReport)
.filter(
AnalysisReport.analysis_id == analysis_id,
AnalysisReport.report_type == report_type,
)
.first()
)
if not report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Report {report_type} not found for analysis {analysis_id}",
)
return ReportResponse(
report_type=report.report_type,
content=report.content,
created_at=report.created_at,
)
@router.get("/{analysis_id}/logs", response_model=List[LogEntry])
async def get_analysis_logs(
analysis_id: str,
log_type: Optional[str] = Query(None, description="Filter by log type"),
limit: int = Query(100, ge=1, le=1000, description="Max results"),
offset: int = Query(0, ge=0, description="Offset for pagination"),
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Get execution logs for an analysis."""
# Check if analysis exists
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if not analysis:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Analysis {analysis_id} not found",
)
query = db.query(AnalysisLog).filter(AnalysisLog.analysis_id == analysis_id)
if log_type:
query = query.filter(AnalysisLog.log_type == log_type)
logs = query.order_by(AnalysisLog.timestamp).offset(offset).limit(limit).all()
return [
LogEntry(
timestamp=log.timestamp,
log_type=log.log_type,
content=log.content,
)
for log in logs
]
@router.delete("/{analysis_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_analysis(
analysis_id: str,
permanent: bool = Query(False, description="Permanently delete from database"),
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Cancel and/or delete an analysis."""
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if not analysis:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Analysis {analysis_id} not found",
)
# Try to cancel if running
if analysis.status in ["pending", "running"]:
executor = get_executor()
executor.cancel_analysis(analysis_id)
# Delete from database if requested
if permanent:
db.delete(analysis)
db.commit()
elif analysis.status not in ["cancelled", "failed", "completed"]:
# Just mark as cancelled
analysis.status = "cancelled"
analysis.updated_at = datetime.utcnow()
db.commit()
return None

129
api/endpoints/data.py Normal file
View File

@ -0,0 +1,129 @@
"""Cached data access endpoints."""
import csv
import glob
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from api.auth import APIKey, get_current_api_key
from api.models.responses import CachedDataResponse, CachedTickerInfo
router = APIRouter(prefix="/api/v1/data", tags=["data"])
# Data cache directory
DATA_CACHE_DIR = Path("./tradingagents/dataflows/data_cache")
def _parse_date_range(filename: str) -> Optional[Dict[str, str]]:
"""Parse date range from cache filename."""
try:
# Format: TICKER-YFin-data-START-END.csv
parts = filename.replace(".csv", "").split("-")
if len(parts) >= 5:
start_date = parts[-2]
end_date = parts[-1]
return {"start": start_date, "end": end_date}
except:
pass
return None
@router.get("/cache", response_model=List[CachedTickerInfo])
async def list_cached_tickers(
api_key: APIKey = Depends(get_current_api_key),
):
"""List all cached tickers with date ranges."""
if not DATA_CACHE_DIR.exists():
return []
cached_tickers = []
for csv_file in DATA_CACHE_DIR.glob("*-YFin-data-*.csv"):
ticker = csv_file.name.split("-")[0]
date_range = _parse_date_range(csv_file.name)
if date_range:
# Count records
try:
with open(csv_file, "r") as f:
record_count = sum(1 for _ in f) - 1 # Subtract header
except:
record_count = 0
cached_tickers.append(
CachedTickerInfo(
ticker=ticker,
date_range=date_range,
record_count=record_count,
)
)
return sorted(cached_tickers, key=lambda x: x.ticker)
@router.get("/cache/{ticker}", response_model=CachedDataResponse)
async def get_cached_data(
ticker: str,
start_date: Optional[str] = Query(None, description="Filter from date (YYYY-MM-DD)"),
end_date: Optional[str] = Query(None, description="Filter to date (YYYY-MM-DD)"),
api_key: APIKey = Depends(get_current_api_key),
):
"""Get cached market data for a ticker."""
# Find matching file
pattern = f"{ticker.upper()}-YFin-data-*.csv"
matching_files = list(DATA_CACHE_DIR.glob(pattern))
if not matching_files:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No cached data found for ticker {ticker}",
)
# Use the first matching file (should only be one)
csv_file = matching_files[0]
date_range = _parse_date_range(csv_file.name)
if not date_range:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not parse date range from cache file",
)
# Read CSV data
data = []
try:
with open(csv_file, "r") as f:
reader = csv.DictReader(f)
for row in reader:
# Filter by date if specified
row_date = row.get("Date", "")
if start_date and row_date < start_date:
continue
if end_date and row_date > end_date:
continue
# Convert numeric fields
for field in ["Close", "High", "Low", "Open", "Volume"]:
if field in row and row[field]:
try:
row[field] = float(row[field])
except:
pass
data.append(row)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error reading cache file: {str(e)}",
)
return CachedDataResponse(
ticker=ticker.upper(),
date_range=date_range,
data=data,
)

95
api/endpoints/tickers.py Normal file
View File

@ -0,0 +1,95 @@
"""Ticker history endpoints."""
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func
from sqlalchemy.orm import Session
from api.auth import APIKey, get_current_api_key
from api.database import Analysis, get_db
from api.models import AnalysisResponse, AnalysisSummary, ReportResponse, TickerInfo
from api.endpoints.analyses import get_analysis
router = APIRouter(prefix="/api/v1/tickers", tags=["tickers"])
@router.get("", response_model=List[TickerInfo])
async def list_tickers(
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""List all tickers with analysis count."""
# Query for ticker stats
results = (
db.query(
Analysis.ticker,
func.count(Analysis.id).label("analysis_count"),
func.max(Analysis.analysis_date).label("latest_date"),
)
.group_by(Analysis.ticker)
.order_by(Analysis.ticker)
.all()
)
return [
TickerInfo(
ticker=r.ticker,
analysis_count=r.analysis_count,
latest_date=r.latest_date,
)
for r in results
]
@router.get("/{ticker}/analyses", response_model=List[AnalysisSummary])
async def get_ticker_analyses(
ticker: str,
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Get all analyses for a ticker."""
analyses = (
db.query(Analysis)
.filter(Analysis.ticker == ticker.upper())
.order_by(Analysis.created_at.desc())
.all()
)
return [
AnalysisSummary(
id=a.id,
ticker=a.ticker,
analysis_date=a.analysis_date,
status=a.status,
created_at=a.created_at,
completed_at=a.completed_at,
error_message=a.error_message,
)
for a in analyses
]
@router.get("/{ticker}/latest", response_model=AnalysisResponse)
async def get_ticker_latest_analysis(
ticker: str,
db: Session = Depends(get_db),
api_key: APIKey = Depends(get_current_api_key),
):
"""Get the most recent analysis for a ticker."""
analysis = (
db.query(Analysis)
.filter(Analysis.ticker == ticker.upper())
.order_by(Analysis.created_at.desc())
.first()
)
if not analysis:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No analyses found for ticker {ticker}",
)
# Use the existing get_analysis function
return await get_analysis(analysis.id, db, api_key)

169
api/example_client.py Normal file
View File

@ -0,0 +1,169 @@
"""Example client for Testing the Trading Agents API."""
import asyncio
import json
import time
from datetime import date
import httpx
import websockets
class TradingAgentsAPIClient:
"""Simple client for interacting with the Trading Agents API."""
def __init__(self, api_key: str, base_url: str = "http://localhost:8000"):
self.api_key = api_key
self.base_url = base_url
self.headers = {
"X-API-Key": api_key,
"Content-Type": "application/json",
}
async def create_analysis(
self,
ticker: str,
analysis_date: str = None,
selected_analysts: list = None,
research_depth: int = 1,
):
"""Create a new analysis."""
if analysis_date is None:
analysis_date = date.today().strftime("%Y-%m-%d")
if selected_analysts is None:
selected_analysts = ["market", "news"]
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/api/v1/analyses",
headers=self.headers,
json={
"ticker": ticker,
"analysis_date": analysis_date,
"selected_analysts": selected_analysts,
"research_depth": research_depth,
},
)
response.raise_for_status()
return response.json()
async def get_analysis(self, analysis_id: str):
"""Get full analysis details."""
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/api/v1/analyses/{analysis_id}",
headers=self.headers,
)
response.raise_for_status()
return response.json()
async def get_status(self, analysis_id: str):
"""Get analysis status."""
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/api/v1/analyses/{analysis_id}/status",
headers=self.headers,
)
response.raise_for_status()
return response.json()
async def list_analyses(self, ticker: str = None):
"""List all analyses."""
params = {}
if ticker:
params["ticker"] = ticker
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/api/v1/analyses",
headers=self.headers,
params=params,
)
response.raise_for_status()
return response.json()
async def monitor_via_websocket(self, analysis_id: str, duration: int = 300):
"""Monitor analysis via WebSocket."""
ws_url = f"ws://localhost:8000/api/v1/ws/analyses/{analysis_id}"
print(f"Connecting to WebSocket: {ws_url}")
try:
async with websockets.connect(ws_url) as websocket:
print("Connected! Waiting for updates...")
start_time = time.time()
while time.time() - start_time < duration:
try:
message = await asyncio.wait_for(
websocket.recv(), timeout=10.0
)
data = json.loads(message)
print(f"\n[Update] Status: {data['status']}")
print(f" Progress: {data['progress_percentage']}%")
if data.get('current_agent'):
print(f" Agent: {data['current_agent']}")
if data['status'] in ['completed', 'failed', 'cancelled']:
print("\nAnalysis finished!")
break
except asyncio.TimeoutError:
# Send ping to keep connection alive
await websocket.send("ping")
continue
except Exception as e:
print(f"WebSocket error: {e}")
async def main():
"""Example usage."""
# Replace with your actual API key
API_KEY = "your-api-key-here"
client = TradingAgentsAPIClient(API_KEY)
print("=" * 60)
print("Trading Agents API Client Example")
print("=" * 60)
# 1. Create an analysis
print("\n1. Creating analysis for AAPL...")
analysis = await client.create_analysis(
ticker="AAPL",
selected_analysts=["market", "news"],
research_depth=1,
)
analysis_id = analysis["id"]
print(f" Created: {analysis_id}")
print(f" Status: {analysis['status']}")
# 2. Monitor via WebSocket (run this in background or separately)
print("\n2. Monitoring via WebSocket...")
await client.monitor_via_websocket(analysis_id, duration=600)
# 3. Get final results
print("\n3. Getting final results...")
final = await client.get_analysis(analysis_id)
print(f" Status: {final['status']}")
print(f" Reports: {len(final['reports'])} available")
for report in final['reports']:
print(f"\n - {report['report_type']}:")
print(f" {report['content'][:200]}...")
# 4. List all analyses
print("\n4. Listing all AAPL analyses...")
all_analyses = await client.list_analyses(ticker="AAPL")
print(f" Found {len(all_analyses)} analyses")
print("\n" + "=" * 60)
print("Done!")
if __name__ == "__main__":
asyncio.run(main())

104
api/main.py Normal file
View File

@ -0,0 +1,104 @@
"""FastAPI Trading Agents API application."""
import logging
import sys
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
# Load environment variables from .env file
load_dotenv()
from api.database import init_db
from api.endpoints import analyses, data, tickers
from api.state_manager import get_executor, shutdown_executor
from api.websockets import status
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
],
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
# Startup
logger.info("Initializing Trading Agents API...")
init_db()
get_executor()
logger.info("Trading Agents API started successfully")
yield
# Shutdown
logger.info("Shutting down Trading Agents API...")
shutdown_executor()
logger.info("Trading Agents API shutdown complete")
# Create FastAPI app
app = FastAPI(
title="Trading Agents API",
description="REST API for managing multi-agent trading analyses",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify actual origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
app.include_router(analyses.router)
app.include_router(tickers.router)
app.include_router(data.router)
app.include_router(status.router)
@app.get("/")
async def root():
"""Root endpoint."""
return {
"name": "Trading Agents API",
"version": "1.0.0",
"docs": "/docs",
"redoc": "/redoc",
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
import os
port = int(os.getenv("API_PORT", "8002")) # Default to 8001 instead of 8000
uvicorn.run(
"api.main:app",
host="0.0.0.0",
port=port,
reload=True,
log_level="info",
)

27
api/models/__init__.py Normal file
View File

@ -0,0 +1,27 @@
"""Pydantic models for API requests and responses."""
from api.models.requests import CreateAnalysisRequest, UpdateAnalysisRequest
from api.models.responses import (
AnalysisResponse,
AnalysisStatusResponse,
AnalysisSummary,
ErrorResponse,
LogEntry,
ReportResponse,
TickerInfo,
CachedDataResponse,
)
__all__ = [
"CreateAnalysisRequest",
"UpdateAnalysisRequest",
"AnalysisResponse",
"AnalysisStatusResponse",
"AnalysisSummary",
"ErrorResponse",
"LogEntry",
"ReportResponse",
"TickerInfo",
"CachedDataResponse",
]

68
api/models/requests.py Normal file
View File

@ -0,0 +1,68 @@
"""Request models for API endpoints."""
from typing import List, Optional
from pydantic import BaseModel, Field, field_validator
class CreateAnalysisRequest(BaseModel):
"""Request to create a new analysis."""
ticker: str = Field(..., description="Ticker symbol to analyze")
analysis_date: str = Field(..., description="Analysis date in YYYY-MM-DD format")
selected_analysts: List[str] = Field(
default=["market", "news", "social", "fundamentals"],
description="List of analysts to run (market, news, social, fundamentals)",
)
research_depth: int = Field(
default=1,
ge=1,
le=5,
description="Research depth (1=shallow, 3=medium, 5=deep)",
)
llm_provider: Optional[str] = Field(
default=None,
description="LLM provider (openai, anthropic, google). Uses default if not specified.",
)
backend_url: Optional[str] = Field(
default=None,
description="Backend URL for LLM. Uses default if not specified.",
)
quick_think_llm: Optional[str] = Field(
default=None,
description="Model for quick thinking. Uses default if not specified.",
)
deep_think_llm: Optional[str] = Field(
default=None,
description="Model for deep thinking. Uses default if not specified.",
)
@field_validator("ticker")
@classmethod
def ticker_uppercase(cls, v: str) -> str:
"""Convert ticker to uppercase."""
return v.upper().strip()
@field_validator("selected_analysts")
@classmethod
def validate_analysts(cls, v: List[str]) -> List[str]:
"""Validate analyst selections."""
valid_analysts = {"market", "news", "social", "fundamentals"}
for analyst in v:
if analyst not in valid_analysts:
raise ValueError(
f"Invalid analyst: {analyst}. Must be one of {valid_analysts}"
)
if not v:
raise ValueError("At least one analyst must be selected")
return v
class UpdateAnalysisRequest(BaseModel):
"""Request to update analysis metadata."""
status: Optional[str] = Field(
default=None,
description="New status (pending, running, completed, failed, cancelled)",
)

95
api/models/responses.py Normal file
View File

@ -0,0 +1,95 @@
"""Response models for API endpoints."""
from datetime import datetime
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
class ErrorResponse(BaseModel):
"""Standard error response."""
error: str = Field(..., description="Error message")
detail: Optional[str] = Field(None, description="Additional error details")
class ReportResponse(BaseModel):
"""Single report section."""
report_type: str = Field(..., description="Type of report")
content: str = Field(..., description="Report content in markdown")
created_at: datetime = Field(..., description="When the report was created")
class LogEntry(BaseModel):
"""Single log entry."""
timestamp: datetime = Field(..., description="When the log was created")
log_type: str = Field(..., description="Type of log (Tool Call, Reasoning, System)")
content: str = Field(..., description="Log content")
class AnalysisStatusResponse(BaseModel):
"""Lightweight analysis status."""
id: str = Field(..., description="Analysis ID")
status: str = Field(..., description="Current status")
progress_percentage: int = Field(..., description="Progress percentage (0-100)")
current_agent: Optional[str] = Field(None, description="Currently active agent")
updated_at: datetime = Field(..., description="Last update timestamp")
class AnalysisSummary(BaseModel):
"""Summary view of an analysis."""
id: str = Field(..., description="Analysis ID")
ticker: str = Field(..., description="Ticker symbol")
analysis_date: str = Field(..., description="Analysis date")
status: str = Field(..., description="Current status")
created_at: datetime = Field(..., description="Creation timestamp")
completed_at: Optional[datetime] = Field(None, description="Completion timestamp")
error_message: Optional[str] = Field(None, description="Error message if failed")
class AnalysisResponse(BaseModel):
"""Full analysis details."""
id: str = Field(..., description="Analysis ID")
ticker: str = Field(..., description="Ticker symbol")
analysis_date: str = Field(..., description="Analysis date")
status: str = Field(..., description="Current status")
config: Dict = Field(..., description="Analysis configuration")
reports: List[ReportResponse] = Field(
default_factory=list, description="All report sections"
)
progress_percentage: int = Field(..., description="Progress percentage (0-100)")
current_agent: Optional[str] = Field(None, description="Currently active agent")
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
completed_at: Optional[datetime] = Field(None, description="Completion timestamp")
error_message: Optional[str] = Field(None, description="Error message if failed")
class TickerInfo(BaseModel):
"""Ticker information summary."""
ticker: str = Field(..., description="Ticker symbol")
analysis_count: int = Field(..., description="Number of analyses")
latest_date: Optional[str] = Field(None, description="Most recent analysis date")
class CachedDataResponse(BaseModel):
"""Cached market data response."""
ticker: str = Field(..., description="Ticker symbol")
date_range: Dict[str, str] = Field(..., description="Start and end dates")
data: List[Dict] = Field(..., description="OHLCV data records")
class CachedTickerInfo(BaseModel):
"""Information about cached ticker data."""
ticker: str = Field(..., description="Ticker symbol")
date_range: Dict[str, str] = Field(..., description="Start and end dates")
record_count: int = Field(..., description="Number of records")

447
api/state_manager.py Normal file
View File

@ -0,0 +1,447 @@
"""Analysis execution and state management."""
import json
import logging
import os
import threading
import traceback
import uuid
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from sqlalchemy.orm import Session
from api.database import Analysis, AnalysisLog, AnalysisReport, SessionLocal
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
logger = logging.getLogger(__name__)
class AnalysisExecutor:
"""Manages analysis execution with thread pool and state tracking."""
def __init__(self, max_workers: int = None):
"""
Initialize the executor.
Args:
max_workers: Maximum concurrent analyses (default: from env or 4)
"""
if max_workers is None:
max_workers = int(os.getenv("MAX_CONCURRENT_ANALYSES", "4"))
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.active_analyses: Dict[str, Future] = {}
self.status_callbacks: Dict[str, List[Callable]] = {}
self._lock = threading.Lock()
def register_status_callback(self, analysis_id: str, callback: Callable):
"""Register a callback for status updates."""
with self._lock:
if analysis_id not in self.status_callbacks:
self.status_callbacks[analysis_id] = []
self.status_callbacks[analysis_id].append(callback)
def unregister_status_callbacks(self, analysis_id: str):
"""Remove all callbacks for an analysis."""
with self._lock:
if analysis_id in self.status_callbacks:
del self.status_callbacks[analysis_id]
def _notify_callbacks(self, analysis_id: str, status_data: Dict[str, Any]):
"""Notify all registered callbacks."""
with self._lock:
callbacks = self.status_callbacks.get(analysis_id, [])
for callback in callbacks:
try:
callback(status_data)
except Exception as e:
print(f"Error in status callback: {e}")
def start_analysis(
self,
analysis_id: str,
ticker: str,
analysis_date: str,
selected_analysts: List[str],
config: Dict[str, Any],
) -> str:
"""
Start a new analysis in the background.
Args:
analysis_id: Unique analysis ID
ticker: Ticker symbol
analysis_date: Analysis date
selected_analysts: List of analyst types
config: Trading agents configuration
Returns:
analysis_id
"""
future = self.executor.submit(
self._run_analysis,
analysis_id,
ticker,
analysis_date,
selected_analysts,
config,
)
with self._lock:
self.active_analyses[analysis_id] = future
# Cleanup when done
future.add_done_callback(lambda f: self._cleanup_analysis(analysis_id))
return analysis_id
def _cleanup_analysis(self, analysis_id: str):
"""Clean up after analysis completes."""
with self._lock:
if analysis_id in self.active_analyses:
del self.active_analyses[analysis_id]
self.unregister_status_callbacks(analysis_id)
def cancel_analysis(self, analysis_id: str) -> bool:
"""
Attempt to cancel a running analysis.
Returns:
True if cancelled, False if not found or already completed
"""
with self._lock:
future = self.active_analyses.get(analysis_id)
if future and not future.done():
cancelled = future.cancel()
if cancelled:
# Update database status
db = SessionLocal()
try:
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if analysis:
analysis.status = "cancelled"
analysis.updated_at = datetime.utcnow()
db.commit()
finally:
db.close()
return cancelled
return False
def get_status(self, analysis_id: str) -> Optional[Dict[str, Any]]:
"""Get current status of an analysis."""
db = SessionLocal()
try:
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if not analysis:
return None
return {
"id": analysis.id,
"status": analysis.status,
"progress_percentage": analysis.progress_percentage,
"current_agent": analysis.current_agent,
"updated_at": analysis.updated_at,
}
finally:
db.close()
def _update_status(
self,
analysis_id: str,
status: Optional[str] = None,
progress: Optional[int] = None,
current_agent: Optional[str] = None,
error_message: Optional[str] = None,
):
"""Update analysis status in database and notify callbacks."""
db = SessionLocal()
try:
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if not analysis:
return
if status:
analysis.status = status
if progress is not None:
analysis.progress_percentage = progress
if current_agent:
analysis.current_agent = current_agent
if error_message:
analysis.error_message = error_message
analysis.updated_at = datetime.utcnow()
if status == "completed":
analysis.completed_at = datetime.utcnow()
analysis.progress_percentage = 100
db.commit()
db.refresh(analysis)
# Notify callbacks
status_data = {
"type": "status_update",
"analysis_id": analysis.id,
"status": analysis.status,
"progress_percentage": analysis.progress_percentage,
"current_agent": analysis.current_agent,
"timestamp": analysis.updated_at.isoformat(),
}
self._notify_callbacks(analysis_id, status_data)
finally:
db.close()
def _store_log(self, analysis_id: str, log_type: str, content: str):
"""Store a log entry."""
db = SessionLocal()
try:
log = AnalysisLog(
analysis_id=analysis_id,
log_type=log_type,
content=content,
timestamp=datetime.utcnow(),
)
db.add(log)
db.commit()
finally:
db.close()
def _store_report(self, analysis_id: str, report_type: str, content: str):
"""Store or update a report section."""
db = SessionLocal()
try:
# Check if report already exists
report = (
db.query(AnalysisReport)
.filter(
AnalysisReport.analysis_id == analysis_id,
AnalysisReport.report_type == report_type,
)
.first()
)
if report:
# Update existing
report.content = content
report.created_at = datetime.utcnow()
else:
# Create new
report = AnalysisReport(
analysis_id=analysis_id,
report_type=report_type,
content=content,
)
db.add(report)
db.commit()
finally:
db.close()
def _run_analysis(
self,
analysis_id: str,
ticker: str,
analysis_date: str,
selected_analysts: List[str],
config: Dict[str, Any],
):
"""Execute the analysis (runs in thread pool)."""
logger.info(f"Starting analysis {analysis_id} for {ticker} on {analysis_date}")
try:
# Update status to running
self._update_status(analysis_id, status="running", progress=0)
logger.info(f"Analysis {analysis_id}: Initializing trading graph...")
# Initialize the graph
graph = TradingAgentsGraph(
selected_analysts=selected_analysts,
config=config,
debug=False,
)
# Create initial state
init_agent_state = graph.propagator.create_initial_state(ticker, analysis_date)
init_agent_state["asset_class"] = config.get("asset_class", "equity")
args = graph.propagator.get_graph_args()
# Track agent progress
agent_order = self._get_agent_order(selected_analysts)
total_agents = len(agent_order)
current_agent_index = 0
# Stream the analysis
trace = []
for chunk in graph.graph.stream(init_agent_state, **args):
if len(chunk.get("messages", [])) == 0:
continue
# Process the chunk
last_message = chunk["messages"][-1]
# Extract content
if hasattr(last_message, "content"):
content = self._extract_content(last_message.content)
msg_type = "Reasoning"
# Store log
self._store_log(analysis_id, msg_type, content)
# Handle tool calls
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict):
tool_name = tool_call["name"]
tool_args = tool_call["args"]
else:
tool_name = tool_call.name
tool_args = tool_call.args
args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items())
self._store_log(
analysis_id, "Tool Call", f"{tool_name}({args_str})"
)
# Check for completed reports
for report_type in [
"market_report",
"sentiment_report",
"news_report",
"fundamentals_report",
]:
if report_type in chunk and chunk[report_type]:
self._store_report(analysis_id, report_type, chunk[report_type])
current_agent_index += 1
progress = int((current_agent_index / total_agents) * 100)
agent_name = self._get_agent_name(report_type)
self._update_status(
analysis_id,
progress=min(progress, 95),
current_agent=agent_name,
)
# Check for investment debate state
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
debate_state = chunk["investment_debate_state"]
if "judge_decision" in debate_state and debate_state["judge_decision"]:
self._store_report(
analysis_id, "investment_plan", debate_state["judge_decision"]
)
current_agent_index += 1
progress = int((current_agent_index / total_agents) * 100)
self._update_status(
analysis_id,
progress=min(progress, 98),
current_agent="Research Manager",
)
# Check for trader plan
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
self._store_report(
analysis_id, "trader_investment_plan", chunk["trader_investment_plan"]
)
self._update_status(
analysis_id,
progress=99,
current_agent="Trader",
)
trace.append(chunk)
# Get final state
if trace:
final_state = trace[-1]
# Store final trade decision
if "final_trade_decision" in final_state:
decision = graph.process_signal(final_state["final_trade_decision"])
self._store_report(
analysis_id, "final_trade_decision", final_state["final_trade_decision"]
)
self._store_log(
analysis_id, "System", f"Final decision: {decision}"
)
# Mark as completed
logger.info(f"Analysis {analysis_id} completed successfully")
self._update_status(analysis_id, status="completed", progress=100)
except Exception as e:
error_msg = str(e)
error_trace = traceback.format_exc()
logger.error(f"Analysis {analysis_id} failed: {error_msg}")
logger.error(f"Traceback:\n{error_trace}")
self._update_status(
analysis_id, status="failed", error_message=error_msg
)
self._store_log(analysis_id, "System", f"Error: {error_msg}\n\nTraceback:\n{error_trace}")
def _get_agent_order(self, selected_analysts: List[str]) -> List[str]:
"""Get the order of agents for progress tracking."""
agents = selected_analysts.copy()
agents.extend(["bull_researcher", "bear_researcher", "research_manager", "trader", "risk", "portfolio"])
return agents
def _get_agent_name(self, report_type: str) -> str:
"""Get human-readable agent name from report type."""
mapping = {
"market_report": "Market Analyst",
"sentiment_report": "Social Analyst",
"news_report": "News Analyst",
"fundamentals_report": "Fundamentals Analyst",
}
return mapping.get(report_type, "Unknown")
def _extract_content(self, content: Any) -> str:
"""Extract string content from various message formats."""
if isinstance(content, str):
return content
elif isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "tool_use":
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
return " ".join(text_parts)
else:
return str(content)
def shutdown(self):
"""Shutdown the executor and cancel all running analyses."""
with self._lock:
# Cancel all active analyses
for analysis_id in list(self.active_analyses.keys()):
self.cancel_analysis(analysis_id)
# Shutdown executor
self.executor.shutdown(wait=True)
# Global executor instance
_executor: Optional[AnalysisExecutor] = None
def get_executor() -> AnalysisExecutor:
"""Get the global executor instance."""
global _executor
if _executor is None:
_executor = AnalysisExecutor()
return _executor
def shutdown_executor():
"""Shutdown the global executor."""
global _executor
if _executor:
_executor.shutdown()
_executor = None

View File

@ -0,0 +1,2 @@
"""WebSocket handlers."""

130
api/websockets/status.py Normal file
View File

@ -0,0 +1,130 @@
"""WebSocket status streaming."""
import json
from typing import Dict, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy.orm import Session
from api.database import Analysis, SessionLocal
from api.state_manager import get_executor
router = APIRouter()
class ConnectionManager:
"""Manages WebSocket connections for analyses."""
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}
async def connect(self, analysis_id: str, websocket: WebSocket):
"""Accept and register a WebSocket connection."""
await websocket.accept()
if analysis_id not in self.active_connections:
self.active_connections[analysis_id] = []
self.active_connections[analysis_id].append(websocket)
# Register callback with executor
executor = get_executor()
executor.register_status_callback(analysis_id, self._create_callback(analysis_id))
def disconnect(self, analysis_id: str, websocket: WebSocket):
"""Remove a WebSocket connection."""
if analysis_id in self.active_connections:
if websocket in self.active_connections[analysis_id]:
self.active_connections[analysis_id].remove(websocket)
# Clean up if no more connections
if not self.active_connections[analysis_id]:
del self.active_connections[analysis_id]
def _create_callback(self, analysis_id: str):
"""Create a callback function for status updates."""
def callback(status_data: dict):
# Note: This is called from a thread, so we can't use async here
# The actual broadcasting happens via the websocket event loop
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.create_task(self.broadcast(analysis_id, status_data))
else:
loop.run_until_complete(self.broadcast(analysis_id, status_data))
except:
# If no loop, we can't broadcast (connection will poll status instead)
pass
return callback
async def broadcast(self, analysis_id: str, message: dict):
"""Broadcast a message to all connections for an analysis."""
if analysis_id not in self.active_connections:
return
disconnected = []
for connection in self.active_connections[analysis_id]:
try:
await connection.send_json(message)
except:
disconnected.append(connection)
# Clean up disconnected clients
for connection in disconnected:
self.disconnect(analysis_id, connection)
# Global connection manager
manager = ConnectionManager()
@router.websocket("/api/v1/ws/analyses/{analysis_id}")
async def websocket_analysis_status(websocket: WebSocket, analysis_id: str):
"""WebSocket endpoint for real-time analysis status updates."""
# Verify analysis exists
db = SessionLocal()
try:
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if not analysis:
await websocket.close(code=1008, reason="Analysis not found")
return
finally:
db.close()
# Connect
await manager.connect(analysis_id, websocket)
try:
# Send initial status
db = SessionLocal()
try:
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
if analysis:
initial_status = {
"type": "status_update",
"analysis_id": analysis.id,
"status": analysis.status,
"progress_percentage": analysis.progress_percentage,
"current_agent": analysis.current_agent,
"timestamp": analysis.updated_at.isoformat(),
}
await websocket.send_json(initial_status)
finally:
db.close()
# Keep connection alive and handle messages
while True:
# Wait for any messages from client (like ping)
data = await websocket.receive_text()
# Echo back if it's a ping
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
manager.disconnect(analysis_id, websocket)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(analysis_id, websocket)

BIN
api_database.db Normal file

Binary file not shown.

View File

@ -24,3 +24,12 @@ rich
questionary
langchain_anthropic
langchain-google-genai
# API Dependencies
fastapi==0.109.0
uvicorn[standard]==0.27.0
sqlalchemy==2.0.25
passlib[bcrypt]>=1.7.4
bcrypt>=4.0.0
python-multipart==0.0.6
websockets==12.0

17
run_api.sh Executable file
View File

@ -0,0 +1,17 @@
#!/bin/bash
# Startup script for Trading Agents API
echo "Starting Trading Agents API..."
echo ""
echo "Make sure you have:"
echo "1. Installed dependencies: pip install -r requirements.txt"
echo "2. Initialized database: python -m api.cli_admin init-database"
echo "3. Created an API key: python -m api.cli_admin create-key 'My Key'"
echo ""
echo "API will be available at: http://localhost:8001"
echo "API Documentation: http://localhost:8001/docs"
echo ""
cd "$(dirname "$0")"
python -m api.main