update to use BaseChatModel instead of ChatOpenAI to avoid importing Openai when using other LLM provider

This commit is contained in:
Jeffrey Chu 2025-10-23 19:15:34 +08:00
parent 44a5cb133a
commit 65ef62de17
5 changed files with 51 additions and 57 deletions

View File

@ -1,10 +1,7 @@
from typing import Annotated, Sequence
from datetime import date, timedelta, datetime
from typing_extensions import TypedDict, Optional
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
from typing import Annotated
from langgraph.graph import MessagesState
from typing_extensions import TypedDict
# Researcher team state

View File

@ -1,13 +1,14 @@
# TradingAgents/graph/reflection.py
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langchain_core.language_models.chat_models import BaseChatModel
class Reflector:
"""Handles reflection on decisions and updating memory."""
def __init__(self, quick_thinking_llm: ChatOpenAI):
def __init__(self, quick_thinking_llm: BaseChatModel):
"""Initialize the reflector with an LLM."""
self.quick_thinking_llm = quick_thinking_llm
self.reflection_system_prompt = self._get_reflection_prompt()
@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
def _reflect_on_component(
self, component_type: str, report: str, situation: str, returns_losses
self, component_type: str, report: str, situation: str, returns_losses
) -> str:
"""Generate reflection for a component."""
messages = [

View File

@ -1,13 +1,13 @@
# TradingAgents/graph/setup.py
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.agents.utils.agent_states import AgentState
from .conditional_logic import ConditionalLogic
@ -15,17 +15,17 @@ class GraphSetup:
"""Handles the setup and configuration of the agent graph."""
def __init__(
self,
quick_thinking_llm: ChatOpenAI,
deep_thinking_llm: ChatOpenAI,
tool_nodes: Dict[str, ToolNode],
bull_memory,
bear_memory,
trader_memory,
invest_judge_memory,
risk_manager_memory,
conditional_logic: ConditionalLogic,
config: Dict[str, Any],
self,
quick_thinking_llm: BaseChatModel,
deep_thinking_llm: BaseChatModel,
tool_nodes: Dict[str, ToolNode],
bull_memory,
bear_memory,
trader_memory,
invest_judge_memory,
risk_manager_memory,
conditional_logic: ConditionalLogic,
config: Dict[str, Any],
):
"""Initialize with required components."""
self.quick_thinking_llm = quick_thinking_llm
@ -40,7 +40,7 @@ class GraphSetup:
self.config = config
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
self, selected_analysts=["market", "social", "news", "fundamentals"]
):
"""Set up and compile the agent workflow graph.
@ -149,7 +149,7 @@ class GraphSetup:
# Connect to next analyst or to Bull Researcher if this is the last analyst
if i < len(selected_analysts) - 1:
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
next_analyst = f"{selected_analysts[i + 1].capitalize()} Analyst"
workflow.add_edge(current_clear, next_analyst)
else:
workflow.add_edge(current_clear, "Bull Researcher")

View File

@ -1,12 +1,13 @@
# TradingAgents/graph/signal_processing.py
from langchain_openai import ChatOpenAI
from langchain_core.language_models.chat_models import BaseChatModel
class SignalProcessor:
"""Processes trading signals to extract actionable decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI):
def __init__(self, quick_thinking_llm: BaseChatModel):
"""Initialize with an LLM for processing."""
self.quick_thinking_llm = quick_thinking_llm

View File

@ -1,28 +1,16 @@
# TradingAgents/graph/trading_graph.py
import json
import os
from pathlib import Path
import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_xai import ChatXAI
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
from tradingagents.dataflows.config import set_config
# Import the new abstract tool methods from agent_utils
from tradingagents.agents.utils.agent_utils import (
get_stock_data,
@ -36,11 +24,13 @@ from tradingagents.agents.utils.agent_utils import (
get_insider_transactions,
get_global_news
)
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.dataflows.config import set_config
from tradingagents.default_config import DEFAULT_CONFIG
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor
@ -48,10 +38,10 @@ class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework."""
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
self,
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
):
"""Initialize the trading agents graph and components.
@ -62,7 +52,7 @@ class TradingAgentsGraph:
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
# Update the interface's config
set_config(self.config)
@ -73,12 +63,17 @@ class TradingAgentsGraph:
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config[
"llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"],
base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"],
base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"],
base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"],
base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
@ -87,7 +82,7 @@ class TradingAgentsGraph:
self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"])
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
@ -234,8 +229,8 @@ class TradingAgentsGraph:
directory.mkdir(parents=True, exist_ok=True)
with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
"w",
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
"w",
) as f:
json.dump(self.log_states_dict, f, indent=4)