481 lines
16 KiB
Python
481 lines
16 KiB
Python
"""
|
|
Visualization tools for trading strategy evaluation.
|
|
Generates plots and reports for comparing TradingAgents with baseline strategies.
|
|
"""
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
import warnings
|
|
import json
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
|
# Try to import seaborn for better styling (optional)
|
|
try:
|
|
import seaborn as sns
|
|
plt.style.use('seaborn-v0_8-darkgrid')
|
|
sns.set_palette("husl")
|
|
HAS_SEABORN = True
|
|
except ImportError:
|
|
HAS_SEABORN = False
|
|
# Use default matplotlib styling
|
|
plt.rcParams['figure.facecolor'] = 'white'
|
|
plt.rcParams['axes.facecolor'] = 'white'
|
|
plt.rcParams['axes.grid'] = True
|
|
|
|
|
|
def plot_cumulative_returns(
|
|
results: Dict[str, pd.DataFrame],
|
|
ticker: str,
|
|
output_path: str = None,
|
|
figsize: tuple = (14, 8)
|
|
) -> plt.Figure:
|
|
"""
|
|
Plot cumulative returns comparison for all strategies.
|
|
|
|
Args:
|
|
results: Dictionary mapping strategy name to portfolio DataFrame
|
|
ticker: Stock ticker symbol
|
|
output_path: Path to save the figure (optional)
|
|
figsize: Figure size (width, height)
|
|
|
|
Returns:
|
|
matplotlib Figure object
|
|
"""
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
|
|
for name, portfolio in results.items():
|
|
if "cumulative_return" in portfolio.columns:
|
|
cumulative = (portfolio["cumulative_return"] - 1) * 100 # Convert to percentage
|
|
ax.plot(portfolio.index, cumulative, label=name, linewidth=2, alpha=0.8)
|
|
|
|
ax.set_xlabel('Date', fontsize=12, fontweight='bold')
|
|
ax.set_ylabel('Cumulative Return (%)', fontsize=12, fontweight='bold')
|
|
ax.set_title(f'{ticker} - Cumulative Returns Comparison', fontsize=14, fontweight='bold')
|
|
ax.legend(loc='best', fontsize=10, framealpha=0.9)
|
|
ax.grid(True, alpha=0.3)
|
|
ax.axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.5)
|
|
|
|
# Format y-axis as percentage
|
|
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.1f}%'))
|
|
|
|
plt.tight_layout()
|
|
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
print(f"✓ Saved cumulative returns plot to: {output_path}")
|
|
|
|
return fig
|
|
|
|
|
|
def plot_transaction_history(
|
|
portfolio: pd.DataFrame,
|
|
ticker: str,
|
|
strategy_name: str = "TradingAgents",
|
|
output_path: str = None,
|
|
figsize: tuple = (14, 10)
|
|
) -> plt.Figure:
|
|
"""
|
|
Plot transaction history with buy/sell signals overlaid on price chart.
|
|
|
|
Args:
|
|
portfolio: Portfolio DataFrame with 'signal' and 'close' columns
|
|
ticker: Stock ticker symbol
|
|
strategy_name: Name of the strategy
|
|
output_path: Path to save the figure (optional)
|
|
figsize: Figure size (width, height)
|
|
|
|
Returns:
|
|
matplotlib Figure object
|
|
"""
|
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize, height_ratios=[2, 1])
|
|
|
|
# Price chart with signals
|
|
ax1.plot(portfolio.index, portfolio["close"], label='Close Price',
|
|
color='blue', linewidth=1.5, alpha=0.7)
|
|
|
|
# Buy signals (signal == 1 and previous signal != 1)
|
|
signals = portfolio["signal"].copy()
|
|
buy_signals = (signals == 1) & (signals.shift(1) != 1)
|
|
sell_signals = (signals == -1) & (signals.shift(1) != -1)
|
|
|
|
# Plot buy/sell markers
|
|
if buy_signals.any():
|
|
ax1.scatter(portfolio.index[buy_signals],
|
|
portfolio.loc[buy_signals, "close"],
|
|
marker='^', color='green', s=100, label='Buy',
|
|
zorder=5, alpha=0.8)
|
|
|
|
if sell_signals.any():
|
|
ax1.scatter(portfolio.index[sell_signals],
|
|
portfolio.loc[sell_signals, "close"],
|
|
marker='v', color='red', s=100, label='Sell',
|
|
zorder=5, alpha=0.8)
|
|
|
|
ax1.set_ylabel('Price ($)', fontsize=12, fontweight='bold')
|
|
ax1.set_title(f'{ticker} - {strategy_name} Transaction History',
|
|
fontsize=14, fontweight='bold')
|
|
ax1.legend(loc='best', fontsize=10)
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
# Portfolio value
|
|
ax2.plot(portfolio.index, portfolio["portfolio_value"],
|
|
label='Portfolio Value', color='purple', linewidth=2)
|
|
ax2.fill_between(portfolio.index, portfolio["portfolio_value"],
|
|
alpha=0.3, color='purple')
|
|
ax2.set_xlabel('Date', fontsize=12, fontweight='bold')
|
|
ax2.set_ylabel('Portfolio Value ($)', fontsize=12, fontweight='bold')
|
|
ax2.legend(loc='best', fontsize=10)
|
|
ax2.grid(True, alpha=0.3)
|
|
|
|
# Format y-axis as currency
|
|
ax2.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'${y:,.0f}'))
|
|
|
|
plt.tight_layout()
|
|
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
print(f"✓ Saved transaction history plot to: {output_path}")
|
|
|
|
return fig
|
|
|
|
|
|
def plot_metrics_comparison(
|
|
comparison_df: pd.DataFrame,
|
|
ticker: str,
|
|
output_path: str = None,
|
|
figsize: tuple = (16, 10)
|
|
) -> plt.Figure:
|
|
"""
|
|
Create bar charts comparing key metrics across strategies.
|
|
|
|
Args:
|
|
comparison_df: DataFrame with strategies as rows and metrics as columns
|
|
ticker: Stock ticker symbol
|
|
output_path: Path to save the figure (optional)
|
|
figsize: Figure size (width, height)
|
|
|
|
Returns:
|
|
matplotlib Figure object
|
|
"""
|
|
# Select key metrics (matching paper's Table 1)
|
|
metrics_to_plot = [
|
|
"Cumulative Return (%)",
|
|
"Annualized Return (%)",
|
|
"Sharpe Ratio",
|
|
"Maximum Drawdown (%)"
|
|
]
|
|
|
|
# Filter to available metrics
|
|
available_metrics = [m for m in metrics_to_plot if m in comparison_df.columns]
|
|
|
|
if not available_metrics:
|
|
raise ValueError("No matching metrics found in comparison DataFrame")
|
|
|
|
n_metrics = len(available_metrics)
|
|
fig, axes = plt.subplots(2, 2, figsize=figsize)
|
|
axes = axes.flatten()
|
|
|
|
for idx, metric in enumerate(available_metrics):
|
|
ax = axes[idx]
|
|
data = comparison_df[metric].sort_values(ascending=False)
|
|
|
|
# Color code: TradingAgents in different color
|
|
colors = ['#FF6B6B' if name == 'TradingAgents' else '#4ECDC4'
|
|
for name in data.index]
|
|
|
|
bars = ax.barh(range(len(data)), data.values, color=colors, alpha=0.8)
|
|
ax.set_yticks(range(len(data)))
|
|
ax.set_yticklabels(data.index, fontsize=10)
|
|
ax.set_xlabel(metric, fontsize=11, fontweight='bold')
|
|
ax.set_title(metric, fontsize=12, fontweight='bold')
|
|
ax.grid(True, alpha=0.3, axis='x')
|
|
|
|
# Add value labels on bars
|
|
for i, (bar, value) in enumerate(zip(bars, data.values)):
|
|
if "Ratio" in metric:
|
|
label = f'{value:.2f}'
|
|
else:
|
|
label = f'{value:.1f}%'
|
|
ax.text(value, bar.get_y() + bar.get_height()/2,
|
|
f' {label}', va='center', fontsize=9)
|
|
|
|
# Hide unused subplots
|
|
for idx in range(n_metrics, 4):
|
|
axes[idx].axis('off')
|
|
|
|
fig.suptitle(f'{ticker} - Performance Metrics Comparison',
|
|
fontsize=16, fontweight='bold', y=0.995)
|
|
plt.tight_layout()
|
|
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
print(f"✓ Saved metrics comparison plot to: {output_path}")
|
|
|
|
return fig
|
|
|
|
|
|
def plot_drawdown(
|
|
results: Dict[str, pd.DataFrame],
|
|
ticker: str,
|
|
output_path: str = None,
|
|
figsize: tuple = (14, 8)
|
|
) -> plt.Figure:
|
|
"""
|
|
Plot drawdown analysis for all strategies.
|
|
|
|
Args:
|
|
results: Dictionary mapping strategy name to portfolio DataFrame
|
|
ticker: Stock ticker symbol
|
|
output_path: Path to save the figure (optional)
|
|
figsize: Figure size (width, height)
|
|
|
|
Returns:
|
|
matplotlib Figure object
|
|
"""
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
|
|
for name, portfolio in results.items():
|
|
if "portfolio_value" in portfolio.columns:
|
|
values = portfolio["portfolio_value"]
|
|
running_max = values.cummax()
|
|
drawdown = (values - running_max) / running_max * 100
|
|
ax.plot(portfolio.index, drawdown, label=name, linewidth=2, alpha=0.7)
|
|
|
|
ax.set_xlabel('Date', fontsize=12, fontweight='bold')
|
|
ax.set_ylabel('Drawdown (%)', fontsize=12, fontweight='bold')
|
|
ax.set_title(f'{ticker} - Drawdown Analysis', fontsize=14, fontweight='bold')
|
|
ax.legend(loc='best', fontsize=10, framealpha=0.9)
|
|
ax.grid(True, alpha=0.3)
|
|
ax.axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.5)
|
|
|
|
# Fill drawdown areas
|
|
for name, portfolio in results.items():
|
|
if "portfolio_value" in portfolio.columns:
|
|
values = portfolio["portfolio_value"]
|
|
running_max = values.cummax()
|
|
drawdown = (values - running_max) / running_max * 100
|
|
ax.fill_between(portfolio.index, drawdown, 0, alpha=0.1)
|
|
|
|
# Format y-axis as percentage
|
|
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.1f}%'))
|
|
|
|
plt.tight_layout()
|
|
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
print(f"✓ Saved drawdown plot to: {output_path}")
|
|
|
|
return fig
|
|
|
|
|
|
def plot_returns_distribution(
|
|
results: Dict[str, pd.DataFrame],
|
|
ticker: str,
|
|
output_path: str = None,
|
|
figsize: tuple = (14, 8)
|
|
) -> plt.Figure:
|
|
"""
|
|
Plot distribution of daily returns for all strategies.
|
|
|
|
Args:
|
|
results: Dictionary mapping strategy name to portfolio DataFrame
|
|
ticker: Stock ticker symbol
|
|
output_path: Path to save the figure (optional)
|
|
figsize: Figure size (width, height)
|
|
|
|
Returns:
|
|
matplotlib Figure object
|
|
"""
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
|
|
for name, portfolio in results.items():
|
|
if "strategy_return" in portfolio.columns:
|
|
returns = portfolio["strategy_return"].dropna() * 100 # Convert to percentage
|
|
ax.hist(returns, bins=50, alpha=0.5, label=name, density=True)
|
|
|
|
ax.set_xlabel('Daily Return (%)', fontsize=12, fontweight='bold')
|
|
ax.set_ylabel('Density', fontsize=12, fontweight='bold')
|
|
ax.set_title(f'{ticker} - Returns Distribution', fontsize=14, fontweight='bold')
|
|
ax.legend(loc='best', fontsize=10)
|
|
ax.grid(True, alpha=0.3)
|
|
ax.axvline(x=0, color='black', linestyle='--', linewidth=1, alpha=0.5)
|
|
|
|
plt.tight_layout()
|
|
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
print(f"✓ Saved returns distribution plot to: {output_path}")
|
|
|
|
return fig
|
|
|
|
|
|
def create_summary_report(
|
|
ticker: str,
|
|
results: Dict[str, pd.DataFrame],
|
|
comparison_df: pd.DataFrame,
|
|
output_dir: str
|
|
) -> None:
|
|
"""
|
|
Generate comprehensive visual summary report.
|
|
Creates all standard plots and saves them to output directory.
|
|
|
|
Args:
|
|
ticker: Stock ticker symbol
|
|
results: Dictionary mapping strategy name to portfolio DataFrame
|
|
comparison_df: DataFrame with performance metrics comparison
|
|
output_dir: Directory to save output files
|
|
"""
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
print("\nGenerating visualizations...")
|
|
|
|
# 1. Cumulative Returns
|
|
try:
|
|
plot_cumulative_returns(
|
|
results,
|
|
ticker,
|
|
output_path=str(output_path / f"{ticker}_cumulative_returns.png")
|
|
)
|
|
except Exception as e:
|
|
print(f"✗ Failed to generate cumulative returns plot: {e}")
|
|
|
|
# 2. Metrics Comparison
|
|
try:
|
|
plot_metrics_comparison(
|
|
comparison_df,
|
|
ticker,
|
|
output_path=str(output_path / f"{ticker}_metrics_comparison.png")
|
|
)
|
|
except Exception as e:
|
|
print(f"✗ Failed to generate metrics comparison plot: {e}")
|
|
|
|
# 3. Drawdown Analysis
|
|
try:
|
|
plot_drawdown(
|
|
results,
|
|
ticker,
|
|
output_path=str(output_path / f"{ticker}_drawdown.png")
|
|
)
|
|
except Exception as e:
|
|
print(f"✗ Failed to generate drawdown plot: {e}")
|
|
|
|
# 4. Transaction History (if TradingAgents results available)
|
|
if "TradingAgents" in results:
|
|
try:
|
|
plot_transaction_history(
|
|
results["TradingAgents"],
|
|
ticker,
|
|
strategy_name="TradingAgents",
|
|
output_path=str(output_path / f"{ticker}_TradingAgents_transactions.png")
|
|
)
|
|
except Exception as e:
|
|
print(f"✗ Failed to generate transaction history plot: {e}")
|
|
|
|
# 5. Returns Distribution
|
|
try:
|
|
plot_returns_distribution(
|
|
results,
|
|
ticker,
|
|
output_path=str(output_path / f"{ticker}_returns_distribution.png")
|
|
)
|
|
except Exception as e:
|
|
print(f"✗ Failed to generate returns distribution plot: {e}")
|
|
|
|
print(f"\n✓ All visualizations saved to: {output_dir}")
|
|
|
|
|
|
def plot_cumulative_returns_from_results(
|
|
results_dir: str,
|
|
ticker: str,
|
|
output_path: str = None,
|
|
figsize: tuple = (12, 7)
|
|
) -> plt.Figure:
|
|
"""
|
|
Plot cumulative returns comparison from saved JSON results.
|
|
|
|
Args:
|
|
results_dir: Directory containing strategy result folders
|
|
ticker: Stock ticker symbol
|
|
output_path: Path to save the figure (optional)
|
|
figsize: Figure size (width, height)
|
|
|
|
Returns:
|
|
matplotlib Figure object
|
|
"""
|
|
results_path = Path(results_dir)
|
|
|
|
# Define strategies to load
|
|
strategies = {
|
|
'BuyAndHold': 'BuyAndHoldStrategy',
|
|
'MACD': 'MACDStrategy',
|
|
'KDJ&RSI': 'KDJRSIStrategy',
|
|
'ZMR': 'ZMRStrategy',
|
|
'SMA': 'SMAStrategy',
|
|
'TradingAgents': 'TradingAgents'
|
|
}
|
|
|
|
fig, ax = plt.subplots(figsize=figsize)
|
|
|
|
# Load and plot each strategy
|
|
for folder_name, display_name in strategies.items():
|
|
strategy_dir = results_path / folder_name
|
|
if not strategy_dir.exists():
|
|
continue
|
|
|
|
# Find actions JSON file
|
|
action_files = list(strategy_dir.glob("actions_*.json"))
|
|
if not action_files:
|
|
continue
|
|
|
|
try:
|
|
# Load data
|
|
with open(action_files[0], 'r') as f:
|
|
data = json.load(f)
|
|
|
|
# Extract date and cumulative_return
|
|
dates = pd.to_datetime([action['date'] for action in data['actions']])
|
|
cumulative_returns = [action['cumulative_return'] for action in data['actions']]
|
|
|
|
# Plot
|
|
linewidth = 2.5 if display_name == 'TradingAgents' else 1.5
|
|
ax.plot(dates, cumulative_returns, label=display_name,
|
|
linewidth=linewidth, alpha=0.9)
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Failed to load {display_name}: {e}")
|
|
|
|
ax.set_xlabel('Date', fontsize=12)
|
|
ax.set_ylabel('Cumulative Return', fontsize=12)
|
|
ax.set_title(f'Strategy Comparison - Cumulative Returns for {ticker}',
|
|
fontsize=14, fontweight='bold')
|
|
ax.legend(title='Strategies', loc='best', fontsize=10, framealpha=0.9)
|
|
ax.grid(True, alpha=0.3, linestyle='--')
|
|
ax.axhline(y=1.0, color='black', linestyle='--', linewidth=1, alpha=0.5)
|
|
|
|
plt.tight_layout()
|
|
|
|
if output_path:
|
|
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
|
print(f"✓ Saved cumulative returns comparison to: {output_path}")
|
|
|
|
return fig
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage / testing
|
|
print("Visualization module loaded successfully!")
|
|
print("\nAvailable functions:")
|
|
print(" - plot_cumulative_returns")
|
|
print(" - plot_cumulative_returns_from_results")
|
|
print(" - plot_transaction_history")
|
|
print(" - plot_metrics_comparison")
|
|
print(" - plot_drawdown")
|
|
print(" - plot_returns_distribution")
|
|
print(" - create_summary_report")
|
|
|