303 lines
12 KiB
Python
303 lines
12 KiB
Python
# TradingAgents/graph/trading_graph.py
|
||
|
||
import os
|
||
from pathlib import Path
|
||
import json
|
||
from datetime import date
|
||
from typing import Dict, Any, Tuple, List, Optional
|
||
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_anthropic import ChatAnthropic
|
||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||
from langchain_xai import ChatXAI
|
||
|
||
from langgraph.prebuilt import ToolNode
|
||
|
||
from tradingagents.agents import *
|
||
from tradingagents.default_config import DEFAULT_CONFIG
|
||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||
from tradingagents.agents.utils.agent_states import (
|
||
AgentState,
|
||
InvestDebateState,
|
||
RiskDebateState,
|
||
)
|
||
from tradingagents.dataflows.config import set_config
|
||
|
||
# Import the new abstract tool methods from agent_utils
|
||
from tradingagents.agents.utils.agent_utils import (
|
||
get_stock_data,
|
||
get_indicators,
|
||
get_fundamentals,
|
||
get_balance_sheet,
|
||
get_cashflow,
|
||
get_income_statement,
|
||
get_news,
|
||
get_insider_sentiment,
|
||
get_insider_transactions,
|
||
get_global_news
|
||
)
|
||
|
||
from .conditional_logic import ConditionalLogic
|
||
from .setup import GraphSetup
|
||
from .propagation import Propagator
|
||
from .reflection import Reflector
|
||
from .signal_processing import SignalProcessor
|
||
|
||
|
||
class TradingAgentsGraph:
|
||
"""Main class that orchestrates the trading agents framework."""
|
||
|
||
def __init__(
|
||
self,
|
||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||
debug=False,
|
||
config: Dict[str, Any] = None,
|
||
):
|
||
"""Initialize the trading agents graph and components.
|
||
|
||
Args:
|
||
selected_analysts: List of analyst types to include
|
||
debug: Whether to run in debug mode
|
||
config: Configuration dictionary. If None, uses default config
|
||
"""
|
||
self.debug = debug
|
||
self.config = config or DEFAULT_CONFIG
|
||
|
||
# get the language
|
||
self.output_language = self.config.get("output_language", "en")
|
||
self.language_instruction = self.config.get("language_system_prompts", {}).get(
|
||
self.output_language, ""
|
||
)
|
||
|
||
self._setup_agents()
|
||
self._setup_graph()
|
||
|
||
def _setup_agents(self):
|
||
"""初始化所有 agents,並注入語言指示"""
|
||
# 為每個 agent 加入語言指示
|
||
if self.language_instruction:
|
||
self._inject_language_to_agents()
|
||
|
||
def _inject_language_to_agents(self):
|
||
"""將語言指示注入所有 agents"""
|
||
# 這個方法會在每個 agent 的 system message 前加入語言指示
|
||
self.language_system_message = SystemMessage(content=self.language_instruction)
|
||
|
||
def _create_agent_with_language(self, agent_name, agent_prompt, llm):
|
||
"""創建帶有語言指示的 agent"""
|
||
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||
|
||
# 構建包含語言指示的 prompt
|
||
if self.language_instruction:
|
||
system_prompt = f"{agent_prompt}{self.language_instruction}"
|
||
else:
|
||
system_prompt = agent_prompt
|
||
|
||
prompt = ChatPromptTemplate.from_messages([
|
||
("system", system_prompt),
|
||
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
||
("human", "{input}"),
|
||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||
])
|
||
|
||
agent = create_openai_functions_agent(llm, tools=[], prompt=prompt)
|
||
return AgentExecutor(agent=agent, tools=[], verbose=self.debug)
|
||
|
||
# Update the interface's config
|
||
set_config(self.config)
|
||
|
||
# Create necessary directories
|
||
os.makedirs(
|
||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
||
exist_ok=True,
|
||
)
|
||
|
||
# Initialize LLMs
|
||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||
elif self.config["llm_provider"].lower() == "anthropic":
|
||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||
elif self.config["llm_provider"].lower() == "google":
|
||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||
elif self.config["llm_provider"].lower() == "xai":
|
||
self.deep_thinking_llm = ChatXAI(model=self.config["deep_think_llm"])
|
||
self.quick_thinking_llm = ChatXAI(model=self.config["quick_think_llm"])
|
||
else:
|
||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||
|
||
# Initialize memories
|
||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||
|
||
# Create tool nodes
|
||
self.tool_nodes = self._create_tool_nodes()
|
||
|
||
# Initialize components
|
||
self.conditional_logic = ConditionalLogic()
|
||
self.graph_setup = GraphSetup(
|
||
self.quick_thinking_llm,
|
||
self.deep_thinking_llm,
|
||
self.tool_nodes,
|
||
self.bull_memory,
|
||
self.bear_memory,
|
||
self.trader_memory,
|
||
self.invest_judge_memory,
|
||
self.risk_manager_memory,
|
||
self.conditional_logic,
|
||
)
|
||
|
||
self.propagator = Propagator()
|
||
self.reflector = Reflector(self.quick_thinking_llm)
|
||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||
|
||
# State tracking
|
||
self.curr_state = None
|
||
self.ticker = None
|
||
self.log_states_dict = {} # date to full state dict
|
||
|
||
# Set up the graph
|
||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||
|
||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||
"""Create tool nodes for different data sources using abstract methods."""
|
||
return {
|
||
"market": ToolNode(
|
||
[
|
||
# Core stock data tools
|
||
get_stock_data,
|
||
# Technical indicators
|
||
get_indicators,
|
||
]
|
||
),
|
||
"social": ToolNode(
|
||
[
|
||
# News tools for social media analysis
|
||
get_news,
|
||
]
|
||
),
|
||
"news": ToolNode(
|
||
[
|
||
# News and insider information
|
||
get_news,
|
||
get_global_news,
|
||
get_insider_sentiment,
|
||
get_insider_transactions,
|
||
]
|
||
),
|
||
"fundamentals": ToolNode(
|
||
[
|
||
# Fundamental analysis tools
|
||
get_fundamentals,
|
||
get_balance_sheet,
|
||
get_cashflow,
|
||
get_income_statement,
|
||
]
|
||
),
|
||
}
|
||
|
||
def propagate(self, company_name, trade_date):
|
||
"""Run the trading agents graph for a company on a specific date."""
|
||
|
||
self.ticker = company_name
|
||
|
||
# Initialize state
|
||
init_agent_state = self.propagator.create_initial_state(
|
||
company_name, trade_date
|
||
)
|
||
args = self.propagator.get_graph_args()
|
||
|
||
if self.debug:
|
||
# Debug mode with tracing
|
||
trace = []
|
||
for chunk in self.graph.stream(init_agent_state, **args):
|
||
if len(chunk["messages"]) == 0:
|
||
pass
|
||
else:
|
||
chunk["messages"][-1].pretty_print()
|
||
trace.append(chunk)
|
||
|
||
final_state = trace[-1]
|
||
else:
|
||
# Standard mode without tracing
|
||
final_state = self.graph.invoke(init_agent_state, **args)
|
||
|
||
# Store current state for reflection
|
||
self.curr_state = final_state
|
||
|
||
# Log state
|
||
self._log_state(trade_date, final_state)
|
||
|
||
# Return decision and processed signal
|
||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||
|
||
def _log_state(self, trade_date, final_state):
|
||
"""Log the final state to a JSON file."""
|
||
self.log_states_dict[str(trade_date)] = {
|
||
"company_of_interest": final_state["company_of_interest"],
|
||
"trade_date": final_state["trade_date"],
|
||
"market_report": final_state["market_report"],
|
||
"sentiment_report": final_state["sentiment_report"],
|
||
"news_report": final_state["news_report"],
|
||
"fundamentals_report": final_state["fundamentals_report"],
|
||
"investment_debate_state": {
|
||
"bull_history": final_state["investment_debate_state"]["bull_history"],
|
||
"bear_history": final_state["investment_debate_state"]["bear_history"],
|
||
"history": final_state["investment_debate_state"]["history"],
|
||
"current_response": final_state["investment_debate_state"][
|
||
"current_response"
|
||
],
|
||
"judge_decision": final_state["investment_debate_state"][
|
||
"judge_decision"
|
||
],
|
||
},
|
||
"trader_investment_decision": final_state["trader_investment_plan"],
|
||
"risk_debate_state": {
|
||
"risky_history": final_state["risk_debate_state"]["risky_history"],
|
||
"safe_history": final_state["risk_debate_state"]["safe_history"],
|
||
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
|
||
"history": final_state["risk_debate_state"]["history"],
|
||
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
|
||
},
|
||
"investment_plan": final_state["investment_plan"],
|
||
"final_trade_decision": final_state["final_trade_decision"],
|
||
}
|
||
|
||
# Save to file
|
||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
||
directory.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(
|
||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
||
"w",
|
||
) as f:
|
||
json.dump(self.log_states_dict, f, indent=4)
|
||
|
||
def reflect_and_remember(self, returns_losses):
|
||
"""Reflect on decisions and update memory based on returns."""
|
||
self.reflector.reflect_bull_researcher(
|
||
self.curr_state, returns_losses, self.bull_memory
|
||
)
|
||
self.reflector.reflect_bear_researcher(
|
||
self.curr_state, returns_losses, self.bear_memory
|
||
)
|
||
self.reflector.reflect_trader(
|
||
self.curr_state, returns_losses, self.trader_memory
|
||
)
|
||
self.reflector.reflect_invest_judge(
|
||
self.curr_state, returns_losses, self.invest_judge_memory
|
||
)
|
||
self.reflector.reflect_risk_manager(
|
||
self.curr_state, returns_losses, self.risk_manager_memory
|
||
)
|
||
|
||
def process_signal(self, full_signal):
|
||
"""Process a signal to extract the core decision."""
|
||
return self.signal_processor.process_signal(full_signal)
|