From e99138f5b96b4b266394357d3f9f640bc410c795 Mon Sep 17 00:00:00 2001 From: Quanliang Liu Date: Fri, 7 Nov 2025 08:53:46 -0600 Subject: [PATCH] Visualization added --- evaluation/run_evaluation.py | 27 ++++++++++++ evaluation/visualize.py | 79 ++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/evaluation/run_evaluation.py b/evaluation/run_evaluation.py index 7e579bbe..b0d9e3fd 100644 --- a/evaluation/run_evaluation.py +++ b/evaluation/run_evaluation.py @@ -16,6 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) from evaluation.baseline_strategies import get_all_baseline_strategies from evaluation.backtest import BacktestEngine, TradingAgentsBacktester, load_stock_data, standardize_single_ticker from evaluation.metrics import calculate_all_metrics, create_comparison_table, print_metrics +from evaluation.visualize import plot_cumulative_returns_from_results from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG @@ -178,6 +179,32 @@ def run_evaluation( all_metrics[name] = metrics print_metrics(metrics, name) + # Generate cumulative returns comparison plot + print("\n" + "="*80) + print("STEP 5: Generating Comparison Plot") + print("="*80) + try: + comparison_plot_path = str(out / ticker / "strategy_comparison.png") + plot_cumulative_returns_from_results( + results_dir=str(out / ticker), + ticker=ticker, + output_path=comparison_plot_path + ) + # Also save as PDF + pdf_path = comparison_plot_path.replace('.png', '.pdf') + plot_cumulative_returns_from_results( + results_dir=str(out / ticker), + ticker=ticker, + output_path=pdf_path + ) + print(f"\nāœ“ Comparison plot saved to:") + print(f" - {comparison_plot_path}") + print(f" - {pdf_path}") + except Exception as e: + print(f"\nāœ— Failed to generate comparison plot: {e}") + import traceback + traceback.print_exc() + print("\n" + "="*80) print("EVALUATION COMPLETE") print("="*80) diff --git a/evaluation/visualize.py b/evaluation/visualize.py index c8b55162..416fa89a 100644 --- a/evaluation/visualize.py +++ b/evaluation/visualize.py @@ -9,6 +9,7 @@ import matplotlib.pyplot as plt from pathlib import Path from typing import Dict import warnings +import json warnings.filterwarnings('ignore') @@ -388,11 +389,89 @@ def create_summary_report( print(f"\nāœ“ All visualizations saved to: {output_dir}") +def plot_cumulative_returns_from_results( + results_dir: str, + ticker: str, + output_path: str = None, + figsize: tuple = (12, 7) +) -> plt.Figure: + """ + Plot cumulative returns comparison from saved JSON results. + + Args: + results_dir: Directory containing strategy result folders + ticker: Stock ticker symbol + output_path: Path to save the figure (optional) + figsize: Figure size (width, height) + + Returns: + matplotlib Figure object + """ + results_path = Path(results_dir) + + # Define strategies to load + strategies = { + 'BuyAndHold': 'BuyAndHoldStrategy', + 'MACD': 'MACDStrategy', + 'KDJ&RSI': 'KDJRSIStrategy', + 'ZMR': 'ZMRStrategy', + 'SMA': 'SMAStrategy', + 'TradingAgents': 'TradingAgents' + } + + fig, ax = plt.subplots(figsize=figsize) + + # Load and plot each strategy + for folder_name, display_name in strategies.items(): + strategy_dir = results_path / folder_name + if not strategy_dir.exists(): + continue + + # Find actions JSON file + action_files = list(strategy_dir.glob("actions_*.json")) + if not action_files: + continue + + try: + # Load data + with open(action_files[0], 'r') as f: + data = json.load(f) + + # Extract date and cumulative_return + dates = pd.to_datetime([action['date'] for action in data['actions']]) + cumulative_returns = [action['cumulative_return'] for action in data['actions']] + + # Plot + linewidth = 2.5 if display_name == 'TradingAgents' else 1.5 + ax.plot(dates, cumulative_returns, label=display_name, + linewidth=linewidth, alpha=0.9) + + except Exception as e: + print(f"Warning: Failed to load {display_name}: {e}") + + ax.set_xlabel('Date', fontsize=12) + ax.set_ylabel('Cumulative Return', fontsize=12) + ax.set_title(f'Strategy Comparison - Cumulative Returns for {ticker}', + fontsize=14, fontweight='bold') + ax.legend(title='Strategies', loc='best', fontsize=10, framealpha=0.9) + ax.grid(True, alpha=0.3, linestyle='--') + ax.axhline(y=1.0, color='black', linestyle='--', linewidth=1, alpha=0.5) + + plt.tight_layout() + + if output_path: + fig.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"āœ“ Saved cumulative returns comparison to: {output_path}") + + return fig + + if __name__ == "__main__": # Example usage / testing print("Visualization module loaded successfully!") print("\nAvailable functions:") print(" - plot_cumulative_returns") + print(" - plot_cumulative_returns_from_results") print(" - plot_transaction_history") print(" - plot_metrics_comparison") print(" - plot_drawdown")