diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 3a859ea1..a4b6034f 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,10 +1,7 @@ -from typing import Annotated, Sequence -from datetime import date, timedelta, datetime -from typing_extensions import TypedDict, Optional -from langchain_openai import ChatOpenAI -from tradingagents.agents import * -from langgraph.prebuilt import ToolNode -from langgraph.graph import END, StateGraph, START, MessagesState +from typing import Annotated + +from langgraph.graph import MessagesState +from typing_extensions import TypedDict # Researcher team state diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 33303231..b4cc6163 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -1,13 +1,14 @@ # TradingAgents/graph/reflection.py from typing import Dict, Any -from langchain_openai import ChatOpenAI + +from langchain_core.language_models.chat_models import BaseChatModel class Reflector: """Handles reflection on decisions and updating memory.""" - def __init__(self, quick_thinking_llm: ChatOpenAI): + def __init__(self, quick_thinking_llm: BaseChatModel): """Initialize the reflector with an LLM.""" self.quick_thinking_llm = quick_thinking_llm self.reflection_system_prompt = self._get_reflection_prompt() @@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" def _reflect_on_component( - self, component_type: str, report: str, situation: str, returns_losses + self, component_type: str, report: str, situation: str, returns_losses ) -> str: """Generate reflection for a component.""" messages = [ diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index a130b8b6..ea6f7e5d 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,13 +1,13 @@ # TradingAgents/graph/setup.py from typing import Dict, Any -from langchain_openai import ChatOpenAI + +from langchain_core.language_models.chat_models import BaseChatModel from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode from tradingagents.agents import * from tradingagents.agents.utils.agent_states import AgentState - from .conditional_logic import ConditionalLogic @@ -15,17 +15,17 @@ class GraphSetup: """Handles the setup and configuration of the agent graph.""" def __init__( - self, - quick_thinking_llm: ChatOpenAI, - deep_thinking_llm: ChatOpenAI, - tool_nodes: Dict[str, ToolNode], - bull_memory, - bear_memory, - trader_memory, - invest_judge_memory, - risk_manager_memory, - conditional_logic: ConditionalLogic, - config: Dict[str, Any], + self, + quick_thinking_llm: BaseChatModel, + deep_thinking_llm: BaseChatModel, + tool_nodes: Dict[str, ToolNode], + bull_memory, + bear_memory, + trader_memory, + invest_judge_memory, + risk_manager_memory, + conditional_logic: ConditionalLogic, + config: Dict[str, Any], ): """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm @@ -40,7 +40,7 @@ class GraphSetup: self.config = config def setup_graph( - self, selected_analysts=["market", "social", "news", "fundamentals"] + self, selected_analysts=["market", "social", "news", "fundamentals"] ): """Set up and compile the agent workflow graph. @@ -149,7 +149,7 @@ class GraphSetup: # Connect to next analyst or to Bull Researcher if this is the last analyst if i < len(selected_analysts) - 1: - next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" + next_analyst = f"{selected_analysts[i + 1].capitalize()} Analyst" workflow.add_edge(current_clear, next_analyst) else: workflow.add_edge(current_clear, "Bull Researcher") diff --git a/tradingagents/graph/signal_processing.py b/tradingagents/graph/signal_processing.py index 903e8529..bff77b4d 100644 --- a/tradingagents/graph/signal_processing.py +++ b/tradingagents/graph/signal_processing.py @@ -1,12 +1,13 @@ # TradingAgents/graph/signal_processing.py -from langchain_openai import ChatOpenAI + +from langchain_core.language_models.chat_models import BaseChatModel class SignalProcessor: """Processes trading signals to extract actionable decisions.""" - def __init__(self, quick_thinking_llm: ChatOpenAI): + def __init__(self, quick_thinking_llm: BaseChatModel): """Initialize with an LLM for processing.""" self.quick_thinking_llm = quick_thinking_llm diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 9428f50e..c9239aa5 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,28 +1,16 @@ # TradingAgents/graph/trading_graph.py +import json import os from pathlib import Path -import json -from datetime import date -from typing import Dict, Any, Tuple, List, Optional +from typing import Dict, Any -from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_openai import ChatOpenAI from langchain_xai import ChatXAI - from langgraph.prebuilt import ToolNode -from tradingagents.agents import * -from tradingagents.default_config import DEFAULT_CONFIG -from tradingagents.agents.utils.memory import FinancialSituationMemory -from tradingagents.agents.utils.agent_states import ( - AgentState, - InvestDebateState, - RiskDebateState, -) -from tradingagents.dataflows.config import set_config - # Import the new abstract tool methods from agent_utils from tradingagents.agents.utils.agent_utils import ( get_stock_data, @@ -36,11 +24,13 @@ from tradingagents.agents.utils.agent_utils import ( get_insider_transactions, get_global_news ) - +from tradingagents.agents.utils.memory import FinancialSituationMemory +from tradingagents.dataflows.config import set_config +from tradingagents.default_config import DEFAULT_CONFIG from .conditional_logic import ConditionalLogic -from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector +from .setup import GraphSetup from .signal_processing import SignalProcessor @@ -48,10 +38,10 @@ class TradingAgentsGraph: """Main class that orchestrates the trading agents framework.""" def __init__( - self, - selected_analysts=["market", "social", "news", "fundamentals"], - debug=False, - config: Dict[str, Any] = None, + self, + selected_analysts=["market", "social", "news", "fundamentals"], + debug=False, + config: Dict[str, Any] = None, ): """Initialize the trading agents graph and components. @@ -62,7 +52,7 @@ class TradingAgentsGraph: """ self.debug = debug self.config = config or DEFAULT_CONFIG - + # Update the interface's config set_config(self.config) @@ -73,12 +63,17 @@ class TradingAgentsGraph: ) # Initialize LLMs - if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": - self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) - self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) + if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config[ + "llm_provider"] == "openrouter": + self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], + base_url=self.config["backend_url"]) + self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], + base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "anthropic": - self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) - self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) + self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], + base_url=self.config["backend_url"]) + self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], + base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "google": self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"]) self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) @@ -87,7 +82,7 @@ class TradingAgentsGraph: self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"]) else: raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") - + # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config) @@ -234,8 +229,8 @@ class TradingAgentsGraph: directory.mkdir(parents=True, exist_ok=True) with open( - f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", - "w", + f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", + "w", ) as f: json.dump(self.log_states_dict, f, indent=4)