feat: vllm
This commit is contained in:
parent
817eaec872
commit
9e32a416ed
9
main.py
9
main.py
|
|
@ -1,15 +1,16 @@
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Create a custom config
|
# Create a custom config
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "ollama" # Use a different model
|
config["llm_provider"] = "vllm" # Use a different model
|
||||||
config["backend_url"] = "http://localhost:11434/v1" # Use a different backend
|
config["backend_url"] = "http://localhost:8000/v1" # Use a different backend
|
||||||
config["deep_think_llm"] = "llama3.2" # Use a different model
|
config["deep_think_llm"] = "openai/gpt-oss-120b" # Use a different model
|
||||||
config["quick_think_llm"] = "llama3.2" # Use a different model
|
config["quick_think_llm"] = "openai/gpt-oss-120b" # Use a different model
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||||
config["online_tools"] = True # Increase debate rounds
|
config["online_tools"] = True # Increase debate rounds
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,13 @@ dependencies = [
|
||||||
"eodhd>=1.0.32",
|
"eodhd>=1.0.32",
|
||||||
"feedparser>=6.0.11",
|
"feedparser>=6.0.11",
|
||||||
"finnhub-python>=2.4.23",
|
"finnhub-python>=2.4.23",
|
||||||
|
"huggingface-hub>=0.34.4",
|
||||||
|
"kernels>=0.9.0",
|
||||||
"langchain-anthropic>=0.3.15",
|
"langchain-anthropic>=0.3.15",
|
||||||
|
"langchain-community>=0.3.25",
|
||||||
"langchain-experimental>=0.3.4",
|
"langchain-experimental>=0.3.4",
|
||||||
"langchain-google-genai>=2.1.5",
|
"langchain-google-genai>=2.1.5",
|
||||||
|
"langchain-ollama>=0.3.6",
|
||||||
"langchain-openai>=0.3.23",
|
"langchain-openai>=0.3.23",
|
||||||
"langgraph>=0.4.8",
|
"langgraph>=0.4.8",
|
||||||
"pandas>=2.3.0",
|
"pandas>=2.3.0",
|
||||||
|
|
@ -27,8 +31,11 @@ dependencies = [
|
||||||
"rich>=14.0.0",
|
"rich>=14.0.0",
|
||||||
"setuptools>=80.9.0",
|
"setuptools>=80.9.0",
|
||||||
"stockstats>=0.6.5",
|
"stockstats>=0.6.5",
|
||||||
|
"torch>=2.6.0",
|
||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
|
"transformers>=4.53.3",
|
||||||
"tushare>=1.4.21",
|
"tushare>=1.4.21",
|
||||||
"typing-extensions>=4.14.0",
|
"typing-extensions>=4.14.0",
|
||||||
|
"vllm>=0.8.3",
|
||||||
"yfinance>=0.2.63",
|
"yfinance>=0.2.63",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ from langchain_openai import ChatOpenAI
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
|
from langchain_community.llms import VLLM
|
||||||
|
|
||||||
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
|
|
@ -59,18 +61,44 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize LLMs
|
# Initialize LLMs
|
||||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "openrouter":
|
if (
|
||||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
self.config["llm_provider"].lower() == "openai"
|
||||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
or self.config["llm_provider"] == "openrouter"
|
||||||
|
):
|
||||||
|
self.deep_thinking_llm = ChatOpenAI(
|
||||||
|
model=self.config["deep_think_llm"], base_url=self.config["backend_url"]
|
||||||
|
)
|
||||||
|
self.quick_thinking_llm = ChatOpenAI(
|
||||||
|
model=self.config["quick_think_llm"],
|
||||||
|
base_url=self.config["backend_url"],
|
||||||
|
)
|
||||||
elif self.config["llm_provider"].lower() == "ollama":
|
elif self.config["llm_provider"].lower() == "ollama":
|
||||||
self.deep_thinking_llm = ChatOllama(model=self.config["deep_think_llm"])
|
self.deep_thinking_llm = ChatOllama(model=self.config["deep_think_llm"])
|
||||||
self.quick_thinking_llm = ChatOllama(model=self.config["quick_think_llm"])
|
self.quick_thinking_llm = ChatOllama(model=self.config["quick_think_llm"])
|
||||||
elif self.config["llm_provider"].lower() == "anthropic":
|
elif self.config["llm_provider"].lower() == "anthropic":
|
||||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
self.deep_thinking_llm = ChatAnthropic(
|
||||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
model=self.config["deep_think_llm"], base_url=self.config["backend_url"]
|
||||||
|
)
|
||||||
|
self.quick_thinking_llm = ChatAnthropic(
|
||||||
|
model=self.config["quick_think_llm"],
|
||||||
|
base_url=self.config["backend_url"],
|
||||||
|
)
|
||||||
elif self.config["llm_provider"].lower() == "google":
|
elif self.config["llm_provider"].lower() == "google":
|
||||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
self.deep_thinking_llm = ChatGoogleGenerativeAI(
|
||||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
model=self.config["deep_think_llm"]
|
||||||
|
)
|
||||||
|
self.quick_thinking_llm = ChatGoogleGenerativeAI(
|
||||||
|
model=self.config["quick_think_llm"]
|
||||||
|
)
|
||||||
|
elif self.config["llm_provider"].lower() == "vllm":
|
||||||
|
self.deep_thinking_llm = VLLM(
|
||||||
|
model=self.config["deep_think_llm"],
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
self.quick_thinking_llm = VLLM(
|
||||||
|
model=self.config["quick_think_llm"],
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||||
|
|
||||||
|
|
@ -80,8 +108,12 @@ class TradingAgentsGraph:
|
||||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
self.invest_judge_memory = FinancialSituationMemory(
|
||||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
"invest_judge_memory", self.config
|
||||||
|
)
|
||||||
|
self.risk_manager_memory = FinancialSituationMemory(
|
||||||
|
"risk_manager_memory", self.config
|
||||||
|
)
|
||||||
|
|
||||||
# Create tool nodes
|
# Create tool nodes
|
||||||
self.tool_nodes = self._create_tool_nodes()
|
self.tool_nodes = self._create_tool_nodes()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue