""" Visualization tools for trading strategy evaluation. Generates plots and reports for comparing TradingAgents with baseline strategies. """ import pandas as pd import numpy as np import matplotlib.pyplot as plt from pathlib import Path from typing import Dict import warnings import json warnings.filterwarnings('ignore') # Try to import seaborn for better styling (optional) try: import seaborn as sns plt.style.use('seaborn-v0_8-darkgrid') sns.set_palette("husl") HAS_SEABORN = True except ImportError: HAS_SEABORN = False # Use default matplotlib styling plt.rcParams['figure.facecolor'] = 'white' plt.rcParams['axes.facecolor'] = 'white' plt.rcParams['axes.grid'] = True def plot_cumulative_returns( results: Dict[str, pd.DataFrame], ticker: str, output_path: str = None, figsize: tuple = (14, 8) ) -> plt.Figure: """ Plot cumulative returns comparison for all strategies. Args: results: Dictionary mapping strategy name to portfolio DataFrame ticker: Stock ticker symbol output_path: Path to save the figure (optional) figsize: Figure size (width, height) Returns: matplotlib Figure object """ fig, ax = plt.subplots(figsize=figsize) for name, portfolio in results.items(): if "cumulative_return" in portfolio.columns: cumulative = (portfolio["cumulative_return"] - 1) * 100 # Convert to percentage ax.plot(portfolio.index, cumulative, label=name, linewidth=2, alpha=0.8) ax.set_xlabel('Date', fontsize=12, fontweight='bold') ax.set_ylabel('Cumulative Return (%)', fontsize=12, fontweight='bold') ax.set_title(f'{ticker} - Cumulative Returns Comparison', fontsize=14, fontweight='bold') ax.legend(loc='best', fontsize=10, framealpha=0.9) ax.grid(True, alpha=0.3) ax.axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.5) # Format y-axis as percentage ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.1f}%')) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=300, bbox_inches='tight') print(f"✓ Saved cumulative returns plot to: {output_path}") return fig def plot_transaction_history( portfolio: pd.DataFrame, ticker: str, strategy_name: str = "TradingAgents", output_path: str = None, figsize: tuple = (14, 10) ) -> plt.Figure: """ Plot transaction history with buy/sell signals overlaid on price chart. Args: portfolio: Portfolio DataFrame with 'signal' and 'close' columns ticker: Stock ticker symbol strategy_name: Name of the strategy output_path: Path to save the figure (optional) figsize: Figure size (width, height) Returns: matplotlib Figure object """ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, height_ratios=[2, 1]) # Price chart with signals ax1.plot(portfolio.index, portfolio["close"], label='Close Price', color='blue', linewidth=1.5, alpha=0.7) # Buy signals (signal == 1 and previous signal != 1) signals = portfolio["signal"].copy() buy_signals = (signals == 1) & (signals.shift(1) != 1) sell_signals = (signals == -1) & (signals.shift(1) != -1) # Plot buy/sell markers if buy_signals.any(): ax1.scatter(portfolio.index[buy_signals], portfolio.loc[buy_signals, "close"], marker='^', color='green', s=100, label='Buy', zorder=5, alpha=0.8) if sell_signals.any(): ax1.scatter(portfolio.index[sell_signals], portfolio.loc[sell_signals, "close"], marker='v', color='red', s=100, label='Sell', zorder=5, alpha=0.8) ax1.set_ylabel('Price ($)', fontsize=12, fontweight='bold') ax1.set_title(f'{ticker} - {strategy_name} Transaction History', fontsize=14, fontweight='bold') ax1.legend(loc='best', fontsize=10) ax1.grid(True, alpha=0.3) # Portfolio value ax2.plot(portfolio.index, portfolio["portfolio_value"], label='Portfolio Value', color='purple', linewidth=2) ax2.fill_between(portfolio.index, portfolio["portfolio_value"], alpha=0.3, color='purple') ax2.set_xlabel('Date', fontsize=12, fontweight='bold') ax2.set_ylabel('Portfolio Value ($)', fontsize=12, fontweight='bold') ax2.legend(loc='best', fontsize=10) ax2.grid(True, alpha=0.3) # Format y-axis as currency ax2.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'${y:,.0f}')) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=300, bbox_inches='tight') print(f"✓ Saved transaction history plot to: {output_path}") return fig def plot_metrics_comparison( comparison_df: pd.DataFrame, ticker: str, output_path: str = None, figsize: tuple = (16, 10) ) -> plt.Figure: """ Create bar charts comparing key metrics across strategies. Args: comparison_df: DataFrame with strategies as rows and metrics as columns ticker: Stock ticker symbol output_path: Path to save the figure (optional) figsize: Figure size (width, height) Returns: matplotlib Figure object """ # Select key metrics (matching paper's Table 1) metrics_to_plot = [ "Cumulative Return (%)", "Annualized Return (%)", "Sharpe Ratio", "Maximum Drawdown (%)" ] # Filter to available metrics available_metrics = [m for m in metrics_to_plot if m in comparison_df.columns] if not available_metrics: raise ValueError("No matching metrics found in comparison DataFrame") n_metrics = len(available_metrics) fig, axes = plt.subplots(2, 2, figsize=figsize) axes = axes.flatten() for idx, metric in enumerate(available_metrics): ax = axes[idx] data = comparison_df[metric].sort_values(ascending=False) # Color code: TradingAgents in different color colors = ['#FF6B6B' if name == 'TradingAgents' else '#4ECDC4' for name in data.index] bars = ax.barh(range(len(data)), data.values, color=colors, alpha=0.8) ax.set_yticks(range(len(data))) ax.set_yticklabels(data.index, fontsize=10) ax.set_xlabel(metric, fontsize=11, fontweight='bold') ax.set_title(metric, fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3, axis='x') # Add value labels on bars for i, (bar, value) in enumerate(zip(bars, data.values)): if "Ratio" in metric: label = f'{value:.2f}' else: label = f'{value:.1f}%' ax.text(value, bar.get_y() + bar.get_height()/2, f' {label}', va='center', fontsize=9) # Hide unused subplots for idx in range(n_metrics, 4): axes[idx].axis('off') fig.suptitle(f'{ticker} - Performance Metrics Comparison', fontsize=16, fontweight='bold', y=0.995) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=300, bbox_inches='tight') print(f"✓ Saved metrics comparison plot to: {output_path}") return fig def plot_drawdown( results: Dict[str, pd.DataFrame], ticker: str, output_path: str = None, figsize: tuple = (14, 8) ) -> plt.Figure: """ Plot drawdown analysis for all strategies. Args: results: Dictionary mapping strategy name to portfolio DataFrame ticker: Stock ticker symbol output_path: Path to save the figure (optional) figsize: Figure size (width, height) Returns: matplotlib Figure object """ fig, ax = plt.subplots(figsize=figsize) for name, portfolio in results.items(): if "portfolio_value" in portfolio.columns: values = portfolio["portfolio_value"] running_max = values.cummax() drawdown = (values - running_max) / running_max * 100 ax.plot(portfolio.index, drawdown, label=name, linewidth=2, alpha=0.7) ax.set_xlabel('Date', fontsize=12, fontweight='bold') ax.set_ylabel('Drawdown (%)', fontsize=12, fontweight='bold') ax.set_title(f'{ticker} - Drawdown Analysis', fontsize=14, fontweight='bold') ax.legend(loc='best', fontsize=10, framealpha=0.9) ax.grid(True, alpha=0.3) ax.axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.5) # Fill drawdown areas for name, portfolio in results.items(): if "portfolio_value" in portfolio.columns: values = portfolio["portfolio_value"] running_max = values.cummax() drawdown = (values - running_max) / running_max * 100 ax.fill_between(portfolio.index, drawdown, 0, alpha=0.1) # Format y-axis as percentage ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.1f}%')) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=300, bbox_inches='tight') print(f"✓ Saved drawdown plot to: {output_path}") return fig def plot_returns_distribution( results: Dict[str, pd.DataFrame], ticker: str, output_path: str = None, figsize: tuple = (14, 8) ) -> plt.Figure: """ Plot distribution of daily returns for all strategies. Args: results: Dictionary mapping strategy name to portfolio DataFrame ticker: Stock ticker symbol output_path: Path to save the figure (optional) figsize: Figure size (width, height) Returns: matplotlib Figure object """ fig, ax = plt.subplots(figsize=figsize) for name, portfolio in results.items(): if "strategy_return" in portfolio.columns: returns = portfolio["strategy_return"].dropna() * 100 # Convert to percentage ax.hist(returns, bins=50, alpha=0.5, label=name, density=True) ax.set_xlabel('Daily Return (%)', fontsize=12, fontweight='bold') ax.set_ylabel('Density', fontsize=12, fontweight='bold') ax.set_title(f'{ticker} - Returns Distribution', fontsize=14, fontweight='bold') ax.legend(loc='best', fontsize=10) ax.grid(True, alpha=0.3) ax.axvline(x=0, color='black', linestyle='--', linewidth=1, alpha=0.5) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=300, bbox_inches='tight') print(f"✓ Saved returns distribution plot to: {output_path}") return fig def create_summary_report( ticker: str, results: Dict[str, pd.DataFrame], comparison_df: pd.DataFrame, output_dir: str ) -> None: """ Generate comprehensive visual summary report. Creates all standard plots and saves them to output directory. Args: ticker: Stock ticker symbol results: Dictionary mapping strategy name to portfolio DataFrame comparison_df: DataFrame with performance metrics comparison output_dir: Directory to save output files """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) print("\nGenerating visualizations...") # 1. Cumulative Returns try: plot_cumulative_returns( results, ticker, output_path=str(output_path / f"{ticker}_cumulative_returns.png") ) except Exception as e: print(f"✗ Failed to generate cumulative returns plot: {e}") # 2. Metrics Comparison try: plot_metrics_comparison( comparison_df, ticker, output_path=str(output_path / f"{ticker}_metrics_comparison.png") ) except Exception as e: print(f"✗ Failed to generate metrics comparison plot: {e}") # 3. Drawdown Analysis try: plot_drawdown( results, ticker, output_path=str(output_path / f"{ticker}_drawdown.png") ) except Exception as e: print(f"✗ Failed to generate drawdown plot: {e}") # 4. Transaction History (if TradingAgents results available) if "TradingAgents" in results: try: plot_transaction_history( results["TradingAgents"], ticker, strategy_name="TradingAgents", output_path=str(output_path / f"{ticker}_TradingAgents_transactions.png") ) except Exception as e: print(f"✗ Failed to generate transaction history plot: {e}") # 5. Returns Distribution try: plot_returns_distribution( results, ticker, output_path=str(output_path / f"{ticker}_returns_distribution.png") ) except Exception as e: print(f"✗ Failed to generate returns distribution plot: {e}") print(f"\n✓ All visualizations saved to: {output_dir}") def plot_cumulative_returns_from_results( results_dir: str, ticker: str, output_path: str = None, figsize: tuple = (12, 7) ) -> plt.Figure: """ Plot cumulative returns comparison from saved JSON results. Args: results_dir: Directory containing strategy result folders ticker: Stock ticker symbol output_path: Path to save the figure (optional) figsize: Figure size (width, height) Returns: matplotlib Figure object """ results_path = Path(results_dir) # Define strategies to load strategies = { 'BuyAndHold': 'BuyAndHoldStrategy', 'MACD': 'MACDStrategy', 'KDJ&RSI': 'KDJRSIStrategy', 'ZMR': 'ZMRStrategy', 'SMA': 'SMAStrategy', 'TradingAgents': 'TradingAgents' } fig, ax = plt.subplots(figsize=figsize) # Load and plot each strategy for folder_name, display_name in strategies.items(): strategy_dir = results_path / folder_name if not strategy_dir.exists(): continue # Find actions JSON file action_files = list(strategy_dir.glob("actions_*.json")) if not action_files: continue try: # Load data with open(action_files[0], 'r') as f: data = json.load(f) # Extract date and cumulative_return dates = pd.to_datetime([action['date'] for action in data['actions']]) cumulative_returns = [action['cumulative_return'] for action in data['actions']] # Plot linewidth = 2.5 if display_name == 'TradingAgents' else 1.5 ax.plot(dates, cumulative_returns, label=display_name, linewidth=linewidth, alpha=0.9) except Exception as e: print(f"Warning: Failed to load {display_name}: {e}") ax.set_xlabel('Date', fontsize=12) ax.set_ylabel('Cumulative Return', fontsize=12) ax.set_title(f'Strategy Comparison - Cumulative Returns for {ticker}', fontsize=14, fontweight='bold') ax.legend(title='Strategies', loc='best', fontsize=10, framealpha=0.9) ax.grid(True, alpha=0.3, linestyle='--') ax.axhline(y=1.0, color='black', linestyle='--', linewidth=1, alpha=0.5) plt.tight_layout() if output_path: fig.savefig(output_path, dpi=300, bbox_inches='tight') print(f"✓ Saved cumulative returns comparison to: {output_path}") return fig if __name__ == "__main__": # Example usage / testing print("Visualization module loaded successfully!") print("\nAvailable functions:") print(" - plot_cumulative_returns") print(" - plot_cumulative_returns_from_results") print(" - plot_transaction_history") print(" - plot_metrics_comparison") print(" - plot_drawdown") print(" - plot_returns_distribution") print(" - create_summary_report")