diff --git a/cli/utils.py b/cli/utils.py index 7b9682a6..d2e7b23e 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -148,10 +148,14 @@ def select_shallow_thinking_agent(provider) -> str: ("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"), ("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"), ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), + ("qwen/qwen3-coder:free - Qwen3-Coder-480B-A35B-Instruct is a Mixture-of-Experts (MoE) code generation model", "qwen/qwen3-coder:free"), ], "ollama": [ ("llama3.1 local", "llama3.1"), ("llama3.2 local", "llama3.2"), + ], + "xai": [ + ("grok-4-fast-non-reasoning", "grok-4-fast-non-reasoning") ] } @@ -209,11 +213,14 @@ def select_deep_thinking_agent(provider) -> str: ], "openrouter": [ ("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"), - ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), + ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free") ], "ollama": [ ("llama3.1 local", "llama3.1"), ("qwen3", "qwen3"), + ], + "xai": [ + ("grok-4-fast-reasoning", "grok-4-fast-reasoning") ] } @@ -243,11 +250,12 @@ def select_llm_provider() -> tuple[str, str]: """Select the OpenAI api url using interactive selection.""" # Define OpenAI api options with their corresponding endpoints BASE_URLS = [ + ("XAI", "https://api.x.ai/v1"), ("OpenAI", "https://api.openai.com/v1"), ("Anthropic", "https://api.anthropic.com/"), ("Google", "https://generativelanguage.googleapis.com/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), + ("Ollama", "http://localhost:11434/v1"), ] choice = questionary.select( diff --git a/requirements.txt b/requirements.txt index a6154cd2..57451615 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,5 @@ rich questionary langchain_anthropic langchain-google-genai +langchain-xai +sentence-transformers \ No newline at end of file diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..e5709545 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,25 +1,44 @@ import chromadb from chromadb.config import Settings from openai import OpenAI +from sentence_transformers import SentenceTransformer class FinancialSituationMemory: def __init__(self, name, config): - if config["backend_url"] == "http://localhost:11434/v1": + self.client = OpenAI(base_url=config["backend_url"], api_key=config.get("api_key")) + + # 根據 backend 決定 embedding 策略 + if config["llm_provider"] == "ollama": self.embedding = "nomic-embed-text" + self.embedding_client = OpenAI(base_url="http://localhost:11434/v1") + self.use_local_embedding = False + elif config["llm_provider"] == "xai": + # Grok - 使用本地 Hugging Face 模型 + self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') + self.use_local_embedding = True else: self.embedding = "text-embedding-3-small" - self.client = OpenAI(base_url=config["backend_url"]) + self.embedding_client = self.client + self.use_local_embedding = False + self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) + def get_embedding(self, text): - """Get OpenAI embedding for a text""" - - response = self.client.embeddings.create( - model=self.embedding, input=text - ) - return response.data[0].embedding + """Get embedding for a text""" + if self.use_local_embedding: + # 使用本地 Hugging Face 模型 + return self.embedding_model.encode(text).tolist() + else: + # 使用 API + response = self.embedding_client.embeddings.create( + model=self.embedding, + input=text + ) + return response.data[0].embedding + def add_situations(self, situations_and_advice): """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 40cdff75..e472b143 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -9,6 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_xai import ChatXAI from langgraph.prebuilt import ToolNode @@ -62,6 +63,47 @@ class TradingAgentsGraph: self.debug = debug self.config = config or DEFAULT_CONFIG + # get the language + self.output_language = self.config.get("output_language", "en") + self.language_instruction = self.config.get("language_system_prompts", {}).get( + self.output_language, "" + ) + + self._setup_agents() + self._setup_graph() + + def _setup_agents(self): + """初始化所有 agents,並注入語言指示""" + # 為每個 agent 加入語言指示 + if self.language_instruction: + self._inject_language_to_agents() + + def _inject_language_to_agents(self): + """將語言指示注入所有 agents""" + # 這個方法會在每個 agent 的 system message 前加入語言指示 + self.language_system_message = SystemMessage(content=self.language_instruction) + + def _create_agent_with_language(self, agent_name, agent_prompt, llm): + """創建帶有語言指示的 agent""" + from langchain.agents import AgentExecutor, create_openai_functions_agent + from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder + + # 構建包含語言指示的 prompt + if self.language_instruction: + system_prompt = f"{agent_prompt}{self.language_instruction}" + else: + system_prompt = agent_prompt + + prompt = ChatPromptTemplate.from_messages([ + ("system", system_prompt), + MessagesPlaceholder(variable_name="chat_history", optional=True), + ("human", "{input}"), + MessagesPlaceholder(variable_name="agent_scratchpad"), + ]) + + agent = create_openai_functions_agent(llm, tools=[], prompt=prompt) + return AgentExecutor(agent=agent, tools=[], verbose=self.debug) + # Update the interface's config set_config(self.config) @@ -81,6 +123,9 @@ class TradingAgentsGraph: elif self.config["llm_provider"].lower() == "google": self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"]) self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) + elif self.config["llm_provider"].lower() == "xai": + self.deep_thinking_llm = ChatXAI(model=self.config["deep_think_llm"]) + self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"]) else: raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")