update to use BaseChatModel instead of ChatOpenAI to avoid importing Openai when using other LLM provider
This commit is contained in:
parent
44a5cb133a
commit
65ef62de17
|
|
@ -1,10 +1,7 @@
|
||||||
from typing import Annotated, Sequence
|
from typing import Annotated
|
||||||
from datetime import date, timedelta, datetime
|
|
||||||
from typing_extensions import TypedDict, Optional
|
from langgraph.graph import MessagesState
|
||||||
from langchain_openai import ChatOpenAI
|
from typing_extensions import TypedDict
|
||||||
from tradingagents.agents import *
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
|
||||||
|
|
||||||
|
|
||||||
# Researcher team state
|
# Researcher team state
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
# TradingAgents/graph/reflection.py
|
# TradingAgents/graph/reflection.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
|
|
||||||
class Reflector:
|
class Reflector:
|
||||||
"""Handles reflection on decisions and updating memory."""
|
"""Handles reflection on decisions and updating memory."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: BaseChatModel):
|
||||||
"""Initialize the reflector with an LLM."""
|
"""Initialize the reflector with an LLM."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
self.reflection_system_prompt = self._get_reflection_prompt()
|
self.reflection_system_prompt = self._get_reflection_prompt()
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
# TradingAgents/graph/setup.py
|
# TradingAgents/graph/setup.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langgraph.graph import END, StateGraph, START
|
from langgraph.graph import END, StateGraph, START
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -16,8 +16,8 @@ class GraphSetup:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quick_thinking_llm: ChatOpenAI,
|
quick_thinking_llm: BaseChatModel,
|
||||||
deep_thinking_llm: ChatOpenAI,
|
deep_thinking_llm: BaseChatModel,
|
||||||
tool_nodes: Dict[str, ToolNode],
|
tool_nodes: Dict[str, ToolNode],
|
||||||
bull_memory,
|
bull_memory,
|
||||||
bear_memory,
|
bear_memory,
|
||||||
|
|
@ -149,7 +149,7 @@ class GraphSetup:
|
||||||
|
|
||||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
||||||
if i < len(selected_analysts) - 1:
|
if i < len(selected_analysts) - 1:
|
||||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
next_analyst = f"{selected_analysts[i + 1].capitalize()} Analyst"
|
||||||
workflow.add_edge(current_clear, next_analyst)
|
workflow.add_edge(current_clear, next_analyst)
|
||||||
else:
|
else:
|
||||||
workflow.add_edge(current_clear, "Bull Researcher")
|
workflow.add_edge(current_clear, "Bull Researcher")
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
# TradingAgents/graph/signal_processing.py
|
# TradingAgents/graph/signal_processing.py
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
|
|
||||||
class SignalProcessor:
|
class SignalProcessor:
|
||||||
"""Processes trading signals to extract actionable decisions."""
|
"""Processes trading signals to extract actionable decisions."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: BaseChatModel):
|
||||||
"""Initialize with an LLM for processing."""
|
"""Initialize with an LLM for processing."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,16 @@
|
||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
from typing import Dict, Any
|
||||||
from datetime import date
|
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_xai import ChatXAI
|
from langchain_xai import ChatXAI
|
||||||
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import *
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
|
||||||
from tradingagents.agents.utils.agent_states import (
|
|
||||||
AgentState,
|
|
||||||
InvestDebateState,
|
|
||||||
RiskDebateState,
|
|
||||||
)
|
|
||||||
from tradingagents.dataflows.config import set_config
|
|
||||||
|
|
||||||
# Import the new abstract tool methods from agent_utils
|
# Import the new abstract tool methods from agent_utils
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
|
|
@ -36,11 +24,13 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_insider_transactions,
|
get_insider_transactions,
|
||||||
get_global_news
|
get_global_news
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
from .reflection import Reflector
|
from .reflection import Reflector
|
||||||
|
from .setup import GraphSetup
|
||||||
from .signal_processing import SignalProcessor
|
from .signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -73,12 +63,17 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize LLMs
|
# Initialize LLMs
|
||||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config[
|
||||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
"llm_provider"] == "openrouter":
|
||||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"],
|
||||||
|
base_url=self.config["backend_url"])
|
||||||
|
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"],
|
||||||
|
base_url=self.config["backend_url"])
|
||||||
elif self.config["llm_provider"].lower() == "anthropic":
|
elif self.config["llm_provider"].lower() == "anthropic":
|
||||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"],
|
||||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
base_url=self.config["backend_url"])
|
||||||
|
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"],
|
||||||
|
base_url=self.config["backend_url"])
|
||||||
elif self.config["llm_provider"].lower() == "google":
|
elif self.config["llm_provider"].lower() == "google":
|
||||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue