support get price data from alpha vantage

This commit is contained in:
auwc 2025-07-08 14:00:50 +08:00
parent a438acdbbd
commit 86b357d77a
6 changed files with 380 additions and 8 deletions

View File

@ -123,6 +123,9 @@ You will need the OpenAI API for all the agents.
```bash
export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY
```
**Optional** Stock Price Data
Since YFinance may not work, you can add alpha vantage api key in default_config.py.
### CLI Usage

View File

@ -14,6 +14,8 @@ def create_market_analyst(llm, toolkit):
tools = [
toolkit.get_YFin_data_online,
toolkit.get_stockstats_indicators_report_online,
toolkit.get_alpha_vantage_data,
toolkit.get_alpha_vantage_summary_signals,
]
else:
tools = [

View File

@ -161,6 +161,45 @@ class Toolkit:
return result_data
@staticmethod
@tool
def get_alpha_vantage_data(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve the stock price data for a given ticker symbol from Alpha Vantage.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
"""
result_data = interface.get_alpha_vantage_data(symbol, start_date, end_date)
return result_data
@staticmethod
@tool
def get_alpha_vantage_summary_signals(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"]
) -> str:
"""
Calculate technical indicators and signals based on stock price data from Alpha Vantage.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: Formatted technical indicators summary and signals.
"""
result_data = interface.get_alpha_vantage_summary_signals(symbol, start_date, end_date)
return result_data
@staticmethod
@tool
def get_stockstats_indicators_report(

View File

@ -805,3 +805,328 @@ def get_fundamentals_openai(ticker, curr_date):
)
return response.output[1].content[0].text
def get_alpha_vantage_data(symbol, start_date, end_date):
config = get_config()
api_key = config["alpha_vantage_api_key"]
url = "https://www.alphavantage.co/query"
params = {
'function': 'TIME_SERIES_DAILY',
'symbol': symbol,
'outputsize': 'full',
'apikey': api_key
}
response = requests.get(url, params=params)
data = response.json()
print("API 响应:", list(data.keys())) # 调试信息
if 'Time Series (Daily)' in data:
df = pd.DataFrame(data['Time Series (Daily)']).T
df.index = pd.to_datetime(df.index)
df = df.sort_index()
# 重命名列并转换数据类型
df.columns = ['Open', 'High', 'Low', 'Close', 'Volume']
for col in df.columns:
df[col] = pd.to_numeric(df[col])
# 过滤日期范围
if start_date:
df = df[df.index >= start_date]
if end_date:
df = df[df.index <= end_date]
# # Add header information
# header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
# header += f"# Total records: {len(data)}\n"
# header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return df
else:
print("错误或限制:", data)
return None
def add_technical_indicators(df):
"""
为股票数据添加技术指标
参数:
df: DataFrame with columns ['Open', 'High', 'Low', 'Close', 'Volume']
返回:
DataFrame with additional technical indicator columns
"""
# 创建副本以避免修改原数据
data = df.copy()
# 确保数据按日期排序
data = data.sort_index()
# =============================================================================
# Moving Averages 移动平均线
# =============================================================================
# 50日简单移动平均线
data['close_50_sma'] = data['Close'].rolling(window=50).mean()
# 200日简单移动平均线
data['close_200_sma'] = data['Close'].rolling(window=200).mean()
# 10日指数移动平均线
data['close_10_ema'] = data['Close'].ewm(span=10).mean()
# =============================================================================
# MACD 指标
# =============================================================================
# 计算12日和26日EMA
ema_12 = data['Close'].ewm(span=12).mean()
ema_26 = data['Close'].ewm(span=26).mean()
# MACD线 = 12EMA - 26EMA
data['macd'] = ema_12 - ema_26
# MACD信号线 = MACD的9日EMA
data['macds'] = data['macd'].ewm(span=9).mean()
# MACD柱状图 = MACD - 信号线
data['macdh'] = data['macd'] - data['macds']
# =============================================================================
# RSI 相对强弱指数
# =============================================================================
def calculate_rsi(prices, period=14):
"""计算RSI"""
delta = prices.diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(window=period).mean()
avg_loss = loss.rolling(window=period).mean()
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return rsi
data['rsi'] = calculate_rsi(data['Close'])
# =============================================================================
# Bollinger Bands 布林带
# =============================================================================
# 布林带中线20日SMA
data['boll'] = data['Close'].rolling(window=20).mean()
# 计算20日标准差
std_20 = data['Close'].rolling(window=20).std()
# 布林带上轨 = 中线 + 2倍标准差
data['boll_ub'] = data['boll'] + (2 * std_20)
# 布林带下轨 = 中线 - 2倍标准差
data['boll_lb'] = data['boll'] - (2 * std_20)
# =============================================================================
# ATR 平均真实波幅
# =============================================================================
def calculate_atr(high, low, close, period=14):
"""计算ATR"""
# 真实波幅的三种计算方式
tr1 = high - low
tr2 = abs(high - close.shift(1))
tr3 = abs(low - close.shift(1))
# 取最大值作为真实波幅
true_range = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
# ATR = 真实波幅的移动平均
atr = true_range.rolling(window=period).mean()
return atr
data['atr'] = calculate_atr(data['High'], data['Low'], data['Close'])
# =============================================================================
# VWMA 成交量加权移动平均
# =============================================================================
def calculate_vwma(close, volume, period=20):
"""计算成交量加权移动平均"""
volume_price = close * volume
vwma = volume_price.rolling(window=period).sum() / volume.rolling(window=period).sum()
return vwma
data['vwma'] = calculate_vwma(data['Close'], data['Volume'])
# =============================================================================
# 额外计算一些有用的信号
# =============================================================================
# 金叉死叉信号 (50 SMA vs 200 SMA)
data['golden_cross'] = (data['close_50_sma'] > data['close_200_sma']) & \
(data['close_50_sma'].shift(1) <= data['close_200_sma'].shift(1))
data['death_cross'] = (data['close_50_sma'] < data['close_200_sma']) & \
(data['close_50_sma'].shift(1) >= data['close_200_sma'].shift(1))
# MACD信号
data['macd_bullish'] = (data['macd'] > data['macds']) & \
(data['macd'].shift(1) <= data['macds'].shift(1))
data['macd_bearish'] = (data['macd'] < data['macds']) & \
(data['macd'].shift(1) >= data['macds'].shift(1))
# RSI信号
data['rsi_overbought'] = data['rsi'] > 70
data['rsi_oversold'] = data['rsi'] < 30
# 布林带信号
data['price_above_upper_band'] = data['Close'] > data['boll_ub']
data['price_below_lower_band'] = data['Close'] < data['boll_lb']
# 显示统计信息
total_indicators = 11 # 主要技术指标数量
valid_rows = len(data.dropna())
return data
def display_indicator_summary(df):
"""Generate technical indicator summary string"""
latest = df.iloc[-1] # Latest data
summary = f"\n📈 Latest Technical Indicator Summary ({latest.name.date()}):\n"
summary += "=" * 50 + "\n"
# Price information
summary += f"Close Price: ${latest['Close']:.2f}\n"
# Moving averages
summary += f"\n📊 Moving Averages:\n"
if not pd.isna(latest['close_10_ema']):
summary += f" 10-day EMA: ${latest['close_10_ema']:.2f}\n"
if not pd.isna(latest['close_50_sma']):
summary += f" 50-day SMA: ${latest['close_50_sma']:.2f}\n"
if not pd.isna(latest['close_200_sma']):
summary += f" 200-day SMA: ${latest['close_200_sma']:.2f}\n"
# Trend signals
if not pd.isna(latest['close_50_sma']) and not pd.isna(latest['close_200_sma']):
trend = "Uptrend 📈" if latest['close_50_sma'] > latest['close_200_sma'] else "Downtrend 📉"
summary += f" Long-term Trend: {trend}\n"
# MACD
summary += f"\n📈 MACD:\n"
if not pd.isna(latest['macd']):
summary += f" MACD: {latest['macd']:.4f}\n"
summary += f" Signal Line: {latest['macds']:.4f}\n"
summary += f" Histogram: {latest['macdh']:.4f}\n"
if latest['macd'] > latest['macds']:
summary += " 📈 MACD Bullish Signal\n"
else:
summary += " 📉 MACD Bearish Signal\n"
# RSI
summary += f"\n⚡ RSI:\n"
if not pd.isna(latest['rsi']):
summary += f" RSI(14): {latest['rsi']:.2f}\n"
if latest['rsi'] > 70:
summary += " ⚠️ Overbought Zone\n"
elif latest['rsi'] < 30:
summary += " ⚠️ Oversold Zone\n"
else:
summary += " ✅ Normal Zone\n"
# Bollinger Bands
summary += f"\n📊 Bollinger Bands:\n"
if not pd.isna(latest['boll']):
summary += f" Upper Band: ${latest['boll_ub']:.2f}\n"
summary += f" Middle Band: ${latest['boll']:.2f}\n"
summary += f" Lower Band: ${latest['boll_lb']:.2f}\n"
if latest['Close'] > latest['boll_ub']:
summary += " 📈 Price Above Upper Band\n"
elif latest['Close'] < latest['boll_lb']:
summary += " 📉 Price Below Lower Band\n"
else:
summary += " ✅ Price Within Bands\n"
# ATR and VWMA
summary += f"\n📊 Other Indicators:\n"
if not pd.isna(latest['atr']):
summary += f" ATR(14): {latest['atr']:.2f} (Volatility)\n"
if not pd.isna(latest['vwma']):
summary += f" VWMA(20): ${latest['vwma']:.2f} (Volume-Weighted MA)\n"
return summary
def get_trading_signals(df):
"""Get recent trading signals as string"""
recent_data = df.tail(5) # Last 5 days
signals = []
for date, row in recent_data.iterrows():
day_signals = []
if row.get('golden_cross', False):
day_signals.append("🟢 Golden Cross")
if row.get('death_cross', False):
day_signals.append("🔴 Death Cross")
if row.get('macd_bullish', False):
day_signals.append("📈 MACD Bullish")
if row.get('macd_bearish', False):
day_signals.append("📉 MACD Bearish")
if row.get('rsi_overbought', False):
day_signals.append("⚠️ RSI Overbought")
if row.get('rsi_oversold', False):
day_signals.append("⚠️ RSI Oversold")
if row.get('price_above_upper_band', False):
day_signals.append("📈 Above Upper Band")
if row.get('price_below_lower_band', False):
day_signals.append("📉 Below Lower Band")
if day_signals:
signals.append(f"{date.date()}: {', '.join(day_signals)}")
if signals:
result = "\n🚨 Recent Trading Signals:\n"
result += "=" * 30 + "\n"
for signal in signals:
result += f" {signal}\n"
return result
else:
return "\n✅ No significant trading signals in the last 5 days\n"
def get_alpha_vantage_summary_signals(symbol, start_date, end_date):
# 1. 获取股票数据 (使用前面的Alpha Vantage函数)
data = get_alpha_vantage_data(symbol, start_date, end_date)
# 2. 添加技术指标
data_with_indicators = add_technical_indicators(data)
# 3. 查看指标摘要
summary = display_indicator_summary(data_with_indicators)
# 4. 查看交易信号
signals = get_trading_signals(data_with_indicators)
# 5. 查看特定指标
df_tail = data_with_indicators[['Close', 'rsi', 'macd', 'close_50_sma']].tail()
tail_string = "\nRecent Prices\n" + df_tail.to_string()
return summary + signals + tail_string
if __name__ == "__main__":
text = get_alpha_vantage_summary_signals('NVDA', '2024-01-01', '2024-12-31')
print(text)

View File

@ -19,4 +19,5 @@ DEFAULT_CONFIG = {
"max_recur_limit": 100,
# Tool settings
"online_tools": True,
"alpha_vantage_api_key": "$YOUR_API_KEY"
}

View File

@ -117,6 +117,8 @@ class TradingAgentsGraph:
# online tools
self.toolkit.get_YFin_data_online,
self.toolkit.get_stockstats_indicators_report_online,
self.toolkit.get_alpha_vantage_data,
self.toolkit.get_alpha_vantage_summary_signals,
# offline tools
self.toolkit.get_YFin_data,
self.toolkit.get_stockstats_indicators_report,
@ -127,7 +129,7 @@ class TradingAgentsGraph:
# online tools
self.toolkit.get_stock_news_openai,
# offline tools
self.toolkit.get_reddit_stock_info,
# self.toolkit.get_reddit_stock_info,
]
),
"news": ToolNode(
@ -136,8 +138,8 @@ class TradingAgentsGraph:
self.toolkit.get_global_news_openai,
self.toolkit.get_google_news,
# offline tools
self.toolkit.get_finnhub_news,
self.toolkit.get_reddit_news,
# self.toolkit.get_finnhub_news,
# self.toolkit.get_reddit_news,
]
),
"fundamentals": ToolNode(
@ -145,11 +147,11 @@ class TradingAgentsGraph:
# online tools
self.toolkit.get_fundamentals_openai,
# offline tools
self.toolkit.get_finnhub_company_insider_sentiment,
self.toolkit.get_finnhub_company_insider_transactions,
self.toolkit.get_simfin_balance_sheet,
self.toolkit.get_simfin_cashflow,
self.toolkit.get_simfin_income_stmt,
# self.toolkit.get_finnhub_company_insider_sentiment,
# self.toolkit.get_finnhub_company_insider_transactions,
# self.toolkit.get_simfin_balance_sheet,
# self.toolkit.get_simfin_cashflow,
# self.toolkit.get_simfin_income_stmt,
]
),
}