Visualization added

This commit is contained in:
Quanliang Liu 2025-11-07 08:53:46 -06:00
parent ffff3050c8
commit e99138f5b9
2 changed files with 106 additions and 0 deletions

View File

@ -16,6 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
from evaluation.baseline_strategies import get_all_baseline_strategies
from evaluation.backtest import BacktestEngine, TradingAgentsBacktester, load_stock_data, standardize_single_ticker
from evaluation.metrics import calculate_all_metrics, create_comparison_table, print_metrics
from evaluation.visualize import plot_cumulative_returns_from_results
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
@ -178,6 +179,32 @@ def run_evaluation(
all_metrics[name] = metrics
print_metrics(metrics, name)
# Generate cumulative returns comparison plot
print("\n" + "="*80)
print("STEP 5: Generating Comparison Plot")
print("="*80)
try:
comparison_plot_path = str(out / ticker / "strategy_comparison.png")
plot_cumulative_returns_from_results(
results_dir=str(out / ticker),
ticker=ticker,
output_path=comparison_plot_path
)
# Also save as PDF
pdf_path = comparison_plot_path.replace('.png', '.pdf')
plot_cumulative_returns_from_results(
results_dir=str(out / ticker),
ticker=ticker,
output_path=pdf_path
)
print(f"\n✓ Comparison plot saved to:")
print(f" - {comparison_plot_path}")
print(f" - {pdf_path}")
except Exception as e:
print(f"\n✗ Failed to generate comparison plot: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*80)
print("EVALUATION COMPLETE")
print("="*80)

View File

@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict
import warnings
import json
warnings.filterwarnings('ignore')
@ -388,11 +389,89 @@ def create_summary_report(
print(f"\n✓ All visualizations saved to: {output_dir}")
def plot_cumulative_returns_from_results(
results_dir: str,
ticker: str,
output_path: str = None,
figsize: tuple = (12, 7)
) -> plt.Figure:
"""
Plot cumulative returns comparison from saved JSON results.
Args:
results_dir: Directory containing strategy result folders
ticker: Stock ticker symbol
output_path: Path to save the figure (optional)
figsize: Figure size (width, height)
Returns:
matplotlib Figure object
"""
results_path = Path(results_dir)
# Define strategies to load
strategies = {
'BuyAndHold': 'BuyAndHoldStrategy',
'MACD': 'MACDStrategy',
'KDJ&RSI': 'KDJRSIStrategy',
'ZMR': 'ZMRStrategy',
'SMA': 'SMAStrategy',
'TradingAgents': 'TradingAgents'
}
fig, ax = plt.subplots(figsize=figsize)
# Load and plot each strategy
for folder_name, display_name in strategies.items():
strategy_dir = results_path / folder_name
if not strategy_dir.exists():
continue
# Find actions JSON file
action_files = list(strategy_dir.glob("actions_*.json"))
if not action_files:
continue
try:
# Load data
with open(action_files[0], 'r') as f:
data = json.load(f)
# Extract date and cumulative_return
dates = pd.to_datetime([action['date'] for action in data['actions']])
cumulative_returns = [action['cumulative_return'] for action in data['actions']]
# Plot
linewidth = 2.5 if display_name == 'TradingAgents' else 1.5
ax.plot(dates, cumulative_returns, label=display_name,
linewidth=linewidth, alpha=0.9)
except Exception as e:
print(f"Warning: Failed to load {display_name}: {e}")
ax.set_xlabel('Date', fontsize=12)
ax.set_ylabel('Cumulative Return', fontsize=12)
ax.set_title(f'Strategy Comparison - Cumulative Returns for {ticker}',
fontsize=14, fontweight='bold')
ax.legend(title='Strategies', loc='best', fontsize=10, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle='--')
ax.axhline(y=1.0, color='black', linestyle='--', linewidth=1, alpha=0.5)
plt.tight_layout()
if output_path:
fig.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"✓ Saved cumulative returns comparison to: {output_path}")
return fig
if __name__ == "__main__":
# Example usage / testing
print("Visualization module loaded successfully!")
print("\nAvailable functions:")
print(" - plot_cumulative_returns")
print(" - plot_cumulative_returns_from_results")
print(" - plot_transaction_history")
print(" - plot_metrics_comparison")
print(" - plot_drawdown")