# TradingAgents/graph/trading_graph.py import logging import os from pathlib import Path import json from datetime import datetime, timedelta from typing import Dict, Any, Tuple, List, Optional import yfinance as yf logger = logging.getLogger(__name__) from langgraph.prebuilt import ToolNode from tradingagents.llm_clients import create_llm_client from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import TradingMemoryLog from tradingagents.agents.utils.agent_states import ( AgentState, InvestDebateState, RiskDebateState, ) from tradingagents.dataflows.config import set_config # Import the new abstract tool methods from agent_utils from tradingagents.agents.utils.agent_utils import ( get_stock_data, get_indicators, get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_news, get_insider_transactions, get_global_news ) from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector from .signal_processing import SignalProcessor class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" def __init__( self, selected_analysts=["market", "social", "news", "fundamentals"], debug=False, config: Dict[str, Any] = None, callbacks: Optional[List] = None, ): """Initialize the trading agents graph and components. Args: selected_analysts: List of analyst types to include debug: Whether to run in debug mode config: Configuration dictionary. If None, uses default config callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats) """ self.debug = debug self.config = config or DEFAULT_CONFIG self.callbacks = callbacks or [] # Update the interface's config set_config(self.config) # Create necessary directories os.makedirs(self.config["data_cache_dir"], exist_ok=True) os.makedirs(self.config["results_dir"], exist_ok=True) # Initialize LLMs with provider-specific thinking configuration llm_kwargs = self._get_provider_kwargs() # Add callbacks to kwargs if provided (passed to LLM constructor) if self.callbacks: llm_kwargs["callbacks"] = self.callbacks deep_client = create_llm_client( provider=self.config["llm_provider"], model=self.config["deep_think_llm"], base_url=self.config.get("backend_url"), **llm_kwargs, ) quick_client = create_llm_client( provider=self.config["llm_provider"], model=self.config["quick_think_llm"], base_url=self.config.get("backend_url"), **llm_kwargs, ) self.deep_thinking_llm = deep_client.get_llm() self.quick_thinking_llm = quick_client.get_llm() self.memory_log = TradingMemoryLog(self.config) # Create tool nodes self.tool_nodes = self._create_tool_nodes() # Initialize components self.conditional_logic = ConditionalLogic( max_debate_rounds=self.config["max_debate_rounds"], max_risk_discuss_rounds=self.config["max_risk_discuss_rounds"], ) self.graph_setup = GraphSetup( self.quick_thinking_llm, self.deep_thinking_llm, self.tool_nodes, self.conditional_logic, ) self.propagator = Propagator() self.reflector = Reflector(self.quick_thinking_llm) self.signal_processor = SignalProcessor(self.quick_thinking_llm) # State tracking self.curr_state = None self.ticker = None self.log_states_dict = {} # date to full state dict # Set up the graph self.graph = self.graph_setup.setup_graph(selected_analysts) def _get_provider_kwargs(self) -> Dict[str, Any]: """Get provider-specific kwargs for LLM client creation.""" kwargs = {} provider = self.config.get("llm_provider", "").lower() if provider == "google": thinking_level = self.config.get("google_thinking_level") if thinking_level: kwargs["thinking_level"] = thinking_level elif provider == "openai": reasoning_effort = self.config.get("openai_reasoning_effort") if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort elif provider == "anthropic": effort = self.config.get("anthropic_effort") if effort: kwargs["effort"] = effort return kwargs def _create_tool_nodes(self) -> Dict[str, ToolNode]: """Create tool nodes for different data sources using abstract methods.""" return { "market": ToolNode( [ # Core stock data tools get_stock_data, # Technical indicators get_indicators, ] ), "social": ToolNode( [ # News tools for social media analysis get_news, ] ), "news": ToolNode( [ # News and insider information get_news, get_global_news, get_insider_transactions, ] ), "fundamentals": ToolNode( [ # Fundamental analysis tools get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, ] ), } def _fetch_returns( self, ticker: str, trade_date: str, holding_days: int = 5 ) -> Tuple[Optional[float], Optional[float], Optional[int]]: """Fetch raw and alpha return for ticker over holding_days from trade_date. Returns (raw_return, alpha_return, actual_holding_days) or (None, None, None) if price data is unavailable (too recent, delisted, or network error). """ try: start = datetime.strptime(trade_date, "%Y-%m-%d") end = start + timedelta(days=holding_days + 7) # buffer for weekends/holidays end_str = end.strftime("%Y-%m-%d") stock = yf.Ticker(ticker).history(start=trade_date, end=end_str) spy = yf.Ticker("SPY").history(start=trade_date, end=end_str) if len(stock) < 2 or len(spy) < 2: return None, None, None actual_days = min(holding_days, len(stock) - 1, len(spy) - 1) raw = float( (stock["Close"].iloc[actual_days] - stock["Close"].iloc[0]) / stock["Close"].iloc[0] ) spy_ret = float( (spy["Close"].iloc[actual_days] - spy["Close"].iloc[0]) / spy["Close"].iloc[0] ) alpha = raw - spy_ret return raw, alpha, actual_days except Exception as e: logger.debug("_fetch_returns failed for %s@%s: %s", ticker, trade_date, e) return None, None, None def _resolve_pending_entries(self, ticker: str) -> None: """Resolve pending log entries for ticker at the start of a new run. Fetches returns for each same-ticker pending entry, generates reflections, then writes all updates in a single atomic batch write to avoid redundant I/O. Skips entries whose price data is not yet available (too recent or delisted). Trade-off: only same-ticker entries are resolved per run. Entries for other tickers accumulate until that ticker is run again. """ pending = [e for e in self.memory_log.get_pending_entries() if e["ticker"] == ticker] if not pending: return updates = [] for entry in pending: raw, alpha, days = self._fetch_returns(ticker, entry["date"]) if raw is None: continue # price not available yet — try again next run reflection = self.reflector.reflect_on_final_decision( final_decision=entry.get("decision", ""), raw_return=raw, alpha_return=alpha, ) updates.append({ "ticker": ticker, "trade_date": entry["date"], "raw_return": raw, "alpha_return": alpha, "holding_days": days, "reflection": reflection, }) if updates: self.memory_log.batch_update_with_outcomes(updates) def propagate(self, company_name, trade_date): """Run the trading agents graph for a company on a specific date.""" self.ticker = company_name # Resolve any pending log entries for this ticker before the pipeline runs. # This adds the outcome + reflection from the previous run at zero latency cost. self._resolve_pending_entries(company_name) # Initialize state — inject memory log context for PM past_context = self.memory_log.get_past_context(company_name) init_agent_state = self.propagator.create_initial_state( company_name, trade_date, past_context=past_context ) args = self.propagator.get_graph_args() if self.debug: # Debug mode with tracing trace = [] for chunk in self.graph.stream(init_agent_state, **args): if len(chunk["messages"]) == 0: pass else: chunk["messages"][-1].pretty_print() trace.append(chunk) final_state = trace[-1] else: # Standard mode without tracing final_state = self.graph.invoke(init_agent_state, **args) # Store current state for reflection self.curr_state = final_state # Log state self._log_state(trade_date, final_state) # Store decision for deferred reflection. self.memory_log.store_decision( ticker=company_name, trade_date=trade_date, final_trade_decision=final_state["final_trade_decision"], ) # Return decision and processed signal return final_state, self.process_signal(final_state["final_trade_decision"]) def _log_state(self, trade_date, final_state): """Log the final state to a JSON file.""" self.log_states_dict[str(trade_date)] = { "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], "market_report": final_state["market_report"], "sentiment_report": final_state["sentiment_report"], "news_report": final_state["news_report"], "fundamentals_report": final_state["fundamentals_report"], "investment_debate_state": { "bull_history": final_state["investment_debate_state"]["bull_history"], "bear_history": final_state["investment_debate_state"]["bear_history"], "history": final_state["investment_debate_state"]["history"], "current_response": final_state["investment_debate_state"][ "current_response" ], "judge_decision": final_state["investment_debate_state"][ "judge_decision" ], }, "trader_investment_decision": final_state["trader_investment_plan"], "risk_debate_state": { "aggressive_history": final_state["risk_debate_state"]["aggressive_history"], "conservative_history": final_state["risk_debate_state"]["conservative_history"], "neutral_history": final_state["risk_debate_state"]["neutral_history"], "history": final_state["risk_debate_state"]["history"], "judge_decision": final_state["risk_debate_state"]["judge_decision"], }, "investment_plan": final_state["investment_plan"], "final_trade_decision": final_state["final_trade_decision"], } # Save to file directory = Path(self.config["results_dir"]) / self.ticker / "TradingAgentsStrategy_logs" directory.mkdir(parents=True, exist_ok=True) log_path = directory / f"full_states_log_{trade_date}.json" with open(log_path, "w", encoding="utf-8") as f: json.dump(self.log_states_dict[str(trade_date)], f, indent=4) def process_signal(self, full_signal): """Process a signal to extract the core decision.""" return self.signal_processor.process_signal(full_signal)