TradingAgents/backend/test_parallel_execution.py

323 lines
15 KiB
Python

#!/usr/bin/env python
"""
Test specifically for parallel execution verification
"""
import time
from datetime import datetime
from collections import defaultdict
import threading
import json
from pathlib import Path
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
class ParallelExecutionTracker:
"""Track parallel execution of agents"""
def __init__(self):
self.active_agents = {} # agent_name -> start_time
self.parallel_groups = [] # List of sets of agents that ran in parallel
self.agent_timeline = [] # List of (time, agent, action) tuples
self.lock = threading.Lock()
def agent_started(self, agent_name, timestamp=None):
"""Record agent start"""
timestamp = timestamp or time.time()
with self.lock:
self.active_agents[agent_name] = timestamp
self.agent_timeline.append((timestamp, agent_name, 'start'))
# Check if multiple agents are active
if len(self.active_agents) > 1:
parallel_set = set(self.active_agents.keys())
self.parallel_groups.append({
'agents': parallel_set,
'time': timestamp,
'count': len(parallel_set)
})
print(f"🔄 PARALLEL EXECUTION: {list(parallel_set)} at {datetime.fromtimestamp(timestamp).strftime('%H:%M:%S.%f')[:-3]}")
def agent_ended(self, agent_name, timestamp=None):
"""Record agent end"""
timestamp = timestamp or time.time()
with self.lock:
if agent_name in self.active_agents:
start_time = self.active_agents.pop(agent_name)
duration = timestamp - start_time
self.agent_timeline.append((timestamp, agent_name, 'end'))
print(f"{agent_name} completed in {duration:.2f}s")
def get_parallel_summary(self):
"""Get summary of parallel executions"""
summary = {
'total_parallel_groups': len(self.parallel_groups),
'max_parallel_agents': max((g['count'] for g in self.parallel_groups), default=0),
'parallel_groups': self.parallel_groups,
'timeline': sorted(self.agent_timeline, key=lambda x: x[0])
}
return summary
def test_parallel_execution():
"""Test that agents execute in parallel when expected"""
print("🚀 Testing Parallel Execution of TradingAgents")
print("=" * 80)
# Create results directory
results_dir = Path("test_results/parallel_execution")
results_dir.mkdir(parents=True, exist_ok=True)
# Configure for testing
config = DEFAULT_CONFIG.copy()
config.update({
"llm_provider": "google",
"backend_url": "https://generativelanguage.googleapis.com/v1",
"deep_think_llm": "gemini-2.0-flash",
"quick_think_llm": "gemini-2.0-flash",
"max_debate_rounds": 2,
"online_tools": True
})
# Create tracker
tracker = ParallelExecutionTracker()
# Custom TradingAgentsGraph to track execution
class TrackedGraph(TradingAgentsGraph):
def __init__(self, *args, tracker=None, **kwargs):
super().__init__(*args, **kwargs)
self.tracker = tracker
self.message_timestamps = []
def propagate(self, company_name, trade_date):
"""Enhanced propagate with parallel tracking"""
self.ticker = company_name
# Initialize state
init_agent_state = self.propagator.create_initial_state(company_name, trade_date)
args = self.propagator.get_graph_args()
trace = []
agent_states = {} # Track agent states
print(f"\n📊 Starting analysis for {company_name} on {trade_date}")
print("-" * 60)
# Process stream
for chunk_idx, chunk in enumerate(self.graph.stream(init_agent_state, **args)):
timestamp = time.time()
# Detect which agents are active based on chunk content
chunk_agents = set()
# Check for analyst reports
if "market_report" in chunk and chunk["market_report"] and "market_analyst" not in agent_states:
agent_states["market_analyst"] = "completed"
if self.tracker:
self.tracker.agent_ended("market_analyst", timestamp)
if "sentiment_report" in chunk and chunk["sentiment_report"] and "social_analyst" not in agent_states:
agent_states["social_analyst"] = "completed"
if self.tracker:
self.tracker.agent_ended("social_analyst", timestamp)
if "news_report" in chunk and chunk["news_report"] and "news_analyst" not in agent_states:
agent_states["news_analyst"] = "completed"
if self.tracker:
self.tracker.agent_ended("news_analyst", timestamp)
if "fundamentals_report" in chunk and chunk["fundamentals_report"] and "fundamentals_analyst" not in agent_states:
agent_states["fundamentals_analyst"] = "completed"
if self.tracker:
self.tracker.agent_ended("fundamentals_analyst", timestamp)
# Check messages for agent activity
if len(chunk.get("messages", [])) > 0:
last_message = chunk["messages"][-1]
# Try to identify agent from message
agent_name = None
if hasattr(last_message, 'name') and last_message.name:
agent_name = last_message.name
# Map common agent names
agent_mapping = {
"MarketAnalyst": "market_analyst",
"SocialMediaAnalyst": "social_analyst",
"NewsAnalyst": "news_analyst",
"FundamentalsAnalyst": "fundamentals_analyst",
"BullResearcher": "bull_researcher",
"BearResearcher": "bear_researcher",
"ResearchManager": "research_manager",
"Trader": "trader",
"RiskManager": "risk_manager"
}
if agent_name in agent_mapping:
mapped_name = agent_mapping[agent_name]
if mapped_name not in agent_states:
agent_states[mapped_name] = "active"
if self.tracker:
self.tracker.agent_started(mapped_name, timestamp)
# Check for tool calls which indicate agent activity
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
# Analysts are likely active when tools are called
tool_names = [tc.name if hasattr(tc, 'name') else '' for tc in last_message.tool_calls]
# Map tools to analysts
if any('YFin' in name or 'stockstats' in name for name in tool_names):
if "market_analyst" not in agent_states:
agent_states["market_analyst"] = "active"
if self.tracker:
self.tracker.agent_started("market_analyst", timestamp)
if any('reddit' in name or 'stock_news' in name for name in tool_names):
if "social_analyst" not in agent_states:
agent_states["social_analyst"] = "active"
if self.tracker:
self.tracker.agent_started("social_analyst", timestamp)
if any('news' in name or 'google_news' in name for name in tool_names):
if "news_analyst" not in agent_states:
agent_states["news_analyst"] = "active"
if self.tracker:
self.tracker.agent_started("news_analyst", timestamp)
if any('fundamentals' in name or 'simfin' in name or 'finnhub' in name for name in tool_names):
if "fundamentals_analyst" not in agent_states:
agent_states["fundamentals_analyst"] = "active"
if self.tracker:
self.tracker.agent_started("fundamentals_analyst", timestamp)
# Check for debate states indicating researcher activity
if "investment_debate_state" in chunk:
debate_state = chunk["investment_debate_state"]
if debate_state.get("bull_history") and "bull_researcher" not in agent_states:
agent_states["bull_researcher"] = "active"
if self.tracker:
self.tracker.agent_started("bull_researcher", timestamp)
if debate_state.get("bear_history") and "bear_researcher" not in agent_states:
agent_states["bear_researcher"] = "active"
if self.tracker:
self.tracker.agent_started("bear_researcher", timestamp)
if debate_state.get("judge_decision"):
# Mark researchers as completed
if "bull_researcher" in agent_states and agent_states["bull_researcher"] == "active":
agent_states["bull_researcher"] = "completed"
if self.tracker:
self.tracker.agent_ended("bull_researcher", timestamp)
if "bear_researcher" in agent_states and agent_states["bear_researcher"] == "active":
agent_states["bear_researcher"] = "completed"
if self.tracker:
self.tracker.agent_ended("bear_researcher", timestamp)
trace.append(chunk)
# Mark any remaining active agents as completed
final_timestamp = time.time()
for agent, state in agent_states.items():
if state == "active" and self.tracker:
self.tracker.agent_ended(agent, final_timestamp)
final_state = trace[-1] if trace else {}
self.curr_state = final_state
self._log_state(trade_date, final_state)
return final_state, self.process_signal(final_state["final_trade_decision"])
# Run test
print("\n🧪 Running parallel execution test...")
try:
# Create tracked graph
graph = TrackedGraph(
debug=True,
config=config,
tracker=tracker
)
# Run analysis
start_time = time.time()
final_state, decision = graph.propagate("AAPL", "2024-05-15")
total_time = time.time() - start_time
print(f"\n✅ Analysis completed in {total_time:.2f}s")
print(f"📊 Decision: {decision}")
# Get parallel execution summary
summary = tracker.get_parallel_summary()
print("\n" + "=" * 80)
print("PARALLEL EXECUTION SUMMARY")
print("=" * 80)
print(f"Total parallel groups detected: {summary['total_parallel_groups']}")
print(f"Maximum agents running in parallel: {summary['max_parallel_agents']}")
if summary['parallel_groups']:
print("\nParallel execution instances:")
for i, group in enumerate(summary['parallel_groups']):
agents_str = ", ".join(sorted(group['agents']))
timestamp_str = datetime.fromtimestamp(group['time']).strftime('%H:%M:%S.%f')[:-3]
print(f" {i+1}. [{timestamp_str}] {group['count']} agents: {agents_str}")
# Analyze timeline
print("\nExecution timeline:")
for timestamp, agent, action in summary['timeline'][:20]: # Show first 20 events
timestamp_str = datetime.fromtimestamp(timestamp).strftime('%H:%M:%S.%f')[:-3]
symbol = "▶️" if action == "start" else "⏹️"
print(f" [{timestamp_str}] {symbol} {agent} {action}")
if len(summary['timeline']) > 20:
print(f" ... and {len(summary['timeline']) - 20} more events")
# Save results
results_file = results_dir / "parallel_execution_summary.json"
with open(results_file, 'w') as f:
# Convert to serializable format
serializable_summary = {
'total_time': total_time,
'decision': decision,
'parallel_summary': {
'total_parallel_groups': summary['total_parallel_groups'],
'max_parallel_agents': summary['max_parallel_agents'],
'parallel_groups': [
{
'agents': list(g['agents']),
'time': g['time'],
'count': g['count']
}
for g in summary['parallel_groups']
],
'timeline': [
{
'timestamp': t,
'agent': a,
'action': act
}
for t, a, act in summary['timeline']
]
}
}
json.dump(serializable_summary, f, indent=2)
print(f"\n📁 Results saved to: {results_file}")
# Verify parallel execution occurred
if summary['total_parallel_groups'] > 0:
print("\n✅ PARALLEL EXECUTION VERIFIED!")
print(f" Found {summary['total_parallel_groups']} instances of parallel agent execution")
else:
print("\n⚠️ WARNING: No parallel execution detected!")
print(" This might indicate a performance issue or sequential execution")
except Exception as e:
print(f"\n❌ Error during test: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_parallel_execution()