diff --git a/cli/main.py b/cli/main.py index 64616ee1..7fd21d02 100644 --- a/cli/main.py +++ b/cli/main.py @@ -479,7 +479,8 @@ def get_user_selections(): ) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) - + selected_embedding_model = select_embedding_agent(selected_llm_provider) + return { "ticker": selected_ticker, "analysis_date": analysis_date, @@ -489,6 +490,7 @@ def get_user_selections(): "backend_url": backend_url, "shallow_thinker": selected_shallow_thinker, "deep_thinker": selected_deep_thinker, + "embedding_model": selected_embedding_model, } @@ -741,6 +743,7 @@ def run_analysis(): config["max_risk_discuss_rounds"] = selections["research_depth"] config["quick_think_llm"] = selections["shallow_thinker"] config["deep_think_llm"] = selections["deep_thinker"] + config["embedding_model"] = selections["embedding_model"] config["backend_url"] = selections["backend_url"] config["llm_provider"] = selections["llm_provider"].lower() diff --git a/cli/utils.py b/cli/utils.py index 7b9682a6..538b6a42 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,5 +1,5 @@ import questionary -from typing import List, Optional, Tuple, Dict +from typing import List, Optional, Tuple, Dict, Sequence from cli.models import AnalystType @@ -10,6 +10,55 @@ ANALYST_ORDER = [ ("Fundamentals Analyst", AnalystType.FUNDAMENTALS), ] +def _ask_custom_model(label: str) -> str: + """Prompt the user to type an arbitrary model name.""" + model_name = questionary.text( + f"Enter the exact Ollama model name for {label}:", + validate=lambda x: len(x.strip()) > 0 or "Model name cannot be empty.", + style=questionary.Style([("text", "fg:green")]), + ).ask() + if not model_name: + console.print(f"\n[red]No model name provided. Exiting...[/red]") + exit(1) + return model_name + +def _select_llm( + provider: str, + label: str, + options: Sequence[Tuple[str, str]], +) -> str: + """ + Generic interactive selector that optionally offers a 'custom' entry + for Ollama users. + """ + opts = list(options) + if provider.lower() == "ollama": + opts.append(("Custom model (type manually)", "__CUSTOM__")) + + choice = questionary.select( + f"Select Your [{label}] LLM Engine:", + choices=[questionary.Choice(d, v) for d, v in opts], + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print(f"\n[red]No {label.lower()} engine selected. Exiting...[/red]") + exit(1) + + if choice == "__CUSTOM__": + # ask for arbitrary name + model_name = _ask_custom_model(label) + if model_name is None: + console.print("\n[red]No model name provided. Exiting...[/red]") + exit(1) + return model_name.strip() + return choice def get_ticker() -> str: """Prompt the user to enter a ticker symbol.""" @@ -154,30 +203,7 @@ def select_shallow_thinking_agent(provider) -> str: ("llama3.2 local", "llama3.2"), ] } - - choice = questionary.select( - "Select Your [Quick-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print( - "\n[red]No shallow thinking llm engine selected. Exiting...[/red]" - ) - exit(1) - - return choice + return _select_llm(provider, "Quick-Thinking LLM Engine", SHALLOW_AGENT_OPTIONS[provider.lower()]) def select_deep_thinking_agent(provider) -> str: @@ -217,27 +243,22 @@ def select_deep_thinking_agent(provider) -> str: ] } - choice = questionary.select( - "Select Your [Deep-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in DEEP_AGENT_OPTIONS[provider.lower()] + return _select_llm(provider, "Deep-Thinking LLM Engine", DEEP_AGENT_OPTIONS[provider.lower()]) + +def select_embedding_agent(provider) -> str: + """Select embedding llm engine using an interactive selection.""" + + # Define deep thinking llm engine options with their corresponding model names + EMBEDDING_AGENT_OPTIONS = { + "openai": [ + ("GPT", "text-embedding-3-small"), ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") - exit(1) - - return choice + "ollama": [ + + ] + } + + return _select_llm(provider, "Embedding LLM Engine", EMBEDDING_AGENT_OPTIONS[provider.lower()]) def select_llm_provider() -> tuple[str, str]: """Select the OpenAI api url using interactive selection.""" @@ -247,7 +268,7 @@ def select_llm_provider() -> tuple[str, str]: ("Anthropic", "https://api.anthropic.com/"), ("Google", "https://generativelanguage.googleapis.com/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), + ("Ollama", "http://localhost:11434"), ] choice = questionary.select( diff --git a/main.py b/main.py index 6c8ae3d9..84c11bfe 100644 --- a/main.py +++ b/main.py @@ -3,10 +3,11 @@ from tradingagents.default_config import DEFAULT_CONFIG # Create a custom config config = DEFAULT_CONFIG.copy() -config["llm_provider"] = "google" # Use a different model -config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend -config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model -config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model +config["llm_provider"] = "ollama" # Use a different model +config["backend_url"] = "http://localhost:11434" # Use a different backend +config["deep_think_llm"] = "mixtral:8x7b-instruct-v0.1-q4_K_M" # Use a different model +config["quick_think_llm"] = "phi3:mini" # Use a different model +config["embedding_model"] = "fingpt:7b" # Use a different embedding model config["max_debate_rounds"] = 1 # Increase debate rounds config["online_tools"] = True # Increase debate rounds @@ -14,7 +15,7 @@ config["online_tools"] = True # Increase debate rounds ta = TradingAgentsGraph(debug=True, config=config) # forward propagate -_, decision = ta.propagate("NVDA", "2024-05-10") +_, decision = ta.propagate("NVDA", "2025-07-07") print(decision) # Memorize mistakes and reflect diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 6f507651..0755cf4c 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -1,6 +1,7 @@ from .utils.agent_utils import Toolkit, create_msg_delete from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState from .utils.memory import FinancialSituationMemory +from .utils.safe_bind_tools import safe_bind_tools from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.market_analyst import create_market_analyst @@ -21,6 +22,7 @@ from .trader.trader import create_trader __all__ = [ "FinancialSituationMemory", + "safe_bind_tools", "Toolkit", "AgentState", "create_msg_delete", diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 716d4de1..9da91cd2 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json @@ -47,7 +48,7 @@ def create_fundamentals_analyst(llm, toolkit): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(ticker=ticker) - chain = prompt | llm.bind_tools(tools) + chain = prompt | safe_bind_tools(llm, tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index 41ee944b..0670b2b7 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json @@ -72,7 +73,7 @@ Volume-Based Indicators: prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(ticker=ticker) - chain = prompt | llm.bind_tools(tools) + chain = prompt | safe_bind_tools(llm, tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index e1f03aa4..8e96986d 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json @@ -44,7 +45,7 @@ def create_news_analyst(llm, toolkit): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(ticker=ticker) - chain = prompt | llm.bind_tools(tools) + chain = prompt | safe_bind_tools(llm, tools) result = chain.invoke(state["messages"]) report = "" diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index d556f73a..5de0291c 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,3 +1,4 @@ +from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json @@ -43,7 +44,7 @@ def create_social_media_analyst(llm, toolkit): prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(ticker=ticker) - chain = prompt | llm.bind_tools(tools) + chain = prompt | safe_bind_tools(llm, tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..9f739f4e 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,24 +1,29 @@ import chromadb from chromadb.config import Settings from openai import OpenAI - +from langchain_ollama import OllamaEmbeddings class FinancialSituationMemory: def __init__(self, name, config): - if config["backend_url"] == "http://localhost:11434/v1": - self.embedding = "nomic-embed-text" + if config["backend_url"] == "http://localhost:11434": + self.embedding = OllamaEmbeddings( + model=config["embedding_model"], + base_url=config["backend_url"], # Remove trailing slash + ) else: - self.embedding = "text-embedding-3-small" - self.client = OpenAI(base_url=config["backend_url"]) + self.embedding = config["embedding_model"] + self.client = OpenAI(base_url=config["backend_url"]) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) def get_embedding(self, text): """Get OpenAI embedding for a text""" - - response = self.client.embeddings.create( - model=self.embedding, input=text - ) + try: + response = self.client.embeddings.create( + model=self.embedding, input=text + ) + except AttributeError: + return self.embedding.embed_query(text) return response.data[0].embedding def add_situations(self, situations_and_advice): diff --git a/tradingagents/agents/utils/safe_bind_tools.py b/tradingagents/agents/utils/safe_bind_tools.py new file mode 100644 index 00000000..2db6c526 --- /dev/null +++ b/tradingagents/agents/utils/safe_bind_tools.py @@ -0,0 +1,115 @@ +""" +safe_bind_tools.py +──────────────────────────────────────────────────────── +Attach tool schemas only when the underlying LLM truly +supports OpenAI-style function calling. + +• OpenAI / Anthropic / Google models → always attach +• ChatOllama models → attach **only** + if the Ollama tag contains `"tools": true` +• All other cases → silently fall + back to plain text reasoning +""" + +from __future__ import annotations + +import logging +import shlex +import subprocess +from typing import Any, Sequence + +from langchain_core.language_models.chat_models import BaseChatModel + + +log = logging.getLogger(__name__) + +def _ollama_has_tools_flag(model_name: str) -> bool: + """ + Return True iff `ollama show ` contains tools capability. + If the command fails (e.g. Windows, sandbox), fall back to False. + """ + try: + output = subprocess.check_output( + shlex.split(f"ollama show {model_name}"), text=True + ) + # Check for multiple possible tools indicators + tools_indicators = [ + '"tools": true', # Old format + 'tools ', # New format in Capabilities section + 'tools\n', # Alternative new format + 'tools\t', # Tab-separated format + ] + + # Also check if we're in the Capabilities section + lines = output.split('\n') + in_capabilities = False + for line in lines: + line_stripped = line.strip().lower() + if 'capabilities' in line_stripped: + in_capabilities = True + elif in_capabilities and line_stripped and not line.startswith(' '): + # We've left the capabilities section + in_capabilities = False + elif in_capabilities and 'tools' in line_stripped: + log.debug("Found tools capability for model %s", model_name) + return True + + # Fallback to checking for any tools indicator + for indicator in tools_indicators: + if indicator in output: + log.debug("Found tools indicator '%s' for model %s", indicator, model_name) + return True + + log.debug("No tools capability found for model %s", model_name) + return False + except (NotImplementedError, AttributeError, subprocess.CalledProcessError) as e: + log.debug("Could not inspect model %s: %s", model_name, e) + return False + +def safe_bind_tools( + llm: BaseChatModel, tools: Sequence[dict[str, Any]] +) -> BaseChatModel: + """ + Attach `tools` to an LLM **only** if the model can actually handle them. + Otherwise, return the original LLM unchanged. + + Parameters + ---------- + llm + Any LangChain chat model instance. + tools + List of tool schemas compatible with OpenAI function calling. + + Returns + ------- + BaseChatModel + Either the bound LLM (when tool calling is available) or the + original LLM (fallback). + """ + # LLM has no bind_tools method at all → nothing to do + if not hasattr(llm, "bind_tools"): + return llm + + # Special-case ChatOllama: check for tools capability + if llm.__class__.__name__ == 'ChatOllama': + # Get model name from different possible attributes + model_name = getattr(llm, 'model', None) or getattr(llm, 'model_name', None) + + if model_name and not _ollama_has_tools_flag(model_name): + log.info( + "[safe_bind_tools] Model %s lacks tools support -- skipping.", + model_name, + ) + return llm + + # Generic path: try to bind; fall back gracefully on failure + try: + return llm.bind_tools(tools) + except (NotImplementedError, AttributeError) as e: + log.debug( + "[safe_bind_tools] bind_tools failed for %s: %s – " + "falling back to plain reasoning.", + llm.__class__.__name__, + e, + ) + return llm diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 089e9c24..f4267148 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -13,6 +13,7 @@ DEFAULT_CONFIG = { "deep_think_llm": "o4-mini", "quick_think_llm": "gpt-4o-mini", "backend_url": "https://api.openai.com/v1", + "embedding_model": "text-embedding-3-small", # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 80a29e53..93ebc7f7 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -9,7 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI - +from langchain_ollama.chat_models import ChatOllama from langgraph.prebuilt import ToolNode from tradingagents.agents import * @@ -58,7 +58,19 @@ class TradingAgentsGraph: ) # Initialize LLMs - if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": + if self.config.get("llm_provider") == "ollama": + self.deep_thinking_llm = ChatOllama( + model=self.config["deep_think_llm"], + base_url=self.config["backend_url"], + temperature=0.2, + gpu_layers=32, # ← 這裡就能塞 Ollama 特有參數 + ) + self.quick_thinking_llm = ChatOllama( + model=self.config["quick_think_llm"], + base_url=self.config["backend_url"].rstrip("/v1"), # Remove trailing slash + temperature=0.1, + ) + elif self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "anthropic":