update to use BaseChatModel instead of ChatOpenAI to avoid importing Openai when using other LLM provider
This commit is contained in:
parent
44a5cb133a
commit
65ef62de17
|
|
@ -1,10 +1,7 @@
|
||||||
from typing import Annotated, Sequence
|
from typing import Annotated
|
||||||
from datetime import date, timedelta, datetime
|
|
||||||
from typing_extensions import TypedDict, Optional
|
from langgraph.graph import MessagesState
|
||||||
from langchain_openai import ChatOpenAI
|
from typing_extensions import TypedDict
|
||||||
from tradingagents.agents import *
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
|
||||||
|
|
||||||
|
|
||||||
# Researcher team state
|
# Researcher team state
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,14 @@
|
||||||
# TradingAgents/graph/reflection.py
|
# TradingAgents/graph/reflection.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
|
|
||||||
class Reflector:
|
class Reflector:
|
||||||
"""Handles reflection on decisions and updating memory."""
|
"""Handles reflection on decisions and updating memory."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: BaseChatModel):
|
||||||
"""Initialize the reflector with an LLM."""
|
"""Initialize the reflector with an LLM."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
self.reflection_system_prompt = self._get_reflection_prompt()
|
self.reflection_system_prompt = self._get_reflection_prompt()
|
||||||
|
|
@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
|
||||||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
||||||
|
|
||||||
def _reflect_on_component(
|
def _reflect_on_component(
|
||||||
self, component_type: str, report: str, situation: str, returns_losses
|
self, component_type: str, report: str, situation: str, returns_losses
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate reflection for a component."""
|
"""Generate reflection for a component."""
|
||||||
messages = [
|
messages = [
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
# TradingAgents/graph/setup.py
|
# TradingAgents/graph/setup.py
|
||||||
|
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langgraph.graph import END, StateGraph, START
|
from langgraph.graph import END, StateGraph, START
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import *
|
from tradingagents.agents import *
|
||||||
from tradingagents.agents.utils.agent_states import AgentState
|
from tradingagents.agents.utils.agent_states import AgentState
|
||||||
|
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -15,17 +15,17 @@ class GraphSetup:
|
||||||
"""Handles the setup and configuration of the agent graph."""
|
"""Handles the setup and configuration of the agent graph."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quick_thinking_llm: ChatOpenAI,
|
quick_thinking_llm: BaseChatModel,
|
||||||
deep_thinking_llm: ChatOpenAI,
|
deep_thinking_llm: BaseChatModel,
|
||||||
tool_nodes: Dict[str, ToolNode],
|
tool_nodes: Dict[str, ToolNode],
|
||||||
bull_memory,
|
bull_memory,
|
||||||
bear_memory,
|
bear_memory,
|
||||||
trader_memory,
|
trader_memory,
|
||||||
invest_judge_memory,
|
invest_judge_memory,
|
||||||
risk_manager_memory,
|
risk_manager_memory,
|
||||||
conditional_logic: ConditionalLogic,
|
conditional_logic: ConditionalLogic,
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
):
|
):
|
||||||
"""Initialize with required components."""
|
"""Initialize with required components."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
|
|
@ -40,7 +40,7 @@ class GraphSetup:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||||
):
|
):
|
||||||
"""Set up and compile the agent workflow graph.
|
"""Set up and compile the agent workflow graph.
|
||||||
|
|
||||||
|
|
@ -149,7 +149,7 @@ class GraphSetup:
|
||||||
|
|
||||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
||||||
if i < len(selected_analysts) - 1:
|
if i < len(selected_analysts) - 1:
|
||||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
next_analyst = f"{selected_analysts[i + 1].capitalize()} Analyst"
|
||||||
workflow.add_edge(current_clear, next_analyst)
|
workflow.add_edge(current_clear, next_analyst)
|
||||||
else:
|
else:
|
||||||
workflow.add_edge(current_clear, "Bull Researcher")
|
workflow.add_edge(current_clear, "Bull Researcher")
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
# TradingAgents/graph/signal_processing.py
|
# TradingAgents/graph/signal_processing.py
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
|
|
||||||
|
|
||||||
class SignalProcessor:
|
class SignalProcessor:
|
||||||
"""Processes trading signals to extract actionable decisions."""
|
"""Processes trading signals to extract actionable decisions."""
|
||||||
|
|
||||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
def __init__(self, quick_thinking_llm: BaseChatModel):
|
||||||
"""Initialize with an LLM for processing."""
|
"""Initialize with an LLM for processing."""
|
||||||
self.quick_thinking_llm = quick_thinking_llm
|
self.quick_thinking_llm = quick_thinking_llm
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,16 @@
|
||||||
# TradingAgents/graph/trading_graph.py
|
# TradingAgents/graph/trading_graph.py
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
from typing import Dict, Any
|
||||||
from datetime import date
|
|
||||||
from typing import Dict, Any, Tuple, List, Optional
|
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
from langchain_xai import ChatXAI
|
from langchain_xai import ChatXAI
|
||||||
|
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
|
|
||||||
from tradingagents.agents import *
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
|
||||||
from tradingagents.agents.utils.agent_states import (
|
|
||||||
AgentState,
|
|
||||||
InvestDebateState,
|
|
||||||
RiskDebateState,
|
|
||||||
)
|
|
||||||
from tradingagents.dataflows.config import set_config
|
|
||||||
|
|
||||||
# Import the new abstract tool methods from agent_utils
|
# Import the new abstract tool methods from agent_utils
|
||||||
from tradingagents.agents.utils.agent_utils import (
|
from tradingagents.agents.utils.agent_utils import (
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
|
|
@ -36,11 +24,13 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
get_insider_transactions,
|
get_insider_transactions,
|
||||||
get_global_news
|
get_global_news
|
||||||
)
|
)
|
||||||
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||||
|
from tradingagents.dataflows.config import set_config
|
||||||
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from .conditional_logic import ConditionalLogic
|
from .conditional_logic import ConditionalLogic
|
||||||
from .setup import GraphSetup
|
|
||||||
from .propagation import Propagator
|
from .propagation import Propagator
|
||||||
from .reflection import Reflector
|
from .reflection import Reflector
|
||||||
|
from .setup import GraphSetup
|
||||||
from .signal_processing import SignalProcessor
|
from .signal_processing import SignalProcessor
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -48,10 +38,10 @@ class TradingAgentsGraph:
|
||||||
"""Main class that orchestrates the trading agents framework."""
|
"""Main class that orchestrates the trading agents framework."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
debug=False,
|
debug=False,
|
||||||
config: Dict[str, Any] = None,
|
config: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the trading agents graph and components.
|
"""Initialize the trading agents graph and components.
|
||||||
|
|
||||||
|
|
@ -62,7 +52,7 @@ class TradingAgentsGraph:
|
||||||
"""
|
"""
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or DEFAULT_CONFIG
|
||||||
|
|
||||||
# Update the interface's config
|
# Update the interface's config
|
||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
||||||
|
|
@ -73,12 +63,17 @@ class TradingAgentsGraph:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize LLMs
|
# Initialize LLMs
|
||||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config[
|
||||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
"llm_provider"] == "openrouter":
|
||||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"],
|
||||||
|
base_url=self.config["backend_url"])
|
||||||
|
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"],
|
||||||
|
base_url=self.config["backend_url"])
|
||||||
elif self.config["llm_provider"].lower() == "anthropic":
|
elif self.config["llm_provider"].lower() == "anthropic":
|
||||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"],
|
||||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
base_url=self.config["backend_url"])
|
||||||
|
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"],
|
||||||
|
base_url=self.config["backend_url"])
|
||||||
elif self.config["llm_provider"].lower() == "google":
|
elif self.config["llm_provider"].lower() == "google":
|
||||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||||
|
|
@ -87,7 +82,7 @@ class TradingAgentsGraph:
|
||||||
self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"])
|
self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||||
|
|
||||||
# Initialize memories
|
# Initialize memories
|
||||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||||
|
|
@ -234,8 +229,8 @@ class TradingAgentsGraph:
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(
|
with open(
|
||||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
||||||
"w",
|
"w",
|
||||||
) as f:
|
) as f:
|
||||||
json.dump(self.log_states_dict, f, indent=4)
|
json.dump(self.log_states_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue