#!/usr/bin/env python """ Comprehensive test for FastAPI run_api.py - Tests all endpoints, streaming, and parallel execution """ import requests import json import time import threading import asyncio from datetime import datetime from collections import defaultdict from pathlib import Path import sys import multiprocessing class APITestLogger: """Enhanced logger for API testing""" def __init__(self): self.log_file = f"test_results/api_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" Path("test_results").mkdir(exist_ok=True) self.test_results = [] self.stream_events = defaultdict(list) def log(self, message, level="INFO"): """Log with timestamp and level""" timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] log_entry = f"[{timestamp}] [{level}] {message}" print(log_entry) # Also write to file with open(self.log_file, 'a') as f: f.write(log_entry + '\n') def log_test_result(self, test_name, success, duration, details=""): """Log test result""" result = { 'test': test_name, 'success': success, 'duration': duration, 'details': details, 'timestamp': datetime.now().isoformat() } self.test_results.append(result) status = "āœ… PASS" if success else "āŒ FAIL" self.log(f"{status} {test_name} ({duration:.2f}s) {details}", "RESULT") def log_stream_event(self, ticker, event): """Log streaming event""" self.stream_events[ticker].append({ 'time': time.time(), 'event': event }) def print_summary(self): """Print test summary""" self.log("\n" + "="*80, "SUMMARY") self.log("API TEST SUMMARY", "SUMMARY") self.log("="*80, "SUMMARY") passed = sum(1 for r in self.test_results if r['success']) total = len(self.test_results) self.log(f"\nšŸ“Š Test Results: {passed}/{total} passed", "SUMMARY") for result in self.test_results: status = "āœ…" if result['success'] else "āŒ" self.log(f" {status} {result['test']} ({result['duration']:.2f}s)", "SUMMARY") # Stream event summary if self.stream_events: self.log("\nšŸ“” Streaming Events Summary:", "SUMMARY") for ticker, events in self.stream_events.items(): self.log(f" {ticker}: {len(events)} events", "SUMMARY") # Count event types event_types = defaultdict(int) for event_data in events: if 'type' in event_data['event']: event_types[event_data['event']['type']] += 1 for event_type, count in event_types.items(): self.log(f" - {event_type}: {count}", "SUMMARY") self.log(f"\nšŸ“ Full log saved to: {self.log_file}", "SUMMARY") def start_api_server(logger): """Start the API server in a separate process""" logger.log("šŸš€ Starting API server...", "SERVER") def run_server(): import subprocess import os env = os.environ.copy() # Ensure the server runs on the expected port env['API_PORT'] = '8000' subprocess.run([sys.executable, "run_api.py"], env=env) server_process = multiprocessing.Process(target=run_server) server_process.daemon = True server_process.start() # Wait for server to start logger.log("ā³ Waiting for server to start...", "SERVER") time.sleep(5) # Check if server is running max_retries = 10 for i in range(max_retries): try: response = requests.get("http://localhost:8000/health") if response.status_code == 200: logger.log("āœ… API server is running", "SERVER") return server_process except: pass time.sleep(2) logger.log("āŒ Failed to start API server", "ERROR") return None def test_health_endpoint(base_url, logger): """Test health check endpoint""" test_name = "Health Check" start_time = time.time() try: response = requests.get(f"{base_url}/health", timeout=5) duration = time.time() - start_time if response.status_code == 200 and response.json().get("status") == "healthy": logger.log_test_result(test_name, True, duration, "Server is healthy") else: logger.log_test_result(test_name, False, duration, f"Unexpected response: {response.text}") except Exception as e: duration = time.time() - start_time logger.log_test_result(test_name, False, duration, str(e)) def test_root_endpoint(base_url, logger): """Test root endpoint""" test_name = "Root Endpoint" start_time = time.time() try: response = requests.get(f"{base_url}/", timeout=5) duration = time.time() - start_time if response.status_code == 200: logger.log_test_result(test_name, True, duration, f"Response: {response.json()}") else: logger.log_test_result(test_name, False, duration, f"Status: {response.status_code}") except Exception as e: duration = time.time() - start_time logger.log_test_result(test_name, False, duration, str(e)) def test_analyze_endpoint(base_url, ticker, logger): """Test synchronous analysis endpoint""" test_name = f"Analyze Endpoint ({ticker})" start_time = time.time() logger.log(f"\nšŸ” Testing analysis for {ticker}...", "TEST") logger.log("ā³ This may take 30-60 seconds...", "TEST") try: response = requests.post( f"{base_url}/analyze", json={"ticker": ticker}, timeout=120 # 2 minute timeout ) duration = time.time() - start_time if response.status_code == 200: result = response.json() # Check for required fields required_fields = ['ticker', 'analysis_date', 'market_report', 'sentiment_report', 'news_report', 'fundamentals_report', 'final_trade_decision', 'processed_signal'] missing_fields = [f for f in required_fields if not result.get(f)] if not missing_fields and not result.get('error'): logger.log_test_result(test_name, True, duration, f"Signal: {result.get('processed_signal', 'N/A')}") # Log report sizes for field in required_fields[2:]: # Skip ticker and date if result.get(field): logger.log(f" šŸ“„ {field}: {len(str(result[field]))} chars", "INFO") else: details = f"Missing fields: {missing_fields}" if missing_fields else f"Error: {result.get('error')}" logger.log_test_result(test_name, False, duration, details) else: logger.log_test_result(test_name, False, duration, f"Status: {response.status_code}, Response: {response.text[:200]}") except Exception as e: duration = time.time() - start_time logger.log_test_result(test_name, False, duration, str(e)) def test_streaming_endpoint(base_url, ticker, logger): """Test streaming analysis endpoint""" test_name = f"Streaming Endpoint ({ticker})" start_time = time.time() logger.log(f"\nšŸ“” Testing streaming analysis for {ticker}...", "TEST") try: # Track streaming events events_received = [] agent_progress = {} reports_received = [] with requests.get(f"{base_url}/analyze/stream?ticker={ticker}", stream=True, timeout=120) as response: if response.status_code != 200: logger.log_test_result(test_name, False, time.time() - start_time, f"Status: {response.status_code}") return # Process SSE stream for line in response.iter_lines(): if line: line_str = line.decode('utf-8') if line_str.startswith('data: '): try: event_data = json.loads(line_str[6:]) events_received.append(event_data) logger.log_stream_event(ticker, event_data) # Log different event types event_type = event_data.get('type', 'unknown') if event_type == 'status': logger.log(f" šŸ“Š Status: {event_data.get('message', '')}", "STREAM") elif event_type == 'agent_status': agent = event_data.get('agent', 'unknown') status = event_data.get('status', 'unknown') agent_progress[agent] = status logger.log(f" šŸ¤– Agent '{agent}' -> {status}", "STREAM") # Check for parallel execution active_agents = [a for a, s in agent_progress.items() if s == 'in_progress'] if len(active_agents) > 1: logger.log(f" šŸ”„ PARALLEL AGENTS: {active_agents}", "PARALLEL") elif event_type == 'report': section = event_data.get('section', 'unknown') content_len = len(event_data.get('content', '')) reports_received.append(section) logger.log(f" šŸ“„ Report received: {section} ({content_len} chars)", "STREAM") elif event_type == 'progress': progress = event_data.get('content', '0') logger.log(f" šŸ“ˆ Progress: {progress}%", "STREAM") elif event_type == 'reasoning': content_preview = event_data.get('content', '')[:100] logger.log(f" šŸ’­ Reasoning: {content_preview}...", "STREAM") elif event_type == 'complete': signal = event_data.get('signal', 'N/A') logger.log(f" āœ… Complete! Signal: {signal}", "STREAM") break elif event_type == 'error': logger.log(f" āŒ Error: {event_data.get('message', 'Unknown error')}", "ERROR") break except json.JSONDecodeError as e: logger.log(f" āš ļø Failed to parse SSE data: {e}", "WARNING") duration = time.time() - start_time # Validate results success = ( len(events_received) > 0 and len(reports_received) >= 6 and # Should receive all main reports any(e.get('type') == 'complete' for e in events_received) ) details = f"Events: {len(events_received)}, Reports: {len(reports_received)}, Agents: {len(agent_progress)}" logger.log_test_result(test_name, success, duration, details) except Exception as e: duration = time.time() - start_time logger.log_test_result(test_name, False, duration, str(e)) def test_parallel_requests(base_url, logger): """Test multiple parallel requests to verify server handles concurrent load""" test_name = "Parallel Requests" start_time = time.time() logger.log("\nšŸ”„ Testing parallel requests...", "TEST") tickers = ["AAPL", "GOOGL", "MSFT"] threads = [] results = [] def analyze_ticker(ticker): try: response = requests.post( f"{base_url}/analyze", json={"ticker": ticker}, timeout=120 ) results.append({ 'ticker': ticker, 'success': response.status_code == 200, 'time': time.time() - start_time }) except Exception as e: results.append({ 'ticker': ticker, 'success': False, 'error': str(e), 'time': time.time() - start_time }) # Start parallel requests for ticker in tickers: thread = threading.Thread(target=analyze_ticker, args=(ticker,)) thread.start() threads.append(thread) logger.log(f" šŸš€ Started request for {ticker}", "PARALLEL") # Wait for all to complete for thread in threads: thread.join() duration = time.time() - start_time # Check results successful = sum(1 for r in results if r['success']) details = f"Success: {successful}/{len(tickers)}, Total time: {duration:.2f}s" for result in results: status = "āœ…" if result['success'] else "āŒ" logger.log(f" {status} {result['ticker']} completed at {result['time']:.2f}s", "PARALLEL") logger.log_test_result(test_name, successful == len(tickers), duration, details) def test_error_handling(base_url, logger): """Test API error handling""" test_name = "Error Handling" start_time = time.time() logger.log("\nšŸ›”ļø Testing error handling...", "TEST") # Test invalid ticker try: response = requests.post( f"{base_url}/analyze", json={"ticker": ""}, timeout=10 ) if response.status_code == 400 or (response.status_code == 200 and 'error' in response.json()): logger.log(" āœ… Empty ticker handled correctly", "TEST") else: logger.log(" āŒ Empty ticker not handled properly", "TEST") except Exception as e: logger.log(f" āŒ Error testing invalid ticker: {e}", "ERROR") duration = time.time() - start_time logger.log_test_result(test_name, True, duration, "Error handling tested") def run_comprehensive_api_tests(): """Run all API tests comprehensively""" logger = APITestLogger() logger.log("šŸš€ Starting Comprehensive TradingAgents API Test Suite", "START") logger.log(f"šŸ“… Test started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", "START") logger.log("-" * 80) # Base URL base_url = "http://localhost:8000" # Check if server is already running try: response = requests.get(f"{base_url}/health", timeout=5) if response.status_code == 200: logger.log("āœ… API server already running", "SERVER") server_process = None except: # Start server if not running server_process = start_api_server(logger) if not server_process: logger.log("āŒ Cannot start API server. Please run 'python run_api.py' manually", "ERROR") return try: # Run tests test_health_endpoint(base_url, logger) test_root_endpoint(base_url, logger) # Test with different tickers test_analyze_endpoint(base_url, "NVDA", logger) test_streaming_endpoint(base_url, "AAPL", logger) # Test parallel handling test_parallel_requests(base_url, logger) # Test error handling test_error_handling(base_url, logger) # Print summary logger.print_summary() finally: # Clean up server if we started it if server_process: logger.log("\nšŸ›‘ Stopping API server...", "SERVER") server_process.terminate() server_process.join(timeout=5) if __name__ == "__main__": run_comprehensive_api_tests()