TradingAgents/tradingagents/graph/trading_graph.py

333 lines
12 KiB
Python

import logging
import os
import threading
from pathlib import Path
import json
from datetime import date, datetime
from typing import Dict, Any, Tuple, Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.prebuilt import ToolNode
from tradingagents.dataflows.config import get_config
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.dataflows.config import set_config
from tradingagents.agents.utils.agent_utils import (
get_stock_data,
get_indicators,
get_fundamentals,
get_balance_sheet,
get_cashflow,
get_income_statement,
get_news,
get_insider_sentiment,
get_insider_transactions,
get_global_news
)
from tradingagents.agents.discovery import (
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
TrendingStock,
Sector,
EventCategory,
DiscoveryTimeoutError,
extract_entities,
calculate_trending_scores,
)
from tradingagents.dataflows.interface import get_bulk_news
from tradingagents.validation import validate_ticker, validate_date
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .signal_processing import SignalProcessor
logger = logging.getLogger(__name__)
class DiscoveryTimeoutException(Exception):
pass
def _timeout_handler(signum, frame) -> None:
raise DiscoveryTimeoutException("Discovery operation timed out")
class TradingAgentsGraph:
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
):
self.debug = debug
self.config = config or get_config()
set_config(self.config)
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
self.tool_nodes = self._create_tool_nodes()
self.conditional_logic = ConditionalLogic()
self.graph_setup = GraphSetup(
self.quick_thinking_llm,
self.deep_thinking_llm,
self.tool_nodes,
self.bull_memory,
self.bear_memory,
self.trader_memory,
self.invest_judge_memory,
self.risk_manager_memory,
self.conditional_logic,
)
self.propagator = Propagator()
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
self.curr_state = None
self.ticker = None
self.log_states_dict = {}
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
return {
"market": ToolNode(
[
get_stock_data,
get_indicators,
]
),
"social": ToolNode(
[
get_news,
]
),
"news": ToolNode(
[
get_news,
get_global_news,
get_insider_sentiment,
get_insider_transactions,
]
),
"fundamentals": ToolNode(
[
get_fundamentals,
get_balance_sheet,
get_cashflow,
get_income_statement,
]
),
}
def propagate(self, company_name: str, trade_date) -> Tuple[Dict[str, Any], str]:
company_name = validate_ticker(company_name)
validated_date = validate_date(trade_date, allow_future=False)
if isinstance(trade_date, str):
trade_date = validated_date
self.ticker = company_name
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
args = self.propagator.get_graph_args()
if self.debug:
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
pass
else:
logger.debug("Agent message: %s", chunk["messages"][-1])
trace.append(chunk)
final_state = trace[-1]
else:
final_state = self.graph.invoke(init_agent_state, **args)
self.curr_state = final_state
self._log_state(trade_date, final_state)
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state: Dict[str, Any]) -> None:
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
"market_report": final_state["market_report"],
"sentiment_report": final_state["sentiment_report"],
"news_report": final_state["news_report"],
"fundamentals_report": final_state["fundamentals_report"],
"investment_debate_state": {
"bull_history": final_state["investment_debate_state"]["bull_history"],
"bear_history": final_state["investment_debate_state"]["bear_history"],
"history": final_state["investment_debate_state"]["history"],
"current_response": final_state["investment_debate_state"][
"current_response"
],
"judge_decision": final_state["investment_debate_state"][
"judge_decision"
],
},
"trader_investment_decision": final_state["trader_investment_plan"],
"risk_debate_state": {
"risky_history": final_state["risk_debate_state"]["risky_history"],
"safe_history": final_state["risk_debate_state"]["safe_history"],
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
"history": final_state["risk_debate_state"]["history"],
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
},
"investment_plan": final_state["investment_plan"],
"final_trade_decision": final_state["final_trade_decision"],
}
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
directory.mkdir(parents=True, exist_ok=True)
with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
"w",
) as f:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses) -> None:
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
)
self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory
)
self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory
)
self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory
)
self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory
)
def process_signal(self, full_signal: str) -> str:
return self.signal_processor.process_signal(full_signal)
def discover_trending(
self,
request: Optional[DiscoveryRequest] = None,
) -> DiscoveryResult:
if request is None:
request = DiscoveryRequest(
lookback_period="24h",
max_results=self.config.get("discovery_max_results", 20),
)
started_at = datetime.now()
result = DiscoveryResult(
request=request,
trending_stocks=[],
status=DiscoveryStatus.PROCESSING,
started_at=started_at,
)
hard_timeout = self.config.get("discovery_hard_timeout", 120)
discovery_result = {"stocks": [], "error": None}
def run_discovery():
try:
articles = get_bulk_news(request.lookback_period)
mentions = extract_entities(articles, self.config)
min_mentions = self.config.get("discovery_min_mentions", 2)
if len(articles) < 10:
min_mentions = 1
max_results = request.max_results or self.config.get("discovery_max_results", 20)
trending_stocks = calculate_trending_scores(
mentions,
articles,
max_results=max_results,
min_mentions=min_mentions,
)
discovery_result["stocks"] = trending_stocks
except (ValueError, KeyError, RuntimeError, ConnectionError, TimeoutError) as e:
discovery_result["error"] = str(e)
discovery_thread = threading.Thread(target=run_discovery)
discovery_thread.start()
discovery_thread.join(timeout=hard_timeout)
if discovery_thread.is_alive():
raise DiscoveryTimeoutError(
f"Discovery operation exceeded {hard_timeout} second timeout"
)
if discovery_result["error"]:
result.status = DiscoveryStatus.FAILED
result.error_message = discovery_result["error"]
result.completed_at = datetime.now()
return result
trending_stocks = discovery_result["stocks"]
if request.sector_filter:
sector_values = {s.value if isinstance(s, Sector) else s for s in request.sector_filter}
trending_stocks = [
stock for stock in trending_stocks
if stock.sector.value in sector_values or stock.sector in request.sector_filter
]
if request.event_filter:
event_values = {e.value if isinstance(e, EventCategory) else e for e in request.event_filter}
trending_stocks = [
stock for stock in trending_stocks
if stock.event_type.value in event_values or stock.event_type in request.event_filter
]
result.trending_stocks = trending_stocks
result.status = DiscoveryStatus.COMPLETED
result.completed_at = datetime.now()
return result
def analyze_trending(
self,
trending_stock: TrendingStock,
trade_date: Optional[date] = None,
) -> Tuple[Dict[str, Any], str]:
ticker = trending_stock.ticker
if trade_date is None:
trade_date = date.today()
return self.propagate(ticker, trade_date)