feat: support for ollama users who run their models locally.

This commit is contained in:
autotntfan 2025-07-07 18:09:12 +08:00
parent a438acdbbd
commit 1b3a1ce126
10 changed files with 127 additions and 20 deletions

11
main.py
View File

@ -3,10 +3,11 @@ from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model
config["llm_provider"] = "ollama" # Use a different model
config["backend_url"] = "http://localhost:11434" # Use a different backend
config["deep_think_llm"] = "mixtral:8x7b-instruct-v0.1-q4_K_M" # Use a different model
config["quick_think_llm"] = "phi3:mini" # Use a different model
config["embedding_model"] = "fingpt:7b" # Use a different embedding model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Increase debate rounds
@ -14,7 +15,7 @@ config["online_tools"] = True # Increase debate rounds
ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
_, decision = ta.propagate("NVDA", "2025-07-07")
print(decision)
# Memorize mistakes and reflect

View File

@ -1,6 +1,7 @@
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .utils.safe_bind_tools import safe_bind_tools
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
@ -21,6 +22,7 @@ from .trader.trader import create_trader
__all__ = [
"FinancialSituationMemory",
"safe_bind_tools",
"Toolkit",
"AgentState",
"create_msg_delete",

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
@ -47,7 +48,7 @@ def create_fundamentals_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"])

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
@ -72,7 +73,7 @@ Volume-Based Indicators:
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"])

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
@ -44,7 +45,7 @@ def create_news_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"])
report = ""

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
@ -43,7 +44,7 @@ def create_social_media_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"])

View File

@ -1,24 +1,29 @@
import chromadb
from chromadb.config import Settings
from openai import OpenAI
from langchain_ollama import OllamaEmbeddings
class FinancialSituationMemory:
def __init__(self, name, config):
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
if config["backend_url"] == "http://localhost:11434":
self.embedding = OllamaEmbeddings(
model=config["embedding_model"],
base_url=config["backend_url"], # Remove trailing slash
)
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.embedding = config["embedding_model"]
self.client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text):
"""Get OpenAI embedding for a text"""
response = self.client.embeddings.create(
model=self.embedding, input=text
)
try:
response = self.client.embeddings.create(
model=self.embedding, input=text
)
except AttributeError:
return self.embedding.embed_query(text)
return response.data[0].embedding
def add_situations(self, situations_and_advice):

View File

@ -0,0 +1,82 @@
"""
safe_bind_tools.py
Attach tool schemas only when the underlying LLM truly
supports OpenAI-style function calling.
OpenAI / Anthropic / Google models always attach
ChatOllama models attach **only**
if the Ollama tag contains `"tools": true`
All other cases silently fall
back to plain text reasoning
"""
from __future__ import annotations
import logging
import shlex
import subprocess
from typing import Any, Sequence
from langchain_core.language_models.chat_models import BaseChatModel
log = logging.getLogger(__name__)
def _ollama_has_tools_flag(model_name: str) -> bool:
"""
Return True iff `ollama show <model_name>` contains `"tools": true`.
If the command fails (e.g. Windows, sandbox), fall back to False.
"""
try:
output = subprocess.check_output(
shlex.split(f"ollama show {model_name}"), text=True
)
return '"tools": true' in output
except (NotImplementedError, AttributeError) as e:
log.debug("Could not inspect model %s: %s", model_name, e)
return False
def safe_bind_tools(
llm: BaseChatModel, tools: Sequence[dict[str, Any]]
) -> BaseChatModel:
"""
Attach `tools` to an LLM **only** if the model can actually handle them.
Otherwise, return the original LLM unchanged.
Parameters
----------
llm
Any LangChain chat model instance.
tools
List of tool schemas compatible with OpenAI function calling.
Returns
-------
BaseChatModel
Either the bound LLM (when tool calling is available) or the
original LLM (fallback).
"""
# LLM has no bind_tools method at all → nothing to do
if not hasattr(llm, "bind_tools"):
return llm
# Special-case ChatOllama: check the `"tools": true` tag first
if isinstance(llm, BaseChatModel) and not _ollama_has_tools_flag(llm.model):
log.info(
"[safe_bind_tools] Model %s lacks tools support -- skipping.",
llm.model,
)
return llm
# Generic path: try to bind; fall back gracefully on failure
try:
return llm.bind_tools(tools)
except (NotImplementedError, AttributeError) as e:
log.debug(
"[safe_bind_tools] bind_tools failed for %s: %s "
"falling back to plain reasoning.",
llm.__class__.__name__,
e,
)
return llm

View File

@ -13,6 +13,7 @@ DEFAULT_CONFIG = {
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
"embedding_model": "text-embedding-3-small",
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,

View File

@ -9,7 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama.chat_models import ChatOllama
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
@ -58,7 +58,19 @@ class TradingAgentsGraph:
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
if self.config.get("llm_provider") == "ollama":
self.deep_thinking_llm = ChatOllama(
model=self.config["deep_think_llm"],
base_url=self.config["backend_url"],
temperature=0.2,
gpu_layers=32, # ← 這裡就能塞 Ollama 特有參數
)
self.quick_thinking_llm = ChatOllama(
model=self.config["quick_think_llm"],
base_url=self.config["backend_url"].rstrip("/v1"), # Remove trailing slash
temperature=0.1,
)
elif self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":