TradingAgents/tradingagents/services/autopilot_broker.py

103 lines
3.8 KiB
Python

from __future__ import annotations
import time
from dataclasses import dataclass
from datetime import date, timedelta
from typing import Callable, Dict, Iterable, Optional
from tradingagents.dataflows.interface import route_to_vendor
from tradingagents.services.autopilot_worker import AutopilotWorker
from tradingagents.services.hypothesis_store import HypothesisRecord, HypothesisStore
@dataclass
class PriceThreshold:
symbol: str
operator: str
value: float
class AutopilotBroker:
"""Simple polling broker that watches price thresholds and enqueues events."""
def __init__(
self,
store: HypothesisStore,
worker: AutopilotWorker,
price_fetcher: Optional[Callable[[str], Optional[float]]] = None,
) -> None:
self.store = store
self.worker = worker
self.price_fetcher = price_fetcher or default_price_fetcher
self.poll_interval = 60 # seconds
def parse_triggers(self, record: HypothesisRecord) -> Iterable[PriceThreshold]:
for trigger in record.triggers:
trigger_str = str(trigger).strip().lower()
if trigger_str.startswith("price >="):
try:
symbol, value = self._parse_simple_trigger(record.ticker, trigger_str, ">=")
yield PriceThreshold(symbol=symbol, operator=">=", value=value)
except ValueError:
continue
elif trigger_str.startswith("price <="):
try:
symbol, value = self._parse_simple_trigger(record.ticker, trigger_str, "<=")
yield PriceThreshold(symbol=symbol, operator="<=", value=value)
except ValueError:
continue
def poll_once(self) -> Dict[str, str]:
outcomes: Dict[str, str] = {}
records = self.store.list()
for record in records:
for trigger in self.parse_triggers(record):
latest_price = self.price_fetcher(trigger.symbol)
if latest_price is None:
continue
if self._evaluate(trigger, latest_price):
event = self.worker.enqueue_event(
record.id,
event_type="price_threshold",
payload={"symbol": trigger.symbol, "price": latest_price, "operator": trigger.operator, "value": trigger.value},
)
outcomes[event.id] = f"Triggered price alert for {trigger.symbol}"
return outcomes
def _evaluate(self, trigger: PriceThreshold, price: float) -> bool:
if trigger.operator == ">=":
return price >= trigger.value
if trigger.operator == "<=":
return price <= trigger.value
return False
def _parse_simple_trigger(self, default_symbol: str, trigger_str: str, operator: str) -> (str, float):
parts = trigger_str.replace("price", "").strip().split(operator)
if len(parts) != 2:
raise ValueError("invalid trigger format")
left = parts[0].strip().upper()
symbol = left if left else default_symbol.upper()
value = float(parts[1].strip())
return symbol, value
def default_price_fetcher(symbol: str) -> Optional[float]:
symbol = symbol.upper()
today = date.today()
start = today - timedelta(days=2)
try:
csv_text = route_to_vendor("get_stock_data", symbol, start.isoformat(), today.isoformat())
except Exception:
return None
lines = [line for line in str(csv_text).splitlines() if line and not line.startswith("#")]
if not lines:
return None
last_line = lines[-1]
parts = last_line.split(",")
if len(parts) < 5:
return None
try:
return float(parts[4]) # close price
except ValueError:
return None