Phase 2 changes
This commit is contained in:
parent
af82fd9601
commit
54a3395b37
|
|
@ -5,7 +5,12 @@ from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicator
|
|||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
||||
from tradingagents.engines.regime_detector import RegimeDetector, DynamicIndicatorSelector
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
import pandas as pd
|
||||
from io import StringIO
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Initialize anonymizer (shared instance appropriate here or inside)
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
|
@ -23,6 +28,51 @@ def create_market_analyst(llm):
|
|||
|
||||
# NOTE: We continue to use 'ticker' variable name but it now holds 'ASSET_XXX'
|
||||
|
||||
# REGIME DETECTION LOGIC
|
||||
regime_val = "UNKNOWN"
|
||||
metrics = {}
|
||||
optimal_params = {}
|
||||
regime_context = "REGIME DETECTION FAILED or DATA UNAVAILABLE"
|
||||
volatility_score = 0.0
|
||||
|
||||
try:
|
||||
# Calculate start date (1 year lookback for robust regime detection)
|
||||
dt_obj = datetime.strptime(current_date, "%Y-%m-%d")
|
||||
start_date = (dt_obj - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||
|
||||
# Fetch data for regime detection using the anonymized ticker
|
||||
# This calls the tool which handles deanonymization internally if needed
|
||||
# (assuming core_stock_tools.get_stock_data handles the 'ASSET_XXX' -> Real mapping)
|
||||
# Use invoke for StructuredTool with ALL required args
|
||||
raw_data = get_stock_data.invoke({
|
||||
"symbol": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": current_date,
|
||||
"format": "csv"
|
||||
})
|
||||
|
||||
# Parse data
|
||||
if isinstance(raw_data, str) and "Error" not in raw_data and "No data" not in raw_data:
|
||||
# Parse data (Standardized CSV format with # comments)
|
||||
df = pd.read_csv(StringIO(raw_data), comment='#')
|
||||
|
||||
# Check for Close column
|
||||
if 'Close' in df.columns:
|
||||
# Detect Regime
|
||||
regime, metrics = RegimeDetector.detect_regime(df['Close'])
|
||||
optimal_params = DynamicIndicatorSelector.get_optimal_parameters(regime)
|
||||
regime_val = regime.value
|
||||
volatility_score = metrics.get("volatility", 0.0)
|
||||
|
||||
# Construct Context String
|
||||
regime_context = f"MARKET REGIME DETECTED: {regime_val}\n"
|
||||
regime_context += f"METRICS: {json.dumps(metrics)}\n"
|
||||
regime_context += f"RECOMMENDED STRATEGY: {optimal_params.get('strategy', 'N/A')}\n"
|
||||
regime_context += f"RECOMMENDED INDICATORS: {json.dumps(optimal_params)}\n"
|
||||
regime_context += f"RATIONALE: {optimal_params.get('rationale', '')}"
|
||||
except Exception as e:
|
||||
print(f"WARNING: Regime detection failed for {ticker}: {e}")
|
||||
|
||||
tools = [
|
||||
get_stock_data,
|
||||
get_indicators,
|
||||
|
|
@ -36,7 +86,14 @@ CRITICAL DATA CONSTRAINT:
|
|||
2. "Price 105.0" means +5% gain from start. It does NOT mean $105.00.
|
||||
3. DO NOT hallucinate real-world ticker prices. Treat this as a pure mathematical time series.
|
||||
|
||||
TASK: Select relevant indicators and analyze trends. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
DYNAMIC MARKET REGIME CONTEXT:
|
||||
{regime_context}
|
||||
|
||||
TASK: Select relevant indicators and analyze trends.
|
||||
Your role is to select the **most relevant indicators** for the DETECTED REGIME ({regime_val}).
|
||||
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
|
||||
|
||||
INDICATOR CATEGORIES:
|
||||
|
||||
Moving Averages:
|
||||
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
||||
|
|
@ -98,6 +155,9 @@ Volume-Based Indicators:
|
|||
return {
|
||||
"messages": [result],
|
||||
"market_report": report,
|
||||
"market_regime": regime_val,
|
||||
"regime_metrics": metrics,
|
||||
"volatility_score": volatility_score
|
||||
}
|
||||
|
||||
return market_analyst_node
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ class AgentState(MessagesState):
|
|||
|
||||
# regime data
|
||||
market_regime: Annotated[str, "Current Market Regime (e.g. VOLATILE, TRENDING_UP)"]
|
||||
regime_metrics: Annotated[dict, "Metrics used to determine regime"]
|
||||
volatility_score: Annotated[float, "Current Volatility Score"]
|
||||
|
||||
# researcher team discussion step
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ def get_stock_data(
|
|||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
format: Annotated[str, "Output format 'csv' or 'string' (default: 'string')"] = "string"
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve stock price data (OHLCV) for a given ticker symbol.
|
||||
|
|
@ -16,6 +17,7 @@ def get_stock_data(
|
|||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
format (str): 'csv' or 'string'
|
||||
Returns:
|
||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||
"""
|
||||
|
|
@ -28,7 +30,7 @@ def get_stock_data(
|
|||
real_ticker = symbol # Fallback if not anonymized
|
||||
|
||||
# 2. Get Data using Real Ticker
|
||||
raw_data = route_to_vendor("get_stock_data", real_ticker, start_date, end_date)
|
||||
raw_data = route_to_vendor("get_stock_data", real_ticker, start_date, end_date, format=format)
|
||||
|
||||
# 3. Anonymize Output (AAPL -> ASSET_XXX)
|
||||
anonymized_data = anonymizer.anonymize_text(raw_data, real_ticker)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import pandas as pd
|
|||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) -> str:
|
||||
def get_stock_data(symbol: str, start_date: str = None, end_date: str = None, format: str = "string") -> str:
|
||||
"""
|
||||
Fetch historical stock data (OHLCV) from Alpaca Data API v2.
|
||||
|
||||
|
|
@ -12,6 +12,7 @@ def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) ->
|
|||
symbol: Ticker symbol (e.g., "AAPL")
|
||||
start_date: Start date (YYYY-MM-DD), defaults to 1 year ago
|
||||
end_date: End date (YYYY-MM-DD), defaults to today
|
||||
format: Output format "string" (human readable) or "csv" (machine readable). Defaults to "string".
|
||||
|
||||
Returns:
|
||||
String representation of the dataframe
|
||||
|
|
@ -84,8 +85,12 @@ def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) ->
|
|||
df = pd.DataFrame(df_data)
|
||||
|
||||
# Format output string similar to yfinance output for consistency
|
||||
result_str = f"Stock Data for {symbol} from {start_date} to {end_date}\n"
|
||||
result_str += df.to_string(index=False)
|
||||
result_str = f"# Stock Data for {symbol} from {start_date} to {end_date}\n"
|
||||
|
||||
if format.lower() == "csv":
|
||||
result_str += df.to_csv(index=False)
|
||||
else:
|
||||
result_str += df.to_string(index=False)
|
||||
|
||||
return result_str
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ def get_YFin_data_online(
|
|||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
format: Annotated[str, "Output format 'csv' or 'string'"] = "csv"
|
||||
):
|
||||
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
|
|
@ -36,15 +37,20 @@ def get_YFin_data_online(
|
|||
if col in data.columns:
|
||||
data[col] = data[col].round(2)
|
||||
|
||||
# Convert DataFrame to CSV string
|
||||
csv_string = data.to_csv()
|
||||
# Convert DataFrame to string based on format
|
||||
if format.lower() == 'string':
|
||||
# Use to_string for human readability
|
||||
result_string = data.to_string()
|
||||
else:
|
||||
# Default to CSV
|
||||
result_string = data.to_csv()
|
||||
|
||||
# Add header information
|
||||
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
|
||||
header += f"# Total records: {len(data)}\n"
|
||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
|
||||
return header + csv_string
|
||||
return header + result_string
|
||||
|
||||
def get_stock_stats_indicators_window(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,63 @@
|
|||
import pandas as pd
|
||||
from tradingagents.engines.regime_detector import RegimeDetector, DynamicIndicatorSelector
|
||||
from tradingagents.agents.utils.agent_utils import get_stock_data
|
||||
from io import StringIO
|
||||
import json
|
||||
|
||||
def verify_regime():
|
||||
print("🔬 Verifying Regime Detector Integration...")
|
||||
|
||||
ticker = "AAPL"
|
||||
print(f"Fetching data for {ticker}...")
|
||||
|
||||
# Simulate what market_analyst_node does
|
||||
try:
|
||||
# Use invoke for StructuredTool
|
||||
# Provide dates (using a recent 1-year window relative to now implicitly, or fixed dates if tool supports it)
|
||||
# Assuming Alpaca data is available for this range
|
||||
raw_data = get_stock_data.invoke({
|
||||
"symbol": ticker,
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-12-31",
|
||||
"format": "csv"
|
||||
})
|
||||
|
||||
if "Error" in raw_data:
|
||||
print(f"❌ FAIL: Data fetch error: {raw_data}")
|
||||
return
|
||||
|
||||
# Data has '#' comments in header, and is standard CSV
|
||||
df = pd.read_csv(StringIO(raw_data), comment='#')
|
||||
print(f"✅ Data fetched: {len(df)} rows")
|
||||
print(f"COLUMNS: {df.columns.tolist()}")
|
||||
print(f"HEAD:\n{df.head()}")
|
||||
|
||||
if 'Close' not in df.columns:
|
||||
# Try case insensitive or check if it's in index?
|
||||
# Sometimes to_string creates a weird header structure
|
||||
pass
|
||||
print("❌ FAIL: 'Close' column missing")
|
||||
return
|
||||
|
||||
# Run Detector
|
||||
print("Running RegimeDetector...")
|
||||
regime, metrics = RegimeDetector.detect_regime(df['Close'])
|
||||
|
||||
print(f"✅ DETECTED REGIME: {regime.value}")
|
||||
print(f" Volatility: {metrics['volatility']:.2%}")
|
||||
print(f" Trend Strength: {metrics['trend_strength']:.2f}")
|
||||
|
||||
# Run Selector
|
||||
optimal_params = DynamicIndicatorSelector.get_optimal_parameters(regime)
|
||||
print(f"✅ RECOMMENDED STRATEGY: {optimal_params['strategy']}")
|
||||
print(f" Indicators: {[k for k in optimal_params.keys() if 'period' in k]}")
|
||||
|
||||
print("\n🎉 INTEGRATION VERIFIED: The engine is analyzing data correctly.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL: Exception: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
verify_regime()
|
||||
Loading…
Reference in New Issue