feat: integrate akshare news data source and implement automatic vendor routing for A-share tickers
This commit is contained in:
parent
115a793103
commit
4fb458c156
|
|
@ -0,0 +1,50 @@
|
||||||
|
# tradingagents/dataflows/akshare_news.py
|
||||||
|
import akshare as ak
|
||||||
|
from datetime import datetime
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
def get_news_akshare(
|
||||||
|
ticker: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
使用 AkShare 获取 A 股特定股票的新闻 (东方财富数据源)
|
||||||
|
"""
|
||||||
|
# 1. A股代码清洗: yfinance 习惯用 '600519.SS', akshare 通常只需要 '600519'
|
||||||
|
clean_ticker = ticker.split('.')[0]
|
||||||
|
if not clean_ticker.isdigit():
|
||||||
|
return f"Invalid A-share ticker format: {ticker}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取个股新闻 (东方财富接口)
|
||||||
|
news_df = ak.stock_news_em(symbol=clean_ticker)
|
||||||
|
|
||||||
|
if news_df.empty:
|
||||||
|
return f"No news found for {ticker}"
|
||||||
|
|
||||||
|
# 将发布时间转换为 datetime 进行过滤
|
||||||
|
news_df['发布时间'] = pd.to_datetime(news_df['发布时间'])
|
||||||
|
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
|
||||||
|
# 过滤日期区间
|
||||||
|
mask = (news_df['发布时间'] >= start_dt) & (news_df['发布时间'] <= end_dt)
|
||||||
|
filtered_news = news_df.loc[mask]
|
||||||
|
|
||||||
|
if filtered_news.empty:
|
||||||
|
return f"No news found for {ticker} between {start_date} and {end_date}"
|
||||||
|
|
||||||
|
# 格式化输出为大模型易读的 Markdown 格式
|
||||||
|
news_str = ""
|
||||||
|
# 限制返回条数,避免 token 溢出
|
||||||
|
for _, row in filtered_news.head(20).iterrows():
|
||||||
|
pub_time = row['发布时间'].strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
news_str += f"### {row['新闻标题']} (时间: {pub_time})\n"
|
||||||
|
news_str += f"{row['新闻内容']}\n"
|
||||||
|
news_str += f"来源: {row['文章来源']} | Link: {row['新闻链接']}\n\n"
|
||||||
|
|
||||||
|
return f"## {ticker} A股新闻, 从 {start_date} 到 {end_date}:\n\n{news_str}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error fetching A-share news for {ticker} via AkShare: {str(e)}"
|
||||||
|
|
@ -11,6 +11,7 @@ from .y_finance import (
|
||||||
get_insider_transactions as get_yfinance_insider_transactions,
|
get_insider_transactions as get_yfinance_insider_transactions,
|
||||||
)
|
)
|
||||||
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
|
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
|
||||||
|
from .akshare_news import get_news_akshare
|
||||||
from .alpha_vantage import (
|
from .alpha_vantage import (
|
||||||
get_stock as get_alpha_vantage_stock,
|
get_stock as get_alpha_vantage_stock,
|
||||||
get_indicator as get_alpha_vantage_indicator,
|
get_indicator as get_alpha_vantage_indicator,
|
||||||
|
|
@ -63,6 +64,7 @@ TOOLS_CATEGORIES = {
|
||||||
VENDOR_LIST = [
|
VENDOR_LIST = [
|
||||||
"yfinance",
|
"yfinance",
|
||||||
"alpha_vantage",
|
"alpha_vantage",
|
||||||
|
"akshare", # 新增 A股 数据源
|
||||||
]
|
]
|
||||||
|
|
||||||
# Mapping of methods to their vendor-specific implementations
|
# Mapping of methods to their vendor-specific implementations
|
||||||
|
|
@ -98,6 +100,7 @@ VENDOR_METHODS = {
|
||||||
"get_news": {
|
"get_news": {
|
||||||
"alpha_vantage": get_alpha_vantage_news,
|
"alpha_vantage": get_alpha_vantage_news,
|
||||||
"yfinance": get_news_yfinance,
|
"yfinance": get_news_yfinance,
|
||||||
|
"akshare": get_news_akshare, # 新增映射
|
||||||
},
|
},
|
||||||
"get_global_news": {
|
"get_global_news": {
|
||||||
"yfinance": get_global_news_yfinance,
|
"yfinance": get_global_news_yfinance,
|
||||||
|
|
@ -134,8 +137,18 @@ def get_vendor(category: str, method: str = None) -> str:
|
||||||
def route_to_vendor(method: str, *args, **kwargs):
|
def route_to_vendor(method: str, *args, **kwargs):
|
||||||
"""Route method calls to appropriate vendor implementation with fallback support."""
|
"""Route method calls to appropriate vendor implementation with fallback support."""
|
||||||
category = get_category_for_method(method)
|
category = get_category_for_method(method)
|
||||||
vendor_config = get_vendor(category, method)
|
if len(args) > 0 and isinstance(args[0], str):
|
||||||
primary_vendors = [v.strip() for v in vendor_config.split(',')]
|
print(f"AAAAAA args: {args}")
|
||||||
|
ticker = args[0].upper()
|
||||||
|
# 如果带有上海(.SS)或深圳(.SZ)后缀,强制优先使用 akshare
|
||||||
|
if ticker.endswith(".SS") or ticker.endswith(".SZ"):
|
||||||
|
primary_vendors = ["akshare"]
|
||||||
|
else:
|
||||||
|
vendor_config = get_vendor(category, method)
|
||||||
|
primary_vendors = [v.strip() for v in vendor_config.split(',')]
|
||||||
|
else:
|
||||||
|
vendor_config = get_vendor(category, method)
|
||||||
|
primary_vendors = [v.strip() for v in vendor_config.split(',')]
|
||||||
|
|
||||||
if method not in VENDOR_METHODS:
|
if method not in VENDOR_METHODS:
|
||||||
raise ValueError(f"Method '{method}' not supported")
|
raise ValueError(f"Method '{method}' not supported")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue