This commit is contained in:
YUAN LIN 2025-07-13 23:03:52 +08:00 committed by GitHub
commit 17d305aaf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 231 additions and 67 deletions

View File

@ -479,7 +479,8 @@ def get_user_selections():
)
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
selected_embedding_model = select_embedding_agent(selected_llm_provider)
return {
"ticker": selected_ticker,
"analysis_date": analysis_date,
@ -489,6 +490,7 @@ def get_user_selections():
"backend_url": backend_url,
"shallow_thinker": selected_shallow_thinker,
"deep_thinker": selected_deep_thinker,
"embedding_model": selected_embedding_model,
}
@ -741,6 +743,7 @@ def run_analysis():
config["max_risk_discuss_rounds"] = selections["research_depth"]
config["quick_think_llm"] = selections["shallow_thinker"]
config["deep_think_llm"] = selections["deep_thinker"]
config["embedding_model"] = selections["embedding_model"]
config["backend_url"] = selections["backend_url"]
config["llm_provider"] = selections["llm_provider"].lower()

View File

@ -1,5 +1,5 @@
import questionary
from typing import List, Optional, Tuple, Dict
from typing import List, Optional, Tuple, Dict, Sequence
from cli.models import AnalystType
@ -10,6 +10,55 @@ ANALYST_ORDER = [
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
]
def _ask_custom_model(label: str) -> str:
"""Prompt the user to type an arbitrary model name."""
model_name = questionary.text(
f"Enter the exact Ollama model name for {label}:",
validate=lambda x: len(x.strip()) > 0 or "Model name cannot be empty.",
style=questionary.Style([("text", "fg:green")]),
).ask()
if not model_name:
console.print(f"\n[red]No model name provided. Exiting...[/red]")
exit(1)
return model_name
def _select_llm(
provider: str,
label: str,
options: Sequence[Tuple[str, str]],
) -> str:
"""
Generic interactive selector that optionally offers a 'custom' entry
for Ollama users.
"""
opts = list(options)
if provider.lower() == "ollama":
opts.append(("Custom model (type manually)", "__CUSTOM__"))
choice = questionary.select(
f"Select Your [{label}] LLM Engine:",
choices=[questionary.Choice(d, v) for d, v in opts],
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if choice is None:
console.print(f"\n[red]No {label.lower()} engine selected. Exiting...[/red]")
exit(1)
if choice == "__CUSTOM__":
# ask for arbitrary name
model_name = _ask_custom_model(label)
if model_name is None:
console.print("\n[red]No model name provided. Exiting...[/red]")
exit(1)
return model_name.strip()
return choice
def get_ticker() -> str:
"""Prompt the user to enter a ticker symbol."""
@ -154,30 +203,7 @@ def select_shallow_thinking_agent(provider) -> str:
("llama3.2 local", "llama3.2"),
]
}
choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if choice is None:
console.print(
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
)
exit(1)
return choice
return _select_llm(provider, "Quick-Thinking LLM Engine", SHALLOW_AGENT_OPTIONS[provider.lower()])
def select_deep_thinking_agent(provider) -> str:
@ -217,27 +243,22 @@ def select_deep_thinking_agent(provider) -> str:
]
}
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
return _select_llm(provider, "Deep-Thinking LLM Engine", DEEP_AGENT_OPTIONS[provider.lower()])
def select_embedding_agent(provider) -> str:
"""Select embedding llm engine using an interactive selection."""
# Define deep thinking llm engine options with their corresponding model names
EMBEDDING_AGENT_OPTIONS = {
"openai": [
("GPT", "text-embedding-3-small"),
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
exit(1)
return choice
"ollama": [
]
}
return _select_llm(provider, "Embedding LLM Engine", EMBEDDING_AGENT_OPTIONS[provider.lower()])
def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection."""
@ -247,7 +268,7 @@ def select_llm_provider() -> tuple[str, str]:
("Anthropic", "https://api.anthropic.com/"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Ollama", "http://localhost:11434"),
]
choice = questionary.select(

11
main.py
View File

@ -3,10 +3,11 @@ from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model
config["llm_provider"] = "ollama" # Use a different model
config["backend_url"] = "http://localhost:11434" # Use a different backend
config["deep_think_llm"] = "mixtral:8x7b-instruct-v0.1-q4_K_M" # Use a different model
config["quick_think_llm"] = "phi3:mini" # Use a different model
config["embedding_model"] = "fingpt:7b" # Use a different embedding model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Increase debate rounds
@ -14,7 +15,7 @@ config["online_tools"] = True # Increase debate rounds
ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
_, decision = ta.propagate("NVDA", "2025-07-07")
print(decision)
# Memorize mistakes and reflect

View File

@ -1,6 +1,7 @@
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .utils.safe_bind_tools import safe_bind_tools
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
@ -21,6 +22,7 @@ from .trader.trader import create_trader
__all__ = [
"FinancialSituationMemory",
"safe_bind_tools",
"Toolkit",
"AgentState",
"create_msg_delete",

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
@ -47,7 +48,7 @@ def create_fundamentals_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"])

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
@ -72,7 +73,7 @@ Volume-Based Indicators:
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"])

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
@ -44,7 +45,7 @@ def create_news_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"])
report = ""

View File

@ -1,3 +1,4 @@
from tradingagents.agents.utils.safe_bind_tools import safe_bind_tools
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
@ -43,7 +44,7 @@ def create_social_media_analyst(llm, toolkit):
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(ticker=ticker)
chain = prompt | llm.bind_tools(tools)
chain = prompt | safe_bind_tools(llm, tools)
result = chain.invoke(state["messages"])

View File

@ -1,24 +1,29 @@
import chromadb
from chromadb.config import Settings
from openai import OpenAI
from langchain_ollama import OllamaEmbeddings
class FinancialSituationMemory:
def __init__(self, name, config):
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
if config["backend_url"] == "http://localhost:11434":
self.embedding = OllamaEmbeddings(
model=config["embedding_model"],
base_url=config["backend_url"], # Remove trailing slash
)
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.embedding = config["embedding_model"]
self.client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text):
"""Get OpenAI embedding for a text"""
response = self.client.embeddings.create(
model=self.embedding, input=text
)
try:
response = self.client.embeddings.create(
model=self.embedding, input=text
)
except AttributeError:
return self.embedding.embed_query(text)
return response.data[0].embedding
def add_situations(self, situations_and_advice):

View File

@ -0,0 +1,115 @@
"""
safe_bind_tools.py
Attach tool schemas only when the underlying LLM truly
supports OpenAI-style function calling.
OpenAI / Anthropic / Google models always attach
ChatOllama models attach **only**
if the Ollama tag contains `"tools": true`
All other cases silently fall
back to plain text reasoning
"""
from __future__ import annotations
import logging
import shlex
import subprocess
from typing import Any, Sequence
from langchain_core.language_models.chat_models import BaseChatModel
log = logging.getLogger(__name__)
def _ollama_has_tools_flag(model_name: str) -> bool:
"""
Return True iff `ollama show <model_name>` contains tools capability.
If the command fails (e.g. Windows, sandbox), fall back to False.
"""
try:
output = subprocess.check_output(
shlex.split(f"ollama show {model_name}"), text=True
)
# Check for multiple possible tools indicators
tools_indicators = [
'"tools": true', # Old format
'tools ', # New format in Capabilities section
'tools\n', # Alternative new format
'tools\t', # Tab-separated format
]
# Also check if we're in the Capabilities section
lines = output.split('\n')
in_capabilities = False
for line in lines:
line_stripped = line.strip().lower()
if 'capabilities' in line_stripped:
in_capabilities = True
elif in_capabilities and line_stripped and not line.startswith(' '):
# We've left the capabilities section
in_capabilities = False
elif in_capabilities and 'tools' in line_stripped:
log.debug("Found tools capability for model %s", model_name)
return True
# Fallback to checking for any tools indicator
for indicator in tools_indicators:
if indicator in output:
log.debug("Found tools indicator '%s' for model %s", indicator, model_name)
return True
log.debug("No tools capability found for model %s", model_name)
return False
except (NotImplementedError, AttributeError, subprocess.CalledProcessError) as e:
log.debug("Could not inspect model %s: %s", model_name, e)
return False
def safe_bind_tools(
llm: BaseChatModel, tools: Sequence[dict[str, Any]]
) -> BaseChatModel:
"""
Attach `tools` to an LLM **only** if the model can actually handle them.
Otherwise, return the original LLM unchanged.
Parameters
----------
llm
Any LangChain chat model instance.
tools
List of tool schemas compatible with OpenAI function calling.
Returns
-------
BaseChatModel
Either the bound LLM (when tool calling is available) or the
original LLM (fallback).
"""
# LLM has no bind_tools method at all → nothing to do
if not hasattr(llm, "bind_tools"):
return llm
# Special-case ChatOllama: check for tools capability
if llm.__class__.__name__ == 'ChatOllama':
# Get model name from different possible attributes
model_name = getattr(llm, 'model', None) or getattr(llm, 'model_name', None)
if model_name and not _ollama_has_tools_flag(model_name):
log.info(
"[safe_bind_tools] Model %s lacks tools support -- skipping.",
model_name,
)
return llm
# Generic path: try to bind; fall back gracefully on failure
try:
return llm.bind_tools(tools)
except (NotImplementedError, AttributeError) as e:
log.debug(
"[safe_bind_tools] bind_tools failed for %s: %s "
"falling back to plain reasoning.",
llm.__class__.__name__,
e,
)
return llm

View File

@ -13,6 +13,7 @@ DEFAULT_CONFIG = {
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
"embedding_model": "text-embedding-3-small",
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,

View File

@ -9,7 +9,7 @@ from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama.chat_models import ChatOllama
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
@ -58,7 +58,19 @@ class TradingAgentsGraph:
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
if self.config.get("llm_provider") == "ollama":
self.deep_thinking_llm = ChatOllama(
model=self.config["deep_think_llm"],
base_url=self.config["backend_url"],
temperature=0.2,
gpu_layers=32, # ← 這裡就能塞 Ollama 特有參數
)
self.quick_thinking_llm = ChatOllama(
model=self.config["quick_think_llm"],
base_url=self.config["backend_url"].rstrip("/v1"), # Remove trailing slash
temperature=0.1,
)
elif self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":