Add support for vllm
This commit is contained in:
parent
13b826a31d
commit
080ac8892f
111
cli/utils.py
111
cli/utils.py
|
|
@ -152,24 +152,43 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
"ollama": [
|
"ollama": [
|
||||||
("llama3.1 local", "llama3.1"),
|
("llama3.1 local", "llama3.1"),
|
||||||
("llama3.2 local", "llama3.2"),
|
("llama3.2 local", "llama3.2"),
|
||||||
|
],
|
||||||
|
"vllm": [
|
||||||
|
("llama3.1 local", "llama3.1"),
|
||||||
|
("qwen3", "qwen3"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
choice = questionary.select(
|
if provider == "vllm":
|
||||||
"Select Your [Quick-Thinking LLM Engine]:",
|
choice = questionary.text(
|
||||||
choices=[
|
"Please input the vllm model name for shallow thinking (default: llama3.1):",
|
||||||
questionary.Choice(display, value=value)
|
default="llama3.1",
|
||||||
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid model name.",
|
||||||
],
|
style=questionary.Style(
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
[
|
||||||
style=questionary.Style(
|
("text", "fg:green"),
|
||||||
[
|
("highlighted", "noinherit"),
|
||||||
("selected", "fg:magenta noinherit"),
|
]
|
||||||
("highlighted", "fg:magenta noinherit"),
|
),
|
||||||
("pointer", "fg:magenta noinherit"),
|
).ask()
|
||||||
]
|
|
||||||
),
|
else:
|
||||||
).ask()
|
# Use questionary to create an interactive selection menu for shallow thinking LLM engines
|
||||||
|
choice = questionary.select(
|
||||||
|
"Select Your [Quick-Thinking LLM Engine]:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice(display, value=value)
|
||||||
|
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
|
||||||
|
],
|
||||||
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
|
style=questionary.Style(
|
||||||
|
[
|
||||||
|
("selected", "fg:magenta noinherit"),
|
||||||
|
("highlighted", "fg:magenta noinherit"),
|
||||||
|
("pointer", "fg:magenta noinherit"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print(
|
console.print(
|
||||||
|
|
@ -214,24 +233,43 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
"ollama": [
|
"ollama": [
|
||||||
("llama3.1 local", "llama3.1"),
|
("llama3.1 local", "llama3.1"),
|
||||||
("qwen3", "qwen3"),
|
("qwen3", "qwen3"),
|
||||||
|
],
|
||||||
|
"vllm": [
|
||||||
|
("llama3.1 local", "llama3.1"),
|
||||||
|
("qwen3", "qwen3"),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
choice = questionary.select(
|
if provider == "vllm":
|
||||||
"Select Your [Deep-Thinking LLM Engine]:",
|
choice = questionary.text(
|
||||||
choices=[
|
"Please input the vllm model name for deep thinking (default: llama3.1):",
|
||||||
questionary.Choice(display, value=value)
|
default="llama3.1",
|
||||||
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid model name.",
|
||||||
],
|
style=questionary.Style(
|
||||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
[
|
||||||
style=questionary.Style(
|
("text", "fg:green"),
|
||||||
[
|
("highlighted", "noinherit"),
|
||||||
("selected", "fg:magenta noinherit"),
|
]
|
||||||
("highlighted", "fg:magenta noinherit"),
|
),
|
||||||
("pointer", "fg:magenta noinherit"),
|
).ask()
|
||||||
]
|
|
||||||
),
|
else:
|
||||||
).ask()
|
# Use questionary to create an interactive selection menu for deep thinking LLM engines
|
||||||
|
choice = questionary.select(
|
||||||
|
"Select Your [Deep-Thinking LLM Engine]:",
|
||||||
|
choices=[
|
||||||
|
questionary.Choice(display, value=value)
|
||||||
|
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
|
||||||
|
],
|
||||||
|
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||||
|
style=questionary.Style(
|
||||||
|
[
|
||||||
|
("selected", "fg:magenta noinherit"),
|
||||||
|
("highlighted", "fg:magenta noinherit"),
|
||||||
|
("pointer", "fg:magenta noinherit"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
|
||||||
|
|
@ -248,6 +286,7 @@ def select_llm_provider() -> tuple[str, str]:
|
||||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||||
("Ollama", "http://localhost:11434/v1"),
|
("Ollama", "http://localhost:11434/v1"),
|
||||||
|
("vllm", "http://localhost:8000/v1"),
|
||||||
]
|
]
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -271,6 +310,18 @@ def select_llm_provider() -> tuple[str, str]:
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
display_name, url = choice
|
display_name, url = choice
|
||||||
|
if display_name == "vllm":
|
||||||
|
url = questionary.text(
|
||||||
|
"Please input the vllm api url (default: http://localhost:8000/v1):",
|
||||||
|
default="http://localhost:8000/v1",
|
||||||
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid URL.",
|
||||||
|
style=questionary.Style(
|
||||||
|
[
|
||||||
|
("text", "fg:green"),
|
||||||
|
("highlighted", "noinherit"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).ask()
|
||||||
print(f"You selected: {display_name}\tURL: {url}")
|
print(f"You selected: {display_name}\tURL: {url}")
|
||||||
|
|
||||||
return display_name, url
|
return display_name, url
|
||||||
|
|
|
||||||
2
main.py
2
main.py
|
|
@ -20,6 +20,8 @@ config["data_vendors"] = {
|
||||||
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
|
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config["embeddings"] = "text-embedding-3-small"
|
||||||
|
|
||||||
# Initialize with custom config
|
# Initialize with custom config
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
ta = TradingAgentsGraph(debug=True, config=config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
import questionary
|
||||||
|
|
||||||
|
|
||||||
class FinancialSituationMemory:
|
class FinancialSituationMemory:
|
||||||
def __init__(self, name, config):
|
def __init__(self, name, config):
|
||||||
if config["backend_url"] == "http://localhost:11434/v1":
|
if config["backend_url"] == "http://localhost:11434/v1":
|
||||||
self.embedding = "nomic-embed-text"
|
self.embedding = "nomic-embed-text"
|
||||||
|
elif config["llm_provider"] == "vllm":
|
||||||
|
self.embedding = config["embeddings"]
|
||||||
else:
|
else:
|
||||||
self.embedding = "text-embedding-3-small"
|
self.embedding = "text-embedding-3-small"
|
||||||
self.client = OpenAI(base_url=config["backend_url"])
|
self.client = OpenAI(base_url=config["backend_url"])
|
||||||
|
|
|
||||||
|
|
@ -30,4 +30,5 @@ DEFAULT_CONFIG = {
|
||||||
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
||||||
# Example: "get_news": "openai", # Override category default
|
# Example: "get_news": "openai", # Override category default
|
||||||
},
|
},
|
||||||
|
"embeddings": "text-embedding-3-small", # Options: text-embedding-3-small, nomic-embed-text, vllm model name
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ import json
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
from typing import Dict, Any, Tuple, List, Optional
|
||||||
|
|
||||||
|
import questionary
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
@ -72,7 +74,7 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize LLMs
|
# Initialize LLMs
|
||||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter" or self.config["llm_provider"] == "vllm":
|
||||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||||
elif self.config["llm_provider"].lower() == "anthropic":
|
elif self.config["llm_provider"].lower() == "anthropic":
|
||||||
|
|
@ -85,11 +87,29 @@ class TradingAgentsGraph:
|
||||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||||
|
|
||||||
# Initialize memories
|
# Initialize memories
|
||||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
if self.config["llm_provider"] == "vllm":
|
||||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
questionary.text(
|
||||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
"Please input the vllm embedding model name (default: None):",
|
||||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
default="None",
|
||||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid embedding model name.",
|
||||||
|
style=questionary.Style(
|
||||||
|
[
|
||||||
|
("text", "fg:green"),
|
||||||
|
("highlighted", "noinherit"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).ask()
|
||||||
|
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||||
|
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||||
|
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||||
|
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||||
|
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||||||
|
else:
|
||||||
|
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||||
|
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||||
|
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||||
|
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||||
|
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||||||
|
|
||||||
# Create tool nodes
|
# Create tool nodes
|
||||||
self.tool_nodes = self._create_tool_nodes()
|
self.tool_nodes = self._create_tool_nodes()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue