diff --git a/evaluation_long_short/visualize.py b/evaluation_long_short/visualize.py index 416fa89a..a67d2d8d 100644 --- a/evaluation_long_short/visualize.py +++ b/evaluation_long_short/visualize.py @@ -416,7 +416,8 @@ def plot_cumulative_returns_from_results( 'KDJ&RSI': 'KDJRSIStrategy', 'ZMR': 'ZMRStrategy', 'SMA': 'SMAStrategy', - 'TradingAgents': 'TradingAgents' + 'TradingAgents': 'TradingAgents', + 'TradingAgents_DAPT': 'TradingAgents (DAPT+SFT)' } fig, ax = plt.subplots(figsize=figsize) @@ -441,10 +442,16 @@ def plot_cumulative_returns_from_results( dates = pd.to_datetime([action['date'] for action in data['actions']]) cumulative_returns = [action['cumulative_return'] for action in data['actions']] - # Plot - linewidth = 2.5 if display_name == 'TradingAgents' else 1.5 + # Plot with enhanced styling for TradingAgents variants + if 'TradingAgents' in display_name: + linewidth = 2.5 + alpha = 0.95 + else: + linewidth = 1.5 + alpha = 0.8 + ax.plot(dates, cumulative_returns, label=display_name, - linewidth=linewidth, alpha=0.9) + linewidth=linewidth, alpha=alpha) except Exception as e: print(f"Warning: Failed to load {display_name}: {e}")