feat: implement trend override, harden regime detection, and organize tests

**Core Logic (Safety Valve & Regime Detection):**

  *   Added "Momentum Override": `Overall Return > 30%` now forces `TRENDING_UP` (Bull) regime to capture volatile winners.
    *   Prioritized Trend Strength (ADX) over Volatility for single stocks.
    *   Fixed `Hurst Exponent` calculation to handle non-positive inputs safely.
*   **Data Reliability ([market_analyst.py]
This commit is contained in:
swj.premkumar 2026-01-11 11:18:46 -06:00
parent d0f229a444
commit a6e4c9b770
18 changed files with 462 additions and 42 deletions

62
tests/README.md Normal file
View File

@ -0,0 +1,62 @@
# Trading Agents Verification Suite
This folder contains unit tests and verification scripts to validate the functionality of the Trading Agents system.
## Available Tests
## Core Logic Tests
1. **`test_regime_detection.py`**
* **Purpose:** Validates mathematical components (ADX, Volatility, Hurst) of the `RegimeDetector`.
* **Usage:** `python tests/test_regime_detection.py`
2. **`test_market_node.py`**
* **Purpose:** End-to-end verification of `market_analyst_node`. Checks data fetching logic and regime integration.
* **Usage:** `python tests/test_market_node.py`
3. **`test_override.py`**
* **Purpose:** Unit tests for "Don't Fight the Tape" safety logic. Verifies protection of growth leaders.
* **Usage:** `python tests/test_override.py`
## Integration & API Tests
4. **`test_global_news.py`**
* **Purpose:** Verifies news fetching capabilities.
* **Usage:** `python tests/test_global_news.py`
5. **`test_google_api.py`** & **`verify_google_key.py`**
* **Purpose:** Validates Google Gemini API connectivity and key validity.
* **Usage:** `python tests/test_google_api.py`
6. **`verify_alpaca.py`**
* **Purpose:** Checks Alpaca trading API connection.
* **Usage:** `python tests/verify_alpaca.py`
## Infrastructure & Performance
7. **`verify_local_embeddings.py`** & **`verify_ollama_embeddings.py`**
* **Purpose:** Validates local embedding models (Ollama/TEI) for RAG.
* **Usage:** `python tests/verify_local_embeddings.py`
8. **`verify_tei_native.py`**
* **Purpose:** Tests Text Embeddings Inference (TEI) native endpoint.
* **Usage:** `python tests/verify_tei_native.py`
9. **`bench_yfinance.py`**
* **Purpose:** Benchmarks yfinance data fetch performance (latency/throughput).
* **Usage:** `python tests/bench_yfinance.py`
10. **`verify_regime_integration.py`**
* **Purpose:** Integration test for regime detection within the broader graph context.
* **Usage:** `python tests/verify_regime_integration.py`
## How to Run
Ensure your virtual environment is activated:
```bash
source .venv/bin/activate
export PYTHONPATH=$PYTHONPATH:.
python tests/test_market_node.py
```

61
tests/test_market_node.py Normal file
View File

@ -0,0 +1,61 @@
import os
import sys
import json
from pathlib import Path
from dotenv import load_dotenv
sys.path.append(str(Path(__file__).parent.parent))
# Load env before imports
load_dotenv()
from langchain_core.runnables import Runnable
from langchain_core.messages import AIMessage
from tradingagents.agents.analysts.market_analyst import create_market_analyst
class MockLLM(Runnable):
def bind_tools(self, tools, **kwargs):
return self
def invoke(self, input, config=None, **kwargs):
return AIMessage(content="Mock Market Analysis Report")
def test_market_analyst_node():
print("🔍 TESTING MARKET ANALYST NODE...")
# 1. Setup
mock_llm = MockLLM()
market_analyst_node = create_market_analyst(mock_llm)
# 2. Mock State
state = {
"company_of_interest": "PLTR",
"trade_date": "2026-01-11",
"messages": []
}
# 3. Execution
print(f" Executing node for {state['company_of_interest']}...")
try:
# Pass only state as fixed in previous steps
result = market_analyst_node(state)
# 4. Verification
regime = result.get("market_regime")
metrics = result.get("regime_metrics", {})
print(f"📊 RESULTING REGIME: {regime}")
print(f" METRICS: {json.dumps(metrics, indent=2)}")
if regime != "UNKNOWN" and metrics:
print("✅ PASS: Regime detected correctly!")
else:
print("❌ FAIL: Regime is UNKNOWN or metrics missing.")
except Exception as e:
print(f"❌ ERROR: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_market_analyst_node()

69
tests/test_override.py Normal file
View File

@ -0,0 +1,69 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
from tradingagents.graph.trading_graph import TradingAgentsGraph
# Mock class to expose the method without full initialization
class MockGraph(TradingAgentsGraph):
def __init__(self):
# Skip super init to avoid API keys requirements
self.ticker = "MOCK_TICKER"
def test_trend_override():
print("🔍 TESTING TREND OVERRIDE LOGIC...")
agent = MockGraph()
# Test Case 1: PLTR Scenario (Sell in Bull Market)
print("\n[TEST 1] PLTR Scenario: Sell signal in Bull Market")
decision = "FINAL TRANSACTION PROPOSAL: SELL 75%"
hard_data = {
"status": "OK",
"current_price": 185.0,
"sma_200": 150.0,
"revenue_growth": 0.62
}
regime = "TRENDING_UP"
result = agent.apply_trend_override(decision, hard_data, regime)
print(f"Input: {decision}")
print(f"Regime: {regime}")
if isinstance(result, dict) and result.get("action") == "HOLD":
print("✅ PASS: Correctly recognized uptrend + growth to block SELL")
else:
print(f"❌ FAIL: Returned {result}")
# Test Case 2: Volatile Regime (Should still protect leader)
print("\n[TEST 2] Volatile Regime protection")
regime = "VOLATILE"
result = agent.apply_trend_override(decision, hard_data, regime)
print(f"Regime: {regime}")
if isinstance(result, dict) and result.get("action") == "HOLD":
print("✅ PASS: Protected leader in VOLATILE regime")
else:
print(f"❌ FAIL: Returned {result}")
# Test Case 3: Bear Market (Should allow sell)
print("\n[TEST 3] Bear Market (Should allow SELL)")
regime = "TRENDING_DOWN"
result = agent.apply_trend_override(decision, hard_data, regime)
print(f"Regime: {regime}")
if result == decision:
print("✅ PASS: Allowed SELL in Bear Market")
else:
print(f"❌ FAIL: Blocked SELL improperly: {result}")
# Test Case 4: Low Growth (Should allow sell)
print("\n[TEST 4] Low Growth (Should allow SELL)")
hard_data["revenue_growth"] = 0.10
regime = "TRENDING_UP"
result = agent.apply_trend_override(decision, hard_data, regime)
if result == decision:
print("✅ PASS: Allowed SELL for low growth stock")
else:
print(f"❌ FAIL: Blocked SELL for low growth: {result}")
if __name__ == "__main__":
test_trend_override()

View File

@ -0,0 +1,48 @@
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))
import pandas as pd
from io import StringIO
from datetime import datetime, timedelta
import yfinance as yf
from tradingagents.engines.regime_detector import RegimeDetector, DynamicIndicatorSelector
def test_regime_detection():
print("🧪 Testing Regime Detection for PLTR...")
ticker = "PLTR"
current_date = "2026-01-11"
# Simulate the same logic as market_analyst_node
dt_obj = datetime.strptime(current_date, "%Y-%m-%d")
start_date = (dt_obj - timedelta(days=365)).strftime("%Y-%m-%d")
print(f" Fetching data from {start_date} to {current_date}")
# 1. Fetch raw data (simulating the tool call)
ticker_obj = yf.Ticker(ticker)
data = ticker_obj.history(start=start_date, end=current_date)
if data.empty:
print("❌ FAILURE: No data retrieved from yfinance.")
return
# Check columns
print(f" Columns found: {list(data.columns)}")
# 2. Detect Regime
try:
prices = data['Close']
regime, metrics = RegimeDetector.detect_regime(prices)
print(f"✅ SUCCESS: Regime detected: {regime.value}")
print(f" Metrics: {metrics}")
# Check if it matches 'trending_up' (as it should for PLTR in this hypothetical 2026 bull scenario)
if regime.value == "trending_up":
print("🌟 PLTR is in a BULL TREND.")
except Exception as e:
print(f"❌ FAILURE: Regime detection failed: {e}")
if __name__ == "__main__":
test_regime_detection()

