diff --git a/.env.example b/.env.example index 1328b838..bf065cc8 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,4 @@ GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= OPENROUTER_API_KEY= +SILICONFLOW_API_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore index 9a2904a9..0eb63b91 100644 --- a/.gitignore +++ b/.gitignore @@ -217,3 +217,6 @@ __marimo__/ # Cache **/data_cache/ + +results/* +reports/* \ No newline at end of file diff --git a/cli/utils.py b/cli/utils.py index e071ce06..297d22ea 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -242,6 +242,7 @@ def select_llm_provider() -> tuple[str, str | None]: ("xAI", "https://api.x.ai/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), ("Ollama", "http://localhost:11434/v1"), + ("SiliconFlow", "https://api.siliconflow.cn/v1"), ] choice = questionary.select( diff --git a/tradingagents/dataflows/akshare_news.py b/tradingagents/dataflows/akshare_news.py new file mode 100644 index 00000000..8398d572 --- /dev/null +++ b/tradingagents/dataflows/akshare_news.py @@ -0,0 +1,50 @@ +# tradingagents/dataflows/akshare_news.py +import akshare as ak +from datetime import datetime +import pandas as pd + +def get_news_akshare( + ticker: str, + start_date: str, + end_date: str, +) -> str: + """ + 使用 AkShare 获取 A 股特定股票的新闻 (东方财富数据源) + """ + # 1. A股代码清洗: yfinance 习惯用 '600519.SS', akshare 通常只需要 '600519' + clean_ticker = ticker.split('.')[0] + if not clean_ticker.isdigit(): + return f"Invalid A-share ticker format: {ticker}" + + try: + # 获取个股新闻 (东方财富接口) + news_df = ak.stock_news_em(symbol=clean_ticker) + + if news_df.empty: + return f"No news found for {ticker}" + + # 将发布时间转换为 datetime 进行过滤 + news_df['发布时间'] = pd.to_datetime(news_df['发布时间']) + start_dt = datetime.strptime(start_date, "%Y-%m-%d") + end_dt = datetime.strptime(end_date, "%Y-%m-%d") + + # 过滤日期区间 + mask = (news_df['发布时间'] >= start_dt) & (news_df['发布时间'] <= end_dt) + filtered_news = news_df.loc[mask] + + if filtered_news.empty: + return f"No news found for {ticker} between {start_date} and {end_date}" + + # 格式化输出为大模型易读的 Markdown 格式 + news_str = "" + # 限制返回条数,避免 token 溢出 + for _, row in filtered_news.head(20).iterrows(): + pub_time = row['发布时间'].strftime('%Y-%m-%d %H:%M:%S') + news_str += f"### {row['新闻标题']} (时间: {pub_time})\n" + news_str += f"{row['新闻内容']}\n" + news_str += f"来源: {row['文章来源']} | Link: {row['新闻链接']}\n\n" + + return f"## {ticker} A股新闻, 从 {start_date} 到 {end_date}:\n\n{news_str}" + + except Exception as e: + return f"Error fetching A-share news for {ticker} via AkShare: {str(e)}" \ No newline at end of file diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 0caf4b68..82523cb8 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -11,6 +11,7 @@ from .y_finance import ( get_insider_transactions as get_yfinance_insider_transactions, ) from .yfinance_news import get_news_yfinance, get_global_news_yfinance +from .akshare_news import get_news_akshare from .alpha_vantage import ( get_stock as get_alpha_vantage_stock, get_indicator as get_alpha_vantage_indicator, @@ -63,6 +64,7 @@ TOOLS_CATEGORIES = { VENDOR_LIST = [ "yfinance", "alpha_vantage", + "akshare", # 新增 A股 数据源 ] # Mapping of methods to their vendor-specific implementations @@ -98,6 +100,7 @@ VENDOR_METHODS = { "get_news": { "alpha_vantage": get_alpha_vantage_news, "yfinance": get_news_yfinance, + "akshare": get_news_akshare, # 新增映射 }, "get_global_news": { "yfinance": get_global_news_yfinance, @@ -134,8 +137,18 @@ def get_vendor(category: str, method: str = None) -> str: def route_to_vendor(method: str, *args, **kwargs): """Route method calls to appropriate vendor implementation with fallback support.""" category = get_category_for_method(method) - vendor_config = get_vendor(category, method) - primary_vendors = [v.strip() for v in vendor_config.split(',')] + if len(args) > 0 and isinstance(args[0], str): + print(f"AAAAAA args: {args}") + ticker = args[0].upper() + # 如果带有上海(.SS)或深圳(.SZ)后缀,强制优先使用 akshare + if ticker.endswith(".SS") or ticker.endswith(".SZ"): + primary_vendors = ["akshare"] + else: + vendor_config = get_vendor(category, method) + primary_vendors = [v.strip() for v in vendor_config.split(',')] + else: + vendor_config = get_vendor(category, method) + primary_vendors = [v.strip() for v in vendor_config.split(',')] if method not in VENDOR_METHODS: raise ValueError(f"Method '{method}' not supported") diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d3..9198ba81 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -34,7 +34,7 @@ def create_llm_client( """ provider_lower = provider.lower() - if provider_lower in ("openai", "ollama", "openrouter"): + if provider_lower in ("openai", "ollama", "openrouter", "siliconflow"): return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) if provider_lower == "xai": diff --git a/tradingagents/llm_clients/model_catalog.py b/tradingagents/llm_clients/model_catalog.py index fd91c66d..3f364ffb 100644 --- a/tradingagents/llm_clients/model_catalog.py +++ b/tradingagents/llm_clients/model_catalog.py @@ -77,6 +77,18 @@ MODEL_OPTIONS: ProviderModeOptions = { ("Qwen3:latest (8B, local)", "qwen3:latest"), ], }, + "siliconflow": { + "quick": [ + ("Qwen/Qwen3.5-35B-A3B", "Qwen/Qwen3.5-35B-A3B"), + ("Qwen/Qwen3.5-27B", "Qwen/Qwen3.5-27B"), + ("Qwen/Qwen3.5-122B-A10B", "Qwen/Qwen3.5-122B-A10B"), + ], + "deep": [ + ("Qwen/Qwen3.5-122B-A10B", "Qwen/Qwen3.5-122B-A10B"), + ("Pro/zai-org/GLM-5.1", "Pro/zai-org/GLM-5.1"), + ("Pro/MiniMaxAI/MiniMax-M2.5", "Pro/MiniMaxAI/MiniMax-M2.5"), + ], + }, } diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index 4f2e1b32..9165d3e7 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -29,6 +29,7 @@ _PROVIDER_CONFIG = { "xai": ("https://api.x.ai/v1", "XAI_API_KEY"), "openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"), "ollama": ("http://localhost:11434/v1", None), + "siliconflow": ("https://api.siliconflow.cn/v1", "SILICONFLOW_API_KEY"), }