From 2f89bf4fc7142268fe397ab893cd16ff38a68cc7 Mon Sep 17 00:00:00 2001 From: luceluo <> Date: Wed, 25 Jun 2025 17:15:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=98=BF=E9=87=8C=E5=8D=83?= =?UTF-8?q?=E9=97=AE=E6=A8=A1=E5=9E=8B=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 ++ cli/main.py | 10 ++++++++++ cli/utils.py | 15 +++++++++++++-- tradingagents/agents/utils/memory.py | 12 ++++++++++-- tradingagents/default_config.py | 2 +- tradingagents/graph/trading_graph.py | 12 ++++++++++++ 6 files changed, 48 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 8313619e..e864a6e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ env/ __pycache__/ .DS_Store +.vscode/ +tradingagents-env/ *.csv src/ eval_results/ diff --git a/cli/main.py b/cli/main.py index 3f42f2e2..b0485d26 100644 --- a/cli/main.py +++ b/cli/main.py @@ -741,6 +741,16 @@ def run_analysis(): config["deep_think_llm"] = selections["deep_thinker"] config["backend_url"] = selections["backend_url"] config["llm_provider"] = selections["llm_provider"].lower() + + # Add API key based on provider + import os + if config["llm_provider"] == "alibaba": + config["api_key"] = os.getenv("DASHSCOPE_API_KEY", "") + if not config["api_key"]: + console.print("\n[red]Error: DASHSCOPE_API_KEY environment variable not set for Alibaba provider![/red]") + console.print("[yellow]Please set your Alibaba API key:[/yellow]") + console.print("[cyan]export DASHSCOPE_API_KEY='your-api-key-here'[/cyan]") + exit(1) # Initialize the graph graph = TradingAgentsGraph( diff --git a/cli/utils.py b/cli/utils.py index d3873360..c83335a2 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -151,6 +151,11 @@ def select_shallow_thinking_agent(provider) -> str: ], "ollama": [ ("llama3.2 local", "llama3.2"), + ], + "alibaba": [ + ("Qwen-Plus (阿里云) - 平衡性能和成本", "qwen-plus"), + ("Qwen-Max (阿里云) - 最强推理能力", "qwen-max"), + ("Qwen-Turbo (阿里云) - 快速低成本", "qwen-turbo"), ] } @@ -211,7 +216,12 @@ def select_deep_thinking_agent(provider) -> str: ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), ], "ollama": [ - ("qwen3", "qwen3"), + ("llama3.2 local", "llama3.2"), + ], + "alibaba": [ + ("Qwen-Plus (阿里云) - 平衡性能和成本", "qwen-plus"), + ("Qwen-Max (阿里云) - 最强推理能力", "qwen-max"), + ("Qwen-Turbo (阿里云) - 快速低成本", "qwen-turbo"), ] } @@ -245,7 +255,8 @@ def select_llm_provider() -> tuple[str, str]: ("Anthropic", "https://api.anthropic.com/"), ("Google", "https://generativelanguage.googleapis.com/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), + ("Ollama", "http://localhost:11434/v1"), + ("Alibaba", "https://dashscope.aliyuncs.com/compatible-mode/v1"), ] choice = questionary.select( diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index f3415765..12b0a7d2 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -5,16 +5,24 @@ from openai import OpenAI class FinancialSituationMemory: def __init__(self, name, config): + self.config = config if config["backend_url"] == "http://localhost:11434/v1": self.embedding = "nomic-embed-text" + self.client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama") # Ollama client + elif config.get("llm_provider", "").lower() == "alibaba": + self.embedding = "text-embedding-v3" # 阿里云embedding模型 + self.client = OpenAI( + base_url=config["backend_url"], + api_key=config.get("api_key", "") + ) # 阿里云客户端 else: self.embedding = "text-embedding-3-small" - self.client = OpenAI() + self.client = OpenAI() # Standard OpenAI client self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) def get_embedding(self, text): - """Get OpenAI embedding for a text""" + """Get embedding for a text using the configured client""" response = self.client.embeddings.create( model=self.embedding, input=text diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 2cf15b85..ee21840e 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -2,7 +2,7 @@ import os DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), - "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", + "data_dir": "/Users/luosibao/Documents/Code/ScAI/FR1-data", "data_cache_dir": os.path.join( os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "dataflows/data_cache", diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index eb06cf43..0e43b8c3 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -61,6 +61,18 @@ class TradingAgentsGraph: if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) + elif self.config["llm_provider"].lower() == "alibaba": + # 阿里云需要特殊的API key配置 + self.deep_thinking_llm = ChatOpenAI( + model=self.config["deep_think_llm"], + base_url=self.config["backend_url"], + api_key=self.config.get("api_key", "") + ) + self.quick_thinking_llm = ChatOpenAI( + model=self.config["quick_think_llm"], + base_url=self.config["backend_url"], + api_key=self.config.get("api_key", "") + ) elif self.config["llm_provider"].lower() == "anthropic": self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])