This commit is contained in:
Jifeng Ge 2025-10-16 15:12:18 +08:00 committed by GitHub
commit 2c9de7a144
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 35 deletions

View File

@ -152,24 +152,40 @@ def select_shallow_thinking_agent(provider) -> str:
"ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"),
]
],
"vllm": [],
}
choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if provider == "vllm":
choice = questionary.text(
"Please input the vllm model name for shallow thinking (default: llama3.1):",
default="llama3.1",
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid model name.",
style=questionary.Style(
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
),
).ask()
else:
# Use questionary to create an interactive selection menu for shallow thinking LLM engines
choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if choice is None:
console.print(
@ -214,24 +230,40 @@ def select_deep_thinking_agent(provider) -> str:
"ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"),
]
}
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
"vllm": [],
}
if provider == "vllm":
choice = questionary.text(
"Please input the vllm model name for deep thinking (default: llama3.1):",
default="llama3.1",
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid model name.",
style=questionary.Style(
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
),
).ask()
else:
# Use questionary to create an interactive selection menu for deep thinking LLM engines
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
@ -247,7 +279,8 @@ def select_llm_provider() -> tuple[str, str]:
("Anthropic", "https://api.anthropic.com/"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Ollama", "http://localhost:11434/v1"),
("vllm", "http://localhost:8000/v1"),
]
choice = questionary.select(
@ -271,6 +304,18 @@ def select_llm_provider() -> tuple[str, str]:
exit(1)
display_name, url = choice
if display_name == "vllm":
url = questionary.text(
"Please input the vllm api url (default: http://localhost:8000/v1):",
default="http://localhost:8000/v1",
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid URL.",
style=questionary.Style(
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
),
).ask()
print(f"You selected: {display_name}\tURL: {url}")
return display_name, url

View File

@ -20,6 +20,8 @@ config["data_vendors"] = {
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
}
config["embeddings"] = "text-embedding-3-small"
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)

View File

@ -7,6 +7,8 @@ class FinancialSituationMemory:
def __init__(self, name, config):
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
elif config["llm_provider"] == "vllm":
self.embedding = config["embeddings"]
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])

View File

@ -30,4 +30,5 @@ DEFAULT_CONFIG = {
# Example: "get_stock_data": "alpha_vantage", # Override category default
# Example: "get_news": "openai", # Override category default
},
"embeddings": "text-embedding-3-small", # Options: text-embedding-3-small, nomic-embed-text, vllm model name
}

View File

@ -6,6 +6,8 @@ import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
import questionary
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
@ -72,7 +74,7 @@ class TradingAgentsGraph:
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter" or self.config["llm_provider"] == "vllm":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
@ -85,6 +87,19 @@ class TradingAgentsGraph:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
# Initialize memories
if self.config["llm_provider"] == "vllm":
self.config["embeddings"] = questionary.text(
"Please input the vllm embedding model name (default: None):",
default="None",
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid embedding model name.",
style=questionary.Style(
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
),
).ask()
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)