# Litadel/graph/trading_graph.py # Copyright Notice: Litadel is a successor of TradingAgents by TaurusResearch. # This project builds upon and extends the original TradingAgents framework. import os from pathlib import Path import json from datetime import date from typing import Dict, Any, Tuple, List, Optional from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI from langgraph.prebuilt import ToolNode from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.agents.utils.agent_states import ( AgentState, InvestDebateState, RiskDebateState, ) from tradingagents.dataflows.config import set_config # Import unified tools from agent_utils from tradingagents.agents.utils.agent_utils import ( get_market_data, get_indicators, get_asset_news, get_global_news_unified as get_global_news, get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions, ) from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector from .signal_processing import SignalProcessor class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" def __init__( self, selected_analysts=["market", "social", "news", "fundamentals"], debug=False, config: Dict[str, Any] = None, analysis_id: Optional[str] = None, ): """Initialize the trading agents graph and components. Args: selected_analysts: List of analyst types to include debug: Whether to run in debug mode config: Configuration dictionary. If None, uses default config analysis_id: Optional unique identifier for this analysis (makes memory collections unique) """ self.debug = debug self.config = config or DEFAULT_CONFIG self.analysis_id = analysis_id # Update the interface's config set_config(self.config) # Create necessary directories os.makedirs( os.path.join(self.config["project_dir"], "dataflows/data_cache"), exist_ok=True, ) # Initialize LLMs if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "anthropic": self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "google": self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"]) self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) else: raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") # Initialize memories with unique names per analysis # This prevents "Collection already exists" errors when running multiple analyses memory_suffix = f"_{analysis_id}" if analysis_id else "" self.bull_memory = FinancialSituationMemory(f"bull_memory{memory_suffix}", self.config) self.bear_memory = FinancialSituationMemory(f"bear_memory{memory_suffix}", self.config) self.trader_memory = FinancialSituationMemory(f"trader_memory{memory_suffix}", self.config) self.invest_judge_memory = FinancialSituationMemory(f"invest_judge_memory{memory_suffix}", self.config) self.risk_manager_memory = FinancialSituationMemory(f"risk_manager_memory{memory_suffix}", self.config) # Create tool nodes self.tool_nodes = self._create_tool_nodes() # Initialize components self.conditional_logic = ConditionalLogic() self.graph_setup = GraphSetup( self.quick_thinking_llm, self.deep_thinking_llm, self.tool_nodes, self.bull_memory, self.bear_memory, self.trader_memory, self.invest_judge_memory, self.risk_manager_memory, self.conditional_logic, ) self.propagator = Propagator() self.reflector = Reflector(self.quick_thinking_llm) self.signal_processor = SignalProcessor(self.quick_thinking_llm) # State tracking self.curr_state = None self.ticker = None self.log_states_dict = {} # date to full state dict # Set up the graph self.graph = self.graph_setup.setup_graph(selected_analysts) def _create_tool_nodes(self) -> Dict[str, ToolNode]: """Create tool nodes using unified tools that work across all asset classes.""" # Unified tools automatically route based on asset_class context # No more conditional logic needed here! return { "market": ToolNode([get_market_data, get_indicators]), "social": ToolNode([get_asset_news, get_global_news]), "news": ToolNode([get_asset_news, get_global_news]), "fundamentals": ToolNode([ get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions, ]), } def propagate(self, company_name, trade_date): """Run the trading agents graph for a company on a specific date.""" self.ticker = company_name # Initialize state init_agent_state = self.propagator.create_initial_state( company_name, trade_date ) # Pass asset class into state for downstream branching init_agent_state["asset_class"] = self.config.get("asset_class", "equity") args = self.propagator.get_graph_args() if self.debug: # Debug mode with tracing trace = [] for chunk in self.graph.stream(init_agent_state, **args): if len(chunk["messages"]) == 0: pass else: chunk["messages"][-1].pretty_print() trace.append(chunk) final_state = trace[-1] else: # Standard mode without tracing final_state = self.graph.invoke(init_agent_state, **args) # Store current state for reflection self.curr_state = final_state # Log state self._log_state(trade_date, final_state) # Return decision and processed signal return final_state, self.process_signal(final_state["final_trade_decision"]) def _log_state(self, trade_date, final_state): """Log the final state to a JSON file.""" self.log_states_dict[str(trade_date)] = { "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], "market_report": final_state["market_report"], "sentiment_report": final_state["sentiment_report"], "news_report": final_state["news_report"], "fundamentals_report": final_state["fundamentals_report"], "investment_debate_state": { "bull_history": final_state["investment_debate_state"]["bull_history"], "bear_history": final_state["investment_debate_state"]["bear_history"], "history": final_state["investment_debate_state"]["history"], "current_response": final_state["investment_debate_state"][ "current_response" ], "judge_decision": final_state["investment_debate_state"][ "judge_decision" ], }, "trader_investment_decision": final_state["trader_investment_plan"], "risk_debate_state": { "risky_history": final_state["risk_debate_state"]["risky_history"], "safe_history": final_state["risk_debate_state"]["safe_history"], "neutral_history": final_state["risk_debate_state"]["neutral_history"], "history": final_state["risk_debate_state"]["history"], "judge_decision": final_state["risk_debate_state"]["judge_decision"], }, "investment_plan": final_state["investment_plan"], "final_trade_decision": final_state["final_trade_decision"], } # Save to file directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") directory.mkdir(parents=True, exist_ok=True) with open( f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", "w", ) as f: json.dump(self.log_states_dict, f, indent=4) def reflect_and_remember(self, returns_losses): """Reflect on decisions and update memory based on returns.""" self.reflector.reflect_bull_researcher( self.curr_state, returns_losses, self.bull_memory ) self.reflector.reflect_bear_researcher( self.curr_state, returns_losses, self.bear_memory ) self.reflector.reflect_trader( self.curr_state, returns_losses, self.trader_memory ) self.reflector.reflect_invest_judge( self.curr_state, returns_losses, self.invest_judge_memory ) self.reflector.reflect_risk_manager( self.curr_state, returns_losses, self.risk_manager_memory ) def process_signal(self, full_signal): """Process a signal to extract the core decision.""" return self.signal_processor.process_signal(full_signal) def cleanup_memories(self): """Clean up ChromaDB collections for this analysis to prevent memory leaks.""" if not self.analysis_id: return # Only cleanup if we have a specific analysis_id try: memory_suffix = f"_{self.analysis_id}" collections_to_delete = [ f"bull_memory{memory_suffix}", f"bear_memory{memory_suffix}", f"trader_memory{memory_suffix}", f"invest_judge_memory{memory_suffix}", f"risk_manager_memory{memory_suffix}", ] # Get the chroma client from one of the memories chroma_client = self.bull_memory.chroma_client for collection_name in collections_to_delete: try: chroma_client.delete_collection(name=collection_name) except Exception as e: # Ignore errors if collection doesn't exist pass except Exception as e: # Don't fail the analysis if cleanup fails print(f"Warning: Failed to cleanup memory collections: {e}")