feat: implement trend override, harden regime detection, and organize tests
**Core Logic (Safety Valve & Regime Detection):**
* Added "Momentum Override": `Overall Return > 30%` now forces `TRENDING_UP` (Bull) regime to capture volatile winners.
* Prioritized Trend Strength (ADX) over Volatility for single stocks.
* Fixed `Hurst Exponent` calculation to handle non-positive inputs safely.
* **Data Reliability ([market_analyst.py]
This commit is contained in:
parent
d0f229a444
commit
a6e4c9b770
|
|
@ -0,0 +1,62 @@
|
|||
# Trading Agents Verification Suite
|
||||
|
||||
This folder contains unit tests and verification scripts to validate the functionality of the Trading Agents system.
|
||||
|
||||
## Available Tests
|
||||
|
||||
|
||||
## Core Logic Tests
|
||||
|
||||
1. **`test_regime_detection.py`**
|
||||
* **Purpose:** Validates mathematical components (ADX, Volatility, Hurst) of the `RegimeDetector`.
|
||||
* **Usage:** `python tests/test_regime_detection.py`
|
||||
|
||||
2. **`test_market_node.py`**
|
||||
* **Purpose:** End-to-end verification of `market_analyst_node`. Checks data fetching logic and regime integration.
|
||||
* **Usage:** `python tests/test_market_node.py`
|
||||
|
||||
3. **`test_override.py`**
|
||||
* **Purpose:** Unit tests for "Don't Fight the Tape" safety logic. Verifies protection of growth leaders.
|
||||
* **Usage:** `python tests/test_override.py`
|
||||
|
||||
## Integration & API Tests
|
||||
|
||||
4. **`test_global_news.py`**
|
||||
* **Purpose:** Verifies news fetching capabilities.
|
||||
* **Usage:** `python tests/test_global_news.py`
|
||||
|
||||
5. **`test_google_api.py`** & **`verify_google_key.py`**
|
||||
* **Purpose:** Validates Google Gemini API connectivity and key validity.
|
||||
* **Usage:** `python tests/test_google_api.py`
|
||||
|
||||
6. **`verify_alpaca.py`**
|
||||
* **Purpose:** Checks Alpaca trading API connection.
|
||||
* **Usage:** `python tests/verify_alpaca.py`
|
||||
|
||||
## Infrastructure & Performance
|
||||
|
||||
7. **`verify_local_embeddings.py`** & **`verify_ollama_embeddings.py`**
|
||||
* **Purpose:** Validates local embedding models (Ollama/TEI) for RAG.
|
||||
* **Usage:** `python tests/verify_local_embeddings.py`
|
||||
|
||||
8. **`verify_tei_native.py`**
|
||||
* **Purpose:** Tests Text Embeddings Inference (TEI) native endpoint.
|
||||
* **Usage:** `python tests/verify_tei_native.py`
|
||||
|
||||
9. **`bench_yfinance.py`**
|
||||
* **Purpose:** Benchmarks yfinance data fetch performance (latency/throughput).
|
||||
* **Usage:** `python tests/bench_yfinance.py`
|
||||
|
||||
10. **`verify_regime_integration.py`**
|
||||
* **Purpose:** Integration test for regime detection within the broader graph context.
|
||||
* **Usage:** `python tests/verify_regime_integration.py`
|
||||
|
||||
## How to Run
|
||||
|
||||
Ensure your virtual environment is activated:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
export PYTHONPATH=$PYTHONPATH:.
|
||||
python tests/test_market_node.py
|
||||
```
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
# Load env before imports
|
||||
load_dotenv()
|
||||
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.messages import AIMessage
|
||||
from tradingagents.agents.analysts.market_analyst import create_market_analyst
|
||||
|
||||
class MockLLM(Runnable):
|
||||
def bind_tools(self, tools, **kwargs):
|
||||
return self
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return AIMessage(content="Mock Market Analysis Report")
|
||||
|
||||
def test_market_analyst_node():
|
||||
print("🔍 TESTING MARKET ANALYST NODE...")
|
||||
|
||||
# 1. Setup
|
||||
mock_llm = MockLLM()
|
||||
market_analyst_node = create_market_analyst(mock_llm)
|
||||
|
||||
# 2. Mock State
|
||||
state = {
|
||||
"company_of_interest": "PLTR",
|
||||
"trade_date": "2026-01-11",
|
||||
"messages": []
|
||||
}
|
||||
|
||||
# 3. Execution
|
||||
print(f" Executing node for {state['company_of_interest']}...")
|
||||
try:
|
||||
# Pass only state as fixed in previous steps
|
||||
result = market_analyst_node(state)
|
||||
|
||||
# 4. Verification
|
||||
regime = result.get("market_regime")
|
||||
metrics = result.get("regime_metrics", {})
|
||||
|
||||
print(f"📊 RESULTING REGIME: {regime}")
|
||||
print(f" METRICS: {json.dumps(metrics, indent=2)}")
|
||||
|
||||
if regime != "UNKNOWN" and metrics:
|
||||
print("✅ PASS: Regime detected correctly!")
|
||||
else:
|
||||
print("❌ FAIL: Regime is UNKNOWN or metrics missing.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_market_analyst_node()
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
# Mock class to expose the method without full initialization
|
||||
class MockGraph(TradingAgentsGraph):
|
||||
def __init__(self):
|
||||
# Skip super init to avoid API keys requirements
|
||||
self.ticker = "MOCK_TICKER"
|
||||
|
||||
def test_trend_override():
|
||||
print("🔍 TESTING TREND OVERRIDE LOGIC...")
|
||||
|
||||
agent = MockGraph()
|
||||
|
||||
# Test Case 1: PLTR Scenario (Sell in Bull Market)
|
||||
print("\n[TEST 1] PLTR Scenario: Sell signal in Bull Market")
|
||||
decision = "FINAL TRANSACTION PROPOSAL: SELL 75%"
|
||||
hard_data = {
|
||||
"status": "OK",
|
||||
"current_price": 185.0,
|
||||
"sma_200": 150.0,
|
||||
"revenue_growth": 0.62
|
||||
}
|
||||
regime = "TRENDING_UP"
|
||||
|
||||
result = agent.apply_trend_override(decision, hard_data, regime)
|
||||
print(f"Input: {decision}")
|
||||
print(f"Regime: {regime}")
|
||||
if isinstance(result, dict) and result.get("action") == "HOLD":
|
||||
print("✅ PASS: Correctly recognized uptrend + growth to block SELL")
|
||||
else:
|
||||
print(f"❌ FAIL: Returned {result}")
|
||||
|
||||
# Test Case 2: Volatile Regime (Should still protect leader)
|
||||
print("\n[TEST 2] Volatile Regime protection")
|
||||
regime = "VOLATILE"
|
||||
result = agent.apply_trend_override(decision, hard_data, regime)
|
||||
print(f"Regime: {regime}")
|
||||
if isinstance(result, dict) and result.get("action") == "HOLD":
|
||||
print("✅ PASS: Protected leader in VOLATILE regime")
|
||||
else:
|
||||
print(f"❌ FAIL: Returned {result}")
|
||||
|
||||
# Test Case 3: Bear Market (Should allow sell)
|
||||
print("\n[TEST 3] Bear Market (Should allow SELL)")
|
||||
regime = "TRENDING_DOWN"
|
||||
result = agent.apply_trend_override(decision, hard_data, regime)
|
||||
print(f"Regime: {regime}")
|
||||
if result == decision:
|
||||
print("✅ PASS: Allowed SELL in Bear Market")
|
||||
else:
|
||||
print(f"❌ FAIL: Blocked SELL improperly: {result}")
|
||||
|
||||
# Test Case 4: Low Growth (Should allow sell)
|
||||
print("\n[TEST 4] Low Growth (Should allow SELL)")
|
||||
hard_data["revenue_growth"] = 0.10
|
||||
regime = "TRENDING_UP"
|
||||
result = agent.apply_trend_override(decision, hard_data, regime)
|
||||
if result == decision:
|
||||
print("✅ PASS: Allowed SELL for low growth stock")
|
||||
else:
|
||||
print(f"❌ FAIL: Blocked SELL for low growth: {result}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_trend_override()
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
import pandas as pd
|
||||
from io import StringIO
|
||||
from datetime import datetime, timedelta
|
||||
import yfinance as yf
|
||||
from tradingagents.engines.regime_detector import RegimeDetector, DynamicIndicatorSelector
|
||||
|
||||
def test_regime_detection():
|
||||
print("🧪 Testing Regime Detection for PLTR...")
|
||||
|
||||
ticker = "PLTR"
|
||||
current_date = "2026-01-11"
|
||||
|
||||
# Simulate the same logic as market_analyst_node
|
||||
dt_obj = datetime.strptime(current_date, "%Y-%m-%d")
|
||||
start_date = (dt_obj - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
|
||||
print(f" Fetching data from {start_date} to {current_date}")
|
||||
|
||||
# 1. Fetch raw data (simulating the tool call)
|
||||
ticker_obj = yf.Ticker(ticker)
|
||||
data = ticker_obj.history(start=start_date, end=current_date)
|
||||
|
||||
if data.empty:
|
||||
print("❌ FAILURE: No data retrieved from yfinance.")
|
||||
return
|
||||
|
||||
# Check columns
|
||||
print(f" Columns found: {list(data.columns)}")
|
||||
|
||||
# 2. Detect Regime
|
||||
try:
|
||||
prices = data['Close']
|
||||
regime, metrics = RegimeDetector.detect_regime(prices)
|
||||
print(f"✅ SUCCESS: Regime detected: {regime.value}")
|
||||
print(f" Metrics: {metrics}")
|
||||
|
||||
# Check if it matches 'trending_up' (as it should for PLTR in this hypothetical 2026 bull scenario)
|
||||
if regime.value == "trending_up":
|
||||
print("🌟 PLTR is in a BULL TREND.")
|
||||
except Exception as e:
|
||||
print(f"❌ FAILURE: Regime detection failed: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_regime_detection()
|
||||
|
|
@ -13,14 +13,13 @@ from io import StringIO
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
# Initialize anonymizer (shared instance appropriate here or inside)
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
def create_market_analyst(llm):
|
||||
|
||||
def market_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
# Re-initialize or reload anonymizer state
|
||||
anonymizer = TickerAnonymizer()
|
||||
real_ticker = state["company_of_interest"]
|
||||
# BLINDFIRE PROTOCOL: Anonymize Ticker
|
||||
ticker = anonymizer.anonymize_ticker(real_ticker)
|
||||
|
||||
# NOTE: We continue to use 'ticker' variable name but it now holds 'ASSET_XXX'
|
||||
|
|
@ -38,25 +37,45 @@ def create_market_analyst(llm):
|
|||
start_date = (dt_obj - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
|
||||
# Fetch data for regime detection using the anonymized ticker
|
||||
# This calls the tool which handles deanonymization internally if needed
|
||||
# (assuming core_stock_tools.get_stock_data handles the 'ASSET_XXX' -> Real mapping)
|
||||
# Use invoke for StructuredTool with ALL required args
|
||||
raw_data = get_stock_data.invoke({
|
||||
"symbol": ticker,
|
||||
"symbol": real_ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": current_date,
|
||||
"format": "csv"
|
||||
})
|
||||
|
||||
# Parse data
|
||||
if isinstance(raw_data, str) and "Error" not in raw_data and "No data" not in raw_data:
|
||||
if isinstance(raw_data, str) and len(raw_data.strip()) > 50 and "Error" not in raw_data and "No data" not in raw_data:
|
||||
# Parse data (Standardized CSV format with # comments)
|
||||
df = pd.read_csv(StringIO(raw_data), comment='#')
|
||||
|
||||
# Check for Close column
|
||||
# Handle case-insensitive 'Close' column
|
||||
if 'Close' not in df.columns:
|
||||
# Try to find a column that matches 'close' case-insensitively
|
||||
col_map = {c.lower(): c for c in df.columns}
|
||||
if 'close' in col_map:
|
||||
df.rename(columns={col_map['close']: 'Close'}, inplace=True)
|
||||
|
||||
# Clean index/date
|
||||
if 'Date' in df.columns:
|
||||
df['Date'] = pd.to_datetime(df['Date'])
|
||||
df.set_index('Date', inplace=True)
|
||||
|
||||
# Sort by date
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
# Check for sufficient data
|
||||
# Ensure 'Close' column exists after potential renaming
|
||||
if 'Close' in df.columns:
|
||||
price_data = df['Close']
|
||||
else:
|
||||
price_data = pd.Series([]) # Empty series if 'Close' column is not found
|
||||
|
||||
print(f"DEBUG: Regime Detection - Ticker: {real_ticker}, Rows: {len(price_data)}")
|
||||
|
||||
if not price_data.empty and len(price_data) >= 10:
|
||||
# Detect Regime
|
||||
regime, metrics = RegimeDetector.detect_regime(df['Close'])
|
||||
regime, metrics = RegimeDetector.detect_regime(price_data)
|
||||
optimal_params = DynamicIndicatorSelector.get_optimal_parameters(regime)
|
||||
regime_val = regime.value
|
||||
volatility_score = metrics.get("volatility", 0.0)
|
||||
|
|
@ -67,6 +86,10 @@ def create_market_analyst(llm):
|
|||
regime_context += f"RECOMMENDED STRATEGY: {optimal_params.get('strategy', 'N/A')}\n"
|
||||
regime_context += f"RECOMMENDED INDICATORS: {json.dumps(optimal_params)}\n"
|
||||
regime_context += f"RATIONALE: {optimal_params.get('rationale', '')}"
|
||||
else:
|
||||
print(f"WARNING: Insufficient price data for {ticker}. Columns: {list(df.columns)}, Len: {len(df)}")
|
||||
else:
|
||||
print(f"WARNING: Market data retrieval failed for regime detection for {ticker}. Data snippet: {str(raw_data)[:100]}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Regime detection failed for {ticker}: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import Annotated
|
|||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
import yfinance as yf
|
||||
import pandas as pd
|
||||
import os
|
||||
from .stockstats_utils import StockstatsUtils
|
||||
|
||||
|
|
@ -313,7 +314,7 @@ def get_balance_sheet(
|
|||
else:
|
||||
data = ticker_obj.balance_sheet
|
||||
|
||||
if data.empty:
|
||||
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
|
||||
return f"No balance sheet data found for symbol '{ticker}'"
|
||||
|
||||
# Convert to CSV string for consistency with other functions
|
||||
|
|
@ -343,7 +344,7 @@ def get_cashflow(
|
|||
else:
|
||||
data = ticker_obj.cashflow
|
||||
|
||||
if data.empty:
|
||||
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
|
||||
return f"No cash flow data found for symbol '{ticker}'"
|
||||
|
||||
# Convert to CSV string for consistency with other functions
|
||||
|
|
@ -373,7 +374,7 @@ def get_income_statement(
|
|||
else:
|
||||
data = ticker_obj.income_stmt
|
||||
|
||||
if data.empty:
|
||||
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
|
||||
return f"No income statement data found for symbol '{ticker}'"
|
||||
|
||||
# Convert to CSV string for consistency with other functions
|
||||
|
|
@ -460,3 +461,35 @@ def get_fundamentals(
|
|||
|
||||
except Exception as e:
|
||||
return f"Error retrieving fundamentals for {ticker}: {str(e)}"
|
||||
|
||||
def get_robust_revenue_growth(ticker: str) -> float:
|
||||
"""
|
||||
Retrieve revenue growth with fallback to manual calculation from quarterly financials.
|
||||
Returns growth as a float (e.g., 0.63 for 63%).
|
||||
"""
|
||||
try:
|
||||
ticker_obj = yf.Ticker(ticker.upper())
|
||||
|
||||
# 1. Try .info first (Quick but often unreliable/stale)
|
||||
info = ticker_obj.info
|
||||
growth = info.get('revenueGrowth')
|
||||
if growth is not None and isinstance(growth, (int, float)) and growth != 0:
|
||||
return float(growth)
|
||||
|
||||
# 2. Fallback: Manual calculation from Quarterly Financials
|
||||
# Formula: (Revenue_Current_Q - Revenue_Year_Ago_Q) / Revenue_Year_Ago_Q
|
||||
q_financials = ticker_obj.quarterly_financials
|
||||
if q_financials is not None and not q_financials.empty and 'Total Revenue' in q_financials.index:
|
||||
rev_series = q_financials.loc['Total Revenue']
|
||||
if len(rev_series) >= 5: # Need at least 5 quarters to compare Q1 vs Q5 (Year Ago)
|
||||
rev_now = rev_series.iloc[0]
|
||||
rev_year_ago = rev_series.iloc[4]
|
||||
|
||||
if rev_year_ago and rev_year_ago > 0:
|
||||
calc_growth = (rev_now - rev_year_ago) / rev_year_ago
|
||||
return float(calc_growth)
|
||||
|
||||
return 0.0
|
||||
except Exception as e:
|
||||
print(f"Error calculating robust revenue growth for {ticker}: {e}")
|
||||
return 0.0
|
||||
|
|
|
|||
|
|
@ -51,28 +51,34 @@ class RegimeDetector:
|
|||
# 3. Mean reversion tendency (Hurst exponent)
|
||||
hurst = RegimeDetector._calculate_hurst_exponent(prices.tail(window))
|
||||
|
||||
# 4. Directional bias
|
||||
cumulative_return = (prices.iloc[-1] / prices.iloc[-window]) - 1
|
||||
# 4. Directional bias (Cumulative Return)
|
||||
# We check both the specific window and the broader history to capture leaders in consolidation
|
||||
window_return = (prices.iloc[-1] / prices.iloc[-window]) - 1
|
||||
full_history_return = (prices.iloc[-1] / prices.iloc[0]) - 1
|
||||
|
||||
# Classify regime
|
||||
metrics = {
|
||||
"volatility": volatility,
|
||||
"trend_strength": trend_strength,
|
||||
"hurst_exponent": hurst,
|
||||
"cumulative_return": cumulative_return,
|
||||
"cumulative_return": window_return,
|
||||
"overall_return": full_history_return
|
||||
}
|
||||
|
||||
# Decision tree for regime classification
|
||||
if volatility > 0.40: # High volatility (>40% annualized)
|
||||
regime = MarketRegime.VOLATILE
|
||||
elif trend_strength > 25: # Strong trend (ADX > 25)
|
||||
if cumulative_return > 0:
|
||||
# Decision tree for regime classification - Prioritize Trend & Momentum
|
||||
# If ADX > 25, it's trending. We use the broader return to confirm if it's a leader.
|
||||
if trend_strength > 25:
|
||||
if window_return > 0 or full_history_return > 0.10: # Up on window OR strong long-term momentum
|
||||
regime = MarketRegime.TRENDING_UP
|
||||
else:
|
||||
regime = MarketRegime.TRENDING_DOWN
|
||||
elif hurst < 0.5: # Mean reverting (Hurst < 0.5)
|
||||
elif full_history_return > 0.30: # Massive long-term momentum overrides Hurst/Volatility
|
||||
regime = MarketRegime.TRENDING_UP
|
||||
elif volatility > 0.80: # High volatility threshold for individual tech stocks
|
||||
regime = MarketRegime.VOLATILE
|
||||
elif not np.isnan(hurst) and hurst < 0.45: # Tighter mean reversion check
|
||||
regime = MarketRegime.MEAN_REVERTING
|
||||
else: # Low volatility, no clear trend
|
||||
else:
|
||||
regime = MarketRegime.SIDEWAYS
|
||||
|
||||
return regime, metrics
|
||||
|
|
@ -111,21 +117,27 @@ class RegimeDetector:
|
|||
@staticmethod
|
||||
def _calculate_hurst_exponent(prices: pd.Series) -> float:
|
||||
"""
|
||||
Calculate Hurst exponent.
|
||||
|
||||
Returns:
|
||||
H < 0.5: Mean reverting
|
||||
H = 0.5: Random walk
|
||||
H > 0.5: Trending
|
||||
Calculate Hurst exponent with safety checks.
|
||||
"""
|
||||
lags = range(2, 20)
|
||||
tau = [np.std(np.subtract(prices[lag:], prices[:-lag])) for lag in lags]
|
||||
|
||||
# Linear regression of log(tau) vs log(lags)
|
||||
poly = np.polyfit(np.log(lags), np.log(tau), 1)
|
||||
hurst = poly[0]
|
||||
|
||||
return hurst
|
||||
try:
|
||||
lags = range(2, 20)
|
||||
tau = [np.std(np.subtract(prices[lag:], prices[:-lag].values)) for lag in lags]
|
||||
|
||||
# Filter out non-positive values to avoid log errors
|
||||
valid_idx = [i for i, t in enumerate(tau) if t > 0]
|
||||
if len(valid_idx) < 2:
|
||||
return 0.5 # Random walk default
|
||||
|
||||
valid_lags = [lags[i] for i in valid_idx]
|
||||
valid_tau = [tau[i] for i in valid_idx]
|
||||
|
||||
# Linear regression of log(tau) vs log(lags)
|
||||
poly = np.polyfit(np.log(valid_lags), np.log(valid_tau), 1)
|
||||
hurst = poly[0]
|
||||
|
||||
return hurst
|
||||
except Exception:
|
||||
return 0.5 # Default to random walk on error
|
||||
|
||||
|
||||
class DynamicIndicatorSelector:
|
||||
|
|
|
|||
|
|
@ -182,7 +182,10 @@ class TradingAgentsGraph:
|
|||
|
||||
self.ticker = company_name
|
||||
|
||||
# 2. Register real company name for anonymization
|
||||
# 2. Get Hard Data Baseline (Trend Override & Reporting)
|
||||
self.hard_data = self._get_hard_data_metrics(company_name, trade_date)
|
||||
|
||||
# 3. Register real company name for anonymization
|
||||
try:
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
import yfinance as yf
|
||||
|
|
@ -225,10 +228,26 @@ class TradingAgentsGraph:
|
|||
self._log_state(trade_date, final_state)
|
||||
|
||||
# 3. FIX CRASH RISK: Handle Dead State gracefully
|
||||
# First, extract raw decision from LLM text (The Agent Decision)
|
||||
raw_llm_decision = final_state["final_trade_decision"]
|
||||
|
||||
# Apply Technical Override (Don't Fight the Tape)
|
||||
regime_val = final_state.get("market_regime", "UNKNOWN").upper()
|
||||
print(f"\n🔍 [DEBUG] APPLYING OVERRIDE: Regime='{regime_val}', Growth={self.hard_data.get('revenue_growth', 'N/A')}")
|
||||
|
||||
overridden_decision = self.apply_trend_override(
|
||||
raw_llm_decision,
|
||||
self.hard_data,
|
||||
regime_val
|
||||
)
|
||||
|
||||
# Update final state with potentially overridden decision
|
||||
final_state["final_trade_decision"] = overridden_decision
|
||||
|
||||
trade_decision = final_state["final_trade_decision"]
|
||||
|
||||
# If trade was rejected by a Gate (Fact Check or Risk), return raw decision
|
||||
if trade_decision.get("action") == "HOLD" and "REJECTED" in trade_decision.get("reasoning", ""):
|
||||
if isinstance(trade_decision, dict) and trade_decision.get("action") == "HOLD" and "REJECTED" in trade_decision.get("reasoning", ""):
|
||||
processed_signal = {
|
||||
"action": "HOLD",
|
||||
"quantity": 0,
|
||||
|
|
@ -305,4 +324,91 @@ class TradingAgentsGraph:
|
|||
|
||||
def process_signal(self, full_signal):
|
||||
"""Process a signal to extract the core decision."""
|
||||
# Handle dict if signal was overridden, otherwise handle string from LLM
|
||||
if isinstance(full_signal, dict):
|
||||
return {
|
||||
"action": full_signal.get("action", "HOLD"),
|
||||
"quantity": full_signal.get("quantity", 0),
|
||||
"reason": full_signal.get("reasoning", "OVERRIDDEN")
|
||||
}
|
||||
return self.signal_processor.process_signal(full_signal)
|
||||
|
||||
def _get_hard_data_metrics(self, ticker: str, trade_date: str) -> Dict[str, Any]:
|
||||
"""Fetch raw technical and fundamental data for the override gate."""
|
||||
try:
|
||||
import yfinance as yf
|
||||
from datetime import datetime, timedelta
|
||||
from tradingagents.dataflows.y_finance import get_robust_revenue_growth
|
||||
|
||||
dt_obj = datetime.strptime(trade_date, "%Y-%m-%d")
|
||||
# Fetch 300 days of history to ensure we can calculate 200 SMA
|
||||
start_date = (dt_obj - timedelta(days=450)).strftime("%Y-%m-%d")
|
||||
|
||||
ticker_obj = yf.Ticker(ticker.upper())
|
||||
history = ticker_obj.history(start=start_date, end=trade_date)
|
||||
|
||||
metrics = {
|
||||
"current_price": 0.0,
|
||||
"sma_200": 0.0,
|
||||
"revenue_growth": 0.0,
|
||||
"status": "ERROR"
|
||||
}
|
||||
|
||||
if not history.empty and len(history) >= 200:
|
||||
metrics["current_price"] = history["Close"].iloc[-1]
|
||||
metrics["sma_200"] = history["Close"].rolling(200).mean().iloc[-1]
|
||||
metrics["status"] = "OK"
|
||||
|
||||
metrics["revenue_growth"] = get_robust_revenue_growth(ticker)
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching hard data for {ticker} override: {e}")
|
||||
return {"status": "ERROR", "error": str(e)}
|
||||
|
||||
def apply_trend_override(self, trade_decision_str: str, hard_data: Dict[str, Any], regime: str) -> Any:
|
||||
"""
|
||||
The 'Don't Fight the Tape' Safety Valve.
|
||||
Prevents the system from shorting high-growth winners during a Bull Market.
|
||||
"""
|
||||
if hard_data.get("status") != "OK":
|
||||
return trade_decision_str
|
||||
|
||||
regime = str(regime).strip().upper()
|
||||
|
||||
price = hard_data["current_price"]
|
||||
sma_200 = hard_data["sma_200"]
|
||||
growth = hard_data["revenue_growth"]
|
||||
|
||||
# 1. Technical Uptrend (Price > 200 SMA)
|
||||
is_technical_uptrend = price > sma_200
|
||||
|
||||
# 2. Hyper-Growth (> 30% YoY)
|
||||
is_hyper_growth = growth > 0.30
|
||||
|
||||
# 3. Supportive Regime (Protect leaders unless it's a clear TRENDING_DOWN regime)
|
||||
is_bear_regime = regime in ["TRENDING_DOWN", "BEAR", "BEARISH"]
|
||||
is_bull_regime = not is_bear_regime
|
||||
|
||||
# 4. Trigger Override if trying to SELL a leader in a bull market
|
||||
if is_technical_uptrend and is_hyper_growth and is_bull_regime:
|
||||
# We check if the decision string contains SELL or STRONG_SELL
|
||||
# (llm output is usually messy text, so we check for the verdict)
|
||||
decision_upper = trade_decision_str.upper()
|
||||
if "SELL" in decision_upper:
|
||||
print(f"\n🛑 TREND OVERRIDE TRIGGERED for {self.ticker}")
|
||||
print(f" Reason: Stock (${price:.2f}) is > 200SMA (${sma_200:.2f}) and Growth is {growth:.1%}")
|
||||
print(f" Action 'SELL' blocked. Converting to 'HOLD'.\n")
|
||||
|
||||
return {
|
||||
"action": "HOLD",
|
||||
"quantity": 0,
|
||||
"reasoning": (
|
||||
f"OVERRIDE: System attempted to short a Hyper-Growth stock ({growth:.1%}) "
|
||||
f"above its 200-day trend (${sma_200:.2f}) in a Bull regime. "
|
||||
f"Original Decision: {trade_decision_str[:100]}..."
|
||||
),
|
||||
"confidence": 1.0
|
||||
}
|
||||
|
||||
return trade_decision_str
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TickerAnonymizer:
|
|||
self.auto_persist = auto_persist
|
||||
|
||||
# Persistence path
|
||||
self.map_file = Path("ticker_map.json")
|
||||
self.map_file = Path(__file__).resolve().parent.parent.parent / "ticker_map.json"
|
||||
if self.auto_persist:
|
||||
self._load_from_file()
|
||||
|
||||
|
|
@ -299,8 +299,14 @@ class TickerAnonymizer:
|
|||
print(f"✅ Loaded mapping from {input_path}")
|
||||
|
||||
def deanonymize_ticker(self, anon_ticker: str) -> Optional[str]:
|
||||
"""Reverse mapping: ASSET_042 → AAPL."""
|
||||
return self.reverse_map.get(anon_ticker)
|
||||
"""Reverse mapping: ASSET_042 → AAPL. Robust to 'Company' prefixes."""
|
||||
if not anon_ticker:
|
||||
return None
|
||||
|
||||
# Strip common prefixes that LLMs might include from text
|
||||
clean_ticker = anon_ticker.upper().replace("COMPANY", "").replace("COMPANY_", "").replace("CORPORATION", "").strip()
|
||||
|
||||
return self.reverse_map.get(clean_ticker)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue