TradingAgents/backend/test_comprehensive_fixes.py

314 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Comprehensive test to verify all fixes:
1. Tool call limits (max 3 per analyst)
2. No duplicate completion messages
3. Bear researcher completion status
4. Risk analysts parallel execution
5. Proper status updates for all agents
"""
import asyncio
import json
import time
from datetime import datetime
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
def create_test_config():
"""Create test configuration"""
config = DEFAULT_CONFIG.copy()
config.update({
"llm_provider": "openai",
"deep_think_llm": "gpt-4o",
"quick_think_llm": "gpt-4o",
"backend_url": "https://api.openai.com/v1",
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
"online_tools": True,
})
return config
def analyze_tool_calls(state, channel_name):
"""Analyze tool calls in a message channel"""
messages = state.get(channel_name, [])
tool_calls = []
tool_responses = []
for msg in messages:
if hasattr(msg, 'tool_calls') and msg.tool_calls:
for tc in msg.tool_calls:
tool_calls.append({
'name': tc.name if hasattr(tc, 'name') else 'unknown',
'args': tc.args if hasattr(tc, 'args') else {}
})
if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool':
tool_responses.append(msg)
return tool_calls, tool_responses
def test_graph_execution():
"""Test the trading graph execution with all fixes"""
print("\n" + "="*80)
print("🧪 COMPREHENSIVE TEST: Trading Graph Fixes")
print("="*80)
# Initialize graph
config = create_test_config()
graph = TradingAgentsGraph(
selected_analysts=["market", "social", "news", "fundamentals"],
debug=True,
config=config
)
# Test parameters
ticker = "TSLA"
date = datetime.now().strftime("%Y-%m-%d")
print(f"\n📊 Testing with: {ticker} on {date}")
print("-"*80)
# Track execution metrics
start_time = time.time()
chunks_processed = 0
agent_completions = {}
tool_call_counts = {}
duplicate_completions = set()
# Initialize state
init_state = graph.propagator.create_initial_state(ticker, date)
args = graph.propagator.get_graph_args()
# Stream execution
print("\n🔄 Starting graph execution...")
for chunk in graph.graph.stream(init_state, **args):
chunks_processed += 1
# Analyze message channels
for channel in ["market_messages", "social_messages", "news_messages", "fundamentals_messages"]:
if channel in chunk and chunk[channel]:
analyst_type = channel.replace("_messages", "")
# Count tool calls
tool_calls, tool_responses = analyze_tool_calls(chunk, channel)
if tool_calls:
if analyst_type not in tool_call_counts:
tool_call_counts[analyst_type] = []
tool_call_counts[analyst_type].extend(tool_calls)
print(f"🔧 {analyst_type.upper()}: {len(tool_calls)} new tool calls (total: {len(tool_call_counts[analyst_type])})")
# Check report completions
report_fields = {
"market_report": "Market Analyst",
"sentiment_report": "Social Media Analyst",
"news_report": "News Analyst",
"fundamentals_report": "Fundamentals Analyst"
}
for report_key, agent_name in report_fields.items():
if report_key in chunk and chunk[report_key]:
if agent_name in agent_completions:
duplicate_completions.add(agent_name)
print(f"⚠️ DUPLICATE COMPLETION: {agent_name}")
else:
agent_completions[agent_name] = time.time()
print(f"{agent_name} completed")
# Check Bull/Bear completion
if "investment_debate_state" in chunk:
debate_state = chunk["investment_debate_state"]
if debate_state.get("judge_decision"):
if "Bull Researcher" not in agent_completions:
agent_completions["Bull Researcher"] = time.time()
print("✅ Bull Researcher completed")
if "Bear Researcher" not in agent_completions:
agent_completions["Bear Researcher"] = time.time()
print("✅ Bear Researcher completed")
# Check risk analysts
if "risk_debate_state" in chunk:
risk_state = chunk["risk_debate_state"]
# Track parallel execution timing
if risk_state.get("current_risky_response") and "Risky Analyst" not in agent_completions:
agent_completions["Risky Analyst"] = time.time()
print("✅ Risky Analyst completed")
if risk_state.get("current_safe_response") and "Safe Analyst" not in agent_completions:
agent_completions["Safe Analyst"] = time.time()
print("✅ Safe Analyst completed")
if risk_state.get("current_neutral_response") and "Neutral Analyst" not in agent_completions:
agent_completions["Neutral Analyst"] = time.time()
print("✅ Neutral Analyst completed")
# Final state
final_state = chunk
execution_time = time.time() - start_time
# Validate final trade decision
final_decision = final_state.get("final_trade_decision", "")
decision_valid = True
decision_issues = []
if not final_decision:
decision_issues.append("No final trade decision generated")
decision_valid = False
elif "I'm sorry, but it looks like there is no paragraph" in final_decision:
decision_issues.append("Risk manager received invalid/empty data")
decision_valid = False
elif "no paragraph or financial report provided" in final_decision:
decision_issues.append("Risk manager missing required reports")
decision_valid = False
elif len(final_decision) < 100:
decision_issues.append(f"Final decision too short ({len(final_decision)} chars)")
decision_valid = False
elif not any(keyword in final_decision.upper() for keyword in ["BUY", "SELL", "HOLD"]):
decision_issues.append("Final decision missing BUY/SELL/HOLD recommendation")
decision_valid = False
# Validate risk debate state
risk_debate_state = final_state.get("risk_debate_state", {})
judge_decision = risk_debate_state.get("judge_decision", "")
if not judge_decision:
decision_issues.append("No judge decision in risk debate state")
decision_valid = False
elif judge_decision != final_decision:
decision_issues.append("Mismatch between judge_decision and final_trade_decision")
decision_valid = False
# Generate report
print("\n" + "="*80)
print("📊 EXECUTION REPORT")
print("="*80)
print(f"\n⏱️ Total execution time: {execution_time:.2f} seconds")
print(f"📦 Chunks processed: {chunks_processed}")
# Decision validation
print(f"\n🎯 FINAL DECISION VALIDATION:")
print("-"*40)
if decision_valid:
print("✅ Final trade decision is valid")
print(f"📝 Decision length: {len(final_decision)} chars")
print(f"📝 Decision preview: {final_decision[:200]}...")
else:
print("❌ Final trade decision has issues:")
for issue in decision_issues:
print(f" - {issue}")
if final_decision:
print(f"📝 Decision content: {final_decision[:500]}...")
# Tool call analysis
print("\n🔧 TOOL CALL ANALYSIS:")
print("-"*40)
issues = []
for analyst, calls in tool_call_counts.items():
print(f"\n{analyst.upper()} ANALYST:")
print(f" Total tool calls: {len(calls)}")
# Check limit
if len(calls) > 3:
issues.append(f"{analyst} exceeded tool call limit: {len(calls)} > 3")
print(f" ❌ EXCEEDED LIMIT!")
else:
print(f" ✅ Within limit")
# Check for duplicates
unique_calls = set()
duplicates = []
for call in calls:
call_str = f"{call['name']}:{json.dumps(call['args'], sort_keys=True)}"
if call_str in unique_calls:
duplicates.append(call_str)
unique_calls.add(call_str)
if duplicates:
issues.append(f"{analyst} has duplicate tool calls: {duplicates}")
print(f" ❌ DUPLICATE CALLS: {len(duplicates)}")
else:
print(f" ✅ No duplicate calls")
# Completion analysis
print("\n✅ AGENT COMPLETION ANALYSIS:")
print("-"*40)
expected_agents = [
"Market Analyst", "Social Media Analyst", "News Analyst", "Fundamentals Analyst",
"Bull Researcher", "Bear Researcher", "Risky Analyst", "Safe Analyst", "Neutral Analyst"
]
for agent in expected_agents:
if agent in agent_completions:
print(f"{agent}: Completed")
else:
issues.append(f"{agent} did not complete")
print(f"{agent}: NOT COMPLETED")
# Duplicate completion check
if duplicate_completions:
print(f"\n⚠️ DUPLICATE COMPLETIONS DETECTED: {list(duplicate_completions)}")
issues.extend([f"Duplicate completion: {agent}" for agent in duplicate_completions])
else:
print("\n✅ No duplicate completions")
# Risk analyst parallelization check
print("\n⚡ RISK ANALYST PARALLELIZATION:")
print("-"*40)
risk_analysts = ["Risky Analyst", "Safe Analyst", "Neutral Analyst"]
risk_times = {agent: agent_completions.get(agent, 0) for agent in risk_analysts}
if all(risk_times.values()):
min_time = min(risk_times.values())
max_time = max(risk_times.values())
time_spread = max_time - min_time
print(f"Time spread: {time_spread:.2f} seconds")
if time_spread < 5: # Should complete within 5 seconds of each other if parallel
print("✅ Risk analysts executed in parallel")
else:
issues.append(f"Risk analysts may not be parallel: {time_spread:.2f}s spread")
print(f"⚠️ Risk analysts may not be parallel")
else:
missing = [a for a in risk_analysts if not risk_times[a]]
issues.append(f"Risk analysts did not complete: {missing}")
print(f"❌ Risk analysts did not complete: {missing}")
# Final verdict
print("\n" + "="*80)
print("🎯 FINAL VERDICT")
print("="*80)
# Add decision issues to overall issues
if decision_issues:
issues.extend(decision_issues)
if not issues:
print("\n✅ ALL TESTS PASSED! 🎉")
print("\nKey achievements:")
print("- Tool calls limited to 3 per analyst")
print("- No duplicate completions")
print("- Bear researcher properly tracked")
print("- Risk analysts run in parallel")
print("- Final trade decision is valid and complete")
print(f"- Total execution time: {execution_time:.2f}s")
else:
print("\n❌ ISSUES FOUND:")
for i, issue in enumerate(issues, 1):
print(f"{i}. {issue}")
return not bool(issues), final_state
if __name__ == "__main__":
success, final_state = test_graph_execution()
# Check final decision
if final_state.get("final_trade_decision"):
print(f"\n📊 Final trading decision: {final_state['final_trade_decision'][:100]}...")
exit(0 if success else 1)