diff --git a/README.md b/README.md index c246c384..a3685271 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ An interface will appear showing results as they load, letting you track the age ### Implementation Details -We built Litadel with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls. +We built Litadel with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o1-mini` and `gpt-4o-mini` to save on costs as our framework makes **lots of** API calls. ### Python Usage @@ -157,8 +157,8 @@ from tradingagents.default_config import DEFAULT_CONFIG # Create a custom config config = DEFAULT_CONFIG.copy() -config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model -config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model +config["deep_think_llm"] = "o1-mini" # Use a different model +config["quick_think_llm"] = "gpt-4o-mini" # Use a different model config["max_debate_rounds"] = 1 # Increase debate rounds # Configure data vendors (default uses yfinance and Alpha Vantage) diff --git a/api/API_IMPLEMENTATION_SUMMARY.md b/api/API_IMPLEMENTATION_SUMMARY.md new file mode 100644 index 00000000..0c4ee0d9 --- /dev/null +++ b/api/API_IMPLEMENTATION_SUMMARY.md @@ -0,0 +1,286 @@ +# FastAPI Trading Agents API - Implementation Summary + +## โœ… Implementation Complete + +The FastAPI Trading Agents API has been successfully implemented with all planned features. + +## ๐Ÿ“ Project Structure + +``` +TradingAgents/ +โ”œโ”€โ”€ api/ +โ”‚ โ”œโ”€โ”€ __init__.py โœ… Created +โ”‚ โ”œโ”€โ”€ main.py โœ… FastAPI application +โ”‚ โ”œโ”€โ”€ database.py โœ… SQLAlchemy models +โ”‚ โ”œโ”€โ”€ auth.py โœ… API key authentication +โ”‚ โ”œโ”€โ”€ state_manager.py โœ… Analysis execution manager +โ”‚ โ”œโ”€โ”€ cli_admin.py โœ… Admin CLI tool +โ”‚ โ”œโ”€โ”€ example_client.py โœ… Example Python client +โ”‚ โ”œโ”€โ”€ README.md โœ… Full documentation +โ”‚ โ”œโ”€โ”€ models/ +โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py โœ… Exports +โ”‚ โ”‚ โ”œโ”€โ”€ requests.py โœ… Pydantic request models +โ”‚ โ”‚ โ””โ”€โ”€ responses.py โœ… Pydantic response models +โ”‚ โ”œโ”€โ”€ endpoints/ +โ”‚ โ”‚ โ”œโ”€โ”€ __init__.py โœ… Created +โ”‚ โ”‚ โ”œโ”€โ”€ analyses.py โœ… Analysis CRUD endpoints +โ”‚ โ”‚ โ”œโ”€โ”€ tickers.py โœ… Ticker history endpoints +โ”‚ โ”‚ โ””โ”€โ”€ data.py โœ… Cached data endpoints +โ”‚ โ””โ”€โ”€ websockets/ +โ”‚ โ”œโ”€โ”€ __init__.py โœ… Created +โ”‚ โ””โ”€โ”€ status.py โœ… Real-time status updates +โ”œโ”€โ”€ requirements-api.txt โœ… API dependencies +โ”œโ”€โ”€ run_api.sh โœ… Startup script +โ”œโ”€โ”€ API_QUICKSTART.md โœ… Quick start guide +โ””โ”€โ”€ API_IMPLEMENTATION_SUMMARY.md โœ… This file +``` + +## ๐ŸŽฏ Features Implemented + +### Core API Features +- โœ… REST API with FastAPI +- โœ… WebSocket support for real-time updates +- โœ… SQLite database with SQLAlchemy ORM +- โœ… API key authentication with bcrypt hashing +- โœ… Parallel analysis execution (ThreadPoolExecutor) +- โœ… CORS middleware for frontend integration +- โœ… Auto-generated OpenAPI documentation + +### Database Schema +- โœ… `analyses` - Analysis metadata and status +- โœ… `analysis_logs` - Tool calls and reasoning logs +- โœ… `analysis_reports` - Generated reports by type +- โœ… `api_keys` - Hashed authentication keys + +### REST Endpoints + +#### Analyses (`/api/v1/analyses`) +- โœ… `POST /` - Create and start new analysis +- โœ… `GET /` - List all analyses with filtering +- โœ… `GET /{id}` - Get full analysis details +- โœ… `GET /{id}/status` - Get current status +- โœ… `GET /{id}/reports` - Get all reports +- โœ… `GET /{id}/reports/{type}` - Get specific report +- โœ… `GET /{id}/logs` - Get execution logs +- โœ… `DELETE /{id}` - Cancel/delete analysis + +#### Tickers (`/api/v1/tickers`) +- โœ… `GET /` - List all tickers with counts +- โœ… `GET /{ticker}/analyses` - Get analyses for ticker +- โœ… `GET /{ticker}/latest` - Get latest analysis + +#### Data (`/api/v1/data`) +- โœ… `GET /cache` - List cached tickers +- โœ… `GET /cache/{ticker}` - Get cached market data + +#### WebSocket (`/api/v1/ws`) +- โœ… `WS /analyses/{id}` - Real-time status streaming + +### State Management +- โœ… `AnalysisExecutor` class for managing parallel execution +- โœ… ThreadPoolExecutor with configurable max workers +- โœ… Status callbacks for WebSocket broadcasting +- โœ… Real-time progress tracking +- โœ… Graceful shutdown and cleanup + +### Admin Tools +- โœ… `cli_admin.py` - Command-line admin interface + - โœ… `create-key` - Generate new API key + - โœ… `list-keys` - Show all keys + - โœ… `revoke-key` - Deactivate key + - โœ… `activate-key` - Reactivate key + - โœ… `init-database` - Initialize database + +### Documentation +- โœ… `API_QUICKSTART.md` - Quick start guide +- โœ… `api/README.md` - Full API documentation +- โœ… `api/example_client.py` - Working example client +- โœ… Auto-generated Swagger UI at `/docs` +- โœ… Auto-generated ReDoc at `/redoc` + +## ๐Ÿ”ง Configuration + +### Environment Variables +- `API_DATABASE_URL` - Database connection (default: SQLite) +- `MAX_CONCURRENT_ANALYSES` - Concurrent analysis limit (default: 4) +- Standard TradingAgents env vars (OPENAI_API_KEY, etc.) + +### Parallel Execution +- Default: 4 concurrent analyses +- Configurable via environment variable +- Thread-safe database operations +- Independent graph instances per analysis + +## ๐Ÿ“Š Data Flow + +1. **Client** sends POST request to create analysis +2. **API** creates database record, returns analysis_id +3. **Executor** starts analysis in background thread +4. **Graph** streams chunks during execution +5. **State Manager** captures logs, reports, tool calls +6. **Database** stores all data in real-time +7. **WebSocket** broadcasts status updates +8. **Client** polls or streams for results + +## ๐Ÿ” Security + +- โœ… API key authentication (bcrypt hashed) +- โœ… Secure password hashing with passlib +- โœ… CORS middleware (configurable origins) +- โœ… Input validation with Pydantic +- โœ… SQL injection prevention (SQLAlchemy ORM) + +## ๐Ÿงช Testing + +### Manual Testing +```bash +# 1. Initialize +python -m api.cli_admin init-database +python -m api.cli_admin create-key "Test Key" + +# 2. Start API +python -m api.main + +# 3. Test endpoints +curl http://localhost:8000/health +curl -X POST http://localhost:8000/api/v1/analyses \ + -H "X-API-Key: YOUR_KEY" \ + -H "Content-Type: application/json" \ + -d '{"ticker": "AAPL", "analysis_date": "2025-10-21", "selected_analysts": ["market", "news"], "research_depth": 1}' + +# 4. Run example client +python -m api.example_client +``` + +### Automated Testing +Recommended test suite (not included in this implementation): +- Unit tests for each endpoint +- Integration tests for analysis workflow +- WebSocket connection tests +- Concurrent execution tests +- Authentication tests + +## ๐Ÿ“ˆ Performance + +- **Concurrency**: 4 parallel analyses by default (configurable) +- **Database**: SQLite for development, PostgreSQL recommended for production +- **Threading**: Thread-safe operations throughout +- **WebSocket**: Efficient real-time updates +- **Caching**: Leverages existing TradingAgents data cache + +## ๐Ÿš€ Deployment + +### Development +```bash +python -m api.main +# or +./run_api.sh +``` + +### Production Checklist +- [ ] Switch to PostgreSQL +- [ ] Configure specific CORS origins +- [ ] Enable HTTPS/WSS (reverse proxy) +- [ ] Add rate limiting +- [ ] Set up monitoring/logging +- [ ] Configure database backups +- [ ] Set environment-specific configs + +## ๐ŸŽ“ Usage Examples + +### cURL +```bash +# Create analysis +curl -X POST http://localhost:8000/api/v1/analyses \ + -H "X-API-Key: YOUR_KEY" \ + -H "Content-Type: application/json" \ + -d '{"ticker": "AAPL", "analysis_date": "2025-10-21"}' + +# Get status +curl http://localhost:8000/api/v1/analyses/{id}/status \ + -H "X-API-Key: YOUR_KEY" +``` + +### Python +```python +from api.example_client import TradingAgentsAPIClient + +client = TradingAgentsAPIClient("YOUR_API_KEY") +analysis = await client.create_analysis("AAPL") +await client.monitor_via_websocket(analysis["id"]) +``` + +### JavaScript +```javascript +const response = await fetch('http://localhost:8000/api/v1/analyses', { + method: 'POST', + headers: { + 'X-API-Key': 'YOUR_KEY', + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + ticker: 'AAPL', + analysis_date: '2025-10-21', + }), +}); +const analysis = await response.json(); +``` + +## โœจ Key Achievements + +1. **Separation of Concerns**: API in separate `api/` directory +2. **Persistent Storage**: SQLite with easy PostgreSQL migration +3. **Real-time Updates**: WebSocket streaming of status +4. **Parallel Processing**: ThreadPoolExecutor with configurable concurrency +5. **Complete Logging**: All tool calls, reasoning, and reports saved +6. **Security**: API key authentication with bcrypt +7. **Documentation**: Comprehensive docs with examples +8. **Admin Tools**: CLI for key management +9. **Easy Setup**: Quick start in < 5 minutes + +## ๐Ÿ”„ Integration Points + +### With Existing TradingAgents +- โœ… Uses `TradingAgentsGraph` for execution +- โœ… Leverages `DEFAULT_CONFIG` for configuration +- โœ… Integrates with asset detection +- โœ… Uses existing data cache +- โœ… Compatible with all LLM providers + +### For Frontend Development +- โœ… RESTful API design +- โœ… WebSocket for real-time updates +- โœ… CORS enabled for cross-origin requests +- โœ… Comprehensive error responses +- โœ… Pagination support +- โœ… Filtering and search + +## ๐Ÿ“ Next Steps (Optional Enhancements) + +Future improvements could include: +- User authentication and multi-tenancy +- Rate limiting per API key +- Analysis templates +- Scheduled/recurring analyses +- Email/webhook notifications +- Analysis comparison tools +- Performance metrics dashboard +- Analysis export (PDF, CSV) +- Bulk operations API +- GraphQL endpoint + +## ๐ŸŽ‰ Conclusion + +The FastAPI Trading Agents API is fully implemented and ready for use. It provides a robust, scalable foundation for frontend applications to interact with the TradingAgents multi-agent system. + +### Getting Started +1. Read `API_QUICKSTART.md` for setup instructions +2. Check `api/README.md` for full documentation +3. Run `api/example_client.py` to see it in action +4. Start building your frontend! + +### Support +- API Documentation: http://localhost:8000/docs +- Project README: `api/README.md` +- Quick Start: `API_QUICKSTART.md` + diff --git a/api/API_QUICKSTART.md b/api/API_QUICKSTART.md new file mode 100644 index 00000000..a02ad844 --- /dev/null +++ b/api/API_QUICKSTART.md @@ -0,0 +1,300 @@ +# Trading Agents API - Quick Start Guide + +This guide will get your FastAPI Trading Agents API up and running in minutes. + +## Prerequisites + +- Python 3.10+ +- TradingAgents installed and configured +- Required environment variables (OPENAI_API_KEY, ALPHA_VANTAGE_API_KEY, etc.) + +## Setup (5 minutes) + +### 1. Install API Dependencies + +```bash +cd TradingAgents +pip install -r requirements-api.txt +``` + +### 2. Initialize Database + +```bash +python -m api.cli_admin init-database +``` + +### 3. Create Your First API Key + +```bash +python -m api.cli_admin create-key "Development Key" +``` + +**IMPORTANT**: Save the API key that's displayed. You won't be able to see it again! + +Example output: +``` +โœ“ API Key created successfully! + +Name: Development Key +Created: 2025-10-21 14:30:00 + +API Key (save this, it won't be shown again): +xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + +Use this key in the X-API-Key header for all API requests. +``` + +### 4. Start the API Server + +```bash +python -m api.main +``` + +Or use the startup script: +```bash +./run_api.sh +``` + +The API will start at: `http://localhost:8000` + +## Test Your API (2 minutes) + +### View API Documentation + +Open your browser to: +- **Swagger UI**: http://localhost:8000/docs +- **ReDoc**: http://localhost:8000/redoc + +### Create Your First Analysis + +Replace `YOUR_API_KEY` with the key you created: + +```bash +curl -X POST "http://localhost:8000/api/v1/analyses" \ + -H "X-API-Key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "ticker": "AAPL", + "analysis_date": "2025-10-21", + "selected_analysts": ["market", "news"], + "research_depth": 1 + }' +``` + +You'll get a response with an `analysis_id`: +```json +{ + "id": "550e8400-e29b-41d4-a716-446655440000", + "ticker": "AAPL", + "status": "pending", + ... +} +``` + +### Check Status + +```bash +curl "http://localhost:8000/api/v1/analyses/YOUR_ANALYSIS_ID/status" \ + -H "X-API-Key: YOUR_API_KEY" +``` + +### Get Results + +```bash +curl "http://localhost:8000/api/v1/analyses/YOUR_ANALYSIS_ID" \ + -H "X-API-Key: YOUR_API_KEY" +``` + +## Using the Python Client + +### Install Additional Dependencies + +```bash +pip install httpx websockets +``` + +### Run Example Client + +Edit `api/example_client.py` and replace `YOUR_API_KEY`, then: + +```bash +python -m api.example_client +``` + +This will: +1. Create an analysis for AAPL +2. Monitor it via WebSocket +3. Display the results + +## Next Steps + +### Multiple Parallel Analyses + +Create multiple analyses at once - they'll run in parallel: + +```bash +# Start 3 analyses +for ticker in AAPL MSFT GOOGL; do + curl -X POST "http://localhost:8000/api/v1/analyses" \ + -H "X-API-Key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d "{\"ticker\": \"$ticker\", \"analysis_date\": \"2025-10-21\", \"selected_analysts\": [\"market\", \"news\"], \"research_depth\": 1}" +done +``` + +### List All Analyses + +```bash +curl "http://localhost:8000/api/v1/analyses" \ + -H "X-API-Key: YOUR_API_KEY" +``` + +### Get Analyses for Specific Ticker + +```bash +curl "http://localhost:8000/api/v1/tickers/AAPL/analyses" \ + -H "X-API-Key: YOUR_API_KEY" +``` + +### View Cached Market Data + +```bash +# List all cached tickers +curl "http://localhost:8000/api/v1/data/cache" \ + -H "X-API-Key: YOUR_API_KEY" + +# Get data for specific ticker +curl "http://localhost:8000/api/v1/data/cache/AAPL" \ + -H "X-API-Key: YOUR_API_KEY" +``` + +## WebSocket Real-Time Monitoring + +### JavaScript Example + +```javascript +const ws = new WebSocket('ws://localhost:8000/api/v1/ws/analyses/YOUR_ANALYSIS_ID'); + +ws.onmessage = (event) => { + const update = JSON.parse(event.data); + console.log(`Status: ${update.status}, Progress: ${update.progress_percentage}%`); + + if (update.status === 'completed') { + console.log('Analysis finished!'); + ws.close(); + } +}; + +ws.onerror = (error) => console.error('WebSocket error:', error); +``` + +### Python Example + +```python +import asyncio +import websockets +import json + +async def monitor_analysis(analysis_id, api_key): + uri = f"ws://localhost:8000/api/v1/ws/analyses/{analysis_id}" + + async with websockets.connect(uri) as websocket: + while True: + message = await websocket.recv() + data = json.loads(message) + + print(f"Status: {data['status']}, Progress: {data['progress_percentage']}%") + + if data['status'] in ['completed', 'failed', 'cancelled']: + break + +asyncio.run(monitor_analysis('YOUR_ANALYSIS_ID', 'YOUR_API_KEY')) +``` + +## API Key Management + +### List All Keys + +```bash +python -m api.cli_admin list-keys +``` + +### Create Additional Key + +```bash +python -m api.cli_admin create-key "Frontend App" +``` + +### Revoke a Key + +```bash +python -m api.cli_admin revoke-key 1 +``` + +### Activate a Key + +```bash +python -m api.cli_admin activate-key 1 +``` + +## Configuration + +### Environment Variables + +Set these before starting the API: + +```bash +# Maximum concurrent analyses (default: 4) +export MAX_CONCURRENT_ANALYSES=8 + +# Database URL (default: SQLite in current directory) +export API_DATABASE_URL="sqlite:///./api_database.db" + +# Or use PostgreSQL for production +export API_DATABASE_URL="postgresql://user:pass@localhost/trading_agents" + +# Standard TradingAgents config +export OPENAI_API_KEY="your-key" +export ALPHA_VANTAGE_API_KEY="your-key" +``` + +## Troubleshooting + +### "Invalid or inactive API key" +- Make sure you're using the exact key that was displayed when you created it +- Check that the key hasn't been revoked: `python -m api.cli_admin list-keys` + +### "Analysis not found" +- The analysis ID must be exactly as returned from the create endpoint +- Check available analyses: `curl "http://localhost:8000/api/v1/analyses" -H "X-API-Key: YOUR_KEY"` + +### Database locked +- SQLite has limited concurrency +- Reduce `MAX_CONCURRENT_ANALYSES` to 2-3 +- Or switch to PostgreSQL + +### Import errors +- Make sure you're in the TradingAgents directory +- Run with: `python -m api.main` (not `python api/main.py`) + +## Full API Reference + +See `api/README.md` for complete documentation of all endpoints and features. + +## Production Deployment + +For production use, see the deployment section in `api/README.md`. Key points: + +1. Use PostgreSQL instead of SQLite +2. Configure CORS for your frontend domain +3. Use HTTPS/WSS with reverse proxy +4. Add rate limiting +5. Set up monitoring and logging + +## Support + +For issues or questions: +- Check the full documentation: `api/README.md` +- Review the main TradingAgents README +- Check the interactive docs: http://localhost:8000/docs + diff --git a/api/README.md b/api/README.md new file mode 100644 index 00000000..0b07dd0c --- /dev/null +++ b/api/README.md @@ -0,0 +1,257 @@ +# Trading Agents REST API + +FastAPI-based REST API for managing multi-agent trading analyses with real-time WebSocket support. + +## Features + +- **REST API** for creating, monitoring, and managing trading analyses +- **WebSocket** support for real-time status updates +- **Parallel execution** of multiple analyses (configurable concurrency) +- **SQLite database** for persistent storage +- **API key authentication** for secure access +- **Complete analysis history** with logs and reports + +## Installation + +1. Install API dependencies: +```bash +cd TradingAgents +pip install -r requirements-api.txt +``` + +2. Initialize the database and create an API key: +```bash +python -m api.cli_admin init-database +python -m api.cli_admin create-key "My First Key" +``` + +Save the generated API key - you'll need it for all API requests. + +## Running the API + +Start the API server: +```bash +python -m api.main +``` + +Or with uvicorn directly: +```bash +uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload +``` + +The API will be available at `http://localhost:8000` + +## API Documentation + +Once the server is running, visit: +- **Swagger UI**: http://localhost:8000/docs +- **ReDoc**: http://localhost:8000/redoc + +## Quick Start + +### 1. Create an Analysis + +```bash +curl -X POST "http://localhost:8000/api/v1/analyses" \ + -H "X-API-Key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "ticker": "AAPL", + "analysis_date": "2025-10-21", + "selected_analysts": ["market", "news", "social", "fundamentals"], + "research_depth": 1 + }' +``` + +Response: +```json +{ + "id": "uuid-here", + "ticker": "AAPL", + "analysis_date": "2025-10-21", + "status": "pending", + "progress_percentage": 0, + ... +} +``` + +### 2. Monitor via WebSocket + +```javascript +const ws = new WebSocket('ws://localhost:8000/api/v1/ws/analyses/{analysis_id}'); + +ws.onmessage = (event) => { + const update = JSON.parse(event.data); + console.log(`Status: ${update.status}, Progress: ${update.progress_percentage}%`); +}; +``` + +### 3. Get Analysis Results + +```bash +curl "http://localhost:8000/api/v1/analyses/{analysis_id}" \ + -H "X-API-Key: YOUR_API_KEY" +``` + +## API Endpoints + +### Analyses + +- `POST /api/v1/analyses` - Create and start new analysis +- `GET /api/v1/analyses` - List all analyses (with filtering) +- `GET /api/v1/analyses/{id}` - Get full analysis details +- `GET /api/v1/analyses/{id}/status` - Get current status +- `GET /api/v1/analyses/{id}/reports` - Get all reports +- `GET /api/v1/analyses/{id}/reports/{type}` - Get specific report +- `GET /api/v1/analyses/{id}/logs` - Get execution logs +- `DELETE /api/v1/analyses/{id}` - Cancel/delete analysis + +### Tickers + +- `GET /api/v1/tickers` - List all tickers with analysis counts +- `GET /api/v1/tickers/{ticker}/analyses` - Get all analyses for ticker +- `GET /api/v1/tickers/{ticker}/latest` - Get latest analysis for ticker + +### Data + +- `GET /api/v1/data/cache` - List cached ticker data +- `GET /api/v1/data/cache/{ticker}` - Get cached market data + +### WebSocket + +- `WS /api/v1/ws/analyses/{id}` - Real-time status updates + +## API Key Management + +### Create API Key +```bash +python -m api.cli_admin create-key "Description" +``` + +### List API Keys +```bash +python -m api.cli_admin list-keys +``` + +### Revoke API Key +```bash +python -m api.cli_admin revoke-key +``` + +### Activate API Key +```bash +python -m api.cli_admin activate-key +``` + +## Configuration + +### Environment Variables + +- `API_DATABASE_URL` - Database connection string (default: `sqlite:///./api_database.db`) +- `MAX_CONCURRENT_ANALYSES` - Maximum parallel analyses (default: `4`) +- Standard TradingAgents config (LLM providers, API keys, etc.) + +## Architecture + +``` +api/ +โ”œโ”€โ”€ main.py # FastAPI application +โ”œโ”€โ”€ database.py # SQLAlchemy models +โ”œโ”€โ”€ auth.py # API key authentication +โ”œโ”€โ”€ state_manager.py # Analysis execution manager +โ”œโ”€โ”€ models/ # Pydantic request/response models +โ”œโ”€โ”€ endpoints/ # REST endpoint handlers +โ””โ”€โ”€ websockets/ # WebSocket handlers +``` + +## Database Schema + +- **analyses** - Analysis metadata and status +- **analysis_logs** - Execution logs (tool calls, reasoning) +- **analysis_reports** - Generated reports (by type) +- **api_keys** - Authentication keys + +## Parallel Execution + +The API uses a `ThreadPoolExecutor` to run multiple analyses concurrently: + +- Default: 4 concurrent analyses +- Configurable via `MAX_CONCURRENT_ANALYSES` env var +- Each analysis runs in its own thread +- Database writes are thread-safe +- Cached data is read-only (thread-safe) + +## Example Frontend Integration + +```javascript +class TradingAnalysisClient { + constructor(apiKey, baseURL = 'http://localhost:8000') { + this.apiKey = apiKey; + this.baseURL = baseURL; + } + + async createAnalysis(ticker, date, analysts = ['market', 'news']) { + const response = await fetch(`${this.baseURL}/api/v1/analyses`, { + method: 'POST', + headers: { + 'X-API-Key': this.apiKey, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + ticker, + analysis_date: date, + selected_analysts: analysts, + research_depth: 1, + }), + }); + return await response.json(); + } + + connectWebSocket(analysisId, onUpdate) { + const ws = new WebSocket(`ws://localhost:8000/api/v1/ws/analyses/${analysisId}`); + ws.onmessage = (event) => onUpdate(JSON.parse(event.data)); + return ws; + } + + async getAnalysis(analysisId) { + const response = await fetch( + `${this.baseURL}/api/v1/analyses/${analysisId}`, + { headers: { 'X-API-Key': this.apiKey } } + ); + return await response.json(); + } +} +``` + +## Troubleshooting + +### Database locked error +SQLite has limited concurrent write support. If you get database locked errors: +- Reduce `MAX_CONCURRENT_ANALYSES` +- Or switch to PostgreSQL by changing `API_DATABASE_URL` + +### Import errors +Make sure you're running from the TradingAgents root directory: +```bash +cd TradingAgents +python -m api.main +``` + +### WebSocket connection refused +Check that the server is running and CORS is properly configured for your frontend origin. + +## Production Deployment + +For production use: + +1. **Use PostgreSQL** instead of SQLite +2. **Secure CORS** - Set specific allowed origins in `main.py` +3. **HTTPS/WSS** - Use reverse proxy (nginx) with SSL +4. **Monitoring** - Add logging, metrics, and health checks +5. **Rate limiting** - Add rate limiting middleware +6. **Backup** - Regular database backups + +## License + +Same as TradingAgents/Litadel main project. + diff --git a/api/START_API.md b/api/START_API.md new file mode 100644 index 00000000..475e354d --- /dev/null +++ b/api/START_API.md @@ -0,0 +1,114 @@ +# Quick Start - Trading Agents API + +## 1. Install Dependencies (if not already done) + +```bash +cd TradingAgents +pip install -r requirements.txt +``` + +## 2. Initialize Database & Create API Key + +```bash +# Initialize the database +python -m api.cli_admin init-database + +# Create your first API key +python -m api.cli_admin create-key "My Development Key" +``` + +**IMPORTANT**: Save the API key that's displayed! You'll need it for all requests. + +## 3. Start the API + +```bash +python -m api.main +``` + +Or use the startup script: +```bash +./run_api.sh +``` + +The API will start at: **http://localhost:8001** + +## 4. Test It + +Open your browser to see the interactive API documentation: +- **Swagger UI**: http://localhost:8001/docs +- **ReDoc**: http://localhost:8001/redoc + +## 5. Create Your First Analysis + +Using curl (replace `YOUR_API_KEY` with the key from step 2): + +```bash +curl -X POST "http://localhost:8001/api/v1/analyses" \ + -H "X-API-Key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "ticker": "AAPL", + "analysis_date": "2025-10-21", + "selected_analysts": ["market", "news"], + "research_depth": 1 + }' +``` + +You'll get back an `analysis_id`. Use it to check status: + +```bash +curl "http://localhost:8001/api/v1/analyses/YOUR_ANALYSIS_ID/status" \ + -H "X-API-Key: YOUR_API_KEY" +``` + +## Configuration (Optional) + +Set environment variables before starting: + +```bash +# Maximum concurrent analyses (default: 4) +export MAX_CONCURRENT_ANALYSES=8 + +# Your LLM API keys (if not already set) +export OPENAI_API_KEY="your-key" +export ALPHA_VANTAGE_API_KEY="your-key" + +# Then start the API +python -m api.main +``` + +## Common Commands + +```bash +# List all API keys +python -m api.cli_admin list-keys + +# Create a new API key +python -m api.cli_admin create-key "Frontend App" + +# Revoke an API key (use ID from list-keys) +python -m api.cli_admin revoke-key 1 +``` + +## Full Documentation + +- Quick Start: `API_QUICKSTART.md` +- Full API Docs: `api/README.md` +- Implementation Details: `API_IMPLEMENTATION_SUMMARY.md` + +## Troubleshooting + +**"Invalid or inactive API key"** +- Make sure you're using the exact key from step 2 +- Check: `python -m api.cli_admin list-keys` + +**Import errors** +- Make sure you're in the TradingAgents directory +- Use: `python -m api.main` (not `python api/main.py`) + +**Port 8001 already in use** +- Change port: `API_PORT=8002 python -m api.main` +- Or: `python -m uvicorn api.main:app --port 8002` + +That's it! Your API is ready to use. ๐Ÿš€ + diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 00000000..df5a80d9 --- /dev/null +++ b/api/__init__.py @@ -0,0 +1,2 @@ +"""FastAPI Trading Agents API.""" + diff --git a/api/auth.py b/api/auth.py new file mode 100644 index 00000000..f06cfb88 --- /dev/null +++ b/api/auth.py @@ -0,0 +1,101 @@ +"""API key authentication.""" + +import secrets +import warnings +from typing import Optional + +from fastapi import Depends, HTTPException, Security, status +from fastapi.security import APIKeyHeader + +# Suppress the known passlib/bcrypt compatibility warning +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=".*trapped.*") + from passlib.context import CryptContext + +from sqlalchemy.orm import Session + +from api.database import APIKey, get_db + +# Password hashing context +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# API key header scheme +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=True) + + +def hash_api_key(api_key: str) -> str: + """Hash an API key.""" + return pwd_context.hash(api_key) + + +def verify_api_key_hash(api_key: str, key_hash: str) -> bool: + """Verify an API key against its hash.""" + return pwd_context.verify(api_key, key_hash) + + +def generate_api_key() -> str: + """Generate a new random API key.""" + return secrets.token_urlsafe(32) + + +def create_api_key(db: Session, name: str) -> tuple[str, APIKey]: + """ + Create a new API key in the database. + + Returns: + tuple: (plain_key, db_record) + """ + plain_key = generate_api_key() + key_hash = hash_api_key(plain_key) + + db_key = APIKey(key_hash=key_hash, name=name, is_active=True) + db.add(db_key) + db.commit() + db.refresh(db_key) + + return plain_key, db_key + + +def get_api_key_by_hash(db: Session, key_hash: str) -> Optional[APIKey]: + """Get an API key record by its hash.""" + return db.query(APIKey).filter(APIKey.key_hash == key_hash).first() + + +def verify_api_key_from_db(db: Session, api_key: str) -> Optional[APIKey]: + """ + Verify an API key against the database. + + Returns: + APIKey record if valid and active, None otherwise + """ + # Get all active keys + active_keys = db.query(APIKey).filter(APIKey.is_active == True).all() + + # Check each one + for key_record in active_keys: + if verify_api_key_hash(api_key, key_record.key_hash): + return key_record + + return None + + +async def get_current_api_key( + api_key: str = Security(api_key_header), + db: Session = Depends(get_db), +) -> APIKey: + """ + Dependency to verify API key and return the key record. + + Raises: + HTTPException: If API key is invalid or inactive + """ + key_record = verify_api_key_from_db(db, api_key) + + if not key_record: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or inactive API key", + ) + + return key_record + diff --git a/api/cli_admin.py b/api/cli_admin.py new file mode 100644 index 00000000..59226185 --- /dev/null +++ b/api/cli_admin.py @@ -0,0 +1,146 @@ +"""Admin CLI tool for API key management.""" + +import sys +from datetime import datetime + +import typer +from rich.console import Console +from rich.table import Table + +from api.auth import create_api_key +from api.database import APIKey, SessionLocal, init_db + +app = typer.Typer( + name="api-admin", + help="Trading Agents API Administration Tool", +) +console = Console() + + +@app.command() +def create_key(name: str = typer.Argument(..., help="Name/description for the API key")): + """Create a new API key.""" + # Initialize database + init_db() + + db = SessionLocal() + try: + plain_key, db_key = create_api_key(db, name) + + console.print("\n[green]โœ“ API Key created successfully![/green]\n") + console.print(f"[bold]Name:[/bold] {db_key.name}") + console.print(f"[bold]Created:[/bold] {db_key.created_at}") + console.print(f"\n[bold yellow]API Key (save this, it won't be shown again):[/bold yellow]") + console.print(f"[cyan]{plain_key}[/cyan]\n") + console.print("[dim]Use this key in the X-API-Key header for all API requests.[/dim]\n") + except Exception as e: + console.print(f"[red]Error creating API key: {e}[/red]") + sys.exit(1) + finally: + db.close() + + +@app.command() +def list_keys(): + """List all API keys.""" + init_db() + + db = SessionLocal() + try: + keys = db.query(APIKey).order_by(APIKey.created_at.desc()).all() + + if not keys: + console.print("[yellow]No API keys found.[/yellow]") + return + + table = Table(title="API Keys") + table.add_column("ID", style="cyan") + table.add_column("Name", style="green") + table.add_column("Created", style="blue") + table.add_column("Status", style="magenta") + + for key in keys: + status = "Active" if key.is_active else "Revoked" + status_color = "green" if key.is_active else "red" + table.add_row( + str(key.id), + key.name, + key.created_at.strftime("%Y-%m-%d %H:%M:%S"), + f"[{status_color}]{status}[/{status_color}]", + ) + + console.print(table) + finally: + db.close() + + +@app.command() +def revoke_key(key_id: int = typer.Argument(..., help="API key ID to revoke")): + """Revoke an API key.""" + init_db() + + db = SessionLocal() + try: + key = db.query(APIKey).filter(APIKey.id == key_id).first() + + if not key: + console.print(f"[red]API key with ID {key_id} not found.[/red]") + sys.exit(1) + + if not key.is_active: + console.print(f"[yellow]API key '{key.name}' is already revoked.[/yellow]") + return + + key.is_active = False + db.commit() + + console.print(f"[green]โœ“ API key '{key.name}' has been revoked.[/green]") + except Exception as e: + console.print(f"[red]Error revoking API key: {e}[/red]") + sys.exit(1) + finally: + db.close() + + +@app.command() +def activate_key(key_id: int = typer.Argument(..., help="API key ID to activate")): + """Activate a revoked API key.""" + init_db() + + db = SessionLocal() + try: + key = db.query(APIKey).filter(APIKey.id == key_id).first() + + if not key: + console.print(f"[red]API key with ID {key_id} not found.[/red]") + sys.exit(1) + + if key.is_active: + console.print(f"[yellow]API key '{key.name}' is already active.[/yellow]") + return + + key.is_active = True + db.commit() + + console.print(f"[green]โœ“ API key '{key.name}' has been activated.[/green]") + except Exception as e: + console.print(f"[red]Error activating API key: {e}[/red]") + sys.exit(1) + finally: + db.close() + + +@app.command() +def init_database(): + """Initialize the database (create tables).""" + try: + init_db() + console.print("[green]โœ“ Database initialized successfully![/green]") + except Exception as e: + console.print(f"[red]Error initializing database: {e}[/red]") + sys.exit(1) + + +if __name__ == "__main__": + app() + diff --git a/api/database.py b/api/database.py new file mode 100644 index 00000000..fb28e7da --- /dev/null +++ b/api/database.py @@ -0,0 +1,114 @@ +"""Database models and connection management.""" + +import os +from datetime import datetime +from typing import Generator + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Integer, + String, + Text, + create_engine, + ForeignKey, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session, sessionmaker, relationship + +# Database file location +DATABASE_URL = os.getenv("API_DATABASE_URL", "sqlite:///./api_database.db") + +# Create engine +engine = create_engine( + DATABASE_URL, + connect_args={"check_same_thread": False} if DATABASE_URL.startswith("sqlite") else {}, +) + +# Session factory +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# Base class for models +Base = declarative_base() + + +class Analysis(Base): + """Analysis metadata and configuration.""" + + __tablename__ = "analyses" + + id = Column(String, primary_key=True, index=True) + ticker = Column(String, index=True, nullable=False) + analysis_date = Column(String, nullable=False) + status = Column( + String, nullable=False, default="pending" + ) # pending, running, completed, failed, cancelled + config_json = Column(Text, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + completed_at = Column(DateTime, nullable=True) + error_message = Column(Text, nullable=True) + progress_percentage = Column(Integer, default=0) + current_agent = Column(String, nullable=True) + + # Relationships + logs = relationship("AnalysisLog", back_populates="analysis", cascade="all, delete-orphan") + reports = relationship("AnalysisReport", back_populates="analysis", cascade="all, delete-orphan") + + +class AnalysisLog(Base): + """Log entries from analysis execution.""" + + __tablename__ = "analysis_logs" + + id = Column(Integer, primary_key=True, index=True) + analysis_id = Column(String, ForeignKey("analyses.id"), nullable=False) + timestamp = Column(DateTime, default=datetime.utcnow, nullable=False) + log_type = Column(String, nullable=False) # Tool Call, Reasoning, System + content = Column(Text, nullable=False) + + # Relationships + analysis = relationship("Analysis", back_populates="logs") + + +class AnalysisReport(Base): + """Report sections from analysis.""" + + __tablename__ = "analysis_reports" + + id = Column(Integer, primary_key=True, index=True) + analysis_id = Column(String, ForeignKey("analyses.id"), nullable=False) + report_type = Column(String, nullable=False) # market_report, news_report, etc. + content = Column(Text, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + + # Relationships + analysis = relationship("Analysis", back_populates="reports") + + +class APIKey(Base): + """API keys for authentication.""" + + __tablename__ = "api_keys" + + id = Column(Integer, primary_key=True, index=True) + key_hash = Column(String, unique=True, nullable=False, index=True) + name = Column(String, nullable=False) + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + is_active = Column(Boolean, default=True, nullable=False) + + +def init_db(): + """Initialize database and create all tables.""" + Base.metadata.create_all(bind=engine) + + +def get_db() -> Generator[Session, None, None]: + """Dependency for getting database sessions.""" + db = SessionLocal() + try: + yield db + finally: + db.close() + diff --git a/api/endpoints/__init__.py b/api/endpoints/__init__.py new file mode 100644 index 00000000..fc387163 --- /dev/null +++ b/api/endpoints/__init__.py @@ -0,0 +1,2 @@ +"""API endpoint routers.""" + diff --git a/api/endpoints/analyses.py b/api/endpoints/analyses.py new file mode 100644 index 00000000..4e54337c --- /dev/null +++ b/api/endpoints/analyses.py @@ -0,0 +1,356 @@ +"""Analysis CRUD endpoints.""" + +import json +import logging +import uuid +from datetime import datetime +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.orm import Session + +from api.auth import APIKey, get_current_api_key +from api.database import Analysis, AnalysisLog, AnalysisReport, get_db +from api.models import ( + AnalysisResponse, + AnalysisStatusResponse, + AnalysisSummary, + CreateAnalysisRequest, + LogEntry, + ReportResponse, +) +from api.state_manager import get_executor +from tradingagents.default_config import DEFAULT_CONFIG + +router = APIRouter(prefix="/api/v1/analyses", tags=["analyses"]) +logger = logging.getLogger(__name__) + + +@router.post("", response_model=AnalysisResponse, status_code=status.HTTP_201_CREATED) +async def create_analysis( + request: CreateAnalysisRequest, + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Create and start a new analysis.""" + # Generate analysis ID + analysis_id = str(uuid.uuid4()) + + # Build configuration + config = DEFAULT_CONFIG.copy() + config["max_debate_rounds"] = request.research_depth + config["max_risk_discuss_rounds"] = request.research_depth + + if request.llm_provider: + config["llm_provider"] = request.llm_provider + if request.backend_url: + config["backend_url"] = request.backend_url + if request.quick_think_llm: + config["quick_think_llm"] = request.quick_think_llm + if request.deep_think_llm: + config["deep_think_llm"] = request.deep_think_llm + + # Auto-detect asset class + from cli.asset_detection import detect_asset_class + asset_class = detect_asset_class(request.ticker) + config["asset_class"] = asset_class + + # Filter out fundamentals for commodities/crypto + selected_analysts = request.selected_analysts + if asset_class in ["commodity", "crypto"] and "fundamentals" in selected_analysts: + selected_analysts = [a for a in selected_analysts if a != "fundamentals"] + + # Create database record + analysis = Analysis( + id=analysis_id, + ticker=request.ticker, + analysis_date=request.analysis_date, + status="pending", + config_json=json.dumps(config), + progress_percentage=0, + ) + db.add(analysis) + db.commit() + db.refresh(analysis) + + # Start analysis in background + logger.info(f"Creating analysis {analysis_id} for {request.ticker}") + executor = get_executor() + try: + executor.start_analysis( + analysis_id=analysis_id, + ticker=request.ticker, + analysis_date=request.analysis_date, + selected_analysts=selected_analysts, + config=config, + ) + logger.info(f"Analysis {analysis_id} started successfully") + except Exception as e: + logger.error(f"Failed to start analysis {analysis_id}: {str(e)}") + # Update status to failed + analysis.status = "failed" + analysis.error_message = str(e) + db.commit() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to start analysis: {str(e)}", + ) + + return AnalysisResponse( + id=analysis.id, + ticker=analysis.ticker, + analysis_date=analysis.analysis_date, + status=analysis.status, + config=config, + reports=[], + progress_percentage=analysis.progress_percentage, + current_agent=analysis.current_agent, + created_at=analysis.created_at, + updated_at=analysis.updated_at, + completed_at=analysis.completed_at, + error_message=analysis.error_message, + ) + + +@router.get("", response_model=List[AnalysisSummary]) +async def list_analyses( + ticker: Optional[str] = Query(None, description="Filter by ticker"), + status: Optional[str] = Query(None, description="Filter by status"), + date_from: Optional[str] = Query(None, description="Filter by date (from)"), + date_to: Optional[str] = Query(None, description="Filter by date (to)"), + limit: int = Query(100, ge=1, le=1000, description="Max results"), + offset: int = Query(0, ge=0, description="Offset for pagination"), + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """List all analyses with optional filtering.""" + query = db.query(Analysis) + + if ticker: + query = query.filter(Analysis.ticker == ticker.upper()) + if status: + query = query.filter(Analysis.status == status) + if date_from: + query = query.filter(Analysis.analysis_date >= date_from) + if date_to: + query = query.filter(Analysis.analysis_date <= date_to) + + # Order by created_at descending + query = query.order_by(Analysis.created_at.desc()) + + # Apply pagination + analyses = query.offset(offset).limit(limit).all() + + return [ + AnalysisSummary( + id=a.id, + ticker=a.ticker, + analysis_date=a.analysis_date, + status=a.status, + created_at=a.created_at, + completed_at=a.completed_at, + error_message=a.error_message, + ) + for a in analyses + ] + + +@router.get("/{analysis_id}", response_model=AnalysisResponse) +async def get_analysis( + analysis_id: str, + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Get full analysis details.""" + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + + if not analysis: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Analysis {analysis_id} not found", + ) + + # Get reports + reports = ( + db.query(AnalysisReport) + .filter(AnalysisReport.analysis_id == analysis_id) + .all() + ) + + return AnalysisResponse( + id=analysis.id, + ticker=analysis.ticker, + analysis_date=analysis.analysis_date, + status=analysis.status, + config=json.loads(analysis.config_json), + reports=[ + ReportResponse( + report_type=r.report_type, + content=r.content, + created_at=r.created_at, + ) + for r in reports + ], + progress_percentage=analysis.progress_percentage, + current_agent=analysis.current_agent, + created_at=analysis.created_at, + updated_at=analysis.updated_at, + completed_at=analysis.completed_at, + error_message=analysis.error_message, + ) + + +@router.get("/{analysis_id}/status", response_model=AnalysisStatusResponse) +async def get_analysis_status( + analysis_id: str, + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Get current analysis status.""" + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + + if not analysis: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Analysis {analysis_id} not found", + ) + + return AnalysisStatusResponse( + id=analysis.id, + status=analysis.status, + progress_percentage=analysis.progress_percentage, + current_agent=analysis.current_agent, + updated_at=analysis.updated_at, + ) + + +@router.get("/{analysis_id}/reports", response_model=List[ReportResponse]) +async def get_analysis_reports( + analysis_id: str, + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Get all reports for an analysis.""" + # Check if analysis exists + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + if not analysis: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Analysis {analysis_id} not found", + ) + + reports = ( + db.query(AnalysisReport) + .filter(AnalysisReport.analysis_id == analysis_id) + .order_by(AnalysisReport.created_at) + .all() + ) + + return [ + ReportResponse( + report_type=r.report_type, + content=r.content, + created_at=r.created_at, + ) + for r in reports + ] + + +@router.get("/{analysis_id}/reports/{report_type}", response_model=ReportResponse) +async def get_analysis_report( + analysis_id: str, + report_type: str, + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Get a specific report for an analysis.""" + report = ( + db.query(AnalysisReport) + .filter( + AnalysisReport.analysis_id == analysis_id, + AnalysisReport.report_type == report_type, + ) + .first() + ) + + if not report: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Report {report_type} not found for analysis {analysis_id}", + ) + + return ReportResponse( + report_type=report.report_type, + content=report.content, + created_at=report.created_at, + ) + + +@router.get("/{analysis_id}/logs", response_model=List[LogEntry]) +async def get_analysis_logs( + analysis_id: str, + log_type: Optional[str] = Query(None, description="Filter by log type"), + limit: int = Query(100, ge=1, le=1000, description="Max results"), + offset: int = Query(0, ge=0, description="Offset for pagination"), + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Get execution logs for an analysis.""" + # Check if analysis exists + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + if not analysis: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Analysis {analysis_id} not found", + ) + + query = db.query(AnalysisLog).filter(AnalysisLog.analysis_id == analysis_id) + + if log_type: + query = query.filter(AnalysisLog.log_type == log_type) + + logs = query.order_by(AnalysisLog.timestamp).offset(offset).limit(limit).all() + + return [ + LogEntry( + timestamp=log.timestamp, + log_type=log.log_type, + content=log.content, + ) + for log in logs + ] + + +@router.delete("/{analysis_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_analysis( + analysis_id: str, + permanent: bool = Query(False, description="Permanently delete from database"), + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Cancel and/or delete an analysis.""" + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + + if not analysis: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Analysis {analysis_id} not found", + ) + + # Try to cancel if running + if analysis.status in ["pending", "running"]: + executor = get_executor() + executor.cancel_analysis(analysis_id) + + # Delete from database if requested + if permanent: + db.delete(analysis) + db.commit() + elif analysis.status not in ["cancelled", "failed", "completed"]: + # Just mark as cancelled + analysis.status = "cancelled" + analysis.updated_at = datetime.utcnow() + db.commit() + + return None + diff --git a/api/endpoints/data.py b/api/endpoints/data.py new file mode 100644 index 00000000..f66934ae --- /dev/null +++ b/api/endpoints/data.py @@ -0,0 +1,129 @@ +"""Cached data access endpoints.""" + +import csv +import glob +import os +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, status + +from api.auth import APIKey, get_current_api_key +from api.models.responses import CachedDataResponse, CachedTickerInfo + +router = APIRouter(prefix="/api/v1/data", tags=["data"]) + +# Data cache directory +DATA_CACHE_DIR = Path("./tradingagents/dataflows/data_cache") + + +def _parse_date_range(filename: str) -> Optional[Dict[str, str]]: + """Parse date range from cache filename.""" + try: + # Format: TICKER-YFin-data-START-END.csv + parts = filename.replace(".csv", "").split("-") + if len(parts) >= 5: + start_date = parts[-2] + end_date = parts[-1] + return {"start": start_date, "end": end_date} + except: + pass + return None + + +@router.get("/cache", response_model=List[CachedTickerInfo]) +async def list_cached_tickers( + api_key: APIKey = Depends(get_current_api_key), +): + """List all cached tickers with date ranges.""" + if not DATA_CACHE_DIR.exists(): + return [] + + cached_tickers = [] + + for csv_file in DATA_CACHE_DIR.glob("*-YFin-data-*.csv"): + ticker = csv_file.name.split("-")[0] + date_range = _parse_date_range(csv_file.name) + + if date_range: + # Count records + try: + with open(csv_file, "r") as f: + record_count = sum(1 for _ in f) - 1 # Subtract header + except: + record_count = 0 + + cached_tickers.append( + CachedTickerInfo( + ticker=ticker, + date_range=date_range, + record_count=record_count, + ) + ) + + return sorted(cached_tickers, key=lambda x: x.ticker) + + +@router.get("/cache/{ticker}", response_model=CachedDataResponse) +async def get_cached_data( + ticker: str, + start_date: Optional[str] = Query(None, description="Filter from date (YYYY-MM-DD)"), + end_date: Optional[str] = Query(None, description="Filter to date (YYYY-MM-DD)"), + api_key: APIKey = Depends(get_current_api_key), +): + """Get cached market data for a ticker.""" + # Find matching file + pattern = f"{ticker.upper()}-YFin-data-*.csv" + matching_files = list(DATA_CACHE_DIR.glob(pattern)) + + if not matching_files: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No cached data found for ticker {ticker}", + ) + + # Use the first matching file (should only be one) + csv_file = matching_files[0] + date_range = _parse_date_range(csv_file.name) + + if not date_range: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Could not parse date range from cache file", + ) + + # Read CSV data + data = [] + try: + with open(csv_file, "r") as f: + reader = csv.DictReader(f) + for row in reader: + # Filter by date if specified + row_date = row.get("Date", "") + if start_date and row_date < start_date: + continue + if end_date and row_date > end_date: + continue + + # Convert numeric fields + for field in ["Close", "High", "Low", "Open", "Volume"]: + if field in row and row[field]: + try: + row[field] = float(row[field]) + except: + pass + + data.append(row) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error reading cache file: {str(e)}", + ) + + return CachedDataResponse( + ticker=ticker.upper(), + date_range=date_range, + data=data, + ) + diff --git a/api/endpoints/tickers.py b/api/endpoints/tickers.py new file mode 100644 index 00000000..e2ec0f7f --- /dev/null +++ b/api/endpoints/tickers.py @@ -0,0 +1,95 @@ +"""Ticker history endpoints.""" + +from typing import List + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import func +from sqlalchemy.orm import Session + +from api.auth import APIKey, get_current_api_key +from api.database import Analysis, get_db +from api.models import AnalysisResponse, AnalysisSummary, ReportResponse, TickerInfo +from api.endpoints.analyses import get_analysis + +router = APIRouter(prefix="/api/v1/tickers", tags=["tickers"]) + + +@router.get("", response_model=List[TickerInfo]) +async def list_tickers( + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """List all tickers with analysis count.""" + # Query for ticker stats + results = ( + db.query( + Analysis.ticker, + func.count(Analysis.id).label("analysis_count"), + func.max(Analysis.analysis_date).label("latest_date"), + ) + .group_by(Analysis.ticker) + .order_by(Analysis.ticker) + .all() + ) + + return [ + TickerInfo( + ticker=r.ticker, + analysis_count=r.analysis_count, + latest_date=r.latest_date, + ) + for r in results + ] + + +@router.get("/{ticker}/analyses", response_model=List[AnalysisSummary]) +async def get_ticker_analyses( + ticker: str, + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Get all analyses for a ticker.""" + analyses = ( + db.query(Analysis) + .filter(Analysis.ticker == ticker.upper()) + .order_by(Analysis.created_at.desc()) + .all() + ) + + return [ + AnalysisSummary( + id=a.id, + ticker=a.ticker, + analysis_date=a.analysis_date, + status=a.status, + created_at=a.created_at, + completed_at=a.completed_at, + error_message=a.error_message, + ) + for a in analyses + ] + + +@router.get("/{ticker}/latest", response_model=AnalysisResponse) +async def get_ticker_latest_analysis( + ticker: str, + db: Session = Depends(get_db), + api_key: APIKey = Depends(get_current_api_key), +): + """Get the most recent analysis for a ticker.""" + analysis = ( + db.query(Analysis) + .filter(Analysis.ticker == ticker.upper()) + .order_by(Analysis.created_at.desc()) + .first() + ) + + if not analysis: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No analyses found for ticker {ticker}", + ) + + # Use the existing get_analysis function + return await get_analysis(analysis.id, db, api_key) + diff --git a/api/example_client.py b/api/example_client.py new file mode 100644 index 00000000..c7fe5b3b --- /dev/null +++ b/api/example_client.py @@ -0,0 +1,169 @@ +"""Example client for Testing the Trading Agents API.""" + +import asyncio +import json +import time +from datetime import date + +import httpx +import websockets + + +class TradingAgentsAPIClient: + """Simple client for interacting with the Trading Agents API.""" + + def __init__(self, api_key: str, base_url: str = "http://localhost:8000"): + self.api_key = api_key + self.base_url = base_url + self.headers = { + "X-API-Key": api_key, + "Content-Type": "application/json", + } + + async def create_analysis( + self, + ticker: str, + analysis_date: str = None, + selected_analysts: list = None, + research_depth: int = 1, + ): + """Create a new analysis.""" + if analysis_date is None: + analysis_date = date.today().strftime("%Y-%m-%d") + + if selected_analysts is None: + selected_analysts = ["market", "news"] + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/api/v1/analyses", + headers=self.headers, + json={ + "ticker": ticker, + "analysis_date": analysis_date, + "selected_analysts": selected_analysts, + "research_depth": research_depth, + }, + ) + response.raise_for_status() + return response.json() + + async def get_analysis(self, analysis_id: str): + """Get full analysis details.""" + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/api/v1/analyses/{analysis_id}", + headers=self.headers, + ) + response.raise_for_status() + return response.json() + + async def get_status(self, analysis_id: str): + """Get analysis status.""" + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/api/v1/analyses/{analysis_id}/status", + headers=self.headers, + ) + response.raise_for_status() + return response.json() + + async def list_analyses(self, ticker: str = None): + """List all analyses.""" + params = {} + if ticker: + params["ticker"] = ticker + + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/api/v1/analyses", + headers=self.headers, + params=params, + ) + response.raise_for_status() + return response.json() + + async def monitor_via_websocket(self, analysis_id: str, duration: int = 300): + """Monitor analysis via WebSocket.""" + ws_url = f"ws://localhost:8000/api/v1/ws/analyses/{analysis_id}" + + print(f"Connecting to WebSocket: {ws_url}") + + try: + async with websockets.connect(ws_url) as websocket: + print("Connected! Waiting for updates...") + + start_time = time.time() + while time.time() - start_time < duration: + try: + message = await asyncio.wait_for( + websocket.recv(), timeout=10.0 + ) + data = json.loads(message) + + print(f"\n[Update] Status: {data['status']}") + print(f" Progress: {data['progress_percentage']}%") + if data.get('current_agent'): + print(f" Agent: {data['current_agent']}") + + if data['status'] in ['completed', 'failed', 'cancelled']: + print("\nAnalysis finished!") + break + + except asyncio.TimeoutError: + # Send ping to keep connection alive + await websocket.send("ping") + continue + + except Exception as e: + print(f"WebSocket error: {e}") + + +async def main(): + """Example usage.""" + # Replace with your actual API key + API_KEY = "your-api-key-here" + + client = TradingAgentsAPIClient(API_KEY) + + print("=" * 60) + print("Trading Agents API Client Example") + print("=" * 60) + + # 1. Create an analysis + print("\n1. Creating analysis for AAPL...") + analysis = await client.create_analysis( + ticker="AAPL", + selected_analysts=["market", "news"], + research_depth=1, + ) + analysis_id = analysis["id"] + print(f" Created: {analysis_id}") + print(f" Status: {analysis['status']}") + + # 2. Monitor via WebSocket (run this in background or separately) + print("\n2. Monitoring via WebSocket...") + await client.monitor_via_websocket(analysis_id, duration=600) + + # 3. Get final results + print("\n3. Getting final results...") + final = await client.get_analysis(analysis_id) + print(f" Status: {final['status']}") + print(f" Reports: {len(final['reports'])} available") + + for report in final['reports']: + print(f"\n - {report['report_type']}:") + print(f" {report['content'][:200]}...") + + # 4. List all analyses + print("\n4. Listing all AAPL analyses...") + all_analyses = await client.list_analyses(ticker="AAPL") + print(f" Found {len(all_analyses)} analyses") + + print("\n" + "=" * 60) + print("Done!") + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/api/main.py b/api/main.py new file mode 100644 index 00000000..9c944e54 --- /dev/null +++ b/api/main.py @@ -0,0 +1,104 @@ +"""FastAPI Trading Agents API application.""" + +import logging +import sys +from contextlib import asynccontextmanager + +from dotenv import load_dotenv +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +# Load environment variables from .env file +load_dotenv() + +from api.database import init_db +from api.endpoints import analyses, data, tickers +from api.state_manager import get_executor, shutdown_executor +from api.websockets import status + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.StreamHandler(sys.stdout), + ], +) +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events.""" + # Startup + logger.info("Initializing Trading Agents API...") + init_db() + get_executor() + logger.info("Trading Agents API started successfully") + + yield + + # Shutdown + logger.info("Shutting down Trading Agents API...") + shutdown_executor() + logger.info("Trading Agents API shutdown complete") + + +# Create FastAPI app +app = FastAPI( + title="Trading Agents API", + description="REST API for managing multi-agent trading analyses", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", + lifespan=lifespan, +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # In production, specify actual origins + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Include routers +app.include_router(analyses.router) +app.include_router(tickers.router) +app.include_router(data.router) +app.include_router(status.router) + + +@app.get("/") +async def root(): + """Root endpoint.""" + return { + "name": "Trading Agents API", + "version": "1.0.0", + "docs": "/docs", + "redoc": "/redoc", + } + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + + +if __name__ == "__main__": + import uvicorn + import os + + port = int(os.getenv("API_PORT", "8002")) # Default to 8001 instead of 8000 + + uvicorn.run( + "api.main:app", + host="0.0.0.0", + port=port, + reload=True, + log_level="info", + ) + diff --git a/api/models/__init__.py b/api/models/__init__.py new file mode 100644 index 00000000..53eadd2c --- /dev/null +++ b/api/models/__init__.py @@ -0,0 +1,27 @@ +"""Pydantic models for API requests and responses.""" + +from api.models.requests import CreateAnalysisRequest, UpdateAnalysisRequest +from api.models.responses import ( + AnalysisResponse, + AnalysisStatusResponse, + AnalysisSummary, + ErrorResponse, + LogEntry, + ReportResponse, + TickerInfo, + CachedDataResponse, +) + +__all__ = [ + "CreateAnalysisRequest", + "UpdateAnalysisRequest", + "AnalysisResponse", + "AnalysisStatusResponse", + "AnalysisSummary", + "ErrorResponse", + "LogEntry", + "ReportResponse", + "TickerInfo", + "CachedDataResponse", +] + diff --git a/api/models/requests.py b/api/models/requests.py new file mode 100644 index 00000000..86e4e22f --- /dev/null +++ b/api/models/requests.py @@ -0,0 +1,68 @@ +"""Request models for API endpoints.""" + +from typing import List, Optional + +from pydantic import BaseModel, Field, field_validator + + +class CreateAnalysisRequest(BaseModel): + """Request to create a new analysis.""" + + ticker: str = Field(..., description="Ticker symbol to analyze") + analysis_date: str = Field(..., description="Analysis date in YYYY-MM-DD format") + selected_analysts: List[str] = Field( + default=["market", "news", "social", "fundamentals"], + description="List of analysts to run (market, news, social, fundamentals)", + ) + research_depth: int = Field( + default=1, + ge=1, + le=5, + description="Research depth (1=shallow, 3=medium, 5=deep)", + ) + llm_provider: Optional[str] = Field( + default=None, + description="LLM provider (openai, anthropic, google). Uses default if not specified.", + ) + backend_url: Optional[str] = Field( + default=None, + description="Backend URL for LLM. Uses default if not specified.", + ) + quick_think_llm: Optional[str] = Field( + default=None, + description="Model for quick thinking. Uses default if not specified.", + ) + deep_think_llm: Optional[str] = Field( + default=None, + description="Model for deep thinking. Uses default if not specified.", + ) + + @field_validator("ticker") + @classmethod + def ticker_uppercase(cls, v: str) -> str: + """Convert ticker to uppercase.""" + return v.upper().strip() + + @field_validator("selected_analysts") + @classmethod + def validate_analysts(cls, v: List[str]) -> List[str]: + """Validate analyst selections.""" + valid_analysts = {"market", "news", "social", "fundamentals"} + for analyst in v: + if analyst not in valid_analysts: + raise ValueError( + f"Invalid analyst: {analyst}. Must be one of {valid_analysts}" + ) + if not v: + raise ValueError("At least one analyst must be selected") + return v + + +class UpdateAnalysisRequest(BaseModel): + """Request to update analysis metadata.""" + + status: Optional[str] = Field( + default=None, + description="New status (pending, running, completed, failed, cancelled)", + ) + diff --git a/api/models/responses.py b/api/models/responses.py new file mode 100644 index 00000000..9e0b4238 --- /dev/null +++ b/api/models/responses.py @@ -0,0 +1,95 @@ +"""Response models for API endpoints.""" + +from datetime import datetime +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + + +class ErrorResponse(BaseModel): + """Standard error response.""" + + error: str = Field(..., description="Error message") + detail: Optional[str] = Field(None, description="Additional error details") + + +class ReportResponse(BaseModel): + """Single report section.""" + + report_type: str = Field(..., description="Type of report") + content: str = Field(..., description="Report content in markdown") + created_at: datetime = Field(..., description="When the report was created") + + +class LogEntry(BaseModel): + """Single log entry.""" + + timestamp: datetime = Field(..., description="When the log was created") + log_type: str = Field(..., description="Type of log (Tool Call, Reasoning, System)") + content: str = Field(..., description="Log content") + + +class AnalysisStatusResponse(BaseModel): + """Lightweight analysis status.""" + + id: str = Field(..., description="Analysis ID") + status: str = Field(..., description="Current status") + progress_percentage: int = Field(..., description="Progress percentage (0-100)") + current_agent: Optional[str] = Field(None, description="Currently active agent") + updated_at: datetime = Field(..., description="Last update timestamp") + + +class AnalysisSummary(BaseModel): + """Summary view of an analysis.""" + + id: str = Field(..., description="Analysis ID") + ticker: str = Field(..., description="Ticker symbol") + analysis_date: str = Field(..., description="Analysis date") + status: str = Field(..., description="Current status") + created_at: datetime = Field(..., description="Creation timestamp") + completed_at: Optional[datetime] = Field(None, description="Completion timestamp") + error_message: Optional[str] = Field(None, description="Error message if failed") + + +class AnalysisResponse(BaseModel): + """Full analysis details.""" + + id: str = Field(..., description="Analysis ID") + ticker: str = Field(..., description="Ticker symbol") + analysis_date: str = Field(..., description="Analysis date") + status: str = Field(..., description="Current status") + config: Dict = Field(..., description="Analysis configuration") + reports: List[ReportResponse] = Field( + default_factory=list, description="All report sections" + ) + progress_percentage: int = Field(..., description="Progress percentage (0-100)") + current_agent: Optional[str] = Field(None, description="Currently active agent") + created_at: datetime = Field(..., description="Creation timestamp") + updated_at: datetime = Field(..., description="Last update timestamp") + completed_at: Optional[datetime] = Field(None, description="Completion timestamp") + error_message: Optional[str] = Field(None, description="Error message if failed") + + +class TickerInfo(BaseModel): + """Ticker information summary.""" + + ticker: str = Field(..., description="Ticker symbol") + analysis_count: int = Field(..., description="Number of analyses") + latest_date: Optional[str] = Field(None, description="Most recent analysis date") + + +class CachedDataResponse(BaseModel): + """Cached market data response.""" + + ticker: str = Field(..., description="Ticker symbol") + date_range: Dict[str, str] = Field(..., description="Start and end dates") + data: List[Dict] = Field(..., description="OHLCV data records") + + +class CachedTickerInfo(BaseModel): + """Information about cached ticker data.""" + + ticker: str = Field(..., description="Ticker symbol") + date_range: Dict[str, str] = Field(..., description="Start and end dates") + record_count: int = Field(..., description="Number of records") + diff --git a/api/state_manager.py b/api/state_manager.py new file mode 100644 index 00000000..60b78c6e --- /dev/null +++ b/api/state_manager.py @@ -0,0 +1,447 @@ +"""Analysis execution and state management.""" + +import json +import logging +import os +import threading +import traceback +import uuid +from concurrent.futures import Future, ThreadPoolExecutor +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional + +from sqlalchemy.orm import Session + +from api.database import Analysis, AnalysisLog, AnalysisReport, SessionLocal +from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.graph.trading_graph import TradingAgentsGraph + +logger = logging.getLogger(__name__) + + +class AnalysisExecutor: + """Manages analysis execution with thread pool and state tracking.""" + + def __init__(self, max_workers: int = None): + """ + Initialize the executor. + + Args: + max_workers: Maximum concurrent analyses (default: from env or 4) + """ + if max_workers is None: + max_workers = int(os.getenv("MAX_CONCURRENT_ANALYSES", "4")) + + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.active_analyses: Dict[str, Future] = {} + self.status_callbacks: Dict[str, List[Callable]] = {} + self._lock = threading.Lock() + + def register_status_callback(self, analysis_id: str, callback: Callable): + """Register a callback for status updates.""" + with self._lock: + if analysis_id not in self.status_callbacks: + self.status_callbacks[analysis_id] = [] + self.status_callbacks[analysis_id].append(callback) + + def unregister_status_callbacks(self, analysis_id: str): + """Remove all callbacks for an analysis.""" + with self._lock: + if analysis_id in self.status_callbacks: + del self.status_callbacks[analysis_id] + + def _notify_callbacks(self, analysis_id: str, status_data: Dict[str, Any]): + """Notify all registered callbacks.""" + with self._lock: + callbacks = self.status_callbacks.get(analysis_id, []) + + for callback in callbacks: + try: + callback(status_data) + except Exception as e: + print(f"Error in status callback: {e}") + + def start_analysis( + self, + analysis_id: str, + ticker: str, + analysis_date: str, + selected_analysts: List[str], + config: Dict[str, Any], + ) -> str: + """ + Start a new analysis in the background. + + Args: + analysis_id: Unique analysis ID + ticker: Ticker symbol + analysis_date: Analysis date + selected_analysts: List of analyst types + config: Trading agents configuration + + Returns: + analysis_id + """ + future = self.executor.submit( + self._run_analysis, + analysis_id, + ticker, + analysis_date, + selected_analysts, + config, + ) + + with self._lock: + self.active_analyses[analysis_id] = future + + # Cleanup when done + future.add_done_callback(lambda f: self._cleanup_analysis(analysis_id)) + + return analysis_id + + def _cleanup_analysis(self, analysis_id: str): + """Clean up after analysis completes.""" + with self._lock: + if analysis_id in self.active_analyses: + del self.active_analyses[analysis_id] + self.unregister_status_callbacks(analysis_id) + + def cancel_analysis(self, analysis_id: str) -> bool: + """ + Attempt to cancel a running analysis. + + Returns: + True if cancelled, False if not found or already completed + """ + with self._lock: + future = self.active_analyses.get(analysis_id) + + if future and not future.done(): + cancelled = future.cancel() + if cancelled: + # Update database status + db = SessionLocal() + try: + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + if analysis: + analysis.status = "cancelled" + analysis.updated_at = datetime.utcnow() + db.commit() + finally: + db.close() + return cancelled + return False + + def get_status(self, analysis_id: str) -> Optional[Dict[str, Any]]: + """Get current status of an analysis.""" + db = SessionLocal() + try: + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + if not analysis: + return None + + return { + "id": analysis.id, + "status": analysis.status, + "progress_percentage": analysis.progress_percentage, + "current_agent": analysis.current_agent, + "updated_at": analysis.updated_at, + } + finally: + db.close() + + def _update_status( + self, + analysis_id: str, + status: Optional[str] = None, + progress: Optional[int] = None, + current_agent: Optional[str] = None, + error_message: Optional[str] = None, + ): + """Update analysis status in database and notify callbacks.""" + db = SessionLocal() + try: + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + if not analysis: + return + + if status: + analysis.status = status + if progress is not None: + analysis.progress_percentage = progress + if current_agent: + analysis.current_agent = current_agent + if error_message: + analysis.error_message = error_message + + analysis.updated_at = datetime.utcnow() + + if status == "completed": + analysis.completed_at = datetime.utcnow() + analysis.progress_percentage = 100 + + db.commit() + db.refresh(analysis) + + # Notify callbacks + status_data = { + "type": "status_update", + "analysis_id": analysis.id, + "status": analysis.status, + "progress_percentage": analysis.progress_percentage, + "current_agent": analysis.current_agent, + "timestamp": analysis.updated_at.isoformat(), + } + self._notify_callbacks(analysis_id, status_data) + + finally: + db.close() + + def _store_log(self, analysis_id: str, log_type: str, content: str): + """Store a log entry.""" + db = SessionLocal() + try: + log = AnalysisLog( + analysis_id=analysis_id, + log_type=log_type, + content=content, + timestamp=datetime.utcnow(), + ) + db.add(log) + db.commit() + finally: + db.close() + + def _store_report(self, analysis_id: str, report_type: str, content: str): + """Store or update a report section.""" + db = SessionLocal() + try: + # Check if report already exists + report = ( + db.query(AnalysisReport) + .filter( + AnalysisReport.analysis_id == analysis_id, + AnalysisReport.report_type == report_type, + ) + .first() + ) + + if report: + # Update existing + report.content = content + report.created_at = datetime.utcnow() + else: + # Create new + report = AnalysisReport( + analysis_id=analysis_id, + report_type=report_type, + content=content, + ) + db.add(report) + + db.commit() + finally: + db.close() + + def _run_analysis( + self, + analysis_id: str, + ticker: str, + analysis_date: str, + selected_analysts: List[str], + config: Dict[str, Any], + ): + """Execute the analysis (runs in thread pool).""" + logger.info(f"Starting analysis {analysis_id} for {ticker} on {analysis_date}") + try: + # Update status to running + self._update_status(analysis_id, status="running", progress=0) + logger.info(f"Analysis {analysis_id}: Initializing trading graph...") + + # Initialize the graph + graph = TradingAgentsGraph( + selected_analysts=selected_analysts, + config=config, + debug=False, + ) + + # Create initial state + init_agent_state = graph.propagator.create_initial_state(ticker, analysis_date) + init_agent_state["asset_class"] = config.get("asset_class", "equity") + args = graph.propagator.get_graph_args() + + # Track agent progress + agent_order = self._get_agent_order(selected_analysts) + total_agents = len(agent_order) + current_agent_index = 0 + + # Stream the analysis + trace = [] + for chunk in graph.graph.stream(init_agent_state, **args): + if len(chunk.get("messages", [])) == 0: + continue + + # Process the chunk + last_message = chunk["messages"][-1] + + # Extract content + if hasattr(last_message, "content"): + content = self._extract_content(last_message.content) + msg_type = "Reasoning" + + # Store log + self._store_log(analysis_id, msg_type, content) + + # Handle tool calls + if hasattr(last_message, "tool_calls") and last_message.tool_calls: + for tool_call in last_message.tool_calls: + if isinstance(tool_call, dict): + tool_name = tool_call["name"] + tool_args = tool_call["args"] + else: + tool_name = tool_call.name + tool_args = tool_call.args + + args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items()) + self._store_log( + analysis_id, "Tool Call", f"{tool_name}({args_str})" + ) + + # Check for completed reports + for report_type in [ + "market_report", + "sentiment_report", + "news_report", + "fundamentals_report", + ]: + if report_type in chunk and chunk[report_type]: + self._store_report(analysis_id, report_type, chunk[report_type]) + current_agent_index += 1 + progress = int((current_agent_index / total_agents) * 100) + agent_name = self._get_agent_name(report_type) + self._update_status( + analysis_id, + progress=min(progress, 95), + current_agent=agent_name, + ) + + # Check for investment debate state + if "investment_debate_state" in chunk and chunk["investment_debate_state"]: + debate_state = chunk["investment_debate_state"] + if "judge_decision" in debate_state and debate_state["judge_decision"]: + self._store_report( + analysis_id, "investment_plan", debate_state["judge_decision"] + ) + current_agent_index += 1 + progress = int((current_agent_index / total_agents) * 100) + self._update_status( + analysis_id, + progress=min(progress, 98), + current_agent="Research Manager", + ) + + # Check for trader plan + if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]: + self._store_report( + analysis_id, "trader_investment_plan", chunk["trader_investment_plan"] + ) + self._update_status( + analysis_id, + progress=99, + current_agent="Trader", + ) + + trace.append(chunk) + + # Get final state + if trace: + final_state = trace[-1] + + # Store final trade decision + if "final_trade_decision" in final_state: + decision = graph.process_signal(final_state["final_trade_decision"]) + self._store_report( + analysis_id, "final_trade_decision", final_state["final_trade_decision"] + ) + self._store_log( + analysis_id, "System", f"Final decision: {decision}" + ) + + # Mark as completed + logger.info(f"Analysis {analysis_id} completed successfully") + self._update_status(analysis_id, status="completed", progress=100) + + except Exception as e: + error_msg = str(e) + error_trace = traceback.format_exc() + logger.error(f"Analysis {analysis_id} failed: {error_msg}") + logger.error(f"Traceback:\n{error_trace}") + self._update_status( + analysis_id, status="failed", error_message=error_msg + ) + self._store_log(analysis_id, "System", f"Error: {error_msg}\n\nTraceback:\n{error_trace}") + + def _get_agent_order(self, selected_analysts: List[str]) -> List[str]: + """Get the order of agents for progress tracking.""" + agents = selected_analysts.copy() + agents.extend(["bull_researcher", "bear_researcher", "research_manager", "trader", "risk", "portfolio"]) + return agents + + def _get_agent_name(self, report_type: str) -> str: + """Get human-readable agent name from report type.""" + mapping = { + "market_report": "Market Analyst", + "sentiment_report": "Social Analyst", + "news_report": "News Analyst", + "fundamentals_report": "Fundamentals Analyst", + } + return mapping.get(report_type, "Unknown") + + def _extract_content(self, content: Any) -> str: + """Extract string content from various message formats.""" + if isinstance(content, str): + return content + elif isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + text_parts.append(item.get("text", "")) + elif item.get("type") == "tool_use": + text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") + else: + text_parts.append(str(item)) + return " ".join(text_parts) + else: + return str(content) + + def shutdown(self): + """Shutdown the executor and cancel all running analyses.""" + with self._lock: + # Cancel all active analyses + for analysis_id in list(self.active_analyses.keys()): + self.cancel_analysis(analysis_id) + + # Shutdown executor + self.executor.shutdown(wait=True) + + +# Global executor instance +_executor: Optional[AnalysisExecutor] = None + + +def get_executor() -> AnalysisExecutor: + """Get the global executor instance.""" + global _executor + if _executor is None: + _executor = AnalysisExecutor() + return _executor + + +def shutdown_executor(): + """Shutdown the global executor.""" + global _executor + if _executor: + _executor.shutdown() + _executor = None + diff --git a/api/websockets/__init__.py b/api/websockets/__init__.py new file mode 100644 index 00000000..a7fe0e14 --- /dev/null +++ b/api/websockets/__init__.py @@ -0,0 +1,2 @@ +"""WebSocket handlers.""" + diff --git a/api/websockets/status.py b/api/websockets/status.py new file mode 100644 index 00000000..0036f00f --- /dev/null +++ b/api/websockets/status.py @@ -0,0 +1,130 @@ +"""WebSocket status streaming.""" + +import json +from typing import Dict, List + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from sqlalchemy.orm import Session + +from api.database import Analysis, SessionLocal +from api.state_manager import get_executor + +router = APIRouter() + + +class ConnectionManager: + """Manages WebSocket connections for analyses.""" + + def __init__(self): + self.active_connections: Dict[str, List[WebSocket]] = {} + + async def connect(self, analysis_id: str, websocket: WebSocket): + """Accept and register a WebSocket connection.""" + await websocket.accept() + + if analysis_id not in self.active_connections: + self.active_connections[analysis_id] = [] + + self.active_connections[analysis_id].append(websocket) + + # Register callback with executor + executor = get_executor() + executor.register_status_callback(analysis_id, self._create_callback(analysis_id)) + + def disconnect(self, analysis_id: str, websocket: WebSocket): + """Remove a WebSocket connection.""" + if analysis_id in self.active_connections: + if websocket in self.active_connections[analysis_id]: + self.active_connections[analysis_id].remove(websocket) + + # Clean up if no more connections + if not self.active_connections[analysis_id]: + del self.active_connections[analysis_id] + + def _create_callback(self, analysis_id: str): + """Create a callback function for status updates.""" + def callback(status_data: dict): + # Note: This is called from a thread, so we can't use async here + # The actual broadcasting happens via the websocket event loop + import asyncio + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(self.broadcast(analysis_id, status_data)) + else: + loop.run_until_complete(self.broadcast(analysis_id, status_data)) + except: + # If no loop, we can't broadcast (connection will poll status instead) + pass + return callback + + async def broadcast(self, analysis_id: str, message: dict): + """Broadcast a message to all connections for an analysis.""" + if analysis_id not in self.active_connections: + return + + disconnected = [] + for connection in self.active_connections[analysis_id]: + try: + await connection.send_json(message) + except: + disconnected.append(connection) + + # Clean up disconnected clients + for connection in disconnected: + self.disconnect(analysis_id, connection) + + +# Global connection manager +manager = ConnectionManager() + + +@router.websocket("/api/v1/ws/analyses/{analysis_id}") +async def websocket_analysis_status(websocket: WebSocket, analysis_id: str): + """WebSocket endpoint for real-time analysis status updates.""" + # Verify analysis exists + db = SessionLocal() + try: + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + if not analysis: + await websocket.close(code=1008, reason="Analysis not found") + return + finally: + db.close() + + # Connect + await manager.connect(analysis_id, websocket) + + try: + # Send initial status + db = SessionLocal() + try: + analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first() + if analysis: + initial_status = { + "type": "status_update", + "analysis_id": analysis.id, + "status": analysis.status, + "progress_percentage": analysis.progress_percentage, + "current_agent": analysis.current_agent, + "timestamp": analysis.updated_at.isoformat(), + } + await websocket.send_json(initial_status) + finally: + db.close() + + # Keep connection alive and handle messages + while True: + # Wait for any messages from client (like ping) + data = await websocket.receive_text() + + # Echo back if it's a ping + if data == "ping": + await websocket.send_text("pong") + + except WebSocketDisconnect: + manager.disconnect(analysis_id, websocket) + except Exception as e: + print(f"WebSocket error: {e}") + manager.disconnect(analysis_id, websocket) + diff --git a/api_database.db b/api_database.db new file mode 100644 index 00000000..66e359b1 Binary files /dev/null and b/api_database.db differ diff --git a/requirements.txt b/requirements.txt index a6154cd2..ec9d9000 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,12 @@ rich questionary langchain_anthropic langchain-google-genai + +# API Dependencies +fastapi==0.109.0 +uvicorn[standard]==0.27.0 +sqlalchemy==2.0.25 +passlib[bcrypt]>=1.7.4 +bcrypt>=4.0.0 +python-multipart==0.0.6 +websockets==12.0 diff --git a/run_api.sh b/run_api.sh new file mode 100755 index 00000000..cd953f24 --- /dev/null +++ b/run_api.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Startup script for Trading Agents API + +echo "Starting Trading Agents API..." +echo "" +echo "Make sure you have:" +echo "1. Installed dependencies: pip install -r requirements.txt" +echo "2. Initialized database: python -m api.cli_admin init-database" +echo "3. Created an API key: python -m api.cli_admin create-key 'My Key'" +echo "" +echo "API will be available at: http://localhost:8001" +echo "API Documentation: http://localhost:8001/docs" +echo "" + +cd "$(dirname "$0")" +python -m api.main +