# TradingAgents/graph/trading_graph.py import os from pathlib import Path import json from datetime import date from typing import Dict, Any, Tuple, List, Optional import logging from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI from langgraph.prebuilt import ToolNode # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.agents.utils.agent_states import ( AgentState, InvestDebateState, RiskDebateState, ) from tradingagents.dataflows.interface import set_config from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector from .signal_processing import SignalProcessor class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" def __init__( self, selected_analysts=["market", "social", "news", "fundamentals"], debug=False, config: Dict[str, Any] = None, ): """Initialize the trading agents graph and components. Args: selected_analysts: List of analyst types to include debug: Whether to run in debug mode config: Configuration dictionary. If None, uses default config """ self.debug = debug self.config = config or DEFAULT_CONFIG # Update the interface's config set_config(self.config) # Create necessary directories os.makedirs( os.path.join(self.config["project_dir"], "dataflows/data_cache"), exist_ok=True, ) # Initialize LLMs if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "anthropic": self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "google": self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"]) self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) else: raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") self.toolkit = Toolkit(config=self.config) # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config) self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) # Create tool nodes self.tool_nodes = self._create_tool_nodes() # Initialize components self.conditional_logic = ConditionalLogic() self.graph_setup = GraphSetup( self.quick_thinking_llm, self.deep_thinking_llm, self.toolkit, self.tool_nodes, self.bull_memory, self.bear_memory, self.trader_memory, self.invest_judge_memory, self.risk_manager_memory, self.conditional_logic, ) self.propagator = Propagator() self.reflector = Reflector(self.quick_thinking_llm) self.signal_processor = SignalProcessor(self.quick_thinking_llm) # State tracking self.curr_state = None self.ticker = None self.log_states_dict = {} # date to full state dict # Set up the graph self.graph = self.graph_setup.setup_graph(selected_analysts) def _create_tool_nodes(self) -> Dict[str, ToolNode]: """Create tool nodes for different data sources with specific message channels.""" logger.info("🔧 Creating tool nodes with message channels") tool_nodes = { "market": ToolNode( [ # online tools self.toolkit.get_YFin_data_online, self.toolkit.get_stockstats_indicators_report_online, # offline tools self.toolkit.get_YFin_data, self.toolkit.get_stockstats_indicators_report, ], messages_key="market_messages" ), "social": ToolNode( [ # online tools self.toolkit.get_stock_news_openai, # offline tools self.toolkit.get_reddit_stock_info, ], messages_key="social_messages" ), "news": ToolNode( [ # online tools self.toolkit.get_global_news_openai, self.toolkit.get_google_news, # offline tools self.toolkit.get_finnhub_news, self.toolkit.get_reddit_news, ], messages_key="news_messages" ), "fundamentals": ToolNode( [ # online tools self.toolkit.get_fundamentals_openai, # offline tools self.toolkit.get_finnhub_company_insider_sentiment, self.toolkit.get_finnhub_company_insider_transactions, self.toolkit.get_simfin_balance_sheet, self.toolkit.get_simfin_cashflow, self.toolkit.get_simfin_income_stmt, ], messages_key="fundamentals_messages" ), } for tool_type, node in tool_nodes.items(): logger.info(f" ✅ {tool_type}: {len(node.tools_by_name)} tools") return tool_nodes def propagate(self, company_name, trade_date): """Run the trading agents graph for a company on a specific date.""" self.ticker = company_name # Initialize state init_agent_state = self.propagator.create_initial_state( company_name, trade_date ) args = self.propagator.get_graph_args() if self.debug: # Debug mode with tracing logger.info("🐛 Running in debug mode with full tracing") trace = [] chunk_count = 0 for chunk in self.graph.stream(init_agent_state, **args): chunk_count += 1 logger.info(f"🔄 Processing chunk {chunk_count}") logger.info(f"📋 Chunk keys: {list(chunk.keys())}") # Check for any message updates in analyst channels message_channels = ["market_messages", "social_messages", "news_messages", "fundamentals_messages"] for channel in message_channels: if channel in chunk and chunk[channel]: logger.info(f"💬 Updated {channel}: {len(chunk[channel])} messages") if chunk[channel]: last_msg = chunk[channel][-1] logger.info(f"📝 Last {channel} message type: {type(last_msg).__name__}") if hasattr(last_msg, 'content'): logger.info(f"📝 Content preview: {str(last_msg.content)[:200]}...") if hasattr(last_msg, 'tool_calls') and last_msg.tool_calls: logger.info(f"🔧 Tool calls: {[tc.name if hasattr(tc, 'name') else str(tc) for tc in last_msg.tool_calls]}") # Check for report updates report_keys = ["market_report", "sentiment_report", "news_report", "fundamentals_report"] for report_key in report_keys: if report_key in chunk and chunk[report_key]: logger.info(f"📊 Report generated: {report_key} ({len(chunk[report_key])} chars)") trace.append(chunk) logger.info(f"✅ Debug execution complete. Processed {chunk_count} chunks") final_state = trace[-1] if trace else init_agent_state else: # Standard mode without tracing logger.info("🏃 Running in standard mode") try: final_state = self.graph.invoke(init_agent_state, **args) logger.info("✅ Standard execution complete") except Exception as e: logger.error(f"❌ Error during graph execution: {str(e)}") logger.error(f"❌ Error type: {type(e).__name__}") raise # Store current state for reflection self.curr_state = final_state # Log state logger.info("💾 Logging final state") self._log_state(trade_date, final_state) # Process final decision final_decision = final_state.get("final_trade_decision", "No decision made") processed_signal = self.process_signal(final_decision) logger.info(f"🎯 Analysis complete for {company_name}") logger.info(f"📊 Final decision: {final_decision[:100]}...") logger.info(f"🔄 Processed signal: {processed_signal}") # Return decision and processed signal return final_state, processed_signal def _log_state(self, trade_date, final_state): """Log the final state to a JSON file.""" self.log_states_dict[str(trade_date)] = { "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], "market_report": final_state["market_report"], "sentiment_report": final_state["sentiment_report"], "news_report": final_state["news_report"], "fundamentals_report": final_state["fundamentals_report"], "investment_debate_state": { "bull_history": final_state["investment_debate_state"]["bull_history"], "bear_history": final_state["investment_debate_state"]["bear_history"], "history": final_state["investment_debate_state"]["history"], "current_response": final_state["investment_debate_state"][ "current_response" ], "judge_decision": final_state["investment_debate_state"][ "judge_decision" ], }, "trader_investment_decision": final_state["trader_investment_plan"], "risk_debate_state": { "risky_history": final_state["risk_debate_state"]["risky_history"], "safe_history": final_state["risk_debate_state"]["safe_history"], "neutral_history": final_state["risk_debate_state"]["neutral_history"], "history": final_state["risk_debate_state"]["history"], "judge_decision": final_state["risk_debate_state"]["judge_decision"], }, "investment_plan": final_state["investment_plan"], "final_trade_decision": final_state["final_trade_decision"], } # Save to file directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") directory.mkdir(parents=True, exist_ok=True) with open( f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log.json", "w", ) as f: json.dump(self.log_states_dict, f, indent=4) def reflect_and_remember(self, returns_losses): """Reflect on decisions and update memory based on returns.""" self.reflector.reflect_bull_researcher( self.curr_state, returns_losses, self.bull_memory ) self.reflector.reflect_bear_researcher( self.curr_state, returns_losses, self.bear_memory ) self.reflector.reflect_trader( self.curr_state, returns_losses, self.trader_memory ) self.reflector.reflect_invest_judge( self.curr_state, returns_losses, self.invest_judge_memory ) self.reflector.reflect_risk_manager( self.curr_state, returns_losses, self.risk_manager_memory ) def process_signal(self, full_signal): """Process a signal to extract the core decision.""" return self.signal_processor.process_signal(full_signal)