Visualization added
This commit is contained in:
parent
ffff3050c8
commit
e99138f5b9
|
|
@ -16,6 +16,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
from evaluation.baseline_strategies import get_all_baseline_strategies
|
from evaluation.baseline_strategies import get_all_baseline_strategies
|
||||||
from evaluation.backtest import BacktestEngine, TradingAgentsBacktester, load_stock_data, standardize_single_ticker
|
from evaluation.backtest import BacktestEngine, TradingAgentsBacktester, load_stock_data, standardize_single_ticker
|
||||||
from evaluation.metrics import calculate_all_metrics, create_comparison_table, print_metrics
|
from evaluation.metrics import calculate_all_metrics, create_comparison_table, print_metrics
|
||||||
|
from evaluation.visualize import plot_cumulative_returns_from_results
|
||||||
|
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
@ -178,6 +179,32 @@ def run_evaluation(
|
||||||
all_metrics[name] = metrics
|
all_metrics[name] = metrics
|
||||||
print_metrics(metrics, name)
|
print_metrics(metrics, name)
|
||||||
|
|
||||||
|
# Generate cumulative returns comparison plot
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("STEP 5: Generating Comparison Plot")
|
||||||
|
print("="*80)
|
||||||
|
try:
|
||||||
|
comparison_plot_path = str(out / ticker / "strategy_comparison.png")
|
||||||
|
plot_cumulative_returns_from_results(
|
||||||
|
results_dir=str(out / ticker),
|
||||||
|
ticker=ticker,
|
||||||
|
output_path=comparison_plot_path
|
||||||
|
)
|
||||||
|
# Also save as PDF
|
||||||
|
pdf_path = comparison_plot_path.replace('.png', '.pdf')
|
||||||
|
plot_cumulative_returns_from_results(
|
||||||
|
results_dir=str(out / ticker),
|
||||||
|
ticker=ticker,
|
||||||
|
output_path=pdf_path
|
||||||
|
)
|
||||||
|
print(f"\n✓ Comparison plot saved to:")
|
||||||
|
print(f" - {comparison_plot_path}")
|
||||||
|
print(f" - {pdf_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ Failed to generate comparison plot: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
print("\n" + "="*80)
|
print("\n" + "="*80)
|
||||||
print("EVALUATION COMPLETE")
|
print("EVALUATION COMPLETE")
|
||||||
print("="*80)
|
print("="*80)
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import warnings
|
import warnings
|
||||||
|
import json
|
||||||
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
@ -388,11 +389,89 @@ def create_summary_report(
|
||||||
print(f"\n✓ All visualizations saved to: {output_dir}")
|
print(f"\n✓ All visualizations saved to: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_cumulative_returns_from_results(
|
||||||
|
results_dir: str,
|
||||||
|
ticker: str,
|
||||||
|
output_path: str = None,
|
||||||
|
figsize: tuple = (12, 7)
|
||||||
|
) -> plt.Figure:
|
||||||
|
"""
|
||||||
|
Plot cumulative returns comparison from saved JSON results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results_dir: Directory containing strategy result folders
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
output_path: Path to save the figure (optional)
|
||||||
|
figsize: Figure size (width, height)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib Figure object
|
||||||
|
"""
|
||||||
|
results_path = Path(results_dir)
|
||||||
|
|
||||||
|
# Define strategies to load
|
||||||
|
strategies = {
|
||||||
|
'BuyAndHold': 'BuyAndHoldStrategy',
|
||||||
|
'MACD': 'MACDStrategy',
|
||||||
|
'KDJ&RSI': 'KDJRSIStrategy',
|
||||||
|
'ZMR': 'ZMRStrategy',
|
||||||
|
'SMA': 'SMAStrategy',
|
||||||
|
'TradingAgents': 'TradingAgents'
|
||||||
|
}
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
|
||||||
|
# Load and plot each strategy
|
||||||
|
for folder_name, display_name in strategies.items():
|
||||||
|
strategy_dir = results_path / folder_name
|
||||||
|
if not strategy_dir.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find actions JSON file
|
||||||
|
action_files = list(strategy_dir.glob("actions_*.json"))
|
||||||
|
if not action_files:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load data
|
||||||
|
with open(action_files[0], 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# Extract date and cumulative_return
|
||||||
|
dates = pd.to_datetime([action['date'] for action in data['actions']])
|
||||||
|
cumulative_returns = [action['cumulative_return'] for action in data['actions']]
|
||||||
|
|
||||||
|
# Plot
|
||||||
|
linewidth = 2.5 if display_name == 'TradingAgents' else 1.5
|
||||||
|
ax.plot(dates, cumulative_returns, label=display_name,
|
||||||
|
linewidth=linewidth, alpha=0.9)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to load {display_name}: {e}")
|
||||||
|
|
||||||
|
ax.set_xlabel('Date', fontsize=12)
|
||||||
|
ax.set_ylabel('Cumulative Return', fontsize=12)
|
||||||
|
ax.set_title(f'Strategy Comparison - Cumulative Returns for {ticker}',
|
||||||
|
fontsize=14, fontweight='bold')
|
||||||
|
ax.legend(title='Strategies', loc='best', fontsize=10, framealpha=0.9)
|
||||||
|
ax.grid(True, alpha=0.3, linestyle='--')
|
||||||
|
ax.axhline(y=1.0, color='black', linestyle='--', linewidth=1, alpha=0.5)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||||
|
print(f"✓ Saved cumulative returns comparison to: {output_path}")
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Example usage / testing
|
# Example usage / testing
|
||||||
print("Visualization module loaded successfully!")
|
print("Visualization module loaded successfully!")
|
||||||
print("\nAvailable functions:")
|
print("\nAvailable functions:")
|
||||||
print(" - plot_cumulative_returns")
|
print(" - plot_cumulative_returns")
|
||||||
|
print(" - plot_cumulative_returns_from_results")
|
||||||
print(" - plot_transaction_history")
|
print(" - plot_transaction_history")
|
||||||
print(" - plot_metrics_comparison")
|
print(" - plot_metrics_comparison")
|
||||||
print(" - plot_drawdown")
|
print(" - plot_drawdown")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue