support xai
This commit is contained in:
parent
13b826a31d
commit
108bab8402
12
cli/utils.py
12
cli/utils.py
|
|
@ -148,10 +148,14 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
|
||||
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
|
||||
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
|
||||
("qwen/qwen3-coder:free - Qwen3-Coder-480B-A35B-Instruct is a Mixture-of-Experts (MoE) code generation model", "qwen/qwen3-coder:free"),
|
||||
],
|
||||
"ollama": [
|
||||
("llama3.1 local", "llama3.1"),
|
||||
("llama3.2 local", "llama3.2"),
|
||||
],
|
||||
"xai": [
|
||||
("grok-4-fast-non-reasoning", "grok-4-fast-non-reasoning")
|
||||
]
|
||||
}
|
||||
|
||||
|
|
@ -209,11 +213,14 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
],
|
||||
"openrouter": [
|
||||
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
|
||||
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
|
||||
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free")
|
||||
],
|
||||
"ollama": [
|
||||
("llama3.1 local", "llama3.1"),
|
||||
("qwen3", "qwen3"),
|
||||
],
|
||||
"xai": [
|
||||
("grok-4-fast-reasoning", "grok-4-fast-reasoning")
|
||||
]
|
||||
}
|
||||
|
||||
|
|
@ -243,11 +250,12 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
"""Select the OpenAI api url using interactive selection."""
|
||||
# Define OpenAI api options with their corresponding endpoints
|
||||
BASE_URLS = [
|
||||
("XAI", "https://api.x.ai/v1"),
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
|
|||
|
|
@ -24,3 +24,5 @@ rich
|
|||
questionary
|
||||
langchain_anthropic
|
||||
langchain-google-genai
|
||||
langchain-xai
|
||||
sentence-transformers
|
||||
|
|
@ -1,25 +1,44 @@
|
|||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
if config["backend_url"] == "http://localhost:11434/v1":
|
||||
self.client = OpenAI(base_url=config["backend_url"], api_key=config.get("api_key"))
|
||||
|
||||
# 根據 backend 決定 embedding 策略
|
||||
if config["llm_provider"] == "ollama":
|
||||
self.embedding = "nomic-embed-text"
|
||||
self.embedding_client = OpenAI(base_url="http://localhost:11434/v1")
|
||||
self.use_local_embedding = False
|
||||
elif config["llm_provider"] == "xai":
|
||||
# Grok - 使用本地 Hugging Face 模型
|
||||
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
self.use_local_embedding = True
|
||||
else:
|
||||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.embedding_client = self.client
|
||||
self.use_local_embedding = False
|
||||
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get OpenAI embedding for a text"""
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
model=self.embedding, input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
"""Get embedding for a text"""
|
||||
if self.use_local_embedding:
|
||||
# 使用本地 Hugging Face 模型
|
||||
return self.embedding_model.encode(text).tolist()
|
||||
else:
|
||||
# 使用 API
|
||||
response = self.embedding_client.embeddings.create(
|
||||
model=self.embedding,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
|
||||
def add_situations(self, situations_and_advice):
|
||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional
|
|||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
|
|
@ -62,6 +63,47 @@ class TradingAgentsGraph:
|
|||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
|
||||
# get the language
|
||||
self.output_language = self.config.get("output_language", "en")
|
||||
self.language_instruction = self.config.get("language_system_prompts", {}).get(
|
||||
self.output_language, ""
|
||||
)
|
||||
|
||||
self._setup_agents()
|
||||
self._setup_graph()
|
||||
|
||||
def _setup_agents(self):
|
||||
"""初始化所有 agents,並注入語言指示"""
|
||||
# 為每個 agent 加入語言指示
|
||||
if self.language_instruction:
|
||||
self._inject_language_to_agents()
|
||||
|
||||
def _inject_language_to_agents(self):
|
||||
"""將語言指示注入所有 agents"""
|
||||
# 這個方法會在每個 agent 的 system message 前加入語言指示
|
||||
self.language_system_message = SystemMessage(content=self.language_instruction)
|
||||
|
||||
def _create_agent_with_language(self, agent_name, agent_prompt, llm):
|
||||
"""創建帶有語言指示的 agent"""
|
||||
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
# 構建包含語言指示的 prompt
|
||||
if self.language_instruction:
|
||||
system_prompt = f"{agent_prompt}{self.language_instruction}"
|
||||
else:
|
||||
system_prompt = agent_prompt
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", system_prompt),
|
||||
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
||||
("human", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
|
||||
agent = create_openai_functions_agent(llm, tools=[], prompt=prompt)
|
||||
return AgentExecutor(agent=agent, tools=[], verbose=self.debug)
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
|
||||
|
|
@ -81,6 +123,9 @@ class TradingAgentsGraph:
|
|||
elif self.config["llm_provider"].lower() == "google":
|
||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||
elif self.config["llm_provider"].lower() == "xai":
|
||||
self.deep_thinking_llm = ChatXAI(model=self.config["deep_think_llm"])
|
||||
self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue