103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import date, timedelta
|
|
from typing import Callable, Dict, Iterable, Optional
|
|
|
|
from tradingagents.dataflows.interface import route_to_vendor
|
|
from tradingagents.services.autopilot_worker import AutopilotWorker
|
|
from tradingagents.services.hypothesis_store import HypothesisRecord, HypothesisStore
|
|
|
|
|
|
@dataclass
|
|
class PriceThreshold:
|
|
symbol: str
|
|
operator: str
|
|
value: float
|
|
|
|
|
|
class AutopilotBroker:
|
|
"""Simple polling broker that watches price thresholds and enqueues events."""
|
|
|
|
def __init__(
|
|
self,
|
|
store: HypothesisStore,
|
|
worker: AutopilotWorker,
|
|
price_fetcher: Optional[Callable[[str], Optional[float]]] = None,
|
|
) -> None:
|
|
self.store = store
|
|
self.worker = worker
|
|
self.price_fetcher = price_fetcher or default_price_fetcher
|
|
self.poll_interval = 60 # seconds
|
|
|
|
def parse_triggers(self, record: HypothesisRecord) -> Iterable[PriceThreshold]:
|
|
for trigger in record.triggers:
|
|
trigger_str = str(trigger).strip().lower()
|
|
if trigger_str.startswith("price >="):
|
|
try:
|
|
symbol, value = self._parse_simple_trigger(record.ticker, trigger_str, ">=")
|
|
yield PriceThreshold(symbol=symbol, operator=">=", value=value)
|
|
except ValueError:
|
|
continue
|
|
elif trigger_str.startswith("price <="):
|
|
try:
|
|
symbol, value = self._parse_simple_trigger(record.ticker, trigger_str, "<=")
|
|
yield PriceThreshold(symbol=symbol, operator="<=", value=value)
|
|
except ValueError:
|
|
continue
|
|
|
|
def poll_once(self) -> Dict[str, str]:
|
|
outcomes: Dict[str, str] = {}
|
|
records = self.store.list()
|
|
for record in records:
|
|
for trigger in self.parse_triggers(record):
|
|
latest_price = self.price_fetcher(trigger.symbol)
|
|
if latest_price is None:
|
|
continue
|
|
if self._evaluate(trigger, latest_price):
|
|
event = self.worker.enqueue_event(
|
|
record.id,
|
|
event_type="price_threshold",
|
|
payload={"symbol": trigger.symbol, "price": latest_price, "operator": trigger.operator, "value": trigger.value},
|
|
)
|
|
outcomes[event.id] = f"Triggered price alert for {trigger.symbol}"
|
|
return outcomes
|
|
|
|
def _evaluate(self, trigger: PriceThreshold, price: float) -> bool:
|
|
if trigger.operator == ">=":
|
|
return price >= trigger.value
|
|
if trigger.operator == "<=":
|
|
return price <= trigger.value
|
|
return False
|
|
|
|
def _parse_simple_trigger(self, default_symbol: str, trigger_str: str, operator: str) -> (str, float):
|
|
parts = trigger_str.replace("price", "").strip().split(operator)
|
|
if len(parts) != 2:
|
|
raise ValueError("invalid trigger format")
|
|
left = parts[0].strip().upper()
|
|
symbol = left if left else default_symbol.upper()
|
|
value = float(parts[1].strip())
|
|
return symbol, value
|
|
|
|
|
|
def default_price_fetcher(symbol: str) -> Optional[float]:
|
|
symbol = symbol.upper()
|
|
today = date.today()
|
|
start = today - timedelta(days=2)
|
|
try:
|
|
csv_text = route_to_vendor("get_stock_data", symbol, start.isoformat(), today.isoformat())
|
|
except Exception:
|
|
return None
|
|
lines = [line for line in str(csv_text).splitlines() if line and not line.startswith("#")]
|
|
if not lines:
|
|
return None
|
|
last_line = lines[-1]
|
|
parts = last_line.split(",")
|
|
if len(parts) < 5:
|
|
return None
|
|
try:
|
|
return float(parts[4]) # close price
|
|
except ValueError:
|
|
return None
|