Phase 2 changes
This commit is contained in:
parent
af82fd9601
commit
54a3395b37
|
|
@ -5,7 +5,12 @@ from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicator
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from tradingagents.engines.regime_detector import RegimeDetector, DynamicIndicatorSelector
|
||||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||||
|
import pandas as pd
|
||||||
|
from io import StringIO
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
# Initialize anonymizer (shared instance appropriate here or inside)
|
# Initialize anonymizer (shared instance appropriate here or inside)
|
||||||
anonymizer = TickerAnonymizer()
|
anonymizer = TickerAnonymizer()
|
||||||
|
|
@ -23,6 +28,51 @@ def create_market_analyst(llm):
|
||||||
|
|
||||||
# NOTE: We continue to use 'ticker' variable name but it now holds 'ASSET_XXX'
|
# NOTE: We continue to use 'ticker' variable name but it now holds 'ASSET_XXX'
|
||||||
|
|
||||||
|
# REGIME DETECTION LOGIC
|
||||||
|
regime_val = "UNKNOWN"
|
||||||
|
metrics = {}
|
||||||
|
optimal_params = {}
|
||||||
|
regime_context = "REGIME DETECTION FAILED or DATA UNAVAILABLE"
|
||||||
|
volatility_score = 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Calculate start date (1 year lookback for robust regime detection)
|
||||||
|
dt_obj = datetime.strptime(current_date, "%Y-%m-%d")
|
||||||
|
start_date = (dt_obj - timedelta(days=365)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Fetch data for regime detection using the anonymized ticker
|
||||||
|
# This calls the tool which handles deanonymization internally if needed
|
||||||
|
# (assuming core_stock_tools.get_stock_data handles the 'ASSET_XXX' -> Real mapping)
|
||||||
|
# Use invoke for StructuredTool with ALL required args
|
||||||
|
raw_data = get_stock_data.invoke({
|
||||||
|
"symbol": ticker,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": current_date,
|
||||||
|
"format": "csv"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Parse data
|
||||||
|
if isinstance(raw_data, str) and "Error" not in raw_data and "No data" not in raw_data:
|
||||||
|
# Parse data (Standardized CSV format with # comments)
|
||||||
|
df = pd.read_csv(StringIO(raw_data), comment='#')
|
||||||
|
|
||||||
|
# Check for Close column
|
||||||
|
if 'Close' in df.columns:
|
||||||
|
# Detect Regime
|
||||||
|
regime, metrics = RegimeDetector.detect_regime(df['Close'])
|
||||||
|
optimal_params = DynamicIndicatorSelector.get_optimal_parameters(regime)
|
||||||
|
regime_val = regime.value
|
||||||
|
volatility_score = metrics.get("volatility", 0.0)
|
||||||
|
|
||||||
|
# Construct Context String
|
||||||
|
regime_context = f"MARKET REGIME DETECTED: {regime_val}\n"
|
||||||
|
regime_context += f"METRICS: {json.dumps(metrics)}\n"
|
||||||
|
regime_context += f"RECOMMENDED STRATEGY: {optimal_params.get('strategy', 'N/A')}\n"
|
||||||
|
regime_context += f"RECOMMENDED INDICATORS: {json.dumps(optimal_params)}\n"
|
||||||
|
regime_context += f"RATIONALE: {optimal_params.get('rationale', '')}"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WARNING: Regime detection failed for {ticker}: {e}")
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
get_indicators,
|
get_indicators,
|
||||||
|
|
@ -36,7 +86,14 @@ CRITICAL DATA CONSTRAINT:
|
||||||
2. "Price 105.0" means +5% gain from start. It does NOT mean $105.00.
|
2. "Price 105.0" means +5% gain from start. It does NOT mean $105.00.
|
||||||
3. DO NOT hallucinate real-world ticker prices. Treat this as a pure mathematical time series.
|
3. DO NOT hallucinate real-world ticker prices. Treat this as a pure mathematical time series.
|
||||||
|
|
||||||
TASK: Select relevant indicators and analyze trends. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
DYNAMIC MARKET REGIME CONTEXT:
|
||||||
|
{regime_context}
|
||||||
|
|
||||||
|
TASK: Select relevant indicators and analyze trends.
|
||||||
|
Your role is to select the **most relevant indicators** for the DETECTED REGIME ({regime_val}).
|
||||||
|
The goal is to choose up to **8 indicators** that provide complementary insights without redundancy.
|
||||||
|
|
||||||
|
INDICATOR CATEGORIES:
|
||||||
|
|
||||||
Moving Averages:
|
Moving Averages:
|
||||||
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
||||||
|
|
@ -98,6 +155,9 @@ Volume-Based Indicators:
|
||||||
return {
|
return {
|
||||||
"messages": [result],
|
"messages": [result],
|
||||||
"market_report": report,
|
"market_report": report,
|
||||||
|
"market_regime": regime_val,
|
||||||
|
"regime_metrics": metrics,
|
||||||
|
"volatility_score": volatility_score
|
||||||
}
|
}
|
||||||
|
|
||||||
return market_analyst_node
|
return market_analyst_node
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,7 @@ class AgentState(MessagesState):
|
||||||
|
|
||||||
# regime data
|
# regime data
|
||||||
market_regime: Annotated[str, "Current Market Regime (e.g. VOLATILE, TRENDING_UP)"]
|
market_regime: Annotated[str, "Current Market Regime (e.g. VOLATILE, TRENDING_UP)"]
|
||||||
|
regime_metrics: Annotated[dict, "Metrics used to determine regime"]
|
||||||
volatility_score: Annotated[float, "Current Volatility Score"]
|
volatility_score: Annotated[float, "Current Volatility Score"]
|
||||||
|
|
||||||
# researcher team discussion step
|
# researcher team discussion step
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ def get_stock_data(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||||
|
format: Annotated[str, "Output format 'csv' or 'string' (default: 'string')"] = "string"
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Retrieve stock price data (OHLCV) for a given ticker symbol.
|
Retrieve stock price data (OHLCV) for a given ticker symbol.
|
||||||
|
|
@ -16,6 +17,7 @@ def get_stock_data(
|
||||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||||
start_date (str): Start date in yyyy-mm-dd format
|
start_date (str): Start date in yyyy-mm-dd format
|
||||||
end_date (str): End date in yyyy-mm-dd format
|
end_date (str): End date in yyyy-mm-dd format
|
||||||
|
format (str): 'csv' or 'string'
|
||||||
Returns:
|
Returns:
|
||||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||||
"""
|
"""
|
||||||
|
|
@ -28,7 +30,7 @@ def get_stock_data(
|
||||||
real_ticker = symbol # Fallback if not anonymized
|
real_ticker = symbol # Fallback if not anonymized
|
||||||
|
|
||||||
# 2. Get Data using Real Ticker
|
# 2. Get Data using Real Ticker
|
||||||
raw_data = route_to_vendor("get_stock_data", real_ticker, start_date, end_date)
|
raw_data = route_to_vendor("get_stock_data", real_ticker, start_date, end_date, format=format)
|
||||||
|
|
||||||
# 3. Anonymize Output (AAPL -> ASSET_XXX)
|
# 3. Anonymize Output (AAPL -> ASSET_XXX)
|
||||||
anonymized_data = anonymizer.anonymize_text(raw_data, real_ticker)
|
anonymized_data = anonymizer.anonymize_text(raw_data, real_ticker)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import pandas as pd
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) -> str:
|
def get_stock_data(symbol: str, start_date: str = None, end_date: str = None, format: str = "string") -> str:
|
||||||
"""
|
"""
|
||||||
Fetch historical stock data (OHLCV) from Alpaca Data API v2.
|
Fetch historical stock data (OHLCV) from Alpaca Data API v2.
|
||||||
|
|
||||||
|
|
@ -12,6 +12,7 @@ def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) ->
|
||||||
symbol: Ticker symbol (e.g., "AAPL")
|
symbol: Ticker symbol (e.g., "AAPL")
|
||||||
start_date: Start date (YYYY-MM-DD), defaults to 1 year ago
|
start_date: Start date (YYYY-MM-DD), defaults to 1 year ago
|
||||||
end_date: End date (YYYY-MM-DD), defaults to today
|
end_date: End date (YYYY-MM-DD), defaults to today
|
||||||
|
format: Output format "string" (human readable) or "csv" (machine readable). Defaults to "string".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
String representation of the dataframe
|
String representation of the dataframe
|
||||||
|
|
@ -84,8 +85,12 @@ def get_stock_data(symbol: str, start_date: str = None, end_date: str = None) ->
|
||||||
df = pd.DataFrame(df_data)
|
df = pd.DataFrame(df_data)
|
||||||
|
|
||||||
# Format output string similar to yfinance output for consistency
|
# Format output string similar to yfinance output for consistency
|
||||||
result_str = f"Stock Data for {symbol} from {start_date} to {end_date}\n"
|
result_str = f"# Stock Data for {symbol} from {start_date} to {end_date}\n"
|
||||||
result_str += df.to_string(index=False)
|
|
||||||
|
if format.lower() == "csv":
|
||||||
|
result_str += df.to_csv(index=False)
|
||||||
|
else:
|
||||||
|
result_str += df.to_string(index=False)
|
||||||
|
|
||||||
return result_str
|
return result_str
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ def get_YFin_data_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||||
|
format: Annotated[str, "Output format 'csv' or 'string'"] = "csv"
|
||||||
):
|
):
|
||||||
|
|
||||||
datetime.strptime(start_date, "%Y-%m-%d")
|
datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
|
@ -36,15 +37,20 @@ def get_YFin_data_online(
|
||||||
if col in data.columns:
|
if col in data.columns:
|
||||||
data[col] = data[col].round(2)
|
data[col] = data[col].round(2)
|
||||||
|
|
||||||
# Convert DataFrame to CSV string
|
# Convert DataFrame to string based on format
|
||||||
csv_string = data.to_csv()
|
if format.lower() == 'string':
|
||||||
|
# Use to_string for human readability
|
||||||
|
result_string = data.to_string()
|
||||||
|
else:
|
||||||
|
# Default to CSV
|
||||||
|
result_string = data.to_csv()
|
||||||
|
|
||||||
# Add header information
|
# Add header information
|
||||||
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
|
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
|
||||||
header += f"# Total records: {len(data)}\n"
|
header += f"# Total records: {len(data)}\n"
|
||||||
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||||
|
|
||||||
return header + csv_string
|
return header + result_string
|
||||||
|
|
||||||
def get_stock_stats_indicators_window(
|
def get_stock_stats_indicators_window(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
import pandas as pd
|
||||||
|
from tradingagents.engines.regime_detector import RegimeDetector, DynamicIndicatorSelector
|
||||||
|
from tradingagents.agents.utils.agent_utils import get_stock_data
|
||||||
|
from io import StringIO
|
||||||
|
import json
|
||||||
|
|
||||||
|
def verify_regime():
|
||||||
|
print("🔬 Verifying Regime Detector Integration...")
|
||||||
|
|
||||||
|
ticker = "AAPL"
|
||||||
|
print(f"Fetching data for {ticker}...")
|
||||||
|
|
||||||
|
# Simulate what market_analyst_node does
|
||||||
|
try:
|
||||||
|
# Use invoke for StructuredTool
|
||||||
|
# Provide dates (using a recent 1-year window relative to now implicitly, or fixed dates if tool supports it)
|
||||||
|
# Assuming Alpaca data is available for this range
|
||||||
|
raw_data = get_stock_data.invoke({
|
||||||
|
"symbol": ticker,
|
||||||
|
"start_date": "2024-01-01",
|
||||||
|
"end_date": "2024-12-31",
|
||||||
|
"format": "csv"
|
||||||
|
})
|
||||||
|
|
||||||
|
if "Error" in raw_data:
|
||||||
|
print(f"❌ FAIL: Data fetch error: {raw_data}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Data has '#' comments in header, and is standard CSV
|
||||||
|
df = pd.read_csv(StringIO(raw_data), comment='#')
|
||||||
|
print(f"✅ Data fetched: {len(df)} rows")
|
||||||
|
print(f"COLUMNS: {df.columns.tolist()}")
|
||||||
|
print(f"HEAD:\n{df.head()}")
|
||||||
|
|
||||||
|
if 'Close' not in df.columns:
|
||||||
|
# Try case insensitive or check if it's in index?
|
||||||
|
# Sometimes to_string creates a weird header structure
|
||||||
|
pass
|
||||||
|
print("❌ FAIL: 'Close' column missing")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run Detector
|
||||||
|
print("Running RegimeDetector...")
|
||||||
|
regime, metrics = RegimeDetector.detect_regime(df['Close'])
|
||||||
|
|
||||||
|
print(f"✅ DETECTED REGIME: {regime.value}")
|
||||||
|
print(f" Volatility: {metrics['volatility']:.2%}")
|
||||||
|
print(f" Trend Strength: {metrics['trend_strength']:.2f}")
|
||||||
|
|
||||||
|
# Run Selector
|
||||||
|
optimal_params = DynamicIndicatorSelector.get_optimal_parameters(regime)
|
||||||
|
print(f"✅ RECOMMENDED STRATEGY: {optimal_params['strategy']}")
|
||||||
|
print(f" Indicators: {[k for k in optimal_params.keys() if 'period' in k]}")
|
||||||
|
|
||||||
|
print("\n🎉 INTEGRATION VERIFIED: The engine is analyzing data correctly.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ FAIL: Exception: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
verify_regime()
|
||||||
Loading…
Reference in New Issue