236 lines
9.3 KiB
Python
236 lines
9.3 KiB
Python
"""Portfolio Manager workflow graph setup.
|
|
|
|
Sequential workflow:
|
|
START → load_portfolio → compute_risk → review_holdings
|
|
→ prioritize_candidates → pm_decision → execute_trades → END
|
|
|
|
Non-LLM nodes (load_portfolio, compute_risk, prioritize_candidates,
|
|
execute_trades) receive ``repo`` and ``config`` via closure.
|
|
LLM nodes (review_holdings, pm_decision) are created externally and passed in.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from langgraph.graph import END, START, StateGraph
|
|
|
|
from tradingagents.portfolio.candidate_prioritizer import prioritize_candidates
|
|
from tradingagents.portfolio.portfolio_states import PortfolioManagerState
|
|
from tradingagents.portfolio.risk_evaluator import compute_portfolio_risk
|
|
from tradingagents.portfolio.trade_executor import TradeExecutor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default Portfolio dict for safe fallback when portfolio_data is empty or malformed
|
|
_EMPTY_PORTFOLIO_DICT = {
|
|
"portfolio_id": "",
|
|
"name": "",
|
|
"cash": 0.0,
|
|
"initial_cash": 0.0,
|
|
}
|
|
|
|
|
|
class PortfolioGraphSetup:
|
|
"""Builds the sequential Portfolio Manager workflow graph.
|
|
|
|
Args:
|
|
agents: Dict with keys ``review_holdings`` and ``pm_decision``
|
|
mapping to LLM agent node functions.
|
|
repo: PortfolioRepository instance (injected into closure nodes).
|
|
config: Portfolio config dict.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
agents: dict[str, Any],
|
|
repo=None,
|
|
config: dict[str, Any] | None = None,
|
|
) -> None:
|
|
self.agents = agents
|
|
self._repo = repo
|
|
self._config = config or {}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Node factories (non-LLM)
|
|
# ------------------------------------------------------------------
|
|
|
|
def _make_load_portfolio_node(self):
|
|
repo = self._repo
|
|
config = self._config
|
|
|
|
def load_portfolio_node(state):
|
|
portfolio_id = state["portfolio_id"]
|
|
prices = state.get("prices") or {}
|
|
try:
|
|
if repo is None:
|
|
from tradingagents.portfolio.repository import PortfolioRepository
|
|
_repo = PortfolioRepository(config=config)
|
|
else:
|
|
_repo = repo
|
|
portfolio, holdings = _repo.get_portfolio_with_holdings(
|
|
portfolio_id, prices
|
|
)
|
|
data = {
|
|
"portfolio": portfolio.to_dict(),
|
|
"holdings": [h.to_dict() for h in holdings],
|
|
}
|
|
except Exception as exc:
|
|
logger.error("load_portfolio_node: %s", exc)
|
|
data = {"portfolio": {}, "holdings": [], "error": str(exc)}
|
|
return {
|
|
"portfolio_data": json.dumps(data),
|
|
"sender": "load_portfolio",
|
|
}
|
|
|
|
return load_portfolio_node
|
|
|
|
def _make_compute_risk_node(self):
|
|
def compute_risk_node(state):
|
|
portfolio_data_str = state.get("portfolio_data") or "{}"
|
|
prices = state.get("prices") or {}
|
|
try:
|
|
portfolio_data = json.loads(portfolio_data_str)
|
|
from tradingagents.portfolio.models import Holding, Portfolio
|
|
|
|
portfolio = Portfolio.from_dict(portfolio_data.get("portfolio") or _EMPTY_PORTFOLIO_DICT)
|
|
holdings = [
|
|
Holding.from_dict(h) for h in (portfolio_data.get("holdings") or [])
|
|
]
|
|
|
|
# Enrich holdings with prices so current_value is populated
|
|
if prices and portfolio.total_value is None:
|
|
equity = sum(prices.get(h.ticker, 0.0) * h.shares for h in holdings)
|
|
total_value = portfolio.cash + equity
|
|
for h in holdings:
|
|
if h.ticker in prices:
|
|
h.enrich(prices[h.ticker], total_value)
|
|
portfolio.enrich(holdings)
|
|
|
|
# Build simple price histories from single-point prices
|
|
# (real usage would pass historical prices via scan_summary or state)
|
|
price_histories: dict[str, list[float]] = {}
|
|
scan_summary = state.get("scan_summary") or {}
|
|
for h in holdings:
|
|
history = scan_summary.get("price_histories", {}).get(h.ticker)
|
|
if history:
|
|
price_histories[h.ticker] = history
|
|
elif h.ticker in prices:
|
|
# Single-point price — returns will be empty, metrics None
|
|
price_histories[h.ticker] = [prices[h.ticker]]
|
|
|
|
metrics = compute_portfolio_risk(portfolio, holdings, price_histories)
|
|
except Exception as exc:
|
|
logger.error("compute_risk_node: %s", exc)
|
|
metrics = {"error": str(exc)}
|
|
return {
|
|
"risk_metrics": json.dumps(metrics),
|
|
"sender": "compute_risk",
|
|
}
|
|
|
|
return compute_risk_node
|
|
|
|
def _make_prioritize_candidates_node(self):
|
|
config = self._config
|
|
|
|
def prioritize_candidates_node(state):
|
|
portfolio_data_str = state.get("portfolio_data") or "{}"
|
|
scan_summary = state.get("scan_summary") or {}
|
|
try:
|
|
portfolio_data = json.loads(portfolio_data_str)
|
|
from tradingagents.portfolio.models import Holding, Portfolio
|
|
|
|
portfolio = Portfolio.from_dict(portfolio_data.get("portfolio") or _EMPTY_PORTFOLIO_DICT)
|
|
holdings = [
|
|
Holding.from_dict(h) for h in (portfolio_data.get("holdings") or [])
|
|
]
|
|
candidates = scan_summary.get("stocks_to_investigate") or []
|
|
prices = state.get("prices") or {}
|
|
if prices:
|
|
equity = sum(prices.get(h.ticker, 0.0) * h.shares for h in holdings)
|
|
total_value = portfolio.cash + equity
|
|
for h in holdings:
|
|
if h.ticker in prices:
|
|
h.enrich(prices[h.ticker], total_value)
|
|
portfolio.enrich(holdings)
|
|
|
|
ranked = prioritize_candidates(candidates, portfolio, holdings, config)
|
|
except Exception as exc:
|
|
logger.error("prioritize_candidates_node: %s", exc)
|
|
ranked = []
|
|
return {
|
|
"prioritized_candidates": json.dumps(ranked),
|
|
"sender": "prioritize_candidates",
|
|
}
|
|
|
|
return prioritize_candidates_node
|
|
|
|
def _make_execute_trades_node(self):
|
|
repo = self._repo
|
|
config = self._config
|
|
|
|
def execute_trades_node(state):
|
|
portfolio_id = state["portfolio_id"]
|
|
analysis_date = state.get("analysis_date") or ""
|
|
prices = state.get("prices") or {}
|
|
pm_decision_str = state.get("pm_decision") or "{}"
|
|
try:
|
|
decisions = json.loads(pm_decision_str)
|
|
except (json.JSONDecodeError, TypeError):
|
|
decisions = {}
|
|
|
|
try:
|
|
if repo is None:
|
|
from tradingagents.portfolio.repository import PortfolioRepository
|
|
_repo = PortfolioRepository(config=config)
|
|
else:
|
|
_repo = repo
|
|
executor = TradeExecutor(repo=_repo, config=config)
|
|
result = executor.execute_decisions(
|
|
portfolio_id, decisions, prices, date=analysis_date
|
|
)
|
|
except Exception as exc:
|
|
logger.error("execute_trades_node: %s", exc)
|
|
result = {"error": str(exc), "executed_trades": [], "failed_trades": []}
|
|
return {
|
|
"execution_result": json.dumps(result),
|
|
"sender": "execute_trades",
|
|
}
|
|
|
|
return execute_trades_node
|
|
|
|
# ------------------------------------------------------------------
|
|
# Graph assembly
|
|
# ------------------------------------------------------------------
|
|
|
|
def setup_graph(self):
|
|
"""Build and compile the sequential portfolio workflow graph.
|
|
|
|
Returns:
|
|
A compiled LangGraph graph ready to invoke.
|
|
"""
|
|
workflow = StateGraph(PortfolioManagerState)
|
|
|
|
# Register non-LLM nodes
|
|
workflow.add_node("load_portfolio", self._make_load_portfolio_node())
|
|
workflow.add_node("compute_risk", self._make_compute_risk_node())
|
|
workflow.add_node("prioritize_candidates", self._make_prioritize_candidates_node())
|
|
workflow.add_node("execute_trades", self._make_execute_trades_node())
|
|
|
|
# Register LLM nodes
|
|
workflow.add_node("review_holdings", self.agents["review_holdings"])
|
|
workflow.add_node("pm_decision", self.agents["pm_decision"])
|
|
|
|
# Sequential edges
|
|
workflow.add_edge(START, "load_portfolio")
|
|
workflow.add_edge("load_portfolio", "compute_risk")
|
|
workflow.add_edge("compute_risk", "review_holdings")
|
|
workflow.add_edge("review_holdings", "prioritize_candidates")
|
|
workflow.add_edge("prioritize_candidates", "pm_decision")
|
|
workflow.add_edge("pm_decision", "execute_trades")
|
|
workflow.add_edge("execute_trades", END)
|
|
|
|
return workflow.compile()
|