diff --git a/pyproject.toml b/pyproject.toml index 4c91a733..fb2a621a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,10 @@ dependencies = [ "tqdm>=4.67.1", "typing-extensions>=4.14.0", "yfinance>=0.2.63", + "scipy>=1.11.0", + "scikit-learn>=1.3.0", + "PyWavelets>=1.4.0", + "networkx>=3.1", ] [project.scripts] diff --git a/requirements.txt b/requirements.txt index 184468b8..0c46097d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,7 @@ typer questionary langchain_anthropic langchain-google-genai +scipy +scikit-learn +PyWavelets +networkx diff --git a/tradingagents/cross_asset_correlation/__init__.py b/tradingagents/cross_asset_correlation/__init__.py new file mode 100644 index 00000000..462bcad1 --- /dev/null +++ b/tradingagents/cross_asset_correlation/__init__.py @@ -0,0 +1,22 @@ +""" +Cross-Asset Correlation Engine (CACE) +Advanced correlation analysis for multi-asset trading strategies. + +This module implements sophisticated correlation analysis techniques +for detecting relationships between different financial instruments. +Based on academic research in multivariate time series analysis. + +Author: Research Team +Date: 2026-02-07 +""" + +from .correlation_analyzer import CorrelationAnalyzer +from .multi_asset_processor import MultiAssetProcessor +from .correlation_regime_detector import CorrelationRegimeDetector + +__version__ = "1.0.0" +__all__ = [ + "CorrelationAnalyzer", + "MultiAssetProcessor", + "CorrelationRegimeDetector", +] \ No newline at end of file diff --git a/tradingagents/cross_asset_correlation/correlation_analyzer.py b/tradingagents/cross_asset_correlation/correlation_analyzer.py new file mode 100644 index 00000000..1c04c371 --- /dev/null +++ b/tradingagents/cross_asset_correlation/correlation_analyzer.py @@ -0,0 +1,412 @@ +""" +Correlation Analyzer +Core engine for cross-asset correlation analysis. + +Implements multiple correlation techniques: +1. Pearson correlation for linear relationships +2. Spearman rank correlation for monotonic relationships +3. Dynamic Conditional Correlation (DCC-GARCH) +4. Wavelet coherence for multi-timescale analysis +5. Lead-lag correlation with time shifts + +Based on academic research: +- Engle (2002): Dynamic Conditional Correlation +- Grinsted et al. (2004): Wavelet coherence for geophysical time series +- Alexander (2001): Correlation and cointegration in financial markets +""" + +import numpy as np +import pandas as pd +from typing import Dict, List, Tuple, Optional, Union +from dataclasses import dataclass +from enum import Enum +import warnings +from scipy import stats +from scipy.signal import correlate +import pywt # wavelet transform + + +class CorrelationMethod(Enum): + """Correlation calculation methods.""" + PEARSON = "pearson" + SPEARMAN = "spearman" + KENDALL = "kendall" + DCC_GARCH = "dcc_garch" + WAVELET = "wavelet" + LEAD_LAG = "lead_lag" + + +@dataclass +class CorrelationResult: + """Container for correlation analysis results.""" + assets: Tuple[str, str] + method: CorrelationMethod + correlation: float + p_value: Optional[float] = None + confidence_interval: Optional[Tuple[float, float]] = None + lag: Optional[int] = None # for lead-lag analysis + time_scale: Optional[float] = None # for wavelet analysis + metadata: Optional[Dict] = None + + +class CorrelationAnalyzer: + """Main correlation analysis engine.""" + + def __init__(self, min_data_points: int = 30, significance_level: float = 0.05): + """ + Initialize the correlation analyzer. + + Args: + min_data_points: Minimum data points required for analysis + significance_level: Statistical significance level for p-values + """ + self.min_data_points = min_data_points + self.significance_level = significance_level + + def analyze_pair( + self, + asset1_prices: pd.Series, + asset2_prices: pd.Series, + methods: List[CorrelationMethod] = None, + max_lag: int = 10 + ) -> List[CorrelationResult]: + """ + Analyze correlation between two asset price series. + + Args: + asset1_prices: Price series for first asset + asset2_prices: Price series for second asset + methods: List of correlation methods to apply + max_lag: Maximum lag for lead-lag analysis + + Returns: + List of correlation results for each method + """ + if methods is None: + methods = [ + CorrelationMethod.PEARSON, + CorrelationMethod.SPEARMAN, + CorrelationMethod.LEAD_LAG + ] + + # Validate input data + self._validate_input(asset1_prices, asset2_prices) + + # Align time series + aligned_data = self._align_series(asset1_prices, asset2_prices) + if aligned_data is None: + return [] + + asset1_aligned, asset2_aligned = aligned_data + + results = [] + + for method in methods: + try: + if method == CorrelationMethod.PEARSON: + result = self._pearson_correlation(asset1_aligned, asset2_aligned) + elif method == CorrelationMethod.SPEARMAN: + result = self._spearman_correlation(asset1_aligned, asset2_aligned) + elif method == CorrelationMethod.KENDALL: + result = self._kendall_correlation(asset1_aligned, asset2_aligned) + elif method == CorrelationMethod.LEAD_LAG: + result = self._lead_lag_correlation(asset1_aligned, asset2_aligned, max_lag) + elif method == CorrelationMethod.DCC_GARCH: + result = self._dcc_garch_correlation(asset1_aligned, asset2_aligned) + elif method == CorrelationMethod.WAVELET: + result = self._wavelet_coherence(asset1_aligned, asset2_aligned) + else: + continue + + results.append(result) + + except Exception as e: + warnings.warn(f"Failed to compute {method.value} correlation: {e}") + continue + + return results + + def analyze_portfolio( + self, + price_data: pd.DataFrame, + methods: List[CorrelationMethod] = None + ) -> pd.DataFrame: + """ + Analyze correlation matrix for multiple assets. + + Args: + price_data: DataFrame with assets as columns and prices as rows + methods: Correlation methods to apply + + Returns: + Correlation matrix DataFrame + """ + if methods is None: + methods = [CorrelationMethod.PEARSON] + + assets = price_data.columns.tolist() + n_assets = len(assets) + + # Initialize correlation matrix + corr_matrix = pd.DataFrame( + np.eye(n_assets), + index=assets, + columns=assets + ) + + # Fill correlation matrix + for i in range(n_assets): + for j in range(i + 1, n_assets): + asset_i = price_data.iloc[:, i] + asset_j = price_data.iloc[:, j] + + results = self.analyze_pair(asset_i, asset_j, methods) + if results: + # Use first method's correlation + corr_value = results[0].correlation + corr_matrix.iloc[i, j] = corr_value + corr_matrix.iloc[j, i] = corr_value + + return corr_matrix + + def rolling_correlation( + self, + asset1_prices: pd.Series, + asset2_prices: pd.Series, + window: int = 20, + method: CorrelationMethod = CorrelationMethod.PEARSON + ) -> pd.Series: + """ + Compute rolling correlation between two assets. + + Args: + asset1_prices: Price series for first asset + asset2_prices: Price series for second asset + window: Rolling window size + method: Correlation method to use + + Returns: + Series of rolling correlation values + """ + aligned_data = self._align_series(asset1_prices, asset2_prices) + if aligned_data is None: + return pd.Series([], dtype=float) + + asset1_aligned, asset2_aligned = aligned_data + + # Create DataFrame for rolling calculation + df = pd.DataFrame({ + 'asset1': asset1_aligned, + 'asset2': asset2_aligned + }) + + if method == CorrelationMethod.PEARSON: + return df['asset1'].rolling(window).corr(df['asset2']) + elif method == CorrelationMethod.SPEARMAN: + # Spearman rolling correlation + rolling_corr = [] + for i in range(len(df) - window + 1): + window_data = df.iloc[i:i + window] + corr, _ = stats.spearmanr(window_data['asset1'], window_data['asset2']) + rolling_corr.append(corr) + return pd.Series(rolling_corr, index=df.index[window - 1:]) + else: + raise ValueError(f"Rolling correlation not implemented for {method}") + + def _validate_input(self, series1: pd.Series, series2: pd.Series): + """Validate input time series.""" + if len(series1) < self.min_data_points or len(series2) < self.min_data_points: + raise ValueError( + f"Insufficient data points. Minimum required: {self.min_data_points}" + ) + + if series1.isnull().all() or series2.isnull().all(): + raise ValueError("Input series contain only NaN values") + + def _align_series(self, series1: pd.Series, series2: pd.Series) -> Optional[Tuple[pd.Series, pd.Series]]: + """Align two time series on their index.""" + # Convert to DataFrame and drop NaN + df = pd.DataFrame({'s1': series1, 's2': series2}) + df = df.dropna() + + if len(df) < self.min_data_points: + return None + + return df['s1'], df['s2'] + + def _pearson_correlation(self, series1: pd.Series, series2: pd.Series) -> CorrelationResult: + """Compute Pearson correlation.""" + corr, p_value = stats.pearsonr(series1, series2) + + # Calculate confidence interval using Fisher transformation + n = len(series1) + z = np.arctanh(corr) + se = 1 / np.sqrt(n - 3) + z_lower = z - 1.96 * se + z_upper = z + 1.96 * se + ci = (np.tanh(z_lower), np.tanh(z_upper)) + + return CorrelationResult( + assets=(series1.name, series2.name), + method=CorrelationMethod.PEARSON, + correlation=corr, + p_value=p_value, + confidence_interval=ci, + metadata={'n_observations': n} + ) + + def _spearman_correlation(self, series1: pd.Series, series2: pd.Series) -> CorrelationResult: + """Compute Spearman rank correlation.""" + corr, p_value = stats.spearmanr(series1, series2) + + return CorrelationResult( + assets=(series1.name, series2.name), + method=CorrelationMethod.SPEARMAN, + correlation=corr, + p_value=p_value, + metadata={'n_observations': len(series1)} + ) + + def _kendall_correlation(self, series1: pd.Series, series2: pd.Series) -> CorrelationResult: + """Compute Kendall's tau correlation.""" + corr, p_value = stats.kendalltau(series1, series2) + + return CorrelationResult( + assets=(series1.name, series2.name), + method=CorrelationMethod.KENDALL, + correlation=corr, + p_value=p_value, + metadata={'n_observations': len(series1)} + ) + + def _lead_lag_correlation( + self, + series1: pd.Series, + series2: pd.Series, + max_lag: int + ) -> CorrelationResult: + """ + Compute lead-lag correlation to find optimal time shift. + + Returns correlation at optimal lag where series1 leads series2. + """ + # Normalize series + s1_norm = (series1 - series1.mean()) / series1.std() + s2_norm = (series2 - series2.mean()) / series2.std() + + # Compute cross-correlation + corr_values = correlate(s1_norm, s2_norm, mode='full') + lags = np.arange(-max_lag, max_lag + 1) + + # Find lag with maximum correlation + max_corr_idx = np.argmax(np.abs(corr_values)) + optimal_lag = lags[max_corr_idx] + max_corr = corr_values[max_corr_idx] / len(series1) + + # Determine lead/lag relationship + if optimal_lag < 0: + relationship = f"Asset1 leads Asset2 by {-optimal_lag} periods" + elif optimal_lag > 0: + relationship = f"Asset2 leads Asset1 by {optimal_lag} periods" + else: + relationship = "No significant lead-lag relationship" + + return CorrelationResult( + assets=(series1.name, series2.name), + method=CorrelationMethod.LEAD_LAG, + correlation=max_corr, + lag=optimal_lag, + metadata={ + 'relationship': relationship, + 'max_lag_considered': max_lag, + 'n_observations': len(series1) + } + ) + + def _dcc_garch_correlation(self, series1: pd.Series, series2: pd.Series) -> CorrelationResult: + """ + Compute Dynamic Conditional Correlation using GARCH model. + + Simplified implementation - in production would use arch package. + """ + # Calculate returns + returns1 = series1.pct_change().dropna() + returns2 = series2.pct_change().dropna() + + # Align returns + returns_df = pd.DataFrame({'r1': returns1, 'r2': returns2}).dropna() + + if len(returns_df) < 50: # Need sufficient data for GARCH + raise ValueError("Insufficient data for DCC-GARCH estimation") + + # Simplified DCC calculation + # In practice, would use: from arch import arch_model + ewma_corr = returns_df['r1'].ewm(span=20).corr(returns_df['r2']).iloc[-1] + + return CorrelationResult( + assets=(series1.name, series2.name), + method=CorrelationMethod.DCC_GARCH, + correlation=ewma_corr, + metadata={ + 'method': 'simplified_ewma_approximation', + 'n_observations': len(returns_df) + } + ) + + def _wavelet_coherence(self, series1: pd.Series, series2: pd.Series) -> CorrelationResult: + """ + Compute wavelet coherence for multi-timescale analysis. + + Based on Grinsted et al. (2004) method. + """ + try: + # Ensure equal length + min_len = min(len(series1), len(series2)) + s1 = series1.values[:min_len] + s2 = series2.values[:min_len] + + # Normalize + s1_norm = (s1 - np.mean(s1)) / np.std(s1) + s2_norm = (s2 - np.mean(s2)) / np.std(s2) + + # Continuous wavelet transform + scales = np.arange(1, 65) # 64 scales for multi-resolution + coefficients1, _ = pywt.cwt(s1_norm, scales, 'morl') + coefficients2, _ = pywt.cwt(s2_norm, scales, 'morl') + + # Wavelet coherence + power1 = np.abs(coefficients1) ** 2 + power2 = np.abs(coefficients2) ** 2 + cross_power = coefficients1 * np.conj(coefficients2) + + # Smoothing + smooth = lambda x: np.convolve(x, np.ones(3)/3, mode='same') + smooth_power1 = np.apply_along_axis(smooth, 1, power1) + smooth_power2 = np.apply_along_axis(smooth, 1, power2) + smooth_cross = np.apply_along_axis(smooth, 1, cross_power) + + # Coherence + coherence = np.abs(smooth_cross) ** 2 / (smooth_power1 * smooth_power2) + + # Average coherence across scales + avg_coherence = np.nanmean(coherence) + + # Find dominant time scale + scale_power = np.nanmean(coherence, axis=1) + dominant_scale_idx = np.nanargmax(scale_power) + dominant_scale = scales[dominant_scale_idx] + + return CorrelationResult( + assets=(series1.name, series2.name), + method=CorrelationMethod.WAVELET, + correlation=avg_coherence, + time_scale=dominant_scale, + metadata={ + 'n_scales': len(scales), + 'dominant_scale': dominant_scale, + 'n_observations': min_len + } + ) + + except Exception as e: + raise ValueError(f"Wavelet coherence computation failed: {e}") \ No newline at end of file diff --git a/tradingagents/cross_asset_correlation/correlation_regimes_fixed.py b/tradingagents/cross_asset_correlation/correlation_regimes_fixed.py new file mode 100644 index 00000000..3d88a16d --- /dev/null +++ b/tradingagents/cross_asset_correlation/correlation_regimes_fixed.py @@ -0,0 +1,676 @@ +""" +Correlation Regime Detector +Identifies changing correlation patterns and market regimes. + +Detects: +- Correlation breakdowns and regime shifts +- Crisis periods (flight to quality, contagion) +- Bull/bear market correlation patterns +- Seasonal correlation patterns +- Structural breaks in relationships + +Based on regime switching models and structural break detection. +""" + +import numpy as np +import pandas as pd +from typing import Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +from enum import Enum +import warnings +from scipy import stats +from scipy.signal import find_peaks +from sklearn.cluster import KMeans + + +class CorrelationRegime(Enum): + """Correlation regime classifications.""" + NORMAL = "normal" # Stable, moderate correlations + CRISIS = "crisis" # High correlations (contagion) + DECOUPLING = "decoupling" # Low/negative correlations + TRENDING = "trending" # Strong positive trends + DIVERGING = "diverging" # Strong negative trends + VOLATILE = "volatile" # High volatility, unstable correlations + SEASONAL = "seasonal" # Seasonal pattern + + +@dataclass +class RegimeDetectionResult: + """Results of regime detection.""" + regime_type: CorrelationRegime + start_date: pd.Timestamp + end_date: pd.Timestamp + confidence: float + characteristics: Dict[str, float] + assets_affected: List[str] + + +class CorrelationRegimeDetector: + """Detects and analyzes correlation regimes.""" + + def __init__( + self, + min_regime_length: int = 10, + detection_method: str = "rolling_volatility" + ): + """ + Initialize regime detector. + + Args: + min_regime_length: Minimum length of a regime in periods + detection_method: Primary detection method + """ + self.min_regime_length = min_regime_length + self.detection_method = detection_method + + def detect_regimes( + self, + correlation_series: pd.Series, + price_data: Optional[pd.DataFrame] = None, + methods: List[str] = None + ) -> List[RegimeDetectionResult]: + """ + Detect correlation regimes in time series. + + Args: + correlation_series: Time series of correlation values + price_data: Optional price data for additional context + methods: List of detection methods to use + + Returns: + List of detected regimes + """ + if methods is None: + methods = ["rolling_volatility", "changepoint", "clustering"] + + # Validate input + if len(correlation_series) < self.min_regime_length * 2: + raise ValueError( + f"Insufficient data. Need at least {self.min_regime_length * 2} periods" + ) + + # Apply each detection method + all_regimes = [] + + for method in methods: + try: + if method == "rolling_volatility": + regimes = self._detect_by_volatility(correlation_series) + elif method == "changepoint": + regimes = self._detect_changepoints(correlation_series) + elif method == "clustering": + regimes = self._detect_by_clustering(correlation_series) + elif method == "markov": + regimes = self._detect_markov_regimes(correlation_series) + else: + warnings.warn(f"Unknown detection method: {method}") + continue + + all_regimes.extend(regimes) + + except Exception as e: + warnings.warn(f"Method {method} failed: {e}") + continue + + # Merge overlapping regimes + merged_regimes = self._merge_regimes(all_regimes) + + # Classify regimes + classified_regimes = [] + for regime in merged_regimes: + classified = self._classify_regime(regime, correlation_series, price_data) + classified_regimes.append(classified) + + return classified_regimes + + def analyze_regime_transitions( + self, + regimes: List[RegimeDetectionResult] + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Analyze transitions between regimes. + + Args: + regimes: List of detected regimes + + Returns: + DataFrame with transition probabilities and statistics + """ + if len(regimes) < 2: + return pd.DataFrame(), pd.DataFrame() + + # Create transition matrix + regime_types = [r.regime_type for r in regimes] + unique_regimes = list(set(regime_types)) + + # Initialize transition matrix + transition_matrix = pd.DataFrame( + 0, + index=unique_regimes, + columns=unique_regimes + ) + + # Count transitions + for i in range(len(regime_types) - 1): + from_regime = regime_types[i] + to_regime = regime_types[i + 1] + transition_matrix.loc[from_regime, to_regime] += 1 + + # Convert to probabilities + row_sums = transition_matrix.sum(axis=1) + transition_probs = transition_matrix.div(row_sums, axis=0) + + # Calculate regime statistics + regime_stats = [] + for regime_type in unique_regimes: + regime_instances = [r for r in regimes if r.regime_type == regime_type] + + if regime_instances: + durations = [ + (r.end_date - r.start_date).days + for r in regime_instances + ] + + stats_dict = { + 'regime_type': regime_type.value, + 'count': len(regime_instances), + 'avg_duration_days': np.mean(durations), + 'std_duration_days': np.std(durations), + 'min_duration_days': np.min(durations), + 'max_duration_days': np.max(durations), + 'total_days': np.sum(durations), + 'frequency': len(regime_instances) / len(regimes) + } + + regime_stats.append(stats_dict) + + return pd.DataFrame(regime_stats), transition_probs + + def predict_next_regime( + self, + current_regime: CorrelationRegime, + transition_matrix: pd.DataFrame, + market_conditions: Optional[Dict] = None + ) -> Tuple[CorrelationRegime, float]: + """ + Predict next regime based on transition probabilities. + + Args: + current_regime: Current regime + transition_matrix: Transition probability matrix + market_conditions: Optional market condition indicators + + Returns: + Tuple of (predicted regime, probability) + """ + if current_regime not in transition_matrix.index: + return current_regime, 1.0 + + # Get transition probabilities from current regime + probs = transition_matrix.loc[current_regime] + + # Adjust based on market conditions if provided + if market_conditions: + probs = self._adjust_probs_with_conditions(probs, market_conditions) + + # Find most likely next regime + next_regime = probs.idxmax() + probability = probs.max() + + return next_regime, probability + + def detect_crisis_periods( + self, + correlation_matrix_series: pd.DataFrame, + volatility_series: pd.Series, + threshold: float = 0.8 + ) -> List[Tuple[pd.Timestamp, pd.Timestamp]]: + """ + Detect crisis periods based on correlation and volatility. + + Args: + correlation_matrix_series: Time series of correlation matrices + volatility_series: Time series of market volatility + threshold: Correlation threshold for crisis detection + + Returns: + List of (start_date, end_date) tuples for crisis periods + """ + # Calculate average correlation over time + avg_correlations = [] + dates = [] + + for date, corr_matrix in correlation_matrix_series.items(): + # Flatten matrix (excluding diagonal) + corr_values = corr_matrix.values[ + np.triu_indices_from(corr_matrix, k=1) + ] + avg_corr = np.mean(corr_values) + avg_correlations.append(avg_corr) + dates.append(date) + + avg_corr_series = pd.Series(avg_correlations, index=dates) + + # Detect crisis periods (high correlation + high volatility) + crisis_periods = [] + in_crisis = False + crisis_start = None + + for date in avg_corr_series.index: + avg_corr = avg_corr_series[date] + vol = volatility_series.get(date, 0) + + is_crisis = (avg_corr >= threshold) and (vol > volatility_series.median()) + + if is_crisis and not in_crisis: + # Start of crisis + in_crisis = True + crisis_start = date + elif not is_crisis and in_crisis: + # End of crisis + in_crisis = False + if crisis_start: + crisis_periods.append((crisis_start, date)) + crisis_start = None + + # Handle ongoing crisis at end + if in_crisis and crisis_start: + crisis_periods.append((crisis_start, avg_corr_series.index[-1])) + + return crisis_periods + + def _detect_by_volatility(self, correlation_series: pd.Series) -> List[Dict]: + """Detect regimes based on rolling volatility.""" + # Calculate rolling volatility + window = min(20, len(correlation_series) // 4) + rolling_vol = correlation_series.rolling(window).std() + + # Normalize volatility + vol_normalized = (rolling_vol - rolling_vol.mean()) / rolling_vol.std() + + # Detect high/low volatility periods + regimes = [] + current_regime = None + regime_start = None + + for date, vol in vol_normalized.items(): + if pd.isna(vol): + continue + + if vol > 1.0: + regime_type = "high_vol" + elif vol < -1.0: + regime_type = "low_vol" + else: + regime_type = "normal" + + if regime_type != current_regime: + if current_regime is not None and regime_start: + regimes.append({ + 'type': current_regime, + 'start': regime_start, + 'end': date + }) + current_regime = regime_type + regime_start = date + + # Add final regime + if current_regime is not None and regime_start: + regimes.append({ + 'type': current_regime, + 'start': regime_start, + 'end': correlation_series.index[-1] + }) + + return regimes + + def _detect_changepoints(self, correlation_series: pd.Series) -> List[Dict]: + """Detect regimes using changepoint detection.""" + # Clean data + clean_series = correlation_series.dropna() + + if len(clean_series) < 10: + return [] + + # Simplified changepoint detection using variance changes + regimes = [] + window = min(20, len(clean_series) // 5) + + # Calculate rolling variance + rolling_var = clean_series.rolling(window).var() + + # Detect significant variance changes + var_mean = rolling_var.mean() + var_std = rolling_var.std() + + current_regime = None + regime_start = None + + for date, var in rolling_var.items(): + if pd.isna(var): + continue + + if var > var_mean + var_std: + regime_type = "high_var" + elif var < var_mean - var_std: + regime_type = "low_var" + else: + regime_type = "normal_var" + + if regime_type != current_regime: + if current_regime is not None and regime_start: + regimes.append({ + 'type': current_regime, + 'start': regime_start, + 'end': date + }) + current_regime = regime_type + regime_start = date + + # Add final regime + if current_regime is not None and regime_start: + regimes.append({ + 'type': current_regime, + 'start': regime_start, + 'end': clean_series.index[-1] + }) + + return regimes + + def _detect_by_clustering(self, correlation_series: pd.Series) -> List[Dict]: + """Detect regimes using clustering.""" + # Prepare features for clustering + clean_series = correlation_series.dropna() + + if len(clean_series) < self.min_regime_length * 3: + return [] + + # Create feature matrix (value, rolling mean, rolling std) + window = min(10, len(clean_series) // 10) + features = pd.DataFrame({ + 'value': clean_series, + 'rolling_mean': clean_series.rolling(window).mean(), + 'rolling_std': clean_series.rolling(window).std() + }).dropna() + + # Determine optimal number of clusters (2-4) + n_clusters = min(4, max(2, len(features) // (self.min_regime_length * 2))) + + # Apply K-means clustering + kmeans = KMeans(n_clusters=n_clusters, random_state=42) + clusters = kmeans.fit_predict(features) + + # Convert clusters to regimes + regimes = [] + current_cluster = None + regime_start = None + + for idx, (date, cluster) in enumerate(zip(features.index, clusters)): + if cluster != current_cluster: + if current_cluster is not None and regime_start: + regimes.append({ + 'type': f'cluster_{current_cluster}', + 'start': regime_start, + 'end': features.index[idx - 1], + 'cluster': current_cluster, + 'center': kmeans.cluster_centers_[current_cluster] + }) + current_cluster = cluster + regime_start = date + + # Add final regime + if current_cluster is not None and regime_start: + regimes.append({ + 'type': f'cluster_{current_cluster}', + 'start': regime_start, + 'end': features.index[-1], + 'cluster': current_cluster, + 'center': kmeans.cluster_centers_[current_cluster] + }) + + return regimes + + def _detect_markov_regimes(self, correlation_series: pd.Series) -> List[Dict]: + """Detect regimes using Markov switching model (simplified).""" + # Simplified implementation - in production would use statsmodels + clean_series = correlation_series.dropna() + + if len(clean_series) < 50: + return [] + + # Simple threshold-based regime detection + mean_val = clean_series.mean() + std_val = clean_series.std() + + regimes = [] + current_regime = None + regime_start = None + + for date, value in clean_series.items(): + if value > mean_val + std_val: + regime_type = "high" + elif value < mean_val - std_val: + regime_type = "low" + else: + regime_type = "normal" + + if regime_type != current_regime: + if current_regime is not None and regime_start: + # Check minimum length + if (date - regime_start).days >= self.min_regime_length: + regimes.append({ + 'type': regime_type, + 'start': regime_start, + 'end': date + }) + current_regime = regime_type + regime_start = date + + # Add final regime + if current_regime is not None and regime_start: + regimes.append({ + 'type': current_regime, + 'start': regime_start, + 'end': clean_series.index[-1] + }) + + return regimes + + def _merge_regimes(self, regimes: List[Dict]) -> List[Dict]: + """Merge overlapping regimes.""" + if not regimes: + return [] + + # Sort by start date + regimes.sort(key=lambda x: x['start']) + + merged = [] + current = regimes[0] + + for regime in regimes[1:]: + # Check if regimes overlap or are adjacent + if regime['start'] <= current['end'] or ( + regime['start'] - current['end']).days <= 1: + # Merge regimes + current['end'] = max(current['end'], regime['end']) + # Combine type information + if 'type' in current and 'type' in regime: + current['type'] = f"{current['type']}_{regime['type']}" + else: + merged.append(current) + current = regime + + merged.append(current) + + return merged + + def _classify_regime( + self, + regime: Dict, + correlation_series: pd.Series, + price_data: Optional[pd.DataFrame] = None + ) -> RegimeDetectionResult: + """Classify a regime based on its characteristics.""" + # Extract regime data + regime_data = correlation_series.loc[regime['start']:regime['end']] + + if regime_data.empty: + return RegimeDetectionResult( + regime_type=CorrelationRegime.NORMAL, + start_date=regime['start'], + end_date=regime['end'], + confidence=0.5, + characteristics={}, + assets_affected=[] + ) + + # Calculate regime characteristics + mean_corr = regime_data.mean() + std_corr = regime_data.std() + trend = self._calculate_trend(regime_data) + + # Classify based on characteristics + if std_corr > correlation_series.std() * 1.5: + regime_type = CorrelationRegime.VOLATILE + confidence = 0.7 + elif mean_corr > 0.7: + regime_type = CorrelationRegime.CRISIS + confidence = 0.8 + elif mean_corr < 0.2: + regime_type = CorrelationRegime.DECOUPLING + confidence = 0.6 + elif trend > 0.1: + regime_type = CorrelationRegime.TRENDING + confidence = 0.65 + elif trend < -0.1: + regime_type = CorrelationRegime.DIVERGING + confidence = 0.65 + else: + regime_type = CorrelationRegime.NORMAL + confidence = 0.5 + + # Check for seasonal patterns + if self._detect_seasonal_pattern(regime_data): + regime_type = CorrelationRegime.SEASONAL + confidence = 0.6 + + characteristics = { + 'mean_correlation': mean_corr, + 'std_correlation': std_corr, + 'trend': trend, + 'duration_days': (regime['end'] - regime['start']).days + } + + return RegimeDetectionResult( + regime_type=regime_type, + start_date=regime['start'], + end_date=regime['end'], + confidence=confidence, + characteristics=characteristics, + assets_affected=[] + ) + + def _calculate_trend(self, series: pd.Series) -> float: + """ + Calculate linear trend of a series. + + Returns slope normalized by standard deviation. + """ + if len(series) < 2: + return 0.0 + + x = np.arange(len(series)) + y = series.values + + # Remove NaN + mask = ~np.isnan(y) + if mask.sum() < 2: + return 0.0 + + slope, _, _, _, _ = stats.linregress(x[mask], y[mask]) + + # Normalize by std to get comparable trend magnitude + std = series.std() + if std == 0: + return 0.0 + + return slope * len(series) / std + + def _detect_seasonal_pattern(self, series: pd.Series) -> bool: + """ + Detect if a series exhibits seasonal patterns using autocorrelation peaks. + + Returns True if significant periodic pattern is found. + """ + if len(series) < 30: + return False + + try: + values = series.dropna().values + # Compute autocorrelation + n = len(values) + mean = np.mean(values) + var = np.var(values) + + if var == 0: + return False + + autocorr = np.correlate(values - mean, values - mean, mode='full') + autocorr = autocorr[n - 1:] / (var * n) + + # Look for significant peaks in autocorrelation (beyond lag 5) + if len(autocorr) > 10: + peaks, properties = find_peaks(autocorr[5:], height=0.3) + return len(peaks) > 0 + + except Exception: + pass + + return False + + def _adjust_probs_with_conditions( + self, + probs: pd.Series, + market_conditions: Dict + ) -> pd.Series: + """ + Adjust transition probabilities based on current market conditions. + + Args: + probs: Base transition probabilities + market_conditions: Dict with keys like 'volatility', 'trend', 'volume' + + Returns: + Adjusted probability series + """ + adjusted = probs.copy() + + # High volatility increases probability of crisis/volatile regimes + if market_conditions.get('volatility', 0) > 0.7: + for regime in adjusted.index: + regime_str = regime if isinstance(regime, str) else regime.value + if regime_str in ('crisis', 'volatile'): + adjusted[regime] *= 1.5 + elif regime_str == 'normal': + adjusted[regime] *= 0.7 + + # Strong negative trend increases probability of diverging + if market_conditions.get('trend', 0) < -0.5: + for regime in adjusted.index: + regime_str = regime if isinstance(regime, str) else regime.value + if regime_str == 'diverging': + adjusted[regime] *= 1.3 + + # Low volatility increases probability of normal/decoupling + if market_conditions.get('volatility', 0) < 0.3: + for regime in adjusted.index: + regime_str = regime if isinstance(regime, str) else regime.value + if regime_str in ('normal', 'decoupling'): + adjusted[regime] *= 1.3 + + # Renormalize to sum to 1 + total = adjusted.sum() + if total > 0: + adjusted = adjusted / total + + return adjusted \ No newline at end of file diff --git a/tradingagents/cross_asset_correlation/multi_asset_processor.py b/tradingagents/cross_asset_correlation/multi_asset_processor.py new file mode 100644 index 00000000..b794fe8e --- /dev/null +++ b/tradingagents/cross_asset_correlation/multi_asset_processor.py @@ -0,0 +1,465 @@ +""" +Multi-Asset Processor +Handles processing of multiple asset classes and data sources. + +Supports: +- Multiple asset classes (stocks, ETFs, commodities, currencies, crypto) +- Different data frequencies (daily, hourly, minute) +- Missing data imputation +- Returns calculation and normalization +- Asset classification and grouping + +Based on research in multi-asset portfolio optimization. +""" + +import pandas as pd +import numpy as np +from typing import Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +from enum import Enum +import warnings +from datetime import datetime, timedelta + + +class AssetClass(Enum): + """Asset classification categories.""" + STOCK = "stock" + ETF = "etf" + COMMODITY = "commodity" + CURRENCY = "currency" + CRYPTO = "cryptocurrency" + BOND = "bond" + INDEX = "index" + FUTURE = "future" + OPTION = "option" + + +class DataFrequency(Enum): + """Data frequency options.""" + DAILY = "daily" + HOURLY = "hourly" + MINUTE_30 = "30min" + MINUTE_15 = "15min" + MINUTE_5 = "5min" + MINUTE_1 = "1min" + + +@dataclass +class AssetMetadata: + """Metadata for financial assets.""" + symbol: str + name: str + asset_class: AssetClass + sector: Optional[str] = None + industry: Optional[str] = None + country: Optional[str] = None + currency: Optional[str] = None + market_cap: Optional[float] = None + volume_avg: Optional[float] = None + data_source: Optional[str] = None + + +class MultiAssetProcessor: + """Processor for handling multiple asset data.""" + + def __init__(self, default_frequency: DataFrequency = DataFrequency.DAILY): + """ + Initialize multi-asset processor. + + Args: + default_frequency: Default data frequency for processing + """ + self.default_frequency = default_frequency + self.asset_metadata: Dict[str, AssetMetadata] = {} + + def load_price_data( + self, + price_data: pd.DataFrame, + metadata: Optional[Dict[str, AssetMetadata]] = None + ) -> pd.DataFrame: + """ + Load and validate price data for multiple assets. + + Args: + price_data: DataFrame with assets as columns and prices as rows + metadata: Optional metadata for each asset + + Returns: + Cleaned and validated price DataFrame + """ + # Validate input + if price_data.empty: + raise ValueError("Price data is empty") + + if price_data.isnull().all().all(): + raise ValueError("All price data is NaN") + + # Store metadata if provided + if metadata: + self.asset_metadata.update(metadata) + + # Fill missing metadata + for symbol in price_data.columns: + if symbol not in self.asset_metadata: + self.asset_metadata[symbol] = AssetMetadata( + symbol=symbol, + name=symbol, + asset_class=self._infer_asset_class(symbol) + ) + + # Clean data + cleaned_data = self._clean_price_data(price_data) + + return cleaned_data + + def calculate_returns( + self, + price_data: pd.DataFrame, + return_type: str = "log", + fill_na: bool = True + ) -> pd.DataFrame: + """ + Calculate returns from price data. + + Args: + price_data: Price DataFrame + return_type: Type of returns ('log' or 'simple') + fill_na: Whether to fill NaN returns with zeros + + Returns: + Returns DataFrame + """ + if return_type == "log": + returns = np.log(price_data / price_data.shift(1)) + elif return_type == "simple": + returns = price_data.pct_change() + else: + raise ValueError(f"Unknown return type: {return_type}") + + # Remove first row (NaN) + returns = returns.iloc[1:] + + if fill_na: + returns = returns.fillna(0) + + return returns + + def resample_data( + self, + data: pd.DataFrame, + target_frequency: DataFrequency, + aggregation: str = "last" + ) -> pd.DataFrame: + """ + Resample data to different frequency. + + Args: + data: Input DataFrame with datetime index + target_frequency: Target frequency for resampling + aggregation: Aggregation method ('last', 'mean', 'ohlc') + + Returns: + Resampled DataFrame + """ + if not isinstance(data.index, pd.DatetimeIndex): + raise ValueError("Data must have DatetimeIndex for resampling") + + # Map frequency to pandas offset + freq_map = { + DataFrequency.DAILY: 'D', + DataFrequency.HOURLY: 'H', + DataFrequency.MINUTE_30: '30min', + DataFrequency.MINUTE_15: '15min', + DataFrequency.MINUTE_5: '5min', + DataFrequency.MINUTE_1: '1min' + } + + freq_str = freq_map.get(target_frequency) + if not freq_str: + raise ValueError(f"Unsupported frequency: {target_frequency}") + + if aggregation == "last": + resampled = data.resample(freq_str).last() + elif aggregation == "mean": + resampled = data.resample(freq_str).mean() + elif aggregation == "ohlc": + # For OHLC resampling + resampled = pd.DataFrame() + for col in data.columns: + ohlc = data[col].resample(freq_str).ohlc() + resampled[f"{col}_open"] = ohlc['open'] + resampled[f"{col}_high"] = ohlc['high'] + resampled[f"{col}_low"] = ohlc['low'] + resampled[f"{col}_close"] = ohlc['close'] + else: + raise ValueError(f"Unknown aggregation: {aggregation}") + + return resampled.dropna() + + def align_time_series( + self, + dataframes: List[pd.DataFrame], + method: str = "inner" + ) -> List[pd.DataFrame]: + """ + Align multiple time series to common index. + + Args: + dataframes: List of DataFrames to align + method: Alignment method ('inner' or 'outer') + + Returns: + List of aligned DataFrames + """ + if not dataframes: + return [] + + # Get common index + indices = [df.index for df in dataframes] + + if method == "inner": + common_index = indices[0] + for idx in indices[1:]: + common_index = common_index.intersection(idx) + elif method == "outer": + common_index = indices[0] + for idx in indices[1:]: + common_index = common_index.union(idx) + else: + raise ValueError(f"Unknown alignment method: {method}") + + # Align each DataFrame + aligned_dfs = [] + for df in dataframes: + aligned = df.reindex(common_index) + aligned_dfs.append(aligned) + + return aligned_dfs + + def detect_asset_groups( + self, + correlation_matrix: pd.DataFrame, + threshold: float = 0.7, + method: str = "hierarchical" + ) -> List[List[str]]: + """ + Detect groups of highly correlated assets. + + Args: + correlation_matrix: Correlation matrix DataFrame + threshold: Correlation threshold for grouping + method: Grouping method ('hierarchical' or 'connected_components') + + Returns: + List of asset groups (lists of symbols) + """ + if method == "hierarchical": + return self._hierarchical_clustering(correlation_matrix, threshold) + elif method == "connected_components": + return self._connected_components(correlation_matrix, threshold) + else: + raise ValueError(f"Unknown grouping method: {method}") + + def calculate_correlation_stability( + self, + price_data: pd.DataFrame, + window: int = 60, + step: int = 5 + ) -> pd.DataFrame: + """ + Calculate stability of correlations over time. + + Args: + price_data: Price DataFrame + window: Rolling window size + step: Step between windows + + Returns: + DataFrame with correlation stability metrics + """ + returns = self.calculate_returns(price_data) + n_periods = len(returns) + + stability_metrics = {} + + for i in range(0, n_periods - window, step): + window_data = returns.iloc[i:i + window] + corr_matrix = window_data.corr() + + # Flatten correlation matrix (excluding diagonal) + corr_values = corr_matrix.values[np.triu_indices_from(corr_matrix, k=1)] + + # Calculate stability metrics for this window + stability_metrics[f"window_{i}"] = { + 'mean_correlation': np.mean(corr_values), + 'std_correlation': np.std(corr_values), + 'min_correlation': np.min(corr_values), + 'max_correlation': np.max(corr_values), + 'positive_ratio': np.sum(corr_values > 0) / len(corr_values), + 'start_date': returns.index[i], + 'end_date': returns.index[i + window - 1] + } + + return pd.DataFrame(stability_metrics).T + + def get_asset_class_correlations( + self, + price_data: pd.DataFrame + ) -> pd.DataFrame: + """ + Calculate average correlations within and between asset classes. + + Args: + price_data: Price DataFrame with assets as columns + + Returns: + DataFrame of asset class correlations + """ + # Get returns + returns = self.calculate_returns(price_data) + + # Get asset classes + asset_classes = {} + for symbol in returns.columns: + if symbol in self.asset_metadata: + asset_classes[symbol] = self.asset_metadata[symbol].asset_class + else: + asset_classes[symbol] = AssetClass.STOCK + + # Calculate correlation matrix + corr_matrix = returns.corr() + + # Group by asset class + unique_classes = set(asset_classes.values()) + class_correlations = pd.DataFrame( + index=unique_classes, + columns=unique_classes + ) + + for class1 in unique_classes: + for class2 in unique_classes: + # Get assets in each class + assets1 = [a for a, c in asset_classes.items() if c == class1] + assets2 = [a for a, c in asset_classes.items() if c == class2] + + if not assets1 or not assets2: + class_correlations.loc[class1, class2] = np.nan + continue + + # Get correlations between classes using vectorized operations + sub_matrix = corr_matrix.loc[assets1, assets2].values + + if class1 == class2: + # For intra-class correlation, use upper triangle (excluding diagonal) + if len(assets1) > 1: + corr_values = sub_matrix[np.triu_indices_from(sub_matrix, k=1)] + else: + corr_values = np.array([]) + else: + # For inter-class correlation, use the whole sub-matrix + corr_values = sub_matrix.flatten() + + # Calculate average correlation, ignoring NaNs + if corr_values.size > 0: + class_correlations.loc[class1, class2] = np.nanmean(corr_values) + else: + class_correlations.loc[class1, class2] = np.nan + + return class_correlations + + def _clean_price_data(self, price_data: pd.DataFrame) -> pd.DataFrame: + """Clean and validate price data.""" + # Remove columns with all NaN + price_data = price_data.dropna(axis=1, how='all') + + # Forward fill for small gaps (up to 2 periods) + price_data = price_data.ffill(limit=2) + + # Remove any remaining NaN + price_data = price_data.dropna() + + # Ensure prices are positive + if (price_data <= 0).any().any(): + warnings.warn("Non-positive prices found. Replacing with NaN.") + price_data[price_data <= 0] = np.nan + + return price_data + + def _infer_asset_class(self, symbol: str) -> AssetClass: + """Infer asset class from symbol.""" + symbol_lower = symbol.lower() + + # Common patterns + if any(x in symbol_lower for x in ['.us', '.nyse', '.nasdaq']): + return AssetClass.STOCK + elif symbol_lower.endswith('=x'): + return AssetClass.CURRENCY + elif any(x in symbol_lower for x in ['btc', 'eth', 'xrp', 'crypto']): + return AssetClass.CRYPTO + elif any(x in symbol_lower for x in ['etf', 'ivv', 'spy', 'qqq']): + return AssetClass.ETF + elif any(x in symbol_lower for x in ['gold', 'silver', 'oil', 'commodity']): + return AssetClass.COMMODITY + elif any(x in symbol_lower for x in ['^', 'index', '.indx']): + return AssetClass.INDEX + else: + return AssetClass.STOCK # Default + + def _hierarchical_clustering( + self, + correlation_matrix: pd.DataFrame, + threshold: float + ) -> List[List[str]]: + """Group assets using hierarchical clustering.""" + from scipy.cluster.hierarchy import linkage, fcluster + from scipy.spatial.distance import squareform + + # Convert correlation to distance (1 - |correlation|) + distance_matrix = 1 - np.abs(correlation_matrix.values) + + # Perform hierarchical clustering + linkage_matrix = linkage(squareform(distance_matrix), method='average') + + # Form clusters at threshold + clusters = fcluster(linkage_matrix, 1 - threshold, criterion='distance') + + # Group assets by cluster + asset_groups = {} + for asset, cluster_id in zip(correlation_matrix.columns, clusters): + if cluster_id not in asset_groups: + asset_groups[cluster_id] = [] + asset_groups[cluster_id].append(asset) + + return list(asset_groups.values()) + + def _connected_components( + self, + correlation_matrix: pd.DataFrame, + threshold: float + ) -> List[List[str]]: + """Group assets using graph connected components.""" + import networkx as nx + + # Create graph + G = nx.Graph() + + # Add nodes (assets) + for asset in correlation_matrix.columns: + G.add_node(asset) + + # Add edges for high correlations + n_assets = len(correlation_matrix.columns) + for i in range(n_assets): + for j in range(i + 1, n_assets): + asset_i = correlation_matrix.columns[i] + asset_j = correlation_matrix.columns[j] + corr = abs(correlation_matrix.iloc[i, j]) + + if corr >= threshold: + G.add_edge(asset_i, asset_j) + + # Find connected components + components = list(nx.connected_components(G)) + + # Convert to lists + return [list(comp) for comp in components] \ No newline at end of file