# TradingAgents/graph/setup.py from typing import Dict, Any, List, Set, Tuple from langchain_openai import ChatOpenAI from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode from langchain_core.messages import HumanMessage, ToolMessage import logging import hashlib import json from tradingagents.agents import * from tradingagents.agents.utils.agent_states import AgentState from tradingagents.agents.utils.agent_utils import Toolkit from .conditional_logic import ConditionalLogic # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ToolCallTracker: """Tracks tool calls per analyst to enforce limits and prevent duplicates.""" def __init__(self): self.call_history = {} # analyst_type -> {tool_name: [(params_hash, params_str)]} self.call_counts = {} # analyst_type -> {tool_name: count} # Different limits for different analyst types self.max_total_calls = { "market": 20, # Market analyst needs more tool calls for comprehensive analysis "social": 3, "news": 3, "fundamentals": 3 } self.total_calls = {} # analyst_type -> total_count def _hash_params(self, params: dict) -> str: """Create a hash of parameters for comparison.""" # Sort keys for consistent hashing sorted_params = json.dumps(params, sort_keys=True) return hashlib.md5(sorted_params.encode()).hexdigest() def _get_max_calls_for_analyst(self, analyst_type: str) -> int: """Get the maximum number of calls allowed for a specific analyst type.""" return self.max_total_calls.get(analyst_type, 3) # Default to 3 if not specified def can_call_tool(self, analyst_type: str, tool_name: str, params: dict) -> Tuple[bool, str]: """Check if a tool can be called with given parameters.""" if analyst_type not in self.call_history: self.call_history[analyst_type] = {} self.call_counts[analyst_type] = {} self.total_calls[analyst_type] = 0 # Check total call limit for this analyst max_calls = self._get_max_calls_for_analyst(analyst_type) if self.total_calls[analyst_type] >= max_calls: return False, f"Analyst {analyst_type} has reached maximum total tool calls ({max_calls})" # Initialize tool tracking if first time if tool_name not in self.call_history[analyst_type]: self.call_history[analyst_type][tool_name] = [] self.call_counts[analyst_type][tool_name] = 0 # Check for duplicate parameters - each request/query must be different param_hash = self._hash_params(params) param_str = json.dumps(params, sort_keys=True) for existing_hash, existing_params in self.call_history[analyst_type][tool_name]: if param_hash == existing_hash: return False, f"Tool {tool_name} already called with identical parameters. Each request must be different. Previous: {existing_params}" return True, "OK" def record_tool_call(self, analyst_type: str, tool_name: str, params: dict): """Record a successful tool call.""" if analyst_type not in self.call_history: self.call_history[analyst_type] = {} self.call_counts[analyst_type] = {} self.total_calls[analyst_type] = 0 if tool_name not in self.call_history[analyst_type]: self.call_history[analyst_type][tool_name] = [] self.call_counts[analyst_type][tool_name] = 0 param_hash = self._hash_params(params) param_str = json.dumps(params, sort_keys=True) self.call_history[analyst_type][tool_name].append((param_hash, param_str)) self.call_counts[analyst_type][tool_name] += 1 self.total_calls[analyst_type] += 1 max_calls = self._get_max_calls_for_analyst(analyst_type) logger.info(f"🔧 Recorded tool call: {analyst_type}/{tool_name} (total calls: {self.total_calls[analyst_type]}/{max_calls})") class GraphSetup: """Handles the setup and configuration of the agent graph with parallel analyst execution.""" def __init__( self, quick_thinking_llm: ChatOpenAI, deep_thinking_llm: ChatOpenAI, toolkit: Toolkit, tool_nodes: Dict[str, ToolNode], bull_memory, bear_memory, trader_memory, invest_judge_memory, risk_manager_memory, conditional_logic: ConditionalLogic, ): """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm self.deep_thinking_llm = deep_thinking_llm self.toolkit = toolkit self.tool_nodes = tool_nodes self.bull_memory = bull_memory self.bear_memory = bear_memory self.trader_memory = trader_memory self.invest_judge_memory = invest_judge_memory self.risk_manager_memory = risk_manager_memory self.conditional_logic = conditional_logic # Initialize tool call tracker self.tool_tracker = ToolCallTracker() # Track report completions to prevent duplicates self.completed_reports = set() def setup_graph( self, selected_analysts=["market", "social", "news", "fundamentals"] ): """Set up and compile the agent workflow graph with parallel analyst execution.""" if len(selected_analysts) == 0: raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") logger.info(f"🚀 Setting up parallel graph with analysts: {selected_analysts}") # Create main workflow workflow = StateGraph(AgentState) # Add dispatcher node logger.info("📋 Adding Dispatcher node") workflow.add_node("Dispatcher", self._create_dispatcher()) # Add individual analyst and tool nodes for parallel execution for analyst_type in selected_analysts: logger.info(f"🔧 Adding {analyst_type} analyst nodes") # Create analyst and tool nodes if analyst_type == "market": analyst_node = create_market_analyst(self.quick_thinking_llm, self.toolkit) tool_node = self.tool_nodes["market"] message_key = "market_messages" report_key = "market_report" elif analyst_type == "social": analyst_node = create_social_media_analyst(self.quick_thinking_llm, self.toolkit) tool_node = self.tool_nodes["social"] message_key = "social_messages" report_key = "sentiment_report" elif analyst_type == "news": analyst_node = create_news_analyst(self.quick_thinking_llm, self.toolkit) tool_node = self.tool_nodes["news"] message_key = "news_messages" report_key = "news_report" elif analyst_type == "fundamentals": analyst_node = create_fundamentals_analyst(self.quick_thinking_llm, self.toolkit) tool_node = self.tool_nodes["fundamentals"] message_key = "fundamentals_messages" report_key = "fundamentals_report" else: raise ValueError(f"Unknown analyst type: {analyst_type}") # Wrap nodes for specific message channels wrapped_analyst = self._wrap_analyst_for_channel( analyst_node, message_key, report_key, analyst_type ) wrapped_tool_node = self._wrap_tool_node_for_channel( tool_node, message_key, analyst_type ) # Add nodes to main workflow workflow.add_node(f"{analyst_type}_analyst", wrapped_analyst) workflow.add_node(f"{analyst_type}_tools", wrapped_tool_node) # Add aggregator node logger.info("📊 Adding Aggregator node") workflow.add_node("Aggregator", self._create_aggregator()) # Create researcher and manager nodes bull_researcher_node = create_bull_researcher( self.quick_thinking_llm, self.bull_memory ) bear_researcher_node = create_bear_researcher( self.quick_thinking_llm, self.bear_memory ) research_manager_node = create_research_manager( self.deep_thinking_llm, self.invest_judge_memory ) trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) # Create risk analysis nodes - wrapped for parallel execution risky_analyst_node = create_risky_debator(self.quick_thinking_llm) neutral_analyst_node = create_neutral_debator(self.quick_thinking_llm) safe_analyst_node = create_safe_debator(self.quick_thinking_llm) risk_manager_node = create_risk_manager( self.deep_thinking_llm, self.risk_manager_memory ) # Wrap risk analysts for parallel execution wrapped_risky_analyst = self._wrap_risk_analyst_for_channel(risky_analyst_node, "risky") wrapped_safe_analyst = self._wrap_risk_analyst_for_channel(safe_analyst_node, "safe") wrapped_neutral_analyst = self._wrap_risk_analyst_for_channel(neutral_analyst_node, "neutral") # Add remaining nodes workflow.add_node("Bull Researcher", bull_researcher_node) workflow.add_node("Bear Researcher", bear_researcher_node) workflow.add_node("Research Manager", research_manager_node) workflow.add_node("Trader", trader_node) # Add Risk Dispatcher and Aggregator for parallel risk execution workflow.add_node("Risk Dispatcher", self._create_risk_dispatcher()) workflow.add_node("Risky Analyst", wrapped_risky_analyst) workflow.add_node("Safe Analyst", wrapped_safe_analyst) workflow.add_node("Neutral Analyst", wrapped_neutral_analyst) workflow.add_node("Risk Aggregator", self._create_risk_aggregator()) workflow.add_node("Risk Judge", risk_manager_node) # Define edges for parallel execution logger.info("🔗 Setting up graph edges for parallel execution") # Start with dispatcher workflow.add_edge(START, "Dispatcher") # From dispatcher, go to all analysts in parallel for analyst_type in selected_analysts: workflow.add_edge("Dispatcher", f"{analyst_type}_analyst") # Set up analyst -> tools -> completion routing for analyst_type in selected_analysts: # Define conditional logic for each analyst def create_analyst_conditional(atype): def should_continue_analyst(state: AgentState) -> str: message_key = f"{atype}_messages" report_key_map = { "market": "market_report", "social": "sentiment_report", "news": "news_report", "fundamentals": "fundamentals_report" } report_key = report_key_map.get(atype, f"{atype}_report") messages = state.get(message_key, []) report = state.get(report_key, "") # If report exists, go to aggregator if report: return "aggregator" # If no messages, go to aggregator if not messages: return "aggregator" last_message = messages[-1] # Check for tool calls has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls # Count tool messages to see how much data we have tool_message_count = sum(1 for msg in messages if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool') # Check total tool calls made total_calls = self.tool_tracker.total_calls.get(atype, 0) max_calls = self.tool_tracker.max_total_calls.get(atype, 3) # Special handling for market analyst if atype == "market": # Market analyst needs more cycles to gather comprehensive data if has_tool_calls and total_calls < max_calls: return "tools" elif tool_message_count >= 4 and not has_tool_calls: # Has enough data and no more tool calls - should generate report return "aggregator" elif total_calls >= max_calls: # Hit max calls - force completion return "aggregator" elif has_tool_calls: return "tools" else: return "aggregator" # Special handling for social analyst elif atype == "social": # Social analyst needs multiple tool calls for comprehensive analysis if has_tool_calls and total_calls < max_calls: return "tools" elif tool_message_count >= 2 and not has_tool_calls: # Has enough data and no more tool calls - should generate report return "aggregator" elif total_calls >= max_calls: # Hit max calls - force completion return "aggregator" elif has_tool_calls: return "tools" else: return "aggregator" # For news and fundamentals analysts else: if has_tool_calls and total_calls < max_calls: return "tools" elif tool_message_count >= 1 and not has_tool_calls: # Has data and no more tool calls - should generate report return "aggregator" elif total_calls >= max_calls: # Hit max calls - force completion return "aggregator" elif has_tool_calls: return "tools" else: return "aggregator" return should_continue_analyst # Define conditional logic for tools def create_tool_conditional(atype): def should_continue_after_tools(state: AgentState) -> str: message_key = f"{atype}_messages" messages = state.get(message_key, []) # Check total tool calls total_calls = self.tool_tracker.total_calls.get(atype, 0) max_calls = self.tool_tracker.max_total_calls.get(atype, 3) # Count tool messages tool_message_count = sum(1 for msg in messages if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool') # After tools execute, always go back to analyst to generate report # The analyst will decide whether to make more tool calls or generate final report return "analyst" return should_continue_after_tools # Add conditional edges for each analyst workflow.add_conditional_edges( f"{analyst_type}_analyst", create_analyst_conditional(analyst_type), { "tools": f"{analyst_type}_tools", "aggregator": "Aggregator" } ) # Add conditional edges for tools workflow.add_conditional_edges( f"{analyst_type}_tools", create_tool_conditional(analyst_type), { "analyst": f"{analyst_type}_analyst", "aggregator": "Aggregator" } ) # Aggregator continues to Bull Researcher workflow.add_edge("Aggregator", "Bull Researcher") # Add remaining edges workflow.add_conditional_edges( "Bull Researcher", self.conditional_logic.should_continue_debate, { "Bear Researcher": "Bear Researcher", "Research Manager": "Research Manager", }, ) workflow.add_conditional_edges( "Bear Researcher", self.conditional_logic.should_continue_debate, { "Bull Researcher": "Bull Researcher", "Research Manager": "Research Manager", }, ) workflow.add_edge("Research Manager", "Trader") workflow.add_edge("Trader", "Risk Dispatcher") # Parallel risk analyst execution workflow.add_edge("Risk Dispatcher", "Risky Analyst") workflow.add_edge("Risk Dispatcher", "Safe Analyst") workflow.add_edge("Risk Dispatcher", "Neutral Analyst") # All risk analysts go to aggregator workflow.add_edge("Risky Analyst", "Risk Aggregator") workflow.add_edge("Safe Analyst", "Risk Aggregator") workflow.add_edge("Neutral Analyst", "Risk Aggregator") # Aggregator goes to Risk Judge for final decision workflow.add_edge("Risk Aggregator", "Risk Judge") workflow.add_edge("Risk Judge", END) # Compile and return logger.info("✅ Graph setup complete, compiling...") return workflow.compile() def _create_dispatcher(self): """Create dispatcher node that initializes message channels for each analyst.""" def dispatch(state: AgentState) -> dict: logger.info("=" * 80) logger.info("📋 NODE EXECUTING: DISPATCHER") logger.info("=" * 80) company = state.get("company_of_interest", "Unknown") date = state.get("trade_date", "Unknown") logger.info(f"📋 Dispatcher: Starting parallel analysis for {company} on {date}") # Initialize message channels with initial messages initial_message = f"Analyze {company} on {date}" update = { "market_messages": [HumanMessage(content=initial_message)], "social_messages": [HumanMessage(content=initial_message)], "news_messages": [HumanMessage(content=initial_message)], "fundamentals_messages": [HumanMessage(content=initial_message)] } logger.info("📋 Dispatcher: Initialized all analyst message channels") logger.info("📋 Dispatcher: Starting Market, Social, News, and Fundamentals analysts in parallel") logger.info("✅ DISPATCHER COMPLETE") return update return dispatch def _wrap_analyst_for_channel(self, analyst_node, message_key: str, report_key: str, analyst_type: str): """Wrap an analyst node to work with a specific message channel.""" def wrapped_analyst(state: AgentState) -> dict: logger.info("-" * 60) logger.info(f"🧠 NODE EXECUTING: {analyst_type.upper()} ANALYST") logger.info("-" * 60) # Check if report already exists - prevent duplicate completion existing_report = state.get(report_key, "") if existing_report: logger.info(f"🧠 {analyst_type} analyst: ✅ REPORT ALREADY EXISTS - skipping") # Check if this report was already marked as completed report_id = f"{analyst_type}_report_completed" if report_id not in self.completed_reports: self.completed_reports.add(report_id) logger.info(f"🧠 {analyst_type} analyst: First time seeing completed report, allowing one update") else: logger.info(f"🧠 {analyst_type} analyst: Report already marked as completed, skipping all updates") return {} return {message_key: state.get(message_key, [])} # Get the analyst's messages messages = state.get(message_key, []) logger.info(f"🧠 {analyst_type} analyst: Processing {len(messages)} messages") # Debug: Show message types and tool call counts tool_message_count = sum(1 for msg in messages if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool') logger.info(f"🧠 {analyst_type} analyst: Tool messages in history: {tool_message_count}") # Create a temporary state with the analyst's messages temp_state = state.copy() temp_state["messages"] = messages try: # Run the original analyst logger.info(f"🧠 {analyst_type} analyst: Invoking LLM...") result = analyst_node(temp_state) logger.info(f"🧠 {analyst_type} analyst: LLM response received") # Extract the updated messages updated_messages = result.get("messages", messages) logger.info(f"🧠 {analyst_type} analyst: Updated from {len(messages)} to {len(updated_messages)} messages") # Check if the analyst node directly returned a report direct_report = result.get(report_key, "") if direct_report: logger.info(f"🧠 {analyst_type} analyst: ✅ DIRECT REPORT GENERATED ({len(direct_report)} chars)") # Mark this report as completed self.completed_reports.add(f"{analyst_type}_report_completed") logger.info(f"🧠 {analyst_type} analyst: ✅ SETTING {report_key}") # Return updates with the direct report update = { message_key: updated_messages, report_key: direct_report } logger.info(f"✅ {analyst_type.upper()} ANALYST COMPLETE") return update # If no direct report, check if this is a final response from message content report = "" if updated_messages: last_message = updated_messages[-1] has_tool_calls = hasattr(last_message, 'tool_calls') and last_message.tool_calls has_content = hasattr(last_message, 'content') and last_message.content logger.info(f"🧠 {analyst_type} analyst: Last message has_tool_calls={has_tool_calls}, has_content={has_content}") # Initialize content variable content = str(last_message.content) if has_content else "" # Count tool messages in the full conversation tool_result_count = sum(1 for msg in updated_messages if hasattr(msg, 'type') and str(getattr(msg, 'type', '')) == 'tool') logger.info(f"🧠 {analyst_type} analyst: Total tool results: {tool_result_count}") # Also check for ToolMessage instances if tool_result_count == 0: tool_result_count = sum(1 for msg in updated_messages if isinstance(msg, ToolMessage)) logger.info(f"🧠 {analyst_type} analyst: ToolMessage instances: {tool_result_count}") # Check if we should generate a final report should_generate_report = False if has_content and not has_tool_calls: # No more tool calls, has content - likely final response should_generate_report = True logger.info(f"🧠 {analyst_type} analyst: Final response detected (no tool calls)") elif has_content and tool_result_count > 0: # Has content and tool results - might be ready for final summary if analyst_type == "market": # Market analyst needs comprehensive data if tool_result_count >= 4: should_generate_report = True logger.info(f"🧠 {analyst_type} analyst: Market analyst with sufficient tool results ({tool_result_count})") elif analyst_type == "social": # Social analyst needs multiple sources if tool_result_count >= 2: should_generate_report = True logger.info(f"🧠 {analyst_type} analyst: Social analyst with sufficient tool results ({tool_result_count})") elif tool_result_count >= 1: # Other analysts need fewer tools should_generate_report = True logger.info(f"🧠 {analyst_type} analyst: {analyst_type} analyst with tool results ({tool_result_count})") elif not has_content and tool_result_count > 0: # Has tool results but no content yet - might need to force completion total_calls = self.tool_tracker.total_calls.get(analyst_type, 0) max_calls = self.tool_tracker.max_total_calls.get(analyst_type, 3) if total_calls >= max_calls or tool_result_count >= 4: # Force completion with available data logger.info(f"🧠 {analyst_type} analyst: Forcing completion with available data ({tool_result_count} tool results)") should_generate_report = True # Create a summary from the tool results content = f"Analysis for {state.get('company_of_interest', 'unknown')} based on {tool_result_count} data sources and technical analysis." # Special handling for market analyst - if it has many tool calls but no content yet, # it might need to go through another cycle if analyst_type == "market" and has_tool_calls and not has_content: logger.info(f"🧠 {analyst_type} analyst: Market analyst making more tool calls") if should_generate_report and content: # Only consider it a report if it has substantial content if len(content) > 100 or (tool_result_count > 0 and len(content) > 20): report = content logger.info(f"🧠 {analyst_type} analyst: ✅ FINAL REPORT GENERATED FROM MESSAGE ({len(content)} chars)") else: logger.info(f"🧠 {analyst_type} analyst: Content too short for report ({len(content)} chars)") else: logger.info(f"🧠 {analyst_type} analyst: Not ready for final report yet") # Return updates update = {message_key: updated_messages} if report: update[report_key] = report # Mark this report as completed self.completed_reports.add(f"{analyst_type}_report_completed") logger.info(f"🧠 {analyst_type} analyst: ✅ SETTING {report_key}") logger.info(f"✅ {analyst_type.upper()} ANALYST COMPLETE") return update except Exception as e: logger.error(f"❌ {analyst_type} analyst error: {str(e)}") raise return wrapped_analyst def _wrap_tool_node_for_channel(self, tool_node, message_key: str, analyst_type: str): """Wrap a tool node to work with a specific message channel with tool call limits.""" def wrapped_tool_node(state: AgentState) -> dict: logger.info("-" * 60) logger.info(f"🔧 NODE EXECUTING: {analyst_type.upper()} TOOLS") logger.info("-" * 60) # Get the analyst's messages messages = state.get(message_key, []) logger.info(f"🔧 {analyst_type} tools: Processing {len(messages)} messages") if not messages: logger.error(f"❌ {analyst_type} tools: No messages found") return {message_key: messages} last_msg = messages[-1] logger.info(f"🔧 {analyst_type} tools: Last message type: {type(last_msg).__name__}") if not (hasattr(last_msg, 'tool_calls') and last_msg.tool_calls): logger.error(f"❌ {analyst_type} tools: No tool calls found") return {message_key: messages} logger.info(f"🔧 {analyst_type} tools: Found {len(last_msg.tool_calls)} tool calls") # Process each tool call updated_messages = list(messages) tools_executed = 0 for i, tool_call in enumerate(last_msg.tool_calls): try: # Get tool call details - handle both dict and object formats if isinstance(tool_call, dict): tool_name = tool_call.get('name', '') tool_args = tool_call.get('args', {}) tool_call_id = tool_call.get('id', 'unknown') elif hasattr(tool_call, 'name'): tool_name = tool_call.name tool_args = tool_call.args if hasattr(tool_call, 'args') else {} tool_call_id = tool_call.id if hasattr(tool_call, 'id') else 'unknown' else: logger.error(f"❌ {analyst_type} tools: Unknown tool call format") continue if not tool_name: logger.error(f"❌ {analyst_type} tools: Empty tool name") continue # Check if the tool can be called can_call, reason = self.tool_tracker.can_call_tool(analyst_type, tool_name, tool_args) if not can_call: logger.warning(f"🔧 {analyst_type} tools: SKIPPING - {reason}") continue logger.info(f"🔧 {analyst_type} tools: [{i+1}/{len(last_msg.tool_calls)}] Executing {tool_name}") # Find and execute the tool tool_result = None for tool_func in tool_node.tools_by_name.values(): if tool_func.name == tool_name: tool_result = tool_func.invoke(tool_args) break if tool_result is None: logger.error(f"❌ {analyst_type} tools: Tool {tool_name} not found") tool_result = f"Error: Tool {tool_name} not found" # Create ToolMessage tool_message = ToolMessage( content=str(tool_result), tool_call_id=tool_call_id ) updated_messages.append(tool_message) logger.info(f"🔧 {analyst_type} tools: ✅ Added ToolMessage for {tool_name}") # Record the tool call self.tool_tracker.record_tool_call(analyst_type, tool_name, tool_args) tools_executed += 1 except Exception as e: logger.error(f"❌ {analyst_type} tools: Error executing {tool_name}: {str(e)}") logger.info(f"🔧 {analyst_type} tools: Executed {tools_executed} tools") logger.info(f"🔧 {analyst_type} tools: Total calls for {analyst_type}: {self.tool_tracker.total_calls.get(analyst_type, 0)}") # Return updates update = {message_key: updated_messages} logger.info(f"✅ {analyst_type.upper()} TOOLS COMPLETE") return update return wrapped_tool_node def _create_aggregator(self): """Create aggregator node that validates all analyst reports are complete.""" def aggregate(state: AgentState) -> dict: logger.info("=" * 80) logger.info("📊 NODE EXECUTING: AGGREGATOR") logger.info("=" * 80) # Check that all expected reports are present reports = { "market_report": state.get("market_report", ""), "sentiment_report": state.get("sentiment_report", ""), "news_report": state.get("news_report", ""), "fundamentals_report": state.get("fundamentals_report", "") } logger.info("📊 Aggregator: Checking report status:") for report_name, report_content in reports.items(): status = "✅ PRESENT" if report_content.strip() else "❌ MISSING" length = len(report_content) if report_content else 0 logger.info(f" - {report_name}: {status} ({length} chars)") completed_reports = [name for name, report in reports.items() if report.strip()] missing_reports = [name for name, report in reports.items() if not report.strip()] logger.info(f"📊 Aggregator: ✅ Completed reports: {completed_reports}") if missing_reports: logger.warning(f"📊 Aggregator: ❌ Missing reports: {missing_reports}") logger.info("📊 Aggregator: Marking analysis phase as complete") logger.info("✅ AGGREGATOR COMPLETE") # Don't initialize debate states here - let the Bull Researcher do it # This prevents concurrent update errors return { "analysis_complete": True } return aggregate def _wrap_risk_analyst_for_channel(self, risk_analyst_node, analyst_type: str): """Wrap a risk analyst node to work with risk debate state.""" def wrapped_risk_analyst(state: AgentState) -> dict: logger.info("-" * 60) logger.info(f"⚡ NODE EXECUTING: {analyst_type.upper()} RISK ANALYST") logger.info("-" * 60) try: # Run the original risk analyst logger.info(f"⚡ {analyst_type} risk analyst: Invoking LLM...") result = risk_analyst_node(state) logger.info(f"⚡ {analyst_type} risk analyst: LLM response received") # Extract the risk debate state update risk_debate_state = result.get("risk_debate_state", state.get("risk_debate_state", {})) # Update the appropriate response field response_key = f"current_{analyst_type}_response" if response_key in risk_debate_state: logger.info(f"⚡ {analyst_type} risk analyst: ✅ Analysis complete") logger.info(f"⚡ {analyst_type} risk analyst: Response length: {len(risk_debate_state[response_key])} chars") logger.info(f"✅ {analyst_type.upper()} RISK ANALYST COMPLETE") return {"risk_debate_state": risk_debate_state} except Exception as e: logger.error(f"❌ {analyst_type} risk analyst error: {str(e)}") raise return wrapped_risk_analyst def _create_risk_dispatcher(self): """Create risk dispatcher node that initializes risk analysis phase.""" def dispatch_risk(state: AgentState) -> dict: logger.info("=" * 80) logger.info("⚡ NODE EXECUTING: RISK DISPATCHER") logger.info("=" * 80) # Initialize risk debate state if not present risk_debate_state = state.get("risk_debate_state", {}) # Ensure all required fields are initialized initial_risk_debate = { "risky_history": risk_debate_state.get("risky_history", ""), "safe_history": risk_debate_state.get("safe_history", ""), "neutral_history": risk_debate_state.get("neutral_history", ""), "history": risk_debate_state.get("history", ""), "latest_speaker": "", "current_risky_response": "", "current_safe_response": "", "current_neutral_response": "", "judge_decision": "", "count": 0 } logger.info("⚡ Risk Dispatcher: Initializing parallel risk analysis") logger.info("⚡ Risk Dispatcher: Starting Risky, Safe, and Neutral analysts in parallel") logger.info("✅ RISK DISPATCHER COMPLETE") return {"risk_debate_state": initial_risk_debate} return dispatch_risk def _create_risk_aggregator(self): """Create risk aggregator node that collects all risk analyses.""" def aggregate_risk(state: AgentState) -> dict: logger.info("=" * 80) logger.info("⚡ NODE EXECUTING: RISK AGGREGATOR") logger.info("=" * 80) risk_debate_state = state.get("risk_debate_state", {}) # Check that all risk analyses are complete risky_response = risk_debate_state.get("current_risky_response", "") safe_response = risk_debate_state.get("current_safe_response", "") neutral_response = risk_debate_state.get("current_neutral_response", "") logger.info("⚡ Risk Aggregator: Checking risk analysis status:") logger.info(f" - Risky analysis: {'✅ COMPLETE' if risky_response else '❌ MISSING'} ({len(risky_response)} chars)") logger.info(f" - Safe analysis: {'✅ COMPLETE' if safe_response else '❌ MISSING'} ({len(safe_response)} chars)") logger.info(f" - Neutral analysis: {'✅ COMPLETE' if neutral_response else '❌ MISSING'} ({len(neutral_response)} chars)") # Validate that we have at least some analysis total_responses = len([r for r in [risky_response, safe_response, neutral_response] if r]) if total_responses == 0: logger.error("⚡ Risk Aggregator: ❌ NO RISK ANALYSES AVAILABLE") # Create fallback history combined_history = "No risk analysis available from any analyst. Unable to provide risk assessment." elif total_responses < 3: logger.warning(f"⚡ Risk Aggregator: ⚠️ Only {total_responses}/3 risk analyses available") # Combine available responses combined_history = "" if risky_response: combined_history += f"**Risky Analyst**: {risky_response}\n\n" if safe_response: combined_history += f"**Safe Analyst**: {safe_response}\n\n" if neutral_response: combined_history += f"**Neutral Analyst**: {neutral_response}\n\n" # Add note about missing analyses missing_analysts = [] if not risky_response: missing_analysts.append("Risky") if not safe_response: missing_analysts.append("Safe") if not neutral_response: missing_analysts.append("Neutral") combined_history += f"**Note**: Missing analysis from {', '.join(missing_analysts)} analyst(s). Decision based on available data only." else: logger.info("⚡ Risk Aggregator: ✅ All risk analyses complete") # Combine all responses for Risk Judge input combined_history = "" combined_history += f"**Risky Analyst**: {risky_response}\n\n" combined_history += f"**Safe Analyst**: {safe_response}\n\n" combined_history += f"**Neutral Analyst**: {neutral_response}\n\n" # Update risk debate state with combined history updated_risk_state = risk_debate_state.copy() updated_risk_state["history"] = combined_history updated_risk_state["count"] = 1 # Mark as ready for judgment logger.info(f"⚡ Risk Aggregator: Combined history length: {len(combined_history)} chars") logger.info("⚡ Risk Aggregator: Risk analyses aggregated for final judgment") logger.info("✅ RISK AGGREGATOR COMPLETE") return {"risk_debate_state": updated_risk_state} return aggregate_risk