fixed long-short evaluation
This commit is contained in:
parent
e3952edf91
commit
55e0287ace
|
|
@ -30,72 +30,81 @@ class TradingAgentsBacktester:
|
|||
|
||||
def backtest(self, ticker: str, start_date: str, end_date: str, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Backtest TradingAgents using the same return calculation logic as rule-based strategies.
|
||||
|
||||
Process:
|
||||
1. Collect signals (actions: 1=BUY, 0=HOLD, -1=SELL) for all dates
|
||||
2. Convert actions to positions (0=flat, 1=long) using same logic as baselines
|
||||
3. Calculate returns as: strategy_return = position.shift(1) * market_return
|
||||
Backtest TradingAgents with realistic single-asset account simulation.
|
||||
Supports long, short, and flat positions with 1× leverage on shorts.
|
||||
"""
|
||||
# Restrict to window
|
||||
df = data.loc[start_date:end_date].copy()
|
||||
|
||||
decisions: List[Dict] = []
|
||||
decisions = []
|
||||
signals = pd.Series(0, index=df.index, dtype=float)
|
||||
|
||||
print(f"\nRunning TradingAgents backtest on {ticker} from {start_date} to {end_date}")
|
||||
print(f"Total trading days: {len(df)}")
|
||||
print("-" * 80)
|
||||
|
||||
# Step 1: Collect all signals/decisions
|
||||
# === Step 1: Collect signals ===
|
||||
for i, (date, row) in enumerate(df.iterrows()):
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
price = float(row["Close"])
|
||||
|
||||
# Get decision from TradingAgents graph
|
||||
try:
|
||||
print(f"\n[{i+1}/{len(df)}] {date_str} ... ", end="")
|
||||
final_state, decision = self.graph.propagate(ticker, date_str)
|
||||
print(f"Decision: {decision}")
|
||||
signal = self._parse_decision(decision)
|
||||
decisions.append({"date": date_str, "decision": decision, "signal": signal, "price": price})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
signal = 0
|
||||
decisions.append({"date": date_str, "decision": "ERROR", "signal": 0, "price": price, "error": str(e)})
|
||||
|
||||
signals.loc[date] = signal
|
||||
|
||||
# Step 2: Convert actions to positions (same logic as baseline strategies)
|
||||
position = self._actions_to_position(signals)
|
||||
|
||||
# Step 3: Calculate returns using standardized logic
|
||||
# === Step 2: Run realistic cash+shares backtest ===
|
||||
close = pd.to_numeric(df["Close"], errors="coerce")
|
||||
market_ret = close.pct_change().fillna(0.0)
|
||||
exposure = position.shift(1).fillna(0.0) # Yesterday's position determines today's exposure
|
||||
strat_ret = (exposure * market_ret).astype(float)
|
||||
|
||||
cumret = (1.0 + strat_ret).cumprod()
|
||||
portval = self.initial_capital * cumret
|
||||
|
||||
# Build portfolio DataFrame with same structure as baseline strategies
|
||||
portfolio = pd.DataFrame(index=df.index)
|
||||
portfolio["action"] = signals # 1=BUY, 0=HOLD, -1=SELL
|
||||
portfolio["position"] = position # 1=long, 0=flat
|
||||
portfolio["close"] = close
|
||||
if "Volume" in df.columns:
|
||||
vol = df["Volume"]
|
||||
if isinstance(vol, pd.DataFrame) and vol.shape[1] == 1:
|
||||
vol = vol.iloc[:, 0]
|
||||
if isinstance(vol, pd.Series):
|
||||
portfolio["Volume"] = vol
|
||||
portfolio["market_return"] = market_ret
|
||||
portfolio["strategy_return"] = strat_ret
|
||||
portfolio["cumulative_return"] = cumret
|
||||
portfolio["portfolio_value"] = portval
|
||||
portfolio["trade_delta"] = portfolio["position"].diff().fillna(0.0) # +1=buy, -1=sell
|
||||
cash = self.initial_capital
|
||||
shares = 0.0
|
||||
prev_value = cash
|
||||
records = []
|
||||
|
||||
for i, (date, price) in enumerate(close.items()):
|
||||
action = signals.iloc[i]
|
||||
|
||||
# 先计算上一个交易日的组合价值
|
||||
portfolio_value = cash + shares * price
|
||||
|
||||
# === 若方向改变,先平仓 ===
|
||||
if (shares > 0 and action <= 0) or (shares < 0 and action >= 0):
|
||||
cash += shares * price # 卖出现有股票或回补空头
|
||||
shares = 0.0
|
||||
|
||||
# === 建仓逻辑 ===
|
||||
if action == 1 and shares == 0:
|
||||
# 做多
|
||||
shares = cash / price
|
||||
cash = 0.0
|
||||
elif action == -1 and shares == 0:
|
||||
# 做空(1倍杠杆)
|
||||
shares = -cash / price
|
||||
cash = 2 * cash # 保证金 + 卖出所得
|
||||
|
||||
# === 更新组合价值 ===
|
||||
new_value = cash + shares * price
|
||||
daily_return = (new_value / prev_value) - 1 if prev_value != 0 else 0.0
|
||||
prev_value = new_value
|
||||
|
||||
records.append({
|
||||
"date": date.strftime("%Y-%m-%d"),
|
||||
"action": action,
|
||||
"shares": shares,
|
||||
"close_price": price,
|
||||
"cash": cash,
|
||||
"portfolio_value": new_value,
|
||||
"strategy_return": daily_return,
|
||||
})
|
||||
|
||||
# === Step 3: 转为 DataFrame 并计算累计收益 ===
|
||||
portfolio = pd.DataFrame(records).set_index("date")
|
||||
portfolio["cumulative_return"] = (1 + portfolio["strategy_return"]).cumprod()
|
||||
portfolio["ticker"] = ticker
|
||||
self.latest_portfolio = portfolio
|
||||
self._save_decisions_log(ticker, decisions, start_date, end_date)
|
||||
return portfolio
|
||||
|
||||
|
|
@ -135,13 +144,37 @@ class TradingAgentsBacktester:
|
|||
return 0
|
||||
|
||||
def _save_decisions_log(self, ticker: str, decisions: List[Dict], start_date: str, end_date: str):
|
||||
# Use output_dir if provided, otherwise use default
|
||||
"""
|
||||
Save detailed TradingAgents decisions and portfolio state to JSON.
|
||||
Adds shares, cash, and cumulative return (cr) from the latest backtest results.
|
||||
"""
|
||||
if self.output_dir:
|
||||
out = Path(self.output_dir) / ticker / "TradingAgents"
|
||||
else:
|
||||
out = Path(f"eval_results/{ticker}/TradingAgents")
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
fp = out / f"decisions_{start_date}_to_{end_date}.json"
|
||||
|
||||
# Try to include computed portfolio metrics if available
|
||||
try:
|
||||
# Attempt to load the latest portfolio CSV/DF from memory
|
||||
if hasattr(self, "latest_portfolio"):
|
||||
port = self.latest_portfolio
|
||||
port = port.reset_index()
|
||||
port_dict = {d["date"]: d for d in port.to_dict(orient="records")}
|
||||
# Merge portfolio stats into each decision record
|
||||
for d in decisions:
|
||||
date = d["date"]
|
||||
if date in port_dict:
|
||||
d.update({
|
||||
"shares": port_dict[date].get("shares"),
|
||||
"cash": port_dict[date].get("cash"),
|
||||
"portfolio_value": port_dict[date].get("portfolio_value"),
|
||||
"cumulative_return": port_dict[date].get("cumulative_return"),
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Warning: could not merge portfolio stats into log ({e})")
|
||||
|
||||
with open(fp, "w") as f:
|
||||
json.dump({
|
||||
"strategy": "TradingAgents",
|
||||
|
|
@ -151,6 +184,7 @@ class TradingAgentsBacktester:
|
|||
"total_days": len(decisions),
|
||||
"decisions": decisions
|
||||
}, f, indent=2)
|
||||
|
||||
print(f" ✓ Saved TradingAgents detailed decisions to: {fp}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -151,7 +151,8 @@ def run_evaluation(
|
|||
print(f" Debate Rounds: {cfg['max_debate_rounds']}")
|
||||
|
||||
graph = TradingAgentsGraph(
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
selected_analysts=["news"],
|
||||
# selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
debug=False,
|
||||
config=cfg
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue