diff --git a/tradingagents/dataflows/akshare_news.py b/tradingagents/dataflows/akshare_news.py new file mode 100644 index 00000000..8398d572 --- /dev/null +++ b/tradingagents/dataflows/akshare_news.py @@ -0,0 +1,50 @@ +# tradingagents/dataflows/akshare_news.py +import akshare as ak +from datetime import datetime +import pandas as pd + +def get_news_akshare( + ticker: str, + start_date: str, + end_date: str, +) -> str: + """ + 使用 AkShare 获取 A 股特定股票的新闻 (东方财富数据源) + """ + # 1. A股代码清洗: yfinance 习惯用 '600519.SS', akshare 通常只需要 '600519' + clean_ticker = ticker.split('.')[0] + if not clean_ticker.isdigit(): + return f"Invalid A-share ticker format: {ticker}" + + try: + # 获取个股新闻 (东方财富接口) + news_df = ak.stock_news_em(symbol=clean_ticker) + + if news_df.empty: + return f"No news found for {ticker}" + + # 将发布时间转换为 datetime 进行过滤 + news_df['发布时间'] = pd.to_datetime(news_df['发布时间']) + start_dt = datetime.strptime(start_date, "%Y-%m-%d") + end_dt = datetime.strptime(end_date, "%Y-%m-%d") + + # 过滤日期区间 + mask = (news_df['发布时间'] >= start_dt) & (news_df['发布时间'] <= end_dt) + filtered_news = news_df.loc[mask] + + if filtered_news.empty: + return f"No news found for {ticker} between {start_date} and {end_date}" + + # 格式化输出为大模型易读的 Markdown 格式 + news_str = "" + # 限制返回条数,避免 token 溢出 + for _, row in filtered_news.head(20).iterrows(): + pub_time = row['发布时间'].strftime('%Y-%m-%d %H:%M:%S') + news_str += f"### {row['新闻标题']} (时间: {pub_time})\n" + news_str += f"{row['新闻内容']}\n" + news_str += f"来源: {row['文章来源']} | Link: {row['新闻链接']}\n\n" + + return f"## {ticker} A股新闻, 从 {start_date} 到 {end_date}:\n\n{news_str}" + + except Exception as e: + return f"Error fetching A-share news for {ticker} via AkShare: {str(e)}" \ No newline at end of file diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 0caf4b68..82523cb8 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -11,6 +11,7 @@ from .y_finance import ( get_insider_transactions as get_yfinance_insider_transactions, ) from .yfinance_news import get_news_yfinance, get_global_news_yfinance +from .akshare_news import get_news_akshare from .alpha_vantage import ( get_stock as get_alpha_vantage_stock, get_indicator as get_alpha_vantage_indicator, @@ -63,6 +64,7 @@ TOOLS_CATEGORIES = { VENDOR_LIST = [ "yfinance", "alpha_vantage", + "akshare", # 新增 A股 数据源 ] # Mapping of methods to their vendor-specific implementations @@ -98,6 +100,7 @@ VENDOR_METHODS = { "get_news": { "alpha_vantage": get_alpha_vantage_news, "yfinance": get_news_yfinance, + "akshare": get_news_akshare, # 新增映射 }, "get_global_news": { "yfinance": get_global_news_yfinance, @@ -134,8 +137,18 @@ def get_vendor(category: str, method: str = None) -> str: def route_to_vendor(method: str, *args, **kwargs): """Route method calls to appropriate vendor implementation with fallback support.""" category = get_category_for_method(method) - vendor_config = get_vendor(category, method) - primary_vendors = [v.strip() for v in vendor_config.split(',')] + if len(args) > 0 and isinstance(args[0], str): + print(f"AAAAAA args: {args}") + ticker = args[0].upper() + # 如果带有上海(.SS)或深圳(.SZ)后缀,强制优先使用 akshare + if ticker.endswith(".SS") or ticker.endswith(".SZ"): + primary_vendors = ["akshare"] + else: + vendor_config = get_vendor(category, method) + primary_vendors = [v.strip() for v in vendor_config.split(',')] + else: + vendor_config = get_vendor(category, method) + primary_vendors = [v.strip() for v in vendor_config.split(',')] if method not in VENDOR_METHODS: raise ValueError(f"Method '{method}' not supported")