TradingAgents/tradingagents/services/account.py

176 lines
6.0 KiB
Python

from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
from tradingagents.integrations.alpaca_mcp import AlpacaMCPClient, AlpacaMCPConfig, AlpacaMCPError
@dataclass
class AccountSnapshot:
"""Structured representation of the Alpaca account state."""
fetched_at: datetime
account_text: str
positions_text: str
orders_text: str
account: Dict[str, Any]
positions: List[Dict[str, Any]]
orders: List[Dict[str, Any]]
def buying_power(self) -> float:
value = self.account.get("buying_power") or self.account.get("buying_power_usd")
return _as_float(value)
def cash(self) -> float:
value = self.account.get("cash") or self.account.get("cash_usd")
return _as_float(value)
def portfolio_value(self) -> float:
value = (
self.account.get("portfolio_value")
or self.account.get("equity")
or self.account.get("equity_value")
)
return _as_float(value)
def position_symbols(self) -> List[str]:
symbols = []
for position in self.positions:
symbol = str(position.get("symbol") or position.get("symbol:") or "").upper()
qty = _as_float(position.get("quantity") or position.get("qty") or 0)
if symbol and qty != 0:
symbols.append(symbol)
return symbols
class AccountService:
"""Fetch and cache Alpaca MCP account information."""
def __init__(self, alpaca_config: Dict[str, Any], logger: Optional[logging.Logger] = None) -> None:
config = AlpacaMCPConfig.from_dict(alpaca_config or {})
self.client = AlpacaMCPClient(config, logger=logger)
self.logger = logger or logging.getLogger(__name__)
self._snapshot: Optional[AccountSnapshot] = None
self.enabled = bool(getattr(self.client.config, "enabled", False))
if not self.enabled:
self.logger.info("Alpaca MCP integration disabled; account snapshot will be unavailable.")
def refresh(self) -> AccountSnapshot:
"""Fetch the latest account snapshot from the Alpaca MCP server."""
if not self.enabled:
raise RuntimeError(
"Alpaca MCP integration is disabled. Set ALPACA_MCP_ENABLED=true (and related connection settings) to use the auto-trade workflow."
)
import asyncio
async def _fetch_all() -> Dict[str, str]:
async with self.client._acquire_session() as session: # type: ignore[attr-defined]
account_text = await self.client._call_tool_async("get_account_info", {}, session=session)
positions_text = await self.client._call_tool_async("get_positions", {}, session=session, validate=False)
orders_text = await self.client._call_tool_async("get_orders", {"status": "all", "limit": 50}, session=session, validate=False)
return {
"account": account_text,
"positions": positions_text,
"orders": orders_text,
}
try:
texts = asyncio.run(_fetch_all())
except AlpacaMCPError as exc:
raise RuntimeError(f"Failed to retrieve Alpaca account snapshot: {exc}") from exc
except Exception as exc:
raise RuntimeError(f"Failed to retrieve Alpaca account snapshot: {exc}") from exc
snapshot = AccountSnapshot(
fetched_at=datetime.utcnow(),
account_text=texts["account"],
positions_text=texts["positions"],
orders_text=texts["orders"],
account=_parse_key_values(texts["account"]),
positions=_parse_position_blocks(texts["positions"]),
orders=_parse_order_blocks(texts["orders"]),
)
self._snapshot = snapshot
return snapshot
@property
def snapshot(self) -> Optional[AccountSnapshot]:
return self._snapshot
def _parse_key_values(text: str) -> Dict[str, Any]:
data: Dict[str, Any] = {}
pattern = re.compile(r"^([A-Za-z0-9 _/-]+):\s*(.+)$")
for line in text.splitlines():
line = line.strip()
if not line or line.endswith(":"):
continue
match = pattern.match(line)
if not match:
continue
key = match.group(1).strip().lower().replace(" ", "_")
value = match.group(2).strip()
data[key] = value
return data
def _parse_position_blocks(text: str) -> List[Dict[str, Any]]:
if not text or "No open positions" in text:
return []
blocks = []
current: Dict[str, Any] = {}
for raw_line in text.splitlines():
line = raw_line.strip()
if not line:
continue
if line.startswith("Symbol:") and current:
blocks.append(current)
current = {}
if ":" in line:
key, value = line.split(":", 1)
current[key.strip().lower().replace(" ", "_")] = value.strip()
if current:
blocks.append(current)
return blocks
def _parse_order_blocks(text: str) -> List[Dict[str, Any]]:
if not text or "No all orders" in text or "No orders" in text:
return []
blocks = []
current: Dict[str, Any] = {}
for raw_line in text.splitlines():
line = raw_line.strip()
if not line:
continue
if line.startswith("Order ID:") and current:
blocks.append(current)
current = {}
if ":" in line:
key, value = line.split(":", 1)
current[key.strip().lower().replace(" ", "_")] = value.strip()
if current:
blocks.append(current)
return blocks
def _as_float(value: Any) -> float:
if value is None:
return 0.0
if isinstance(value, (int, float)):
return float(value)
text = str(value).strip()
if not text:
return 0.0
cleaned = text.replace("$", "").replace(",", "")
try:
return float(cleaned)
except ValueError:
return 0.0