TradingAgents/tradingagents/graph/trading_graph.py

435 lines
15 KiB
Python

import json
import logging
import os
import threading
from datetime import date, datetime
from pathlib import Path
from typing import Any
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from tradingagents.agents.discovery import (
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
DiscoveryTimeoutError,
EventCategory,
Sector,
TrendingStock,
calculate_trending_scores,
extract_entities,
)
from tradingagents.agents.utils.agent_utils import (
get_balance_sheet,
get_cashflow,
get_fundamentals,
get_global_news,
get_income_statement,
get_indicators,
get_insider_sentiment,
get_insider_transactions,
get_news,
get_stock_data,
)
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.database import (
AnalysisService,
DiscoveryService,
TradingService,
get_db_session,
init_database,
)
from tradingagents.dataflows.config import get_config, set_config
from tradingagents.dataflows.interface import get_bulk_news
from tradingagents.validation import validate_date, validate_ticker
from .conditional_logic import ConditionalLogic
from .propagation import Propagator
from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor
logger = logging.getLogger(__name__)
class DiscoveryTimeoutException(Exception):
pass
def _timeout_handler(signum, frame) -> None:
raise DiscoveryTimeoutException("Discovery operation timed out")
class TradingAgentsGraph:
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: dict[str, Any] = None,
):
self.debug = debug
self.config = config or get_config()
set_config(self.config)
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
self.db_enabled = self.config.get("database_enabled", False)
if self.db_enabled:
db_path = self.config.get("database_path")
init_database(db_path)
if (
self.config["llm_provider"].lower() == "openai"
or self.config["llm_provider"] == "ollama"
or self.config["llm_provider"] == "openrouter"
):
self.deep_thinking_llm = ChatOpenAI(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"]
)
self.quick_thinking_llm = ChatOpenAI(
model=self.config["quick_think_llm"],
base_url=self.config["backend_url"],
)
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(
model=self.config["deep_think_llm"], base_url=self.config["backend_url"]
)
self.quick_thinking_llm = ChatAnthropic(
model=self.config["quick_think_llm"],
base_url=self.config["backend_url"],
)
elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(
model=self.config["deep_think_llm"]
)
self.quick_thinking_llm = ChatGoogleGenerativeAI(
model=self.config["quick_think_llm"]
)
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory(
"invest_judge_memory", self.config
)
self.risk_manager_memory = FinancialSituationMemory(
"risk_manager_memory", self.config
)
self.tool_nodes = self._create_tool_nodes()
self.conditional_logic = ConditionalLogic()
self.graph_setup = GraphSetup(
self.quick_thinking_llm,
self.deep_thinking_llm,
self.tool_nodes,
self.bull_memory,
self.bear_memory,
self.trader_memory,
self.invest_judge_memory,
self.risk_manager_memory,
self.conditional_logic,
)
self.propagator = Propagator()
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
self.curr_state = None
self.ticker = None
self.log_states_dict = {}
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _create_tool_nodes(self) -> dict[str, ToolNode]:
return {
"market": ToolNode(
[
get_stock_data,
get_indicators,
]
),
"social": ToolNode(
[
get_news,
]
),
"news": ToolNode(
[
get_news,
get_global_news,
get_insider_sentiment,
get_insider_transactions,
]
),
"fundamentals": ToolNode(
[
get_fundamentals,
get_balance_sheet,
get_cashflow,
get_income_statement,
]
),
}
def propagate(self, company_name: str, trade_date) -> tuple[dict[str, Any], str]:
company_name = validate_ticker(company_name)
validated_date = validate_date(trade_date, allow_future=False)
if isinstance(trade_date, str):
trade_date = validated_date
self.ticker = company_name
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
args = self.propagator.get_graph_args()
if self.debug:
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
pass
else:
logger.debug("Agent message: %s", chunk["messages"][-1])
trace.append(chunk)
final_state = trace[-1]
else:
final_state = self.graph.invoke(init_agent_state, **args)
self.curr_state = final_state
self._log_state(trade_date, final_state)
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state: dict[str, Any]) -> None:
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
"market_report": final_state["market_report"],
"sentiment_report": final_state["sentiment_report"],
"news_report": final_state["news_report"],
"fundamentals_report": final_state["fundamentals_report"],
"investment_debate_state": {
"bull_history": final_state["investment_debate_state"]["bull_history"],
"bear_history": final_state["investment_debate_state"]["bear_history"],
"history": final_state["investment_debate_state"]["history"],
"current_response": final_state["investment_debate_state"][
"current_response"
],
"judge_decision": final_state["investment_debate_state"][
"judge_decision"
],
},
"trader_investment_decision": final_state["trader_investment_plan"],
"risk_debate_state": {
"risky_history": final_state["risk_debate_state"]["risky_history"],
"safe_history": final_state["risk_debate_state"]["safe_history"],
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
"history": final_state["risk_debate_state"]["history"],
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
},
"investment_plan": final_state["investment_plan"],
"final_trade_decision": final_state["final_trade_decision"],
}
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
directory.mkdir(parents=True, exist_ok=True)
with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
"w",
) as f:
json.dump(self.log_states_dict, f, indent=4)
if self.db_enabled:
self._persist_to_database(trade_date, final_state)
def _persist_to_database(self, trade_date, final_state: dict[str, Any]) -> None:
with get_db_session() as session:
analysis_service = AnalysisService(session)
trading_service = TradingService(session)
analysis_session = analysis_service.save_full_state(
self.ticker, str(trade_date), final_state
)
signal = self.process_signal(final_state.get("final_trade_decision", ""))
trading_service.save_trading_decision(
analysis_session.id, self.ticker, final_state, signal
)
def reflect_and_remember(self, returns_losses) -> None:
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
)
self.reflector.reflect_bear_researcher(
self.curr_state, returns_losses, self.bear_memory
)
self.reflector.reflect_trader(
self.curr_state, returns_losses, self.trader_memory
)
self.reflector.reflect_invest_judge(
self.curr_state, returns_losses, self.invest_judge_memory
)
self.reflector.reflect_risk_manager(
self.curr_state, returns_losses, self.risk_manager_memory
)
def process_signal(self, full_signal: str) -> str:
return self.signal_processor.process_signal(full_signal)
def discover_trending(
self,
request: DiscoveryRequest | None = None,
) -> DiscoveryResult:
if request is None:
request = DiscoveryRequest(
lookback_period="24h",
max_results=self.config.get("discovery_max_results", 20),
)
started_at = datetime.now()
result = DiscoveryResult(
request=request,
trending_stocks=[],
status=DiscoveryStatus.PROCESSING,
started_at=started_at,
)
hard_timeout = self.config.get("discovery_hard_timeout", 120)
discovery_result = {"stocks": [], "error": None}
def run_discovery():
try:
articles = get_bulk_news(request.lookback_period)
mentions = extract_entities(articles, self.config)
min_mentions = self.config.get("discovery_min_mentions", 2)
if len(articles) < 10:
min_mentions = 1
max_results = request.max_results or self.config.get(
"discovery_max_results", 20
)
trending_stocks = calculate_trending_scores(
mentions,
articles,
max_results=max_results,
min_mentions=min_mentions,
)
discovery_result["stocks"] = trending_stocks
except (
ValueError,
KeyError,
RuntimeError,
ConnectionError,
TimeoutError,
) as e:
discovery_result["error"] = str(e)
discovery_thread = threading.Thread(target=run_discovery)
discovery_thread.start()
discovery_thread.join(timeout=hard_timeout)
if discovery_thread.is_alive():
raise DiscoveryTimeoutError(
f"Discovery operation exceeded {hard_timeout} second timeout"
)
if discovery_result["error"]:
result.status = DiscoveryStatus.FAILED
result.error_message = discovery_result["error"]
result.completed_at = datetime.now()
return result
trending_stocks = discovery_result["stocks"]
if request.sector_filter:
sector_values = {
s.value if isinstance(s, Sector) else s for s in request.sector_filter
}
trending_stocks = [
stock
for stock in trending_stocks
if stock.sector.value in sector_values
or stock.sector in request.sector_filter
]
if request.event_filter:
event_values = {
e.value if isinstance(e, EventCategory) else e
for e in request.event_filter
}
trending_stocks = [
stock
for stock in trending_stocks
if stock.event_type.value in event_values
or stock.event_type in request.event_filter
]
result.trending_stocks = trending_stocks
result.status = DiscoveryStatus.COMPLETED
result.completed_at = datetime.now()
if self.db_enabled:
self._persist_discovery_to_database(request, result)
return result
def _persist_discovery_to_database(
self, request: DiscoveryRequest, result: DiscoveryResult
) -> None:
with get_db_session() as session:
discovery_service = DiscoveryService(session)
run = discovery_service.create_run(
request.lookback_period, request.max_results or 20
)
for stock in result.trending_stocks:
discovery_service.save_trending_stock(
run.id,
stock.ticker,
stock.company_name,
stock.trending_score,
stock.mention_count,
stock.sector.value
if hasattr(stock.sector, "value")
else str(stock.sector),
stock.event_type.value
if hasattr(stock.event_type, "value")
else str(stock.event_type),
stock.summary,
[a.get("title", "") for a in stock.source_articles]
if stock.source_articles
else None,
)
if result.status == DiscoveryStatus.COMPLETED:
discovery_service.complete_run(run.id, len(result.trending_stocks))
else:
discovery_service.fail_run(
run.id, result.error_message or "Unknown error"
)
def analyze_trending(
self,
trending_stock: TrendingStock,
trade_date: date | None = None,
) -> tuple[dict[str, Any], str]:
ticker = trending_stock.ticker
if trade_date is None:
trade_date = date.today()
return self.propagate(ticker, trade_date)