import logging from datetime import date, datetime from decimal import Decimal from typing import Any, Dict, Optional from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.models.backtest import BacktestConfig, BacktestResult from tradingagents.models.decisions import ( AnalystReport, AnalystType, SignalType, TradingDecision, ) from .engine import BacktestEngine logger = logging.getLogger(__name__) class AgentBacktestEngine(BacktestEngine): def __init__( self, config: BacktestConfig, agent_config: dict[str, Any] | None = None, ): super().__init__(config) self.agent_config = agent_config or config.agent_config self.trading_graph: TradingAgentsGraph | None = None self._decision_cache: dict[str, TradingDecision] = {} def _initialize(self): super()._initialize() graph_config = { **self.agent_config, } self.trading_graph = TradingAgentsGraph( selected_analysts=self.agent_config.get( "selected_analysts", ["market", "social", "news", "fundamentals"], ), debug=self.agent_config.get("debug", False), config=graph_config if graph_config else None, ) def _get_decision( self, ticker: str, trading_date: date, day_index: int, ) -> TradingDecision | None: cache_key = f"{ticker}_{trading_date}" if cache_key in self._decision_cache: return self._decision_cache[cache_key] try: final_state, signal_info = self.trading_graph.propagate( ticker, trading_date ) decision = self._parse_agent_decision( ticker, trading_date, final_state, signal_info ) self._decision_cache[cache_key] = decision return decision except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e: logger.error( "Agent decision failed for %s on %s: %s", ticker, trading_date, e ) return None def _parse_agent_decision( self, ticker: str, trading_date: date, final_state: dict[str, Any], signal_info: dict[str, Any], ) -> TradingDecision: signal = self._extract_signal(signal_info) confidence = self._extract_confidence(signal_info) analyst_reports = [] if final_state.get("market_report"): analyst_reports.append( AnalystReport( analyst_type=AnalystType.MARKET, ticker=ticker, report_date=datetime.combine(trading_date, datetime.min.time()), summary=final_state["market_report"][:500], raw_content=final_state["market_report"], ) ) if final_state.get("sentiment_report"): analyst_reports.append( AnalystReport( analyst_type=AnalystType.SENTIMENT, ticker=ticker, report_date=datetime.combine(trading_date, datetime.min.time()), summary=final_state["sentiment_report"][:500], raw_content=final_state["sentiment_report"], ) ) if final_state.get("news_report"): analyst_reports.append( AnalystReport( analyst_type=AnalystType.NEWS, ticker=ticker, report_date=datetime.combine(trading_date, datetime.min.time()), summary=final_state["news_report"][:500], raw_content=final_state["news_report"], ) ) if final_state.get("fundamentals_report"): analyst_reports.append( AnalystReport( analyst_type=AnalystType.FUNDAMENTALS, ticker=ticker, report_date=datetime.combine(trading_date, datetime.min.time()), summary=final_state["fundamentals_report"][:500], raw_content=final_state["fundamentals_report"], ) ) debate_state = final_state.get("investment_debate_state", {}) bull_argument = None bear_argument = None if debate_state.get("bull_history"): bull_argument = ( debate_state["bull_history"][-1] if debate_state["bull_history"] else None ) if debate_state.get("bear_history"): bear_argument = ( debate_state["bear_history"][-1] if debate_state["bear_history"] else None ) risk_state = final_state.get("risk_debate_state", {}) risk_approved = self._extract_risk_approval(risk_state) final_decision_text = final_state.get("final_trade_decision", "") recommended_action = self._extract_action(signal_info, final_decision_text) return TradingDecision( ticker=ticker, timestamp=datetime.now(), decision_date=datetime.combine(trading_date, datetime.min.time()), signal=signal, confidence=confidence, recommended_action=recommended_action, analyst_reports=analyst_reports, bull_argument=bull_argument, bear_argument=bear_argument, debate_rounds=debate_state.get("count", 0), risk_manager_approved=risk_approved, final_decision=recommended_action, rationale=final_decision_text[:1000] if final_decision_text else "", ) def _extract_signal(self, signal_info: dict[str, Any]) -> SignalType: action = signal_info.get("action", "").upper() direction = signal_info.get("direction", "").upper() if action == "BUY" or direction == "BULLISH": confidence = signal_info.get("confidence", 0.5) if confidence > 0.8: return SignalType.STRONG_BUY return SignalType.BUY elif action == "SELL" or direction == "BEARISH": confidence = signal_info.get("confidence", 0.5) if confidence > 0.8: return SignalType.STRONG_SELL return SignalType.SELL return SignalType.HOLD def _extract_confidence(self, signal_info: dict[str, Any]) -> Decimal: confidence = signal_info.get("confidence", 0.5) if isinstance(confidence, str): try: confidence = float(confidence.replace("%", "")) / 100 except ValueError: confidence = 0.5 return Decimal(str(min(max(float(confidence), 0.0), 1.0))) def _extract_action( self, signal_info: dict[str, Any], final_decision_text: str, ) -> str: action = signal_info.get("action", "") if action: return action.upper() text_upper = final_decision_text.upper() if "BUY" in text_upper and "DON'T BUY" not in text_upper: return "BUY" elif "SELL" in text_upper: return "SELL" return "HOLD" def _extract_risk_approval(self, risk_state: dict[str, Any]) -> bool | None: judge_decision = risk_state.get("judge_decision", "") if not judge_decision: return None text_upper = judge_decision.upper() if "APPROVE" in text_upper or "ACCEPT" in text_upper: return True elif "REJECT" in text_upper or "DENY" in text_upper: return False return None def run_agent_backtest( tickers: list[str], start_date: date, end_date: date, initial_cash: Decimal = Decimal("100000"), agent_config: dict[str, Any] | None = None, ) -> BacktestResult: from tradingagents.models.portfolio import PortfolioConfig config = BacktestConfig( name=f"Agent Backtest - {', '.join(tickers)}", tickers=tickers, start_date=start_date, end_date=end_date, portfolio_config=PortfolioConfig( initial_cash=initial_cash, commission_per_trade=Decimal("1"), slippage_percent=Decimal("0.05"), ), warmup_period=5, agent_config=agent_config or {}, ) engine = AgentBacktestEngine(config, agent_config) return engine.run()