add rest and websocket api
This commit is contained in:
parent
3de318602f
commit
4f26352220
|
|
@ -132,7 +132,7 @@ An interface will appear showing results as they load, letting you track the age
|
|||
|
||||
### Implementation Details
|
||||
|
||||
We built Litadel with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o4-mini` and `gpt-4.1-mini` to save on costs as our framework makes **lots of** API calls.
|
||||
We built Litadel with LangGraph to ensure flexibility and modularity. We utilize `o1-preview` and `gpt-4o` as our deep thinking and fast thinking LLMs for our experiments. However, for testing purposes, we recommend you use `o1-mini` and `gpt-4o-mini` to save on costs as our framework makes **lots of** API calls.
|
||||
|
||||
### Python Usage
|
||||
|
||||
|
|
@ -157,8 +157,8 @@ from tradingagents.default_config import DEFAULT_CONFIG
|
|||
|
||||
# Create a custom config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model
|
||||
config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model
|
||||
config["deep_think_llm"] = "o1-mini" # Use a different model
|
||||
config["quick_think_llm"] = "gpt-4o-mini" # Use a different model
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
|
||||
# Configure data vendors (default uses yfinance and Alpha Vantage)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,286 @@
|
|||
# FastAPI Trading Agents API - Implementation Summary
|
||||
|
||||
## ✅ Implementation Complete
|
||||
|
||||
The FastAPI Trading Agents API has been successfully implemented with all planned features.
|
||||
|
||||
## 📁 Project Structure
|
||||
|
||||
```
|
||||
TradingAgents/
|
||||
├── api/
|
||||
│ ├── __init__.py ✅ Created
|
||||
│ ├── main.py ✅ FastAPI application
|
||||
│ ├── database.py ✅ SQLAlchemy models
|
||||
│ ├── auth.py ✅ API key authentication
|
||||
│ ├── state_manager.py ✅ Analysis execution manager
|
||||
│ ├── cli_admin.py ✅ Admin CLI tool
|
||||
│ ├── example_client.py ✅ Example Python client
|
||||
│ ├── README.md ✅ Full documentation
|
||||
│ ├── models/
|
||||
│ │ ├── __init__.py ✅ Exports
|
||||
│ │ ├── requests.py ✅ Pydantic request models
|
||||
│ │ └── responses.py ✅ Pydantic response models
|
||||
│ ├── endpoints/
|
||||
│ │ ├── __init__.py ✅ Created
|
||||
│ │ ├── analyses.py ✅ Analysis CRUD endpoints
|
||||
│ │ ├── tickers.py ✅ Ticker history endpoints
|
||||
│ │ └── data.py ✅ Cached data endpoints
|
||||
│ └── websockets/
|
||||
│ ├── __init__.py ✅ Created
|
||||
│ └── status.py ✅ Real-time status updates
|
||||
├── requirements-api.txt ✅ API dependencies
|
||||
├── run_api.sh ✅ Startup script
|
||||
├── API_QUICKSTART.md ✅ Quick start guide
|
||||
└── API_IMPLEMENTATION_SUMMARY.md ✅ This file
|
||||
```
|
||||
|
||||
## 🎯 Features Implemented
|
||||
|
||||
### Core API Features
|
||||
- ✅ REST API with FastAPI
|
||||
- ✅ WebSocket support for real-time updates
|
||||
- ✅ SQLite database with SQLAlchemy ORM
|
||||
- ✅ API key authentication with bcrypt hashing
|
||||
- ✅ Parallel analysis execution (ThreadPoolExecutor)
|
||||
- ✅ CORS middleware for frontend integration
|
||||
- ✅ Auto-generated OpenAPI documentation
|
||||
|
||||
### Database Schema
|
||||
- ✅ `analyses` - Analysis metadata and status
|
||||
- ✅ `analysis_logs` - Tool calls and reasoning logs
|
||||
- ✅ `analysis_reports` - Generated reports by type
|
||||
- ✅ `api_keys` - Hashed authentication keys
|
||||
|
||||
### REST Endpoints
|
||||
|
||||
#### Analyses (`/api/v1/analyses`)
|
||||
- ✅ `POST /` - Create and start new analysis
|
||||
- ✅ `GET /` - List all analyses with filtering
|
||||
- ✅ `GET /{id}` - Get full analysis details
|
||||
- ✅ `GET /{id}/status` - Get current status
|
||||
- ✅ `GET /{id}/reports` - Get all reports
|
||||
- ✅ `GET /{id}/reports/{type}` - Get specific report
|
||||
- ✅ `GET /{id}/logs` - Get execution logs
|
||||
- ✅ `DELETE /{id}` - Cancel/delete analysis
|
||||
|
||||
#### Tickers (`/api/v1/tickers`)
|
||||
- ✅ `GET /` - List all tickers with counts
|
||||
- ✅ `GET /{ticker}/analyses` - Get analyses for ticker
|
||||
- ✅ `GET /{ticker}/latest` - Get latest analysis
|
||||
|
||||
#### Data (`/api/v1/data`)
|
||||
- ✅ `GET /cache` - List cached tickers
|
||||
- ✅ `GET /cache/{ticker}` - Get cached market data
|
||||
|
||||
#### WebSocket (`/api/v1/ws`)
|
||||
- ✅ `WS /analyses/{id}` - Real-time status streaming
|
||||
|
||||
### State Management
|
||||
- ✅ `AnalysisExecutor` class for managing parallel execution
|
||||
- ✅ ThreadPoolExecutor with configurable max workers
|
||||
- ✅ Status callbacks for WebSocket broadcasting
|
||||
- ✅ Real-time progress tracking
|
||||
- ✅ Graceful shutdown and cleanup
|
||||
|
||||
### Admin Tools
|
||||
- ✅ `cli_admin.py` - Command-line admin interface
|
||||
- ✅ `create-key` - Generate new API key
|
||||
- ✅ `list-keys` - Show all keys
|
||||
- ✅ `revoke-key` - Deactivate key
|
||||
- ✅ `activate-key` - Reactivate key
|
||||
- ✅ `init-database` - Initialize database
|
||||
|
||||
### Documentation
|
||||
- ✅ `API_QUICKSTART.md` - Quick start guide
|
||||
- ✅ `api/README.md` - Full API documentation
|
||||
- ✅ `api/example_client.py` - Working example client
|
||||
- ✅ Auto-generated Swagger UI at `/docs`
|
||||
- ✅ Auto-generated ReDoc at `/redoc`
|
||||
|
||||
## 🔧 Configuration
|
||||
|
||||
### Environment Variables
|
||||
- `API_DATABASE_URL` - Database connection (default: SQLite)
|
||||
- `MAX_CONCURRENT_ANALYSES` - Concurrent analysis limit (default: 4)
|
||||
- Standard TradingAgents env vars (OPENAI_API_KEY, etc.)
|
||||
|
||||
### Parallel Execution
|
||||
- Default: 4 concurrent analyses
|
||||
- Configurable via environment variable
|
||||
- Thread-safe database operations
|
||||
- Independent graph instances per analysis
|
||||
|
||||
## 📊 Data Flow
|
||||
|
||||
1. **Client** sends POST request to create analysis
|
||||
2. **API** creates database record, returns analysis_id
|
||||
3. **Executor** starts analysis in background thread
|
||||
4. **Graph** streams chunks during execution
|
||||
5. **State Manager** captures logs, reports, tool calls
|
||||
6. **Database** stores all data in real-time
|
||||
7. **WebSocket** broadcasts status updates
|
||||
8. **Client** polls or streams for results
|
||||
|
||||
## 🔐 Security
|
||||
|
||||
- ✅ API key authentication (bcrypt hashed)
|
||||
- ✅ Secure password hashing with passlib
|
||||
- ✅ CORS middleware (configurable origins)
|
||||
- ✅ Input validation with Pydantic
|
||||
- ✅ SQL injection prevention (SQLAlchemy ORM)
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
### Manual Testing
|
||||
```bash
|
||||
# 1. Initialize
|
||||
python -m api.cli_admin init-database
|
||||
python -m api.cli_admin create-key "Test Key"
|
||||
|
||||
# 2. Start API
|
||||
python -m api.main
|
||||
|
||||
# 3. Test endpoints
|
||||
curl http://localhost:8000/health
|
||||
curl -X POST http://localhost:8000/api/v1/analyses \
|
||||
-H "X-API-Key: YOUR_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"ticker": "AAPL", "analysis_date": "2025-10-21", "selected_analysts": ["market", "news"], "research_depth": 1}'
|
||||
|
||||
# 4. Run example client
|
||||
python -m api.example_client
|
||||
```
|
||||
|
||||
### Automated Testing
|
||||
Recommended test suite (not included in this implementation):
|
||||
- Unit tests for each endpoint
|
||||
- Integration tests for analysis workflow
|
||||
- WebSocket connection tests
|
||||
- Concurrent execution tests
|
||||
- Authentication tests
|
||||
|
||||
## 📈 Performance
|
||||
|
||||
- **Concurrency**: 4 parallel analyses by default (configurable)
|
||||
- **Database**: SQLite for development, PostgreSQL recommended for production
|
||||
- **Threading**: Thread-safe operations throughout
|
||||
- **WebSocket**: Efficient real-time updates
|
||||
- **Caching**: Leverages existing TradingAgents data cache
|
||||
|
||||
## 🚀 Deployment
|
||||
|
||||
### Development
|
||||
```bash
|
||||
python -m api.main
|
||||
# or
|
||||
./run_api.sh
|
||||
```
|
||||
|
||||
### Production Checklist
|
||||
- [ ] Switch to PostgreSQL
|
||||
- [ ] Configure specific CORS origins
|
||||
- [ ] Enable HTTPS/WSS (reverse proxy)
|
||||
- [ ] Add rate limiting
|
||||
- [ ] Set up monitoring/logging
|
||||
- [ ] Configure database backups
|
||||
- [ ] Set environment-specific configs
|
||||
|
||||
## 🎓 Usage Examples
|
||||
|
||||
### cURL
|
||||
```bash
|
||||
# Create analysis
|
||||
curl -X POST http://localhost:8000/api/v1/analyses \
|
||||
-H "X-API-Key: YOUR_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"ticker": "AAPL", "analysis_date": "2025-10-21"}'
|
||||
|
||||
# Get status
|
||||
curl http://localhost:8000/api/v1/analyses/{id}/status \
|
||||
-H "X-API-Key: YOUR_KEY"
|
||||
```
|
||||
|
||||
### Python
|
||||
```python
|
||||
from api.example_client import TradingAgentsAPIClient
|
||||
|
||||
client = TradingAgentsAPIClient("YOUR_API_KEY")
|
||||
analysis = await client.create_analysis("AAPL")
|
||||
await client.monitor_via_websocket(analysis["id"])
|
||||
```
|
||||
|
||||
### JavaScript
|
||||
```javascript
|
||||
const response = await fetch('http://localhost:8000/api/v1/analyses', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'X-API-Key': 'YOUR_KEY',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
ticker: 'AAPL',
|
||||
analysis_date: '2025-10-21',
|
||||
}),
|
||||
});
|
||||
const analysis = await response.json();
|
||||
```
|
||||
|
||||
## ✨ Key Achievements
|
||||
|
||||
1. **Separation of Concerns**: API in separate `api/` directory
|
||||
2. **Persistent Storage**: SQLite with easy PostgreSQL migration
|
||||
3. **Real-time Updates**: WebSocket streaming of status
|
||||
4. **Parallel Processing**: ThreadPoolExecutor with configurable concurrency
|
||||
5. **Complete Logging**: All tool calls, reasoning, and reports saved
|
||||
6. **Security**: API key authentication with bcrypt
|
||||
7. **Documentation**: Comprehensive docs with examples
|
||||
8. **Admin Tools**: CLI for key management
|
||||
9. **Easy Setup**: Quick start in < 5 minutes
|
||||
|
||||
## 🔄 Integration Points
|
||||
|
||||
### With Existing TradingAgents
|
||||
- ✅ Uses `TradingAgentsGraph` for execution
|
||||
- ✅ Leverages `DEFAULT_CONFIG` for configuration
|
||||
- ✅ Integrates with asset detection
|
||||
- ✅ Uses existing data cache
|
||||
- ✅ Compatible with all LLM providers
|
||||
|
||||
### For Frontend Development
|
||||
- ✅ RESTful API design
|
||||
- ✅ WebSocket for real-time updates
|
||||
- ✅ CORS enabled for cross-origin requests
|
||||
- ✅ Comprehensive error responses
|
||||
- ✅ Pagination support
|
||||
- ✅ Filtering and search
|
||||
|
||||
## 📝 Next Steps (Optional Enhancements)
|
||||
|
||||
Future improvements could include:
|
||||
- User authentication and multi-tenancy
|
||||
- Rate limiting per API key
|
||||
- Analysis templates
|
||||
- Scheduled/recurring analyses
|
||||
- Email/webhook notifications
|
||||
- Analysis comparison tools
|
||||
- Performance metrics dashboard
|
||||
- Analysis export (PDF, CSV)
|
||||
- Bulk operations API
|
||||
- GraphQL endpoint
|
||||
|
||||
## 🎉 Conclusion
|
||||
|
||||
The FastAPI Trading Agents API is fully implemented and ready for use. It provides a robust, scalable foundation for frontend applications to interact with the TradingAgents multi-agent system.
|
||||
|
||||
### Getting Started
|
||||
1. Read `API_QUICKSTART.md` for setup instructions
|
||||
2. Check `api/README.md` for full documentation
|
||||
3. Run `api/example_client.py` to see it in action
|
||||
4. Start building your frontend!
|
||||
|
||||
### Support
|
||||
- API Documentation: http://localhost:8000/docs
|
||||
- Project README: `api/README.md`
|
||||
- Quick Start: `API_QUICKSTART.md`
|
||||
|
||||
|
|
@ -0,0 +1,300 @@
|
|||
# Trading Agents API - Quick Start Guide
|
||||
|
||||
This guide will get your FastAPI Trading Agents API up and running in minutes.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.10+
|
||||
- TradingAgents installed and configured
|
||||
- Required environment variables (OPENAI_API_KEY, ALPHA_VANTAGE_API_KEY, etc.)
|
||||
|
||||
## Setup (5 minutes)
|
||||
|
||||
### 1. Install API Dependencies
|
||||
|
||||
```bash
|
||||
cd TradingAgents
|
||||
pip install -r requirements-api.txt
|
||||
```
|
||||
|
||||
### 2. Initialize Database
|
||||
|
||||
```bash
|
||||
python -m api.cli_admin init-database
|
||||
```
|
||||
|
||||
### 3. Create Your First API Key
|
||||
|
||||
```bash
|
||||
python -m api.cli_admin create-key "Development Key"
|
||||
```
|
||||
|
||||
**IMPORTANT**: Save the API key that's displayed. You won't be able to see it again!
|
||||
|
||||
Example output:
|
||||
```
|
||||
✓ API Key created successfully!
|
||||
|
||||
Name: Development Key
|
||||
Created: 2025-10-21 14:30:00
|
||||
|
||||
API Key (save this, it won't be shown again):
|
||||
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
|
||||
Use this key in the X-API-Key header for all API requests.
|
||||
```
|
||||
|
||||
### 4. Start the API Server
|
||||
|
||||
```bash
|
||||
python -m api.main
|
||||
```
|
||||
|
||||
Or use the startup script:
|
||||
```bash
|
||||
./run_api.sh
|
||||
```
|
||||
|
||||
The API will start at: `http://localhost:8000`
|
||||
|
||||
## Test Your API (2 minutes)
|
||||
|
||||
### View API Documentation
|
||||
|
||||
Open your browser to:
|
||||
- **Swagger UI**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
|
||||
### Create Your First Analysis
|
||||
|
||||
Replace `YOUR_API_KEY` with the key you created:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/api/v1/analyses" \
|
||||
-H "X-API-Key: YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"ticker": "AAPL",
|
||||
"analysis_date": "2025-10-21",
|
||||
"selected_analysts": ["market", "news"],
|
||||
"research_depth": 1
|
||||
}'
|
||||
```
|
||||
|
||||
You'll get a response with an `analysis_id`:
|
||||
```json
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"ticker": "AAPL",
|
||||
"status": "pending",
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
### Check Status
|
||||
|
||||
```bash
|
||||
curl "http://localhost:8000/api/v1/analyses/YOUR_ANALYSIS_ID/status" \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
```
|
||||
|
||||
### Get Results
|
||||
|
||||
```bash
|
||||
curl "http://localhost:8000/api/v1/analyses/YOUR_ANALYSIS_ID" \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
```
|
||||
|
||||
## Using the Python Client
|
||||
|
||||
### Install Additional Dependencies
|
||||
|
||||
```bash
|
||||
pip install httpx websockets
|
||||
```
|
||||
|
||||
### Run Example Client
|
||||
|
||||
Edit `api/example_client.py` and replace `YOUR_API_KEY`, then:
|
||||
|
||||
```bash
|
||||
python -m api.example_client
|
||||
```
|
||||
|
||||
This will:
|
||||
1. Create an analysis for AAPL
|
||||
2. Monitor it via WebSocket
|
||||
3. Display the results
|
||||
|
||||
## Next Steps
|
||||
|
||||
### Multiple Parallel Analyses
|
||||
|
||||
Create multiple analyses at once - they'll run in parallel:
|
||||
|
||||
```bash
|
||||
# Start 3 analyses
|
||||
for ticker in AAPL MSFT GOOGL; do
|
||||
curl -X POST "http://localhost:8000/api/v1/analyses" \
|
||||
-H "X-API-Key: YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{\"ticker\": \"$ticker\", \"analysis_date\": \"2025-10-21\", \"selected_analysts\": [\"market\", \"news\"], \"research_depth\": 1}"
|
||||
done
|
||||
```
|
||||
|
||||
### List All Analyses
|
||||
|
||||
```bash
|
||||
curl "http://localhost:8000/api/v1/analyses" \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
```
|
||||
|
||||
### Get Analyses for Specific Ticker
|
||||
|
||||
```bash
|
||||
curl "http://localhost:8000/api/v1/tickers/AAPL/analyses" \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
```
|
||||
|
||||
### View Cached Market Data
|
||||
|
||||
```bash
|
||||
# List all cached tickers
|
||||
curl "http://localhost:8000/api/v1/data/cache" \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
|
||||
# Get data for specific ticker
|
||||
curl "http://localhost:8000/api/v1/data/cache/AAPL" \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
```
|
||||
|
||||
## WebSocket Real-Time Monitoring
|
||||
|
||||
### JavaScript Example
|
||||
|
||||
```javascript
|
||||
const ws = new WebSocket('ws://localhost:8000/api/v1/ws/analyses/YOUR_ANALYSIS_ID');
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
const update = JSON.parse(event.data);
|
||||
console.log(`Status: ${update.status}, Progress: ${update.progress_percentage}%`);
|
||||
|
||||
if (update.status === 'completed') {
|
||||
console.log('Analysis finished!');
|
||||
ws.close();
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = (error) => console.error('WebSocket error:', error);
|
||||
```
|
||||
|
||||
### Python Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
|
||||
async def monitor_analysis(analysis_id, api_key):
|
||||
uri = f"ws://localhost:8000/api/v1/ws/analyses/{analysis_id}"
|
||||
|
||||
async with websockets.connect(uri) as websocket:
|
||||
while True:
|
||||
message = await websocket.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
print(f"Status: {data['status']}, Progress: {data['progress_percentage']}%")
|
||||
|
||||
if data['status'] in ['completed', 'failed', 'cancelled']:
|
||||
break
|
||||
|
||||
asyncio.run(monitor_analysis('YOUR_ANALYSIS_ID', 'YOUR_API_KEY'))
|
||||
```
|
||||
|
||||
## API Key Management
|
||||
|
||||
### List All Keys
|
||||
|
||||
```bash
|
||||
python -m api.cli_admin list-keys
|
||||
```
|
||||
|
||||
### Create Additional Key
|
||||
|
||||
```bash
|
||||
python -m api.cli_admin create-key "Frontend App"
|
||||
```
|
||||
|
||||
### Revoke a Key
|
||||
|
||||
```bash
|
||||
python -m api.cli_admin revoke-key 1
|
||||
```
|
||||
|
||||
### Activate a Key
|
||||
|
||||
```bash
|
||||
python -m api.cli_admin activate-key 1
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Set these before starting the API:
|
||||
|
||||
```bash
|
||||
# Maximum concurrent analyses (default: 4)
|
||||
export MAX_CONCURRENT_ANALYSES=8
|
||||
|
||||
# Database URL (default: SQLite in current directory)
|
||||
export API_DATABASE_URL="sqlite:///./api_database.db"
|
||||
|
||||
# Or use PostgreSQL for production
|
||||
export API_DATABASE_URL="postgresql://user:pass@localhost/trading_agents"
|
||||
|
||||
# Standard TradingAgents config
|
||||
export OPENAI_API_KEY="your-key"
|
||||
export ALPHA_VANTAGE_API_KEY="your-key"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Invalid or inactive API key"
|
||||
- Make sure you're using the exact key that was displayed when you created it
|
||||
- Check that the key hasn't been revoked: `python -m api.cli_admin list-keys`
|
||||
|
||||
### "Analysis not found"
|
||||
- The analysis ID must be exactly as returned from the create endpoint
|
||||
- Check available analyses: `curl "http://localhost:8000/api/v1/analyses" -H "X-API-Key: YOUR_KEY"`
|
||||
|
||||
### Database locked
|
||||
- SQLite has limited concurrency
|
||||
- Reduce `MAX_CONCURRENT_ANALYSES` to 2-3
|
||||
- Or switch to PostgreSQL
|
||||
|
||||
### Import errors
|
||||
- Make sure you're in the TradingAgents directory
|
||||
- Run with: `python -m api.main` (not `python api/main.py`)
|
||||
|
||||
## Full API Reference
|
||||
|
||||
See `api/README.md` for complete documentation of all endpoints and features.
|
||||
|
||||
## Production Deployment
|
||||
|
||||
For production use, see the deployment section in `api/README.md`. Key points:
|
||||
|
||||
1. Use PostgreSQL instead of SQLite
|
||||
2. Configure CORS for your frontend domain
|
||||
3. Use HTTPS/WSS with reverse proxy
|
||||
4. Add rate limiting
|
||||
5. Set up monitoring and logging
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
- Check the full documentation: `api/README.md`
|
||||
- Review the main TradingAgents README
|
||||
- Check the interactive docs: http://localhost:8000/docs
|
||||
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
# Trading Agents REST API
|
||||
|
||||
FastAPI-based REST API for managing multi-agent trading analyses with real-time WebSocket support.
|
||||
|
||||
## Features
|
||||
|
||||
- **REST API** for creating, monitoring, and managing trading analyses
|
||||
- **WebSocket** support for real-time status updates
|
||||
- **Parallel execution** of multiple analyses (configurable concurrency)
|
||||
- **SQLite database** for persistent storage
|
||||
- **API key authentication** for secure access
|
||||
- **Complete analysis history** with logs and reports
|
||||
|
||||
## Installation
|
||||
|
||||
1. Install API dependencies:
|
||||
```bash
|
||||
cd TradingAgents
|
||||
pip install -r requirements-api.txt
|
||||
```
|
||||
|
||||
2. Initialize the database and create an API key:
|
||||
```bash
|
||||
python -m api.cli_admin init-database
|
||||
python -m api.cli_admin create-key "My First Key"
|
||||
```
|
||||
|
||||
Save the generated API key - you'll need it for all API requests.
|
||||
|
||||
## Running the API
|
||||
|
||||
Start the API server:
|
||||
```bash
|
||||
python -m api.main
|
||||
```
|
||||
|
||||
Or with uvicorn directly:
|
||||
```bash
|
||||
uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
The API will be available at `http://localhost:8000`
|
||||
|
||||
## API Documentation
|
||||
|
||||
Once the server is running, visit:
|
||||
- **Swagger UI**: http://localhost:8000/docs
|
||||
- **ReDoc**: http://localhost:8000/redoc
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Create an Analysis
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/api/v1/analyses" \
|
||||
-H "X-API-Key: YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"ticker": "AAPL",
|
||||
"analysis_date": "2025-10-21",
|
||||
"selected_analysts": ["market", "news", "social", "fundamentals"],
|
||||
"research_depth": 1
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"id": "uuid-here",
|
||||
"ticker": "AAPL",
|
||||
"analysis_date": "2025-10-21",
|
||||
"status": "pending",
|
||||
"progress_percentage": 0,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Monitor via WebSocket
|
||||
|
||||
```javascript
|
||||
const ws = new WebSocket('ws://localhost:8000/api/v1/ws/analyses/{analysis_id}');
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
const update = JSON.parse(event.data);
|
||||
console.log(`Status: ${update.status}, Progress: ${update.progress_percentage}%`);
|
||||
};
|
||||
```
|
||||
|
||||
### 3. Get Analysis Results
|
||||
|
||||
```bash
|
||||
curl "http://localhost:8000/api/v1/analyses/{analysis_id}" \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Analyses
|
||||
|
||||
- `POST /api/v1/analyses` - Create and start new analysis
|
||||
- `GET /api/v1/analyses` - List all analyses (with filtering)
|
||||
- `GET /api/v1/analyses/{id}` - Get full analysis details
|
||||
- `GET /api/v1/analyses/{id}/status` - Get current status
|
||||
- `GET /api/v1/analyses/{id}/reports` - Get all reports
|
||||
- `GET /api/v1/analyses/{id}/reports/{type}` - Get specific report
|
||||
- `GET /api/v1/analyses/{id}/logs` - Get execution logs
|
||||
- `DELETE /api/v1/analyses/{id}` - Cancel/delete analysis
|
||||
|
||||
### Tickers
|
||||
|
||||
- `GET /api/v1/tickers` - List all tickers with analysis counts
|
||||
- `GET /api/v1/tickers/{ticker}/analyses` - Get all analyses for ticker
|
||||
- `GET /api/v1/tickers/{ticker}/latest` - Get latest analysis for ticker
|
||||
|
||||
### Data
|
||||
|
||||
- `GET /api/v1/data/cache` - List cached ticker data
|
||||
- `GET /api/v1/data/cache/{ticker}` - Get cached market data
|
||||
|
||||
### WebSocket
|
||||
|
||||
- `WS /api/v1/ws/analyses/{id}` - Real-time status updates
|
||||
|
||||
## API Key Management
|
||||
|
||||
### Create API Key
|
||||
```bash
|
||||
python -m api.cli_admin create-key "Description"
|
||||
```
|
||||
|
||||
### List API Keys
|
||||
```bash
|
||||
python -m api.cli_admin list-keys
|
||||
```
|
||||
|
||||
### Revoke API Key
|
||||
```bash
|
||||
python -m api.cli_admin revoke-key <key_id>
|
||||
```
|
||||
|
||||
### Activate API Key
|
||||
```bash
|
||||
python -m api.cli_admin activate-key <key_id>
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
- `API_DATABASE_URL` - Database connection string (default: `sqlite:///./api_database.db`)
|
||||
- `MAX_CONCURRENT_ANALYSES` - Maximum parallel analyses (default: `4`)
|
||||
- Standard TradingAgents config (LLM providers, API keys, etc.)
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
api/
|
||||
├── main.py # FastAPI application
|
||||
├── database.py # SQLAlchemy models
|
||||
├── auth.py # API key authentication
|
||||
├── state_manager.py # Analysis execution manager
|
||||
├── models/ # Pydantic request/response models
|
||||
├── endpoints/ # REST endpoint handlers
|
||||
└── websockets/ # WebSocket handlers
|
||||
```
|
||||
|
||||
## Database Schema
|
||||
|
||||
- **analyses** - Analysis metadata and status
|
||||
- **analysis_logs** - Execution logs (tool calls, reasoning)
|
||||
- **analysis_reports** - Generated reports (by type)
|
||||
- **api_keys** - Authentication keys
|
||||
|
||||
## Parallel Execution
|
||||
|
||||
The API uses a `ThreadPoolExecutor` to run multiple analyses concurrently:
|
||||
|
||||
- Default: 4 concurrent analyses
|
||||
- Configurable via `MAX_CONCURRENT_ANALYSES` env var
|
||||
- Each analysis runs in its own thread
|
||||
- Database writes are thread-safe
|
||||
- Cached data is read-only (thread-safe)
|
||||
|
||||
## Example Frontend Integration
|
||||
|
||||
```javascript
|
||||
class TradingAnalysisClient {
|
||||
constructor(apiKey, baseURL = 'http://localhost:8000') {
|
||||
this.apiKey = apiKey;
|
||||
this.baseURL = baseURL;
|
||||
}
|
||||
|
||||
async createAnalysis(ticker, date, analysts = ['market', 'news']) {
|
||||
const response = await fetch(`${this.baseURL}/api/v1/analyses`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'X-API-Key': this.apiKey,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
ticker,
|
||||
analysis_date: date,
|
||||
selected_analysts: analysts,
|
||||
research_depth: 1,
|
||||
}),
|
||||
});
|
||||
return await response.json();
|
||||
}
|
||||
|
||||
connectWebSocket(analysisId, onUpdate) {
|
||||
const ws = new WebSocket(`ws://localhost:8000/api/v1/ws/analyses/${analysisId}`);
|
||||
ws.onmessage = (event) => onUpdate(JSON.parse(event.data));
|
||||
return ws;
|
||||
}
|
||||
|
||||
async getAnalysis(analysisId) {
|
||||
const response = await fetch(
|
||||
`${this.baseURL}/api/v1/analyses/${analysisId}`,
|
||||
{ headers: { 'X-API-Key': this.apiKey } }
|
||||
);
|
||||
return await response.json();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Database locked error
|
||||
SQLite has limited concurrent write support. If you get database locked errors:
|
||||
- Reduce `MAX_CONCURRENT_ANALYSES`
|
||||
- Or switch to PostgreSQL by changing `API_DATABASE_URL`
|
||||
|
||||
### Import errors
|
||||
Make sure you're running from the TradingAgents root directory:
|
||||
```bash
|
||||
cd TradingAgents
|
||||
python -m api.main
|
||||
```
|
||||
|
||||
### WebSocket connection refused
|
||||
Check that the server is running and CORS is properly configured for your frontend origin.
|
||||
|
||||
## Production Deployment
|
||||
|
||||
For production use:
|
||||
|
||||
1. **Use PostgreSQL** instead of SQLite
|
||||
2. **Secure CORS** - Set specific allowed origins in `main.py`
|
||||
3. **HTTPS/WSS** - Use reverse proxy (nginx) with SSL
|
||||
4. **Monitoring** - Add logging, metrics, and health checks
|
||||
5. **Rate limiting** - Add rate limiting middleware
|
||||
6. **Backup** - Regular database backups
|
||||
|
||||
## License
|
||||
|
||||
Same as TradingAgents/Litadel main project.
|
||||
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
# Quick Start - Trading Agents API
|
||||
|
||||
## 1. Install Dependencies (if not already done)
|
||||
|
||||
```bash
|
||||
cd TradingAgents
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 2. Initialize Database & Create API Key
|
||||
|
||||
```bash
|
||||
# Initialize the database
|
||||
python -m api.cli_admin init-database
|
||||
|
||||
# Create your first API key
|
||||
python -m api.cli_admin create-key "My Development Key"
|
||||
```
|
||||
|
||||
**IMPORTANT**: Save the API key that's displayed! You'll need it for all requests.
|
||||
|
||||
## 3. Start the API
|
||||
|
||||
```bash
|
||||
python -m api.main
|
||||
```
|
||||
|
||||
Or use the startup script:
|
||||
```bash
|
||||
./run_api.sh
|
||||
```
|
||||
|
||||
The API will start at: **http://localhost:8001**
|
||||
|
||||
## 4. Test It
|
||||
|
||||
Open your browser to see the interactive API documentation:
|
||||
- **Swagger UI**: http://localhost:8001/docs
|
||||
- **ReDoc**: http://localhost:8001/redoc
|
||||
|
||||
## 5. Create Your First Analysis
|
||||
|
||||
Using curl (replace `YOUR_API_KEY` with the key from step 2):
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8001/api/v1/analyses" \
|
||||
-H "X-API-Key: YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"ticker": "AAPL",
|
||||
"analysis_date": "2025-10-21",
|
||||
"selected_analysts": ["market", "news"],
|
||||
"research_depth": 1
|
||||
}'
|
||||
```
|
||||
|
||||
You'll get back an `analysis_id`. Use it to check status:
|
||||
|
||||
```bash
|
||||
curl "http://localhost:8001/api/v1/analyses/YOUR_ANALYSIS_ID/status" \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
```
|
||||
|
||||
## Configuration (Optional)
|
||||
|
||||
Set environment variables before starting:
|
||||
|
||||
```bash
|
||||
# Maximum concurrent analyses (default: 4)
|
||||
export MAX_CONCURRENT_ANALYSES=8
|
||||
|
||||
# Your LLM API keys (if not already set)
|
||||
export OPENAI_API_KEY="your-key"
|
||||
export ALPHA_VANTAGE_API_KEY="your-key"
|
||||
|
||||
# Then start the API
|
||||
python -m api.main
|
||||
```
|
||||
|
||||
## Common Commands
|
||||
|
||||
```bash
|
||||
# List all API keys
|
||||
python -m api.cli_admin list-keys
|
||||
|
||||
# Create a new API key
|
||||
python -m api.cli_admin create-key "Frontend App"
|
||||
|
||||
# Revoke an API key (use ID from list-keys)
|
||||
python -m api.cli_admin revoke-key 1
|
||||
```
|
||||
|
||||
## Full Documentation
|
||||
|
||||
- Quick Start: `API_QUICKSTART.md`
|
||||
- Full API Docs: `api/README.md`
|
||||
- Implementation Details: `API_IMPLEMENTATION_SUMMARY.md`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**"Invalid or inactive API key"**
|
||||
- Make sure you're using the exact key from step 2
|
||||
- Check: `python -m api.cli_admin list-keys`
|
||||
|
||||
**Import errors**
|
||||
- Make sure you're in the TradingAgents directory
|
||||
- Use: `python -m api.main` (not `python api/main.py`)
|
||||
|
||||
**Port 8001 already in use**
|
||||
- Change port: `API_PORT=8002 python -m api.main`
|
||||
- Or: `python -m uvicorn api.main:app --port 8002`
|
||||
|
||||
That's it! Your API is ready to use. 🚀
|
||||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
"""FastAPI Trading Agents API."""
|
||||
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
"""API key authentication."""
|
||||
|
||||
import secrets
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
# Suppress the known passlib/bcrypt compatibility warning
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message=".*trapped.*")
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.database import APIKey, get_db
|
||||
|
||||
# Password hashing context
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# API key header scheme
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=True)
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""Hash an API key."""
|
||||
return pwd_context.hash(api_key)
|
||||
|
||||
|
||||
def verify_api_key_hash(api_key: str, key_hash: str) -> bool:
|
||||
"""Verify an API key against its hash."""
|
||||
return pwd_context.verify(api_key, key_hash)
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
"""Generate a new random API key."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def create_api_key(db: Session, name: str) -> tuple[str, APIKey]:
|
||||
"""
|
||||
Create a new API key in the database.
|
||||
|
||||
Returns:
|
||||
tuple: (plain_key, db_record)
|
||||
"""
|
||||
plain_key = generate_api_key()
|
||||
key_hash = hash_api_key(plain_key)
|
||||
|
||||
db_key = APIKey(key_hash=key_hash, name=name, is_active=True)
|
||||
db.add(db_key)
|
||||
db.commit()
|
||||
db.refresh(db_key)
|
||||
|
||||
return plain_key, db_key
|
||||
|
||||
|
||||
def get_api_key_by_hash(db: Session, key_hash: str) -> Optional[APIKey]:
|
||||
"""Get an API key record by its hash."""
|
||||
return db.query(APIKey).filter(APIKey.key_hash == key_hash).first()
|
||||
|
||||
|
||||
def verify_api_key_from_db(db: Session, api_key: str) -> Optional[APIKey]:
|
||||
"""
|
||||
Verify an API key against the database.
|
||||
|
||||
Returns:
|
||||
APIKey record if valid and active, None otherwise
|
||||
"""
|
||||
# Get all active keys
|
||||
active_keys = db.query(APIKey).filter(APIKey.is_active == True).all()
|
||||
|
||||
# Check each one
|
||||
for key_record in active_keys:
|
||||
if verify_api_key_hash(api_key, key_record.key_hash):
|
||||
return key_record
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_api_key(
|
||||
api_key: str = Security(api_key_header),
|
||||
db: Session = Depends(get_db),
|
||||
) -> APIKey:
|
||||
"""
|
||||
Dependency to verify API key and return the key record.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or inactive
|
||||
"""
|
||||
key_record = verify_api_key_from_db(db, api_key)
|
||||
|
||||
if not key_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or inactive API key",
|
||||
)
|
||||
|
||||
return key_record
|
||||
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
"""Admin CLI tool for API key management."""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from api.auth import create_api_key
|
||||
from api.database import APIKey, SessionLocal, init_db
|
||||
|
||||
app = typer.Typer(
|
||||
name="api-admin",
|
||||
help="Trading Agents API Administration Tool",
|
||||
)
|
||||
console = Console()
|
||||
|
||||
|
||||
@app.command()
|
||||
def create_key(name: str = typer.Argument(..., help="Name/description for the API key")):
|
||||
"""Create a new API key."""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
plain_key, db_key = create_api_key(db, name)
|
||||
|
||||
console.print("\n[green]✓ API Key created successfully![/green]\n")
|
||||
console.print(f"[bold]Name:[/bold] {db_key.name}")
|
||||
console.print(f"[bold]Created:[/bold] {db_key.created_at}")
|
||||
console.print(f"\n[bold yellow]API Key (save this, it won't be shown again):[/bold yellow]")
|
||||
console.print(f"[cyan]{plain_key}[/cyan]\n")
|
||||
console.print("[dim]Use this key in the X-API-Key header for all API requests.[/dim]\n")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error creating API key: {e}[/red]")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@app.command()
|
||||
def list_keys():
|
||||
"""List all API keys."""
|
||||
init_db()
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
keys = db.query(APIKey).order_by(APIKey.created_at.desc()).all()
|
||||
|
||||
if not keys:
|
||||
console.print("[yellow]No API keys found.[/yellow]")
|
||||
return
|
||||
|
||||
table = Table(title="API Keys")
|
||||
table.add_column("ID", style="cyan")
|
||||
table.add_column("Name", style="green")
|
||||
table.add_column("Created", style="blue")
|
||||
table.add_column("Status", style="magenta")
|
||||
|
||||
for key in keys:
|
||||
status = "Active" if key.is_active else "Revoked"
|
||||
status_color = "green" if key.is_active else "red"
|
||||
table.add_row(
|
||||
str(key.id),
|
||||
key.name,
|
||||
key.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
f"[{status_color}]{status}[/{status_color}]",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@app.command()
|
||||
def revoke_key(key_id: int = typer.Argument(..., help="API key ID to revoke")):
|
||||
"""Revoke an API key."""
|
||||
init_db()
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
key = db.query(APIKey).filter(APIKey.id == key_id).first()
|
||||
|
||||
if not key:
|
||||
console.print(f"[red]API key with ID {key_id} not found.[/red]")
|
||||
sys.exit(1)
|
||||
|
||||
if not key.is_active:
|
||||
console.print(f"[yellow]API key '{key.name}' is already revoked.[/yellow]")
|
||||
return
|
||||
|
||||
key.is_active = False
|
||||
db.commit()
|
||||
|
||||
console.print(f"[green]✓ API key '{key.name}' has been revoked.[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error revoking API key: {e}[/red]")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@app.command()
|
||||
def activate_key(key_id: int = typer.Argument(..., help="API key ID to activate")):
|
||||
"""Activate a revoked API key."""
|
||||
init_db()
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
key = db.query(APIKey).filter(APIKey.id == key_id).first()
|
||||
|
||||
if not key:
|
||||
console.print(f"[red]API key with ID {key_id} not found.[/red]")
|
||||
sys.exit(1)
|
||||
|
||||
if key.is_active:
|
||||
console.print(f"[yellow]API key '{key.name}' is already active.[/yellow]")
|
||||
return
|
||||
|
||||
key.is_active = True
|
||||
db.commit()
|
||||
|
||||
console.print(f"[green]✓ API key '{key.name}' has been activated.[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error activating API key: {e}[/red]")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@app.command()
|
||||
def init_database():
|
||||
"""Initialize the database (create tables)."""
|
||||
try:
|
||||
init_db()
|
||||
console.print("[green]✓ Database initialized successfully![/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error initializing database: {e}[/red]")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
"""Database models and connection management."""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Generator
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
create_engine,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, sessionmaker, relationship
|
||||
|
||||
# Database file location
|
||||
DATABASE_URL = os.getenv("API_DATABASE_URL", "sqlite:///./api_database.db")
|
||||
|
||||
# Create engine
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
connect_args={"check_same_thread": False} if DATABASE_URL.startswith("sqlite") else {},
|
||||
)
|
||||
|
||||
# Session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# Base class for models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Analysis(Base):
|
||||
"""Analysis metadata and configuration."""
|
||||
|
||||
__tablename__ = "analyses"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
ticker = Column(String, index=True, nullable=False)
|
||||
analysis_date = Column(String, nullable=False)
|
||||
status = Column(
|
||||
String, nullable=False, default="pending"
|
||||
) # pending, running, completed, failed, cancelled
|
||||
config_json = Column(Text, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
progress_percentage = Column(Integer, default=0)
|
||||
current_agent = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
logs = relationship("AnalysisLog", back_populates="analysis", cascade="all, delete-orphan")
|
||||
reports = relationship("AnalysisReport", back_populates="analysis", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class AnalysisLog(Base):
|
||||
"""Log entries from analysis execution."""
|
||||
|
||||
__tablename__ = "analysis_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
analysis_id = Column(String, ForeignKey("analyses.id"), nullable=False)
|
||||
timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
log_type = Column(String, nullable=False) # Tool Call, Reasoning, System
|
||||
content = Column(Text, nullable=False)
|
||||
|
||||
# Relationships
|
||||
analysis = relationship("Analysis", back_populates="logs")
|
||||
|
||||
|
||||
class AnalysisReport(Base):
|
||||
"""Report sections from analysis."""
|
||||
|
||||
__tablename__ = "analysis_reports"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
analysis_id = Column(String, ForeignKey("analyses.id"), nullable=False)
|
||||
report_type = Column(String, nullable=False) # market_report, news_report, etc.
|
||||
content = Column(Text, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
analysis = relationship("Analysis", back_populates="reports")
|
||||
|
||||
|
||||
class APIKey(Base):
|
||||
"""API keys for authentication."""
|
||||
|
||||
__tablename__ = "api_keys"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
key_hash = Column(String, unique=True, nullable=False, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Initialize database and create all tables."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""Dependency for getting database sessions."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
"""API endpoint routers."""
|
||||
|
||||
|
|
@ -0,0 +1,356 @@
|
|||
"""Analysis CRUD endpoints."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.auth import APIKey, get_current_api_key
|
||||
from api.database import Analysis, AnalysisLog, AnalysisReport, get_db
|
||||
from api.models import (
|
||||
AnalysisResponse,
|
||||
AnalysisStatusResponse,
|
||||
AnalysisSummary,
|
||||
CreateAnalysisRequest,
|
||||
LogEntry,
|
||||
ReportResponse,
|
||||
)
|
||||
from api.state_manager import get_executor
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
router = APIRouter(prefix="/api/v1/analyses", tags=["analyses"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("", response_model=AnalysisResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_analysis(
|
||||
request: CreateAnalysisRequest,
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Create and start a new analysis."""
|
||||
# Generate analysis ID
|
||||
analysis_id = str(uuid.uuid4())
|
||||
|
||||
# Build configuration
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["max_debate_rounds"] = request.research_depth
|
||||
config["max_risk_discuss_rounds"] = request.research_depth
|
||||
|
||||
if request.llm_provider:
|
||||
config["llm_provider"] = request.llm_provider
|
||||
if request.backend_url:
|
||||
config["backend_url"] = request.backend_url
|
||||
if request.quick_think_llm:
|
||||
config["quick_think_llm"] = request.quick_think_llm
|
||||
if request.deep_think_llm:
|
||||
config["deep_think_llm"] = request.deep_think_llm
|
||||
|
||||
# Auto-detect asset class
|
||||
from cli.asset_detection import detect_asset_class
|
||||
asset_class = detect_asset_class(request.ticker)
|
||||
config["asset_class"] = asset_class
|
||||
|
||||
# Filter out fundamentals for commodities/crypto
|
||||
selected_analysts = request.selected_analysts
|
||||
if asset_class in ["commodity", "crypto"] and "fundamentals" in selected_analysts:
|
||||
selected_analysts = [a for a in selected_analysts if a != "fundamentals"]
|
||||
|
||||
# Create database record
|
||||
analysis = Analysis(
|
||||
id=analysis_id,
|
||||
ticker=request.ticker,
|
||||
analysis_date=request.analysis_date,
|
||||
status="pending",
|
||||
config_json=json.dumps(config),
|
||||
progress_percentage=0,
|
||||
)
|
||||
db.add(analysis)
|
||||
db.commit()
|
||||
db.refresh(analysis)
|
||||
|
||||
# Start analysis in background
|
||||
logger.info(f"Creating analysis {analysis_id} for {request.ticker}")
|
||||
executor = get_executor()
|
||||
try:
|
||||
executor.start_analysis(
|
||||
analysis_id=analysis_id,
|
||||
ticker=request.ticker,
|
||||
analysis_date=request.analysis_date,
|
||||
selected_analysts=selected_analysts,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Analysis {analysis_id} started successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start analysis {analysis_id}: {str(e)}")
|
||||
# Update status to failed
|
||||
analysis.status = "failed"
|
||||
analysis.error_message = str(e)
|
||||
db.commit()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to start analysis: {str(e)}",
|
||||
)
|
||||
|
||||
return AnalysisResponse(
|
||||
id=analysis.id,
|
||||
ticker=analysis.ticker,
|
||||
analysis_date=analysis.analysis_date,
|
||||
status=analysis.status,
|
||||
config=config,
|
||||
reports=[],
|
||||
progress_percentage=analysis.progress_percentage,
|
||||
current_agent=analysis.current_agent,
|
||||
created_at=analysis.created_at,
|
||||
updated_at=analysis.updated_at,
|
||||
completed_at=analysis.completed_at,
|
||||
error_message=analysis.error_message,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[AnalysisSummary])
|
||||
async def list_analyses(
|
||||
ticker: Optional[str] = Query(None, description="Filter by ticker"),
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
date_from: Optional[str] = Query(None, description="Filter by date (from)"),
|
||||
date_to: Optional[str] = Query(None, description="Filter by date (to)"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Max results"),
|
||||
offset: int = Query(0, ge=0, description="Offset for pagination"),
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""List all analyses with optional filtering."""
|
||||
query = db.query(Analysis)
|
||||
|
||||
if ticker:
|
||||
query = query.filter(Analysis.ticker == ticker.upper())
|
||||
if status:
|
||||
query = query.filter(Analysis.status == status)
|
||||
if date_from:
|
||||
query = query.filter(Analysis.analysis_date >= date_from)
|
||||
if date_to:
|
||||
query = query.filter(Analysis.analysis_date <= date_to)
|
||||
|
||||
# Order by created_at descending
|
||||
query = query.order_by(Analysis.created_at.desc())
|
||||
|
||||
# Apply pagination
|
||||
analyses = query.offset(offset).limit(limit).all()
|
||||
|
||||
return [
|
||||
AnalysisSummary(
|
||||
id=a.id,
|
||||
ticker=a.ticker,
|
||||
analysis_date=a.analysis_date,
|
||||
status=a.status,
|
||||
created_at=a.created_at,
|
||||
completed_at=a.completed_at,
|
||||
error_message=a.error_message,
|
||||
)
|
||||
for a in analyses
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{analysis_id}", response_model=AnalysisResponse)
|
||||
async def get_analysis(
|
||||
analysis_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Get full analysis details."""
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Analysis {analysis_id} not found",
|
||||
)
|
||||
|
||||
# Get reports
|
||||
reports = (
|
||||
db.query(AnalysisReport)
|
||||
.filter(AnalysisReport.analysis_id == analysis_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return AnalysisResponse(
|
||||
id=analysis.id,
|
||||
ticker=analysis.ticker,
|
||||
analysis_date=analysis.analysis_date,
|
||||
status=analysis.status,
|
||||
config=json.loads(analysis.config_json),
|
||||
reports=[
|
||||
ReportResponse(
|
||||
report_type=r.report_type,
|
||||
content=r.content,
|
||||
created_at=r.created_at,
|
||||
)
|
||||
for r in reports
|
||||
],
|
||||
progress_percentage=analysis.progress_percentage,
|
||||
current_agent=analysis.current_agent,
|
||||
created_at=analysis.created_at,
|
||||
updated_at=analysis.updated_at,
|
||||
completed_at=analysis.completed_at,
|
||||
error_message=analysis.error_message,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{analysis_id}/status", response_model=AnalysisStatusResponse)
|
||||
async def get_analysis_status(
|
||||
analysis_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Get current analysis status."""
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Analysis {analysis_id} not found",
|
||||
)
|
||||
|
||||
return AnalysisStatusResponse(
|
||||
id=analysis.id,
|
||||
status=analysis.status,
|
||||
progress_percentage=analysis.progress_percentage,
|
||||
current_agent=analysis.current_agent,
|
||||
updated_at=analysis.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{analysis_id}/reports", response_model=List[ReportResponse])
|
||||
async def get_analysis_reports(
|
||||
analysis_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Get all reports for an analysis."""
|
||||
# Check if analysis exists
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if not analysis:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Analysis {analysis_id} not found",
|
||||
)
|
||||
|
||||
reports = (
|
||||
db.query(AnalysisReport)
|
||||
.filter(AnalysisReport.analysis_id == analysis_id)
|
||||
.order_by(AnalysisReport.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
ReportResponse(
|
||||
report_type=r.report_type,
|
||||
content=r.content,
|
||||
created_at=r.created_at,
|
||||
)
|
||||
for r in reports
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{analysis_id}/reports/{report_type}", response_model=ReportResponse)
|
||||
async def get_analysis_report(
|
||||
analysis_id: str,
|
||||
report_type: str,
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Get a specific report for an analysis."""
|
||||
report = (
|
||||
db.query(AnalysisReport)
|
||||
.filter(
|
||||
AnalysisReport.analysis_id == analysis_id,
|
||||
AnalysisReport.report_type == report_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not report:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Report {report_type} not found for analysis {analysis_id}",
|
||||
)
|
||||
|
||||
return ReportResponse(
|
||||
report_type=report.report_type,
|
||||
content=report.content,
|
||||
created_at=report.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{analysis_id}/logs", response_model=List[LogEntry])
|
||||
async def get_analysis_logs(
|
||||
analysis_id: str,
|
||||
log_type: Optional[str] = Query(None, description="Filter by log type"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Max results"),
|
||||
offset: int = Query(0, ge=0, description="Offset for pagination"),
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Get execution logs for an analysis."""
|
||||
# Check if analysis exists
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if not analysis:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Analysis {analysis_id} not found",
|
||||
)
|
||||
|
||||
query = db.query(AnalysisLog).filter(AnalysisLog.analysis_id == analysis_id)
|
||||
|
||||
if log_type:
|
||||
query = query.filter(AnalysisLog.log_type == log_type)
|
||||
|
||||
logs = query.order_by(AnalysisLog.timestamp).offset(offset).limit(limit).all()
|
||||
|
||||
return [
|
||||
LogEntry(
|
||||
timestamp=log.timestamp,
|
||||
log_type=log.log_type,
|
||||
content=log.content,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/{analysis_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_analysis(
|
||||
analysis_id: str,
|
||||
permanent: bool = Query(False, description="Permanently delete from database"),
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Cancel and/or delete an analysis."""
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Analysis {analysis_id} not found",
|
||||
)
|
||||
|
||||
# Try to cancel if running
|
||||
if analysis.status in ["pending", "running"]:
|
||||
executor = get_executor()
|
||||
executor.cancel_analysis(analysis_id)
|
||||
|
||||
# Delete from database if requested
|
||||
if permanent:
|
||||
db.delete(analysis)
|
||||
db.commit()
|
||||
elif analysis.status not in ["cancelled", "failed", "completed"]:
|
||||
# Just mark as cancelled
|
||||
analysis.status = "cancelled"
|
||||
analysis.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
"""Cached data access endpoints."""
|
||||
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from api.auth import APIKey, get_current_api_key
|
||||
from api.models.responses import CachedDataResponse, CachedTickerInfo
|
||||
|
||||
router = APIRouter(prefix="/api/v1/data", tags=["data"])
|
||||
|
||||
# Data cache directory
|
||||
DATA_CACHE_DIR = Path("./tradingagents/dataflows/data_cache")
|
||||
|
||||
|
||||
def _parse_date_range(filename: str) -> Optional[Dict[str, str]]:
|
||||
"""Parse date range from cache filename."""
|
||||
try:
|
||||
# Format: TICKER-YFin-data-START-END.csv
|
||||
parts = filename.replace(".csv", "").split("-")
|
||||
if len(parts) >= 5:
|
||||
start_date = parts[-2]
|
||||
end_date = parts[-1]
|
||||
return {"start": start_date, "end": end_date}
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/cache", response_model=List[CachedTickerInfo])
|
||||
async def list_cached_tickers(
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""List all cached tickers with date ranges."""
|
||||
if not DATA_CACHE_DIR.exists():
|
||||
return []
|
||||
|
||||
cached_tickers = []
|
||||
|
||||
for csv_file in DATA_CACHE_DIR.glob("*-YFin-data-*.csv"):
|
||||
ticker = csv_file.name.split("-")[0]
|
||||
date_range = _parse_date_range(csv_file.name)
|
||||
|
||||
if date_range:
|
||||
# Count records
|
||||
try:
|
||||
with open(csv_file, "r") as f:
|
||||
record_count = sum(1 for _ in f) - 1 # Subtract header
|
||||
except:
|
||||
record_count = 0
|
||||
|
||||
cached_tickers.append(
|
||||
CachedTickerInfo(
|
||||
ticker=ticker,
|
||||
date_range=date_range,
|
||||
record_count=record_count,
|
||||
)
|
||||
)
|
||||
|
||||
return sorted(cached_tickers, key=lambda x: x.ticker)
|
||||
|
||||
|
||||
@router.get("/cache/{ticker}", response_model=CachedDataResponse)
|
||||
async def get_cached_data(
|
||||
ticker: str,
|
||||
start_date: Optional[str] = Query(None, description="Filter from date (YYYY-MM-DD)"),
|
||||
end_date: Optional[str] = Query(None, description="Filter to date (YYYY-MM-DD)"),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Get cached market data for a ticker."""
|
||||
# Find matching file
|
||||
pattern = f"{ticker.upper()}-YFin-data-*.csv"
|
||||
matching_files = list(DATA_CACHE_DIR.glob(pattern))
|
||||
|
||||
if not matching_files:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No cached data found for ticker {ticker}",
|
||||
)
|
||||
|
||||
# Use the first matching file (should only be one)
|
||||
csv_file = matching_files[0]
|
||||
date_range = _parse_date_range(csv_file.name)
|
||||
|
||||
if not date_range:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not parse date range from cache file",
|
||||
)
|
||||
|
||||
# Read CSV data
|
||||
data = []
|
||||
try:
|
||||
with open(csv_file, "r") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
# Filter by date if specified
|
||||
row_date = row.get("Date", "")
|
||||
if start_date and row_date < start_date:
|
||||
continue
|
||||
if end_date and row_date > end_date:
|
||||
continue
|
||||
|
||||
# Convert numeric fields
|
||||
for field in ["Close", "High", "Low", "Open", "Volume"]:
|
||||
if field in row and row[field]:
|
||||
try:
|
||||
row[field] = float(row[field])
|
||||
except:
|
||||
pass
|
||||
|
||||
data.append(row)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error reading cache file: {str(e)}",
|
||||
)
|
||||
|
||||
return CachedDataResponse(
|
||||
ticker=ticker.upper(),
|
||||
date_range=date_range,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""Ticker history endpoints."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.auth import APIKey, get_current_api_key
|
||||
from api.database import Analysis, get_db
|
||||
from api.models import AnalysisResponse, AnalysisSummary, ReportResponse, TickerInfo
|
||||
from api.endpoints.analyses import get_analysis
|
||||
|
||||
router = APIRouter(prefix="/api/v1/tickers", tags=["tickers"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[TickerInfo])
|
||||
async def list_tickers(
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""List all tickers with analysis count."""
|
||||
# Query for ticker stats
|
||||
results = (
|
||||
db.query(
|
||||
Analysis.ticker,
|
||||
func.count(Analysis.id).label("analysis_count"),
|
||||
func.max(Analysis.analysis_date).label("latest_date"),
|
||||
)
|
||||
.group_by(Analysis.ticker)
|
||||
.order_by(Analysis.ticker)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
TickerInfo(
|
||||
ticker=r.ticker,
|
||||
analysis_count=r.analysis_count,
|
||||
latest_date=r.latest_date,
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{ticker}/analyses", response_model=List[AnalysisSummary])
|
||||
async def get_ticker_analyses(
|
||||
ticker: str,
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Get all analyses for a ticker."""
|
||||
analyses = (
|
||||
db.query(Analysis)
|
||||
.filter(Analysis.ticker == ticker.upper())
|
||||
.order_by(Analysis.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
AnalysisSummary(
|
||||
id=a.id,
|
||||
ticker=a.ticker,
|
||||
analysis_date=a.analysis_date,
|
||||
status=a.status,
|
||||
created_at=a.created_at,
|
||||
completed_at=a.completed_at,
|
||||
error_message=a.error_message,
|
||||
)
|
||||
for a in analyses
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{ticker}/latest", response_model=AnalysisResponse)
|
||||
async def get_ticker_latest_analysis(
|
||||
ticker: str,
|
||||
db: Session = Depends(get_db),
|
||||
api_key: APIKey = Depends(get_current_api_key),
|
||||
):
|
||||
"""Get the most recent analysis for a ticker."""
|
||||
analysis = (
|
||||
db.query(Analysis)
|
||||
.filter(Analysis.ticker == ticker.upper())
|
||||
.order_by(Analysis.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not analysis:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No analyses found for ticker {ticker}",
|
||||
)
|
||||
|
||||
# Use the existing get_analysis function
|
||||
return await get_analysis(analysis.id, db, api_key)
|
||||
|
||||
|
|
@ -0,0 +1,169 @@
|
|||
"""Example client for Testing the Trading Agents API."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import date
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
|
||||
|
||||
class TradingAgentsAPIClient:
|
||||
"""Simple client for interacting with the Trading Agents API."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str = "http://localhost:8000"):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.headers = {
|
||||
"X-API-Key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def create_analysis(
|
||||
self,
|
||||
ticker: str,
|
||||
analysis_date: str = None,
|
||||
selected_analysts: list = None,
|
||||
research_depth: int = 1,
|
||||
):
|
||||
"""Create a new analysis."""
|
||||
if analysis_date is None:
|
||||
analysis_date = date.today().strftime("%Y-%m-%d")
|
||||
|
||||
if selected_analysts is None:
|
||||
selected_analysts = ["market", "news"]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/analyses",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"ticker": ticker,
|
||||
"analysis_date": analysis_date,
|
||||
"selected_analysts": selected_analysts,
|
||||
"research_depth": research_depth,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_analysis(self, analysis_id: str):
|
||||
"""Get full analysis details."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/analyses/{analysis_id}",
|
||||
headers=self.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_status(self, analysis_id: str):
|
||||
"""Get analysis status."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/analyses/{analysis_id}/status",
|
||||
headers=self.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def list_analyses(self, ticker: str = None):
|
||||
"""List all analyses."""
|
||||
params = {}
|
||||
if ticker:
|
||||
params["ticker"] = ticker
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/api/v1/analyses",
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def monitor_via_websocket(self, analysis_id: str, duration: int = 300):
|
||||
"""Monitor analysis via WebSocket."""
|
||||
ws_url = f"ws://localhost:8000/api/v1/ws/analyses/{analysis_id}"
|
||||
|
||||
print(f"Connecting to WebSocket: {ws_url}")
|
||||
|
||||
try:
|
||||
async with websockets.connect(ws_url) as websocket:
|
||||
print("Connected! Waiting for updates...")
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < duration:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
websocket.recv(), timeout=10.0
|
||||
)
|
||||
data = json.loads(message)
|
||||
|
||||
print(f"\n[Update] Status: {data['status']}")
|
||||
print(f" Progress: {data['progress_percentage']}%")
|
||||
if data.get('current_agent'):
|
||||
print(f" Agent: {data['current_agent']}")
|
||||
|
||||
if data['status'] in ['completed', 'failed', 'cancelled']:
|
||||
print("\nAnalysis finished!")
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Send ping to keep connection alive
|
||||
await websocket.send("ping")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Example usage."""
|
||||
# Replace with your actual API key
|
||||
API_KEY = "your-api-key-here"
|
||||
|
||||
client = TradingAgentsAPIClient(API_KEY)
|
||||
|
||||
print("=" * 60)
|
||||
print("Trading Agents API Client Example")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Create an analysis
|
||||
print("\n1. Creating analysis for AAPL...")
|
||||
analysis = await client.create_analysis(
|
||||
ticker="AAPL",
|
||||
selected_analysts=["market", "news"],
|
||||
research_depth=1,
|
||||
)
|
||||
analysis_id = analysis["id"]
|
||||
print(f" Created: {analysis_id}")
|
||||
print(f" Status: {analysis['status']}")
|
||||
|
||||
# 2. Monitor via WebSocket (run this in background or separately)
|
||||
print("\n2. Monitoring via WebSocket...")
|
||||
await client.monitor_via_websocket(analysis_id, duration=600)
|
||||
|
||||
# 3. Get final results
|
||||
print("\n3. Getting final results...")
|
||||
final = await client.get_analysis(analysis_id)
|
||||
print(f" Status: {final['status']}")
|
||||
print(f" Reports: {len(final['reports'])} available")
|
||||
|
||||
for report in final['reports']:
|
||||
print(f"\n - {report['report_type']}:")
|
||||
print(f" {report['content'][:200]}...")
|
||||
|
||||
# 4. List all analyses
|
||||
print("\n4. Listing all AAPL analyses...")
|
||||
all_analyses = await client.list_analyses(ticker="AAPL")
|
||||
print(f" Found {len(all_analyses)} analyses")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
"""FastAPI Trading Agents API application."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
from api.database import init_db
|
||||
from api.endpoints import analyses, data, tickers
|
||||
from api.state_manager import get_executor, shutdown_executor
|
||||
from api.websockets import status
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events."""
|
||||
# Startup
|
||||
logger.info("Initializing Trading Agents API...")
|
||||
init_db()
|
||||
get_executor()
|
||||
logger.info("Trading Agents API started successfully")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down Trading Agents API...")
|
||||
shutdown_executor()
|
||||
logger.info("Trading Agents API shutdown complete")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Trading Agents API",
|
||||
description="REST API for managing multi-agent trading analyses",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify actual origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(analyses.router)
|
||||
app.include_router(tickers.router)
|
||||
app.include_router(data.router)
|
||||
app.include_router(status.router)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
return {
|
||||
"name": "Trading Agents API",
|
||||
"version": "1.0.0",
|
||||
"docs": "/docs",
|
||||
"redoc": "/redoc",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
import os
|
||||
|
||||
port = int(os.getenv("API_PORT", "8002")) # Default to 8001 instead of 8000
|
||||
|
||||
uvicorn.run(
|
||||
"api.main:app",
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
reload=True,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""Pydantic models for API requests and responses."""
|
||||
|
||||
from api.models.requests import CreateAnalysisRequest, UpdateAnalysisRequest
|
||||
from api.models.responses import (
|
||||
AnalysisResponse,
|
||||
AnalysisStatusResponse,
|
||||
AnalysisSummary,
|
||||
ErrorResponse,
|
||||
LogEntry,
|
||||
ReportResponse,
|
||||
TickerInfo,
|
||||
CachedDataResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CreateAnalysisRequest",
|
||||
"UpdateAnalysisRequest",
|
||||
"AnalysisResponse",
|
||||
"AnalysisStatusResponse",
|
||||
"AnalysisSummary",
|
||||
"ErrorResponse",
|
||||
"LogEntry",
|
||||
"ReportResponse",
|
||||
"TickerInfo",
|
||||
"CachedDataResponse",
|
||||
]
|
||||
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
"""Request models for API endpoints."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class CreateAnalysisRequest(BaseModel):
|
||||
"""Request to create a new analysis."""
|
||||
|
||||
ticker: str = Field(..., description="Ticker symbol to analyze")
|
||||
analysis_date: str = Field(..., description="Analysis date in YYYY-MM-DD format")
|
||||
selected_analysts: List[str] = Field(
|
||||
default=["market", "news", "social", "fundamentals"],
|
||||
description="List of analysts to run (market, news, social, fundamentals)",
|
||||
)
|
||||
research_depth: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
le=5,
|
||||
description="Research depth (1=shallow, 3=medium, 5=deep)",
|
||||
)
|
||||
llm_provider: Optional[str] = Field(
|
||||
default=None,
|
||||
description="LLM provider (openai, anthropic, google). Uses default if not specified.",
|
||||
)
|
||||
backend_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Backend URL for LLM. Uses default if not specified.",
|
||||
)
|
||||
quick_think_llm: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Model for quick thinking. Uses default if not specified.",
|
||||
)
|
||||
deep_think_llm: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Model for deep thinking. Uses default if not specified.",
|
||||
)
|
||||
|
||||
@field_validator("ticker")
|
||||
@classmethod
|
||||
def ticker_uppercase(cls, v: str) -> str:
|
||||
"""Convert ticker to uppercase."""
|
||||
return v.upper().strip()
|
||||
|
||||
@field_validator("selected_analysts")
|
||||
@classmethod
|
||||
def validate_analysts(cls, v: List[str]) -> List[str]:
|
||||
"""Validate analyst selections."""
|
||||
valid_analysts = {"market", "news", "social", "fundamentals"}
|
||||
for analyst in v:
|
||||
if analyst not in valid_analysts:
|
||||
raise ValueError(
|
||||
f"Invalid analyst: {analyst}. Must be one of {valid_analysts}"
|
||||
)
|
||||
if not v:
|
||||
raise ValueError("At least one analyst must be selected")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateAnalysisRequest(BaseModel):
|
||||
"""Request to update analysis metadata."""
|
||||
|
||||
status: Optional[str] = Field(
|
||||
default=None,
|
||||
description="New status (pending, running, completed, failed, cancelled)",
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""Response models for API endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Standard error response."""
|
||||
|
||||
error: str = Field(..., description="Error message")
|
||||
detail: Optional[str] = Field(None, description="Additional error details")
|
||||
|
||||
|
||||
class ReportResponse(BaseModel):
|
||||
"""Single report section."""
|
||||
|
||||
report_type: str = Field(..., description="Type of report")
|
||||
content: str = Field(..., description="Report content in markdown")
|
||||
created_at: datetime = Field(..., description="When the report was created")
|
||||
|
||||
|
||||
class LogEntry(BaseModel):
|
||||
"""Single log entry."""
|
||||
|
||||
timestamp: datetime = Field(..., description="When the log was created")
|
||||
log_type: str = Field(..., description="Type of log (Tool Call, Reasoning, System)")
|
||||
content: str = Field(..., description="Log content")
|
||||
|
||||
|
||||
class AnalysisStatusResponse(BaseModel):
|
||||
"""Lightweight analysis status."""
|
||||
|
||||
id: str = Field(..., description="Analysis ID")
|
||||
status: str = Field(..., description="Current status")
|
||||
progress_percentage: int = Field(..., description="Progress percentage (0-100)")
|
||||
current_agent: Optional[str] = Field(None, description="Currently active agent")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
|
||||
class AnalysisSummary(BaseModel):
|
||||
"""Summary view of an analysis."""
|
||||
|
||||
id: str = Field(..., description="Analysis ID")
|
||||
ticker: str = Field(..., description="Ticker symbol")
|
||||
analysis_date: str = Field(..., description="Analysis date")
|
||||
status: str = Field(..., description="Current status")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
completed_at: Optional[datetime] = Field(None, description="Completion timestamp")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
class AnalysisResponse(BaseModel):
|
||||
"""Full analysis details."""
|
||||
|
||||
id: str = Field(..., description="Analysis ID")
|
||||
ticker: str = Field(..., description="Ticker symbol")
|
||||
analysis_date: str = Field(..., description="Analysis date")
|
||||
status: str = Field(..., description="Current status")
|
||||
config: Dict = Field(..., description="Analysis configuration")
|
||||
reports: List[ReportResponse] = Field(
|
||||
default_factory=list, description="All report sections"
|
||||
)
|
||||
progress_percentage: int = Field(..., description="Progress percentage (0-100)")
|
||||
current_agent: Optional[str] = Field(None, description="Currently active agent")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
completed_at: Optional[datetime] = Field(None, description="Completion timestamp")
|
||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
class TickerInfo(BaseModel):
|
||||
"""Ticker information summary."""
|
||||
|
||||
ticker: str = Field(..., description="Ticker symbol")
|
||||
analysis_count: int = Field(..., description="Number of analyses")
|
||||
latest_date: Optional[str] = Field(None, description="Most recent analysis date")
|
||||
|
||||
|
||||
class CachedDataResponse(BaseModel):
|
||||
"""Cached market data response."""
|
||||
|
||||
ticker: str = Field(..., description="Ticker symbol")
|
||||
date_range: Dict[str, str] = Field(..., description="Start and end dates")
|
||||
data: List[Dict] = Field(..., description="OHLCV data records")
|
||||
|
||||
|
||||
class CachedTickerInfo(BaseModel):
|
||||
"""Information about cached ticker data."""
|
||||
|
||||
ticker: str = Field(..., description="Ticker symbol")
|
||||
date_range: Dict[str, str] = Field(..., description="Start and end dates")
|
||||
record_count: int = Field(..., description="Number of records")
|
||||
|
||||
|
|
@ -0,0 +1,447 @@
|
|||
"""Analysis execution and state management."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.database import Analysis, AnalysisLog, AnalysisReport, SessionLocal
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnalysisExecutor:
|
||||
"""Manages analysis execution with thread pool and state tracking."""
|
||||
|
||||
def __init__(self, max_workers: int = None):
|
||||
"""
|
||||
Initialize the executor.
|
||||
|
||||
Args:
|
||||
max_workers: Maximum concurrent analyses (default: from env or 4)
|
||||
"""
|
||||
if max_workers is None:
|
||||
max_workers = int(os.getenv("MAX_CONCURRENT_ANALYSES", "4"))
|
||||
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.active_analyses: Dict[str, Future] = {}
|
||||
self.status_callbacks: Dict[str, List[Callable]] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def register_status_callback(self, analysis_id: str, callback: Callable):
|
||||
"""Register a callback for status updates."""
|
||||
with self._lock:
|
||||
if analysis_id not in self.status_callbacks:
|
||||
self.status_callbacks[analysis_id] = []
|
||||
self.status_callbacks[analysis_id].append(callback)
|
||||
|
||||
def unregister_status_callbacks(self, analysis_id: str):
|
||||
"""Remove all callbacks for an analysis."""
|
||||
with self._lock:
|
||||
if analysis_id in self.status_callbacks:
|
||||
del self.status_callbacks[analysis_id]
|
||||
|
||||
def _notify_callbacks(self, analysis_id: str, status_data: Dict[str, Any]):
|
||||
"""Notify all registered callbacks."""
|
||||
with self._lock:
|
||||
callbacks = self.status_callbacks.get(analysis_id, [])
|
||||
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback(status_data)
|
||||
except Exception as e:
|
||||
print(f"Error in status callback: {e}")
|
||||
|
||||
def start_analysis(
|
||||
self,
|
||||
analysis_id: str,
|
||||
ticker: str,
|
||||
analysis_date: str,
|
||||
selected_analysts: List[str],
|
||||
config: Dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
Start a new analysis in the background.
|
||||
|
||||
Args:
|
||||
analysis_id: Unique analysis ID
|
||||
ticker: Ticker symbol
|
||||
analysis_date: Analysis date
|
||||
selected_analysts: List of analyst types
|
||||
config: Trading agents configuration
|
||||
|
||||
Returns:
|
||||
analysis_id
|
||||
"""
|
||||
future = self.executor.submit(
|
||||
self._run_analysis,
|
||||
analysis_id,
|
||||
ticker,
|
||||
analysis_date,
|
||||
selected_analysts,
|
||||
config,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self.active_analyses[analysis_id] = future
|
||||
|
||||
# Cleanup when done
|
||||
future.add_done_callback(lambda f: self._cleanup_analysis(analysis_id))
|
||||
|
||||
return analysis_id
|
||||
|
||||
def _cleanup_analysis(self, analysis_id: str):
|
||||
"""Clean up after analysis completes."""
|
||||
with self._lock:
|
||||
if analysis_id in self.active_analyses:
|
||||
del self.active_analyses[analysis_id]
|
||||
self.unregister_status_callbacks(analysis_id)
|
||||
|
||||
def cancel_analysis(self, analysis_id: str) -> bool:
|
||||
"""
|
||||
Attempt to cancel a running analysis.
|
||||
|
||||
Returns:
|
||||
True if cancelled, False if not found or already completed
|
||||
"""
|
||||
with self._lock:
|
||||
future = self.active_analyses.get(analysis_id)
|
||||
|
||||
if future and not future.done():
|
||||
cancelled = future.cancel()
|
||||
if cancelled:
|
||||
# Update database status
|
||||
db = SessionLocal()
|
||||
try:
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if analysis:
|
||||
analysis.status = "cancelled"
|
||||
analysis.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
return cancelled
|
||||
return False
|
||||
|
||||
def get_status(self, analysis_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get current status of an analysis."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if not analysis:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": analysis.id,
|
||||
"status": analysis.status,
|
||||
"progress_percentage": analysis.progress_percentage,
|
||||
"current_agent": analysis.current_agent,
|
||||
"updated_at": analysis.updated_at,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _update_status(
|
||||
self,
|
||||
analysis_id: str,
|
||||
status: Optional[str] = None,
|
||||
progress: Optional[int] = None,
|
||||
current_agent: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
):
|
||||
"""Update analysis status in database and notify callbacks."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if not analysis:
|
||||
return
|
||||
|
||||
if status:
|
||||
analysis.status = status
|
||||
if progress is not None:
|
||||
analysis.progress_percentage = progress
|
||||
if current_agent:
|
||||
analysis.current_agent = current_agent
|
||||
if error_message:
|
||||
analysis.error_message = error_message
|
||||
|
||||
analysis.updated_at = datetime.utcnow()
|
||||
|
||||
if status == "completed":
|
||||
analysis.completed_at = datetime.utcnow()
|
||||
analysis.progress_percentage = 100
|
||||
|
||||
db.commit()
|
||||
db.refresh(analysis)
|
||||
|
||||
# Notify callbacks
|
||||
status_data = {
|
||||
"type": "status_update",
|
||||
"analysis_id": analysis.id,
|
||||
"status": analysis.status,
|
||||
"progress_percentage": analysis.progress_percentage,
|
||||
"current_agent": analysis.current_agent,
|
||||
"timestamp": analysis.updated_at.isoformat(),
|
||||
}
|
||||
self._notify_callbacks(analysis_id, status_data)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _store_log(self, analysis_id: str, log_type: str, content: str):
|
||||
"""Store a log entry."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
log = AnalysisLog(
|
||||
analysis_id=analysis_id,
|
||||
log_type=log_type,
|
||||
content=content,
|
||||
timestamp=datetime.utcnow(),
|
||||
)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _store_report(self, analysis_id: str, report_type: str, content: str):
|
||||
"""Store or update a report section."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Check if report already exists
|
||||
report = (
|
||||
db.query(AnalysisReport)
|
||||
.filter(
|
||||
AnalysisReport.analysis_id == analysis_id,
|
||||
AnalysisReport.report_type == report_type,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if report:
|
||||
# Update existing
|
||||
report.content = content
|
||||
report.created_at = datetime.utcnow()
|
||||
else:
|
||||
# Create new
|
||||
report = AnalysisReport(
|
||||
analysis_id=analysis_id,
|
||||
report_type=report_type,
|
||||
content=content,
|
||||
)
|
||||
db.add(report)
|
||||
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _run_analysis(
|
||||
self,
|
||||
analysis_id: str,
|
||||
ticker: str,
|
||||
analysis_date: str,
|
||||
selected_analysts: List[str],
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
"""Execute the analysis (runs in thread pool)."""
|
||||
logger.info(f"Starting analysis {analysis_id} for {ticker} on {analysis_date}")
|
||||
try:
|
||||
# Update status to running
|
||||
self._update_status(analysis_id, status="running", progress=0)
|
||||
logger.info(f"Analysis {analysis_id}: Initializing trading graph...")
|
||||
|
||||
# Initialize the graph
|
||||
graph = TradingAgentsGraph(
|
||||
selected_analysts=selected_analysts,
|
||||
config=config,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
# Create initial state
|
||||
init_agent_state = graph.propagator.create_initial_state(ticker, analysis_date)
|
||||
init_agent_state["asset_class"] = config.get("asset_class", "equity")
|
||||
args = graph.propagator.get_graph_args()
|
||||
|
||||
# Track agent progress
|
||||
agent_order = self._get_agent_order(selected_analysts)
|
||||
total_agents = len(agent_order)
|
||||
current_agent_index = 0
|
||||
|
||||
# Stream the analysis
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
if len(chunk.get("messages", [])) == 0:
|
||||
continue
|
||||
|
||||
# Process the chunk
|
||||
last_message = chunk["messages"][-1]
|
||||
|
||||
# Extract content
|
||||
if hasattr(last_message, "content"):
|
||||
content = self._extract_content(last_message.content)
|
||||
msg_type = "Reasoning"
|
||||
|
||||
# Store log
|
||||
self._store_log(analysis_id, msg_type, content)
|
||||
|
||||
# Handle tool calls
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
else:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.args
|
||||
|
||||
args_str = ", ".join(f"{k}={v}" for k, v in tool_args.items())
|
||||
self._store_log(
|
||||
analysis_id, "Tool Call", f"{tool_name}({args_str})"
|
||||
)
|
||||
|
||||
# Check for completed reports
|
||||
for report_type in [
|
||||
"market_report",
|
||||
"sentiment_report",
|
||||
"news_report",
|
||||
"fundamentals_report",
|
||||
]:
|
||||
if report_type in chunk and chunk[report_type]:
|
||||
self._store_report(analysis_id, report_type, chunk[report_type])
|
||||
current_agent_index += 1
|
||||
progress = int((current_agent_index / total_agents) * 100)
|
||||
agent_name = self._get_agent_name(report_type)
|
||||
self._update_status(
|
||||
analysis_id,
|
||||
progress=min(progress, 95),
|
||||
current_agent=agent_name,
|
||||
)
|
||||
|
||||
# Check for investment debate state
|
||||
if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
if "judge_decision" in debate_state and debate_state["judge_decision"]:
|
||||
self._store_report(
|
||||
analysis_id, "investment_plan", debate_state["judge_decision"]
|
||||
)
|
||||
current_agent_index += 1
|
||||
progress = int((current_agent_index / total_agents) * 100)
|
||||
self._update_status(
|
||||
analysis_id,
|
||||
progress=min(progress, 98),
|
||||
current_agent="Research Manager",
|
||||
)
|
||||
|
||||
# Check for trader plan
|
||||
if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
|
||||
self._store_report(
|
||||
analysis_id, "trader_investment_plan", chunk["trader_investment_plan"]
|
||||
)
|
||||
self._update_status(
|
||||
analysis_id,
|
||||
progress=99,
|
||||
current_agent="Trader",
|
||||
)
|
||||
|
||||
trace.append(chunk)
|
||||
|
||||
# Get final state
|
||||
if trace:
|
||||
final_state = trace[-1]
|
||||
|
||||
# Store final trade decision
|
||||
if "final_trade_decision" in final_state:
|
||||
decision = graph.process_signal(final_state["final_trade_decision"])
|
||||
self._store_report(
|
||||
analysis_id, "final_trade_decision", final_state["final_trade_decision"]
|
||||
)
|
||||
self._store_log(
|
||||
analysis_id, "System", f"Final decision: {decision}"
|
||||
)
|
||||
|
||||
# Mark as completed
|
||||
logger.info(f"Analysis {analysis_id} completed successfully")
|
||||
self._update_status(analysis_id, status="completed", progress=100)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_trace = traceback.format_exc()
|
||||
logger.error(f"Analysis {analysis_id} failed: {error_msg}")
|
||||
logger.error(f"Traceback:\n{error_trace}")
|
||||
self._update_status(
|
||||
analysis_id, status="failed", error_message=error_msg
|
||||
)
|
||||
self._store_log(analysis_id, "System", f"Error: {error_msg}\n\nTraceback:\n{error_trace}")
|
||||
|
||||
def _get_agent_order(self, selected_analysts: List[str]) -> List[str]:
|
||||
"""Get the order of agents for progress tracking."""
|
||||
agents = selected_analysts.copy()
|
||||
agents.extend(["bull_researcher", "bear_researcher", "research_manager", "trader", "risk", "portfolio"])
|
||||
return agents
|
||||
|
||||
def _get_agent_name(self, report_type: str) -> str:
|
||||
"""Get human-readable agent name from report type."""
|
||||
mapping = {
|
||||
"market_report": "Market Analyst",
|
||||
"sentiment_report": "Social Analyst",
|
||||
"news_report": "News Analyst",
|
||||
"fundamentals_report": "Fundamentals Analyst",
|
||||
}
|
||||
return mapping.get(report_type, "Unknown")
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract string content from various message formats."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "tool_use":
|
||||
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
|
||||
else:
|
||||
text_parts.append(str(item))
|
||||
return " ".join(text_parts)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the executor and cancel all running analyses."""
|
||||
with self._lock:
|
||||
# Cancel all active analyses
|
||||
for analysis_id in list(self.active_analyses.keys()):
|
||||
self.cancel_analysis(analysis_id)
|
||||
|
||||
# Shutdown executor
|
||||
self.executor.shutdown(wait=True)
|
||||
|
||||
|
||||
# Global executor instance
|
||||
_executor: Optional[AnalysisExecutor] = None
|
||||
|
||||
|
||||
def get_executor() -> AnalysisExecutor:
|
||||
"""Get the global executor instance."""
|
||||
global _executor
|
||||
if _executor is None:
|
||||
_executor = AnalysisExecutor()
|
||||
return _executor
|
||||
|
||||
|
||||
def shutdown_executor():
|
||||
"""Shutdown the global executor."""
|
||||
global _executor
|
||||
if _executor:
|
||||
_executor.shutdown()
|
||||
_executor = None
|
||||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
"""WebSocket handlers."""
|
||||
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
"""WebSocket status streaming."""
|
||||
|
||||
import json
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.database import Analysis, SessionLocal
|
||||
from api.state_manager import get_executor
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections for analyses."""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, List[WebSocket]] = {}
|
||||
|
||||
async def connect(self, analysis_id: str, websocket: WebSocket):
|
||||
"""Accept and register a WebSocket connection."""
|
||||
await websocket.accept()
|
||||
|
||||
if analysis_id not in self.active_connections:
|
||||
self.active_connections[analysis_id] = []
|
||||
|
||||
self.active_connections[analysis_id].append(websocket)
|
||||
|
||||
# Register callback with executor
|
||||
executor = get_executor()
|
||||
executor.register_status_callback(analysis_id, self._create_callback(analysis_id))
|
||||
|
||||
def disconnect(self, analysis_id: str, websocket: WebSocket):
|
||||
"""Remove a WebSocket connection."""
|
||||
if analysis_id in self.active_connections:
|
||||
if websocket in self.active_connections[analysis_id]:
|
||||
self.active_connections[analysis_id].remove(websocket)
|
||||
|
||||
# Clean up if no more connections
|
||||
if not self.active_connections[analysis_id]:
|
||||
del self.active_connections[analysis_id]
|
||||
|
||||
def _create_callback(self, analysis_id: str):
|
||||
"""Create a callback function for status updates."""
|
||||
def callback(status_data: dict):
|
||||
# Note: This is called from a thread, so we can't use async here
|
||||
# The actual broadcasting happens via the websocket event loop
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.create_task(self.broadcast(analysis_id, status_data))
|
||||
else:
|
||||
loop.run_until_complete(self.broadcast(analysis_id, status_data))
|
||||
except:
|
||||
# If no loop, we can't broadcast (connection will poll status instead)
|
||||
pass
|
||||
return callback
|
||||
|
||||
async def broadcast(self, analysis_id: str, message: dict):
|
||||
"""Broadcast a message to all connections for an analysis."""
|
||||
if analysis_id not in self.active_connections:
|
||||
return
|
||||
|
||||
disconnected = []
|
||||
for connection in self.active_connections[analysis_id]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except:
|
||||
disconnected.append(connection)
|
||||
|
||||
# Clean up disconnected clients
|
||||
for connection in disconnected:
|
||||
self.disconnect(analysis_id, connection)
|
||||
|
||||
|
||||
# Global connection manager
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@router.websocket("/api/v1/ws/analyses/{analysis_id}")
|
||||
async def websocket_analysis_status(websocket: WebSocket, analysis_id: str):
|
||||
"""WebSocket endpoint for real-time analysis status updates."""
|
||||
# Verify analysis exists
|
||||
db = SessionLocal()
|
||||
try:
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if not analysis:
|
||||
await websocket.close(code=1008, reason="Analysis not found")
|
||||
return
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Connect
|
||||
await manager.connect(analysis_id, websocket)
|
||||
|
||||
try:
|
||||
# Send initial status
|
||||
db = SessionLocal()
|
||||
try:
|
||||
analysis = db.query(Analysis).filter(Analysis.id == analysis_id).first()
|
||||
if analysis:
|
||||
initial_status = {
|
||||
"type": "status_update",
|
||||
"analysis_id": analysis.id,
|
||||
"status": analysis.status,
|
||||
"progress_percentage": analysis.progress_percentage,
|
||||
"current_agent": analysis.current_agent,
|
||||
"timestamp": analysis.updated_at.isoformat(),
|
||||
}
|
||||
await websocket.send_json(initial_status)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Keep connection alive and handle messages
|
||||
while True:
|
||||
# Wait for any messages from client (like ping)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# Echo back if it's a ping
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(analysis_id, websocket)
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
manager.disconnect(analysis_id, websocket)
|
||||
|
||||
Binary file not shown.
|
|
@ -24,3 +24,12 @@ rich
|
|||
questionary
|
||||
langchain_anthropic
|
||||
langchain-google-genai
|
||||
|
||||
# API Dependencies
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
sqlalchemy==2.0.25
|
||||
passlib[bcrypt]>=1.7.4
|
||||
bcrypt>=4.0.0
|
||||
python-multipart==0.0.6
|
||||
websockets==12.0
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
# Startup script for Trading Agents API
|
||||
|
||||
echo "Starting Trading Agents API..."
|
||||
echo ""
|
||||
echo "Make sure you have:"
|
||||
echo "1. Installed dependencies: pip install -r requirements.txt"
|
||||
echo "2. Initialized database: python -m api.cli_admin init-database"
|
||||
echo "3. Created an API key: python -m api.cli_admin create-key 'My Key'"
|
||||
echo ""
|
||||
echo "API will be available at: http://localhost:8001"
|
||||
echo "API Documentation: http://localhost:8001/docs"
|
||||
echo ""
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
python -m api.main
|
||||
|
||||
Loading…
Reference in New Issue