176 lines
6.0 KiB
Python
176 lines
6.0 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from tradingagents.integrations.alpaca_mcp import AlpacaMCPClient, AlpacaMCPConfig, AlpacaMCPError
|
|
|
|
|
|
@dataclass
|
|
class AccountSnapshot:
|
|
"""Structured representation of the Alpaca account state."""
|
|
|
|
fetched_at: datetime
|
|
account_text: str
|
|
positions_text: str
|
|
orders_text: str
|
|
account: Dict[str, Any]
|
|
positions: List[Dict[str, Any]]
|
|
orders: List[Dict[str, Any]]
|
|
|
|
def buying_power(self) -> float:
|
|
value = self.account.get("buying_power") or self.account.get("buying_power_usd")
|
|
return _as_float(value)
|
|
|
|
def cash(self) -> float:
|
|
value = self.account.get("cash") or self.account.get("cash_usd")
|
|
return _as_float(value)
|
|
|
|
def portfolio_value(self) -> float:
|
|
value = (
|
|
self.account.get("portfolio_value")
|
|
or self.account.get("equity")
|
|
or self.account.get("equity_value")
|
|
)
|
|
return _as_float(value)
|
|
|
|
def position_symbols(self) -> List[str]:
|
|
symbols = []
|
|
for position in self.positions:
|
|
symbol = str(position.get("symbol") or position.get("symbol:") or "").upper()
|
|
qty = _as_float(position.get("quantity") or position.get("qty") or 0)
|
|
if symbol and qty != 0:
|
|
symbols.append(symbol)
|
|
return symbols
|
|
|
|
|
|
class AccountService:
|
|
"""Fetch and cache Alpaca MCP account information."""
|
|
|
|
def __init__(self, alpaca_config: Dict[str, Any], logger: Optional[logging.Logger] = None) -> None:
|
|
config = AlpacaMCPConfig.from_dict(alpaca_config or {})
|
|
self.client = AlpacaMCPClient(config, logger=logger)
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
self._snapshot: Optional[AccountSnapshot] = None
|
|
self.enabled = bool(getattr(self.client.config, "enabled", False))
|
|
if not self.enabled:
|
|
self.logger.info("Alpaca MCP integration disabled; account snapshot will be unavailable.")
|
|
|
|
def refresh(self) -> AccountSnapshot:
|
|
"""Fetch the latest account snapshot from the Alpaca MCP server."""
|
|
|
|
if not self.enabled:
|
|
raise RuntimeError(
|
|
"Alpaca MCP integration is disabled. Set ALPACA_MCP_ENABLED=true (and related connection settings) to use the auto-trade workflow."
|
|
)
|
|
|
|
import asyncio
|
|
|
|
async def _fetch_all() -> Dict[str, str]:
|
|
async with self.client._acquire_session() as session: # type: ignore[attr-defined]
|
|
account_text = await self.client._call_tool_async("get_account_info", {}, session=session)
|
|
positions_text = await self.client._call_tool_async("get_positions", {}, session=session, validate=False)
|
|
orders_text = await self.client._call_tool_async("get_orders", {"status": "all", "limit": 50}, session=session, validate=False)
|
|
return {
|
|
"account": account_text,
|
|
"positions": positions_text,
|
|
"orders": orders_text,
|
|
}
|
|
|
|
try:
|
|
texts = asyncio.run(_fetch_all())
|
|
except AlpacaMCPError as exc:
|
|
raise RuntimeError(f"Failed to retrieve Alpaca account snapshot: {exc}") from exc
|
|
except Exception as exc:
|
|
raise RuntimeError(f"Failed to retrieve Alpaca account snapshot: {exc}") from exc
|
|
|
|
snapshot = AccountSnapshot(
|
|
fetched_at=datetime.utcnow(),
|
|
account_text=texts["account"],
|
|
positions_text=texts["positions"],
|
|
orders_text=texts["orders"],
|
|
account=_parse_key_values(texts["account"]),
|
|
positions=_parse_position_blocks(texts["positions"]),
|
|
orders=_parse_order_blocks(texts["orders"]),
|
|
)
|
|
self._snapshot = snapshot
|
|
return snapshot
|
|
|
|
@property
|
|
def snapshot(self) -> Optional[AccountSnapshot]:
|
|
return self._snapshot
|
|
|
|
|
|
def _parse_key_values(text: str) -> Dict[str, Any]:
|
|
data: Dict[str, Any] = {}
|
|
pattern = re.compile(r"^([A-Za-z0-9 _/-]+):\s*(.+)$")
|
|
for line in text.splitlines():
|
|
line = line.strip()
|
|
if not line or line.endswith(":"):
|
|
continue
|
|
match = pattern.match(line)
|
|
if not match:
|
|
continue
|
|
key = match.group(1).strip().lower().replace(" ", "_")
|
|
value = match.group(2).strip()
|
|
data[key] = value
|
|
return data
|
|
|
|
|
|
def _parse_position_blocks(text: str) -> List[Dict[str, Any]]:
|
|
if not text or "No open positions" in text:
|
|
return []
|
|
blocks = []
|
|
current: Dict[str, Any] = {}
|
|
for raw_line in text.splitlines():
|
|
line = raw_line.strip()
|
|
if not line:
|
|
continue
|
|
if line.startswith("Symbol:") and current:
|
|
blocks.append(current)
|
|
current = {}
|
|
if ":" in line:
|
|
key, value = line.split(":", 1)
|
|
current[key.strip().lower().replace(" ", "_")] = value.strip()
|
|
if current:
|
|
blocks.append(current)
|
|
return blocks
|
|
|
|
|
|
def _parse_order_blocks(text: str) -> List[Dict[str, Any]]:
|
|
if not text or "No all orders" in text or "No orders" in text:
|
|
return []
|
|
blocks = []
|
|
current: Dict[str, Any] = {}
|
|
for raw_line in text.splitlines():
|
|
line = raw_line.strip()
|
|
if not line:
|
|
continue
|
|
if line.startswith("Order ID:") and current:
|
|
blocks.append(current)
|
|
current = {}
|
|
if ":" in line:
|
|
key, value = line.split(":", 1)
|
|
current[key.strip().lower().replace(" ", "_")] = value.strip()
|
|
if current:
|
|
blocks.append(current)
|
|
return blocks
|
|
|
|
|
|
def _as_float(value: Any) -> float:
|
|
if value is None:
|
|
return 0.0
|
|
if isinstance(value, (int, float)):
|
|
return float(value)
|
|
text = str(value).strip()
|
|
if not text:
|
|
return 0.0
|
|
cleaned = text.replace("$", "").replace(",", "")
|
|
try:
|
|
return float(cleaned)
|
|
except ValueError:
|
|
return 0.0
|