diff --git a/README.md b/README.md index cac18691..32b285d8 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,9 @@ You will need the OpenAI API for all the agents. ```bash export OPENAI_API_KEY=$YOUR_OPENAI_API_KEY ``` +**Optional** Stock Price Data + +Since YFinance may not work, you can add alpha vantage api key in default_config.py. ### CLI Usage diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 41ee944b..f7dab5ec 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -14,6 +14,8 @@ def create_market_analyst(llm, toolkit): tools = [ toolkit.get_YFin_data_online, toolkit.get_stockstats_indicators_report_online, + toolkit.get_alpha_vantage_data, + toolkit.get_alpha_vantage_summary_signals, ] else: tools = [ diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 0b07f044..428c5fc0 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -161,6 +161,45 @@ class Toolkit: return result_data + @staticmethod + @tool + def get_alpha_vantage_data( + symbol: Annotated[str, "ticker symbol of the company"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], + ) -> str: + """ + Retrieve the stock price data for a given ticker symbol from Alpha Vantage. + Args: + symbol (str): Ticker symbol of the company, e.g. AAPL, TSM + start_date (str): Start date in yyyy-mm-dd format + end_date (str): End date in yyyy-mm-dd format + Returns: + str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range. + """ + + result_data = interface.get_alpha_vantage_data(symbol, start_date, end_date) + return result_data + + @staticmethod + @tool + def get_alpha_vantage_summary_signals( + symbol: Annotated[str, "ticker symbol of the company"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"] + ) -> str: + """ + Calculate technical indicators and signals based on stock price data from Alpha Vantage. + Args: + symbol (str): Ticker symbol of the company, e.g. AAPL, TSM + start_date (str): Start date in yyyy-mm-dd format + end_date (str): End date in yyyy-mm-dd format + Returns: + str: Formatted technical indicators summary and signals. + """ + result_data = interface.get_alpha_vantage_summary_signals(symbol, start_date, end_date) + return result_data + @staticmethod @tool def get_stockstats_indicators_report( diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 7fffbb4f..d0c0a01f 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -805,3 +805,328 @@ def get_fundamentals_openai(ticker, curr_date): ) return response.output[1].content[0].text + +def get_alpha_vantage_data(symbol, start_date, end_date): + config = get_config() + api_key = config["alpha_vantage_api_key"] + + url = "https://www.alphavantage.co/query" + params = { + 'function': 'TIME_SERIES_DAILY', + 'symbol': symbol, + 'outputsize': 'full', + 'apikey': api_key + } + + response = requests.get(url, params=params) + data = response.json() + + print("API 响应:", list(data.keys())) # 调试信息 + + if 'Time Series (Daily)' in data: + df = pd.DataFrame(data['Time Series (Daily)']).T + df.index = pd.to_datetime(df.index) + df = df.sort_index() + + # 重命名列并转换数据类型 + df.columns = ['Open', 'High', 'Low', 'Close', 'Volume'] + for col in df.columns: + df[col] = pd.to_numeric(df[col]) + + # 过滤日期范围 + if start_date: + df = df[df.index >= start_date] + if end_date: + df = df[df.index <= end_date] + + # # Add header information + # header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n" + # header += f"# Total records: {len(data)}\n" + # header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + + return df + else: + print("错误或限制:", data) + return None + + +def add_technical_indicators(df): + """ + 为股票数据添加技术指标 + + 参数: + df: DataFrame with columns ['Open', 'High', 'Low', 'Close', 'Volume'] + + 返回: + DataFrame with additional technical indicator columns + """ + + # 创建副本以避免修改原数据 + data = df.copy() + + # 确保数据按日期排序 + data = data.sort_index() + + # ============================================================================= + # Moving Averages 移动平均线 + # ============================================================================= + + # 50日简单移动平均线 + data['close_50_sma'] = data['Close'].rolling(window=50).mean() + + # 200日简单移动平均线 + data['close_200_sma'] = data['Close'].rolling(window=200).mean() + + # 10日指数移动平均线 + data['close_10_ema'] = data['Close'].ewm(span=10).mean() + + # ============================================================================= + # MACD 指标 + # ============================================================================= + + # 计算12日和26日EMA + ema_12 = data['Close'].ewm(span=12).mean() + ema_26 = data['Close'].ewm(span=26).mean() + + # MACD线 = 12EMA - 26EMA + data['macd'] = ema_12 - ema_26 + + # MACD信号线 = MACD的9日EMA + data['macds'] = data['macd'].ewm(span=9).mean() + + # MACD柱状图 = MACD - 信号线 + data['macdh'] = data['macd'] - data['macds'] + + # ============================================================================= + # RSI 相对强弱指数 + # ============================================================================= + + def calculate_rsi(prices, period=14): + """计算RSI""" + delta = prices.diff() + gain = delta.where(delta > 0, 0) + loss = -delta.where(delta < 0, 0) + + avg_gain = gain.rolling(window=period).mean() + avg_loss = loss.rolling(window=period).mean() + + rs = avg_gain / avg_loss + rsi = 100 - (100 / (1 + rs)) + return rsi + + data['rsi'] = calculate_rsi(data['Close']) + + # ============================================================================= + # Bollinger Bands 布林带 + # ============================================================================= + + # 布林带中线(20日SMA) + data['boll'] = data['Close'].rolling(window=20).mean() + + # 计算20日标准差 + std_20 = data['Close'].rolling(window=20).std() + + # 布林带上轨 = 中线 + 2倍标准差 + data['boll_ub'] = data['boll'] + (2 * std_20) + + # 布林带下轨 = 中线 - 2倍标准差 + data['boll_lb'] = data['boll'] - (2 * std_20) + + # ============================================================================= + # ATR 平均真实波幅 + # ============================================================================= + + def calculate_atr(high, low, close, period=14): + """计算ATR""" + # 真实波幅的三种计算方式 + tr1 = high - low + tr2 = abs(high - close.shift(1)) + tr3 = abs(low - close.shift(1)) + + # 取最大值作为真实波幅 + true_range = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) + + # ATR = 真实波幅的移动平均 + atr = true_range.rolling(window=period).mean() + return atr + + data['atr'] = calculate_atr(data['High'], data['Low'], data['Close']) + + # ============================================================================= + # VWMA 成交量加权移动平均 + # ============================================================================= + + def calculate_vwma(close, volume, period=20): + """计算成交量加权移动平均""" + volume_price = close * volume + vwma = volume_price.rolling(window=period).sum() / volume.rolling(window=period).sum() + return vwma + + data['vwma'] = calculate_vwma(data['Close'], data['Volume']) + + # ============================================================================= + # 额外计算一些有用的信号 + # ============================================================================= + + # 金叉死叉信号 (50 SMA vs 200 SMA) + data['golden_cross'] = (data['close_50_sma'] > data['close_200_sma']) & \ + (data['close_50_sma'].shift(1) <= data['close_200_sma'].shift(1)) + + data['death_cross'] = (data['close_50_sma'] < data['close_200_sma']) & \ + (data['close_50_sma'].shift(1) >= data['close_200_sma'].shift(1)) + + # MACD信号 + data['macd_bullish'] = (data['macd'] > data['macds']) & \ + (data['macd'].shift(1) <= data['macds'].shift(1)) + + data['macd_bearish'] = (data['macd'] < data['macds']) & \ + (data['macd'].shift(1) >= data['macds'].shift(1)) + + # RSI信号 + data['rsi_overbought'] = data['rsi'] > 70 + data['rsi_oversold'] = data['rsi'] < 30 + + # 布林带信号 + data['price_above_upper_band'] = data['Close'] > data['boll_ub'] + data['price_below_lower_band'] = data['Close'] < data['boll_lb'] + + # 显示统计信息 + total_indicators = 11 # 主要技术指标数量 + valid_rows = len(data.dropna()) + + return data + + +def display_indicator_summary(df): + """Generate technical indicator summary string""" + + latest = df.iloc[-1] # Latest data + + summary = f"\n📈 Latest Technical Indicator Summary ({latest.name.date()}):\n" + summary += "=" * 50 + "\n" + + # Price information + summary += f"Close Price: ${latest['Close']:.2f}\n" + + # Moving averages + summary += f"\n📊 Moving Averages:\n" + if not pd.isna(latest['close_10_ema']): + summary += f" 10-day EMA: ${latest['close_10_ema']:.2f}\n" + if not pd.isna(latest['close_50_sma']): + summary += f" 50-day SMA: ${latest['close_50_sma']:.2f}\n" + if not pd.isna(latest['close_200_sma']): + summary += f" 200-day SMA: ${latest['close_200_sma']:.2f}\n" + + # Trend signals + if not pd.isna(latest['close_50_sma']) and not pd.isna(latest['close_200_sma']): + trend = "Uptrend 📈" if latest['close_50_sma'] > latest['close_200_sma'] else "Downtrend 📉" + summary += f" Long-term Trend: {trend}\n" + + # MACD + summary += f"\n📈 MACD:\n" + if not pd.isna(latest['macd']): + summary += f" MACD: {latest['macd']:.4f}\n" + summary += f" Signal Line: {latest['macds']:.4f}\n" + summary += f" Histogram: {latest['macdh']:.4f}\n" + + if latest['macd'] > latest['macds']: + summary += " 📈 MACD Bullish Signal\n" + else: + summary += " 📉 MACD Bearish Signal\n" + + # RSI + summary += f"\n⚡ RSI:\n" + if not pd.isna(latest['rsi']): + summary += f" RSI(14): {latest['rsi']:.2f}\n" + if latest['rsi'] > 70: + summary += " ⚠️ Overbought Zone\n" + elif latest['rsi'] < 30: + summary += " ⚠️ Oversold Zone\n" + else: + summary += " ✅ Normal Zone\n" + + # Bollinger Bands + summary += f"\n📊 Bollinger Bands:\n" + if not pd.isna(latest['boll']): + summary += f" Upper Band: ${latest['boll_ub']:.2f}\n" + summary += f" Middle Band: ${latest['boll']:.2f}\n" + summary += f" Lower Band: ${latest['boll_lb']:.2f}\n" + + if latest['Close'] > latest['boll_ub']: + summary += " 📈 Price Above Upper Band\n" + elif latest['Close'] < latest['boll_lb']: + summary += " 📉 Price Below Lower Band\n" + else: + summary += " ✅ Price Within Bands\n" + + # ATR and VWMA + summary += f"\n📊 Other Indicators:\n" + if not pd.isna(latest['atr']): + summary += f" ATR(14): {latest['atr']:.2f} (Volatility)\n" + if not pd.isna(latest['vwma']): + summary += f" VWMA(20): ${latest['vwma']:.2f} (Volume-Weighted MA)\n" + + return summary + + +def get_trading_signals(df): + """Get recent trading signals as string""" + + recent_data = df.tail(5) # Last 5 days + signals = [] + + for date, row in recent_data.iterrows(): + day_signals = [] + + if row.get('golden_cross', False): + day_signals.append("🟢 Golden Cross") + if row.get('death_cross', False): + day_signals.append("🔴 Death Cross") + if row.get('macd_bullish', False): + day_signals.append("📈 MACD Bullish") + if row.get('macd_bearish', False): + day_signals.append("📉 MACD Bearish") + if row.get('rsi_overbought', False): + day_signals.append("⚠️ RSI Overbought") + if row.get('rsi_oversold', False): + day_signals.append("⚠️ RSI Oversold") + if row.get('price_above_upper_band', False): + day_signals.append("📈 Above Upper Band") + if row.get('price_below_lower_band', False): + day_signals.append("📉 Below Lower Band") + + if day_signals: + signals.append(f"{date.date()}: {', '.join(day_signals)}") + + if signals: + result = "\n🚨 Recent Trading Signals:\n" + result += "=" * 30 + "\n" + for signal in signals: + result += f" {signal}\n" + return result + else: + return "\n✅ No significant trading signals in the last 5 days\n" + +def get_alpha_vantage_summary_signals(symbol, start_date, end_date): + # 1. 获取股票数据 (使用前面的Alpha Vantage函数) + data = get_alpha_vantage_data(symbol, start_date, end_date) + + # 2. 添加技术指标 + data_with_indicators = add_technical_indicators(data) + + # 3. 查看指标摘要 + summary = display_indicator_summary(data_with_indicators) + + # 4. 查看交易信号 + signals = get_trading_signals(data_with_indicators) + + # 5. 查看特定指标 + df_tail = data_with_indicators[['Close', 'rsi', 'macd', 'close_50_sma']].tail() + tail_string = "\nRecent Prices\n" + df_tail.to_string() + + return summary + signals + tail_string + + +if __name__ == "__main__": + text = get_alpha_vantage_summary_signals('NVDA', '2024-01-01', '2024-12-31') + print(text) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 089e9c24..d70b1c10 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -19,4 +19,5 @@ DEFAULT_CONFIG = { "max_recur_limit": 100, # Tool settings "online_tools": True, + "alpha_vantage_api_key": "$YOUR_API_KEY" } diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 80a29e53..b5d83953 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -117,6 +117,8 @@ class TradingAgentsGraph: # online tools self.toolkit.get_YFin_data_online, self.toolkit.get_stockstats_indicators_report_online, + self.toolkit.get_alpha_vantage_data, + self.toolkit.get_alpha_vantage_summary_signals, # offline tools self.toolkit.get_YFin_data, self.toolkit.get_stockstats_indicators_report, @@ -127,7 +129,7 @@ class TradingAgentsGraph: # online tools self.toolkit.get_stock_news_openai, # offline tools - self.toolkit.get_reddit_stock_info, + # self.toolkit.get_reddit_stock_info, ] ), "news": ToolNode( @@ -136,8 +138,8 @@ class TradingAgentsGraph: self.toolkit.get_global_news_openai, self.toolkit.get_google_news, # offline tools - self.toolkit.get_finnhub_news, - self.toolkit.get_reddit_news, + # self.toolkit.get_finnhub_news, + # self.toolkit.get_reddit_news, ] ), "fundamentals": ToolNode( @@ -145,11 +147,11 @@ class TradingAgentsGraph: # online tools self.toolkit.get_fundamentals_openai, # offline tools - self.toolkit.get_finnhub_company_insider_sentiment, - self.toolkit.get_finnhub_company_insider_transactions, - self.toolkit.get_simfin_balance_sheet, - self.toolkit.get_simfin_cashflow, - self.toolkit.get_simfin_income_stmt, + # self.toolkit.get_finnhub_company_insider_sentiment, + # self.toolkit.get_finnhub_company_insider_transactions, + # self.toolkit.get_simfin_balance_sheet, + # self.toolkit.get_simfin_cashflow, + # self.toolkit.get_simfin_income_stmt, ] ), }