update to use BaseChatModel instead of ChatOpenAI to avoid importing Openai when using other LLM provider
This commit is contained in:
parent
44a5cb133a
commit
65ef62de17
|
|
@ -1,10 +1,7 @@
|
|||
from typing import Annotated, Sequence
|
||||
from datetime import date, timedelta, datetime
|
||||
from typing_extensions import TypedDict, Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tradingagents.agents import *
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
||||
from typing import Annotated
|
||||
|
||||
from langgraph.graph import MessagesState
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# Researcher team state
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
# TradingAgents/graph/reflection.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
|
||||
class Reflector:
|
||||
"""Handles reflection on decisions and updating memory."""
|
||||
|
||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||
def __init__(self, quick_thinking_llm: BaseChatModel):
|
||||
"""Initialize the reflector with an LLM."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
self.reflection_system_prompt = self._get_reflection_prompt()
|
||||
|
|
@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
|||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
||||
|
||||
def _reflect_on_component(
|
||||
self, component_type: str, report: str, situation: str, returns_losses
|
||||
self, component_type: str, report: str, situation: str, returns_losses
|
||||
) -> str:
|
||||
"""Generate reflection for a component."""
|
||||
messages = [
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
# TradingAgents/graph/setup.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
|
||||
|
||||
|
|
@ -15,17 +15,17 @@ class GraphSetup:
|
|||
"""Handles the setup and configuration of the agent graph."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quick_thinking_llm: ChatOpenAI,
|
||||
deep_thinking_llm: ChatOpenAI,
|
||||
tool_nodes: Dict[str, ToolNode],
|
||||
bull_memory,
|
||||
bear_memory,
|
||||
trader_memory,
|
||||
invest_judge_memory,
|
||||
risk_manager_memory,
|
||||
conditional_logic: ConditionalLogic,
|
||||
config: Dict[str, Any],
|
||||
self,
|
||||
quick_thinking_llm: BaseChatModel,
|
||||
deep_thinking_llm: BaseChatModel,
|
||||
tool_nodes: Dict[str, ToolNode],
|
||||
bull_memory,
|
||||
bear_memory,
|
||||
trader_memory,
|
||||
invest_judge_memory,
|
||||
risk_manager_memory,
|
||||
conditional_logic: ConditionalLogic,
|
||||
config: Dict[str, Any],
|
||||
):
|
||||
"""Initialize with required components."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
|
|
@ -40,7 +40,7 @@ class GraphSetup:
|
|||
self.config = config
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ class GraphSetup:
|
|||
|
||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
||||
if i < len(selected_analysts) - 1:
|
||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
||||
next_analyst = f"{selected_analysts[i + 1].capitalize()} Analyst"
|
||||
workflow.add_edge(current_clear, next_analyst)
|
||||
else:
|
||||
workflow.add_edge(current_clear, "Bull Researcher")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
# TradingAgents/graph/signal_processing.py
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
|
||||
|
||||
class SignalProcessor:
|
||||
"""Processes trading signals to extract actionable decisions."""
|
||||
|
||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||
def __init__(self, quick_thinking_llm: BaseChatModel):
|
||||
"""Initialize with an LLM for processing."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +1,16 @@
|
|||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import date
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
from typing import Dict, Any
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
)
|
||||
from tradingagents.dataflows.config import set_config
|
||||
|
||||
# Import the new abstract tool methods from agent_utils
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
get_stock_data,
|
||||
|
|
@ -36,11 +24,13 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_insider_transactions,
|
||||
get_global_news
|
||||
)
|
||||
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.dataflows.config import set_config
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
from .reflection import Reflector
|
||||
from .setup import GraphSetup
|
||||
from .signal_processing import SignalProcessor
|
||||
|
||||
|
||||
|
|
@ -48,10 +38,10 @@ class TradingAgentsGraph:
|
|||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""Initialize the trading agents graph and components.
|
||||
|
||||
|
|
@ -62,7 +52,7 @@ class TradingAgentsGraph:
|
|||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
|
||||
|
|
@ -73,12 +63,17 @@ class TradingAgentsGraph:
|
|||
)
|
||||
|
||||
# Initialize LLMs
|
||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config[
|
||||
"llm_provider"] == "openrouter":
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"],
|
||||
base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"],
|
||||
base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "anthropic":
|
||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"],
|
||||
base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"],
|
||||
base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "google":
|
||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||
|
|
@ -87,7 +82,7 @@ class TradingAgentsGraph:
|
|||
self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||
|
|
@ -234,8 +229,8 @@ class TradingAgentsGraph:
|
|||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(
|
||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
||||
"w",
|
||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
||||
"w",
|
||||
) as f:
|
||||
json.dump(self.log_states_dict, f, indent=4)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue