diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 95de76ae..b23d21fe 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -5,7 +5,12 @@ from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicator from tradingagents.dataflows.config import get_config + +from tradingagents.engines.regime_detector import RegimeDetector, DynamicIndicatorSelector from tradingagents.utils.anonymizer import TickerAnonymizer +import pandas as pd +from io import StringIO +from datetime import datetime, timedelta # Initialize anonymizer (shared instance appropriate here or inside) anonymizer = TickerAnonymizer() @@ -23,6 +28,51 @@ def create_market_analyst(llm): # NOTE: We continue to use 'ticker' variable name but it now holds 'ASSET_XXX' + # REGIME DETECTION LOGIC + regime_val = "UNKNOWN" + metrics = {} + optimal_params = {} + regime_context = "REGIME DETECTION FAILED or DATA UNAVAILABLE" + volatility_score = 0.0 + + try: + # Calculate start date (1 year lookback for robust regime detection) + dt_obj = datetime.strptime(current_date, "%Y-%m-%d") + start_date = (dt_obj - timedelta(days=365)).strftime("%Y-%m-%d") + + # Fetch data for regime detection using the anonymized ticker + # This calls the tool which handles deanonymization internally if needed + # (assuming core_stock_tools.get_stock_data handles the 'ASSET_XXX' -> Real mapping) + # Use invoke for StructuredTool with ALL required args + raw_data = get_stock_data.invoke({ + "symbol": ticker, + "start_date": start_date, + "end_date": current_date, + "format": "csv" + }) + + # Parse data + if isinstance(raw_data, str) and "Error" not in raw_data and "No data" not in raw_data: + # Parse data (Standardized CSV format with # comments) + df = pd.read_csv(StringIO(raw_data), comment='#') + + # Check for Close column + if 'Close' in df.columns: + # Detect Regime + regime, metrics = RegimeDetector.detect_regime(df['Close']) + optimal_params = DynamicIndicatorSelector.get_optimal_parameters(regime) + regime_val = regime.value + volatility_score = metrics.get("volatility", 0.0) + + # Construct Context String + regime_context = f"MARKET REGIME DETECTED: {regime_val}\n" + regime_context += f"METRICS: {json.dumps(metrics)}\n" + regime_context += f"RECOMMENDED STRATEGY: {optimal_params.get('strategy', 'N/A')}\n" + regime_context += f"RECOMMENDED INDICATORS: {json.dumps(optimal_params)}\n" + regime_context += f"RATIONALE: {optimal_params.get('rationale', '')}" + except Exception as e: + print(f"WARNING: Regime detection failed for {ticker}: {e}") + tools = [ get_stock_data, get_indicators, @@ -36,7 +86,14 @@ CRITICAL DATA CONSTRAINT: 2. "Price 105.0" means +5% gain from start. It does NOT mean $105.00. 3. DO NOT hallucinate real-world ticker prices. Treat this as a pure mathematical time series. -TASK: Select relevant indicators and analyze trends. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are: +DYNAMIC MARKET REGIME CONTEXT: +{regime_context} + +TASK: Select relevant indicators and analyze trends. +Your role is to select the **most relevant indicators** for the DETECTED REGIME ({regime_val}). +The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. + +INDICATOR CATEGORIES: Moving Averages: - close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals. @@ -98,6 +155,9 @@ Volume-Based Indicators: return { "messages": [result], "market_report": report, + "market_regime": regime_val, + "regime_metrics": metrics, + "volatility_score": volatility_score } return market_analyst_node diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 271e212f..f3f830e3 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -63,6 +63,7 @@ class AgentState(MessagesState): # regime data market_regime: Annotated[str, "Current Market Regime (e.g. VOLATILE, TRENDING_UP)"] + regime_metrics: Annotated[dict, "Metrics used to determine regime"] volatility_score: Annotated[float, "Current Volatility Score"] # researcher team discussion step diff --git a/tradingagents/agents/utils/core_stock_tools.py b/tradingagents/agents/utils/core_stock_tools.py index 5092f7d3..77b0bf66 100644 --- a/tradingagents/agents/utils/core_stock_tools.py +++ b/tradingagents/agents/utils/core_stock_tools.py @@ -8,6 +8,7 @@ def get_stock_data( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], + format: Annotated[str, "Output format 'csv' or 'string' (default: 'string')"] = "string" ) -> str: """ Retrieve stock price data (OHLCV) for a given ticker symbol. @@ -16,6 +17,7 @@ def get_stock_data( symbol (str): Ticker symbol of the company, e.g. AAPL, TSM start_date (str): Start date in yyyy-mm-dd format end_date (str): End date in yyyy-mm-dd format + format (str): 'csv' or 'string' Returns: str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range. """ @@ -28,7 +30,7 @@ def get_stock_data( real_ticker = symbol # Fallback if not anonymized # 2. Get Data using Real Ticker - raw_data = route_to_vendor("get_stock_data", real_ticker, start_date, end_date) + raw_data = route_to_vendor("get_stock_data", real_ticker, start_date, end_date, format=format) # 3. Anonymize Output (AAPL -> ASSET_XXX) anonymized_data = anonymizer.anonymize_text(raw_data, real_ticker) diff --git a/tradingagents/dataflows/alpaca.py b/tradingagents/dataflows/alpaca.py index c8d6739d..28c4db42 100644 --- a/tradingagents/dataflows/alpaca.py +++ b/tradingagents/dataflows/alpaca.py @@ -4,7 +4,7 @@ import pandas as pd from typing import Optional from datetime import datetime, timedelta -def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) -> str: +def get_stock_data(symbol: str, start_date: str = None, end_date: str = None, format: str = "string") -> str: """ Fetch historical stock data (OHLCV) from Alpaca Data API v2. @@ -12,6 +12,7 @@ def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) -> symbol: Ticker symbol (e.g., "AAPL") start_date: Start date (YYYY-MM-DD), defaults to 1 year ago end_date: End date (YYYY-MM-DD), defaults to today + format: Output format "string" (human readable) or "csv" (machine readable). Defaults to "string". Returns: String representation of the dataframe @@ -84,8 +85,12 @@ def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) -> df = pd.DataFrame(df_data) # Format output string similar to yfinance output for consistency - result_str = f"Stock Data for {symbol} from {start_date} to {end_date}\n" - result_str += df.to_string(index=False) + result_str = f"# Stock Data for {symbol} from {start_date} to {end_date}\n" + + if format.lower() == "csv": + result_str += df.to_csv(index=False) + else: + result_str += df.to_string(index=False) return result_str diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 1fcfd335..e75eae3c 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -9,6 +9,7 @@ def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], + format: Annotated[str, "Output format 'csv' or 'string'"] = "csv" ): datetime.strptime(start_date, "%Y-%m-%d") @@ -36,15 +37,20 @@ def get_YFin_data_online( if col in data.columns: data[col] = data[col].round(2) - # Convert DataFrame to CSV string - csv_string = data.to_csv() + # Convert DataFrame to string based on format + if format.lower() == 'string': + # Use to_string for human readability + result_string = data.to_string() + else: + # Default to CSV + result_string = data.to_csv() # Add header information header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n" header += f"# Total records: {len(data)}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - return header + csv_string + return header + result_string def get_stock_stats_indicators_window( symbol: Annotated[str, "ticker symbol of the company"], diff --git a/verify_regime_integration.py b/verify_regime_integration.py new file mode 100644 index 00000000..1576cccd --- /dev/null +++ b/verify_regime_integration.py @@ -0,0 +1,63 @@ +import pandas as pd +from tradingagents.engines.regime_detector import RegimeDetector, DynamicIndicatorSelector +from tradingagents.agents.utils.agent_utils import get_stock_data +from io import StringIO +import json + +def verify_regime(): + print("šŸ”¬ Verifying Regime Detector Integration...") + + ticker = "AAPL" + print(f"Fetching data for {ticker}...") + + # Simulate what market_analyst_node does + try: + # Use invoke for StructuredTool + # Provide dates (using a recent 1-year window relative to now implicitly, or fixed dates if tool supports it) + # Assuming Alpaca data is available for this range + raw_data = get_stock_data.invoke({ + "symbol": ticker, + "start_date": "2024-01-01", + "end_date": "2024-12-31", + "format": "csv" + }) + + if "Error" in raw_data: + print(f"āŒ FAIL: Data fetch error: {raw_data}") + return + + # Data has '#' comments in header, and is standard CSV + df = pd.read_csv(StringIO(raw_data), comment='#') + print(f"āœ… Data fetched: {len(df)} rows") + print(f"COLUMNS: {df.columns.tolist()}") + print(f"HEAD:\n{df.head()}") + + if 'Close' not in df.columns: + # Try case insensitive or check if it's in index? + # Sometimes to_string creates a weird header structure + pass + print("āŒ FAIL: 'Close' column missing") + return + + # Run Detector + print("Running RegimeDetector...") + regime, metrics = RegimeDetector.detect_regime(df['Close']) + + print(f"āœ… DETECTED REGIME: {regime.value}") + print(f" Volatility: {metrics['volatility']:.2%}") + print(f" Trend Strength: {metrics['trend_strength']:.2f}") + + # Run Selector + optimal_params = DynamicIndicatorSelector.get_optimal_parameters(regime) + print(f"āœ… RECOMMENDED STRATEGY: {optimal_params['strategy']}") + print(f" Indicators: {[k for k in optimal_params.keys() if 'period' in k]}") + + print("\nšŸŽ‰ INTEGRATION VERIFIED: The engine is analyzing data correctly.") + + except Exception as e: + print(f"āŒ FAIL: Exception: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + verify_regime()