From 1b3a1ce126140ce4eae59283abf9aa6d39c5c970 Mon Sep 17 00:00:00 2001 From: autotntfan Date: Mon, 7 Jul 2025 18:09:12 +0800 Subject: [PATCH] feat: support for ollama users who run their models locally. --- main.py | 11 +-- tradingagents/agents/__init__.py | 2 + .../agents/analysts/fundamentals_analyst.py | 3 +- .../agents/analysts/market_analyst.py | 3 +- tradingagents/agents/analysts/news_analyst.py | 3 +- .../agents/analysts/social_media_analyst.py | 3 +- tradingagents/agents/utils/memory.py | 23 ++++-- tradingagents/agents/utils/safe_bind_tools.py | 82 +++++++++++++++++++ tradingagents/default_config.py | 1 + tradingagents/graph/trading_graph.py | 16 +++- 10 files changed, 127 insertions(+), 20 deletions(-) create mode 100644 tradingagents/agents/utils/safe_bind_tools.py diff --git a/main.py b/main.py index 6c8ae3d9..84c11bfe 100644 --- a/main.py +++ b/main.py @@ -3,10 +3,11 @@ from tradingagents.default_config import DEFAULT_CONFIG # Create a custom config config = DEFAULT_CONFIG.copy() -config["llm_provider"] = "google" # Use a different model -config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend -config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model -config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model +config["llm_provider"] = "ollama" # Use a different model +config["backend_url"] = "http://localhost:11434" # Use a different backend +config["deep_think_llm"] = "mixtral:8x7b-instruct-v0.1-q4_K_M" # Use a different model +config["quick_think_llm"] = "phi3:mini" # Use a different model +config["embedding_model"] = "fingpt:7b" # Use a different embedding model config["max_debate_rounds"] = 1 # Increase debate rounds config["online_tools"] = True # Increase debate rounds @@ -14,7 +15,7 @@ config["online_tools"] = True # Increase debate rounds ta = TradingAgentsGraph(debug=True, config=config) # forward propagate -_, decision = ta.propagate("NVDA", "2024-05-10") +_, decision = ta.propagate("NVDA", "2025-07-07") print(decision) # Memorize mistakes and reflect diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 6f507651..0755cf4c 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -1,6 +1,7 @@ from .utils.agent_utils import Toolkit, create_msg_delete from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState from .utils.memory import FinancialSituationMemory +from .utils.safe_bind_tools import safe_bind_tools from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.market_analyst import create_market_analyst @@ -21,6 +22,7 @@ from .trader.trader import create_trader __all__ = [ "FinancialSituationMemory", + "safe_bind_tools", "Toolkit", "AgentState", "create_msg_delete", diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 716d4de1..9da91cd2 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json @@ -47,7 +48,7 @@ def create_fundamentals_analyst(llm, toolkit): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(ticker=ticker) - chain = prompt | llm.bind_tools(tools) + chain = prompt | safe_bind_tools(llm, tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 41ee944b..0670b2b7 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json @@ -72,7 +73,7 @@ Volume-Based Indicators: prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(ticker=ticker) - chain = prompt | llm.bind_tools(tools) + chain = prompt | safe_bind_tools(llm, tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index e1f03aa4..8e96986d 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json @@ -44,7 +45,7 @@ def create_news_analyst(llm, toolkit): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(ticker=ticker) - chain = prompt | llm.bind_tools(tools) + chain = prompt | safe_bind_tools(llm, tools) result = chain.invoke(state["messages"]) report = "" diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index d556f73a..5de0291c 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json @@ -43,7 +44,7 @@ def create_social_media_analyst(llm, toolkit): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(ticker=ticker) - chain = prompt | llm.bind_tools(tools) + chain = prompt | safe_bind_tools(llm, tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..9f739f4e 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,24 +1,29 @@ import chromadb from chromadb.config import Settings from openai import OpenAI - +from langchain_ollama import OllamaEmbeddings class FinancialSituationMemory: def __init__(self, name, config): - if config["backend_url"] == "http://localhost:11434/v1": - self.embedding = "nomic-embed-text" + if config["backend_url"] == "http://localhost:11434": + self.embedding = OllamaEmbeddings( + model=config["embedding_model"], + base_url=config["backend_url"], # Remove trailing slash + ) else: - self.embedding = "text-embedding-3-small" - self.client = OpenAI(base_url=config["backend_url"]) + self.embedding = config["embedding_model"] + self.client = OpenAI(base_url=config["backend_url"]) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) def get_embedding(self, text): """Get OpenAI embedding for a text""" - - response = self.client.embeddings.create( - model=self.embedding, input=text - ) + try: + response = self.client.embeddings.create( + model=self.embedding, input=text + ) + except AttributeError: + return self.embedding.embed_query(text) return response.data[0].embedding def add_situations(self, situations_and_advice): diff --git a/tradingagents/agents/utils/safe_bind_tools.py b/tradingagents/agents/utils/safe_bind_tools.py new file mode 100644 index 00000000..e5b5383d --- /dev/null +++ b/tradingagents/agents/utils/safe_bind_tools.py @@ -0,0 +1,82 @@ +""" +safe_bind_tools.py +──────────────────────────────────────────────────────── +Attach tool schemas only when the underlying LLM truly +supports OpenAI-style function calling. + +• OpenAI / Anthropic / Google models → always attach +• ChatOllama models → attach **only** + if the Ollama tag contains `"tools": true` +• All other cases → silently fall + back to plain text reasoning +""" + +from __future__ import annotations + +import logging +import shlex +import subprocess +from typing import Any, Sequence + +from langchain_core.language_models.chat_models import BaseChatModel + + +log = logging.getLogger(__name__) + +def _ollama_has_tools_flag(model_name: str) -> bool: + """ + Return True iff `ollama show ` contains `"tools": true`. + If the command fails (e.g. Windows, sandbox), fall back to False. + """ + try: + output = subprocess.check_output( + shlex.split(f"ollama show {model_name}"), text=True + ) + return '"tools": true' in output + except (NotImplementedError, AttributeError) as e: + log.debug("Could not inspect model %s: %s", model_name, e) + return False + +def safe_bind_tools( + llm: BaseChatModel, tools: Sequence[dict[str, Any]] +) -> BaseChatModel: + """ + Attach `tools` to an LLM **only** if the model can actually handle them. + Otherwise, return the original LLM unchanged. + + Parameters + ---------- + llm + Any LangChain chat model instance. + tools + List of tool schemas compatible with OpenAI function calling. + + Returns + ------- + BaseChatModel + Either the bound LLM (when tool calling is available) or the + original LLM (fallback). + """ + # LLM has no bind_tools method at all → nothing to do + if not hasattr(llm, "bind_tools"): + return llm + + # Special-case ChatOllama: check the `"tools": true` tag first + if isinstance(llm, BaseChatModel) and not _ollama_has_tools_flag(llm.model): + log.info( + "[safe_bind_tools] Model %s lacks tools support -- skipping.", + llm.model, + ) + return llm + + # Generic path: try to bind; fall back gracefully on failure + try: + return llm.bind_tools(tools) + except (NotImplementedError, AttributeError) as e: + log.debug( + "[safe_bind_tools] bind_tools failed for %s: %s – " + "falling back to plain reasoning.", + llm.__class__.__name__, + e, + ) + return llm diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 089e9c24..f4267148 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -13,6 +13,7 @@ DEFAULT_CONFIG = { "deep_think_llm": "o4-mini", "quick_think_llm": "gpt-4o-mini", "backend_url": "https://api.openai.com/v1", + "embedding_model": "text-embedding-3-small", # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 80a29e53..93ebc7f7 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -9,7 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI - +from langchain_ollama.chat_models import ChatOllama from langgraph.prebuilt import ToolNode from tradingagents.agents import * @@ -58,7 +58,19 @@ class TradingAgentsGraph: ) # Initialize LLMs - if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": + if self.config.get("llm_provider") == "ollama": + self.deep_thinking_llm = ChatOllama( + model=self.config["deep_think_llm"], + base_url=self.config["backend_url"], + temperature=0.2, + gpu_layers=32, # ← 這裡就能塞 Ollama 特有參數 + ) + self.quick_thinking_llm = ChatOllama( + model=self.config["quick_think_llm"], + base_url=self.config["backend_url"].rstrip("/v1"), # Remove trailing slash + temperature=0.1, + ) + elif self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "anthropic":