Visualization added
This commit is contained in:
parent
ffff3050c8
commit
e99138f5b9
|
|
@ -16,6 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
|||
from evaluation.baseline_strategies import get_all_baseline_strategies
|
||||
from evaluation.backtest import BacktestEngine, TradingAgentsBacktester, load_stock_data, standardize_single_ticker
|
||||
from evaluation.metrics import calculate_all_metrics, create_comparison_table, print_metrics
|
||||
from evaluation.visualize import plot_cumulative_returns_from_results
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
|
@ -178,6 +179,32 @@ def run_evaluation(
|
|||
all_metrics[name] = metrics
|
||||
print_metrics(metrics, name)
|
||||
|
||||
# Generate cumulative returns comparison plot
|
||||
print("\n" + "="*80)
|
||||
print("STEP 5: Generating Comparison Plot")
|
||||
print("="*80)
|
||||
try:
|
||||
comparison_plot_path = str(out / ticker / "strategy_comparison.png")
|
||||
plot_cumulative_returns_from_results(
|
||||
results_dir=str(out / ticker),
|
||||
ticker=ticker,
|
||||
output_path=comparison_plot_path
|
||||
)
|
||||
# Also save as PDF
|
||||
pdf_path = comparison_plot_path.replace('.png', '.pdf')
|
||||
plot_cumulative_returns_from_results(
|
||||
results_dir=str(out / ticker),
|
||||
ticker=ticker,
|
||||
output_path=pdf_path
|
||||
)
|
||||
print(f"\n✓ Comparison plot saved to:")
|
||||
print(f" - {comparison_plot_path}")
|
||||
print(f" - {pdf_path}")
|
||||
except Exception as e:
|
||||
print(f"\n✗ Failed to generate comparison plot: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("EVALUATION COMPLETE")
|
||||
print("="*80)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
|
|||
from pathlib import Path
|
||||
from typing import Dict
|
||||
import warnings
|
||||
import json
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
|
@ -388,11 +389,89 @@ def create_summary_report(
|
|||
print(f"\n✓ All visualizations saved to: {output_dir}")
|
||||
|
||||
|
||||
def plot_cumulative_returns_from_results(
|
||||
results_dir: str,
|
||||
ticker: str,
|
||||
output_path: str = None,
|
||||
figsize: tuple = (12, 7)
|
||||
) -> plt.Figure:
|
||||
"""
|
||||
Plot cumulative returns comparison from saved JSON results.
|
||||
|
||||
Args:
|
||||
results_dir: Directory containing strategy result folders
|
||||
ticker: Stock ticker symbol
|
||||
output_path: Path to save the figure (optional)
|
||||
figsize: Figure size (width, height)
|
||||
|
||||
Returns:
|
||||
matplotlib Figure object
|
||||
"""
|
||||
results_path = Path(results_dir)
|
||||
|
||||
# Define strategies to load
|
||||
strategies = {
|
||||
'BuyAndHold': 'BuyAndHoldStrategy',
|
||||
'MACD': 'MACDStrategy',
|
||||
'KDJ&RSI': 'KDJRSIStrategy',
|
||||
'ZMR': 'ZMRStrategy',
|
||||
'SMA': 'SMAStrategy',
|
||||
'TradingAgents': 'TradingAgents'
|
||||
}
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Load and plot each strategy
|
||||
for folder_name, display_name in strategies.items():
|
||||
strategy_dir = results_path / folder_name
|
||||
if not strategy_dir.exists():
|
||||
continue
|
||||
|
||||
# Find actions JSON file
|
||||
action_files = list(strategy_dir.glob("actions_*.json"))
|
||||
if not action_files:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Load data
|
||||
with open(action_files[0], 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Extract date and cumulative_return
|
||||
dates = pd.to_datetime([action['date'] for action in data['actions']])
|
||||
cumulative_returns = [action['cumulative_return'] for action in data['actions']]
|
||||
|
||||
# Plot
|
||||
linewidth = 2.5 if display_name == 'TradingAgents' else 1.5
|
||||
ax.plot(dates, cumulative_returns, label=display_name,
|
||||
linewidth=linewidth, alpha=0.9)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load {display_name}: {e}")
|
||||
|
||||
ax.set_xlabel('Date', fontsize=12)
|
||||
ax.set_ylabel('Cumulative Return', fontsize=12)
|
||||
ax.set_title(f'Strategy Comparison - Cumulative Returns for {ticker}',
|
||||
fontsize=14, fontweight='bold')
|
||||
ax.legend(title='Strategies', loc='best', fontsize=10, framealpha=0.9)
|
||||
ax.grid(True, alpha=0.3, linestyle='--')
|
||||
ax.axhline(y=1.0, color='black', linestyle='--', linewidth=1, alpha=0.5)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if output_path:
|
||||
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
print(f"✓ Saved cumulative returns comparison to: {output_path}")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage / testing
|
||||
print("Visualization module loaded successfully!")
|
||||
print("\nAvailable functions:")
|
||||
print(" - plot_cumulative_returns")
|
||||
print(" - plot_cumulative_returns_from_results")
|
||||
print(" - plot_transaction_history")
|
||||
print(" - plot_metrics_comparison")
|
||||
print(" - plot_drawdown")
|
||||
|
|
|
|||
Loading…
Reference in New Issue