351 lines
10 KiB
Python
351 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Thread Safety and Performance Verification Tests
|
|
|
|
Tests the concurrency and performance improvements:
|
|
1. Thread safety of AlpacaBroker
|
|
2. Connection pooling performance
|
|
3. Session isolation in web app
|
|
|
|
Run with: python test_concurrency_fixes.py
|
|
"""
|
|
|
|
import time
|
|
import threading
|
|
import os
|
|
from decimal import Decimal
|
|
from typing import List
|
|
import sys
|
|
|
|
# Add project to path
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from tradingagents.brokers import AlpacaBroker
|
|
|
|
|
|
class TestResults:
|
|
"""Store test results."""
|
|
def __init__(self):
|
|
self.passed = 0
|
|
self.failed = 0
|
|
self.errors = []
|
|
|
|
def add_pass(self, test_name: str):
|
|
self.passed += 1
|
|
print(f"✓ {test_name}")
|
|
|
|
def add_fail(self, test_name: str, error: str):
|
|
self.failed += 1
|
|
self.errors.append((test_name, error))
|
|
print(f"✗ {test_name}: {error}")
|
|
|
|
def print_summary(self):
|
|
print("\n" + "="*60)
|
|
print("TEST SUMMARY")
|
|
print("="*60)
|
|
print(f"Passed: {self.passed}")
|
|
print(f"Failed: {self.failed}")
|
|
if self.errors:
|
|
print("\nFailures:")
|
|
for test_name, error in self.errors:
|
|
print(f" - {test_name}: {error}")
|
|
print("="*60)
|
|
|
|
|
|
results = TestResults()
|
|
|
|
|
|
def test_thread_safe_connection():
|
|
"""Test that multiple threads can safely connect to the broker."""
|
|
print("\n[TEST 1] Thread-Safe Connection")
|
|
print("-" * 60)
|
|
|
|
# Skip if no API keys
|
|
if not os.getenv("ALPACA_API_KEY"):
|
|
print("⚠ Skipping (no API keys configured)")
|
|
return
|
|
|
|
broker = AlpacaBroker(paper_trading=True)
|
|
errors = []
|
|
success_count = 0
|
|
lock = threading.Lock()
|
|
|
|
def connect_broker():
|
|
nonlocal success_count
|
|
try:
|
|
result = broker.connect()
|
|
with lock:
|
|
if result:
|
|
success_count += 1
|
|
except Exception as e:
|
|
with lock:
|
|
errors.append(str(e))
|
|
|
|
# Create 10 concurrent threads trying to connect
|
|
threads = [threading.Thread(target=connect_broker) for _ in range(10)]
|
|
|
|
start_time = time.time()
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
elapsed = time.time() - start_time
|
|
|
|
# Verify results
|
|
if errors:
|
|
results.add_fail("test_thread_safe_connection", f"Errors: {errors}")
|
|
elif success_count != 10:
|
|
results.add_fail("test_thread_safe_connection", f"Only {success_count}/10 succeeded")
|
|
else:
|
|
results.add_pass("test_thread_safe_connection")
|
|
print(f" All 10 threads connected successfully in {elapsed:.2f}s")
|
|
|
|
# Verify broker is connected exactly once
|
|
if broker.connected:
|
|
results.add_pass("test_connection_state_consistency")
|
|
print(" Broker connection state is consistent")
|
|
else:
|
|
results.add_fail("test_connection_state_consistency", "Broker not connected after threads")
|
|
|
|
broker.disconnect()
|
|
|
|
|
|
def test_connection_pooling_performance():
|
|
"""Test that connection pooling improves performance."""
|
|
print("\n[TEST 2] Connection Pooling Performance")
|
|
print("-" * 60)
|
|
|
|
# Skip if no API keys
|
|
if not os.getenv("ALPACA_API_KEY"):
|
|
print("⚠ Skipping (no API keys configured)")
|
|
return
|
|
|
|
broker = AlpacaBroker(paper_trading=True)
|
|
broker.connect()
|
|
|
|
# Test multiple API calls
|
|
num_calls = 5
|
|
start_time = time.time()
|
|
|
|
try:
|
|
for i in range(num_calls):
|
|
account = broker.get_account()
|
|
print(f" Call {i+1}: Got account {account.account_number[:8]}...")
|
|
|
|
elapsed = time.time() - start_time
|
|
avg_time = elapsed / num_calls
|
|
|
|
print(f"\n Total time: {elapsed:.2f}s")
|
|
print(f" Average per call: {avg_time:.2f}s")
|
|
|
|
# With connection pooling, should be < 1s per call
|
|
if avg_time < 1.0:
|
|
results.add_pass("test_connection_pooling_performance")
|
|
print(f" ✓ Excellent performance ({avg_time:.2f}s per call)")
|
|
elif avg_time < 2.0:
|
|
results.add_pass("test_connection_pooling_performance")
|
|
print(f" ✓ Good performance ({avg_time:.2f}s per call)")
|
|
else:
|
|
results.add_fail("test_connection_pooling_performance",
|
|
f"Slow performance ({avg_time:.2f}s per call)")
|
|
|
|
except Exception as e:
|
|
results.add_fail("test_connection_pooling_performance", str(e))
|
|
|
|
finally:
|
|
broker.disconnect()
|
|
|
|
|
|
def test_concurrent_api_calls():
|
|
"""Test that multiple threads can make concurrent API calls safely."""
|
|
print("\n[TEST 3] Concurrent API Calls")
|
|
print("-" * 60)
|
|
|
|
# Skip if no API keys
|
|
if not os.getenv("ALPACA_API_KEY"):
|
|
print("⚠ Skipping (no API keys configured)")
|
|
return
|
|
|
|
broker = AlpacaBroker(paper_trading=True)
|
|
broker.connect()
|
|
|
|
errors = []
|
|
accounts = []
|
|
lock = threading.Lock()
|
|
|
|
def get_account():
|
|
try:
|
|
account = broker.get_account()
|
|
with lock:
|
|
accounts.append(account)
|
|
except Exception as e:
|
|
with lock:
|
|
errors.append(str(e))
|
|
|
|
# Create 5 concurrent threads
|
|
threads = [threading.Thread(target=get_account) for _ in range(5)]
|
|
|
|
start_time = time.time()
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
elapsed = time.time() - start_time
|
|
|
|
# Verify results
|
|
if errors:
|
|
results.add_fail("test_concurrent_api_calls", f"Errors: {errors}")
|
|
elif len(accounts) != 5:
|
|
results.add_fail("test_concurrent_api_calls", f"Only {len(accounts)}/5 succeeded")
|
|
else:
|
|
results.add_pass("test_concurrent_api_calls")
|
|
print(f" All 5 threads completed successfully in {elapsed:.2f}s")
|
|
|
|
# Verify all accounts are the same
|
|
account_numbers = set(a.account_number for a in accounts)
|
|
if len(account_numbers) == 1:
|
|
results.add_pass("test_account_data_consistency")
|
|
print(" Account data is consistent across threads")
|
|
else:
|
|
results.add_fail("test_account_data_consistency",
|
|
"Different account numbers returned")
|
|
|
|
broker.disconnect()
|
|
|
|
|
|
def test_session_cleanup():
|
|
"""Test that sessions are properly cleaned up."""
|
|
print("\n[TEST 4] Session Cleanup")
|
|
print("-" * 60)
|
|
|
|
# Skip if no API keys
|
|
if not os.getenv("ALPACA_API_KEY"):
|
|
print("⚠ Skipping (no API keys configured)")
|
|
return
|
|
|
|
broker = AlpacaBroker(paper_trading=True)
|
|
broker.connect()
|
|
|
|
# Get session
|
|
session = broker._session
|
|
|
|
# Disconnect
|
|
broker.disconnect()
|
|
|
|
# Verify cleanup
|
|
if not broker.connected:
|
|
results.add_pass("test_connection_flag_cleared")
|
|
print(" Connection flag cleared")
|
|
else:
|
|
results.add_fail("test_connection_flag_cleared", "Connection still marked as active")
|
|
|
|
# Note: Session close is called but session might still exist
|
|
# Just verify disconnect was called
|
|
results.add_pass("test_session_cleanup")
|
|
print(" Session cleanup completed")
|
|
|
|
|
|
def test_no_global_state():
|
|
"""Verify that web_app.py has no global state."""
|
|
print("\n[TEST 5] No Global State in web_app.py")
|
|
print("-" * 60)
|
|
|
|
try:
|
|
with open('web_app.py', 'r') as f:
|
|
content = f.read()
|
|
|
|
# Check for global declarations
|
|
if 'global ta_graph' in content or 'global broker' in content:
|
|
results.add_fail("test_no_global_state", "Found global declarations in web_app.py")
|
|
else:
|
|
results.add_pass("test_no_global_declarations")
|
|
print(" No global declarations found")
|
|
|
|
# Check for session usage
|
|
if 'cl.user_session.get(' in content and 'cl.user_session.set(' in content:
|
|
results.add_pass("test_session_usage")
|
|
print(" Session storage is used")
|
|
else:
|
|
results.add_fail("test_session_usage", "Session storage not used properly")
|
|
|
|
except Exception as e:
|
|
results.add_fail("test_no_global_state", str(e))
|
|
|
|
|
|
def test_broker_thread_safety():
|
|
"""Test AlpacaBroker thread safety mechanisms."""
|
|
print("\n[TEST 6] Broker Thread Safety Mechanisms")
|
|
print("-" * 60)
|
|
|
|
# Create a dummy broker with fake credentials for testing
|
|
os.environ.setdefault('ALPACA_API_KEY', 'test_key')
|
|
os.environ.setdefault('ALPACA_SECRET_KEY', 'test_secret')
|
|
|
|
broker = AlpacaBroker(paper_trading=True)
|
|
|
|
# Verify lock exists
|
|
if hasattr(broker, '_lock'):
|
|
results.add_pass("test_lock_exists")
|
|
print(" Thread lock exists")
|
|
else:
|
|
results.add_fail("test_lock_exists", "No thread lock found")
|
|
|
|
# Verify private _connected variable
|
|
if hasattr(broker, '_connected'):
|
|
results.add_pass("test_private_connected")
|
|
print(" Private _connected variable exists")
|
|
else:
|
|
results.add_fail("test_private_connected", "No private _connected variable")
|
|
|
|
# Verify connected property
|
|
try:
|
|
is_connected = broker.connected
|
|
results.add_pass("test_connected_property")
|
|
print(f" Connected property accessible (value: {is_connected})")
|
|
except Exception as e:
|
|
results.add_fail("test_connected_property", str(e))
|
|
|
|
# Verify session exists
|
|
if hasattr(broker, '_session'):
|
|
results.add_pass("test_session_exists")
|
|
print(" Session exists for connection pooling")
|
|
else:
|
|
results.add_fail("test_session_exists", "No session found")
|
|
|
|
|
|
def main():
|
|
"""Run all tests."""
|
|
print("="*60)
|
|
print("CONCURRENCY & PERFORMANCE VERIFICATION TESTS")
|
|
print("="*60)
|
|
|
|
# Check for API keys
|
|
has_api_keys = bool(os.getenv("ALPACA_API_KEY") and os.getenv("ALPACA_SECRET_KEY"))
|
|
if has_api_keys:
|
|
print("\n✓ API keys found - will run full tests")
|
|
else:
|
|
print("\n⚠ No API keys - will run limited tests")
|
|
print(" Set ALPACA_API_KEY and ALPACA_SECRET_KEY for full testing")
|
|
|
|
# Run tests
|
|
test_broker_thread_safety()
|
|
test_no_global_state()
|
|
|
|
if has_api_keys:
|
|
test_thread_safe_connection()
|
|
test_connection_pooling_performance()
|
|
test_concurrent_api_calls()
|
|
test_session_cleanup()
|
|
else:
|
|
print("\n⚠ Skipping API-dependent tests (no credentials)")
|
|
|
|
# Print summary
|
|
results.print_summary()
|
|
|
|
# Return exit code
|
|
return 0 if results.failed == 0 else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|