support get price data from alpha vantage
This commit is contained in:
parent
a438acdbbd
commit
86b357d77a
|
|
@ -123,6 +123,9 @@ You will need the OpenAI API for all the agents.
|
|||
```bash
|
||||
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
|
||||
```
|
||||
**Optional** Stock Price Data
|
||||
|
||||
Since YFinance may not work, you can add alpha vantage api key in default_config.py.
|
||||
|
||||
### CLI Usage
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ def create_market_analyst(llm, toolkit):
|
|||
tools = [
|
||||
toolkit.get_YFin_data_online,
|
||||
toolkit.get_stockstats_indicators_report_online,
|
||||
toolkit.get_alpha_vantage_data,
|
||||
toolkit.get_alpha_vantage_summary_signals,
|
||||
]
|
||||
else:
|
||||
tools = [
|
||||
|
|
|
|||
|
|
@ -161,6 +161,45 @@ class Toolkit:
|
|||
|
||||
return result_data
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_alpha_vantage_data(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the stock price data for a given ticker symbol from Alpha Vantage.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||
"""
|
||||
|
||||
result_data = interface.get_alpha_vantage_data(symbol, start_date, end_date)
|
||||
return result_data
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_alpha_vantage_summary_signals(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"]
|
||||
) -> str:
|
||||
"""
|
||||
Calculate technical indicators and signals based on stock price data from Alpha Vantage.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: Formatted technical indicators summary and signals.
|
||||
"""
|
||||
result_data = interface.get_alpha_vantage_summary_signals(symbol, start_date, end_date)
|
||||
return result_data
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_stockstats_indicators_report(
|
||||
|
|
|
|||
|
|
@ -805,3 +805,328 @@ def get_fundamentals_openai(ticker, curr_date):
|
|||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
|
||||
def get_alpha_vantage_data(symbol, start_date, end_date):
|
||||
config = get_config()
|
||||
api_key = config["alpha_vantage_api_key"]
|
||||
|
||||
url = "https://www.alphavantage.co/query"
|
||||
params = {
|
||||
'function': 'TIME_SERIES_DAILY',
|
||||
'symbol': symbol,
|
||||
'outputsize': 'full',
|
||||
'apikey': api_key
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params)
|
||||
data = response.json()
|
||||
|
||||
print("API 响应:", list(data.keys())) # 调试信息
|
||||
|
||||
if 'Time Series (Daily)' in data:
|
||||
df = pd.DataFrame(data['Time Series (Daily)']).T
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df.sort_index()
|
||||
|
||||
# 重命名列并转换数据类型
|
||||
df.columns = ['Open', 'High', 'Low', 'Close', 'Volume']
|
||||
for col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col])
|
||||
|
||||
# 过滤日期范围
|
||||
if start_date:
|
||||
df = df[df.index >= start_date]
|
||||
if end_date:
|
||||
df = df[df.index <= end_date]
|
||||
|
||||
# # Add header information
|
||||
# header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
|
||||
# header += f"# Total records: {len(data)}\n"
|
||||
# header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
|
||||
return df
|
||||
else:
|
||||
print("错误或限制:", data)
|
||||
return None
|
||||
|
||||
|
||||
def add_technical_indicators(df):
|
||||
"""
|
||||
为股票数据添加技术指标
|
||||
|
||||
参数:
|
||||
df: DataFrame with columns ['Open', 'High', 'Low', 'Close', 'Volume']
|
||||
|
||||
返回:
|
||||
DataFrame with additional technical indicator columns
|
||||
"""
|
||||
|
||||
# 创建副本以避免修改原数据
|
||||
data = df.copy()
|
||||
|
||||
# 确保数据按日期排序
|
||||
data = data.sort_index()
|
||||
|
||||
# =============================================================================
|
||||
# Moving Averages 移动平均线
|
||||
# =============================================================================
|
||||
|
||||
# 50日简单移动平均线
|
||||
data['close_50_sma'] = data['Close'].rolling(window=50).mean()
|
||||
|
||||
# 200日简单移动平均线
|
||||
data['close_200_sma'] = data['Close'].rolling(window=200).mean()
|
||||
|
||||
# 10日指数移动平均线
|
||||
data['close_10_ema'] = data['Close'].ewm(span=10).mean()
|
||||
|
||||
# =============================================================================
|
||||
# MACD 指标
|
||||
# =============================================================================
|
||||
|
||||
# 计算12日和26日EMA
|
||||
ema_12 = data['Close'].ewm(span=12).mean()
|
||||
ema_26 = data['Close'].ewm(span=26).mean()
|
||||
|
||||
# MACD线 = 12EMA - 26EMA
|
||||
data['macd'] = ema_12 - ema_26
|
||||
|
||||
# MACD信号线 = MACD的9日EMA
|
||||
data['macds'] = data['macd'].ewm(span=9).mean()
|
||||
|
||||
# MACD柱状图 = MACD - 信号线
|
||||
data['macdh'] = data['macd'] - data['macds']
|
||||
|
||||
# =============================================================================
|
||||
# RSI 相对强弱指数
|
||||
# =============================================================================
|
||||
|
||||
def calculate_rsi(prices, period=14):
|
||||
"""计算RSI"""
|
||||
delta = prices.diff()
|
||||
gain = delta.where(delta > 0, 0)
|
||||
loss = -delta.where(delta < 0, 0)
|
||||
|
||||
avg_gain = gain.rolling(window=period).mean()
|
||||
avg_loss = loss.rolling(window=period).mean()
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi
|
||||
|
||||
data['rsi'] = calculate_rsi(data['Close'])
|
||||
|
||||
# =============================================================================
|
||||
# Bollinger Bands 布林带
|
||||
# =============================================================================
|
||||
|
||||
# 布林带中线(20日SMA)
|
||||
data['boll'] = data['Close'].rolling(window=20).mean()
|
||||
|
||||
# 计算20日标准差
|
||||
std_20 = data['Close'].rolling(window=20).std()
|
||||
|
||||
# 布林带上轨 = 中线 + 2倍标准差
|
||||
data['boll_ub'] = data['boll'] + (2 * std_20)
|
||||
|
||||
# 布林带下轨 = 中线 - 2倍标准差
|
||||
data['boll_lb'] = data['boll'] - (2 * std_20)
|
||||
|
||||
# =============================================================================
|
||||
# ATR 平均真实波幅
|
||||
# =============================================================================
|
||||
|
||||
def calculate_atr(high, low, close, period=14):
|
||||
"""计算ATR"""
|
||||
# 真实波幅的三种计算方式
|
||||
tr1 = high - low
|
||||
tr2 = abs(high - close.shift(1))
|
||||
tr3 = abs(low - close.shift(1))
|
||||
|
||||
# 取最大值作为真实波幅
|
||||
true_range = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
||||
|
||||
# ATR = 真实波幅的移动平均
|
||||
atr = true_range.rolling(window=period).mean()
|
||||
return atr
|
||||
|
||||
data['atr'] = calculate_atr(data['High'], data['Low'], data['Close'])
|
||||
|
||||
# =============================================================================
|
||||
# VWMA 成交量加权移动平均
|
||||
# =============================================================================
|
||||
|
||||
def calculate_vwma(close, volume, period=20):
|
||||
"""计算成交量加权移动平均"""
|
||||
volume_price = close * volume
|
||||
vwma = volume_price.rolling(window=period).sum() / volume.rolling(window=period).sum()
|
||||
return vwma
|
||||
|
||||
data['vwma'] = calculate_vwma(data['Close'], data['Volume'])
|
||||
|
||||
# =============================================================================
|
||||
# 额外计算一些有用的信号
|
||||
# =============================================================================
|
||||
|
||||
# 金叉死叉信号 (50 SMA vs 200 SMA)
|
||||
data['golden_cross'] = (data['close_50_sma'] > data['close_200_sma']) & \
|
||||
(data['close_50_sma'].shift(1) <= data['close_200_sma'].shift(1))
|
||||
|
||||
data['death_cross'] = (data['close_50_sma'] < data['close_200_sma']) & \
|
||||
(data['close_50_sma'].shift(1) >= data['close_200_sma'].shift(1))
|
||||
|
||||
# MACD信号
|
||||
data['macd_bullish'] = (data['macd'] > data['macds']) & \
|
||||
(data['macd'].shift(1) <= data['macds'].shift(1))
|
||||
|
||||
data['macd_bearish'] = (data['macd'] < data['macds']) & \
|
||||
(data['macd'].shift(1) >= data['macds'].shift(1))
|
||||
|
||||
# RSI信号
|
||||
data['rsi_overbought'] = data['rsi'] > 70
|
||||
data['rsi_oversold'] = data['rsi'] < 30
|
||||
|
||||
# 布林带信号
|
||||
data['price_above_upper_band'] = data['Close'] > data['boll_ub']
|
||||
data['price_below_lower_band'] = data['Close'] < data['boll_lb']
|
||||
|
||||
# 显示统计信息
|
||||
total_indicators = 11 # 主要技术指标数量
|
||||
valid_rows = len(data.dropna())
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def display_indicator_summary(df):
|
||||
"""Generate technical indicator summary string"""
|
||||
|
||||
latest = df.iloc[-1] # Latest data
|
||||
|
||||
summary = f"\n📈 Latest Technical Indicator Summary ({latest.name.date()}):\n"
|
||||
summary += "=" * 50 + "\n"
|
||||
|
||||
# Price information
|
||||
summary += f"Close Price: ${latest['Close']:.2f}\n"
|
||||
|
||||
# Moving averages
|
||||
summary += f"\n📊 Moving Averages:\n"
|
||||
if not pd.isna(latest['close_10_ema']):
|
||||
summary += f" 10-day EMA: ${latest['close_10_ema']:.2f}\n"
|
||||
if not pd.isna(latest['close_50_sma']):
|
||||
summary += f" 50-day SMA: ${latest['close_50_sma']:.2f}\n"
|
||||
if not pd.isna(latest['close_200_sma']):
|
||||
summary += f" 200-day SMA: ${latest['close_200_sma']:.2f}\n"
|
||||
|
||||
# Trend signals
|
||||
if not pd.isna(latest['close_50_sma']) and not pd.isna(latest['close_200_sma']):
|
||||
trend = "Uptrend 📈" if latest['close_50_sma'] > latest['close_200_sma'] else "Downtrend 📉"
|
||||
summary += f" Long-term Trend: {trend}\n"
|
||||
|
||||
# MACD
|
||||
summary += f"\n📈 MACD:\n"
|
||||
if not pd.isna(latest['macd']):
|
||||
summary += f" MACD: {latest['macd']:.4f}\n"
|
||||
summary += f" Signal Line: {latest['macds']:.4f}\n"
|
||||
summary += f" Histogram: {latest['macdh']:.4f}\n"
|
||||
|
||||
if latest['macd'] > latest['macds']:
|
||||
summary += " 📈 MACD Bullish Signal\n"
|
||||
else:
|
||||
summary += " 📉 MACD Bearish Signal\n"
|
||||
|
||||
# RSI
|
||||
summary += f"\n⚡ RSI:\n"
|
||||
if not pd.isna(latest['rsi']):
|
||||
summary += f" RSI(14): {latest['rsi']:.2f}\n"
|
||||
if latest['rsi'] > 70:
|
||||
summary += " ⚠️ Overbought Zone\n"
|
||||
elif latest['rsi'] < 30:
|
||||
summary += " ⚠️ Oversold Zone\n"
|
||||
else:
|
||||
summary += " ✅ Normal Zone\n"
|
||||
|
||||
# Bollinger Bands
|
||||
summary += f"\n📊 Bollinger Bands:\n"
|
||||
if not pd.isna(latest['boll']):
|
||||
summary += f" Upper Band: ${latest['boll_ub']:.2f}\n"
|
||||
summary += f" Middle Band: ${latest['boll']:.2f}\n"
|
||||
summary += f" Lower Band: ${latest['boll_lb']:.2f}\n"
|
||||
|
||||
if latest['Close'] > latest['boll_ub']:
|
||||
summary += " 📈 Price Above Upper Band\n"
|
||||
elif latest['Close'] < latest['boll_lb']:
|
||||
summary += " 📉 Price Below Lower Band\n"
|
||||
else:
|
||||
summary += " ✅ Price Within Bands\n"
|
||||
|
||||
# ATR and VWMA
|
||||
summary += f"\n📊 Other Indicators:\n"
|
||||
if not pd.isna(latest['atr']):
|
||||
summary += f" ATR(14): {latest['atr']:.2f} (Volatility)\n"
|
||||
if not pd.isna(latest['vwma']):
|
||||
summary += f" VWMA(20): ${latest['vwma']:.2f} (Volume-Weighted MA)\n"
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def get_trading_signals(df):
|
||||
"""Get recent trading signals as string"""
|
||||
|
||||
recent_data = df.tail(5) # Last 5 days
|
||||
signals = []
|
||||
|
||||
for date, row in recent_data.iterrows():
|
||||
day_signals = []
|
||||
|
||||
if row.get('golden_cross', False):
|
||||
day_signals.append("🟢 Golden Cross")
|
||||
if row.get('death_cross', False):
|
||||
day_signals.append("🔴 Death Cross")
|
||||
if row.get('macd_bullish', False):
|
||||
day_signals.append("📈 MACD Bullish")
|
||||
if row.get('macd_bearish', False):
|
||||
day_signals.append("📉 MACD Bearish")
|
||||
if row.get('rsi_overbought', False):
|
||||
day_signals.append("⚠️ RSI Overbought")
|
||||
if row.get('rsi_oversold', False):
|
||||
day_signals.append("⚠️ RSI Oversold")
|
||||
if row.get('price_above_upper_band', False):
|
||||
day_signals.append("📈 Above Upper Band")
|
||||
if row.get('price_below_lower_band', False):
|
||||
day_signals.append("📉 Below Lower Band")
|
||||
|
||||
if day_signals:
|
||||
signals.append(f"{date.date()}: {', '.join(day_signals)}")
|
||||
|
||||
if signals:
|
||||
result = "\n🚨 Recent Trading Signals:\n"
|
||||
result += "=" * 30 + "\n"
|
||||
for signal in signals:
|
||||
result += f" {signal}\n"
|
||||
return result
|
||||
else:
|
||||
return "\n✅ No significant trading signals in the last 5 days\n"
|
||||
|
||||
def get_alpha_vantage_summary_signals(symbol, start_date, end_date):
|
||||
# 1. 获取股票数据 (使用前面的Alpha Vantage函数)
|
||||
data = get_alpha_vantage_data(symbol, start_date, end_date)
|
||||
|
||||
# 2. 添加技术指标
|
||||
data_with_indicators = add_technical_indicators(data)
|
||||
|
||||
# 3. 查看指标摘要
|
||||
summary = display_indicator_summary(data_with_indicators)
|
||||
|
||||
# 4. 查看交易信号
|
||||
signals = get_trading_signals(data_with_indicators)
|
||||
|
||||
# 5. 查看特定指标
|
||||
df_tail = data_with_indicators[['Close', 'rsi', 'macd', 'close_50_sma']].tail()
|
||||
tail_string = "\nRecent Prices\n" + df_tail.to_string()
|
||||
|
||||
return summary + signals + tail_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = get_alpha_vantage_summary_signals('NVDA', '2024-01-01', '2024-12-31')
|
||||
print(text)
|
||||
|
|
|
|||
|
|
@ -19,4 +19,5 @@ DEFAULT_CONFIG = {
|
|||
"max_recur_limit": 100,
|
||||
# Tool settings
|
||||
"online_tools": True,
|
||||
"alpha_vantage_api_key": "$YOUR_API_KEY"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -117,6 +117,8 @@ class TradingAgentsGraph:
|
|||
# online tools
|
||||
self.toolkit.get_YFin_data_online,
|
||||
self.toolkit.get_stockstats_indicators_report_online,
|
||||
self.toolkit.get_alpha_vantage_data,
|
||||
self.toolkit.get_alpha_vantage_summary_signals,
|
||||
# offline tools
|
||||
self.toolkit.get_YFin_data,
|
||||
self.toolkit.get_stockstats_indicators_report,
|
||||
|
|
@ -127,7 +129,7 @@ class TradingAgentsGraph:
|
|||
# online tools
|
||||
self.toolkit.get_stock_news_openai,
|
||||
# offline tools
|
||||
self.toolkit.get_reddit_stock_info,
|
||||
# self.toolkit.get_reddit_stock_info,
|
||||
]
|
||||
),
|
||||
"news": ToolNode(
|
||||
|
|
@ -136,8 +138,8 @@ class TradingAgentsGraph:
|
|||
self.toolkit.get_global_news_openai,
|
||||
self.toolkit.get_google_news,
|
||||
# offline tools
|
||||
self.toolkit.get_finnhub_news,
|
||||
self.toolkit.get_reddit_news,
|
||||
# self.toolkit.get_finnhub_news,
|
||||
# self.toolkit.get_reddit_news,
|
||||
]
|
||||
),
|
||||
"fundamentals": ToolNode(
|
||||
|
|
@ -145,11 +147,11 @@ class TradingAgentsGraph:
|
|||
# online tools
|
||||
self.toolkit.get_fundamentals_openai,
|
||||
# offline tools
|
||||
self.toolkit.get_finnhub_company_insider_sentiment,
|
||||
self.toolkit.get_finnhub_company_insider_transactions,
|
||||
self.toolkit.get_simfin_balance_sheet,
|
||||
self.toolkit.get_simfin_cashflow,
|
||||
self.toolkit.get_simfin_income_stmt,
|
||||
# self.toolkit.get_finnhub_company_insider_sentiment,
|
||||
# self.toolkit.get_finnhub_company_insider_transactions,
|
||||
# self.toolkit.get_simfin_balance_sheet,
|
||||
# self.toolkit.get_simfin_cashflow,
|
||||
# self.toolkit.get_simfin_income_stmt,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue