feat: support for ollama users who run their models locally.

This commit is contained in:
autotntfan 2025-07-07 18:09:12 +08:00
parent a438acdbbd
commit 1b3a1ce126
10 changed files with 127 additions and 20 deletions

11
main.py
View File

@ -3,10 +3,11 @@ from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config # Create a custom config
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model config["llm_provider"] = "ollama" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend config["backend_url"] = "http://localhost:11434" # Use a different backend
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model config["deep_think_llm"] = "mixtral:8x7b-instruct-v0.1-q4_K_M" # Use a different model
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model config["quick_think_llm"] = "phi3:mini" # Use a different model
config["embedding_model"] = "fingpt:7b" # Use a different embedding model
config["max_debate_rounds"] = 1 # Increase debate rounds config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Increase debate rounds config["online_tools"] = True # Increase debate rounds
@ -14,7 +15,7 @@ config["online_tools"] = True # Increase debate rounds
ta = TradingAgentsGraph(debug=True, config=config) ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate # forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10") _, decision = ta.propagate("NVDA", "2025-07-07")
print(decision) print(decision)
# Memorize mistakes and reflect # Memorize mistakes and reflect

View File

@ -1,6 +1,7 @@
from .utils.agent_utils import Toolkit, create_msg_delete from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory from .utils.memory import FinancialSituationMemory
from .utils.safe_bind_tools import safe_bind_tools
from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst from .analysts.market_analyst import create_market_analyst
@ -21,6 +22,7 @@ from .trader.trader import create_trader
__all__ = [ __all__ = [
"FinancialSituationMemory", "FinancialSituationMemory",
"safe_bind_tools",
"Toolkit", "Toolkit",
"AgentState", "AgentState",
"create_msg_delete", "create_msg_delete",

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
@ -47,7 +48,7 @@ def create_fundamentals_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker) prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools) chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"]) result = chain.invoke(state["messages"])

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
@ -72,7 +73,7 @@ Volume-Based Indicators:
prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker) prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools) chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"]) result = chain.invoke(state["messages"])

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
@ -44,7 +45,7 @@ def create_news_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker) prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools) chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"]) result = chain.invoke(state["messages"])
report = "" report = ""

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time import time
import json import json
@ -43,7 +44,7 @@ def create_social_media_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker) prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools) chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"]) result = chain.invoke(state["messages"])

View File

@ -1,24 +1,29 @@
import chromadb import chromadb
from chromadb.config import Settings from chromadb.config import Settings
from openai import OpenAI from openai import OpenAI
from langchain_ollama import OllamaEmbeddings
class FinancialSituationMemory: class FinancialSituationMemory:
def __init__(self, name, config): def __init__(self, name, config):
if config["backend_url"] == "http://localhost:11434/v1": if config["backend_url"] == "http://localhost:11434":
self.embedding = "nomic-embed-text" self.embedding = OllamaEmbeddings(
model=config["embedding_model"],
base_url=config["backend_url"], # Remove trailing slash
)
else: else:
self.embedding = "text-embedding-3-small" self.embedding = config["embedding_model"]
self.client = OpenAI(base_url=config["backend_url"]) self.client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name) self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text): def get_embedding(self, text):
"""Get OpenAI embedding for a text""" """Get OpenAI embedding for a text"""
try:
response = self.client.embeddings.create( response = self.client.embeddings.create(
model=self.embedding, input=text model=self.embedding, input=text
) )
except AttributeError:
return self.embedding.embed_query(text)
return response.data[0].embedding return response.data[0].embedding
def add_situations(self, situations_and_advice): def add_situations(self, situations_and_advice):

View File

@ -0,0 +1,82 @@
"""
safe_bind_tools.py
Attach tool schemas only when the underlying LLM truly
supports OpenAI-style function calling.
OpenAI / Anthropic / Google models always attach
ChatOllama models attach **only**
if the Ollama tag contains `"tools": true`
All other cases silently fall
back to plain text reasoning
"""
from __future__ import annotations
import logging
import shlex
import subprocess
from typing import Any, Sequence
from langchain_core.language_models.chat_models import BaseChatModel
log = logging.getLogger(__name__)
def _ollama_has_tools_flag(model_name: str) -> bool:
"""
Return True iff `ollama show <model_name>` contains `"tools": true`.
If the command fails (e.g. Windows, sandbox), fall back to False.
"""
try:
output = subprocess.check_output(
shlex.split(f"ollama show {model_name}"), text=True
)
return '"tools": true' in output
except (NotImplementedError, AttributeError) as e:
log.debug("Could not inspect model %s: %s", model_name, e)
return False
def safe_bind_tools(
llm: BaseChatModel, tools: Sequence[dict[str, Any]]
) -> BaseChatModel:
"""
Attach `tools` to an LLM **only** if the model can actually handle them.
Otherwise, return the original LLM unchanged.
Parameters
----------
llm
Any LangChain chat model instance.
tools
List of tool schemas compatible with OpenAI function calling.
Returns
-------
BaseChatModel
Either the bound LLM (when tool calling is available) or the
original LLM (fallback).
"""
# LLM has no bind_tools method at all → nothing to do
if not hasattr(llm, "bind_tools"):
return llm
# Special-case ChatOllama: check the `"tools": true` tag first
if isinstance(llm, BaseChatModel) and not _ollama_has_tools_flag(llm.model):
log.info(
"[safe_bind_tools] Model %s lacks tools support -- skipping.",
llm.model,
)
return llm
# Generic path: try to bind; fall back gracefully on failure
try:
return llm.bind_tools(tools)
except (NotImplementedError, AttributeError) as e:
log.debug(
"[safe_bind_tools] bind_tools failed for %s: %s "
"falling back to plain reasoning.",
llm.__class__.__name__,
e,
)
return llm

View File

@ -13,6 +13,7 @@ DEFAULT_CONFIG = {
"deep_think_llm": "o4-mini", "deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini", "quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1", "backend_url": "https://api.openai.com/v1",
"embedding_model": "text-embedding-3-small",
# Debate and discussion settings # Debate and discussion settings
"max_debate_rounds": 1, "max_debate_rounds": 1,
"max_risk_discuss_rounds": 1, "max_risk_discuss_rounds": 1,

View File

@ -9,7 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama.chat_models import ChatOllama
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from tradingagents.agents import * from tradingagents.agents import *
@ -58,7 +58,19 @@ class TradingAgentsGraph:
) )
# Initialize LLMs # Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": if self.config.get("llm_provider") == "ollama":
self.deep_thinking_llm = ChatOllama(
model=self.config["deep_think_llm"],
base_url=self.config["backend_url"],
temperature=0.2,
gpu_layers=32, # ← 這裡就能塞 Ollama 特有參數
)
self.quick_thinking_llm = ChatOllama(
model=self.config["quick_think_llm"],
base_url=self.config["backend_url"].rstrip("/v1"), # Remove trailing slash
temperature=0.1,
)
elif self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic": elif self.config["llm_provider"].lower() == "anthropic":