View File

@ -13,14 +13,13 @@ from io import StringIO
from datetime import datetime, timedelta
# Initialize anonymizer (shared instance appropriate here or inside)
anonymizer = TickerAnonymizer()
def create_market_analyst(llm):
def market_analyst_node(state):
current_date = state["trade_date"]
# Re-initialize or reload anonymizer state
anonymizer = TickerAnonymizer()
real_ticker = state["company_of_interest"]
# BLINDFIRE PROTOCOL: Anonymize Ticker
ticker = anonymizer.anonymize_ticker(real_ticker)
# NOTE: We continue to use 'ticker' variable name but it now holds 'ASSET_XXX'
@ -38,25 +37,45 @@ def create_market_analyst(llm):
start_date = (dt_obj - timedelta(days=365)).strftime("%Y-%m-%d")
# Fetch data for regime detection using the anonymized ticker
# This calls the tool which handles deanonymization internally if needed
# (assuming core_stock_tools.get_stock_data handles the 'ASSET_XXX' -> Real mapping)
# Use invoke for StructuredTool with ALL required args
raw_data = get_stock_data.invoke({
"symbol": ticker,
"symbol": real_ticker,
"start_date": start_date,
"end_date": current_date,
"format": "csv"
})
# Parse data
if isinstance(raw_data, str) and "Error" not in raw_data and "No data" not in raw_data:
if isinstance(raw_data, str) and len(raw_data.strip()) > 50 and "Error" not in raw_data and "No data" not in raw_data:
# Parse data (Standardized CSV format with # comments)
df = pd.read_csv(StringIO(raw_data), comment='#')
# Check for Close column
# Handle case-insensitive 'Close' column
if 'Close' not in df.columns:
# Try to find a column that matches 'close' case-insensitively
col_map = {c.lower(): c for c in df.columns}
if 'close' in col_map:
df.rename(columns={col_map['close']: 'Close'}, inplace=True)
# Clean index/date
if 'Date' in df.columns:
df['Date'] = pd.to_datetime(df['Date'])
df.set_index('Date', inplace=True)
# Sort by date
df.sort_index(inplace=True)
# Check for sufficient data
# Ensure 'Close' column exists after potential renaming
if 'Close' in df.columns:
price_data = df['Close']
else:
price_data = pd.Series([]) # Empty series if 'Close' column is not found
print(f"DEBUG: Regime Detection - Ticker: {real_ticker}, Rows: {len(price_data)}")
if not price_data.empty and len(price_data) >= 10:
# Detect Regime
regime, metrics = RegimeDetector.detect_regime(df['Close'])
regime, metrics = RegimeDetector.detect_regime(price_data)
optimal_params = DynamicIndicatorSelector.get_optimal_parameters(regime)
regime_val = regime.value
volatility_score = metrics.get("volatility", 0.0)
@ -67,6 +86,10 @@ def create_market_analyst(llm):
regime_context += f"RECOMMENDED STRATEGY: {optimal_params.get('strategy', 'N/A')}\n"
regime_context += f"RECOMMENDED INDICATORS: {json.dumps(optimal_params)}\n"
regime_context += f"RATIONALE: {optimal_params.get('rationale', '')}"
else:
print(f"WARNING: Insufficient price data for {ticker}. Columns: {list(df.columns)}, Len: {len(df)}")
else:
print(f"WARNING: Market data retrieval failed for regime detection for {ticker}. Data snippet: {str(raw_data)[:100]}")
except Exception as e:
print(f"WARNING: Regime detection failed for {ticker}: {e}")

View File

@ -2,6 +2,7 @@ from typing import Annotated
from datetime import datetime
from dateutil.relativedelta import relativedelta
import yfinance as yf
import pandas as pd
import os
from .stockstats_utils import StockstatsUtils
@ -313,7 +314,7 @@ def get_balance_sheet(
else:
data = ticker_obj.balance_sheet
if data.empty:
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
return f"No balance sheet data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
@ -343,7 +344,7 @@ def get_cashflow(
else:
data = ticker_obj.cashflow
if data.empty:
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
return f"No cash flow data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
@ -373,7 +374,7 @@ def get_income_statement(
else:
data = ticker_obj.income_stmt
if data.empty:
if data is None or (isinstance(data, pd.DataFrame) and data.empty):
return f"No income statement data found for symbol '{ticker}'"
# Convert to CSV string for consistency with other functions
@ -460,3 +461,35 @@ def get_fundamentals(
except Exception as e:
return f"Error retrieving fundamentals for {ticker}: {str(e)}"
def get_robust_revenue_growth(ticker: str) -> float:
"""
Retrieve revenue growth with fallback to manual calculation from quarterly financials.
Returns growth as a float (e.g., 0.63 for 63%).
"""
try:
ticker_obj = yf.Ticker(ticker.upper())
# 1. Try .info first (Quick but often unreliable/stale)
info = ticker_obj.info
growth = info.get('revenueGrowth')
if growth is not None and isinstance(growth, (int, float)) and growth != 0:
return float(growth)
# 2. Fallback: Manual calculation from Quarterly Financials
# Formula: (Revenue_Current_Q - Revenue_Year_Ago_Q) / Revenue_Year_Ago_Q
q_financials = ticker_obj.quarterly_financials
if q_financials is not None and not q_financials.empty and 'Total Revenue' in q_financials.index:
rev_series = q_financials.loc['Total Revenue']
if len(rev_series) >= 5: # Need at least 5 quarters to compare Q1 vs Q5 (Year Ago)
rev_now = rev_series.iloc[0]
rev_year_ago = rev_series.iloc[4]
if rev_year_ago and rev_year_ago > 0:
calc_growth = (rev_now - rev_year_ago) / rev_year_ago
return float(calc_growth)
return 0.0
except Exception as e:
print(f"Error calculating robust revenue growth for {ticker}: {e}")
return 0.0

View File

@ -51,28 +51,34 @@ class RegimeDetector:
# 3. Mean reversion tendency (Hurst exponent)
hurst = RegimeDetector._calculate_hurst_exponent(prices.tail(window))
# 4. Directional bias
cumulative_return = (prices.iloc[-1] / prices.iloc[-window]) - 1
# 4. Directional bias (Cumulative Return)
# We check both the specific window and the broader history to capture leaders in consolidation
window_return = (prices.iloc[-1] / prices.iloc[-window]) - 1
full_history_return = (prices.iloc[-1] / prices.iloc[0]) - 1
# Classify regime
metrics = {
"volatility": volatility,
"trend_strength": trend_strength,
"hurst_exponent": hurst,
"cumulative_return": cumulative_return,
"cumulative_return": window_return,
"overall_return": full_history_return
}
# Decision tree for regime classification
if volatility > 0.40: # High volatility (>40% annualized)
regime = MarketRegime.VOLATILE
elif trend_strength > 25: # Strong trend (ADX > 25)
if cumulative_return > 0:
# Decision tree for regime classification - Prioritize Trend & Momentum
# If ADX > 25, it's trending. We use the broader return to confirm if it's a leader.
if trend_strength > 25:
if window_return > 0 or full_history_return > 0.10: # Up on window OR strong long-term momentum
regime = MarketRegime.TRENDING_UP
else:
regime = MarketRegime.TRENDING_DOWN
elif hurst < 0.5: # Mean reverting (Hurst < 0.5)
elif full_history_return > 0.30: # Massive long-term momentum overrides Hurst/Volatility
regime = MarketRegime.TRENDING_UP
elif volatility > 0.80: # High volatility threshold for individual tech stocks
regime = MarketRegime.VOLATILE
elif not np.isnan(hurst) and hurst < 0.45: # Tighter mean reversion check
regime = MarketRegime.MEAN_REVERTING
else: # Low volatility, no clear trend
else:
regime = MarketRegime.SIDEWAYS
return regime, metrics
@ -111,21 +117,27 @@ class RegimeDetector:
@staticmethod
def _calculate_hurst_exponent(prices: pd.Series) -> float:
"""
Calculate Hurst exponent.
Returns:
H < 0.5: Mean reverting
H = 0.5: Random walk
H > 0.5: Trending
Calculate Hurst exponent with safety checks.
"""
lags = range(2, 20)
tau = [np.std(np.subtract(prices[lag:], prices[:-lag])) for lag in lags]
# Linear regression of log(tau) vs log(lags)
poly = np.polyfit(np.log(lags), np.log(tau), 1)
hurst = poly[0]
return hurst
try:
lags = range(2, 20)
tau = [np.std(np.subtract(prices[lag:], prices[:-lag].values)) for lag in lags]
# Filter out non-positive values to avoid log errors
valid_idx = [i for i, t in enumerate(tau) if t > 0]
if len(valid_idx) < 2:
return 0.5 # Random walk default
valid_lags = [lags[i] for i in valid_idx]
valid_tau = [tau[i] for i in valid_idx]
# Linear regression of log(tau) vs log(lags)
poly = np.polyfit(np.log(valid_lags), np.log(valid_tau), 1)
hurst = poly[0]
return hurst
except Exception:
return 0.5 # Default to random walk on error
class DynamicIndicatorSelector:

View File

@ -182,7 +182,10 @@ class TradingAgentsGraph:
self.ticker = company_name
# 2. Register real company name for anonymization
# 2. Get Hard Data Baseline (Trend Override & Reporting)
self.hard_data = self._get_hard_data_metrics(company_name, trade_date)
# 3. Register real company name for anonymization
try:
from tradingagents.utils.anonymizer import TickerAnonymizer
import yfinance as yf
@ -225,10 +228,26 @@ class TradingAgentsGraph:
self._log_state(trade_date, final_state)
# 3. FIX CRASH RISK: Handle Dead State gracefully
# First, extract raw decision from LLM text (The Agent Decision)
raw_llm_decision = final_state["final_trade_decision"]
# Apply Technical Override (Don't Fight the Tape)
regime_val = final_state.get("market_regime", "UNKNOWN").upper()
print(f"\n🔍 [DEBUG] APPLYING OVERRIDE: Regime='{regime_val}', Growth={self.hard_data.get('revenue_growth', 'N/A')}")
overridden_decision = self.apply_trend_override(
raw_llm_decision,
self.hard_data,
regime_val
)
# Update final state with potentially overridden decision
final_state["final_trade_decision"] = overridden_decision
trade_decision = final_state["final_trade_decision"]
# If trade was rejected by a Gate (Fact Check or Risk), return raw decision
if trade_decision.get("action") == "HOLD" and "REJECTED" in trade_decision.get("reasoning", ""):
if isinstance(trade_decision, dict) and trade_decision.get("action") == "HOLD" and "REJECTED" in trade_decision.get("reasoning", ""):
processed_signal = {
"action": "HOLD",
"quantity": 0,
@ -305,4 +324,91 @@ class TradingAgentsGraph:
def process_signal(self, full_signal):
"""Process a signal to extract the core decision."""
# Handle dict if signal was overridden, otherwise handle string from LLM
if isinstance(full_signal, dict):
return {
"action": full_signal.get("action", "HOLD"),
"quantity": full_signal.get("quantity", 0),
"reason": full_signal.get("reasoning", "OVERRIDDEN")
}
return self.signal_processor.process_signal(full_signal)
def _get_hard_data_metrics(self, ticker: str, trade_date: str) -> Dict[str, Any]:
"""Fetch raw technical and fundamental data for the override gate."""
try:
import yfinance as yf
from datetime import datetime, timedelta
from tradingagents.dataflows.y_finance import get_robust_revenue_growth
dt_obj = datetime.strptime(trade_date, "%Y-%m-%d")
# Fetch 300 days of history to ensure we can calculate 200 SMA
start_date = (dt_obj - timedelta(days=450)).strftime("%Y-%m-%d")
ticker_obj = yf.Ticker(ticker.upper())
history = ticker_obj.history(start=start_date, end=trade_date)
metrics = {
"current_price": 0.0,
"sma_200": 0.0,
"revenue_growth": 0.0,
"status": "ERROR"
}
if not history.empty and len(history) >= 200:
metrics["current_price"] = history["Close"].iloc[-1]
metrics["sma_200"] = history["Close"].rolling(200).mean().iloc[-1]
metrics["status"] = "OK"
metrics["revenue_growth"] = get_robust_revenue_growth(ticker)
return metrics
except Exception as e:
print(f"Error fetching hard data for {ticker} override: {e}")
return {"status": "ERROR", "error": str(e)}
def apply_trend_override(self, trade_decision_str: str, hard_data: Dict[str, Any], regime: str) -> Any:
"""
The 'Don't Fight the Tape' Safety Valve.
Prevents the system from shorting high-growth winners during a Bull Market.
"""
if hard_data.get("status") != "OK":
return trade_decision_str
regime = str(regime).strip().upper()
price = hard_data["current_price"]
sma_200 = hard_data["sma_200"]
growth = hard_data["revenue_growth"]
# 1. Technical Uptrend (Price > 200 SMA)
is_technical_uptrend = price > sma_200
# 2. Hyper-Growth (> 30% YoY)
is_hyper_growth = growth > 0.30
# 3. Supportive Regime (Protect leaders unless it's a clear TRENDING_DOWN regime)
is_bear_regime = regime in ["TRENDING_DOWN", "BEAR", "BEARISH"]
is_bull_regime = not is_bear_regime
# 4. Trigger Override if trying to SELL a leader in a bull market
if is_technical_uptrend and is_hyper_growth and is_bull_regime:
# We check if the decision string contains SELL or STRONG_SELL
# (llm output is usually messy text, so we check for the verdict)
decision_upper = trade_decision_str.upper()
if "SELL" in decision_upper:
print(f"\n🛑 TREND OVERRIDE TRIGGERED for {self.ticker}")
print(f" Reason: Stock (${price:.2f}) is > 200SMA (${sma_200:.2f}) and Growth is {growth:.1%}")
print(f" Action 'SELL' blocked. Converting to 'HOLD'.\n")
return {
"action": "HOLD",
"quantity": 0,
"reasoning": (
f"OVERRIDE: System attempted to short a Hyper-Growth stock ({growth:.1%}) "
f"above its 200-day trend (${sma_200:.2f}) in a Bull regime. "
f"Original Decision: {trade_decision_str[:100]}..."
),
"confidence": 1.0
}
return trade_decision_str

View File

@ -34,7 +34,7 @@ class TickerAnonymizer:
self.auto_persist = auto_persist
# Persistence path
self.map_file = Path("ticker_map.json")
self.map_file = Path(__file__).resolve().parent.parent.parent / "ticker_map.json"
if self.auto_persist:
self._load_from_file()
@ -299,8 +299,14 @@ class TickerAnonymizer:
print(f"✅ Loaded mapping from {input_path}")
def deanonymize_ticker(self, anon_ticker: str) -> Optional[str]:
"""Reverse mapping: ASSET_042 → AAPL."""
return self.reverse_map.get(anon_ticker)
"""Reverse mapping: ASSET_042 → AAPL. Robust to 'Company' prefixes."""
if not anon_ticker:
return None
# Strip common prefixes that LLMs might include from text
clean_ticker = anon_ticker.upper().replace("COMPANY", "").replace("COMPANY_", "").replace("CORPORATION", "").strip()
return self.reverse_map.get(clean_ticker)