update to use BaseChatModel instead of ChatOpenAI to avoid importing Openai when using other LLM provider

This commit is contained in:
Jeffrey Chu 2025-10-23 19:15:34 +08:00
parent 44a5cb133a
commit 65ef62de17
5 changed files with 51 additions and 57 deletions

View File

@ -1,10 +1,7 @@
from typing import Annotated, Sequence from typing import Annotated
from datetime import date, timedelta, datetime
from typing_extensions import TypedDict, Optional from langgraph.graph import MessagesState
from langchain_openai import ChatOpenAI from typing_extensions import TypedDict
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
# Researcher team state # Researcher team state

View File

@ -1,13 +1,14 @@
# TradingAgents/graph/reflection.py # TradingAgents/graph/reflection.py
from typing import Dict, Any from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langchain_core.language_models.chat_models import BaseChatModel
class Reflector: class Reflector:
"""Handles reflection on decisions and updating memory.""" """Handles reflection on decisions and updating memory."""
def __init__(self, quick_thinking_llm: ChatOpenAI): def __init__(self, quick_thinking_llm: BaseChatModel):
"""Initialize the reflector with an LLM.""" """Initialize the reflector with an LLM."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm
self.reflection_system_prompt = self._get_reflection_prompt() self.reflection_system_prompt = self._get_reflection_prompt()
@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
def _reflect_on_component( def _reflect_on_component(
self, component_type: str, report: str, situation: str, returns_losses self, component_type: str, report: str, situation: str, returns_losses
) -> str: ) -> str:
"""Generate reflection for a component.""" """Generate reflection for a component."""
messages = [ messages = [

View File

@ -1,13 +1,13 @@
# TradingAgents/graph/setup.py # TradingAgents/graph/setup.py
from typing import Dict, Any from typing import Dict, Any
from langchain_openai import ChatOpenAI
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph import END, StateGraph, START from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from tradingagents.agents import * from tradingagents.agents import *
from tradingagents.agents.utils.agent_states import AgentState from tradingagents.agents.utils.agent_states import AgentState
from .conditional_logic import ConditionalLogic from .conditional_logic import ConditionalLogic
@ -15,17 +15,17 @@ class GraphSetup:
"""Handles the setup and configuration of the agent graph.""" """Handles the setup and configuration of the agent graph."""
def __init__( def __init__(
self, self,
quick_thinking_llm: ChatOpenAI, quick_thinking_llm: BaseChatModel,
deep_thinking_llm: ChatOpenAI, deep_thinking_llm: BaseChatModel,
tool_nodes: Dict[str, ToolNode], tool_nodes: Dict[str, ToolNode],
bull_memory, bull_memory,
bear_memory, bear_memory,
trader_memory, trader_memory,
invest_judge_memory, invest_judge_memory,
risk_manager_memory, risk_manager_memory,
conditional_logic: ConditionalLogic, conditional_logic: ConditionalLogic,
config: Dict[str, Any], config: Dict[str, Any],
): ):
"""Initialize with required components.""" """Initialize with required components."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm
@ -40,7 +40,7 @@ class GraphSetup:
self.config = config self.config = config
def setup_graph( def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"] self, selected_analysts=["market", "social", "news", "fundamentals"]
): ):
"""Set up and compile the agent workflow graph. """Set up and compile the agent workflow graph.
@ -149,7 +149,7 @@ class GraphSetup:
# Connect to next analyst or to Bull Researcher if this is the last analyst # Connect to next analyst or to Bull Researcher if this is the last analyst
if i < len(selected_analysts) - 1: if i < len(selected_analysts) - 1:
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" next_analyst = f"{selected_analysts[i + 1].capitalize()} Analyst"
workflow.add_edge(current_clear, next_analyst) workflow.add_edge(current_clear, next_analyst)
else: else:
workflow.add_edge(current_clear, "Bull Researcher") workflow.add_edge(current_clear, "Bull Researcher")

View File

@ -1,12 +1,13 @@
# TradingAgents/graph/signal_processing.py # TradingAgents/graph/signal_processing.py
from langchain_openai import ChatOpenAI
from langchain_core.language_models.chat_models import BaseChatModel
class SignalProcessor: class SignalProcessor:
"""Processes trading signals to extract actionable decisions.""" """Processes trading signals to extract actionable decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI): def __init__(self, quick_thinking_llm: BaseChatModel):
"""Initialize with an LLM for processing.""" """Initialize with an LLM for processing."""
self.quick_thinking_llm = quick_thinking_llm self.quick_thinking_llm = quick_thinking_llm

View File

@ -1,28 +1,16 @@
# TradingAgents/graph/trading_graph.py # TradingAgents/graph/trading_graph.py
import json
import os import os
from pathlib import Path from pathlib import Path
import json from typing import Dict, Any
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langchain_xai import ChatXAI from langchain_xai import ChatXAI
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
from tradingagents.dataflows.config import set_config
# Import the new abstract tool methods from agent_utils # Import the new abstract tool methods from agent_utils
from tradingagents.agents.utils.agent_utils import ( from tradingagents.agents.utils.agent_utils import (
get_stock_data, get_stock_data,
@ -36,11 +24,13 @@ from tradingagents.agents.utils.agent_utils import (
get_insider_transactions, get_insider_transactions,
get_global_news get_global_news
) )
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.dataflows.config import set_config
from tradingagents.default_config import DEFAULT_CONFIG
from .conditional_logic import ConditionalLogic from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator from .propagation import Propagator
from .reflection import Reflector from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor from .signal_processing import SignalProcessor
@ -48,10 +38,10 @@ class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework.""" """Main class that orchestrates the trading agents framework."""
def __init__( def __init__(
self, self,
selected_analysts=["market", "social", "news", "fundamentals"], selected_analysts=["market", "social", "news", "fundamentals"],
debug=False, debug=False,
config: Dict[str, Any] = None, config: Dict[str, Any] = None,
): ):
"""Initialize the trading agents graph and components. """Initialize the trading agents graph and components.
@ -62,7 +52,7 @@ class TradingAgentsGraph:
""" """
self.debug = debug self.debug = debug
self.config = config or DEFAULT_CONFIG self.config = config or DEFAULT_CONFIG
# Update the interface's config # Update the interface's config
set_config(self.config) set_config(self.config)
@ -73,12 +63,17 @@ class TradingAgentsGraph:
) )
# Initialize LLMs # Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config[
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) "llm_provider"] == "openrouter":
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"],
base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"],
base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic": elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"],
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"],
base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "google": elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"]) self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
@ -87,7 +82,7 @@ class TradingAgentsGraph:
self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"]) self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"])
else: else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
# Initialize memories # Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
@ -234,8 +229,8 @@ class TradingAgentsGraph:
directory.mkdir(parents=True, exist_ok=True) directory.mkdir(parents=True, exist_ok=True)
with open( with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
"w", "w",
) as f: ) as f:
json.dump(self.log_states_dict, f, indent=4) json.dump(self.log_states_dict, f, indent=4)