From 2906e5a5706b08546f427123057e1803c3885c37 Mon Sep 17 00:00:00 2001 From: Slava Nikitin Date: Fri, 14 Nov 2025 20:49:09 +0700 Subject: [PATCH] feat: integrate Alpaca MCP orchestrator --- .dockerignore | 11 + .env.alpaca.example | 3 + .env.example | 61 +- .gitignore | 4 + Dockerfile | 23 + README.md | 106 +- docker-compose.yml | 27 + docs/auto_pilot_hypothesis_webhooks.md | 202 ++++ docs/fixtures/autopilot_seed.json | 208 ++++ main.py | 881 +++++++++++++++- pyproject.toml | 2 + requirements.txt | 2 + run_autopilot.sh | 6 + scripts/seed_autopilot_state.py | 155 +++ setup.py | 1 + tradingagents/agents/__init__.py | 2 + .../agents/analysts/fundamentals_analyst.py | 70 +- .../agents/analysts/market_analyst.py | 68 +- tradingagents/agents/analysts/news_analyst.py | 60 +- .../agents/analysts/social_media_analyst.py | 34 +- tradingagents/agents/managers/orchestrator.py | 619 +++++++++++ .../agents/managers/research_manager.py | 9 + tradingagents/agents/managers/risk_manager.py | 10 +- tradingagents/agents/trader/trader.py | 32 +- tradingagents/agents/utils/agent_states.py | 22 +- tradingagents/agents/utils/logging_utils.py | 15 + tradingagents/agents/utils/tool_runner.py | 101 ++ tradingagents/dataflows/interface.py | 42 +- tradingagents/default_config.py | 115 +++ tradingagents/graph/propagation.py | 37 +- tradingagents/graph/scheduler.py | 60 ++ tradingagents/graph/setup.py | 49 +- tradingagents/graph/signal_processing.py | 8 + tradingagents/graph/trading_graph.py | 585 ++++++++++- tradingagents/integrations/__init__.py | 3 + .../integrations/alpaca_mcp/__init__.py | 6 + .../integrations/alpaca_mcp/client.py | 221 ++++ .../integrations/alpaca_mcp/config.py | 51 + tradingagents/integrations/mcp_handshake.py | 190 ++++ .../integrations/sequential_mcp/__init__.py | 6 + .../integrations/sequential_mcp/client.py | 207 ++++ .../integrations/sequential_mcp/config.py | 49 + .../integrations/sequential_mcp/server.py | 179 ++++ tradingagents/services/__init__.py | 24 + tradingagents/services/account.py | 175 ++++ tradingagents/services/auto_trade.py | 628 +++++++++++ tradingagents/services/autopilot_broker.py | 102 ++ tradingagents/services/autopilot_events.py | 79 ++ tradingagents/services/autopilot_worker.py | 148 +++ tradingagents/services/hypothesis_store.py | 226 ++++ tradingagents/services/memory.py | 76 ++ tradingagents/services/realtime_broker.py | 167 +++ .../services/realtime_news_broker.py | 148 +++ .../services/responses_auto_trade.py | 977 ++++++++++++++++++ 54 files changed, 7133 insertions(+), 159 deletions(-) create mode 100644 .dockerignore create mode 100644 .env.alpaca.example create mode 100644 Dockerfile create mode 100644 docker-compose.yml create mode 100644 docs/auto_pilot_hypothesis_webhooks.md create mode 100644 docs/fixtures/autopilot_seed.json create mode 100755 run_autopilot.sh create mode 100644 scripts/seed_autopilot_state.py create mode 100644 tradingagents/agents/managers/orchestrator.py create mode 100644 tradingagents/agents/utils/logging_utils.py create mode 100644 tradingagents/agents/utils/tool_runner.py create mode 100644 tradingagents/graph/scheduler.py create mode 100644 tradingagents/integrations/__init__.py create mode 100644 tradingagents/integrations/alpaca_mcp/__init__.py create mode 100644 tradingagents/integrations/alpaca_mcp/client.py create mode 100644 tradingagents/integrations/alpaca_mcp/config.py create mode 100644 tradingagents/integrations/mcp_handshake.py create mode 100644 tradingagents/integrations/sequential_mcp/__init__.py create mode 100644 tradingagents/integrations/sequential_mcp/client.py create mode 100644 tradingagents/integrations/sequential_mcp/config.py create mode 100644 tradingagents/integrations/sequential_mcp/server.py create mode 100644 tradingagents/services/__init__.py create mode 100644 tradingagents/services/account.py create mode 100644 tradingagents/services/auto_trade.py create mode 100644 tradingagents/services/autopilot_broker.py create mode 100644 tradingagents/services/autopilot_events.py create mode 100644 tradingagents/services/autopilot_worker.py create mode 100644 tradingagents/services/hypothesis_store.py create mode 100644 tradingagents/services/memory.py create mode 100644 tradingagents/services/realtime_broker.py create mode 100644 tradingagents/services/realtime_news_broker.py create mode 100644 tradingagents/services/responses_auto_trade.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..f0313660 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,11 @@ +.git +.gitignore +.env +.env.* +.venv +__pycache__ +*.pyc +results +eval_results +assets +.codex diff --git a/.env.alpaca.example b/.env.alpaca.example new file mode 100644 index 00000000..1913bdbf --- /dev/null +++ b/.env.alpaca.example @@ -0,0 +1,3 @@ +ALPACA_API_KEY=your_alpaca_key +ALPACA_SECRET_KEY=your_alpaca_secret +ALPACA_PAPER_TRADE=True diff --git a/.env.example b/.env.example index 1e257c3c..21a29bf5 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,61 @@ ALPHA_VANTAGE_API_KEY=alpha_vantage_api_key_placeholder -OPENAI_API_KEY=openai_api_key_placeholder \ No newline at end of file +OPENAI_API_KEY=openai_api_key_placeholder +PORTFOLIO_PROFILE_NAME=Balanced Multi-Asset +PORTFOLIO_MANDATE=Preserve capital while capturing medium-term growth opportunities in technology and consumer sectors. +PORTFOLIO_MAX_POSITION_PCT=0.15 +PORTFOLIO_MAX_SECTOR_PCT=0.35 +PORTFOLIO_NOTES=Prioritize liquid large-cap names. Avoid exceeding buying power and respect existing hedges. +PORTFOLIO_UNIVERSE=NVDA,AAPL,MSFT,AMZN,TSLA +PORTFOLIO_SENTIMENT_LOOKBACK=2 +PORTFOLIO_NEWS_LIMIT=5 +PORTFOLIO_HYPOTHESIS_THRESHOLD=0.6 +PORTFOLIO_MAX_HYPOTHESES=2 +PORTFOLIO_MARKET_LOOKBACK=30 +PORTFOLIO_TRADE_PRIORITY_THRESHOLD=0.8 +PORTFOLIO_TRADE_MIN_CASH=50000 +PORTFOLIO_TRADE_MIN_CASH_RATIO=0.1 +ALPACA_MCP_ENABLED=false +ALPACA_MCP_TRANSPORT=http +ALPACA_MCP_HOST=127.0.0.1 +ALPACA_MCP_BASE_URL=http://host.docker.internal:8000/mcp +ALPACA_MCP_PORT=8000 +ALPACA_MCP_COMMAND= +ALPACA_MCP_TIMEOUT_SECONDS=30 +TRADE_EXECUTION_ENABLED=false +TRADE_EXECUTION_DRY_RUN=true +TRADE_EXECUTION_DEFAULT_QTY=10 +TRADE_EXECUTION_TIF=day +AUTO_TRADE_MODE=graph +AUTO_TRADE_MAX_TICKERS=12 +AUTO_TRADE_SKIP_WHEN_MARKET_CLOSED=true +AUTO_TRADE_RESPONSES_MODEL=gpt-4.1-mini +# Optional: set AUTO_TRADE_RESPONSES_REASONING=medium (only for models that support reasoning) +AUTO_TRADE_RESPONSES_REASONING= +AUTO_TRADE_RESPONSES_MAX_TURNS=8 +AUTO_TRADE_MEMORY_ENABLED=true +AUTO_TRADE_MEMORY_DIR=./results/memory +AUTO_TRADE_MEMORY_MAX_ENTRIES=5 +TRADINGAGENTS_RESULTS_DIR=./results +APCA_API_KEY_ID= +APCA_API_SECRET_KEY= +ALPACA_DATA_FEED=iex +ALPACA_NEWS_STREAM_URL= +TRADINGAGENTS_AUTOPILOT=false +AUTOPILOT_SEED_AUTO_TRADE=true +AUTOPILOT_LOOP_SECONDS=10 +AUTOPILOT_PRICE_POLL_SECONDS=60 +AUTOPILOT_PREMARKET_MINUTES=30 +# Strategy presets (default values) +TRADINGAGENTS_DEFAULT_STRATEGY=swing +# Day Trade: 6h horizon, +2% target, -1% stop +TRADINGAGENTS_DAYTRADE_HOURS=6 +TRADINGAGENTS_DAYTRADE_TARGET=0.02 +TRADINGAGENTS_DAYTRADE_STOP=0.01 +# Swing Trade: 72h horizon (~3 days), +4% target, -2% stop +TRADINGAGENTS_SWING_HOURS=72 +TRADINGAGENTS_SWING_TARGET=0.04 +TRADINGAGENTS_SWING_STOP=0.02 +# Position Trade: 336h horizon (~2 weeks), +8% target, -4% stop +TRADINGAGENTS_POSITION_HOURS=336 +TRADINGAGENTS_POSITION_TARGET=0.08 +TRADINGAGENTS_POSITION_STOP=0.04 diff --git a/.gitignore b/.gitignore index 3369bad9..3496796d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ results env/ __pycache__/ +.pycache/ .DS_Store *.csv src/ @@ -9,3 +10,6 @@ eval_results/ eval_data/ *.egg-info/ .env +.env.alpaca +.codex/ +Library/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..c7b1fbdc --- /dev/null +++ b/Dockerfile @@ -0,0 +1,23 @@ +FROM python:3.11-slim + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +WORKDIR /app + +# Install build deps required by some Python packages +RUN apt-get update \ + && apt-get install -y --no-install-recommends build-essential \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt pyproject.toml README.md ./ +RUN pip install --upgrade pip \ + && pip install -r requirements.txt + +COPY tradingagents ./tradingagents +COPY main.py . + +# Provide a writable directory for results/logs +RUN mkdir -p /app/results + +CMD ["python", "main.py"] diff --git a/README.md b/README.md index 7e90c60f..43bd398a 100644 --- a/README.md +++ b/README.md @@ -159,45 +159,97 @@ We built TradingAgents with LangGraph to ensure flexibility and modularity. We u ### Python Usage -To use TradingAgents inside your code, you can import the `tradingagents` module and initialize a `TradingAgentsGraph()` object. The `.propagate()` function will return a decision. You can run `main.py`, here's also a quick example: +Run `python main.py` to launch the interactive CLI. On startup TradingAgents connects to Alpaca MCP, caches the current account snapshot, and presents a menu: + +- **Refresh Alpaca snapshot** – pull the latest account/position/order data. +- **Show account summary / positions / recent orders** – inspect the cached snapshot. +- **Run auto-trade** – execute the end-to-end workflow (market data fetch → hypothesis generation → sequential deep thinking per ticker) and display the reasoning trace for every ticker. + +Each auto-trade run saves a JSON summary to `results/auto_trade_.json`, making it easy to schedule cron jobs or other entrypoints that call the same logic programmatically. The CLI uses the new `AutoTradeService`, so you can reuse it directly: ```python -from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.services.account import AccountService +from tradingagents.services.auto_trade import AutoTradeService -ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy()) +config = DEFAULT_CONFIG.copy() +account_service = AccountService(config["alpaca_mcp"]) +snapshot = account_service.refresh() -# forward propagate -_, decision = ta.propagate("NVDA", "2024-05-10") -print(decision) +graph = TradingAgentsGraph(config=config, skip_initial_probes=True) +auto_trader = AutoTradeService(config=config, graph=graph) +result = auto_trader.run(snapshot) + +print(result.summary()) ``` -You can also adjust the default configuration to set your own choice of LLMs, debate rounds, etc. +#### Responses-driven orchestration (experimental) + +Set `AUTO_TRADE_MODE=responses` to let the OpenAI Responses API drive the run. The orchestrator narrates each step, calls the registered tools (snapshot, vendor data, order submission), and finishes with a JSON summary that the CLI renders. Configure the model via `AUTO_TRADE_RESPONSES_MODEL` (for example `gpt-4.1-mini`). If you are using a reasoning-capable model, you can optionally set `AUTO_TRADE_RESPONSES_REASONING=medium`; otherwise leave it blank. The guardrail `AUTO_TRADE_SKIP_WHEN_MARKET_CLOSED` still applies before the session is started, so you will not burn tokens while markets are closed unless you opt in. + +To help the LLM build on past context, enable the ticker memory tool (default). After each run the agent stores a short history of decisions per ticker under `AUTO_TRADE_MEMORY_DIR` (default `./results/memory`). On the next run it can call `get_ticker_memory` to recap recent actions. Control retention with `AUTO_TRADE_MEMORY_MAX_ENTRIES`. + +The Responses orchestrator can also call the same research agents that power the LangGraph pipeline via tools (`run_market_analyst`, `run_news_analyst`, `run_fundamentals_analyst`). Encourage this by leaving `AUTO_TRADE_MODE=responses` and the prompt will request those tools before making final trade recommendations. + +When the sequential-thinking planner promotes a ticker to `trade` or `escalate`, the CLI highlights the exact reasoning steps (confidence, capital checks, escalation path) so every decision is auditable. + +#### Resetting autopilot state for testing + +The autopilot loop reads hypotheses/memory from `results/`. When you want to start from a clean slate (no ticker memory) but still have a few ready-made hypotheses + triggers to exercise the realtime brokers, run the seeding helper: + +```bash +python scripts/seed_autopilot_state.py --force +# optional knobs: +# --skip-fixture only wipe existing data +# --auto-trade run a fresh auto-trade after seeding (requires MCP) +# --results-dir PATH override results directory (default ./results) +# --memory-dir PATH override AUTO_TRADE_MEMORY_DIR +# --fixture PATH load a custom seed fixture instead of docs/fixtures/autopilot_seed.json +``` + +By default the script wipes `results/hypotheses/`, `results/autopilot/`, and the configured `AUTO_TRADE_MEMORY_DIR`, then loads `docs/fixtures/autopilot_seed.json`. The fixture provides three sample hypotheses (NVDA/AAPL/TSLA) with price + news triggers and a couple of seed events, so the realtime price/news brokers immediately subscribe and the heartbeat shows non-zero symbols. After seeding you can launch the CLI in autopilot mode (`python main.py --autopilot`) and watch the worker consume the pre-created history before you generate new hypotheses. + +Each auto-trade decision now carries an explicit strategy directive. Configure the presets under `trading_strategies` in `default_config.py` (or via env vars such as `TRADINGAGENTS_DEFAULT_STRATEGY`, `TRADINGAGENTS_DAYTRADE_TARGET`, etc.). Strategies define horizon, target/stop percentages, urgency, and follow-up behavior, ensuring hypotheses always include measurable success/failure metrics and a deadline for reevaluation. + +Autopilot is market-aware: when `AUTO_TRADE_SKIP_WHEN_MARKET_CLOSED=true`, the orchestrator will skip baseline runs while the exchange is closed, wake itself up a configurable number of minutes before the next open (`AUTOPILOT_PREMARKET_MINUTES`, default 30) to refresh research, and immediately re-run once the bell rings. Websocket listeners remain active 24/7, so breaking news still triggers focused research runs even outside trading hours, but actual order placement is deferred until the market opens again. + +### Portfolio Orchestrator & Alpaca Execution (Optional) + +Set `ALPACA_MCP_ENABLED=true` and point the connection variables to your running Alpaca MCP server if you want the auto-trader to pull live account context. Most deployments expose the Model Context Protocol over JSON-RPC at `/mcp`, so in practice you will define: + +```env +ALPACA_MCP_ENABLED=true +ALPACA_MCP_BASE_URL=http://host.docker.internal:8000/mcp # from inside Docker +``` + +You can omit `ALPACA_MCP_HOST` when a `base_url` is provided. The orchestrator scans its configured universe (`PORTFOLIO_UNIVERSE`), drafts hypotheses, and selectively schedules analysts before escalating to the trader. When `TRADE_EXECUTION_ENABLED=true`, the trader will attempt to place a market order through the MCP server (respecting `TRADE_EXECUTION_DRY_RUN`, `TRADE_EXECUTION_DEFAULT_QTY`, and `TRADE_EXECUTION_TIF`). Leave the flags at their defaults to run analysis-only mode. + + +### Docker Quickstart + +1. Build the TradingAgents image: + ```bash + docker build -t tradingagents:latest . + ``` +2. Prepare environment files: + - `.env` – TradingAgents config (OpenAI key, portfolio settings, `ALPACA_MCP_BASE_URL=http://host.docker.internal:8000/mcp`, etc.) + - `.env.alpaca` – Alpaca MCP credentials if you run the server locally (see `.env.alpaca.example`). +3. Launch the Alpaca MCP + TradingAgents stack: + ```bash + docker compose up --build trading-agents + ``` + The orchestrator connects to Alpaca via the `alpaca-mcp` service and uses the configured LLM to produce the sequential plan. +4. Results are written to `./results` on the host. Toggle `TRADE_EXECUTION_DRY_RUN` when you’re ready for real orders. + +You can customise `DEFAULT_CONFIG` the same way as before (choice of LLMs, vendor overrides, trade thresholds). The CLI and `AutoTradeService` honour those settings. For example, to increase the market data window and the trade priority threshold: ```python -from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG -# Create a custom config config = DEFAULT_CONFIG.copy() -config["deep_think_llm"] = "gpt-4.1-nano" # Use a different model -config["quick_think_llm"] = "gpt-4.1-nano" # Use a different model -config["max_debate_rounds"] = 1 # Increase debate rounds - -# Configure data vendors (default uses yfinance and Alpha Vantage) -config["data_vendors"] = { - "core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local - "technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local - "fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local - "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local -} - -# Initialize with custom config -ta = TradingAgentsGraph(debug=True, config=config) - -# forward propagate -_, decision = ta.propagate("NVDA", "2024-05-10") -print(decision) +config["portfolio_orchestrator"]["market_data_lookback_days"] = 90 +config["portfolio_orchestrator"]["trade_activation"]["priority_threshold"] = 0.85 ``` > The default configuration uses yfinance for stock price and technical data, and Alpha Vantage for fundamental and news data. For production use or if you encounter rate limits, consider upgrading to [Alpha Vantage Premium](https://www.alphavantage.co/premium/) for more stable and reliable data access. For offline experimentation, there's a local data vendor option that uses our **Tauric TradingDB**, a curated dataset for backtesting, though this is still in development. We're currently refining this dataset and plan to release it soon alongside our upcoming projects. Stay tuned! diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..3eb406e9 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,27 @@ +services: + alpaca-mcp: + image: mcp/alpaca:latest + container_name: alpaca-mcp + env_file: + - .env.alpaca + command: serve --transport http --host 0.0.0.0 --port 8000 + ports: + - "8000:8000" + + trading-agents: + build: . + container_name: trading-agents + env_file: + - .env + environment: + ALPACA_MCP_BASE_URL: ${ALPACA_MCP_BASE_URL:-http://alpaca-mcp:8000/mcp} + SEQUENTIAL_MCP_ENABLED: ${SEQUENTIAL_MCP_ENABLED:-false} + SEQUENTIAL_MCP_TRANSPORT: ${SEQUENTIAL_MCP_TRANSPORT:-http} + SEQUENTIAL_MCP_COMMAND: ${SEQUENTIAL_MCP_COMMAND:-} + SEQUENTIAL_MCP_HOST: ${SEQUENTIAL_MCP_HOST:-} + SEQUENTIAL_MCP_BASE_URL: ${SEQUENTIAL_MCP_BASE_URL:-} + SEQUENTIAL_MCP_PORT: ${SEQUENTIAL_MCP_PORT:-8001} + depends_on: + - alpaca-mcp + volumes: + - ./results:/app/results diff --git a/docs/auto_pilot_hypothesis_webhooks.md b/docs/auto_pilot_hypothesis_webhooks.md new file mode 100644 index 00000000..201cd5cd --- /dev/null +++ b/docs/auto_pilot_hypothesis_webhooks.md @@ -0,0 +1,202 @@ +# Auto-Pilot Hypothesis Webhook Architecture + +This document describes how to evolve TradingAgents from "fire-and-forget" auto-trade runs into a persistent, event-driven autopilot that keeps hypotheses alive, listens for webhooks (price/news/deadline), and reruns the decision pipeline whenever fresh context arrives. + +## Goals + +1. **Hypothesis lifecycle awareness** – record every hypothesis (ticker, rationale, triggers, expiry) so we can resume work without re-running the full CLI loop. +2. **Event subscriptions** – for each hypothesis, subscribe to market data and news webhooks (or schedule polling fallbacks) so updates push into the system. +3. **Autonomous reevaluation** – when an event or deadline fires, spin a focused Responses run for that ticker using the stored plan + new evidence, optionally executing trades. +4. **Operator visibility** – expose hypothesis status, last event, and upcoming triggers both in the CLI and via logs/metrics. + +## High-Level Components + +| Component | Responsibility | +| --- | --- | +| **Auto-Trade Runner (existing)** | Seeds new hypotheses via CLI run; persists plan/memory snapshots. | +| **Hypothesis Registry (new service / module)** | Stores hypotheses, plan tracker, triggers, active status, timestamps. Provides CRUD + query APIs. | +| **Subscription Broker** | Translates hypothesis triggers into concrete subscriptions: price alerts, scheduled polls, news feeds. Emits events back into the registry. | +| **Event Queue / Webhook Ingress** | Receives vendor callbacks (HTTP POST), normalizes payloads, and enqueues `HypothesisEvent` messages. | +| **Hypothesis Worker** | Consumes events, fetches the hypothesis state, runs a scoped Responses plan (or graph propagator), updates decisions/trades, and writes a new memory entry. | +| **Scheduler / Heartbeats** | Ensures every hypothesis is re-evaluated at deadlines even without external events. | +| **Notifications** | Optional module to push summaries to Slack/email once a hypothesis changes state or executes trades. | + +## Data Model + +```mermaid +classDiagram + class Hypothesis { + string id + string ticker + string rationale + float priority + list~PlanStep~ plan + list~Trigger~ triggers + string status // pending, monitoring, triggered, resolved, cancelled + datetime created_at + datetime updated_at + datetime expires_at + dict latest_context // cached indicators/news snippets + } + class PlanStep { + string id + string description + string status // pending, in_progress, done, blocked + dict metadata // e.g., tool args, evidence summary + } + class Trigger { + string type // price_threshold, news_keyword, heartbeat + dict params // symbol, operator, value, keywords, cadence + string subscription_id + datetime last_fired + } + class HypothesisEvent { + string id + string hypothesis_id + string trigger_type + dict payload + datetime received_at + } +``` + +## Event Flow + +1. **Registration** + - After the CLI auto-trade run emits decisions, call `POST /hypotheses` with ticker, plan, triggers, and expiry. The registry responds with `hypothesis_id`. + - For each trigger, call the Subscription Broker (`POST /subscriptions`) which either: + - Registers a vendor webhook (e.g., Polygon alert), storing the callback target (`/events/vendor`), or + - Schedules a polling job / cron for news feeds. + +2. **Event Ingress** + - Vendor posts to `/events/vendor` with its raw payload. + - Broker normalizes to `HypothesisEvent` and enqueues it (e.g., Redis stream / SQS / PostgreSQL table). + +3. **Hypothesis Worker** + - Worker dequeues the event, loads the hypothesis + plan, and constructs a mini conversation for Responses: + ```json + { + "hypothesis": {...}, + "last_plan": [...], + "event": {"type": "price_threshold", "payload": {...}} + } + ``` + - Responses returns updated plan_status, decisions, and optional trades. + - Worker executes trades via existing `submit_trade_order` tooling (respecting dry-run), updates plan statuses, and emits notifications. + +4. **Heartbeats / Deadlines** + - Scheduler queues `HypothesisEvent(type=heartbeat)` at `expires_at` or `next_check_at` to force reevaluation if no external trigger fires. + +5. **Resolution** + - When a hypothesis executes its plan (trade complete, catalyst resolved, or manually cancelled), mark status `resolved` and cancel all subscriptions. + +## API Sketch + +### Hypothesis Registry +``` +POST /hypotheses +{ + "ticker": "TSLA", + "rationale": "Watch for breakout > 435", + "priority": 0.7, + "plan": [ + {"id": "step_tsla_market_7d", "description": "Fetch 7-day market data", "status": "pending"}, + {"id": "step_tsla_news", "description": "Fetch news (5d)", "status": "pending"} + ], + "triggers": [ + {"type": "price_threshold", "params": {"symbol": "TSLA", "operator": ">=", "value": 435}}, + {"type": "heartbeat", "params": {"interval_minutes": 120}} + ], + "expires_at": "2025-11-14T21:00:00Z" +} +``` + +### Subscription Broker +``` +POST /subscriptions +{ + "hypothesis_id": "hypo_tsla_001", + "trigger_id": "trig_price_1", + "type": "price_threshold", + "params": {"symbol": "TSLA", "operator": ">=", "value": 435} +} +``` +_Response_: `{ "subscription_id": "sub_polygon_abc" }` + +### Event Ingress +``` +POST /events/vendor +{ + "subscription_id": "sub_polygon_abc", + "vendor": "polygon", + "payload": {"symbol": "TSLA", "price": 435.2, "timestamp": "..."} +} +``` + +### Hypothesis Worker Output +``` +PATCH /hypotheses/hypo_tsla_001 +{ + "status": "monitoring", + "plan": [ + {"id": "step_tsla_market_7d", "status": "done", "evidence": "Price dropped to 397"}, + ... + ], + "latest_context": {"price": 435.2, "news_sentiment": 0.18}, + "next_check_at": "2025-11-14T18:00:00Z" +} +``` + +## Implementation Plan + +1. **Data layer** + - Create `Hypothesis` + `PlanStep` models (SQLAlchemy/Prisma/SQLite) under `tradingagents/services/hypotheses/`. + - Provide `HypothesisStore` with CRUD and serialization to feed Responses runs. + +2. **CLI → Registry hook** + - After `AutoTradeResult` is produced, translate each decision into a hypothesis payload and POST to the registry (or call local store API). Include plan_steps derived from `decision.sequential_plan.actions` and triggers derived from `action_queue` or heuristics. + +3. **Subscription Broker** + - MVP: cron-based pollers (every N minutes) for price + news; advanced: integrate Polygon/IEX price alerts and Benzinga/Finhub news webhooks. + - Maintain mapping `subscription_id → hypothesis_id, trigger_id`. + +4. **Event Queue** + - Simple approach: Postgres table `hypothesis_events` + worker loop. + - Scalable option: Redis stream / AWS SQS. + +5. **Hypothesis Worker** + - Reuse `ResponsesAutoTradeService` but scope conversation to the single ticker + event. + - Provide `HypothesisContextBuilder` to gather memory, last plan, and event payload. + - Record new plan statuses + trade actions; append to memory store. + +6. **Scheduler** + - Use APScheduler/Celery beat to enqueue `heartbeat` events per hypothesis (e.g., every 2h and at market close). + +7. **Notifications / CLI integration** + - CLI command `tradingagents cli hypotheses list` to show status, last event, next trigger. + - Optional: Slack webhook when hypothesis transitions to `triggered` or executes a trade. + +### MVP status + +- ✅ Hypothesis store + CLI viewer +- ✅ Event queue + manual simulation +- ✅ Price-threshold broker (polling) that enqueues events +- ✅ Realtime brokers: stock trades via `StockDataStream` and Alpaca news (`wss://stream.data.alpaca.markets/v1beta1/news`) that forward triggers to the autopilot worker. +- ✅ Worker that marks plan steps done *and* re-runs the auto-trade pipeline (Responses/graph) for the affected ticker using `focus_override` to keep the reevaluation scoped. +- ⏳ Next: plug real webhook providers and extend the worker to place trades automatically once decisions flip to BUY/SELL. + +## Notes on Webhooks vs Polling + +- **Price Alerts** – Polygon.io, Alpaca, and Tiingo all support webhook alerts. We can fall back to `subscribe_price(symbol, operator, value)` that polls every minute if no webhook support. +- **News** – Few providers push webhooks. Run a background RSS/REST poller that hits Benzinga/Alpha Vantage news endpoints and emits synthetic webhook events. +- **Deadlines** – Heartbeats require no vendor support: schedule `HypothesisEvent(type="deadline")` using APScheduler or a managed cron (Temporal, Airflow). + +## Next Steps + +1. Implement the `HypothesisStore` and registry API (FastAPI blueprint or Flask blueprint under `tradingagents/services`). +2. Extend `AutoTradeResult` handling to register hypotheses + plan steps. +3. Build the subscription broker (start with polling, abstract interface for real vendor webhooks later). +4. Create the worker that consumes events and reuses the Responses orchestrator with focused prompts. +5. Update CLI to display hypothesis queue and allow manual cancellation. +6. Add config knobs for enabling autopilot mode, vendor credentials, heartbeat interval, and maximum concurrent hypotheses. + +With these pieces in place, hypothesized trades can stay "alive" outside the CLI session, react instantly to price/news events, and keep the plan tracker up to date via webhook-driven reevaluations. diff --git a/docs/fixtures/autopilot_seed.json b/docs/fixtures/autopilot_seed.json new file mode 100644 index 00000000..ef3a31c1 --- /dev/null +++ b/docs/fixtures/autopilot_seed.json @@ -0,0 +1,208 @@ +{ + "hypotheses": [ + { + "id": "seed_nvda_trigger", + "ticker": "NVDA", + "action": "HOLD", + "priority": 0.65, + "status": "monitoring", + "rationale": "Watching for NVDA trim/add levels to fire while earnings jitters persist.", + "notes": "Hold core 10 shares. Trim above 195 to lock gains; add below 184.50 to average down.", + "plan": [ + { + "id": "nvda_step_1", + "description": "Monitor intraday price for >= 195 to trim 5 shares", + "status": "pending", + "metadata": { + "next_decision": "Trim 5 shares if 195 prints" + } + }, + { + "id": "nvda_step_2", + "description": "Monitor intraday price for <= 184.50 to add 5 shares", + "status": "pending", + "metadata": { + "next_decision": "Add 5 shares if 184.50 breaks" + } + }, + { + "id": "nvda_step_3", + "description": "Re-evaluate at close if no triggers fire", + "status": "pending", + "metadata": {} + } + ], + "strategy": { + "name": "swing", + "horizon_hours": 72, + "target_pct": 0.04, + "stop_pct": 0.02, + "success_metric": "Gain at least +4.0% within 72h", + "failure_metric": "Drawdown beyond -2.0% or catalyst fades", + "follow_up": "reassess_every_close", + "urgency": "medium", + "deadline": "2025-11-16T15:30:00Z", + "notes": "Seed swing plan" + }, + "triggers": [ + "price >= 195", + "price <= 184.5", + "news:NVDA" + ], + "created_at": "2025-11-13T15:30:00.000000Z", + "updated_at": "2025-11-13T15:30:00.000000Z", + "source_snapshot": "2025-11-13T15:20:00" + }, + { + "id": "seed_aapl_breakout", + "ticker": "AAPL", + "action": "WATCH", + "priority": 0.45, + "status": "monitoring", + "rationale": "Breakout trade setup—only add if price clears resistance with volume.", + "notes": "Scale in 10 shares in two clips if > 178 with strong breadth.", + "plan": [ + { + "id": "aapl_step_1", + "description": "Alert on price >= 178 breakout", + "status": "pending", + "metadata": { + "next_decision": "Enter starter position if breakout confirmed" + } + }, + { + "id": "aapl_step_2", + "description": "Watch news sentiment for product-cycle surprises", + "status": "pending", + "metadata": { + "next_decision": "Hold off if negative news arrives" + } + } + ], + "strategy": { + "name": "swing", + "horizon_hours": 72, + "target_pct": 0.05, + "stop_pct": 0.025, + "success_metric": "Breakout extends +5% within horizon", + "failure_metric": "Breakout fails or drawdown exceeds -2.5%", + "follow_up": "reassess_every_close", + "urgency": "medium", + "deadline": "2025-11-16T15:31:00Z", + "notes": "Breakout ladder plan" + }, + "triggers": [ + "price >= 178", + "news:AAPL" + ], + "created_at": "2025-11-13T15:31:00.000000Z", + "updated_at": "2025-11-13T15:31:00.000000Z", + "source_snapshot": "2025-11-13T15:20:00" + }, + { + "id": "seed_tsla_breakout", + "ticker": "TSLA", + "action": "HOLD", + "priority": 0.4, + "status": "monitoring", + "rationale": "No position; wait for 420-435 breakout with positive catalysts before entering.", + "notes": "Starter 5-10 shares only if breakout + bullish news lines up.", + "plan": [ + { + "id": "tsla_step_1", + "description": "Price >= 425 breakout alert", + "status": "pending", + "metadata": { + "next_decision": "Consider starter position" + } + }, + { + "id": "tsla_step_2", + "description": "News sentiment scan for TSLA", + "status": "pending", + "metadata": {} + } + ], + "strategy": { + "name": "position", + "horizon_hours": 240, + "target_pct": 0.08, + "stop_pct": 0.04, + "success_metric": "Sustain breakout and gain 8% within horizon", + "failure_metric": "Breakout fails or macro turns", + "follow_up": "weekly_review", + "urgency": "low", + "deadline": "2025-11-23T15:32:00Z", + "notes": "Wait for confirmation" + }, + "triggers": [ + "price >= 425", + "news:TSLA" + ], + "created_at": "2025-11-13T15:32:00.000000Z", + "updated_at": "2025-11-13T15:32:00.000000Z", + "source_snapshot": "2025-11-13T15:20:00" + } + ], + "events": [ + { + "id": "evt_seed_nvda_news", + "hypothesis_id": "seed_nvda_trigger", + "event_type": "news", + "payload": { + "T": "n", + "headline": "Sample NVDA news", + "summary": "Seeded story to verify webhook ingestion.", + "symbols": ["NVDA"], + "source": "seed" + }, + "created_at": "2025-11-13T16:00:00.000000Z" + }, + { + "id": "evt_seed_aapl_price", + "hypothesis_id": "seed_aapl_breakout", + "event_type": "price_threshold", + "payload": { + "symbol": "AAPL", + "price": 178.12, + "operator": ">=", + "value": 178 + }, + "created_at": "2025-11-13T16:05:00.000000Z" + } + ], + "memory": { + "NVDA": [ + { + "timestamp": "2025-11-13T14:55:00Z", + "action": "HOLD", + "priority": 0.6, + "notes": "Holding 10 shares; trimming plan above 195.", + "plan_actions": [ + "monitor intraday price", + "trim above 195" + ], + "raw": { + "ticker": "NVDA", + "action": "HOLD" + } + } + ], + "AAPL": [ + { + "timestamp": "2025-11-13T14:58:00Z", + "action": "WATCH", + "priority": 0.45, + "notes": "Waiting for breakout before adding.", + "plan_actions": [ + "monitor breakout", + "ladder entries" + ], + "raw": { + "ticker": "AAPL", + "action": "WATCH" + } + } + ] + } +} diff --git a/main.py b/main.py index a85ee6ec..ffffcc42 100644 --- a/main.py +++ b/main.py @@ -1,31 +1,866 @@ -from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.default_config import DEFAULT_CONFIG +from __future__ import annotations +import json +import logging +import os +import sys +import threading +import time +from contextlib import nullcontext +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional + +import questionary from dotenv import load_dotenv +from rich.console import Console +from rich.panel import Panel +from rich.table import Table +from rich.text import Text -# Load environment variables from .env file -load_dotenv() +load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env") -# Create a custom config -config = DEFAULT_CONFIG.copy() -config["deep_think_llm"] = "gpt-4o-mini" # Use a different model -config["quick_think_llm"] = "gpt-4o-mini" # Use a different model -config["max_debate_rounds"] = 1 # Increase debate rounds +from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.services.account import AccountService, AccountSnapshot +from tradingagents.services.auto_trade import AutoTradeResult, AutoTradeService +from tradingagents.services.autopilot_worker import AutopilotWorker +from tradingagents.services.autopilot_broker import AutopilotBroker +from tradingagents.services.realtime_broker import RealtimeBroker +from tradingagents.services.realtime_news_broker import RealtimeNewsBroker +from tradingagents.services.hypothesis_store import HypothesisStore -# Configure data vendors (default uses yfinance and alpha_vantage) -config["data_vendors"] = { - "core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local - "technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local - "fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local - "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local -} -# Initialize with custom config -ta = TradingAgentsGraph(debug=True, config=config) +console = Console() -# forward propagate -_, decision = ta.propagate("NVDA", "2024-05-10") -print(decision) -# Memorize mistakes and reflect -# ta.reflect_and_remember(1000) # parameter is the position returns +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(message)s") + + config = DEFAULT_CONFIG.copy() + console.print(Panel("[bold]TradingAgents CLI[/bold]\nConnect to Alpaca MCP and manage auto-trading.", title="TradingAgents", expand=False)) + + autopilot_requested = "--autopilot" in sys.argv or bool(config.get("autopilot", {}).get("enabled")) + if "--autopilot" in sys.argv: + sys.argv.remove("--autopilot") + + results_root = Path(config.get("results_dir", "./results")).resolve() + hypothesis_store = HypothesisStore(results_root / "hypotheses") + realtime_state: Dict[str, Any] = {} + news_state: Dict[str, Any] = {} + + account_service = AccountService(config.get("alpaca_mcp", {})) + snapshot = _refresh_account_snapshot(account_service) + + graph = TradingAgentsGraph(debug=False, config=config, skip_initial_probes=True) + auto_trader = AutoTradeService(config=config, graph=graph) + autopilot_worker = AutopilotWorker(results_root, auto_trader, account_service) + autopilot_broker = AutopilotBroker(hypothesis_store, autopilot_worker) + + if autopilot_requested: + _run_autopilot_loop( + config, + snapshot, + auto_trader, + hypothesis_store, + autopilot_worker, + autopilot_broker, + realtime_state, + news_state, + account_service, + ) + return + + if not sys.stdin.isatty(): + console.print( + "Detected non-interactive environment. Running auto-trade once and exiting.", + style="yellow", + ) + _execute_auto_trade(auto_trader, snapshot, hypothesis_store) + return + + while True: + choice = questionary.select( + "Select an option", + choices=[ + "Refresh Alpaca snapshot", + "Show account summary", + "Show positions", + "Show recent orders", + "Show hypotheses", + "Run auto-trade", + "Simulate hypothesis event", + "Process autopilot events", + "Run price-alert poll", + "Start realtime broker", + "Start news broker", + "Exit", + ], + ).ask() + + if choice is None or choice == "Exit": + console.print("Goodbye!") + break + + if choice == "Refresh Alpaca snapshot": + snapshot = _refresh_account_snapshot(account_service) + elif choice == "Show account summary": + _render_account_summary(snapshot) + elif choice == "Show positions": + _render_positions(snapshot) + elif choice == "Show recent orders": + _render_orders(snapshot) + elif choice == "Show hypotheses": + _render_hypotheses(hypothesis_store) + elif choice == "Run auto-trade": + _execute_auto_trade(auto_trader, snapshot, hypothesis_store) + elif choice == "Simulate hypothesis event": + _simulate_hypothesis_event(hypothesis_store, autopilot_worker) + elif choice == "Process autopilot events": + _process_autopilot_events(autopilot_worker) + elif choice == "Run price-alert poll": + _run_price_alert_poll(autopilot_broker) + elif choice == "Start realtime broker": + _start_realtime_broker(config, hypothesis_store, autopilot_worker, realtime_state) + elif choice == "Start news broker": + _start_news_broker(config, hypothesis_store, autopilot_worker, news_state) + + +def _refresh_account_snapshot(account_service: AccountService) -> AccountSnapshot: + console.print("Connecting to Alpaca MCP …", style="bold cyan") + try: + snapshot = account_service.refresh() + except RuntimeError as exc: + console.print(str(exc), style="red") + raise SystemExit(1) + + console.print( + f"Snapshot fetched at {snapshot.fetched_at.strftime('%Y-%m-%d %H:%M:%S UTC')}", + style="green", + ) + _render_account_summary(snapshot) + return snapshot + + +def _render_account_summary(snapshot: AccountSnapshot) -> None: + table = Table(title="Account Summary", box=None) + table.add_column("Field", justify="left", style="cyan") + table.add_column("Value", justify="right", style="magenta") + + for key, label in [ + ("account_id", "Account ID"), + ("status", "Status"), + ("currency", "Currency"), + ("buying_power", "Buying Power"), + ("cash", "Cash"), + ("portfolio_value", "Portfolio Value"), + ("equity", "Equity"), + ("long_market_value", "Long Market Value"), + ("short_market_value", "Short Market Value"), + ("pattern_day_trader", "Pattern Day Trader"), + ("day_trades_remaining", "Day Trades Remaining"), + ]: + value = snapshot.account.get(key) + if value is not None: + table.add_row(label, str(value)) + + console.print(table) + + +def _render_positions(snapshot: AccountSnapshot) -> None: + if not snapshot.positions: + console.print("No open positions.") + return + + table = Table(title="Open Positions", box=None) + table.add_column("Symbol", style="cyan") + table.add_column("Quantity", justify="right") + table.add_column("Market Value", justify="right") + table.add_column("Cost Basis", justify="right") + + for position in snapshot.positions: + table.add_row( + str(position.get("symbol") or position.get("symbol:") or ""), + str(position.get("quantity") or position.get("qty") or ""), + str(position.get("market_value") or ""), + str(position.get("cost_basis") or ""), + ) + + console.print(table) + + +def _render_orders(snapshot: AccountSnapshot) -> None: + if not snapshot.orders: + console.print("No recent orders.") + return + + table = Table(title="Recent Orders", box=None) + table.add_column("Order ID") + table.add_column("Symbol") + table.add_column("Side") + table.add_column("Qty") + table.add_column("Status") + + for order in snapshot.orders: + table.add_row( + str(order.get("order_id") or order.get("id") or ""), + str(order.get("symbol") or ""), + str(order.get("side") or ""), + str(order.get("qty") or order.get("quantity") or ""), + str(order.get("status") or ""), + ) + + console.print(table) + + +def _render_hypotheses(store: HypothesisStore) -> None: + records = store.list() + if not records: + console.print("No stored hypotheses yet.", style="yellow") + return + + table = Table(title="Stored Hypotheses", box=None) + table.add_column("ID", style="dim") + table.add_column("Ticker", style="cyan") + table.add_column("Status") + table.add_column("Action") + table.add_column("Priority", justify="right") + table.add_column("Next Step") + table.add_column("Created", style="dim") + + display_limit = 15 + for record in records[:display_limit]: + next_step = record.next_open_step() + next_desc = next_step.description if next_step else "" + created = record.created_at.split("T")[0] + table.add_row( + record.id[-6:], + record.ticker, + record.status, + record.action, + f"{record.priority:.2f}", + next_desc, + created, + ) + + console.print(table) + remaining = len(records) - display_limit + if remaining > 0: + console.print(f"(+{remaining} more stored hypotheses)", style="dim") + + +def _simulate_hypothesis_event(store: HypothesisStore, worker: AutopilotWorker) -> None: + records = store.list() + if not records: + console.print("No hypotheses to simulate events for.", style="yellow") + return + + record_choices = [ + questionary.Choice( + title=f"{record.ticker} ({record.status}) – id {record.id[-6:]}", + value=record.id, + ) + for record in records[:25] + ] + hypothesis_id = questionary.select("Select hypothesis", choices=record_choices).ask() + if not hypothesis_id: + return + + event_type = questionary.select( + "Select event type", + choices=["price_threshold", "news", "heartbeat", "manual"], + ).ask() + if not event_type: + return + + payload_text = questionary.text( + "Optional JSON payload (press Enter to skip)", + default="", + ).ask() or "" + payload = {} + if payload_text.strip(): + try: + payload = json.loads(payload_text) + except json.JSONDecodeError: + console.print("Invalid JSON payload; storing empty payload instead.", style="yellow") + + event = worker.enqueue_event(hypothesis_id, event_type, payload) + console.print( + f"Enqueued {event.event_type} event for hypothesis {hypothesis_id[-6:]}", + style="green", + ) + + +def _process_autopilot_events(worker: AutopilotWorker) -> None: + processed = worker.process_all() + if not processed: + console.print("No autopilot events queued.", style="yellow") + return + + table = Table(title="Autopilot Event Processing", box=None) + table.add_column("Event ID", style="dim") + table.add_column("Hypothesis", style="cyan") + table.add_column("Type") + table.add_column("Status") + table.add_column("Message") + + for result in processed: + table.add_row( + result.event.id[-8:], + result.event.hypothesis_id[-6:], + result.event.event_type, + result.status, + result.message, + ) + + console.print(table) + + +def _run_price_alert_poll(broker: AutopilotBroker) -> None: + outcomes = broker.poll_once() + if not outcomes: + console.print("No price triggers fired during this poll.", style="yellow") + return + table = Table(title="Price Trigger Poll", box=None) + table.add_column("Event ID", style="dim") + table.add_column("Result") + for event_id, message in outcomes.items(): + table.add_row(event_id[-8:], message) + console.print(table) + console.print("Processing newly queued events …", style="dim") + _process_autopilot_events(broker.worker) + + +def _start_realtime_broker( + config: Dict[str, Any], + store: HypothesisStore, + worker: AutopilotWorker, + state: Dict[str, Any], +) -> None: + if state.get("thread") and state["thread"].is_alive(): + console.print("Realtime broker already running.", style="yellow") + return + + api_key = os.getenv("APCA_API_KEY_ID") or config.get("market_data", {}).get("api_key") + secret_key = os.getenv("APCA_API_SECRET_KEY") or config.get("market_data", {}).get("secret_key") + feed = (config.get("market_data", {}) or {}).get("feed", "iex") + if not api_key or not secret_key: + console.print("Set APCA_API_KEY_ID / APCA_API_SECRET_KEY env vars to use realtime broker.", style="red") + return + + try: + broker = RealtimeBroker(store, worker, api_key, secret_key, feed=feed) + except RuntimeError as exc: + console.print(str(exc), style="red") + return + + def _run(): + broker.run_forever() + + thread = threading.Thread(target=_run, daemon=True) + thread.start() + state["thread"] = thread + state["broker"] = broker + console.print("Realtime broker started in background thread.", style="green") + + +def _start_news_broker( + config: Dict[str, Any], + store: HypothesisStore, + worker: AutopilotWorker, + state: Dict[str, Any], +) -> None: + thread = state.get("thread") + if thread and thread.is_alive(): + console.print("News broker already running.", style="yellow") + return + + api_key = os.getenv("APCA_API_KEY_ID") or config.get("market_data", {}).get("api_key") + secret_key = os.getenv("APCA_API_SECRET_KEY") or config.get("market_data", {}).get("secret_key") + url = (config.get("market_data", {}) or {}).get("news_stream_url") + if not api_key or not secret_key: + console.print("Set APCA_API_KEY_ID / APCA_API_SECRET_KEY to use news broker.", style="red") + return + + broker = RealtimeNewsBroker(store, worker, api_key, secret_key, url=url) + try: + broker.start() + except Exception as exc: # pragma: no cover - network bootstrap errors + console.print(f"Failed to start news broker: {exc}", style="red") + return + + state["broker"] = broker + state["thread"] = broker._thread + console.print("News broker started in background thread.", style="green") + + +def _execute_auto_trade( + auto_trader: AutoTradeService, + snapshot: AccountSnapshot, + hypothesis_store: HypothesisStore, + *, + compact: bool = False, + skip_if_market_closed: bool = False, + allow_market_closed: bool = False, +) -> bool: + should_skip = skip_if_market_closed and bool( + (auto_trader.config.get("auto_trade", {}) or {}).get("skip_when_market_closed", True) + ) + if should_skip: + is_open, reason = _market_is_open(auto_trader) + if not is_open: + suffix = f" ({reason})" if reason else "" + console.print(f"Skipping auto-trade: market is closed{suffix}.", style="yellow") + return False + + console.print("Running auto-trade …", style="bold cyan") + try: + result = auto_trader.run(snapshot, allow_market_closed=allow_market_closed) + except Exception as exc: # pragma: no cover - surfaced to CLI + console.print(f"Auto-trade failed: {exc}", style="red") + logging.exception("Auto-trade failed") + return False + _render_auto_trade_result(result, compact=compact) + results_dir = Path(auto_trader.config.get("results_dir", "./results")) + _persist_auto_trade_result(result, results_dir) + new_records = hypothesis_store.record_result(result) + if new_records: + console.print(f"Recorded {len(new_records)} hypothesis{'es' if len(new_records) != 1 else ''} for autopilot follow-up.", style="green") + return True + + +def _render_auto_trade_result(result: AutoTradeResult, *, compact: bool = False) -> None: + console.rule("Auto-Trade Result") + focus = ", ".join(result.focus_tickers) or "" + console.print( + f"Focus tickers: {focus}\n" + f"Buying Power: ${result.account_snapshot.buying_power():,.0f}\n" + f"Cash: ${result.account_snapshot.cash():,.0f}" + ) + skip_reason = (result.raw_state or {}).get("skip_reason") if result.raw_state is not None else None + if skip_reason: + console.print(skip_reason, style="yellow") + + if compact: + if not result.decisions: + console.print("No decisions produced.", style="yellow") + return + table = Table(title="Decisions", box=None) + table.add_column("Ticker", style="cyan") + table.add_column("Action", style="magenta") + table.add_column("Next", overflow="fold") + for decision in result.decisions: + action = (decision.final_decision or decision.immediate_action or "hold").upper() + next_hint = decision.sequential_plan.next_decision or decision.sequential_plan.notes or decision.final_notes or "" + table.add_row(decision.ticker, action, next_hint) + console.print(table) + return + + transcript = (result.raw_state or {}).get("responses_transcript") if result.raw_state is not None else None + if transcript: + console.rule("Narrative") + for idx, entry in enumerate(transcript, 1): + console.print(f"[bold]Step {idx}: [/bold]{entry}") + console.rule("Decisions") + + if not result.decisions: + console.print("No decisions produced.", style="yellow") + return + + for decision in result.decisions: + header = f"[bold]{decision.ticker}[/bold] – action: [cyan]{decision.immediate_action.upper()}[/cyan]" + table = Table(title=header, box=None) + table.add_column("Field", style="cyan") + table.add_column("Value", style="magenta") + + required = decision.hypothesis.get("required_analysts", []) + plan_next = ( + decision.sequential_plan.next_decision.upper() + if decision.sequential_plan.next_decision + else "" + ) + table.add_row("Priority", f"{decision.priority:.2f}") + table.add_row("Required Analysts", ", ".join(required) or "") + table.add_row("Plan Actions", " → ".join(decision.sequential_plan.actions) or "") + table.add_row("Plan Next Decision", plan_next) + table.add_row("Action Queue", " → ".join(decision.action_queue or []) or "") + table.add_row("Planner Notes", decision.sequential_plan.notes or "") + table.add_row("Final Decision", decision.final_decision or "") + table.add_row("Trader Plan", decision.trader_plan or "") + + console.print(table) + + if decision.sequential_plan.reasoning: + console.print(Text("Reasoning:", style="bold underline")) + for idx, step in enumerate(decision.sequential_plan.reasoning, 1): + console.print(f" {idx}. {step}") + + if decision.final_notes: + console.print(Text("Final Notes:", style="bold underline")) + console.print(decision.final_notes) + + console.print() + + +def _persist_auto_trade_result(result: AutoTradeResult, results_dir: Path) -> None: + try: + results_dir.mkdir(parents=True, exist_ok=True) + path = results_dir / f"auto_trade_{datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')}.json" + with path.open("w", encoding="utf-8") as handle: + json.dump(result.summary(), handle, indent=2) + console.print(f"Saved auto-trade summary to {path}", style="green") + except Exception as exc: # pragma: no cover - persistence best effort + console.print(f"Failed to persist auto-trade summary: {exc}", style="red") + + +def _run_autopilot_loop( + config: Dict[str, Any], + snapshot: AccountSnapshot, + auto_trader: AutoTradeService, + hypothesis_store: HypothesisStore, + autopilot_worker: AutopilotWorker, + autopilot_broker: AutopilotBroker, + realtime_state: Dict[str, Any], + news_state: Dict[str, Any], + account_service: AccountService, +) -> None: + autopilot_cfg = config.get("autopilot", {}) or {} + event_interval = max(int(autopilot_cfg.get("event_loop_interval_seconds", 10)), 1) + price_poll_interval = max(int(autopilot_cfg.get("price_poll_interval_seconds", 60)), event_interval) + seed_run = bool(autopilot_cfg.get("auto_trade_on_start", True)) + premarket_window = max(int(autopilot_cfg.get("pre_market_research_minutes", 30)), 0) + + console.print("Autopilot mode enabled. Press Ctrl+C to stop.", style="bold cyan") + + latest_snapshot = snapshot + pending_market_open_run = False + premarket_marker: Optional[str] = None + + def _refresh_snapshot() -> Optional[AccountSnapshot]: + nonlocal latest_snapshot + try: + snap = account_service.refresh() + latest_snapshot = snap + return snap + except Exception as exc: + console.print(f"Failed to refresh account snapshot: {exc}", style="red") + return None + + market_status = _get_market_status(auto_trader) + last_market_check = time.time() + market_is_open = bool(market_status.get("is_open")) + + if seed_run: + if market_is_open: + snap = _refresh_snapshot() + if snap: + ran = _execute_auto_trade( + auto_trader, + snap, + hypothesis_store, + compact=True, + skip_if_market_closed=False, + allow_market_closed=False, + ) + pending_market_open_run = not ran + else: + pending_market_open_run = True + reason = market_status.get("clock_text") or market_status.get("reason") or "market closed" + console.print(f"Initial run skipped: {reason}.", style="yellow") + if premarket_window > 0 and _should_run_premarket(market_status, premarket_window): + snap = _refresh_snapshot() + if snap and _execute_auto_trade( + auto_trader, + snap, + hypothesis_store, + compact=True, + skip_if_market_closed=False, + allow_market_closed=True, + ): + premarket_marker = market_status.get("next_open") + console.print( + "Pre-market research run completed; awaiting opening bell.", + style="dim", + ) + else: + console.print("Skipping initial auto-trade seed (auto_trade_on_start=false).", style="yellow") + + _start_realtime_broker(config, hypothesis_store, autopilot_worker, realtime_state) + _start_news_broker(config, hypothesis_store, autopilot_worker, news_state) + + last_price_poll = 0.0 + last_signature = "" + last_heartbeat = 0.0 + heartbeat_interval = max(event_interval, 30) + market_check_interval = max(price_poll_interval, 60) + events_since_heartbeat = 0 + console.print( + f"Entering autopilot loop (event every {event_interval}s, price poll every {price_poll_interval}s)…", + style="dim", + ) + + try: + market_status = _get_market_status(auto_trader) + last_market_check = time.time() + while True: + events_since_heartbeat += _drain_autopilot_queue(autopilot_worker) + + records = hypothesis_store.list() + signature = _hypothesis_signature(records) + if signature != last_signature: + last_signature = signature + _refresh_stream_registrations(realtime_state, news_state, records) + + now = time.time() + if now - last_price_poll >= price_poll_interval: + events_since_heartbeat += _poll_price_alerts_quiet( + autopilot_broker, + autopilot_worker, + market_open=bool(market_status.get("is_open")), + ) + last_price_poll = now + + if now - last_heartbeat >= heartbeat_interval: + stats = _collect_stream_stats(realtime_state, news_state) + _print_autopilot_heartbeat(events_since_heartbeat, stats) + events_since_heartbeat = 0 + last_heartbeat = now + + if now - last_market_check >= market_check_interval: + market_status = _get_market_status(auto_trader) + last_market_check = now + is_open = bool(market_status.get("is_open")) + if is_open: + if pending_market_open_run: + snap = _refresh_snapshot() + if snap and _execute_auto_trade( + auto_trader, + snap, + hypothesis_store, + compact=True, + skip_if_market_closed=False, + allow_market_closed=False, + ): + pending_market_open_run = False + premarket_marker = None + else: + pending_market_open_run = True + if premarket_window > 0 and _should_run_premarket(market_status, premarket_window): + marker = market_status.get("next_open") + if marker and marker != premarket_marker: + snap = _refresh_snapshot() + if snap and _execute_auto_trade( + auto_trader, + snap, + hypothesis_store, + compact=True, + skip_if_market_closed=False, + allow_market_closed=True, + ): + console.print( + "Pre-market research run completed; awaiting opening bell.", + style="dim", + ) + premarket_marker = marker + + time.sleep(event_interval) + except KeyboardInterrupt: + console.print("Autopilot loop stopped by user request.", style="yellow") + + +def _drain_autopilot_queue(worker: AutopilotWorker) -> int: + try: + processed = worker.process_all() + except Exception as exc: # pragma: no cover - best effort logging + console.print(f"Autopilot event processing failed: {exc}", style="red") + return 0 + + if not processed: + return 0 + + table = Table(title="Autopilot Updates", box=None) + table.add_column("Hypothesis", style="cyan") + table.add_column("Event") + table.add_column("Status", style="green") + table.add_column("Message", style="magenta") + + max_rows = 10 + for item in processed[:max_rows]: + hypothesis = item.event.hypothesis_id[-6:] + table.add_row( + hypothesis, + item.event.event_type, + item.status, + item.message, + ) + + console.print(table) + if len(processed) > max_rows: + console.print(f"(+{len(processed) - max_rows} more events)", style="dim") + + return len(processed) + + +def _refresh_stream_registrations( + realtime_state: Dict[str, Any], + news_state: Dict[str, Any], + records: List[Any], +) -> None: + broker = realtime_state.get("broker") + if broker is not None: + try: + registered = broker.refresh_triggers(records, reset=True) + if registered: + console.print(f"Realtime broker tracking {registered} trigger(s).", style="dim") + except Exception as exc: # pragma: no cover - best effort logging + console.print(f"Realtime broker refresh failed: {exc}", style="red") + + news_broker = news_state.get("broker") + if news_broker is not None: + try: + watchers = news_broker.refresh_watchers(records) + if watchers: + console.print(f"News broker monitoring {watchers} symbol-link(s).", style="dim") + except Exception as exc: # pragma: no cover - best effort logging + console.print(f"News broker refresh failed: {exc}", style="red") + + +def _collect_stream_stats( + realtime_state: Dict[str, Any], + news_state: Dict[str, Any], +) -> Dict[str, Any]: + price_thread = realtime_state.get("thread") + price_connected = bool(price_thread and getattr(price_thread, "is_alive", lambda: False)()) + price_symbols = 0 + price_triggers = 0 + broker = realtime_state.get("broker") + if broker is not None: + lock = getattr(broker, "_lock", None) + context = lock or nullcontext() + with context: + trigger_map = getattr(broker, "triggers", {}) or {} + price_symbols = len(trigger_map) + price_triggers = sum(len(bucket) for bucket in trigger_map.values()) + + news_thread = news_state.get("thread") + news_connected = bool(news_thread and getattr(news_thread, "is_alive", lambda: False)()) + news_symbols = 0 + news_broker = news_state.get("broker") + if news_broker is not None: + lock = getattr(news_broker, "_lock", None) + context = lock or nullcontext() + with context: + watchers = getattr(news_broker, "watchers", {}) or {} + news_symbols = len(watchers) + + return { + "price_connected": price_connected, + "price_symbols": price_symbols, + "price_triggers": price_triggers, + "news_connected": news_connected, + "news_symbols": news_symbols, + } + + +def _print_autopilot_heartbeat(events_processed: int, stats: Dict[str, Any]) -> None: + price_status = "connected" if stats.get("price_connected") else "idle" + news_status = "connected" if stats.get("news_connected") else "idle" + message = ( + f"[dim]Heartbeat – events:{events_processed} | price stream {price_status} " + f"({stats.get('price_symbols', 0)} symbols, {stats.get('price_triggers', 0)} triggers) " + f"| news stream {news_status} ({stats.get('news_symbols', 0)} symbols).[/dim]" + ) + console.print(message) + + +def _hypothesis_signature(records: List[Any]) -> str: + if not records: + return "" + parts = [f"{getattr(rec, 'id', '')}:{getattr(rec, 'updated_at', '')}:{getattr(rec, 'status', '')}:{getattr(rec, 'action', '')}" for rec in records] + return "|".join(parts) + + +def _poll_price_alerts_quiet( + broker: AutopilotBroker, + worker: AutopilotWorker, + *, + market_open: bool, +) -> int: + if not market_open: + return 0 + try: + outcomes = broker.poll_once() + except Exception as exc: # pragma: no cover - best effort logging + console.print(f"Price alert poll failed: {exc}", style="red") + return 0 + + if not outcomes: + return 0 + + table = Table(title="Price Trigger Alerts", box=None) + table.add_column("Event", style="cyan") + table.add_column("Message", style="magenta") + for event_id, message in outcomes.items(): + table.add_row(event_id[-8:], message) + console.print(table) + + # Process the events immediately so hypotheses update promptly. + return _drain_autopilot_queue(worker) + + +def _market_is_open(auto_trader: AutoTradeService) -> (bool, Optional[str]): + checker = getattr(auto_trader.graph, "check_market_status", None) + if not callable(checker): + return True, None + try: + status = checker() or {} + except Exception: + return True, None + + is_open = status.get("is_open") + if is_open: + return True, None + reason = status.get("clock_text") or status.get("reason") + return False, reason + + +def _get_market_status(auto_trader: AutoTradeService) -> Dict[str, Any]: + checker = getattr(auto_trader.graph, "check_market_status", None) + if not callable(checker): + return {"is_open": True, "reason": "clock_unavailable"} + try: + return checker() or {} + except Exception as exc: + console.print(f"Failed to fetch market status: {exc}", style="red") + return {"is_open": False, "reason": "clock_error"} + + +def _should_run_premarket(status: Dict[str, Any], window_minutes: int) -> bool: + if window_minutes <= 0: + return False + next_open = _parse_market_time(status.get("next_open")) + if not next_open: + return False + now_utc = datetime.now(timezone.utc) + target = next_open.astimezone(timezone.utc) + minutes = (target - now_utc).total_seconds() / 60 + return 0 <= minutes <= window_minutes + + +def _parse_market_time(value: Optional[str]) -> Optional[datetime]: + if not value: + return None + text = value.strip() + if "T" not in text and " " in text: + text = text.replace(" ", "T", 1) + try: + dt = datetime.fromisoformat(text) + except ValueError: + return None + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 63af4721..cca91e5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,9 @@ dependencies = [ "langchain-experimental>=0.3.4", "langchain-google-genai>=2.1.5", "langchain-openai>=0.3.23", + "openai>=1.52.0", "langgraph>=0.4.8", + "mcp[servers]>=1.6.0", "pandas>=2.3.0", "parsel>=1.10.0", "praw>=7.8.1", diff --git a/requirements.txt b/requirements.txt index a6154cd2..4f45ad58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ typing-extensions langchain-openai +openai>=1.52.0 langchain-experimental pandas yfinance @@ -8,6 +9,7 @@ feedparser stockstats eodhd langgraph +mcp[servers]>=1.6.0 chromadb setuptools backtrader diff --git a/run_autopilot.sh b/run_autopilot.sh new file mode 100755 index 00000000..62154b82 --- /dev/null +++ b/run_autopilot.sh @@ -0,0 +1,6 @@ +#!/bin/bash +VENV=/Users/slavanikitin/Documents/Projects/TradingAgents/.venv +if [ -d "$VENV" ]; then + source "$VENV/bin/activate" +fi +PYTHONPATH="$VENV/lib/python3.13/site-packages:$PYTHONPATH" python3 main.py --autopilot diff --git a/scripts/seed_autopilot_state.py b/scripts/seed_autopilot_state.py new file mode 100644 index 00000000..446d2dd3 --- /dev/null +++ b/scripts/seed_autopilot_state.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +"""Utility to reset autopilot state and seed sample history for testing.""" + +from __future__ import annotations + +import argparse +import json +import shutil +import sys +from pathlib import Path +from typing import Any, Dict + +# Ensure the project root (one level up from `scripts/`) is importable when the +# script is executed via `python scripts/...` without installing the package. +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.services.account import AccountService +from tradingagents.services.auto_trade import AutoTradeResult, AutoTradeService +from tradingagents.services.hypothesis_store import HypothesisStore + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--results-dir", + default=DEFAULT_CONFIG.get("results_dir", "./results"), + help="Root directory for TradingAgents results (default: %(default)s)", + ) + parser.add_argument( + "--memory-dir", + default=((DEFAULT_CONFIG.get("auto_trade", {}) or {}).get("memory", {}) or {}).get("dir", "./results/memory"), + help="Directory that stores ticker memory snapshots (default pulled from config).", + ) + parser.add_argument( + "--fixture", + default="docs/fixtures/autopilot_seed.json", + help="JSON file containing seed hypotheses/events/memory.", + ) + parser.add_argument( + "--skip-fixture", + action="store_true", + help="Only wipe state; do not load the fixture data.", + ) + parser.add_argument( + "--auto-trade", + action="store_true", + help="Run a fresh auto-trade after seeding (requires Alpaca MCP).", + ) + parser.add_argument( + "--force", + action="store_true", + help="Skip confirmation prompt when deleting existing data.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + results_dir = Path(args.results_dir).expanduser().resolve() + hypothesis_dir = results_dir / "hypotheses" + autopilot_dir = results_dir / "autopilot" + memory_dir = Path(args.memory_dir).expanduser().resolve() + + targets = [hypothesis_dir, autopilot_dir, memory_dir] + existing = [path for path in targets if path.exists()] + if existing and not args.force: + response = input( + "This will delete existing hypotheses/autopilot/memory data. Continue? [y/N] " + ).strip() + if response.lower() not in {"y", "yes"}: + print("Aborted.") + sys.exit(1) + + for path in targets: + reset_directory(path) + + if args.skip_fixture: + print("State cleared. No fixture loaded (per --skip-fixture).") + else: + seed_fixture(Path(args.fixture), hypothesis_dir, autopilot_dir, memory_dir) + + if args.auto_trade: + run_auto_trade(results_dir) + + print("\nDone. You can now run `python main.py --autopilot` to test from the seeded state.") + + +def reset_directory(path: Path) -> None: + if path.exists(): + shutil.rmtree(path) + path.mkdir(parents=True, exist_ok=True) + + +def seed_fixture(fixture_path: Path, hypothesis_dir: Path, autopilot_dir: Path, memory_dir: Path) -> None: + if not fixture_path.exists(): + print(f"Fixture {fixture_path} not found; skipping seed.") + return + + with fixture_path.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + + hypotheses = payload.get("hypotheses") or [] + if hypotheses: + out_path = hypothesis_dir / "hypotheses.json" + with out_path.open("w", encoding="utf-8") as handle: + json.dump(hypotheses, handle, indent=2) + print(f"Seeded {len(hypotheses)} hypothesis records → {out_path}") + + events = payload.get("events") or [] + if events: + out_path = autopilot_dir / "events.json" + autopilot_dir.mkdir(parents=True, exist_ok=True) + with out_path.open("w", encoding="utf-8") as handle: + json.dump(events, handle, indent=2) + print(f"Seeded {len(events)} autopilot events → {out_path}") + + memory_entries: Dict[str, Any] = payload.get("memory") or {} + if memory_entries: + memory_dir.mkdir(parents=True, exist_ok=True) + for ticker, entries in memory_entries.items(): + ticker_path = memory_dir / f"{ticker.upper()}.json" + with ticker_path.open("w", encoding="utf-8") as handle: + json.dump(entries, handle, indent=2) + print(f"Seeded memory for {len(memory_entries)} ticker(s) → {memory_dir}") + + +def run_auto_trade(results_dir: Path) -> None: + print("Running fresh auto-trade to capture new hypotheses …") + config = DEFAULT_CONFIG.copy() + graph = TradingAgentsGraph(debug=False, config=config, skip_initial_probes=True) + try: + account_service = AccountService(config.get("alpaca_mcp", {})) + snapshot = account_service.refresh() + except Exception as exc: # pragma: no cover - environment dependent + print(f"Failed to refresh account snapshot: {exc}") + return + + auto_trader = AutoTradeService(config=config, graph=graph) + try: + result: AutoTradeResult = auto_trader.run(snapshot) + except Exception as exc: # pragma: no cover - surfaced for operator + print(f"Auto-trade run failed: {exc}") + return + + store = HypothesisStore(results_dir / "hypotheses") + new_records = store.record_result(result) + print(f"Auto-trade run complete. Recorded {len(new_records)} new hypothesis entries.") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 793df3e6..322716df 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ setup( "langchain-openai>=0.0.2", "langchain-experimental>=0.0.40", "langgraph>=0.0.20", + "mcp>=1.6.0", "numpy>=1.24.0", "pandas>=2.0.0", "praw>=7.7.0", diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index d84d9eb1..ffbe5cb4 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -16,6 +16,7 @@ from .risk_mgmt.neutral_debator import create_neutral_debator from .managers.research_manager import create_research_manager from .managers.risk_manager import create_risk_manager +from .managers.orchestrator import create_portfolio_orchestrator from .trader.trader import create_trader @@ -28,6 +29,7 @@ __all__ = [ "create_bear_researcher", "create_bull_researcher", "create_research_manager", + "create_portfolio_orchestrator", "create_fundamentals_analyst", "create_market_analyst", "create_neutral_debator", diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index e20139cb..47734254 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,27 +1,45 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions -from tradingagents.dataflows.config import get_config +from tradingagents.agents.utils.logging_utils import log_report_preview +from tradingagents.agents.utils.tool_runner import ( + extract_text_from_content, + run_chain_with_tools, +) def create_fundamentals_analyst(llm): def fundamentals_analyst_node(state): + scheduled_list = state.get("scheduled_analysts", []) or [] + scheduled_plan_list = state.get("scheduled_analysts_plan", []) or [] + scheduled_plan = {item.lower() for item in scheduled_plan_list} + action = state.get("orchestrator_action", "").lower() + ticker = state.get("target_ticker") or state["company_of_interest"] + if (scheduled_plan and "fundamentals" not in scheduled_plan) or action not in ("", "monitor", "escalate", "trade", "execute"): + try: + print(f"[Fundamentals Analyst] Skipping for {ticker} | action={action} | scheduled_plan={scheduled_plan}") + except Exception: + pass + return { + "fundamentals_report": "Fundamentals analyst skipped by orchestrator directive.", + "scheduled_analysts": [item for item in scheduled_list if item.lower() != "fundamentals"], + "scheduled_analysts_plan": scheduled_plan_list, + } current_date = state["trade_date"] - ticker = state["company_of_interest"] - company_name = state["company_of_interest"] + company_name = ticker tools = [ get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, + get_insider_sentiment, + get_insider_transactions, ] system_message = ( "You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." + " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read." - + " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements.", + + " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements, plus `get_insider_sentiment` and `get_insider_transactions` for ownership context.", ) prompt = ChatPromptTemplate.from_messages( @@ -48,16 +66,44 @@ def create_fundamentals_analyst(llm): chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) + try: + print(f"[Fundamentals Analyst] Running analysis for {ticker} | scheduled={scheduled_list}") + except Exception: + pass - report = "" + try: + result, message_history, tool_runs = run_chain_with_tools( + chain, + tools, + state["messages"], + ) + except Exception as exc: + updated_schedule = [item for item in scheduled_list if item.lower() != "fundamentals"] + report = f"Fundamentals analyst failed: {exc}" + return { + "messages": state["messages"], + "fundamentals_report": report, + "scheduled_analysts": updated_schedule, + "scheduled_analysts_plan": scheduled_plan_list, + } - if len(result.tool_calls) == 0: - report = result.content + report = extract_text_from_content(getattr(result, "content", "")) or "Fundamentals analyst produced no narrative." + updated_schedule = [item for item in scheduled_list if item.lower() != "fundamentals"] - return { - "messages": [result], + payload = { + "messages": message_history, "fundamentals_report": report, + "scheduled_analysts": updated_schedule, + "scheduled_analysts_plan": scheduled_plan_list, } + log_report_preview("Fundamentals Analyst", ticker, report) + try: + print( + f"[Fundamentals Analyst] Completed step for {ticker} | tool_runs={tool_runs} | report_len={len(report) if report else 0}" + ) + except Exception: + pass + + return payload return fundamentals_analyst_node diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index c955dd76..8f966876 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,16 +1,32 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators -from tradingagents.dataflows.config import get_config +from tradingagents.agents.utils.logging_utils import log_report_preview +from tradingagents.agents.utils.tool_runner import ( + extract_text_from_content, + run_chain_with_tools, +) def create_market_analyst(llm): def market_analyst_node(state): + scheduled_list = state.get("scheduled_analysts", []) or [] + scheduled_plan_list = state.get("scheduled_analysts_plan", []) or [] + scheduled_plan = {item.lower() for item in scheduled_plan_list} + action = state.get("orchestrator_action", "").lower() + ticker = state.get("target_ticker") or state["company_of_interest"] + if (scheduled_plan and "market" not in scheduled_plan) or action not in ("", "monitor", "escalate", "trade", "execute"): + try: + print(f"[Market Analyst] Skipping for ticker {ticker} | action={action} | scheduled_plan={scheduled_plan}") + except Exception: + pass + return { + "market_report": "Market analyst skipped by orchestrator directive.", + "scheduled_analysts": [item for item in scheduled_list if item.lower() != "market"], + "scheduled_analysts_plan": scheduled_plan_list, + } current_date = state["trade_date"] - ticker = state["company_of_interest"] - company_name = state["company_of_interest"] + company_name = ticker tools = [ get_stock_data, @@ -70,16 +86,44 @@ Volume-Based Indicators: chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) + try: + print(f"[Market Analyst] Running analysis for {ticker} | action={action} | scheduled={scheduled_list}") + except Exception: + pass - report = "" + try: + result, message_history, tool_runs = run_chain_with_tools( + chain, + tools, + state["messages"], + ) + except Exception as exc: + report = f"Market analyst failed: {exc}" + updated_schedule = [item for item in scheduled_list if item.lower() != "market"] + return { + "messages": state["messages"], + "market_report": report, + "scheduled_analysts": updated_schedule, + "scheduled_analysts_plan": scheduled_plan_list, + } - if len(result.tool_calls) == 0: - report = result.content - - return { - "messages": [result], + report = extract_text_from_content(getattr(result, "content", "")) or "Market analyst produced no narrative." + updated_schedule = [item for item in scheduled_list if item.lower() != "market"] + + payload = { + "messages": message_history, "market_report": report, + "scheduled_analysts": updated_schedule, + "scheduled_analysts_plan": scheduled_plan_list, } + log_report_preview("Market Analyst", ticker, report) + try: + print( + f"[Market Analyst] Completed step for {ticker} | tool_runs={tool_runs} | report_len={len(report) if report else 0}" + ) + except Exception: + pass + + return payload return market_analyst_node diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index 03b4fae4..eb16070c 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,14 +1,26 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json from tradingagents.agents.utils.agent_utils import get_news, get_global_news -from tradingagents.dataflows.config import get_config +from tradingagents.agents.utils.logging_utils import log_report_preview +from tradingagents.agents.utils.tool_runner import ( + extract_text_from_content, + run_chain_with_tools, +) def create_news_analyst(llm): def news_analyst_node(state): + scheduled_list = state.get("scheduled_analysts", []) or [] + scheduled_plan_list = state.get("scheduled_analysts_plan", []) or [] + scheduled_plan = {item.lower() for item in scheduled_plan_list} + action = state.get("orchestrator_action", "").lower() + ticker = state.get("target_ticker") or state["company_of_interest"] + if (scheduled_plan and "news" not in scheduled_plan) or action not in ("", "monitor", "escalate", "trade", "execute"): + return { + "news_report": "News analyst skipped by orchestrator directive.", + "scheduled_analysts": [item for item in scheduled_list if item.lower() != "news"], + "scheduled_analysts_plan": scheduled_plan_list, + } current_date = state["trade_date"] - ticker = state["company_of_interest"] tools = [ get_news, @@ -43,16 +55,44 @@ def create_news_analyst(llm): prompt = prompt.partial(ticker=ticker) chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) + try: + print(f"[News Analyst] Running analysis for {ticker} | scheduled={scheduled_list}") + except Exception: + pass - report = "" + try: + result, message_history, tool_runs = run_chain_with_tools( + chain, + tools, + state["messages"], + ) + except Exception as exc: + updated_schedule = [item for item in scheduled_list if item.lower() != "news"] + report = f"News analyst failed: {exc}" + return { + "messages": state["messages"], + "news_report": report, + "scheduled_analysts": updated_schedule, + "scheduled_analysts_plan": scheduled_plan_list, + } - if len(result.tool_calls) == 0: - report = result.content + report = extract_text_from_content(getattr(result, "content", "")) or "News analyst produced no narrative." + updated_schedule = [item for item in scheduled_list if item.lower() != "news"] - return { - "messages": [result], + payload = { + "messages": message_history, "news_report": report, + "scheduled_analysts": updated_schedule, + "scheduled_analysts_plan": scheduled_plan_list, } + log_report_preview("News Analyst", ticker, report) + try: + print( + f"[News Analyst] Completed step for {ticker} | tool_runs={tool_runs} | report_len={len(report) if report else 0}" + ) + except Exception: + pass + + return payload return news_analyst_node diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index b25712d7..b93a7f8d 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -2,14 +2,25 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder import time import json from tradingagents.agents.utils.agent_utils import get_news +from tradingagents.agents.utils.logging_utils import log_report_preview from tradingagents.dataflows.config import get_config def create_social_media_analyst(llm): def social_media_analyst_node(state): + scheduled_list = state.get("scheduled_analysts", []) or [] + scheduled_plan_list = state.get("scheduled_analysts_plan", []) or [] + scheduled_plan = {item.lower() for item in scheduled_plan_list} + action = state.get("orchestrator_action", "").lower() + ticker = state.get("target_ticker") or state["company_of_interest"] + if (scheduled_plan and "social" not in scheduled_plan) or action not in ("", "monitor", "escalate", "trade", "execute"): + return { + "sentiment_report": "Social media analyst skipped by orchestrator directive.", + "scheduled_analysts": [item for item in scheduled_list if item.lower() != "social"], + "scheduled_analysts_plan": scheduled_plan_list, + } current_date = state["trade_date"] - ticker = state["company_of_interest"] - company_name = state["company_of_interest"] + company_name = ticker tools = [ get_news, @@ -44,6 +55,11 @@ def create_social_media_analyst(llm): chain = prompt | llm.bind_tools(tools) + try: + print(f"[Social Analyst] Running analysis for {ticker} | scheduled={scheduled_list}") + except Exception: + pass + result = chain.invoke(state["messages"]) report = "" @@ -51,9 +67,21 @@ def create_social_media_analyst(llm): if len(result.tool_calls) == 0: report = result.content - return { + updated_schedule = [item for item in scheduled_list if item.lower() != "social"] + + payload = { "messages": [result], "sentiment_report": report, + "scheduled_analysts": updated_schedule, + "scheduled_analysts_plan": scheduled_plan_list, } + if report is not None: + log_report_preview("Social Analyst", ticker, report) + try: + print(f"[Social Analyst] Completed step for {ticker} | tool_calls={len(result.tool_calls)} | report_len={len(report) if report else 0}") + except Exception: + pass + + return payload return social_media_analyst_node diff --git a/tradingagents/agents/managers/orchestrator.py b/tradingagents/agents/managers/orchestrator.py new file mode 100644 index 00000000..0b97b34b --- /dev/null +++ b/tradingagents/agents/managers/orchestrator.py @@ -0,0 +1,619 @@ +"""Advanced portfolio orchestrator that prioritizes hypotheses and schedules analysts.""" + +from __future__ import annotations + +import json +import re +from datetime import date, datetime, timedelta +from typing import Any, Callable, Dict, List, Optional + +from tradingagents.dataflows.interface import route_to_vendor + + +def create_portfolio_orchestrator( + llm: Any, + profile: Dict[str, object], + context_fetcher: Callable[[List[str]], List[Dict[str, str]]], + fast_news_fetcher: Callable[[str, str, int, int], Dict[str, str]], + plan_generator: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, +): + """Create an orchestrator node that generates and manages portfolio hypotheses.""" + + profile_name = profile.get("profile_name", "Portfolio") + mandate = profile.get("mandate", "") + risk_limits = profile.get("risk_limits", {}) + notes = profile.get("notes", "") + universe_raw = profile.get("universe", "") + universe = [symbol.strip().upper() for symbol in universe_raw.split(",") if symbol.strip()] + sentiment_lookback = profile.get("sentiment_lookback_days", 2) + headline_limit = profile.get("news_headline_limit", 5) + market_lookback = profile.get("market_data_lookback_days", 30) + trade_policy = profile.get("trade_activation", {}) + threshold = profile.get("hypothesis_threshold", 0.6) + max_hypotheses = profile.get("max_concurrent_hypotheses", 2) + + risk_blurb = ( + f"Max single position: {risk_limits.get('max_single_position_pct', 'n/a')} | " + f"Max sector: {risk_limits.get('max_sector_pct', 'n/a')}" + ) + + def _parse_account(account_text: str) -> Dict[str, str]: + summary = {} + pattern = re.compile(r"^([A-Za-z ]+):\s*\$?([\d\.,\-]+)") + for line in account_text.splitlines(): + match = pattern.match(line.strip()) + if match: + key = match.group(1).strip().lower().replace(" ", "_") + summary[key] = match.group(2).strip() + return summary + + def _parse_positions(positions_text: str) -> List[Dict[str, str]]: + positions: List[Dict[str, str]] = [] + current: Dict[str, str] = {} + for line in positions_text.splitlines(): + line = line.strip() + if not line: + continue + if line.startswith("Symbol:"): + if current: + positions.append(current) + current = {} + if ":" in line: + key, value = line.split(":", 1) + current[key.strip().lower().replace(" ", "_")] = value.strip() + if current: + positions.append(current) + return positions + + def _to_float(value: Any) -> float: + if value is None: + return 0.0 + if isinstance(value, (int, float)): + return float(value) + text = str(value).strip() + if not text: + return 0.0 + cleaned = text.replace("$", "").replace(",", "") + try: + return float(cleaned) + except ValueError: + return 0.0 + + def orchestrator_node(state): + override_symbols_raw = state.get("orchestrator_focus_override") or [] + override_symbols = [str(sym).upper() for sym in override_symbols_raw if str(sym).strip()] + + if override_symbols: + symbols = override_symbols.copy() + else: + symbols = [sym.upper() for sym in dict.fromkeys(universe)] + pending_override = [str(sym).upper() for sym in state.get("orchestrator_pending_tickers", []) if sym] + incumbent = state.get("company_of_interest") + if incumbent and incumbent.upper() not in symbols: + symbols.append(incumbent.upper()) + + # Step 1: Gather live portfolio context for the universe + snapshots = context_fetcher(symbols) + quick_signals: Dict[str, Dict[str, str]] = { + str(key).upper(): value for key, value in (state.get("orchestrator_quick_signals", {}) or {}).items() + } + market_data_cache: Dict[str, str] = { + str(key).upper(): value for key, value in (state.get("orchestrator_market_data", {}) or {}).items() + } + + first_snapshot = snapshots[0] if snapshots else {} + # Step 3: Ask the LLM to synthesize hypotheses and routing plan + system_prompt = ( + "You are the head of trading. Review portfolio context, mandate, and quick signals to propose trading " + "hypotheses. Each hypothesis must include: ticker, rationale, priority (0-1), required_analysts list (subset " + "of ['market','social','news','fundamentals']), and immediate actions (monitor, abandon, escalate). Limit to the " + f"strongest {max_hypotheses} hypotheses above priority {threshold}. Respond with valid JSON containing keys " + "'hypotheses' (list), 'summary' (string), and 'status' (string)." + ) + + account_summary = _parse_account(first_snapshot.get("account", "")) if first_snapshot else {} + positions_summary = _parse_positions(first_snapshot.get("positions", "")) if first_snapshot else [] + buying_power_value = _to_float(account_summary.get("buying_power") or account_summary.get("buying_power_usd")) + cash_value = _to_float(account_summary.get("cash") or account_summary.get("cash_usd")) + portfolio_value = _to_float( + account_summary.get("portfolio_value") + or account_summary.get("equity") + or account_summary.get("equity_value") + ) + current_holdings = set() + for item in positions_summary: + symbol = (item.get("symbol") or item.get("symbol:") or "").upper() + qty_raw = item.get("quantity", "") + qty_value = 0.0 + if qty_raw: + match = re.search(r"[-+]?\d*\.?\d+", qty_raw.replace(",", "")) + if match: + try: + qty_value = float(match.group(0)) + except ValueError: + qty_value = 0.0 + if symbol and qty_value != 0.0: + current_holdings.add(symbol) + universe_gap = [sym for sym in symbols if sym not in current_holdings] + + payload = { + "profile": { + "name": profile_name, + "mandate": mandate, + "risk_limits": risk_limits, + "notes": notes, + "risk_summary": risk_blurb, + }, + "portfolio_snapshots": snapshots, + "quick_signals": {}, + "current_hypotheses": state.get("orchestrator_hypotheses", []), + "account_summary": account_summary, + "positions": positions_summary, + "existing_holdings": list(current_holdings), + "universe_candidates": universe_gap, + "trade_policy": trade_policy, + } + + response = llm.invoke( + [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": json.dumps(payload)}, + ] + ) + + try: + parsed = json.loads(response.content if hasattr(response, "content") else str(response)) + except json.JSONDecodeError: + parsed = {"hypotheses": [], "notes": "Failed to parse orchestrator response."} + + hypotheses_raw = parsed.get("hypotheses", [])[:max_hypotheses] + if override_symbols: + allowed = set(override_symbols) + filtered = [item for item in hypotheses_raw if isinstance(item, dict) and str(item.get("ticker", "")).upper() in allowed] + if filtered: + hypotheses = filtered + else: + hypotheses = [hypotheses_raw[0]] if hypotheses_raw else [] + else: + hypotheses = hypotheses_raw + summary_text = parsed.get("summary") or first_snapshot.get("summary_prompt", "") + status_text = parsed.get("status") or first_snapshot.get("status", "ok") + + trade_priority_threshold = float(trade_policy.get("priority_threshold", 0.8)) + min_cash_abs = float(trade_policy.get("min_cash_absolute", 0)) + min_cash_ratio = float(trade_policy.get("min_cash_ratio", 0)) + min_cash_required = max(min_cash_abs, portfolio_value * min_cash_ratio) + + if not hypotheses: + fallback_ticker = None + if override_symbols: + fallback_ticker = override_symbols[0] + elif symbols: + fallback_ticker = symbols[0] + elif incumbent: + fallback_ticker = incumbent.upper() + elif state.get("company_of_interest"): + fallback_ticker = str(state.get("company_of_interest")).upper() + + if fallback_ticker: + fallback_required = ["market", "news", "fundamentals"] + can_trade = buying_power_value >= min_cash_required + fallback_action = "trade" if can_trade else "escalate" + fallback_priority = max(trade_priority_threshold, 0.8) + fallback = { + "ticker": fallback_ticker, + "rationale": ( + f"Generated fallback hypothesis for {fallback_ticker} to evaluate trading opportunities with available buying power " + f"of ${buying_power_value:,.0f}. Analysts should validate signals before execution." + ), + "priority": fallback_priority, + "required_analysts": fallback_required, + "immediate_actions": fallback_action, + } + hypotheses = [fallback] + if override_symbols and fallback_ticker not in override_symbols: + override_symbols.append(fallback_ticker) + if fallback_ticker not in symbols: + symbols.insert(0, fallback_ticker) + + candidate_pool = override_symbols or symbols + + try: + print(f"[Orchestrator] Override focus: {override_symbols if override_symbols else ''}") + except Exception: + pass + + focus_symbols: List[str] = [] + hypothesis_map: Dict[str, Dict[str, Any]] = {} + for item in hypotheses: + ticker_value = str(item.get("ticker", "")).upper() if isinstance(item, dict) else "" + if ticker_value and ticker_value not in focus_symbols: + focus_symbols.append(ticker_value) + if isinstance(item, dict): + hypothesis_map[ticker_value] = item + if ticker_value and isinstance(item, dict): + hypothesis_map.setdefault(ticker_value, item) + for holding in current_holdings: + if holding not in focus_symbols: + focus_symbols.append(holding) + if incumbent: + incumbent_u = incumbent.upper() + if incumbent_u and incumbent_u not in focus_symbols: + focus_symbols.append(incumbent_u) + if not focus_symbols: + focus_symbols = symbols[: max_hypotheses or 1] + else: + for symbol in symbols: + if symbol not in focus_symbols: + focus_symbols.append(symbol) + + if pending_override: + ordered_focus: List[str] = [] + for sym in pending_override: + if sym in focus_symbols and sym not in ordered_focus: + ordered_focus.append(sym) + for sym in focus_symbols: + if sym not in ordered_focus: + ordered_focus.append(sym) + focus_symbols = ordered_focus + + try: + print(f"[Orchestrator] Focus tickers: {', '.join(focus_symbols)}") + except Exception: # pragma: no cover - defensive logging + pass + + if pending_override: + active_ticker = pending_override[0] + elif focus_symbols: + active_ticker = focus_symbols[0] + else: + active_ticker = state.get("target_ticker", state.get("company_of_interest", "")).upper() + + active = hypothesis_map.get(active_ticker) + if active is None and hypotheses: + fallback = hypotheses[0] + if isinstance(fallback, dict) and fallback.get("ticker"): + active_ticker = str(fallback.get("ticker")).upper() + active = fallback + + scheduled_sequence: List[str] = [] + if active and isinstance(active, dict): + for analyst in active.get("required_analysts", []): + role = str(analyst).lower() + if role not in scheduled_sequence: + scheduled_sequence.append(role) + immediate_action = (active.get("immediate_actions") or active.get("action") or "").lower() if isinstance(active, dict) else "" + + priority_val = float(active.get("priority") or 0) if isinstance(active, dict) else 0.0 + + holding_active = active_ticker in current_holdings + if isinstance(active, dict) and immediate_action in {"", "monitor"}: + if priority_val >= trade_priority_threshold: + if not holding_active and buying_power_value >= min_cash_required: + immediate_action = "trade" + active["immediate_actions"] = "trade" + else: + immediate_action = "escalate" + active["immediate_actions"] = "escalate" + + default_analysis_order = ["market", "news", "social", "fundamentals"] + analysis_candidates = [ + item for item in scheduled_sequence if item in {"market", "social", "news", "fundamentals"} + ] + try: + if analysis_candidates: + print(f"[Orchestrator] Hypothesis requests analysts {analysis_candidates} for {active_ticker}.") + else: + print(f"[Orchestrator] No explicit analyst order from hypothesis; defaulting to {default_analysis_order}.") + except Exception: + pass + + try: + print( + f"[Orchestrator] Active {active_ticker or ''} priority {priority_val:.2f} -> {immediate_action.upper()} " + f"(cash ${cash_value:,.0f}, buying power ${buying_power_value:,.0f}, min cash ${min_cash_required:,.0f})" + ) + except Exception: + pass + + planner_raw: Dict[str, Any] = {} + planner_actions: List[str] = [] + planner_immediate: str = "" + planner_notes: str = "" + planner_next_directive: Optional[str] = None + ticker_plan_summaries: Dict[str, Any] = dict(state.get("orchestrator_ticker_plans", {}) or {}) + + ACTION_ALIASES = { + "market": "market", + "market_analyst": "market", + "run_market": "market", + "run_market_analyst": "market", + "news": "news", + "news_analyst": "news", + "run_news": "news", + "social": "social", + "social_analyst": "social", + "fundamentals": "fundamentals", + "fundamental": "fundamentals", + "fundamentals_analyst": "fundamentals", + "debate": "debate", + "research_debate": "debate", + "manager": "manager", + "research_manager": "manager", + "trader": "trader", + "risk": "risk", + "risk_manager": "risk", + "orchestrator": "orchestrator", + "stop": "end", + "end": "end", + } + + def normalise_action(value: str) -> Optional[str]: + key = (value or "").strip().lower() + return ACTION_ALIASES.get(key) + + def append_unique(target: List[str], items: List[str]) -> None: + seen = set(target) + for elem in items: + if elem not in seen: + target.append(elem) + seen.add(elem) + + trade_date = state.get("trade_date", "") + + def _coerce_trade_date(value: str) -> date: + if not value: + return date.today() + try: + dt_value = datetime.fromisoformat(value) + except ValueError: + try: + dt_value = datetime.fromisoformat(f"{value}T00:00:00") + except ValueError: + dt_value = datetime.today() + return dt_value.date() + + trade_date_obj = _coerce_trade_date(trade_date) + start_dt = trade_date_obj - timedelta(days=max(1, int(market_lookback))) + start_date_str = start_dt.isoformat() + end_date_str = trade_date_obj.isoformat() + + market_data: Dict[str, str] = dict(market_data_cache) + + for symbol in focus_symbols: + cache_key = symbol.upper() + if cache_key not in quick_signals or not quick_signals.get(cache_key): + try: + quick_signals[cache_key] = fast_news_fetcher(symbol, trade_date, sentiment_lookback, headline_limit) + except Exception as exc: # pragma: no cover - vendor failure is non-critical + quick_signals[cache_key] = {"error": str(exc)} + if cache_key not in market_data or not market_data.get(cache_key): + try: + market_data[cache_key] = route_to_vendor("get_stock_data", symbol, start_date_str, end_date_str) + except Exception as exc: # pragma: no cover - vendor failure is non-critical + market_data[cache_key] = f"Error fetching market data: {exc}" + + try: + print( + "[Orchestrator] Retrieved market data for: " + + ", ".join(sorted(market_data.keys())) + ) + except Exception: # pragma: no cover - defensive logging + pass + + existing_plan = ticker_plan_summaries.get(active_ticker or "", {}) if active_ticker else {} + if plan_generator is not None: + planner_payload = { + "profile": payload["profile"], + "hypotheses": hypotheses, + "active_hypothesis": active, + "summary": summary_text, + "status": status_text, + "account_summary": account_summary, + "positions_summary": positions_summary, + "portfolio_snapshots": snapshots, + "quick_signals": quick_signals, + "market_data": market_data, + "focus_symbols": focus_symbols, + "focus_symbol": active_ticker, + "trade_policy": trade_policy, + "buying_power": buying_power_value, + "cash_available": cash_value, + "portfolio_value": portfolio_value, + } + if existing_plan: + planner_raw = existing_plan + else: + planner_result = plan_generator(planner_payload) or {} + if isinstance(planner_result, dict): + planner_raw = planner_result + else: + planner_raw = {"text": str(planner_result)} + ticker_plan_summaries[active_ticker or ""] = planner_raw if planner_raw else {} + + if isinstance(planner_raw, dict): + plan_structured = planner_raw.get("structured") + planner_notes = planner_raw.get("text") or planner_raw.get("notes", "") + if plan_structured is None: + plan_text = planner_raw.get("text") + if plan_text: + try: + plan_structured = json.loads(plan_text) + except json.JSONDecodeError: + plan_structured = None + if plan_structured is None: + plan_structured = planner_raw.get("plan") + else: + plan_structured = None + + if isinstance(plan_structured, str): + try: + plan_structured = json.loads(plan_structured) + except json.JSONDecodeError: + plan_structured = None + + if isinstance(plan_structured, dict): + raw_actions = plan_structured.get("actions") or plan_structured.get("steps") or [] + planner_immediate = str(plan_structured.get("next_decision") or "").lower() + planner_notes = planner_notes or plan_structured.get("notes", "") + planner_next_directive = plan_structured.get("next_directive") + elif isinstance(plan_structured, list): + raw_actions = plan_structured + else: + raw_actions = [] + + if planner_immediate and not immediate_action: + immediate_action = planner_immediate + + for item in raw_actions: + if isinstance(item, str): + normalized = normalise_action(item) + elif isinstance(item, dict): + action_value = item.get("action") or item.get("name") or item.get("tool") + normalized = normalise_action(str(action_value)) if action_value else None + if not planner_notes and item.get("reason"): + planner_notes = str(item.get("reason")) + else: + normalized = None + if normalized: + planner_actions.append(normalized) + + for symbol in focus_symbols: + if symbol == active_ticker: + continue + if override_symbols and symbol not in override_symbols: + continue + extra_payload = { + "profile": payload["profile"], + "hypotheses": hypotheses, + "active_hypothesis": hypothesis_map.get(symbol), + "summary": summary_text, + "status": status_text, + "account_summary": account_summary, + "positions_summary": positions_summary, + "portfolio_snapshots": snapshots, + "quick_signals": quick_signals, + "market_data": market_data, + "focus_symbols": focus_symbols, + "focus_symbol": symbol, + "trade_policy": trade_policy, + "buying_power": buying_power_value, + "cash_available": cash_value, + "portfolio_value": portfolio_value, + } + if symbol in ticker_plan_summaries and ticker_plan_summaries[symbol]: + continue + try: + extra_result = plan_generator(extra_payload) or {} + except Exception as exc: # pragma: no cover - planner failures shouldn't halt orchestration + extra_result = {"error": str(exc)} + if isinstance(extra_result, dict): + ticker_plan_summaries[symbol] = extra_result + else: + ticker_plan_summaries[symbol] = {"text": str(extra_result)} + + action_queue: List[str] = [] + if planner_actions: + append_unique(action_queue, planner_actions) + try: + print(f"[Orchestrator] Sequential planner actions for {active_ticker}: {planner_actions}") + except Exception: + pass + else: + append_unique(action_queue, analysis_candidates or default_analysis_order) + + if immediate_action in {"escalate", "trade", "execute"}: + append_unique(action_queue, ["debate"]) + append_unique(action_queue, ["manager"]) + append_unique(action_queue, ["trader"]) + if immediate_action in {"trade", "execute"}: + append_unique(action_queue, ["risk"]) + + next_directive_value = planner_next_directive or "end" + if override_symbols: + remaining_focus = [] + next_directive_value = "end" + else: + remaining_focus = [sym for sym in focus_symbols if sym != active_ticker] + if remaining_focus and next_directive_value == "end": + next_directive_value = "orchestrator" + + try: + print( + "[Orchestrator] Queue:" + f" focus={focus_symbols} | active={active_ticker} | immediate={immediate_action} | queue={action_queue}" + ) + print( + "[Orchestrator] Ticker plans generated for: " + + ", ".join(sorted(ticker_plan_summaries.keys())) + ) + except Exception: # pragma: no cover - defensive logging + pass + + # Log initial hypothesis for debugging + try: + formatted = json.dumps( + { + "hypotheses": hypotheses, + "summary": summary_text, + "status": status_text, + "action": immediate_action, + "queue": action_queue, + "planner_plan": planner_raw, + }, + indent=2, + ) + print(f"[Orchestrator] Initial hypotheses:\n{formatted}") + except Exception: + pass + + analyst_schedule = [item for item in action_queue if item in {"market", "social", "news", "fundamentals"}] + if planner_actions: + plan_snapshot = planner_actions + else: + plan_snapshot = analysis_candidates or [item for item in default_analysis_order if item in action_queue] + + serializable_plan = planner_raw + if isinstance(planner_raw, dict): + serializable_plan = {k: v for k, v in planner_raw.items() if k != "raw"} + if "raw" in planner_raw: + try: + serializable_plan.setdefault("raw_repr", repr(planner_raw["raw"])) + except Exception: + serializable_plan.setdefault("raw_repr", "") + + state_update = { + "messages": [response], + "portfolio_profile": profile, + "portfolio_summary": summary_text, + "orchestrator_status": status_text, + "alpaca_account_text": first_snapshot.get("account", ""), + "alpaca_positions_text": first_snapshot.get("positions", ""), + "alpaca_orders_text": first_snapshot.get("orders", ""), + "orchestrator_hypotheses": hypotheses, + "active_hypothesis": active, + "scheduled_analysts": analyst_schedule, + "scheduled_analysts_plan": plan_snapshot, + "company_of_interest": active_ticker or state.get("company_of_interest"), + "target_ticker": active_ticker, + "orchestrator_action": immediate_action, + "portfolio_account_summary": account_summary, + "portfolio_positions_summary": positions_summary, + "orchestrator_focus_symbols": focus_symbols, + "orchestrator_quick_signals": quick_signals, + "orchestrator_market_data": market_data, + "orchestrator_ticker_plans": ticker_plan_summaries, + "orchestrator_pending_tickers": remaining_focus, + "orchestrator_buying_power": buying_power_value, + "orchestrator_cash_available": cash_value, + "orchestrator_portfolio_value": portfolio_value, + "action_queue": action_queue, + "next_directive": next_directive_value, + "planner_plan": serializable_plan, + "planner_notes": planner_notes, + "orchestrator_focus_override": override_symbols, + } + + return state_update + + return orchestrator_node diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index c537fa2f..37346849 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -4,7 +4,16 @@ import json def create_research_manager(llm, memory): def research_manager_node(state) -> dict: + active = state.get("active_hypothesis") + if active: + action = (active.get("immediate_action") or active.get("action") or "").lower() + if action not in {"escalate", "trade", "execute"}: + return { + "investment_debate_state": state.get("investment_debate_state", {}), + "investment_plan": state.get("investment_plan", ""), + } history = state["investment_debate_state"].get("history", "") + _ = state.get("target_ticker") or state.get("company_of_interest") market_research_report = state["market_report"] sentiment_report = state["sentiment_report"] news_report = state["news_report"] diff --git a/tradingagents/agents/managers/risk_manager.py b/tradingagents/agents/managers/risk_manager.py index fba763d6..e248e8a0 100644 --- a/tradingagents/agents/managers/risk_manager.py +++ b/tradingagents/agents/managers/risk_manager.py @@ -4,8 +4,16 @@ import json def create_risk_manager(llm, memory): def risk_manager_node(state) -> dict: + active = state.get("active_hypothesis") + if active: + action = (active.get("immediate_action") or active.get("action") or "").lower() + if action not in {"escalate", "trade", "execute"}: + return { + "risk_debate_state": state.get("risk_debate_state", {}), + "final_trade_decision": state.get("final_trade_decision", "Final decision withheld by orchestrator."), + } - company_name = state["company_of_interest"] + company_name = state.get("target_ticker") or state["company_of_interest"] history = state["risk_debate_state"]["history"] risk_debate_state = state["risk_debate_state"] diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 1b05c35d..bd067647 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -5,14 +5,27 @@ import json def create_trader(llm, memory): def trader_node(state, name): - company_name = state["company_of_interest"] + active = state.get("active_hypothesis") + if active: + action = (active.get("immediate_action") or active.get("action") or "").lower() + if action not in {"escalate", "trade", "execute"}: + return { + "trader_investment_plan": state.get("trader_investment_plan", ""), + "sender": name, + } + company_name = state.get("target_ticker") or state["company_of_interest"] investment_plan = state["investment_plan"] market_research_report = state["market_report"] sentiment_report = state["sentiment_report"] news_report = state["news_report"] fundamentals_report = state["fundamentals_report"] + portfolio_summary = state.get("portfolio_summary", "") - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + curr_situation = ( + f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + ) + if portfolio_summary: + curr_situation += f"\n\nPortfolio Briefing:\n{portfolio_summary}" past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" @@ -22,9 +35,22 @@ def create_trader(llm, memory): else: past_memory_str = "No past memories found." + context_message = ( + f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. " + "This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. " + "Use this plan as a foundation for evaluating your next trading decision." + f"\n\nProposed Investment Plan: {investment_plan}" + ) + if portfolio_summary: + context_message += ( + "\n\nCurrent Portfolio Briefing:\n" + f"{portfolio_summary}\n" + "Ensure any trade recommendation respects buying power, risk limits, and existing exposures." + ) + context = { "role": "user", - "content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.", + "content": context_message, } messages = [ diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 3a859ea1..d35ad4ca 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,4 +1,4 @@ -from typing import Annotated, Sequence +from typing import Annotated, Dict, List, Sequence from datetime import date, timedelta, datetime from typing_extensions import TypedDict, Optional from langchain_openai import ChatOpenAI @@ -50,6 +50,7 @@ class RiskDebateState(TypedDict): class AgentState(MessagesState): company_of_interest: Annotated[str, "Company that we are interested in trading"] trade_date: Annotated[str, "What date we are trading at"] + target_ticker: Annotated[str, "Ticker selected by the orchestrator for deep analysis"] sender: Annotated[str, "Agent that sent this message"] @@ -74,3 +75,22 @@ class AgentState(MessagesState): RiskDebateState, "Current state of the debate on evaluating risk" ] final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] + portfolio_profile: Annotated[Dict[str, object], "Static portfolio configuration"] + portfolio_summary: Annotated[str, "Orchestrator briefing for the run"] + orchestrator_status: Annotated[str, "Status message produced by the orchestrator"] + alpaca_account_text: Annotated[str, "Raw Alpaca account summary text"] + alpaca_positions_text: Annotated[str, "Raw Alpaca positions text"] + alpaca_orders_text: Annotated[str, "Raw Alpaca recent orders text"] + orchestrator_hypotheses: Annotated[List[Dict[str, object]], "List of active hypotheses evaluated by the orchestrator"] + active_hypothesis: Annotated[Optional[Dict[str, object]], "The hypothesis currently under deep analysis"] + scheduled_analysts: Annotated[List[str], "Analyst roles the orchestrator requested to run for the active hypothesis"] + scheduled_analysts_plan: Annotated[List[str], "Full list of orchestrator-requested analysts for this cycle"] + orchestrator_action: Annotated[str, "Immediate action directive from the orchestrator for the active hypothesis"] + action_queue: Annotated[List[str], "Actions queued by the orchestrator for execution"] + next_directive: Annotated[str, "Directive for the scheduler when the queue is empty"] + next_node: Annotated[str, "Next node selected by the action scheduler"] + portfolio_account_summary: Annotated[Dict[str, object], "Parsed Alpaca account summary used by the orchestrator"] + portfolio_positions_summary: Annotated[List[Dict[str, object]], "Parsed Alpaca positions summary used by the orchestrator"] + planner_plan: Annotated[Dict[str, object], "Raw plan response produced by the sequential planner"] + planner_notes: Annotated[str, "Notes or commentary returned by the sequential planner"] + orchestrator_focus_override: Annotated[List[str], "Optional override forcing orchestrator to focus on specific tickers"] diff --git a/tradingagents/agents/utils/logging_utils.py b/tradingagents/agents/utils/logging_utils.py new file mode 100644 index 00000000..283d25f3 --- /dev/null +++ b/tradingagents/agents/utils/logging_utils.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +def log_report_preview(agent_name: str, ticker: str, report: str, *, max_chars: int = 800) -> None: + """Log the agent's report in full so the operator can read the entire analysis.""" + if not report: + try: + print(f"[{agent_name}] No direct narrative output returned for {ticker} (tool-only response).") + except Exception: + pass + return + + try: + print(f"[{agent_name}] Report for {ticker}:\n{report.strip()}\n") + except Exception: + pass diff --git a/tradingagents/agents/utils/tool_runner.py b/tradingagents/agents/utils/tool_runner.py new file mode 100644 index 00000000..5a2ef736 --- /dev/null +++ b/tradingagents/agents/utils/tool_runner.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import json +from typing import Any, Dict, Iterable, List, Tuple + +from langchain_core.messages import BaseMessage, ToolMessage +from langchain_core.runnables import Runnable + + +def run_chain_with_tools( + chain: Runnable, + tools: Iterable[Any], + initial_messages: Iterable[BaseMessage], + *, + max_iterations: int = 6, + logger: Any = None, +) -> Tuple[Any, List[BaseMessage], int]: + """Execute a LangChain Runnable, fulfilling any tool calls until text is produced.""" + + messages: List[BaseMessage] = list(initial_messages) + tool_map: Dict[str, Any] = { + getattr(tool, "name", ""): tool for tool in tools if getattr(tool, "name", None) + } + last_result: Any = None + tool_runs = 0 + + for _ in range(max_iterations): + last_result = chain.invoke(messages) + messages.append(last_result) + tool_calls = getattr(last_result, "tool_calls", None) or [] + if not tool_calls: + return last_result, messages, tool_runs + + for call in tool_calls: + tool_runs += 1 + tool_name = call.get("name") or call.get("tool_name") or "" + raw_args = call.get("args") or call.get("arguments") or {} + if isinstance(raw_args, str): + try: + tool_args = json.loads(raw_args) + except json.JSONDecodeError: + tool_args = {"raw": raw_args} + else: + tool_args = raw_args + + tool = tool_map.get(tool_name) + if not tool: + output = {"error": f"Tool '{tool_name}' unavailable."} + else: + try: + output = tool.invoke(tool_args) + except Exception as exc: # pragma: no cover - defensive logging + if logger: + logger.warning("Tool %s failed: %s", tool_name, exc) + output = {"error": str(exc)} + + messages.append( + ToolMessage( + content=_stringify(output), + tool_call_id=call.get("id") or call.get("tool_call_id") or tool_name or "tool-call", + name=tool_name or None, + ) + ) + + raise RuntimeError("Tool loop exceeded max iterations before producing a response.") + + +def extract_text_from_content(content: Any) -> str: + """Normalize structured message content into printable text.""" + + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + if "text" in item and isinstance(item["text"], str): + parts.append(item["text"]) + elif "content" in item and isinstance(item["content"], str): + parts.append(item["content"]) + else: + parts.append(str(item)) + return "\n".join(part.strip() for part in parts if part).strip() + if isinstance(content, dict): + return json.dumps(content, default=str) + return str(content) + + +def _stringify(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + try: + return json.dumps(value, default=str) + except TypeError: # pragma: no cover - fallback + return str(value) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 4cd5ddef..916a7f54 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -138,10 +138,22 @@ def get_vendor(category: str, method: str = None) -> str: # Fall back to category-level configuration return config.get("data_vendors", {}).get(category, "default") +def _vendor_log(enabled: bool, message: str, level: str = "DEBUG") -> None: + if not enabled: + return + try: + print(f"{level}: {message}") + except Exception: + pass + + def route_to_vendor(method: str, *args, **kwargs): """Route method calls to appropriate vendor implementation with fallback support.""" category = get_category_for_method(method) vendor_config = get_vendor(category, method) + config_snapshot = get_config() + vendor_logging_cfg = config_snapshot.get("vendor_logging") or {} + log_verbose = bool(vendor_logging_cfg.get("verbose")) # Handle comma-separated vendors primary_vendors = [v.strip() for v in vendor_config.split(',')] @@ -161,7 +173,7 @@ def route_to_vendor(method: str, *args, **kwargs): # Debug: Print fallback ordering primary_str = " → ".join(primary_vendors) fallback_str = " → ".join(fallback_vendors) - print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]") + _vendor_log(log_verbose, f"{method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]") # Track results and execution state results = [] @@ -172,7 +184,7 @@ def route_to_vendor(method: str, *args, **kwargs): for vendor in fallback_vendors: if vendor not in VENDOR_METHODS[method]: if vendor in primary_vendors: - print(f"INFO: Vendor '{vendor}' not supported for method '{method}', falling back to next vendor") + _vendor_log(log_verbose, f"Vendor '{vendor}' not supported for method '{method}', falling back to next vendor", "INFO") continue vendor_impl = VENDOR_METHODS[method][vendor] @@ -185,12 +197,12 @@ def route_to_vendor(method: str, *args, **kwargs): # Debug: Print current attempt vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK" - print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})") + _vendor_log(log_verbose, f"Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})") # Handle list of methods for a vendor if isinstance(vendor_impl, list): vendor_methods = [(impl, vendor) for impl in vendor_impl] - print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions") + _vendor_log(log_verbose, f"Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions") else: vendor_methods = [(vendor_impl, vendor)] @@ -198,20 +210,20 @@ def route_to_vendor(method: str, *args, **kwargs): vendor_results = [] for impl_func, vendor_name in vendor_methods: try: - print(f"DEBUG: Calling {impl_func.__name__} from vendor '{vendor_name}'...") + _vendor_log(log_verbose, f"Calling {impl_func.__name__} from vendor '{vendor_name}'...") result = impl_func(*args, **kwargs) vendor_results.append(result) - print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully") + _vendor_log(log_verbose, f"{impl_func.__name__} from vendor '{vendor_name}' completed successfully", "SUCCESS") except AlphaVantageRateLimitError as e: if vendor == "alpha_vantage": - print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor") - print(f"DEBUG: Rate limit details: {e}") + _vendor_log(log_verbose, "Alpha Vantage rate limit exceeded, falling back to next available vendor", "RATE_LIMIT") + _vendor_log(log_verbose, f"Rate limit details: {e}") # Continue to next vendor for fallback continue except Exception as e: # Log error but continue with other implementations - print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}") + _vendor_log(log_verbose, f"{impl_func.__name__} from vendor '{vendor_name}' failed: {e}", "FAILED") continue # Add this vendor's results @@ -219,26 +231,26 @@ def route_to_vendor(method: str, *args, **kwargs): results.extend(vendor_results) successful_vendor = vendor result_summary = f"Got {len(vendor_results)} result(s)" - print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}") + _vendor_log(log_verbose, f"Vendor '{vendor}' succeeded - {result_summary}", "SUCCESS") # Stopping logic: Stop after first successful vendor for single-vendor configs # Multiple vendor configs (comma-separated) may want to collect from multiple sources if len(primary_vendors) == 1: - print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)") + _vendor_log(log_verbose, f"Stopping after successful vendor '{vendor}' (single-vendor config)") break else: - print(f"FAILED: Vendor '{vendor}' produced no results") + _vendor_log(log_verbose, f"Vendor '{vendor}' produced no results", "FAILED") # Final result summary if not results: - print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'") + _vendor_log(log_verbose, f"All {vendor_attempt_count} vendor attempts failed for method '{method}'", "FAILURE") raise RuntimeError(f"All vendor implementations failed for method '{method}'") else: - print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)") + _vendor_log(log_verbose, f"Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)", "FINAL") # Return single result if only one, otherwise concatenate as string if len(results) == 1: return results[0] else: # Convert all results to strings and concatenate - return '\n'.join(str(result) for result in results) \ No newline at end of file + return '\n'.join(str(result) for result in results) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 1f40a2a2..95754c80 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -30,4 +30,119 @@ DEFAULT_CONFIG = { # Example: "get_stock_data": "alpha_vantage", # Override category default # Example: "get_news": "openai", # Override category default }, + "portfolio_orchestrator": { + "profile_name": os.getenv("PORTFOLIO_PROFILE_NAME", "Balanced Multi-Asset"), + "mandate": os.getenv( + "PORTFOLIO_MANDATE", + "Preserve capital while capturing medium-term growth opportunities in technology and consumer sectors.", + ), + "risk_limits": { + "max_single_position_pct": float(os.getenv("PORTFOLIO_MAX_POSITION_PCT", "0.15")), + "max_sector_pct": float(os.getenv("PORTFOLIO_MAX_SECTOR_PCT", "0.35")), + }, + "notes": os.getenv( + "PORTFOLIO_NOTES", + "Prioritize liquid large-cap names. Avoid exceeding buying power and respect existing hedges.", + ), + "universe": os.getenv("PORTFOLIO_UNIVERSE", "NVDA,AAPL,MSFT,AMZN,TSLA"), + "sentiment_lookback_days": int(os.getenv("PORTFOLIO_SENTIMENT_LOOKBACK", "2")), + "news_headline_limit": int(os.getenv("PORTFOLIO_NEWS_LIMIT", "5")), + "hypothesis_threshold": float(os.getenv("PORTFOLIO_HYPOTHESIS_THRESHOLD", "0.6")), + "max_concurrent_hypotheses": int(os.getenv("PORTFOLIO_MAX_HYPOTHESES", "2")), + "market_data_lookback_days": int(os.getenv("PORTFOLIO_MARKET_LOOKBACK", "30")), + "trade_activation": { + "priority_threshold": float(os.getenv("PORTFOLIO_TRADE_PRIORITY_THRESHOLD", "0.8")), + "min_cash_absolute": float(os.getenv("PORTFOLIO_TRADE_MIN_CASH", "50000")), + "min_cash_ratio": float(os.getenv("PORTFOLIO_TRADE_MIN_CASH_RATIO", "0.1")), + }, + }, + "alpaca_mcp": { + "enabled": os.getenv("ALPACA_MCP_ENABLED", "false").lower() not in ("false", "0", "no"), + "transport": os.getenv("ALPACA_MCP_TRANSPORT", "http"), + "host": os.getenv("ALPACA_MCP_HOST", "127.0.0.1"), + "base_url": os.getenv("ALPACA_MCP_BASE_URL", ""), + "port": int(os.getenv("ALPACA_MCP_PORT", "8000")), + "command": os.getenv("ALPACA_MCP_COMMAND", ""), + "timeout_seconds": float(os.getenv("ALPACA_MCP_TIMEOUT_SECONDS", "30")), + "required_tools": [ + "get_account_info", + "get_positions", + "get_orders", + "get_market_clock", + ], + }, + "trade_execution": { + "enabled": os.getenv("TRADE_EXECUTION_ENABLED", "false").lower() not in ("false", "0", "no"), + "dry_run": os.getenv("TRADE_EXECUTION_DRY_RUN", "true").lower() not in ("false", "0", "no"), + "default_order_quantity": float(os.getenv("TRADE_EXECUTION_DEFAULT_QTY", "10")), + "time_in_force": os.getenv("TRADE_EXECUTION_TIF", "day"), + }, + "market_data": { + "api_key": os.getenv("APCA_API_KEY_ID", ""), + "secret_key": os.getenv("APCA_API_SECRET_KEY", ""), + "feed": os.getenv("ALPACA_DATA_FEED", "iex"), + "news_stream_url": os.getenv("ALPACA_NEWS_STREAM_URL", ""), + }, + "trading_strategies": { + "default": os.getenv("TRADINGAGENTS_DEFAULT_STRATEGY", "swing"), + "presets": { + "day_trade": { + "label": "Intraday momentum scalp", + "horizon_hours": float(os.getenv("TRADINGAGENTS_DAYTRADE_HOURS", "6")), + "target_pct": float(os.getenv("TRADINGAGENTS_DAYTRADE_TARGET", "0.02")), + "stop_pct": float(os.getenv("TRADINGAGENTS_DAYTRADE_STOP", "0.01")), + "follow_up": "close_before_market_close", + "urgency": "high", + "success_metric": "Hit target gain within the same session", + "failure_metric": "Trigger stop or reach session end without target", + "notes": "Used for rapid intraday reactions; requires strict discipline.", + }, + "swing": { + "label": "Multi-day swing trade", + "horizon_hours": float(os.getenv("TRADINGAGENTS_SWING_HOURS", "72")), + "target_pct": float(os.getenv("TRADINGAGENTS_SWING_TARGET", "0.04")), + "stop_pct": float(os.getenv("TRADINGAGENTS_SWING_STOP", "0.02")), + "follow_up": "reassess_every_close", + "urgency": "medium", + "success_metric": "Capture mid-term move within horizon", + "failure_metric": "Price violates stop or catalyst deteriorates", + "notes": "Default for most holdings; expects catalysts to resolve within a few days.", + }, + "position": { + "label": "Longer-term position build", + "horizon_hours": float(os.getenv("TRADINGAGENTS_POSITION_HOURS", "336")), + "target_pct": float(os.getenv("TRADINGAGENTS_POSITION_TARGET", "0.08")), + "stop_pct": float(os.getenv("TRADINGAGENTS_POSITION_STOP", "0.04")), + "follow_up": "weekly_review", + "urgency": "low", + "success_metric": "Fundamental thesis validated and price reaches target", + "failure_metric": "Thesis breaks or drawdown exceeds tolerance", + "notes": "Use for core holdings where narrative spans weeks.", + }, + }, + }, + "autopilot": { + "enabled": os.getenv("TRADINGAGENTS_AUTOPILOT", "false").lower() in ("1", "true", "yes"), + "auto_trade_on_start": os.getenv("AUTOPILOT_SEED_AUTO_TRADE", "true").lower() not in ("false", "0", "no"), + "event_loop_interval_seconds": int(os.getenv("AUTOPILOT_LOOP_SECONDS", "10")), + "price_poll_interval_seconds": int(os.getenv("AUTOPILOT_PRICE_POLL_SECONDS", "60")), + "pre_market_research_minutes": int(os.getenv("AUTOPILOT_PREMARKET_MINUTES", "30")), + }, + "auto_trade": { + "max_tickers": int(os.getenv("AUTO_TRADE_MAX_TICKERS", "12")), + "skip_when_market_closed": os.getenv("AUTO_TRADE_SKIP_WHEN_MARKET_CLOSED", "true").lower() + not in ("false", "0", "no"), + "mode": os.getenv("AUTO_TRADE_MODE", "graph"), + "responses_model": os.getenv("AUTO_TRADE_RESPONSES_MODEL", os.getenv("TRADINGAGENTS_RESPONSES_MODEL", "")), + "responses_reasoning_effort": os.getenv("AUTO_TRADE_RESPONSES_REASONING", ""), + "responses_max_turns": int(os.getenv("AUTO_TRADE_RESPONSES_MAX_TURNS", "8")), + "memory": { + "enabled": os.getenv("AUTO_TRADE_MEMORY_ENABLED", "true").lower() not in ("false", "0", "no"), + "dir": os.getenv("AUTO_TRADE_MEMORY_DIR", os.path.join(os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "memory")), + "max_entries": int(os.getenv("AUTO_TRADE_MEMORY_MAX_ENTRIES", "5")), + }, + }, + "vendor_logging": { + "verbose": os.getenv("VENDOR_LOG_VERBOSE", "false").lower() in ("1", "true", "yes", "on") + }, } diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index 58ebd0a8..d9da2e70 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -19,10 +19,14 @@ class Propagator: self, company_name: str, trade_date: str ) -> Dict[str, Any]: """Create the initial state for the agent graph.""" + company_value = (company_name or "").strip() + trade_value = str(trade_date) if trade_date else "" + initial_prompt = company_value or "Portfolio orchestration start" return { - "messages": [("human", company_name)], - "company_of_interest": company_name, - "trade_date": str(trade_date), + "messages": [("human", initial_prompt)], + "company_of_interest": company_value, + "trade_date": trade_value, + "target_ticker": company_value, "investment_debate_state": InvestDebateState( {"history": "", "current_response": "", "count": 0} ), @@ -39,6 +43,33 @@ class Propagator: "fundamentals_report": "", "sentiment_report": "", "news_report": "", + "investment_plan": "", + "trader_investment_plan": "", + "final_trade_decision": "", + "portfolio_profile": {}, + "portfolio_summary": "", + "orchestrator_status": "not_started", + "alpaca_account_text": "", + "alpaca_positions_text": "", + "alpaca_orders_text": "", + "orchestrator_hypotheses": [], + "active_hypothesis": None, + "scheduled_analysts": [], + "scheduled_analysts_plan": [], + "orchestrator_action": "", + "action_queue": [], + "next_directive": "stop", + "next_node": "", + "portfolio_account_summary": {}, + "portfolio_positions_summary": [], + "planner_plan": {}, + "planner_notes": "", + "orchestrator_pending_tickers": [], + "orchestrator_focus_symbols": [], + "orchestrator_quick_signals": {}, + "orchestrator_market_data": {}, + "orchestrator_ticker_plans": {}, + "orchestrator_focus_override": [], } def get_graph_args(self) -> Dict[str, Any]: diff --git a/tradingagents/graph/scheduler.py b/tradingagents/graph/scheduler.py new file mode 100644 index 00000000..fd3529ea --- /dev/null +++ b/tradingagents/graph/scheduler.py @@ -0,0 +1,60 @@ +"""Action scheduler node for orchestrator-controlled execution.""" + +from __future__ import annotations + +from typing import Dict, Any + + +DISPATCH_MAP = { + "market": "Market Analyst", + "news": "News Analyst", + "social": "Social Analyst", + "fundamentals": "Fundamentals Analyst", + "debate": "Bull Researcher", + "manager": "Research Manager", + "trader": "Trader", + "risk": "Risky Analyst", + "orchestrator": "Portfolio Orchestrator", +} + + +def create_action_scheduler(): + """Return a node that routes execution based on the orchestrator's queue.""" + + def scheduler_node(state: Dict[str, Any]) -> Dict[str, Any]: + queue = list(state.get("action_queue") or []) + if queue: + action = str(queue.pop(0) or "").strip().lower() + next_node = DISPATCH_MAP.get(action, "Portfolio Orchestrator") + try: + print(f"[Action Scheduler] Dispatching '{action}' to {next_node}") + except Exception: + pass + return { + "action_queue": queue, + "next_node": next_node, + } + + directive = (state.get("next_directive") or "stop").lower() + if directive in {"continue", "orchestrator"}: + next_node = "Portfolio Orchestrator" + try: + print(f"[Action Scheduler] Queue empty; returning control to orchestrator (directive={directive}).") + except Exception: + pass + return { + "action_queue": queue, + "next_node": next_node, + } + + next_node = DISPATCH_MAP.get(directive, "end") + try: + print(f"[Action Scheduler] No pending actions; directive '{directive}' maps to {next_node}.") + except Exception: + pass + return { + "action_queue": queue, + "next_node": next_node, + } + + return scheduler_node diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index b270ffc0..a58153ed 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -25,6 +25,8 @@ class GraphSetup: invest_judge_memory, risk_manager_memory, conditional_logic: ConditionalLogic, + orchestrator_node, + action_scheduler_node, ): """Initialize with required components.""" self.quick_thinking_llm = quick_thinking_llm @@ -36,6 +38,8 @@ class GraphSetup: self.invest_judge_memory = invest_judge_memory self.risk_manager_memory = risk_manager_memory self.conditional_logic = conditional_logic + self.orchestrator_node = orchestrator_node + self.action_scheduler_node = action_scheduler_node def setup_graph( self, selected_analysts=["market", "social", "news", "fundamentals"] @@ -116,6 +120,8 @@ class GraphSetup: ) workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) + workflow.add_node("Action Scheduler", self.action_scheduler_node) + # Add other nodes workflow.add_node("Bull Researcher", bull_researcher_node) workflow.add_node("Bear Researcher", bear_researcher_node) @@ -126,13 +132,13 @@ class GraphSetup: workflow.add_node("Safe Analyst", safe_analyst) workflow.add_node("Risk Judge", risk_manager_node) - # Define edges - # Start with the first analyst - first_analyst = selected_analysts[0] - workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst") + # Define edges controlled by orchestrator + scheduler + workflow.add_node("Portfolio Orchestrator", self.orchestrator_node) + workflow.add_edge(START, "Portfolio Orchestrator") + workflow.add_edge("Portfolio Orchestrator", "Action Scheduler") # Connect analysts in sequence - for i, analyst_type in enumerate(selected_analysts): + for analyst_type in selected_analysts: current_analyst = f"{analyst_type.capitalize()} Analyst" current_tools = f"tools_{analyst_type}" current_clear = f"Msg Clear {analyst_type.capitalize()}" @@ -145,12 +151,7 @@ class GraphSetup: ) workflow.add_edge(current_tools, current_analyst) - # Connect to next analyst or to Bull Researcher if this is the last analyst - if i < len(selected_analysts) - 1: - next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" - workflow.add_edge(current_clear, next_analyst) - else: - workflow.add_edge(current_clear, "Bull Researcher") + workflow.add_edge(current_clear, "Action Scheduler") # Add remaining edges workflow.add_conditional_edges( @@ -169,8 +170,8 @@ class GraphSetup: "Research Manager": "Research Manager", }, ) - workflow.add_edge("Research Manager", "Trader") - workflow.add_edge("Trader", "Risky Analyst") + workflow.add_edge("Research Manager", "Action Scheduler") + workflow.add_edge("Trader", "Action Scheduler") workflow.add_conditional_edges( "Risky Analyst", self.conditional_logic.should_continue_risk_analysis, @@ -196,7 +197,27 @@ class GraphSetup: }, ) - workflow.add_edge("Risk Judge", END) + workflow.add_edge("Risk Judge", "Action Scheduler") + + def scheduler_target(state: AgentState) -> str: + return state.get("next_node", "end") + + workflow.add_conditional_edges( + "Action Scheduler", + scheduler_target, + { + "Market Analyst": "Market Analyst", + "Social Analyst": "Social Analyst", + "News Analyst": "News Analyst", + "Fundamentals Analyst": "Fundamentals Analyst", + "Bull Researcher": "Bull Researcher", + "Research Manager": "Research Manager", + "Trader": "Trader", + "Risky Analyst": "Risky Analyst", + "Portfolio Orchestrator": "Portfolio Orchestrator", + "end": END, + }, + ) # Compile and return return workflow.compile() diff --git a/tradingagents/graph/signal_processing.py b/tradingagents/graph/signal_processing.py index 903e8529..7df97e0f 100644 --- a/tradingagents/graph/signal_processing.py +++ b/tradingagents/graph/signal_processing.py @@ -20,6 +20,14 @@ class SignalProcessor: Returns: Extracted decision (BUY, SELL, or HOLD) """ + if not full_signal: + return "HOLD" + + normalized = full_signal.strip().upper() + for keyword in ("BUY", "SELL", "HOLD", "TRADE"): + if keyword in normalized: + return "BUY" if keyword == "TRADE" else keyword + messages = [ ( "system", diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 40cdff75..b9809776 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,10 +1,11 @@ # TradingAgents/graph/trading_graph.py import os -from pathlib import Path import json -from datetime import date -from typing import Dict, Any, Tuple, List, Optional +import logging +from pathlib import Path +from datetime import date, datetime, timedelta +from typing import Dict, Any, Tuple, List, Optional, TYPE_CHECKING from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic @@ -15,12 +16,43 @@ from langgraph.prebuilt import ToolNode from tradingagents.agents import * from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import FinancialSituationMemory +from tradingagents.agents.managers.orchestrator import create_portfolio_orchestrator from tradingagents.agents.utils.agent_states import ( AgentState, InvestDebateState, RiskDebateState, ) from tradingagents.dataflows.config import set_config +from tradingagents.integrations.alpaca_mcp import AlpacaMCPClient, AlpacaMCPConfig, AlpacaMCPError +from tradingagents.dataflows.interface import route_to_vendor +from tradingagents.graph.scheduler import create_action_scheduler + +if TYPE_CHECKING: + from tradingagents.services.account import AccountSnapshot + +def _extract_json_block(text: str) -> Dict[str, Any]: + """Attempt to locate a JSON object within free-form text.""" + + import json + if not text: + return {} + + text = text.strip() + if text.startswith("```"): + parts = text.split("```") + for part in parts: + candidate = part.strip() + if candidate.startswith("{") and candidate.endswith("}"): + try: + return json.loads(candidate) + except Exception: + continue + if text.startswith("{") and text.endswith("}"): + try: + return json.loads(text) + except Exception: + return {} + return {} # Import the new abstract tool methods from agent_utils from tradingagents.agents.utils.agent_utils import ( @@ -51,6 +83,8 @@ class TradingAgentsGraph: selected_analysts=["market", "social", "news", "fundamentals"], debug=False, config: Dict[str, Any] = None, + *, + skip_initial_probes: bool = False, ): """Initialize the trading agents graph and components. @@ -61,6 +95,7 @@ class TradingAgentsGraph: """ self.debug = debug self.config = config or DEFAULT_CONFIG + self.logger = logging.getLogger(__name__) # Update the interface's config set_config(self.config) @@ -83,7 +118,36 @@ class TradingAgentsGraph: self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) else: raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") - + + # Portfolio + Alpaca configuration + self.portfolio_profile: Dict[str, Any] = self.config.get("portfolio_orchestrator", {}) + try: + self.alpaca_config = AlpacaMCPConfig.from_dict(self.config.get("alpaca_mcp", {})) + self.alpaca_config.validate() + except ValueError as exc: + self.logger.warning("Alpaca MCP configuration invalid: %s", exc) + self.alpaca_config = AlpacaMCPConfig.from_dict( + { + "enabled": False, + "transport": "http", + "host": "127.0.0.1", + "port": 8000, + "command": "", + "timeout_seconds": 30.0, + "required_tools": [], + } + ) + self._alpaca_client: Optional[AlpacaMCPClient] = None + self._manual_portfolio_snapshot: Optional[Dict[str, str]] = None + + self.trade_execution_config: Dict[str, Any] = self.config.get("trade_execution", {}) + + # Probe MCP connectivity early so any configuration issues are visible + # before the graph starts processing signals. + if not skip_initial_probes: + self._report_mcp_connectivity() + self._log_alpaca_account_overview() + # Initialize memories self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config) @@ -94,6 +158,15 @@ class TradingAgentsGraph: # Create tool nodes self.tool_nodes = self._create_tool_nodes() + orchestrator_node = create_portfolio_orchestrator( + self.quick_thinking_llm, + self.portfolio_profile, + self._collect_portfolio_context, + self._fetch_quick_signals, + self._generate_plan_with_llm, + ) + action_scheduler_node = create_action_scheduler() + # Initialize components self.conditional_logic = ConditionalLogic() self.graph_setup = GraphSetup( @@ -106,6 +179,8 @@ class TradingAgentsGraph: self.invest_judge_memory, self.risk_manager_memory, self.conditional_logic, + orchestrator_node, + action_scheduler_node, ) self.propagator = Propagator() @@ -115,11 +190,351 @@ class TradingAgentsGraph: # State tracking self.curr_state = None self.ticker = None + self.trade_date = None self.log_states_dict = {} # date to full state dict # Set up the graph self.graph = self.graph_setup.setup_graph(selected_analysts) + # ------------------------------------------------------------------ + # Portfolio context helpers + # ------------------------------------------------------------------ + def _report_mcp_connectivity(self) -> None: + if getattr(self.alpaca_config, "enabled", False): + client = self._get_alpaca_client() + if client and not client.verify_connection(): + self.logger.warning( + "Alpaca MCP connection check failed; see previous log messages for details." + ) + else: + self.logger.info("Alpaca MCP disabled; skipping connectivity check.") + + def _get_alpaca_client(self) -> Optional[AlpacaMCPClient]: + if not getattr(self.alpaca_config, "enabled", False): + return None + if self._alpaca_client is None: + self._alpaca_client = AlpacaMCPClient(self.alpaca_config, self.logger) + return self._alpaca_client + + def _log_alpaca_account_overview(self) -> None: + if not getattr(self.alpaca_config, "enabled", False): + msg = "Alpaca MCP disabled; skipping account overview." + self.logger.info(msg) + print(msg) + return + + client = self._get_alpaca_client() + if client is None: + return + + try: + account_text = (client.fetch_account_info() or "").strip() + positions_text = (client.fetch_positions() or "").strip() + orders_text = (client.fetch_orders() or "").strip() + except AlpacaMCPError as exc: + msg = f"Unable to retrieve Alpaca account overview: {exc}" + self.logger.warning(msg) + print(f"WARNING: {msg}") + return + except Exception as exc: # pragma: no cover - defensive logging + msg = f"Unexpected error while fetching Alpaca account overview: {exc}" + self.logger.warning(msg) + print(f"WARNING: {msg}") + return + + overview_lines = [ + "Alpaca account overview:", + account_text or "", + "", + "Open positions:", + positions_text or "", + "", + "Recent orders:", + orders_text or "", + ] + overview_message = "\n".join(overview_lines) + self.logger.info(overview_message) + print(overview_message) + + def set_manual_portfolio_snapshot(self, snapshot: "AccountSnapshot") -> None: + """Provide a pre-fetched Alpaca snapshot to reuse during orchestration.""" + self._manual_portfolio_snapshot = { + "account": snapshot.account_text, + "positions": snapshot.positions_text, + "orders": snapshot.orders_text, + } + + def clear_manual_portfolio_snapshot(self) -> None: + """Clear any cached snapshot so subsequent runs fetch live data.""" + self._manual_portfolio_snapshot = None + + def _collect_portfolio_context(self, symbols: List[str]) -> List[Dict[str, str]]: + symbols = symbols or [] + if self._manual_portfolio_snapshot: + cached = self._manual_portfolio_snapshot + snapshots: List[Dict[str, str]] = [] + for idx, symbol in enumerate(symbols): + snapshots.append( + { + "symbol": symbol.upper(), + "status": "alpaca_cached", + "account": cached.get("account", "") if idx == 0 else "", + "positions": cached.get("positions", "") if idx == 0 else "", + "orders": cached.get("orders", "") if idx == 0 else "", + "summary_prompt": "Cached Alpaca snapshot" if idx == 0 else "", + } + ) + return snapshots + + client = self._get_alpaca_client() + if client is None: + return [ + { + "symbol": symbol.upper(), + "status": "alpaca_disabled", + "account": "", + "positions": "", + "orders": "", + "summary_prompt": "Alpaca MCP disabled; using static portfolio profile only.", + } + for symbol in symbols + ] + + try: + account_text = client.fetch_account_info() + positions_text = client.fetch_positions() + orders_text = client.fetch_orders() + snapshots: List[Dict[str, str]] = [] + for idx, symbol in enumerate(symbols): + snapshots.append( + { + "symbol": symbol.upper(), + "status": "alpaca_connected", + "account": account_text if idx == 0 else "", + "positions": positions_text if idx == 0 else "", + "orders": orders_text if idx == 0 else "", + "summary_prompt": "Live Alpaca data available" if idx == 0 else "", + } + ) + return snapshots + except AlpacaMCPError as exc: + self.logger.warning("Alpaca MCP call failed: %s", exc) + return [ + { + "symbol": symbol.upper(), + "status": f"alpaca_error: {exc}", + "account": "" if idx else "", + "positions": "", + "orders": "", + "summary_prompt": "Unable to fetch Alpaca context." if idx == 0 else "", + } + for idx, symbol in enumerate(symbols) + ] + except Exception as exc: # pragma: no cover + self.logger.error("Unexpected Alpaca MCP failure: %s", exc) + return [ + { + "symbol": symbol.upper(), + "status": "alpaca_error", + "account": "", + "positions": "", + "orders": "", + "summary_prompt": "Unexpected error while fetching Alpaca context." if idx == 0 else "", + } + for idx, symbol in enumerate(symbols) + ] + + def _fetch_quick_signals(self, symbol: str, trade_date: str, lookback_days: int, limit: int) -> Dict[str, str]: + if not trade_date: + trade_date = date.today().isoformat() + + try: + trade_dt = datetime.fromisoformat(trade_date) + except ValueError: + try: + trade_dt = datetime.fromisoformat(f"{trade_date}T00:00:00") + except Exception: + trade_dt = datetime.today() + + trade_date_value = trade_dt.date() + start_dt = trade_date_value - timedelta(days=lookback_days) + + def safe_call(method: str, *args) -> str: + try: + return str(route_to_vendor(method, *args)) + except Exception as exc: # pragma: no cover + self.logger.debug("Quick signal fetch failed for %s: %s", symbol, exc) + return f"Failed to fetch {method}: {exc}" + + news_text = safe_call("get_news", symbol, start_dt.isoformat(), trade_date_value.isoformat()) + global_text = safe_call("get_global_news", trade_date_value.isoformat(), lookback_days, limit) + + def truncate(txt: str, max_chars: int = 2000) -> str: + if len(txt) <= max_chars: + return txt + return txt[: max_chars - 3] + "..." + + return { + "symbol": symbol.upper(), + "news": truncate(news_text, 1500), + "global": truncate(global_text, 1500), + } + + def check_market_status(self) -> Dict[str, Any]: + """Return the current Alpaca market clock status if available.""" + client = self._get_alpaca_client() + if client is None: + return {"is_open": True, "reason": "alpaca_disabled"} + try: + clock_text = client.fetch_market_clock() + except AlpacaMCPError as exc: + self.logger.warning("Unable to fetch market clock: %s", exc) + return {"is_open": False, "reason": f"clock_error: {exc}"} + except Exception as exc: # pragma: no cover + self.logger.warning("Unexpected market clock error: %s", exc) + return {"is_open": True, "reason": "clock_unavailable"} + + normalized = clock_text.lower() + is_open = "is open: yes" in normalized + return { + "is_open": is_open, + "clock_text": clock_text, + } + + def _generate_plan_with_llm(self, payload: Dict[str, Any]) -> Dict[str, Any]: + system_prompt = ( + "You are the sequential planning engine for TradingAgents. " + "Given the payload (account_summary, positions_summary, hypotheses, quick_signals, market_data, trade_policy), " + "recommend the next sequence of analysts/managers to involve and the immediate directive for the hypothesis. " + "Always reply with JSON containing: actions (array of role identifiers), next_decision (monitor|escalate|trade|execute), " + "notes (string), reasoning (array of short bullet explanations)." + ) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": json.dumps(payload)}, + ] + + try: + response = self.quick_thinking_llm.invoke(messages) + content = getattr(response, "content", None) + if isinstance(content, list): + content = "".join( + chunk.get("text", "") if isinstance(chunk, dict) else str(chunk) + for chunk in content + ) + if not content: + content = str(response) + except Exception as exc: # pragma: no cover + self.logger.warning("Sequential plan generation failed: %s", exc) + return {"error": str(exc)} + + structured: Dict[str, Any] + try: + structured = json.loads(content) + except json.JSONDecodeError: + structured = _extract_json_block(content) + + if not isinstance(structured, dict): + structured = {} + + actions = structured.get("actions") + if not isinstance(actions, list): + actions = [str(actions)] if actions else [] + + reasoning = structured.get("reasoning") + if not isinstance(reasoning, list): + reasoning = [str(reasoning)] if reasoning else [] + + plan = { + "actions": [str(item).strip().lower() for item in actions if str(item).strip()], + "next_decision": str(structured.get("next_decision") or "monitor").lower(), + "notes": str(structured.get("notes") or ""), + "reasoning": [str(item) for item in reasoning if str(item)], + } + + return { + "structured": plan, + "text": content, + } + + + def _maybe_execute_trade(self, final_state: Dict[str, Any], decision_text: str) -> Dict[str, Any]: + exec_cfg = self.trade_execution_config or {} + if not exec_cfg.get("enabled"): + return {"status": "disabled", "reason": "trade_execution_disabled"} + + action = self._extract_action(decision_text) + symbol = final_state.get("company_of_interest", "") + if not symbol: + return {"status": "skipped", "reason": "missing_symbol"} + + if action not in {"BUY", "SELL"}: + return {"status": "skipped", "reason": f"action_{action}"} + + quantity = float(exec_cfg.get("default_order_quantity", 0)) + if quantity <= 0: + return {"status": "skipped", "reason": "invalid_quantity"} + + client = self._get_alpaca_client() + if client is None: + return {"status": "failed", "reason": "alpaca_disabled"} + + payload = { + "symbol": symbol, + "side": "buy" if action == "BUY" else "sell", + "order_type": "market", + "time_in_force": exec_cfg.get("time_in_force", "day").upper(), + "quantity": float(quantity), + } + + try: + clock_text = client.fetch_market_clock() + if "Is Open: Yes" not in clock_text: + return {"status": "market_closed", "payload": payload, "clock": clock_text} + except AlpacaMCPError as exc: + self.logger.warning("Unable to fetch market clock: %s", exc) + except Exception as exc: # pragma: no cover + self.logger.warning("Unexpected market clock error: %s", exc) + + if exec_cfg.get("dry_run", True): + self.logger.info("[DRY RUN] Would submit Alpaca order: %s", payload) + return {"status": "dry_run", "payload": payload} + + try: + response_text = client.place_stock_order(payload) + self.logger.info("Alpaca MCP order submitted: %s", response_text) + return { + "status": "executed", + "payload": payload, + "response": response_text, + } + except AlpacaMCPError as exc: + self.logger.error("Order submission failed: %s", exc) + return {"status": "failed", "reason": str(exc), "payload": payload} + except Exception as exc: # pragma: no cover + self.logger.exception("Unexpected error during order submission") + return {"status": "failed", "reason": str(exc), "payload": payload} + + def _extract_action(self, decision_text: str) -> str: + if not decision_text: + return "UNKNOWN" + normalized = decision_text.upper() + if "FINAL TRANSACTION PROPOSAL" in normalized: + if "**BUY**" in normalized: + return "BUY" + if "**SELL**" in normalized: + return "SELL" + if "**HOLD**" in normalized: + return "HOLD" + + for keyword in ("BUY", "SELL", "HOLD"): + if keyword in normalized: + return keyword + if "TRADE" in normalized: + return "BUY" + return "UNKNOWN" + def _create_tool_nodes(self) -> Dict[str, ToolNode]: """Create tool nodes for different data sources using abstract methods.""" return { @@ -157,15 +572,22 @@ class TradingAgentsGraph: ), } - def propagate(self, company_name, trade_date): + def propagate(self, company_name=None, trade_date=None, *, initial_overrides: Optional[Dict[str, Any]] = None): """Run the trading agents graph for a company on a specific date.""" - self.ticker = company_name + company_value = (company_name or "").strip() + trade_date_value = str(trade_date) if trade_date else date.today().isoformat() + + self.ticker = company_value or "portfolio" + self.trade_date = trade_date_value # Initialize state init_agent_state = self.propagator.create_initial_state( - company_name, trade_date + company_value, trade_date_value ) + if initial_overrides: + init_agent_state.update(initial_overrides) + init_agent_state["portfolio_profile"] = self.portfolio_profile args = self.propagator.get_graph_args() if self.debug: @@ -185,54 +607,149 @@ class TradingAgentsGraph: # Store current state for reflection self.curr_state = final_state + preferred_ticker = final_state.get("target_ticker") or final_state.get("company_of_interest") + if preferred_ticker: + self.ticker = preferred_ticker # Log state self._log_state(trade_date, final_state) # Return decision and processed signal - return final_state, self.process_signal(final_state["final_trade_decision"]) + decision_text = final_state.get("final_trade_decision", "") + if not decision_text: + orchestrator_action = str(final_state.get("orchestrator_action") or "").strip() + if orchestrator_action: + decision_text = orchestrator_action.upper() + final_state["final_trade_decision"] = decision_text + if decision_text: + processed_decision = self.process_signal(decision_text) + else: + processed_decision = "" + execution_result = self._maybe_execute_trade(final_state, decision_text) + processed_result = { + "decision": processed_decision, + "execution": execution_result, + } + final_state["execution_result"] = execution_result + self._write_run_summary(final_state, processed_result) + return final_state, processed_result + + def execute_trade_directive(self, symbol: str, action: str) -> Dict[str, Any]: + """Execute a trade directive issued outside the standard graph run.""" + directive = (action or "").strip().upper() + minimal_state = {"company_of_interest": symbol} + return self._maybe_execute_trade(minimal_state, directive) def _log_state(self, trade_date, final_state): """Log the final state to a JSON file.""" + ticker_for_logs = final_state.get("target_ticker") or final_state.get("company_of_interest") or "portfolio" + invest_state = final_state.get("investment_debate_state") or {} + risk_state = final_state.get("risk_debate_state") or {} self.log_states_dict[str(trade_date)] = { - "company_of_interest": final_state["company_of_interest"], - "trade_date": final_state["trade_date"], - "market_report": final_state["market_report"], - "sentiment_report": final_state["sentiment_report"], - "news_report": final_state["news_report"], - "fundamentals_report": final_state["fundamentals_report"], + "company_of_interest": final_state.get("company_of_interest"), + "target_ticker": final_state.get("target_ticker"), + "trade_date": final_state.get("trade_date"), + "market_report": final_state.get("market_report", ""), + "sentiment_report": final_state.get("sentiment_report", ""), + "news_report": final_state.get("news_report", ""), + "fundamentals_report": final_state.get("fundamentals_report", ""), "investment_debate_state": { - "bull_history": final_state["investment_debate_state"]["bull_history"], - "bear_history": final_state["investment_debate_state"]["bear_history"], - "history": final_state["investment_debate_state"]["history"], - "current_response": final_state["investment_debate_state"][ - "current_response" - ], - "judge_decision": final_state["investment_debate_state"][ - "judge_decision" - ], + "bull_history": invest_state.get("bull_history", ""), + "bear_history": invest_state.get("bear_history", ""), + "history": invest_state.get("history", ""), + "current_response": invest_state.get("current_response", ""), + "judge_decision": invest_state.get("judge_decision", ""), + "count": invest_state.get("count", 0), }, - "trader_investment_decision": final_state["trader_investment_plan"], + "trader_investment_decision": final_state.get("trader_investment_plan", ""), "risk_debate_state": { - "risky_history": final_state["risk_debate_state"]["risky_history"], - "safe_history": final_state["risk_debate_state"]["safe_history"], - "neutral_history": final_state["risk_debate_state"]["neutral_history"], - "history": final_state["risk_debate_state"]["history"], - "judge_decision": final_state["risk_debate_state"]["judge_decision"], + "risky_history": risk_state.get("risky_history", ""), + "safe_history": risk_state.get("safe_history", ""), + "neutral_history": risk_state.get("neutral_history", ""), + "history": risk_state.get("history", ""), + "judge_decision": risk_state.get("judge_decision", ""), + "latest_speaker": risk_state.get("latest_speaker", ""), + "count": risk_state.get("count", 0), }, - "investment_plan": final_state["investment_plan"], - "final_trade_decision": final_state["final_trade_decision"], + "investment_plan": final_state.get("investment_plan", ""), + "final_trade_decision": final_state.get("final_trade_decision", ""), + "portfolio_summary": final_state.get("portfolio_summary"), + "orchestrator_status": final_state.get("orchestrator_status"), + "alpaca_account_text": final_state.get("alpaca_account_text"), + "alpaca_positions_text": final_state.get("alpaca_positions_text"), + "alpaca_orders_text": final_state.get("alpaca_orders_text"), + "execution_result": final_state.get("execution_result"), + "orchestrator_hypotheses": final_state.get("orchestrator_hypotheses", []), + "active_hypothesis": final_state.get("active_hypothesis"), + "orchestrator_focus_symbols": final_state.get("orchestrator_focus_symbols", []), + "orchestrator_quick_signals": final_state.get("orchestrator_quick_signals", {}), + "orchestrator_market_data": final_state.get("orchestrator_market_data", {}), + "orchestrator_ticker_plans": final_state.get("orchestrator_ticker_plans", {}), + "orchestrator_pending_tickers": final_state.get("orchestrator_pending_tickers", []), + "orchestrator_buying_power": final_state.get("orchestrator_buying_power"), + "orchestrator_cash_available": final_state.get("orchestrator_cash_available"), + "orchestrator_portfolio_value": final_state.get("orchestrator_portfolio_value"), + "scheduled_analysts_plan": final_state.get("scheduled_analysts_plan", []), + "orchestrator_action": final_state.get("orchestrator_action"), + "action_queue": final_state.get("action_queue", []), + "next_directive": final_state.get("next_directive"), + "planner_plan": final_state.get("planner_plan", {}), + "planner_notes": final_state.get("planner_notes", ""), } # Save to file - directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") + directory = Path(f"eval_results/{ticker_for_logs}/TradingAgentsStrategy_logs/") directory.mkdir(parents=True, exist_ok=True) with open( - f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", + f"eval_results/{ticker_for_logs}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", "w", ) as f: - json.dump(self.log_states_dict, f, indent=4) + json.dump(self.log_states_dict, f, indent=4, default=str) + + def _write_run_summary(self, final_state: Dict[str, Any], processed: Dict[str, Any]) -> None: + try: + results_dir = Path(self.config.get("results_dir", "./results")) + except Exception: + results_dir = Path("./results") + try: + results_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + ticker_for_summary = final_state.get("target_ticker") or self.ticker or "portfolio" + summary_path = results_dir / f"run_{ticker_for_summary}_{timestamp}.json" + summary = { + "ticker": final_state.get("target_ticker"), + "trade_date": final_state.get("trade_date"), + "orchestrator_summary": final_state.get("portfolio_summary"), + "orchestrator_status": final_state.get("orchestrator_status"), + "orchestrator_action": final_state.get("orchestrator_action"), + "orchestrator_hypotheses": final_state.get("orchestrator_hypotheses"), + "orchestrator_focus_symbols": final_state.get("orchestrator_focus_symbols"), + "orchestrator_quick_signals": final_state.get("orchestrator_quick_signals"), + "orchestrator_market_data": final_state.get("orchestrator_market_data"), + "orchestrator_ticker_plans": final_state.get("orchestrator_ticker_plans"), + "orchestrator_pending_tickers": final_state.get("orchestrator_pending_tickers"), + "orchestrator_buying_power": final_state.get("orchestrator_buying_power"), + "orchestrator_cash_available": final_state.get("orchestrator_cash_available"), + "orchestrator_portfolio_value": final_state.get("orchestrator_portfolio_value"), + "active_hypothesis": final_state.get("active_hypothesis"), + "scheduled_analysts_plan": final_state.get("scheduled_analysts_plan"), + "action_queue": final_state.get("action_queue"), + "next_directive": final_state.get("next_directive"), + "planner_plan": final_state.get("planner_plan"), + "planner_notes": final_state.get("planner_notes"), + "execution": processed.get("execution"), + "decision": processed.get("decision"), + } + with open(summary_path, "w", encoding="utf-8") as handle: + json.dump(summary, handle, indent=2, default=str) + try: + print("[Run Summary] Final decision:", summary.get("decision")) + print("[Run Summary] Execution status:", summary.get("execution")) + except Exception: + pass + except Exception as exc: # pragma: no cover + self.logger.warning("Failed to write run summary: %s", exc) def reflect_and_remember(self, returns_losses): """Reflect on decisions and update memory based on returns.""" diff --git a/tradingagents/integrations/__init__.py b/tradingagents/integrations/__init__.py new file mode 100644 index 00000000..1277d881 --- /dev/null +++ b/tradingagents/integrations/__init__.py @@ -0,0 +1,3 @@ +"""Integration helpers for TradingAgents.""" + +__all__ = ["alpaca_mcp"] diff --git a/tradingagents/integrations/alpaca_mcp/__init__.py b/tradingagents/integrations/alpaca_mcp/__init__.py new file mode 100644 index 00000000..ce71f5fd --- /dev/null +++ b/tradingagents/integrations/alpaca_mcp/__init__.py @@ -0,0 +1,6 @@ +"""Public interface for the Alpaca MCP integration.""" + +from .client import AlpacaMCPClient, AlpacaMCPError +from .config import AlpacaMCPConfig + +__all__ = ["AlpacaMCPClient", "AlpacaMCPError", "AlpacaMCPConfig"] diff --git a/tradingagents/integrations/alpaca_mcp/client.py b/tradingagents/integrations/alpaca_mcp/client.py new file mode 100644 index 00000000..f1f8c728 --- /dev/null +++ b/tradingagents/integrations/alpaca_mcp/client.py @@ -0,0 +1,221 @@ +"""Minimal client for calling tools on an Alpaca MCP server.""" + +from __future__ import annotations + +import asyncio +import logging +import shlex +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Optional + +from .config import AlpacaMCPConfig +from ..mcp_handshake import emit_console, perform_handshake + +try: # pragma: no cover - optional dependency during linting + from mcp.client.session import ClientSession + from mcp.client.streamable_http import streamablehttp_client + from mcp.client.stdio import StdioServerParameters, stdio_client +except ImportError: # pragma: no cover - surfaced at runtime with helpful error + ClientSession = None # type: ignore[assignment] + streamablehttp_client = None # type: ignore[assignment] + stdio_client = None # type: ignore[assignment] + StdioServerParameters = None # type: ignore[assignment] + + +class AlpacaMCPError(RuntimeError): + """Raised when the MCP client cannot satisfy a request.""" + + +class AlpacaMCPClient: + """Fire-and-forget interface for fetching portfolio context from Alpaca MCP.""" + + def __init__(self, config: AlpacaMCPConfig, logger: Optional[logging.Logger] = None) -> None: + self.config = config + self.logger = logger or logging.getLogger(__name__) + + def fetch_account_info(self) -> str: + return self._call_tool("get_account_info") + + def fetch_positions(self) -> str: + return self._call_tool("get_positions") + + def fetch_orders(self, limit: int = 25) -> str: + return self._call_tool("get_orders", {"status": "all", "limit": limit}) + + def place_stock_order(self, payload: Dict[str, Any]) -> str: + return self._call_tool("place_stock_order", payload) + + def close_position(self, payload: Dict[str, Any]) -> str: + return self._call_tool("close_position", payload) + + def fetch_market_clock(self) -> str: + return self._call_tool("get_market_clock") + + def verify_connection(self) -> bool: + """Check that the MCP server is reachable and exposes required tools.""" + + if not self.config.enabled: + msg = "Alpaca MCP disabled; skipping connectivity check." + self.logger.info(msg) + emit_console("INFO", msg) + return False + + try: + return asyncio.run(self._verify_async()) + except AlpacaMCPError as exc: + msg = f"Alpaca MCP connectivity probe failed: {exc}" + self.logger.warning(msg) + emit_console("WARNING", msg) + return False + except Exception as exc: # pragma: no cover - best-effort diagnostics + msg = f"Alpaca MCP connectivity probe failed: {exc}" + self.logger.warning(msg) + emit_console("WARNING", msg) + return False + + async def _verify_async(self) -> bool: + async with self._acquire_session() as session: + tools_response = await session.list_tools() + available = [getattr(tool, "name", "") for tool in getattr(tools_response, "tools", [])] + missing = self.config.required_toolset(available) + if missing: + msg = "Alpaca MCP connected but missing required tools: " + ", ".join(missing) + self.logger.warning(msg) + emit_console("WARNING", msg) + return False + tools_list = ", ".join(sorted(filter(None, available))) + msg = f"Alpaca MCP connectivity verified (tools={tools_list})" + self.logger.info(msg) + emit_console("INFO", msg) + return True + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _call_tool(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None) -> str: + if not self.config.enabled: + raise AlpacaMCPError("Alpaca MCP integration is disabled.") + if ClientSession is None or ( + self.config.transport == "http" and streamablehttp_client is None + ): + raise AlpacaMCPError( + "Package 'mcp' is required to use the Alpaca MCP integration. Install it with `pip install mcp`." + ) + + self.config.validate() + payload = arguments or {} + return asyncio.run(self._call_tool_async(tool_name, payload)) + + async def _call_tool_async(self, tool_name: str, arguments: Dict[str, Any], *, session: Optional["ClientSession"] = None, validate: bool = True) -> str: + try: + if session is None: + async with self._acquire_session() as managed_session: + return await self._call_tool_async(tool_name, arguments, session=managed_session, validate=validate) + + if validate: + tools_response = await session.list_tools() + available = [getattr(tool, "name", "") for tool in getattr(tools_response, "tools", [])] + missing = self.config.required_toolset(available) + if missing: + raise AlpacaMCPError( + "Alpaca MCP server is missing required tools: " + ", ".join(missing) + ) + if tool_name not in available: + raise AlpacaMCPError(f"Alpaca MCP server does not expose tool '{tool_name}'.") + + result = await session.call_tool(tool_name, arguments) + return self._extract_text(result) + except AlpacaMCPError: + raise + except BaseExceptionGroup as exc_group: # pragma: no cover - requires Python 3.11+ + message = self._flatten_exception_message(exc_group) + raise AlpacaMCPError(f"Failed to call Alpaca MCP tool '{tool_name}': {message}") from exc_group + except Exception as exc: + raise AlpacaMCPError(f"Failed to call Alpaca MCP tool '{tool_name}': {exc}") from exc + + @asynccontextmanager + async def _acquire_session(self) -> "ClientSession": + if ClientSession is None: + raise AlpacaMCPError( + "Package 'mcp' is required to use the Alpaca MCP integration. Install it with `pip install mcp`." + ) + + if self.config.transport == "http": + if streamablehttp_client is None: + raise AlpacaMCPError( + "HTTP transport requires the 'mcp' package. Install it with `pip install mcp`." + ) + base_url = self._build_http_base() + self.logger.debug("Connecting to Alpaca MCP via HTTP at %s", base_url) + async with streamablehttp_client( + url=base_url, + timeout=self.config.timeout_seconds, + ) as (read_stream, write_stream, _session_id_cb): + async with ClientSession(read_stream, write_stream) as session: + await perform_handshake( + session, + client_label="Alpaca", + logger=self.logger, + ) + yield session + return + + if self.config.transport == "stdio": + if stdio_client is None or StdioServerParameters is None: + raise AlpacaMCPError( + "STDIO transport requires the 'mcp' package. Install it with `pip install mcp`." + ) + if not self.config.command: + raise AlpacaMCPError("STDIO transport requires a command to launch the server.") + args = shlex.split(self.config.command) + if not args: + raise AlpacaMCPError("STDIO command is empty.") + + params = StdioServerParameters(command=args[0], args=args[1:]) + self.logger.debug("Launching Alpaca MCP via STDIO: %s", args) + async with stdio_client(params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await perform_handshake( + session, + client_label="Alpaca", + logger=self.logger, + ) + yield session + return + + raise AlpacaMCPError(f"Unsupported transport '{self.config.transport}'.") + + def _build_http_base(self) -> str: + if getattr(self.config, "base_url", ""): + return self.config.base_url.rstrip("/") + host = self.config.host + if host.startswith("http://") or host.startswith("https://"): + return host.rstrip("/") + return f"http://{host}:{self.config.port}" + + @staticmethod + def _extract_text(result: Any) -> str: + content = getattr(result, "content", None) + if content is None and isinstance(result, dict): + content = result.get("content") + if not content: + return str(result) + + fragments: List[str] = [] + for item in content: + text_value = getattr(item, "text", None) + if text_value is None and isinstance(item, dict): + text_value = item.get("text") + fragments.append(str(text_value) if text_value is not None else str(item)) + return "\n".join(fragment for fragment in fragments if fragment) + + @staticmethod + def _flatten_exception_message(exc: BaseException) -> str: + if isinstance(exc, BaseExceptionGroup): + parts: List[str] = [] + for item in exc.exceptions: + message = AlpacaMCPClient._flatten_exception_message(item) + if message: + parts.append(message) + return "; ".join(parts) + return str(exc) diff --git a/tradingagents/integrations/alpaca_mcp/config.py b/tradingagents/integrations/alpaca_mcp/config.py new file mode 100644 index 00000000..2e0e6880 --- /dev/null +++ b/tradingagents/integrations/alpaca_mcp/config.py @@ -0,0 +1,51 @@ +"""Configuration utilities for connecting to the Alpaca MCP server.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List + + +@dataclass +class AlpacaMCPConfig: + """Normalized connection information for the Alpaca MCP server.""" + + enabled: bool + transport: str + host: str + base_url: str + port: int + command: str + timeout_seconds: float + required_tools: List[str] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AlpacaMCPConfig": + required = data or {} + transport = (required.get("transport") or "http").lower() + return cls( + enabled=bool(required.get("enabled", False)), + transport=transport, + host=required.get("host", "127.0.0.1"), + base_url=required.get("base_url", ""), + port=int(required.get("port", 8000)), + command=required.get("command", ""), + timeout_seconds=float(required.get("timeout_seconds", 30.0)), + required_tools=list(required.get("required_tools", [])), + ) + + def validate(self) -> None: + if not self.enabled: + return + if self.transport not in {"http", "stdio"}: + raise ValueError(f"Unsupported Alpaca MCP transport '{self.transport}'.") + if self.transport == "http" and not (self.base_url or self.host): + raise ValueError("HTTP transport requires a host or base_url value.") + if self.transport == "stdio" and not self.command: + raise ValueError("STDIO transport requires a command to launch the server.") + + def required_toolset(self, available: Iterable[str]) -> List[str]: + if not self.required_tools: + return [] + available_set = set(available) + return [tool for tool in self.required_tools if tool not in available_set] diff --git a/tradingagents/integrations/mcp_handshake.py b/tradingagents/integrations/mcp_handshake.py new file mode 100644 index 00000000..8bbafe87 --- /dev/null +++ b/tradingagents/integrations/mcp_handshake.py @@ -0,0 +1,190 @@ +"""Utilities shared by MCP clients for the TradingAgents project. + +This module contains a single public helper, :func:`perform_handshake`, which +wraps the Model Context Protocol session bootstrapping sequence. The helper +adds a thin compatibility layer around the official Python SDK so we can supply +client metadata, negotiate capabilities, and emit useful debug logs without +duplicating this logic in every integration. + +The implementation favours graceful degradation – if the optional ``mcp`` +package is not installed, or if the runtime SDK version does not implement a +specific API entry point, we simply skip that portion of the handshake while +providing informative logging. This keeps the behaviour predictable in +development environments where the dependency may not be available yet. +""" + +from __future__ import annotations + +import asyncio +import importlib.metadata +import logging +from dataclasses import asdict, is_dataclass +from typing import Any, Dict, Optional + +try: # pragma: no cover - optional dependency during linting + from mcp.client.session import ClientSession +except ImportError: # pragma: no cover - surfaced at runtime with helpful error + ClientSession = Any # type: ignore[misc,assignment] + +try: # pragma: no cover - optional dependency during linting + from mcp.types import ClientCapabilities, Implementation +except ImportError: # pragma: no cover - surfaced at runtime when missing + ClientCapabilities = None # type: ignore[misc,assignment] + Implementation = None # type: ignore[misc,assignment] + + +_DEFAULT_PROTOCOL_VERSION = "2025-06-18" + + +def emit_console(level: str, message: str) -> None: + """Mirror log output to stdout when no logging handlers are configured.""" + + root_logger = logging.getLogger() + if not root_logger.handlers: + print(f"{level.upper()}: {message}") + + +def _detect_package_version() -> str: + """Return the installed TradingAgents version, falling back to ``0.0.0``.""" + + try: + return importlib.metadata.version("tradingagents") + except importlib.metadata.PackageNotFoundError: + return "0.0.0" + + +async def perform_handshake( + session: "ClientSession", + *, + client_label: str, + logger: logging.Logger, + capabilities: Optional[Any] = None, + protocol_version: str = _DEFAULT_PROTOCOL_VERSION, +) -> Dict[str, Any]: + """Execute the MCP initialization handshake and emit diagnostic metadata. + + Parameters + ---------- + session + A connected :class:`mcp.client.session.ClientSession` instance. + client_label + Human friendly tag used in log messages to identify the integration. + logger + Logger used for status updates. ``perform_handshake`` only emits at + ``INFO`` and ``DEBUG`` levels, so callers can opt in/out via standard + logging configuration. + capabilities + Optional per-client capability overrides. When omitted we fall back to + an empty :class:`ClientCapabilities` object if the SDK exposes one. If + the caller supplies a plain mapping, the helper forwards it unchanged to + the ``initialize`` call – the Python SDK accepts either Pydantic models + or raw dictionaries. + protocol_version + Requested MCP protocol version. Defaults to the latest spec revision we + target in this repository. + + Returns + ------- + dict + A dictionary with the initial handshake result. The structure matches + the underlying ``InitializeResult`` object but is normalised to basic + Python types so it can be safely logged or inspected by tests if + desired. + """ + + if ClientSession is Any: # pragma: no cover - defensive runtime guard + raise RuntimeError("perform_handshake cannot run without the 'mcp' package installed.") + + initialize_kwargs: Dict[str, Any] = {} + + if protocol_version: + initialize_kwargs["protocol_version"] = protocol_version + + if Implementation is not None: + initialize_kwargs["client_info"] = Implementation( + name=f"TradingAgents::{client_label}", + version=_detect_package_version(), + ) + else: + initialize_kwargs["client_info"] = { + "name": f"TradingAgents::{client_label}", + "version": _detect_package_version(), + } + + if capabilities is not None: + initialize_kwargs["capabilities"] = capabilities + elif ClientCapabilities is not None: + initialize_kwargs["capabilities"] = ClientCapabilities() + + try: + result = await session.initialize(**initialize_kwargs) + except TypeError: + # Older SDK versions did not support keyword arguments. Retry using the + # minimal signature, while still surfacing the original failure in debug + # logs so developers understand why metadata was omitted. + logger.debug( + "MCP initialize signature did not accept kwargs for %s, falling back to defaults.", + client_label, + exc_info=True, + ) + result = await session.initialize() + + # Normalise the result into a dictionary for consistent downstream usage. + payload: Dict[str, Any] + if hasattr(result, "model_dump"): + payload = result.model_dump() # type: ignore[assignment] + elif is_dataclass(result): + payload = asdict(result) # type: ignore[arg-type] + elif isinstance(result, dict): + payload = dict(result) + else: + payload = { + "protocolVersion": getattr(result, "protocolVersion", None) + or getattr(result, "protocol_version", None), + "capabilities": getattr(result, "capabilities", None), + "serverInfo": getattr(result, "serverInfo", None) + or getattr(result, "server_info", None), + "instructions": getattr(result, "instructions", None), + } + + protocol = payload.get("protocolVersion") or payload.get("protocol_version") + server_info = payload.get("serverInfo") or payload.get("server_info") + msg = f"{client_label} MCP handshake complete (protocol={protocol or 'unknown'}, server={server_info or 'n/a'})" + logger.info(msg) + emit_console("INFO", msg) + + instructions = payload.get("instructions") + if instructions: + instr_msg = f"{client_label} MCP server instructions: {instructions}" + logger.debug(instr_msg) + emit_console("DEBUG", instr_msg) + + # Send notifications/initialized when the SDK exposes the helper. The + # attribute name changed between releases, so we probe the common forms. + notification_senders = [ + getattr(session, "notify_initialized", None), + getattr(session, "notifications_initialized", None), + ] + + for sender in notification_senders: + if callable(sender): + try: + maybe_coro = sender() + if asyncio.iscoroutine(maybe_coro): + await maybe_coro + note = f"Sent notifications/initialized for {client_label}" + logger.debug(note) + emit_console("DEBUG", note) + break + except Exception as exc: # pragma: no cover - best-effort notification + fail_msg = ( + f"Unable to send notifications/initialized for {client_label}: {exc}" + ) + logger.debug(fail_msg, exc_info=True) + emit_console("DEBUG", fail_msg) + break + + return payload + + +__all__ = ["perform_handshake", "emit_console"] diff --git a/tradingagents/integrations/sequential_mcp/__init__.py b/tradingagents/integrations/sequential_mcp/__init__.py new file mode 100644 index 00000000..c2f26d9e --- /dev/null +++ b/tradingagents/integrations/sequential_mcp/__init__.py @@ -0,0 +1,6 @@ +"""Public interface for the Sequential Thinking MCP integration.""" + +from .client import SequentialMCPClient, SequentialMCPError +from .config import SequentialMCPConfig + +__all__ = ["SequentialMCPClient", "SequentialMCPError", "SequentialMCPConfig"] diff --git a/tradingagents/integrations/sequential_mcp/client.py b/tradingagents/integrations/sequential_mcp/client.py new file mode 100644 index 00000000..fe4409a1 --- /dev/null +++ b/tradingagents/integrations/sequential_mcp/client.py @@ -0,0 +1,207 @@ +"""Client for delegating planning to the Sequential Thinking MCP server.""" + +from __future__ import annotations + +import asyncio +import logging +import shlex +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Optional + +from .config import SequentialMCPConfig +from ..mcp_handshake import emit_console, perform_handshake + +try: # pragma: no cover - optional dependency during linting + from mcp.client.session import ClientSession + from mcp.client.streamable_http import streamablehttp_client + from mcp.client.stdio import StdioServerParameters, stdio_client +except ImportError: # pragma: no cover - surfaced at runtime with helpful error + ClientSession = None # type: ignore[assignment] + streamablehttp_client = None # type: ignore[assignment] + stdio_client = None # type: ignore[assignment] + StdioServerParameters = None # type: ignore[assignment] + + +class SequentialMCPError(RuntimeError): + """Raised when the Sequential Thinking MCP client cannot satisfy a request.""" + + +class SequentialMCPClient: + """Simple interface for requesting action plans from the Sequential Thinking MCP server.""" + + def __init__(self, config: SequentialMCPConfig, logger: Optional[logging.Logger] = None) -> None: + self.config = config + self.logger = logger or logging.getLogger(__name__) + + def generate_plan(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Request a plan from the sequential thinking server.""" + return self._call_tool("sequential_thinking", payload) + + def verify_connection(self) -> bool: + """Check that the Sequential MCP server is reachable and exposes its tool.""" + + if not self.config.enabled: + msg = "Sequential MCP disabled; skipping connectivity check." + self.logger.info(msg) + emit_console("INFO", msg) + return False + + try: + return asyncio.run(self._verify_async()) + except SequentialMCPError as exc: + msg = f"Sequential MCP connectivity probe failed: {exc}" + self.logger.warning(msg) + emit_console("WARNING", msg) + return False + except Exception as exc: # pragma: no cover - diagnostic logging only + msg = f"Sequential MCP connectivity probe failed: {exc}" + self.logger.warning(msg) + emit_console("WARNING", msg) + return False + + async def _verify_async(self) -> bool: + async with self._acquire_session() as session: + tools_response = await session.list_tools() + available = [getattr(tool, "name", "") for tool in getattr(tools_response, "tools", [])] + missing = self.config.required_toolset(available) + if missing: + msg = "Sequential MCP connected but missing required tools: " + ", ".join(missing) + self.logger.warning(msg) + emit_console("WARNING", msg) + return False + tools_list = ", ".join(sorted(filter(None, available))) + msg = f"Sequential MCP connectivity verified (tools={tools_list})" + self.logger.info(msg) + emit_console("INFO", msg) + return True + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _call_tool(self, tool_name: str, arguments: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + if not self.config.enabled: + raise SequentialMCPError("Sequential MCP integration is disabled.") + if ClientSession is None or ( + self.config.transport == "http" and streamablehttp_client is None + ): + raise SequentialMCPError( + "Package 'mcp' is required to use the Sequential Thinking MCP integration. Install it with `pip install mcp`." + ) + + self.config.validate() + payload = arguments or {} + return asyncio.run(self._call_tool_async(tool_name, payload)) + + async def _call_tool_async(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + try: + async with self._acquire_session() as session: + tools_response = await session.list_tools() + available = [getattr(tool, "name", "") for tool in getattr(tools_response, "tools", [])] + missing = self.config.required_toolset(available) + if missing: + raise SequentialMCPError( + "Sequential MCP server is missing required tools: " + ", ".join(missing) + ) + if tool_name not in available: + raise SequentialMCPError(f"Sequential MCP server does not expose tool '{tool_name}'.") + + result = await session.call_tool(tool_name, arguments) + return self._extract_content(result) + except SequentialMCPError: + raise + except BaseExceptionGroup as exc_group: # pragma: no cover + message = self._flatten_exception_message(exc_group) + raise SequentialMCPError(f"Failed to call Sequential MCP tool '{tool_name}': {message}") from exc_group + except Exception as exc: + raise SequentialMCPError(f"Failed to call Sequential MCP tool '{tool_name}': {exc}") from exc + + @asynccontextmanager + async def _acquire_session(self) -> "ClientSession": + if ClientSession is None: + raise SequentialMCPError( + "Package 'mcp' is required to use the Sequential Thinking MCP integration. Install it with `pip install mcp`." + ) + + if self.config.transport == "http": + if streamablehttp_client is None: + raise SequentialMCPError( + "HTTP transport requires the 'mcp' package. Install it with `pip install mcp`." + ) + base_url = self._build_http_base() + self.logger.debug("Connecting to Sequential MCP via HTTP at %s", base_url) + async with streamablehttp_client( + url=base_url, + timeout=self.config.timeout_seconds, + ) as (read_stream, write_stream, _session_id_cb): + async with ClientSession(read_stream, write_stream) as session: + await perform_handshake( + session, + client_label="Sequential", + logger=self.logger, + ) + yield session + return + + if self.config.transport == "stdio": + if stdio_client is None or StdioServerParameters is None: + raise SequentialMCPError( + "STDIO transport requires the 'mcp' package. Install it with `pip install mcp`." + ) + command = self.config.command or "python -m tradingagents.integrations.sequential_mcp.server" + args = shlex.split(command) + if not args: + raise SequentialMCPError("STDIO command is empty.") + + params = StdioServerParameters(command=args[0], args=args[1:]) + self.logger.debug("Launching Sequential MCP via STDIO: %s", args) + async with stdio_client(params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await perform_handshake( + session, + client_label="Sequential", + logger=self.logger, + ) + yield session + return + + raise SequentialMCPError(f"Unsupported transport '{self.config.transport}'.") + + def _build_http_base(self) -> str: + if self.config.base_url: + return self.config.base_url.rstrip("/") + host = self.config.host + if host.startswith("http://") or host.startswith("https://"): + return host.rstrip("/") + return f"http://{host}:{self.config.port}/mcp" + + @staticmethod + def _extract_content(result: Any) -> Dict[str, Any]: + content = getattr(result, "content", None) + if content is None and isinstance(result, dict): + content = result.get("content") + + fragments: List[str] = [] + if content: + for item in content: + text_value = getattr(item, "text", None) + if text_value is None and isinstance(item, dict): + text_value = item.get("text") + fragments.append(str(text_value) if text_value is not None else str(item)) + text = "\n".join(fragment for fragment in fragments if fragment) + + structured = getattr(result, "structured_content", None) + if structured is None and isinstance(result, dict): + structured = result.get("structured_content") or result.get("structuredContent") + + return {"text": text, "structured": structured, "raw": result} + + @staticmethod + def _flatten_exception_message(exc: BaseException) -> str: + if isinstance(exc, BaseExceptionGroup): + parts: List[str] = [] + for item in exc.exceptions: + message = SequentialMCPClient._flatten_exception_message(item) + if message: + parts.append(message) + return "; ".join(parts) + return str(exc) diff --git a/tradingagents/integrations/sequential_mcp/config.py b/tradingagents/integrations/sequential_mcp/config.py new file mode 100644 index 00000000..f8139b0d --- /dev/null +++ b/tradingagents/integrations/sequential_mcp/config.py @@ -0,0 +1,49 @@ +"""Configuration utilities for connecting to the Sequential Thinking MCP server.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List + + +@dataclass +class SequentialMCPConfig: + """Normalized connection information for the Sequential Thinking MCP server.""" + + enabled: bool + transport: str + host: str + base_url: str + port: int + command: str + timeout_seconds: float + required_tools: List[str] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SequentialMCPConfig": + required = data or {} + transport = (required.get("transport") or "http").lower() + return cls( + enabled=bool(required.get("enabled", False)), + transport=transport, + host=required.get("host", "127.0.0.1"), + base_url=required.get("base_url", ""), + port=int(required.get("port", 8000)), + command=required.get("command", ""), + timeout_seconds=float(required.get("timeout_seconds", 30.0)), + required_tools=list(required.get("required_tools", [])), + ) + + def validate(self) -> None: + if not self.enabled: + return + if self.transport not in {"http", "stdio"}: + raise ValueError(f"Unsupported Sequential Thinking MCP transport '{self.transport}'.") + if self.transport == "http" and not (self.base_url or self.host): + raise ValueError("HTTP transport requires a host or base_url value.") + + def required_toolset(self, available: Iterable[str]) -> List[str]: + if not self.required_tools: + return [] + available_set = set(available) + return [tool for tool in self.required_tools if tool not in available_set] diff --git a/tradingagents/integrations/sequential_mcp/server.py b/tradingagents/integrations/sequential_mcp/server.py new file mode 100644 index 00000000..1dbd49fa --- /dev/null +++ b/tradingagents/integrations/sequential_mcp/server.py @@ -0,0 +1,179 @@ +"""Lightweight Sequential Thinking MCP server bundled with TradingAgents.""" + +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List, Set + +from mcp.server.fastmcp import FastMCP, Context + +app = FastMCP( + name="Sequential Thinking Planner", + instructions=( + "Generate ordered action plans for the TradingAgents workflow. " + "Return a list of analyst/manager nodes to execute and any notes for the orchestrator." + ), +) + +_DEFAULT_ANALYST_ORDER = ["market", "news", "social", "fundamentals"] +_SUPPORT_STAGES = ["debate", "manager", "trader", "risk"] + + +def _to_float(value: Any) -> float: + if value is None: + return 0.0 + if isinstance(value, (int, float)): + return float(value) + text = str(value).strip() + if not text: + return 0.0 + cleaned = text.replace("$", "").replace(",", "") + try: + return float(cleaned) + except ValueError: + return 0.0 + + +def _normalise(role: str) -> str: + mapping = { + "market_analyst": "market", + "run_market": "market", + "news_analyst": "news", + "run_news": "news", + "social_analyst": "social", + "fundamental": "fundamentals", + "fundamentals_analyst": "fundamentals", + "research_manager": "manager", + "risk_manager": "risk", + "stop": "end", + } + key = (role or "").strip().lower() + return mapping.get(key, key) + + +def _append_unique(target: List[str], items: List[str]) -> None: + seen = set(target) + for value in items: + if value and value not in seen: + target.append(value) + seen.add(value) + + +@app.tool() +async def sequential_thinking(ctx: Context, request: Dict[str, Any]) -> Dict[str, Any]: + """Return a sequential plan for TradingAgents.""" + + active = (request or {}).get("active_hypothesis") or {} + immediate_raw = str(active.get("immediate_actions") or active.get("action") or "").lower() + immediate_action = immediate_raw if immediate_raw in {"monitor", "escalate", "trade", "execute"} else "monitor" + priority_val = float(active.get("priority") or 0) + focus_symbol = str(request.get("focus_symbol") or active.get("ticker") or "").upper() + + required = [ + _normalise(role) + for role in active.get("required_analysts", []) + ] + required = [role for role in required if role in {"market", "news", "social", "fundamentals"}] + + actions: List[str] = [] + if required: + _append_unique(actions, required) + else: + _append_unique(actions, _DEFAULT_ANALYST_ORDER) + + notes_parts = [] + reasoning: List[str] = [] + summary = request.get("summary") or "" + if summary: + notes_parts.append(summary) + reasoning.append(f"Initial directive: {immediate_action.upper()}") + if immediate_action in {"monitor", "escalate", "trade", "execute"}: + notes_parts.append(f"Directive: {immediate_action.upper()}") + portfolio = request.get("account_summary") or {} + buying_power = portfolio.get("buying_power") or portfolio.get("buying_power_usd") + cash = portfolio.get("cash") or portfolio.get("cash_usd") + buying_power_val = _to_float(buying_power) + cash_val = _to_float(cash) + portfolio_value = _to_float(portfolio.get("portfolio_value") or portfolio.get("equity")) + if buying_power or cash: + notes_parts.append( + "Capital -> " + + ", ".join(filter(None, [f"Cash: {cash}" if cash else "", f"Buying Power: {buying_power}" if buying_power else ""])) + ) + + trade_policy = request.get("trade_policy") or {} + priority_threshold = float(trade_policy.get("priority_threshold", 0.8)) + min_cash_abs = float(trade_policy.get("min_cash_absolute", 0)) + min_cash_ratio = float(trade_policy.get("min_cash_ratio", 0)) + min_cash_required = max(min_cash_abs, portfolio_value * min_cash_ratio) + + positions_summary = request.get("positions_summary") or [] + held_symbols: Set[str] = set() + for pos in positions_summary: + symbol = str(pos.get("symbol") or pos.get("symbol:") or "").upper() + qty_val = _to_float(pos.get("quantity") or pos.get("qty") or 0) + if symbol and qty_val != 0: + held_symbols.add(symbol) + + reasoning.append( + f"Policy thresholds -> priority >= {priority_threshold:.2f}, min cash ${min_cash_required:,.0f}" + ) + + if immediate_action in {"", "monitor"} and focus_symbol: + reasoning.append( + f"Evaluating {focus_symbol}: priority {priority_val:.2f}, buying power ${buying_power_val:,.0f}" + ) + if priority_val >= priority_threshold: + if focus_symbol not in held_symbols and buying_power_val >= min_cash_required: + immediate_action = "trade" + notes_parts.append(f"Auto-upgraded to TRADE for {focus_symbol}") + reasoning.append("Priority high and sufficient buying power -> promote to TRADE") + elif buying_power_val > 0: + immediate_action = "escalate" + notes_parts.append(f"Escalate {focus_symbol} due to priority {priority_val:.2f}") + reasoning.append("Priority high but capital reserved -> escalate to manager") + else: + notes_parts.append("Insufficient buying power to escalate") + reasoning.append("Insufficient buying power -> remain monitoring") + else: + reasoning.append("Priority below threshold -> remain monitoring") + + if immediate_action in {"trade", "execute"}: + _append_unique(actions, ["debate", "manager"]) + _append_unique(actions, ["trader"]) + _append_unique(actions, ["risk"]) + if focus_symbol: + notes_parts.append( + f"Queue trader for {focus_symbol} (buying power ${buying_power_val:,.0f}, cash ${cash_val:,.0f})" + ) + reasoning.append("Trader and risk review queued for execution") + elif immediate_action == "escalate": + _append_unique(actions, ["debate", "manager"]) + if focus_symbol: + notes_parts.append(f"Manager review requested for {focus_symbol}") + reasoning.append("Escalation path via manager") + else: + if required: + _append_unique(actions, required) + else: + _append_unique(actions, _DEFAULT_ANALYST_ORDER) + reasoning.append("Maintain analyst coverage with monitoring loop") + + return { + "actions": actions, + "next_decision": immediate_action, + "notes": "\n".join(notes_parts).strip(), + "reasoning": reasoning, + } + + +async def _main_async() -> None: + await app.run_stdio_async() + + +def main() -> None: + asyncio.run(_main_async()) + + +if __name__ == "__main__": + main() diff --git a/tradingagents/services/__init__.py b/tradingagents/services/__init__.py new file mode 100644 index 00000000..dfc264ff --- /dev/null +++ b/tradingagents/services/__init__.py @@ -0,0 +1,24 @@ +"""Service layer utilities for TradingAgents.""" + +from .account import AccountService, AccountSnapshot +from .auto_trade import ( + AutoTradeService, + AutoTradeResult, + TickerDecision, + SequentialPlan, + StrategyDirective, +) +from .responses_auto_trade import ResponsesAutoTradeService +from .memory import TickerMemoryStore + +__all__ = [ + "AccountService", + "AccountSnapshot", + "AutoTradeService", + "AutoTradeResult", + "ResponsesAutoTradeService", + "TickerMemoryStore", + "TickerDecision", + "SequentialPlan", + "StrategyDirective", +] diff --git a/tradingagents/services/account.py b/tradingagents/services/account.py new file mode 100644 index 00000000..462f3572 --- /dev/null +++ b/tradingagents/services/account.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional + +from tradingagents.integrations.alpaca_mcp import AlpacaMCPClient, AlpacaMCPConfig, AlpacaMCPError + + +@dataclass +class AccountSnapshot: + """Structured representation of the Alpaca account state.""" + + fetched_at: datetime + account_text: str + positions_text: str + orders_text: str + account: Dict[str, Any] + positions: List[Dict[str, Any]] + orders: List[Dict[str, Any]] + + def buying_power(self) -> float: + value = self.account.get("buying_power") or self.account.get("buying_power_usd") + return _as_float(value) + + def cash(self) -> float: + value = self.account.get("cash") or self.account.get("cash_usd") + return _as_float(value) + + def portfolio_value(self) -> float: + value = ( + self.account.get("portfolio_value") + or self.account.get("equity") + or self.account.get("equity_value") + ) + return _as_float(value) + + def position_symbols(self) -> List[str]: + symbols = [] + for position in self.positions: + symbol = str(position.get("symbol") or position.get("symbol:") or "").upper() + qty = _as_float(position.get("quantity") or position.get("qty") or 0) + if symbol and qty != 0: + symbols.append(symbol) + return symbols + + +class AccountService: + """Fetch and cache Alpaca MCP account information.""" + + def __init__(self, alpaca_config: Dict[str, Any], logger: Optional[logging.Logger] = None) -> None: + config = AlpacaMCPConfig.from_dict(alpaca_config or {}) + self.client = AlpacaMCPClient(config, logger=logger) + self.logger = logger or logging.getLogger(__name__) + self._snapshot: Optional[AccountSnapshot] = None + self.enabled = bool(getattr(self.client.config, "enabled", False)) + if not self.enabled: + self.logger.info("Alpaca MCP integration disabled; account snapshot will be unavailable.") + + def refresh(self) -> AccountSnapshot: + """Fetch the latest account snapshot from the Alpaca MCP server.""" + + if not self.enabled: + raise RuntimeError( + "Alpaca MCP integration is disabled. Set ALPACA_MCP_ENABLED=true (and related connection settings) to use the auto-trade workflow." + ) + + import asyncio + + async def _fetch_all() -> Dict[str, str]: + async with self.client._acquire_session() as session: # type: ignore[attr-defined] + account_text = await self.client._call_tool_async("get_account_info", {}, session=session) + positions_text = await self.client._call_tool_async("get_positions", {}, session=session, validate=False) + orders_text = await self.client._call_tool_async("get_orders", {"status": "all", "limit": 50}, session=session, validate=False) + return { + "account": account_text, + "positions": positions_text, + "orders": orders_text, + } + + try: + texts = asyncio.run(_fetch_all()) + except AlpacaMCPError as exc: + raise RuntimeError(f"Failed to retrieve Alpaca account snapshot: {exc}") from exc + except Exception as exc: + raise RuntimeError(f"Failed to retrieve Alpaca account snapshot: {exc}") from exc + + snapshot = AccountSnapshot( + fetched_at=datetime.utcnow(), + account_text=texts["account"], + positions_text=texts["positions"], + orders_text=texts["orders"], + account=_parse_key_values(texts["account"]), + positions=_parse_position_blocks(texts["positions"]), + orders=_parse_order_blocks(texts["orders"]), + ) + self._snapshot = snapshot + return snapshot + + @property + def snapshot(self) -> Optional[AccountSnapshot]: + return self._snapshot + + +def _parse_key_values(text: str) -> Dict[str, Any]: + data: Dict[str, Any] = {} + pattern = re.compile(r"^([A-Za-z0-9 _/-]+):\s*(.+)$") + for line in text.splitlines(): + line = line.strip() + if not line or line.endswith(":"): + continue + match = pattern.match(line) + if not match: + continue + key = match.group(1).strip().lower().replace(" ", "_") + value = match.group(2).strip() + data[key] = value + return data + + +def _parse_position_blocks(text: str) -> List[Dict[str, Any]]: + if not text or "No open positions" in text: + return [] + blocks = [] + current: Dict[str, Any] = {} + for raw_line in text.splitlines(): + line = raw_line.strip() + if not line: + continue + if line.startswith("Symbol:") and current: + blocks.append(current) + current = {} + if ":" in line: + key, value = line.split(":", 1) + current[key.strip().lower().replace(" ", "_")] = value.strip() + if current: + blocks.append(current) + return blocks + + +def _parse_order_blocks(text: str) -> List[Dict[str, Any]]: + if not text or "No all orders" in text or "No orders" in text: + return [] + blocks = [] + current: Dict[str, Any] = {} + for raw_line in text.splitlines(): + line = raw_line.strip() + if not line: + continue + if line.startswith("Order ID:") and current: + blocks.append(current) + current = {} + if ":" in line: + key, value = line.split(":", 1) + current[key.strip().lower().replace(" ", "_")] = value.strip() + if current: + blocks.append(current) + return blocks + + +def _as_float(value: Any) -> float: + if value is None: + return 0.0 + if isinstance(value, (int, float)): + return float(value) + text = str(value).strip() + if not text: + return 0.0 + cleaned = text.replace("$", "").replace(",", "") + try: + return float(cleaned) + except ValueError: + return 0.0 diff --git a/tradingagents/services/auto_trade.py b/tradingagents/services/auto_trade.py new file mode 100644 index 00000000..0f90f104 --- /dev/null +++ b/tradingagents/services/auto_trade.py @@ -0,0 +1,628 @@ +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import date, datetime, timedelta +from typing import Any, Dict, List, Optional, Set, Tuple + +from tradingagents.graph.propagation import Propagator +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.services.account import AccountSnapshot + + +@dataclass +class SequentialPlan: + actions: List[str] + next_decision: str + notes: str + reasoning: List[str] = field(default_factory=list) + + +@dataclass +class StrategyDirective: + name: str + horizon_hours: float + target_pct: float + stop_pct: float + success_metric: str + failure_metric: str + follow_up: str = "reevaluate" + urgency: str = "medium" + deadline: Optional[str] = "" + notes: str = "" + success_price: Optional[float] = None + failure_price: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "horizon_hours": self.horizon_hours, + "target_pct": self.target_pct, + "stop_pct": self.stop_pct, + "success_metric": self.success_metric, + "failure_metric": self.failure_metric, + "follow_up": self.follow_up, + "urgency": self.urgency, + "deadline": self.deadline, + "notes": self.notes, + "success_price": self.success_price, + "failure_price": self.failure_price, + } + + +@dataclass +class TickerDecision: + ticker: str + hypothesis: Dict[str, Any] + sequential_plan: SequentialPlan + action_queue: List[str] + immediate_action: str + priority: float + final_decision: str = "" + trader_plan: str = "" + final_notes: str = "" + strategy: StrategyDirective = None # type: ignore + + def to_dict(self) -> Dict[str, Any]: + return { + "ticker": self.ticker, + "hypothesis": self.hypothesis, + "sequential_plan": { + "actions": self.sequential_plan.actions, + "next_decision": self.sequential_plan.next_decision, + "notes": self.sequential_plan.notes, + "reasoning": self.sequential_plan.reasoning, + }, + "action_queue": self.action_queue, + "immediate_action": self.immediate_action, + "priority": self.priority, + "final_decision": self.final_decision, + "trader_plan": self.trader_plan, + "final_notes": self.final_notes, + "strategy": self.strategy.to_dict() if self.strategy else None, + } + + +@dataclass +class AutoTradeResult: + focus_tickers: List[str] + decisions: List[TickerDecision] + account_snapshot: AccountSnapshot + raw_state: Dict[str, Any] + + def summary(self) -> Dict[str, Any]: + return { + "focus_tickers": self.focus_tickers, + "decisions": [decision.to_dict() for decision in self.decisions], + "buying_power": self.account_snapshot.buying_power(), + "cash": self.account_snapshot.cash(), + "portfolio_value": self.account_snapshot.portfolio_value(), + "fetched_at": self.account_snapshot.fetched_at.isoformat(), + } + + +class AutoTradeService: + """High level orchestration of the auto-trade workflow.""" + + def __init__( + self, + config: Dict[str, Any], + graph: Optional[TradingAgentsGraph] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + self.config = config + self.graph = graph or TradingAgentsGraph(config=config) + self.logger = logger or logging.getLogger(__name__) + self.propagator: Propagator = getattr(self.graph, "propagator", Propagator()) + + def run( + self, + snapshot: AccountSnapshot, + *, + focus_override: Optional[List[str]] = None, + allow_market_closed: bool = False, + ) -> AutoTradeResult: + """Execute the auto-trade workflow using the provided account snapshot.""" + + auto_trade_cfg = self.config.get("auto_trade", {}) or {} + skip_when_market_closed = bool(auto_trade_cfg.get("skip_when_market_closed")) + if skip_when_market_closed and not allow_market_closed: + checker = getattr(self.graph, "check_market_status", None) + status = checker() if callable(checker) else None + if status and not status.get("is_open", False): + clock_text = status.get("clock_text") + message = ( + "Auto-trade skipped: market is currently closed. " + "Set AUTO_TRADE_SKIP_WHEN_MARKET_CLOSED=false to override." + ) + if clock_text: + message = f"{message}\n{clock_text.strip()}" + elif status.get("reason"): + message = f"{message}\nReason: {status['reason']}" + self.logger.info(message) + return AutoTradeResult( + focus_tickers=[], + decisions=[], + account_snapshot=snapshot, + raw_state={ + "skip_reason": message, + "market_clock": clock_text, + }, + ) + + mode = str(auto_trade_cfg.get("mode") or "graph").lower() + if mode == "responses": + from .responses_auto_trade import ResponsesAutoTradeService + + responses_service = ResponsesAutoTradeService( + config=self.config, + graph=self.graph, + logger=self.logger, + ) + result = responses_service.run(snapshot, focus_override=focus_override) + if hasattr(self.graph, "clear_manual_portfolio_snapshot"): + try: + self.graph.clear_manual_portfolio_snapshot() # type: ignore[attr-defined] + except Exception: + pass + return result + + if hasattr(self.graph, "set_manual_portfolio_snapshot"): + try: + self.graph.set_manual_portfolio_snapshot(snapshot) # type: ignore[attr-defined] + except Exception: # pragma: no cover - defensive + self.logger.debug("Unable to seed manual portfolio snapshot", exc_info=True) + + seed_symbols = focus_override or self._determine_focus_tickers(snapshot) + worklist = list(dict.fromkeys((symbol or "").upper() for symbol in seed_symbols if symbol)) + if not worklist: + worklist = ["SPY"] + + max_runs = auto_trade_cfg.get("max_tickers") + try: + max_runs_int = int(max_runs) if max_runs is not None else 12 + except (TypeError, ValueError): + max_runs_int = 12 + if max_runs_int <= 0: + max_runs_int = len(worklist) or 1 + + decisions: List[TickerDecision] = [] + raw_state: Dict[str, Any] = {} + visited: Set[str] = set() + discovered_order: List[str] = [] + discovered_seen: Set[str] = set() + cached_quick_signals: Dict[str, Any] = {} + cached_market_data: Dict[str, Any] = {} + cached_ticker_plans: Dict[str, Any] = {} + cached_pending: List[str] = [] + + while worklist and len(visited) < max_runs_int: + ticker = (worklist.pop(0) or "").upper() + if not ticker or ticker in visited: + continue + + initial_overrides: Dict[str, Any] = {} + if cached_quick_signals: + initial_overrides["orchestrator_quick_signals"] = dict(cached_quick_signals) + if cached_market_data: + initial_overrides["orchestrator_market_data"] = dict(cached_market_data) + if cached_ticker_plans: + initial_overrides["orchestrator_ticker_plans"] = dict(cached_ticker_plans) + if cached_pending: + initial_overrides["orchestrator_pending_tickers"] = list(cached_pending) + if discovered_order: + initial_overrides["orchestrator_focus_symbols"] = list(discovered_order) + initial_overrides["orchestrator_focus_override"] = [ticker] + + overrides = initial_overrides or None + + try: + final_state, processed = self.graph.propagate( + ticker, + date.today().isoformat(), + initial_overrides=overrides, + ) + except Exception as exc: # pragma: no cover - surfaced to CLI + self.logger.exception("Graph propagation failed for %s", ticker) + raw_state[ticker] = {"error": str(exc)} + visited.add(ticker) + continue + + raw_state[ticker] = { + "final_state": final_state, + "processed": processed, + } + + try: + decision = self._decision_from_state(ticker, final_state, processed) + decisions.append(decision) + if decision.ticker not in discovered_seen: + discovered_order.append(decision.ticker) + discovered_seen.add(decision.ticker) + except Exception as exc: # pragma: no cover - best-effort diagnostics + self.logger.exception("Failed to build decision for %s", ticker) + finally: + visited.add(ticker) + + new_symbols = self._extract_focus_symbols(final_state) + for sym in new_symbols: + if sym not in discovered_seen: + discovered_order.append(sym) + discovered_seen.add(sym) + if sym and sym not in visited and sym not in worklist: + worklist.append(sym) + + quick_signals_state = final_state.get("orchestrator_quick_signals") or {} + if isinstance(quick_signals_state, dict): + for key, value in quick_signals_state.items(): + key_u = str(key).upper() + if value: + cached_quick_signals[key_u] = value + + market_data_state = final_state.get("orchestrator_market_data") or {} + if isinstance(market_data_state, dict): + for key, value in market_data_state.items(): + key_u = str(key).upper() + if value: + cached_market_data[key_u] = value + + ticker_plans_state = final_state.get("orchestrator_ticker_plans") or {} + if isinstance(ticker_plans_state, dict): + for key, value in ticker_plans_state.items(): + key_u = str(key).upper() + if value: + cached_ticker_plans[key_u] = value + + pending_state = final_state.get("orchestrator_pending_tickers") or [] + if isinstance(pending_state, list): + cached_pending = [str(item).upper() for item in pending_state if str(item).strip()] + + focus_tickers = list(dict.fromkeys(discovered_order)) + if not focus_tickers: + focus_tickers = list(dict.fromkeys(decision.ticker for decision in decisions)) + if not focus_tickers: + focus_tickers = worklist[:] + if not focus_tickers: + focus_tickers = list(dict.fromkeys((symbol or "").upper() for symbol in seed_symbols if symbol)) or ["SPY"] + + if hasattr(self.graph, "clear_manual_portfolio_snapshot"): + try: + self.graph.clear_manual_portfolio_snapshot() # type: ignore[attr-defined] + except Exception: # pragma: no cover - defensive + pass + + return AutoTradeResult( + focus_tickers=focus_tickers, + decisions=decisions, + account_snapshot=snapshot, + raw_state=raw_state, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _determine_focus_tickers(self, snapshot: AccountSnapshot) -> List[str]: + universe_raw = self.config.get("portfolio_orchestrator", {}).get("universe", "") + universe = [sym.strip().upper() for sym in universe_raw.split(",") if sym.strip()] + holdings = snapshot.position_symbols() + combined: List[str] = [] + for symbol in list(dict.fromkeys(universe + holdings)): + if symbol: + combined.append(symbol) + return combined or ["SPY"] + + def _decision_from_state( + self, + requested_ticker: str, + final_state: Dict[str, Any], + processed: Dict[str, Any], + ) -> TickerDecision: + focus_symbol = str(final_state.get("target_ticker") or requested_ticker or "").upper() + hypotheses = final_state.get("orchestrator_hypotheses") or [] + active = final_state.get("active_hypothesis") or {} + hypothesis = _select_hypothesis_for_ticker(focus_symbol, hypotheses, active) + + plans = final_state.get("orchestrator_ticker_plans") + plan_raw = plans.get(focus_symbol) if isinstance(plans, dict) else None + if plan_raw is None and isinstance(plans, dict): + plan_raw = plans.get(requested_ticker.upper()) + actions, next_decision, plan_notes, reasoning = self._parse_plan(plan_raw) + + immediate_action = str( + final_state.get("orchestrator_action") + or hypothesis.get("immediate_actions") + or next_decision + or "" + ).strip().lower() + if not immediate_action: + immediate_action = "monitor" + if not next_decision: + next_decision = immediate_action + + try: + priority = float(hypothesis.get("priority") or 0.0) + except (TypeError, ValueError): + priority = 0.0 + + action_queue = self._string_list(final_state.get("action_queue"), lower=True) + + processed_decision = processed.get("decision") if isinstance(processed, dict) else None + final_decision = str(processed_decision or "").strip().upper() + if not final_decision and immediate_action: + final_decision = immediate_action.upper() + if final_decision == "PLEASE PROVIDE THE PARAGRAPH OR FINANCIAL REPORT FOR ANALYSIS.": + final_decision = immediate_action.upper() + final_trade_text = str(final_state.get("final_trade_decision") or "").strip() + if final_trade_text: + if final_decision and final_trade_text.lower() not in final_decision.lower(): + final_decision = f"{final_decision} | {final_trade_text}" + elif not final_decision: + final_decision = final_trade_text + + trader_plan = str( + final_state.get("trader_investment_plan") + or final_state.get("investment_plan") + or "" + ).strip() + + notes_parts: List[str] = [] + for key in ("portfolio_summary", "planner_notes"): + value = final_state.get(key) + if value: + notes_parts.append(str(value).strip()) + if plan_notes: + notes_parts.append(plan_notes) + execution = processed.get("execution") if isinstance(processed, dict) else None + if execution: + notes_parts.append("Execution: " + json.dumps(execution, default=str)) + final_notes = "\n\n".join(dict.fromkeys(part for part in notes_parts if part)) + + sequential_plan = SequentialPlan( + actions=actions, + next_decision=next_decision, + notes=plan_notes, + reasoning=reasoning, + ) + + strategy_raw = final_state.get("orchestrator_strategy") or hypothesis.get("strategy") + strategy = resolve_strategy_directive(self.config, strategy_raw) + + return TickerDecision( + ticker=focus_symbol or requested_ticker.upper(), + hypothesis=hypothesis, + sequential_plan=sequential_plan, + action_queue=action_queue, + immediate_action=immediate_action, + priority=priority, + final_decision=final_decision, + trader_plan=trader_plan, + final_notes=final_notes, + strategy=strategy, + ) + + def _parse_plan(self, plan_raw: Any) -> Tuple[List[str], str, str, List[str]]: + actions: List[str] = [] + next_decision = "" + notes = "" + reasoning: List[str] = [] + + structured: Any = None + if isinstance(plan_raw, dict): + structured = plan_raw.get("structured") + if structured is None and any(key in plan_raw for key in ("actions", "steps", "reasoning", "next_decision")): + structured = { + key: plan_raw.get(key) + for key in ("actions", "steps", "reasoning", "next_decision", "notes") + if plan_raw.get(key) is not None + } + if structured is None: + structured = self._extract_json_candidate(plan_raw.get("text") or plan_raw.get("plan")) + elif isinstance(plan_raw, str): + structured = self._extract_json_candidate(plan_raw) + + if isinstance(structured, dict): + actions = self._string_list(structured.get("actions") or structured.get("steps"), lower=True) + next_decision = str(structured.get("next_decision") or "").strip().lower() + notes = str(structured.get("notes") or notes or "") + reasoning = self._string_list(structured.get("reasoning")) + elif isinstance(structured, list): + actions = self._string_list(structured, lower=True) + + if not actions and isinstance(plan_raw, dict): + actions = self._string_list(plan_raw.get("actions"), lower=True) + + if not notes and isinstance(plan_raw, dict): + note_candidate = plan_raw.get("notes") or plan_raw.get("text") or plan_raw.get("error") + if isinstance(note_candidate, str): + notes = note_candidate.strip() + + if not reasoning and isinstance(plan_raw, dict): + reasoning = self._string_list(plan_raw.get("reasoning")) + + return actions, next_decision, notes, reasoning + + @staticmethod + def _extract_json_candidate(text: Optional[str]) -> Optional[Any]: + if not text or not isinstance(text, str): + return None + candidate = text.strip() + if not candidate: + return None + if candidate.startswith("```"): + parts = candidate.split("```") + for part in parts: + segment = part.strip() + if not segment: + continue + try: + return json.loads(segment) + except json.JSONDecodeError: + continue + return None + try: + return json.loads(candidate) + except json.JSONDecodeError: + return None + + @staticmethod + def _string_list(value: Any, *, lower: bool = False) -> List[str]: + if value is None: + return [] + if isinstance(value, (list, tuple, set)): + iterable = list(value) + else: + iterable = [value] + + result: List[str] = [] + for item in iterable: + if item is None: + continue + if isinstance(item, dict): + text_value = ( + item.get("action") + or item.get("name") + or item.get("tool") + or item.get("role") + or item.get("value") + ) + text = str(text_value or "").strip() + else: + text = str(item).strip() + if not text: + continue + result.append(text.lower() if lower else text) + return result + + @staticmethod + def _as_iterable(value: Any) -> List[Any]: + if value is None: + return [] + if isinstance(value, (list, tuple, set)): + return list(value) + return [value] + + def _extract_focus_symbols(self, final_state: Dict[str, Any]) -> List[str]: + collected: List[str] = [] + seen: Set[str] = set() + + def push(symbol: Any) -> None: + text = str(symbol or "").strip().upper() + if text and text not in seen: + seen.add(text) + collected.append(text) + + push(final_state.get("target_ticker")) + push(final_state.get("company_of_interest")) + active = final_state.get("active_hypothesis") + if isinstance(active, dict): + push(active.get("ticker")) + + for key in ("orchestrator_focus_symbols", "orchestrator_pending_tickers"): + for item in self._as_iterable(final_state.get(key)): + if isinstance(item, dict): + push(item.get("ticker")) + else: + push(item) + + return collected + + +def _select_hypothesis_for_ticker( + ticker: str, + hypotheses: List[Dict[str, Any]], + active: Dict[str, Any], +) -> Dict[str, Any]: + ticker = ticker.upper() + if str(active.get("ticker") or "").upper() == ticker: + return active + for hypothesis in hypotheses: + if str(hypothesis.get("ticker") or "").upper() == ticker: + return hypothesis + return active + + +def resolve_strategy_directive( + config: Dict[str, Any], + overrides: Optional[Any] = None, +) -> StrategyDirective: + strategies_cfg = config.get("trading_strategies", {}) or {} + presets = strategies_cfg.get("presets", {}) or {} + + if isinstance(overrides, dict): + overrides_dict = overrides + elif isinstance(overrides, str): + overrides_dict = {"name": overrides} + else: + overrides_dict = {} + + def _preset_fallback() -> Dict[str, Any]: + default_key = str(strategies_cfg.get("default") or "").lower() + if default_key and default_key in presets: + return presets[default_key] + if "swing" in presets: + return presets["swing"] + if presets: + return next(iter(presets.values())) + return {} + + preset_name = str(overrides_dict.get("name") or strategies_cfg.get("default") or "swing").lower() + preset = presets.get(preset_name) or _preset_fallback() + + def _resolve_value(key: str, default: Any) -> Any: + value = overrides_dict.get(key) + if value not in (None, ""): + return value + if preset and preset.get(key) not in (None, ""): + return preset[key] + return default + + horizon_hours = float(_resolve_value("horizon_hours", 72)) + target_pct = float(_resolve_value("target_pct", 0.03)) + stop_pct = float(_resolve_value("stop_pct", 0.015)) + follow_up = str(_resolve_value("follow_up", "reevaluate")) + urgency = str(_resolve_value("urgency", "medium")) + success_metric = str( + _resolve_value( + "success_metric", + f"Gain at least +{target_pct * 100:.1f}% within {horizon_hours:.0f}h", + ) + ) + failure_metric = str( + _resolve_value( + "failure_metric", + f"Drawdown of -{stop_pct * 100:.1f}% or thesis invalidated before {horizon_hours:.0f}h", + ) + ) + notes = str(_resolve_value("notes", "")) + + deadline = _resolve_value("deadline", "") + if not deadline: + deadline_dt = datetime.utcnow() + timedelta(hours=horizon_hours) + deadline = deadline_dt.strftime("%Y-%m-%dT%H:%M:%SZ") + + success_price = overrides_dict.get("success_price") + failure_price = overrides_dict.get("failure_price") + try: + success_price = float(success_price) if success_price is not None else None + except (TypeError, ValueError): + success_price = None + try: + failure_price = float(failure_price) if failure_price is not None else None + except (TypeError, ValueError): + failure_price = None + + return StrategyDirective( + name=preset_name, + horizon_hours=horizon_hours, + target_pct=target_pct, + stop_pct=stop_pct, + success_metric=success_metric, + failure_metric=failure_metric, + follow_up=follow_up, + urgency=urgency, + deadline=str(deadline), + notes=notes, + success_price=success_price, + failure_price=failure_price, + ) diff --git a/tradingagents/services/autopilot_broker.py b/tradingagents/services/autopilot_broker.py new file mode 100644 index 00000000..77908c5f --- /dev/null +++ b/tradingagents/services/autopilot_broker.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from datetime import date, timedelta +from typing import Callable, Dict, Iterable, Optional + +from tradingagents.dataflows.interface import route_to_vendor +from tradingagents.services.autopilot_worker import AutopilotWorker +from tradingagents.services.hypothesis_store import HypothesisRecord, HypothesisStore + + +@dataclass +class PriceThreshold: + symbol: str + operator: str + value: float + + +class AutopilotBroker: + """Simple polling broker that watches price thresholds and enqueues events.""" + + def __init__( + self, + store: HypothesisStore, + worker: AutopilotWorker, + price_fetcher: Optional[Callable[[str], Optional[float]]] = None, + ) -> None: + self.store = store + self.worker = worker + self.price_fetcher = price_fetcher or default_price_fetcher + self.poll_interval = 60 # seconds + + def parse_triggers(self, record: HypothesisRecord) -> Iterable[PriceThreshold]: + for trigger in record.triggers: + trigger_str = str(trigger).strip().lower() + if trigger_str.startswith("price >="): + try: + symbol, value = self._parse_simple_trigger(record.ticker, trigger_str, ">=") + yield PriceThreshold(symbol=symbol, operator=">=", value=value) + except ValueError: + continue + elif trigger_str.startswith("price <="): + try: + symbol, value = self._parse_simple_trigger(record.ticker, trigger_str, "<=") + yield PriceThreshold(symbol=symbol, operator="<=", value=value) + except ValueError: + continue + + def poll_once(self) -> Dict[str, str]: + outcomes: Dict[str, str] = {} + records = self.store.list() + for record in records: + for trigger in self.parse_triggers(record): + latest_price = self.price_fetcher(trigger.symbol) + if latest_price is None: + continue + if self._evaluate(trigger, latest_price): + event = self.worker.enqueue_event( + record.id, + event_type="price_threshold", + payload={"symbol": trigger.symbol, "price": latest_price, "operator": trigger.operator, "value": trigger.value}, + ) + outcomes[event.id] = f"Triggered price alert for {trigger.symbol}" + return outcomes + + def _evaluate(self, trigger: PriceThreshold, price: float) -> bool: + if trigger.operator == ">=": + return price >= trigger.value + if trigger.operator == "<=": + return price <= trigger.value + return False + + def _parse_simple_trigger(self, default_symbol: str, trigger_str: str, operator: str) -> (str, float): + parts = trigger_str.replace("price", "").strip().split(operator) + if len(parts) != 2: + raise ValueError("invalid trigger format") + left = parts[0].strip().upper() + symbol = left if left else default_symbol.upper() + value = float(parts[1].strip()) + return symbol, value + + +def default_price_fetcher(symbol: str) -> Optional[float]: + symbol = symbol.upper() + today = date.today() + start = today - timedelta(days=2) + try: + csv_text = route_to_vendor("get_stock_data", symbol, start.isoformat(), today.isoformat()) + except Exception: + return None + lines = [line for line in str(csv_text).splitlines() if line and not line.startswith("#")] + if not lines: + return None + last_line = lines[-1] + parts = last_line.split(",") + if len(parts) < 5: + return None + try: + return float(parts[4]) # close price + except ValueError: + return None diff --git a/tradingagents/services/autopilot_events.py b/tradingagents/services/autopilot_events.py new file mode 100644 index 00000000..a25e21ed --- /dev/null +++ b/tradingagents/services/autopilot_events.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List + +ISO_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def _utcnow() -> str: + return datetime.utcnow().strftime(ISO_FORMAT) + + +@dataclass +class HypothesisEvent: + id: str + hypothesis_id: str + event_type: str + payload: Dict[str, Any] + created_at: str + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "hypothesis_id": self.hypothesis_id, + "event_type": self.event_type, + "payload": self.payload, + "created_at": self.created_at, + } + + @classmethod + def from_dict(cls, payload: Dict[str, Any]) -> "HypothesisEvent": + return cls( + id=str(payload.get("id") or uuid.uuid4().hex), + hypothesis_id=str(payload.get("hypothesis_id") or ""), + event_type=str(payload.get("event_type") or "unknown"), + payload=dict(payload.get("payload") or {}), + created_at=str(payload.get("created_at") or _utcnow()), + ) + + +class AutopilotEventQueue: + def __init__(self, root_dir: Path) -> None: + self.root = Path(root_dir) + self.root.mkdir(parents=True, exist_ok=True) + self.path = self.root / "events.json" + + def enqueue(self, event: HypothesisEvent) -> None: + events = self._load() + events.append(event) + self._save(events) + + def dequeue_all(self) -> List[HypothesisEvent]: + events = self._load() + self._save([]) + return events + + def list(self) -> List[HypothesisEvent]: + return self._load() + + def _load(self) -> List[HypothesisEvent]: + if not self.path.exists(): + return [] + with self.path.open("r", encoding="utf-8") as handle: + try: + payload = json.load(handle) + except json.JSONDecodeError: + payload = [] + return [HypothesisEvent.from_dict(item) for item in payload or []] + + def _save(self, events: List[HypothesisEvent]) -> None: + serializable = [event.to_dict() for event in events] + tmp = self.path.with_suffix(".tmp") + with tmp.open("w", encoding="utf-8") as handle: + json.dump(serializable, handle, indent=2) + tmp.replace(self.path) diff --git a/tradingagents/services/autopilot_worker.py b/tradingagents/services/autopilot_worker.py new file mode 100644 index 00000000..f1afd796 --- /dev/null +++ b/tradingagents/services/autopilot_worker.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from tradingagents.services.account import AccountService, AccountSnapshot +from tradingagents.services.auto_trade import AutoTradeResult, AutoTradeService, TickerDecision +from tradingagents.services.autopilot_events import AutopilotEventQueue, HypothesisEvent, _utcnow +from tradingagents.services.hypothesis_store import HypothesisRecord, HypothesisStore, PlanStepRecord + + +@dataclass +class ProcessedEvent: + event: HypothesisEvent + status: str + message: str + + +class AutopilotWorker: + """Processes hypothesis events and runs focused reevaluations.""" + + def __init__( + self, + results_root: Path, + auto_trader: AutoTradeService, + account_service: AccountService, + ) -> None: + self.results_root = Path(results_root) + self.store = HypothesisStore(self.results_root / "hypotheses") + self.queue = AutopilotEventQueue(self.results_root / "autopilot") + self.auto_trader = auto_trader + self.account_service = account_service + + def enqueue_event(self, hypothesis_id: str, event_type: str, payload: Optional[Dict[str, Any]] = None) -> HypothesisEvent: + event = HypothesisEvent( + id=f"evt_{hypothesis_id}_{event_type}_{_utcnow()}", + hypothesis_id=hypothesis_id, + event_type=event_type, + payload=payload or {}, + created_at=_utcnow(), + ) + self.queue.enqueue(event) + return event + + def list_events(self) -> List[HypothesisEvent]: + return self.queue.list() + + def process_all(self) -> List[ProcessedEvent]: + events = self.queue.dequeue_all() + processed: List[ProcessedEvent] = [] + for event in events: + status, message = self._handle_event(event) + processed.append(ProcessedEvent(event=event, status=status, message=message)) + return processed + + def _handle_event(self, event: HypothesisEvent) -> Tuple[str, str]: + record = self.store.get(event.hypothesis_id) + if not record: + return ("skipped", f"Hypothesis {event.hypothesis_id} not found") + + step = record.next_open_step() + if not step: + record.status = "completed" + record.updated_at = _utcnow() + self.store.upsert(record) + return ("completed", "Hypothesis already completed; marked as completed") + + step.status = "done" + step.metadata.setdefault("events", []).append( + { + "type": event.event_type, + "payload": event.payload, + "timestamp": event.created_at, + } + ) + record.updated_at = _utcnow() + if record.next_open_step() is None: + record.status = "completed" + self.store.upsert(record) + + reevaluation_msg = self._reevaluate(record, event) + return ("updated", f"Marked step '{step.description}' as done. {reevaluation_msg}") + + def _reevaluate(self, record: HypothesisRecord, event: HypothesisEvent) -> str: + try: + snapshot = self.account_service.refresh() + except Exception as exc: + return f"Failed to refresh account snapshot: {exc}" + + try: + result = self.auto_trader.run( + snapshot, + focus_override=[record.ticker], + allow_market_closed=True, + ) + except Exception as exc: + return f"Auto-trade reevaluation failed: {exc}" + + decision = self._extract_decision(result, record.ticker) + if not decision: + return "No decision returned for ticker" + + self._apply_decision(record, decision) + record.updated_at = _utcnow() + self.store.upsert(record) + return f"Reevaluated with action {record.action}" + + def _extract_decision(self, result: AutoTradeResult, ticker: str) -> Optional[TickerDecision]: + for decision in result.decisions: + if decision.ticker.upper() == ticker.upper(): + return decision + return None + + def _apply_decision(self, record: HypothesisRecord, decision: TickerDecision) -> None: + record.action = (decision.final_decision or decision.immediate_action or record.action).upper() + record.priority = decision.priority + record.notes = decision.final_notes or decision.sequential_plan.notes or record.notes + record.plan = self._build_plan_from_decision(record.ticker, decision) + record.triggers = decision.action_queue or record.triggers + record.status = "monitoring" + if getattr(decision, "strategy", None): + record.strategy = decision.strategy.to_dict() + + def _build_plan_from_decision(self, ticker: str, decision: TickerDecision) -> List[PlanStepRecord]: + steps: List[PlanStepRecord] = [] + actions = decision.sequential_plan.actions or [] + for idx, action in enumerate(actions, 1): + steps.append( + PlanStepRecord( + id=f"{ticker.lower()}_{idx}", + description=str(action), + status="pending", + metadata={ + "next_decision": decision.sequential_plan.next_decision, + "source": "autopilot", + }, + ) + ) + if not steps: + steps.append( + PlanStepRecord( + id=f"{ticker.lower()}_plan", + description=f"Monitor hypothesis for {ticker}", + status="pending", + ) + ) + return steps diff --git a/tradingagents/services/hypothesis_store.py b/tradingagents/services/hypothesis_store.py new file mode 100644 index 00000000..00a8a353 --- /dev/null +++ b/tradingagents/services/hypothesis_store.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import json +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + +from tradingagents.services.auto_trade import AutoTradeResult, TickerDecision + +ISO_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + +def _utcnow() -> str: + return datetime.utcnow().strftime(ISO_FORMAT) + + +@dataclass +class PlanStepRecord: + id: str + description: str + status: str = "pending" + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "description": self.description, + "status": self.status, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, payload: Dict[str, Any]) -> "PlanStepRecord": + return cls( + id=str(payload.get("id") or uuid.uuid4().hex[:8]), + description=str(payload.get("description") or ""), + status=str(payload.get("status") or "pending"), + metadata=dict(payload.get("metadata") or {}), + ) + + +@dataclass +class HypothesisRecord: + id: str + ticker: str + action: str + priority: float + status: str + rationale: str + notes: str + plan: List[PlanStepRecord] + created_at: str + updated_at: str + source_snapshot: str + strategy: Dict[str, Any] = field(default_factory=dict) + triggers: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "ticker": self.ticker, + "action": self.action, + "priority": self.priority, + "status": self.status, + "rationale": self.rationale, + "notes": self.notes, + "plan": [step.to_dict() for step in self.plan], + "created_at": self.created_at, + "updated_at": self.updated_at, + "source_snapshot": self.source_snapshot, + "strategy": dict(self.strategy), + "triggers": list(self.triggers), + } + + @classmethod + def from_dict(cls, payload: Dict[str, Any]) -> "HypothesisRecord": + plan_payload = payload.get("plan") or [] + plan_steps = [PlanStepRecord.from_dict(step) for step in plan_payload] + return cls( + id=str(payload.get("id") or uuid.uuid4().hex), + ticker=str(payload.get("ticker") or ""), + action=str(payload.get("action") or "HOLD"), + priority=float(payload.get("priority") or 0.0), + status=str(payload.get("status") or "monitoring"), + rationale=str(payload.get("rationale") or ""), + notes=str(payload.get("notes") or ""), + plan=plan_steps, + created_at=str(payload.get("created_at") or _utcnow()), + updated_at=str(payload.get("updated_at") or _utcnow()), + source_snapshot=str(payload.get("source_snapshot") or ""), + strategy=dict(payload.get("strategy") or {}), + triggers=[str(item) for item in payload.get("triggers") or []], + ) + + def next_open_step(self) -> Optional[PlanStepRecord]: + for step in self.plan: + if step.status.lower() not in {"done", "complete", "completed"}: + return step + return None + + +class HypothesisStore: + """Persist hypotheses derived from auto-trade runs for autopilot follow-up.""" + + def __init__(self, root_dir: Path) -> None: + self.root = Path(root_dir) + self.root.mkdir(parents=True, exist_ok=True) + self.path = self.root / "hypotheses.json" + + def list(self) -> List[HypothesisRecord]: + if not self.path.exists(): + return [] + with self.path.open("r", encoding="utf-8") as handle: + try: + payload = json.load(handle) + except json.JSONDecodeError: + payload = [] + records = [HypothesisRecord.from_dict(item) for item in payload or []] + records.sort(key=lambda rec: rec.created_at, reverse=True) + return records + + def record_result(self, result: AutoTradeResult) -> List[HypothesisRecord]: + records = self.list() + new_records: List[HypothesisRecord] = [] + for decision in result.decisions: + record = self._record_from_decision(decision, result) + records.append(record) + new_records.append(record) + self._save(records) + return new_records + + def get(self, hypothesis_id: str) -> Optional[HypothesisRecord]: + for record in self.list(): + if record.id == hypothesis_id: + return record + return None + + def upsert(self, updated_record: HypothesisRecord) -> None: + records = self.list() + replaced = False + for idx, record in enumerate(records): + if record.id == updated_record.id: + records[idx] = updated_record + replaced = True + break + if not replaced: + records.append(updated_record) + self._save(records) + + def update_plan_step( + self, + hypothesis_id: str, + step_id: str, + *, + status: Optional[str] = None, + metadata_patch: Optional[Dict[str, Any]] = None, + ) -> Optional[HypothesisRecord]: + record = self.get(hypothesis_id) + if not record: + return None + for step in record.plan: + if step.id == step_id: + if status: + step.status = status + if metadata_patch: + step.metadata.update(metadata_patch) + record.updated_at = _utcnow() + self.upsert(record) + return record + return None + + def _save(self, records: List[HypothesisRecord]) -> None: + records.sort(key=lambda rec: rec.created_at, reverse=True) + serializable = [record.to_dict() for record in records] + tmp_path = self.path.with_suffix(".tmp") + with tmp_path.open("w", encoding="utf-8") as handle: + json.dump(serializable, handle, indent=2) + tmp_path.replace(self.path) + + def _record_from_decision(self, decision: TickerDecision, result: AutoTradeResult) -> HypothesisRecord: + record_id = uuid.uuid4().hex + created = _utcnow() + plan_steps = self._plan_steps(decision) + notes = decision.final_notes or decision.sequential_plan.notes or "" + rationale = str(decision.hypothesis.get("rationale") or notes) + triggers = decision.action_queue or [] + return HypothesisRecord( + id=record_id, + ticker=decision.ticker, + action=(decision.final_decision or decision.immediate_action or "hold").upper(), + priority=decision.priority, + status="monitoring", + rationale=rationale, + notes=notes, + plan=plan_steps, + created_at=created, + updated_at=created, + source_snapshot=result.account_snapshot.fetched_at.isoformat(), + strategy=decision.strategy.to_dict() if decision.strategy else {}, + triggers=[trigger for trigger in triggers if trigger], + ) + + def _plan_steps(self, decision: TickerDecision) -> List[PlanStepRecord]: + steps: List[PlanStepRecord] = [] + actions = decision.sequential_plan.actions or [] + for idx, description in enumerate(actions, 1): + steps.append( + PlanStepRecord( + id=f"{decision.ticker.lower()}_{idx}", + description=str(description), + status="pending", + metadata={ + "next_decision": decision.sequential_plan.next_decision, + }, + ) + ) + if not steps: + steps.append( + PlanStepRecord( + id=f"{decision.ticker.lower()}_plan", + description=f"Monitor hypothesis for {decision.ticker} (auto-generated)", + ) + ) + return steps diff --git a/tradingagents/services/memory.py b/tradingagents/services/memory.py new file mode 100644 index 00000000..a89aea0b --- /dev/null +++ b/tradingagents/services/memory.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import json +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional + + +class TickerMemoryStore: + """Filesystem-backed short-term memory for per-ticker decisions.""" + + def __init__(self, base_dir: str, *, max_entries: int = 5, enabled: bool = True) -> None: + self.enabled = enabled + self.base_path = Path(base_dir) + self.base_path.mkdir(parents=True, exist_ok=True) + self.max_entries = max(1, int(max_entries)) + + def is_enabled(self) -> bool: + return self.enabled + + def _path(self, ticker: str) -> Path: + return self.base_path / f"{ticker.upper()}.json" + + def load(self, ticker: str, limit: Optional[int] = None) -> List[Dict[str, Any]]: + if not self.enabled: + return [] + path = self._path(ticker) + if not path.exists(): + return [] + try: + with path.open("r", encoding="utf-8") as handle: + data = json.load(handle) + except Exception: + return [] + if not isinstance(data, list): + return [] + limit_value = limit if limit is not None else self.max_entries + if limit_value <= 0: + return data[-self.max_entries :] + return data[-limit_value:] + + def append(self, ticker: str, entry: Dict[str, Any]) -> None: + if not self.enabled: + return + path = self._path(ticker) + try: + history = [] + if path.exists(): + with path.open("r", encoding="utf-8") as handle: + history = json.load(handle) or [] + if not isinstance(history, list): + history = [] + except Exception: + history = [] + history.append(entry) + history = history[-self.max_entries :] + with path.open("w", encoding="utf-8") as handle: + json.dump(history, handle, indent=2, default=str) + + def record_decisions(self, decisions: List[Dict[str, Any]]) -> None: + if not self.enabled: + return + timestamp = datetime.utcnow().isoformat() + for decision in decisions: + ticker = str(decision.get("ticker") or "").upper() + if not ticker: + continue + payload = { + "timestamp": timestamp, + "action": decision.get("action") or decision.get("final_decision") or "", + "priority": decision.get("priority"), + "notes": decision.get("notes") or decision.get("final_notes") or "", + "plan_actions": decision.get("plan_actions") or decision.get("sequential_plan", {}).get("actions"), + "raw": decision, + } + self.append(ticker, payload) diff --git a/tradingagents/services/realtime_broker.py b/tradingagents/services/realtime_broker.py new file mode 100644 index 00000000..fa306402 --- /dev/null +++ b/tradingagents/services/realtime_broker.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import asyncio +import logging +import threading +from dataclasses import dataclass +from typing import Dict, List, Optional + +try: # optional dependency + from alpaca.data.live import StockDataStream + from alpaca.data.enums import DataFeed +except ImportError: # pragma: no cover - only when dependency missing + StockDataStream = None # type: ignore + DataFeed = None # type: ignore + +from tradingagents.services.autopilot_worker import AutopilotWorker +from tradingagents.services.hypothesis_store import HypothesisStore, HypothesisRecord + + +@dataclass +class PriceTrigger: + hypothesis_id: str + symbol: str + operator: str + value: float + + +class RealtimeBroker: + """Subscribe to Alpaca stock data stream and enqueue autopilot events.""" + + def __init__( + self, + store: HypothesisStore, + worker: AutopilotWorker, + api_key: str, + secret_key: str, + *, + feed: str = "iex", + logger: Optional[logging.Logger] = None, + ) -> None: + self.store = store + self.worker = worker + self.logger = logger or logging.getLogger(__name__) + if StockDataStream is None or DataFeed is None: + raise RuntimeError( + "alpaca-py is required for realtime broker. Install with `pip install alpaca-py`." + ) + data_feed = DataFeed.IEX if feed.lower() == "iex" else DataFeed.SIP + self.stream = StockDataStream(api_key, secret_key, feed=data_feed) + self.triggers: Dict[str, List[PriceTrigger]] = {} + self._subscribed: set[str] = set() + self._registered_keys: set[str] = set() + self._lock = threading.Lock() + + def bootstrap_triggers(self) -> None: + self.refresh_triggers(reset=True) + + def refresh_triggers( + self, + records: Optional[List[HypothesisRecord]] = None, + *, + reset: bool = False, + ) -> int: + """Register triggers for stored hypotheses. + + When ``reset`` is True the in-memory trigger cache is rebuilt so updates to + hypothesis triggers take effect immediately. + Returns the number of triggers registered during this call (useful for logging). + """ + + dataset = records or self.store.list() + with self._lock: + if reset: + self.triggers.clear() + self._registered_keys.clear() + + registered = 0 + for record in dataset: + for trigger in self._parse_triggers(record): + if self._register_trigger_locked(trigger): + registered += 1 + return registered + + def _parse_triggers(self, record: HypothesisRecord) -> List[PriceTrigger]: + triggers: List[PriceTrigger] = [] + for raw in record.triggers: + text = str(raw).strip().lower() + if text.startswith("price >="): + try: + symbol, value = self._extract_symbol_value(record.ticker, text, ">=") + triggers.append( + PriceTrigger( + hypothesis_id=record.id, + symbol=symbol, + operator=">=", + value=value, + ) + ) + except ValueError: + continue + elif text.startswith("price <="): + try: + symbol, value = self._extract_symbol_value(record.ticker, text, "<=") + triggers.append( + PriceTrigger( + hypothesis_id=record.id, + symbol=symbol, + operator="<=", + value=value, + ) + ) + except ValueError: + continue + return triggers + + def _extract_symbol_value(self, default_symbol: str, text: str, operator: str) -> (str, float): + left, right = text.replace("price", "").split(operator, 1) + symbol = left.strip().upper() or default_symbol.upper() + value = float(right.strip()) + return symbol, value + + def _register_trigger_locked(self, trigger: PriceTrigger) -> bool: + key = self._trigger_key(trigger) + if key in self._registered_keys: + return False + + self._registered_keys.add(key) + symbol = trigger.symbol + bucket = self.triggers.setdefault(symbol, []) + bucket.append(trigger) + if symbol not in self._subscribed: + self.stream.subscribe_trades(self._trade_handler, symbol) + self._subscribed.add(symbol) + return True + + def _trigger_key(self, trigger: PriceTrigger) -> str: + return f"{trigger.hypothesis_id}:{trigger.symbol}:{trigger.operator}:{trigger.value}" + + async def _trade_handler(self, data) -> None: # pragma: no cover - network callback + symbol = getattr(data, "symbol", "") + price = getattr(data, "price", None) + if not symbol or price is None: + return + with self._lock: + triggers = list(self.triggers.get(symbol, [])) + for trigger in triggers: + if self._evaluate(trigger, price): + self.worker.enqueue_event( + trigger.hypothesis_id, + event_type="price_threshold", + payload={"symbol": symbol, "price": price, "operator": trigger.operator, "value": trigger.value}, + ) + + def _evaluate(self, trigger: PriceTrigger, price: float) -> bool: + if trigger.operator == ">=": + return price >= trigger.value + if trigger.operator == "<=": + return price <= trigger.value + return False + + def run_forever(self) -> None: + self.bootstrap_triggers() + self.logger.info("Starting Alpaca stock data stream …") + try: + self.stream.run() + except KeyboardInterrupt: # pragma: no cover - manual stop + self.logger.info("Realtime broker stopped") diff --git a/tradingagents/services/realtime_news_broker.py b/tradingagents/services/realtime_news_broker.py new file mode 100644 index 00000000..0967d0e7 --- /dev/null +++ b/tradingagents/services/realtime_news_broker.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import json +import logging +import threading +from collections import defaultdict +from typing import Dict, List, Optional, Set + +import websocket + +from tradingagents.services.autopilot_worker import AutopilotWorker +from tradingagents.services.hypothesis_store import HypothesisStore, HypothesisRecord + + +class RealtimeNewsBroker: + """Subscribe to Alpaca news WebSocket and forward events to the autopilot worker.""" + + DEFAULT_URL = "wss://stream.data.alpaca.markets/v1beta1/news" + + def __init__( + self, + store: HypothesisStore, + worker: AutopilotWorker, + api_key: str, + secret_key: str, + *, + url: Optional[str] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + self.store = store + self.worker = worker + self.api_key = api_key + self.secret_key = secret_key + self.url = url or self.DEFAULT_URL + self.logger = logger or logging.getLogger(__name__) + self.watchers: Dict[str, Set[str]] = defaultdict(set) + self.ws: Optional[websocket.WebSocketApp] = None + self._thread: Optional[threading.Thread] = None + self._lock = threading.Lock() + + def bootstrap_watchers(self) -> None: + self.refresh_watchers() + + def refresh_watchers(self, records: Optional[List[HypothesisRecord]] = None) -> int: + """Rebuild the news watcher map from stored hypotheses.""" + + dataset = records or self.store.list() + with self._lock: + self.watchers.clear() + registered = 0 + for record in dataset: + registered += self._register_symbol_unlocked(record.ticker.upper(), record.id) + for symbol in self._extract_symbols_from_triggers(record): + registered += self._register_symbol_unlocked(symbol, record.id) + return registered + + def _extract_symbols_from_triggers(self, record: HypothesisRecord) -> List[str]: + symbols: List[str] = [] + for raw in record.triggers: + text = str(raw).strip().lower() + if text.startswith("news"): + parts = text.split(":", 1) + if len(parts) == 2 and parts[1].strip(): + symbols.append(parts[1].strip().upper()) + else: + symbols.append(record.ticker.upper()) + return symbols + + def _register_symbol(self, symbol: str, hypothesis_id: str) -> int: + if not symbol: + return 0 + symbol_key = symbol.upper() + with self._lock: + return self._register_symbol_unlocked(symbol_key, hypothesis_id) + + def _register_symbol_unlocked(self, symbol: str, hypothesis_id: str) -> int: + if not symbol: + return 0 + bucket = self.watchers[symbol] + before = len(bucket) + bucket.add(hypothesis_id) + return 1 if len(bucket) > before else 0 + + def start(self) -> None: + if self.ws is not None: + self.logger.info("News broker already running") + return + self.bootstrap_watchers() + + headers = [ + f"APCA-API-KEY-ID: {self.api_key}", + f"APCA-API-SECRET-KEY: {self.secret_key}", + ] + + self.ws = websocket.WebSocketApp( + self.url, + header=headers, + on_open=self._on_open, + on_message=self._on_message, + on_error=self._on_error, + on_close=self._on_close, + ) + + def _run(): # pragma: no cover - network behavior + self.logger.info("Connecting to Alpaca news stream …") + self.ws.run_forever() + + self._thread = threading.Thread(target=_run, daemon=True) + self._thread.start() + + def _on_open(self, ws): # pragma: no cover - network callback + self.logger.info("News stream connected; subscribing to all news…") + ws.send(json.dumps({"action": "subscribe", "news": ["*"]})) + + def _on_message(self, ws, message): # pragma: no cover - network callback + try: + data = json.loads(message) + except json.JSONDecodeError: + return + if isinstance(data, list): + for item in data: + self._handle_news(item) + else: + self._handle_news(data) + + def _handle_news(self, payload: Dict[str, object]) -> None: + if payload.get("T") != "n": + return + symbols = payload.get("symbols") or [] + if not isinstance(symbols, list): + return + for symbol in symbols: + symbol_key = str(symbol or "").upper() + with self._lock: + targets = list(self.watchers.get(symbol_key, ())) + for hypothesis_id in targets: + self.worker.enqueue_event( + hypothesis_id, + event_type="news", + payload=payload, + ) + + def _on_error(self, ws, error): # pragma: no cover - network callback + self.logger.error("News stream error: %s", error) + + def _on_close(self, ws, code, msg): # pragma: no cover - network callback + self.logger.info("News stream closed: %s %s", code, msg) + self.ws = None diff --git a/tradingagents/services/responses_auto_trade.py b/tradingagents/services/responses_auto_trade.py new file mode 100644 index 00000000..69ae6498 --- /dev/null +++ b/tradingagents/services/responses_auto_trade.py @@ -0,0 +1,977 @@ +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from datetime import date, timedelta +from typing import Any, Callable, Dict, List, Optional, Set, Tuple +import logging + +from langchain_core.messages import HumanMessage +from openai import OpenAI +from tradingagents.dataflows.interface import route_to_vendor +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.services.account import AccountSnapshot +from tradingagents.services.auto_trade import ( + AutoTradeResult, + SequentialPlan, + TickerDecision, + StrategyDirective, + resolve_strategy_directive, +) +from tradingagents.services.memory import TickerMemoryStore +from tradingagents.agents.analysts.market_analyst import create_market_analyst +from tradingagents.agents.analysts.news_analyst import create_news_analyst +from tradingagents.agents.analysts.fundamentals_analyst import create_fundamentals_analyst + + +def _extract_json_block(text: str) -> Dict[str, Any]: + if not text: + return {} + snippet = text.strip() + if snippet.startswith("```"): + parts = snippet.split("```") + for part in parts: + candidate = part.strip() + if candidate.startswith("{") and candidate.endswith("}"): + try: + return json.loads(candidate) + except json.JSONDecodeError: + continue + return {} + if snippet.startswith("{") and snippet.endswith("}"): + try: + return json.loads(snippet) + except json.JSONDecodeError: + return {} + # Fallback: scan for first JSON object within the text + decoder = json.JSONDecoder() + for idx, char in enumerate(snippet): + if char == "{": + try: + data, _ = decoder.raw_decode(snippet[idx:]) + return data + except json.JSONDecodeError: + continue + return {} + + +def _trimmed_json(payload: Any, *, limit: int = 400) -> str: + try: + text = json.dumps(payload, default=str) + except Exception: + text = str(payload) + if len(text) <= limit: + return text + return f"{text[: limit - 3]}..." + + +@dataclass +class ResponsesTool: + name: str + description: str + schema: Dict[str, Any] + handler: Callable[[Dict[str, Any]], Dict[str, Any]] + + def spec(self) -> Dict[str, Any]: + return { + "type": "function", + "name": self.name, + "description": self.description, + "parameters": self.schema, + } + + +class TradingToolbox: + """Wrap the existing TradingAgents capabilities as Responses-ready tools.""" + + def __init__( + self, + config: Dict[str, Any], + graph: TradingAgentsGraph, + snapshot: AccountSnapshot, + logger: Optional[logging.Logger] = None, + memory_store: Optional[TickerMemoryStore] = None, + ) -> None: + self.config = config + self.graph = graph + self.snapshot = snapshot + self.logger = logger or logging.getLogger(__name__) + self.memory_store = memory_store + self._agent_runners = self._init_agent_runners() + self._tools = self._build_tools() + + @property + def specs(self) -> List[Dict[str, Any]]: + return [tool.spec() for tool in self._tools.values()] + + def invoke(self, name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + if name not in self._tools: + raise ValueError(f"Unknown tool requested: {name}") + if self.logger: + self.logger.debug("Responses tool call: %s args=%s", name, arguments) + return self._tools[name].handler(arguments or {}) + + def _build_tools(self) -> Dict[str, ResponsesTool]: + tools: Dict[str, ResponsesTool] = { + "get_account_overview": ResponsesTool( + name="get_account_overview", + description="Return the cached Alpaca account, position, and order snapshots for context.", + schema={"type": "object", "properties": {}, "additionalProperties": False}, + handler=lambda _: { + "fetched_at": self.snapshot.fetched_at.isoformat(), + "account_text": self.snapshot.account_text, + "positions_text": self.snapshot.positions_text, + "orders_text": self.snapshot.orders_text, + }, + ), + "list_focus_tickers": ResponsesTool( + name="list_focus_tickers", + description="Return the configured trading universe merged with current holdings.", + schema={"type": "object", "properties": {}, "additionalProperties": False}, + handler=lambda _: { + "universe": self._determine_focus_tickers(), + }, + ), + "fetch_market_data": ResponsesTool( + name="fetch_market_data", + description="Fetch OHLCV market data for a symbol over the requested lookback window (days).", + schema={ + "type": "object", + "properties": { + "symbol": {"type": "string"}, + "lookback_days": {"type": "integer", "minimum": 1, "default": 30}, + }, + "required": ["symbol"], + "additionalProperties": False, + }, + handler=self._tool_fetch_market_data, + ), + "fetch_company_news": ResponsesTool( + name="fetch_company_news", + description="Fetch recent company-specific news articles for a symbol.", + schema={ + "type": "object", + "properties": { + "symbol": {"type": "string"}, + "lookback_days": {"type": "integer", "minimum": 1, "default": 5}, + }, + "required": ["symbol"], + "additionalProperties": False, + }, + handler=self._tool_fetch_company_news, + ), + "fetch_global_news": ResponsesTool( + name="fetch_global_news", + description="Fetch macro/global news context for the requested lookback horizon.", + schema={ + "type": "object", + "properties": { + "lookback_days": {"type": "integer", "minimum": 1, "default": 3}, + "limit": {"type": "integer", "minimum": 1, "default": 5}, + }, + "required": [], + "additionalProperties": False, + }, + handler=self._tool_fetch_global_news, + ), + "fetch_indicators": ResponsesTool( + name="fetch_indicators", + description="Fetch technical indicators for a symbol. Indicators should be provided as a list of canonical names (e.g., rsi, close_50_sma).", + schema={ + "type": "object", + "properties": { + "symbol": {"type": "string"}, + "indicators": { + "type": "array", + "items": {"type": "string"}, + "default": ["rsi", "close_50_sma", "close_200_sma"], + }, + "lookback_days": {"type": "integer", "minimum": 1, "default": 30}, + }, + "required": ["symbol"], + "additionalProperties": False, + }, + handler=self._tool_fetch_indicators, + ), + "submit_trade_order": ResponsesTool( + name="submit_trade_order", + description="Submit a trade directive (BUY/SELL/HOLD) for a ticker with optional reasoning. Honors dry-run and market-open checks.", + schema={ + "type": "object", + "properties": { + "symbol": {"type": "string"}, + "action": {"type": "string", "enum": ["BUY", "SELL", "HOLD"]}, + "notes": {"type": "string"}, + }, + "required": ["symbol", "action"], + "additionalProperties": False, + }, + handler=self._tool_submit_trade, + ), + } + if self.memory_store and self.memory_store.is_enabled(): + tools["get_ticker_memory"] = ResponsesTool( + name="get_ticker_memory", + description="Retrieve recent decision memory for a ticker.", + schema={ + "type": "object", + "properties": { + "symbol": {"type": "string"}, + "limit": {"type": "integer", "minimum": 1, "default": self.memory_store.max_entries}, + }, + "required": ["symbol"], + "additionalProperties": False, + }, + handler=self._tool_get_memory, + ) + for agent_key in (self._agent_runners or {}): + tool_name = f"run_{agent_key}_analyst" + tools[tool_name] = ResponsesTool( + name=tool_name, + description=f"Run the {agent_key} analyst to produce a detailed report for a ticker.", + schema={ + "type": "object", + "properties": {"symbol": {"type": "string"}}, + "required": ["symbol"], + "additionalProperties": False, + }, + handler=lambda args, agent=agent_key: self._tool_run_agent(agent, args or {}), + ) + return tools + + def _determine_focus_tickers(self) -> List[str]: + universe_raw = self.config.get("portfolio_orchestrator", {}).get("universe", "") + universe = [sym.strip().upper() for sym in universe_raw.split(",") if sym.strip()] + holdings = self.snapshot.position_symbols() + combined: List[str] = [] + for symbol in list(dict.fromkeys(universe + holdings)): + if symbol: + combined.append(symbol) + return combined or ["SPY"] + + def _tool_fetch_market_data(self, args: Dict[str, Any]) -> Dict[str, Any]: + symbol = str(args.get("symbol") or "").upper() + lookback_days = int(args.get("lookback_days") or 30) + end_date = date.today() + start_date = end_date - timedelta(days=max(lookback_days, 1)) + payload = route_to_vendor("get_stock_data", symbol, start_date.isoformat(), end_date.isoformat()) + return {"symbol": symbol, "start": start_date.isoformat(), "end": end_date.isoformat(), "data": payload} + + def _tool_fetch_company_news(self, args: Dict[str, Any]) -> Dict[str, Any]: + symbol = str(args.get("symbol") or "").upper() + lookback_days = int(args.get("lookback_days") or 5) + end_date = date.today() + start_date = end_date - timedelta(days=max(lookback_days, 1)) + payload = route_to_vendor("get_news", symbol, start_date.isoformat(), end_date.isoformat()) + return { + "symbol": symbol, + "start": start_date.isoformat(), + "end": end_date.isoformat(), + "data": payload, + } + def _tool_fetch_global_news(self, args: Dict[str, Any]) -> Dict[str, Any]: + lookback_days = int(args.get("lookback_days") or 3) + limit = int(args.get("limit") or 5) + payload = route_to_vendor("get_global_news", date.today().isoformat(), lookback_days, limit) + return {"lookback_days": lookback_days, "limit": limit, "data": payload} + + def _tool_fetch_indicators(self, args: Dict[str, Any]) -> Dict[str, Any]: + symbol = str(args.get("symbol") or "").upper() + lookback_days = int(args.get("lookback_days") or 30) + indicators = args.get("indicators") or [] + if not indicators: + indicators = ["rsi", "close_50_sma", "close_200_sma"] + end_date = date.today().isoformat() + payloads: Dict[str, Any] = {} + for indicator_name in indicators: + try: + payloads[indicator_name] = route_to_vendor( + "get_indicators", + symbol, + indicator_name, + end_date, + lookback_days, + ) + except Exception as exc: + payloads[indicator_name] = {"error": str(exc)} + return { + "symbol": symbol, + "indicators": indicators, + "as_of": end_date, + "lookback_days": lookback_days, + "data": payloads, + } + + def _tool_submit_trade(self, args: Dict[str, Any]) -> Dict[str, Any]: + symbol = str(args.get("symbol") or "").upper() + action = str(args.get("action") or "").upper() + status = self.graph.check_market_status() + market_open = bool(status.get("is_open", True)) + if not market_open: + return { + "status": "market_closed", + "clock": status.get("clock_text"), + } + result = self.graph.execute_trade_directive(symbol, action) + return {"status": result.get("status"), "response": result} + + def _call_vendor(self, method: str, *args) -> Any: + try: + return route_to_vendor(method, *args) + except Exception as exc: + if self.logger: + self.logger.warning("Vendor call %s failed: %s", method, exc) + return {"error": str(exc)} + + def _tool_get_memory(self, args: Dict[str, Any]) -> Dict[str, Any]: + if not self.memory_store: + return {"entries": []} + symbol = str(args.get("symbol") or "").upper() + limit = int(args.get("limit") or self.memory_store.max_entries) + entries = self.memory_store.load(symbol, limit) + return {"symbol": symbol, "entries": entries} + + def _init_agent_runners(self) -> Dict[str, Any]: + try: + market = create_market_analyst(self.graph.quick_thinking_llm) + news = create_news_analyst(self.graph.quick_thinking_llm) + fundamentals = create_fundamentals_analyst(self.graph.quick_thinking_llm) + except Exception as exc: + if self.logger: + self.logger.warning("Failed to initialize analyst agents: %s", exc) + return {} + return { + "market": market, + "news": news, + "fundamentals": fundamentals, + } + + def _tool_run_agent(self, agent_key: str, args: Dict[str, Any]) -> Dict[str, Any]: + symbol = str(args.get("symbol") or "").upper() + if not symbol: + return {"error": "Missing symbol"} + report = self._run_agent(agent_key, symbol) + return {"symbol": symbol, "agent": agent_key, "report": report} + + def _run_agent(self, agent_key: str, symbol: str) -> str: + runner = (self._agent_runners or {}).get(agent_key) + if not runner: + return f"{agent_key} analyst unavailable." + state = self._build_agent_state(agent_key, symbol) + try: + result = runner(state) + except Exception as exc: + if self.logger: + self.logger.warning("Analyst %s failed for %s: %s", agent_key, symbol, exc) + return f"{agent_key} analyst failed: {exc}" + report_key = { + "market": "market_report", + "news": "news_report", + "fundamentals": "fundamentals_report", + }.get(agent_key, "report") + report = result.get(report_key) + if not report: + messages = result.get("messages") or [] + if messages: + try: + report = messages[-1].content + except Exception: + report = str(messages[-1]) + return report or f"{agent_key} analyst produced no narrative." + + def _build_agent_state(self, agent_key: str, symbol: str) -> Dict[str, Any]: + today = date.today().isoformat() + return { + "messages": [HumanMessage(content=f"Provide {agent_key} analysis for {symbol} on {today}.")], + "company_of_interest": symbol, + "target_ticker": symbol, + "trade_date": today, + "scheduled_analysts": [agent_key], + "scheduled_analysts_plan": [agent_key], + "orchestrator_action": "execute", + } + + +class ResponsesAutoTradeService: + """Auto-trade orchestration powered by the OpenAI Responses API.""" + + def __init__( + self, + config: Dict[str, Any], + graph: Optional[TradingAgentsGraph] = None, + logger: Optional[logging.Logger] = None, + ) -> None: + self.config = config + self.graph = graph or TradingAgentsGraph(config=config, skip_initial_probes=True) + self.logger = logger or logging.getLogger(__name__) + backend_url = config.get("backend_url") + client_kwargs = {} + if backend_url: + client_kwargs["base_url"] = backend_url + self.client = OpenAI(**client_kwargs) + memory_cfg = (self.config.get("auto_trade") or {}).get("memory", {}) or {} + memory_enabled = bool(memory_cfg.get("enabled", True)) + memory_dir = memory_cfg.get( + "dir", + os.path.join(self.config.get("results_dir", "./results"), "memory"), + ) + max_entries = int(memory_cfg.get("max_entries", 5)) + self.memory_store = TickerMemoryStore(memory_dir, max_entries=max_entries, enabled=memory_enabled) + self._strategy_brief_cache = self._strategy_presets_brief() + + def run(self, snapshot: AccountSnapshot, *, focus_override: Optional[List[str]] = None) -> AutoTradeResult: + self._reference_prices = _snapshot_reference_prices(snapshot) + toolbox = TradingToolbox( + self.config, + self.graph, + snapshot, + logger=self.logger, + memory_store=self.memory_store, + ) + system_prompt = self._build_system_prompt() + focus_tickers = focus_override or toolbox._determine_focus_tickers() + conversation: List[Dict[str, Any]] = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": json.dumps( + { + "account": snapshot.account, + "positions": snapshot.positions, + "orders": snapshot.orders, + "focus_tickers": focus_tickers, + "fetched_at": snapshot.fetched_at.isoformat(), + } + ), + }, + ] + if self.memory_store and self.memory_store.is_enabled(): + memory_payload = {} + for ticker in focus_tickers: + entries = self.memory_store.load(ticker, limit=3) + if entries: + memory_payload[ticker] = entries + if memory_payload: + conversation.append( + { + "role": "user", + "content": json.dumps( + { + "memory_hint": "Historical decisions per ticker. Use get_ticker_memory if deeper detail needed.", + "entries": memory_payload, + } + ), + } + ) + else: + memory_payload = {} + + lacking_memory = [ticker for ticker in focus_tickers if ticker not in memory_payload] + if lacking_memory: + conversation.append( + { + "role": "user", + "content": json.dumps( + { + "context_gap": "Some focus tickers currently have no stored memory.", + "tickers": lacking_memory, + "required_actions": ( + "Before finalizing decisions for these tickers, gather baseline context by " + "calling `fetch_market_data` with at least a 7-day lookback and " + "`fetch_company_news` (and optionally `fetch_global_news` if macro forces matter). " + "Summarize what you learned from those tools so the operator can review your reasoning." + ), + } + ), + } + ) + + conversation.append( + { + "role": "user", + "content": json.dumps( + { + "planning_protocol": ( + "Before calling additional tools, outline a numbered plan where each step names the ticker and the tool/data you intend to use. " + "Track each step's status (`pending`, `in_progress`, `done`). After every tool call, explicitly state which step changed status and why. " + "If the plan changes mid-run, update the list immediately so the operator sees the live state of each action." + ) + } + ), + } + ) + + if self._strategy_brief_cache.get("presets"): + conversation.append( + { + "role": "user", + "content": json.dumps( + { + "strategy_presets": self._strategy_brief_cache, + "instructions": ( + "Select whichever preset best matches each ticker's urgency; override target/stop only when necessary." + ), + } + ), + } + ) + + transcript: List[str] = [] + submitted_trades: Set[Tuple[str, str]] = set() + response = self._responses_call( + conversation, + toolbox, + transcript, + allow_tools=True, + submitted_trades=submitted_trades, + ) + final_text = self._response_text(response) + summary = _extract_json_block(final_text) + + if not summary.get("decisions"): + conversation.append( + { + "role": "user", + "content": ( + "Provide the final decision summary strictly as JSON with the schema:\n" + '{"decisions":[{"ticker": "...", "action": "...", "priority": 0.0, ' + '"plan_actions": [], "next_decision": "...", "notes": "...", ' + '"plan_status": {"step description": "pending"}, ' + '"strategy": {"name": "swing", "horizon_hours": 72, "target_pct": 0.04, ' + '"stop_pct": 0.02, "success_metric": "...", "failure_metric": "...", ' + '"follow_up": "reassess_every_close", "deadline": "2025-11-14T21:00:00Z", "urgency": "medium"}}]}' + " Do not include prose outside the JSON." + ), + } + ) + response = self._responses_call( + conversation, + toolbox, + transcript, + max_turns=2, + allow_tools=False, + submitted_trades=submitted_trades, + ) + final_text = self._response_text(response) + summary = _extract_json_block(final_text) + + decisions, focus = self._decisions_from_summary(summary) + raw_state = { + "responses_transcript": transcript, + "responses_summary": summary, + "responses_output_text": final_text, + } + if self.memory_store and self.memory_store.is_enabled() and decisions: + payload: List[Dict[str, Any]] = [] + for decision in decisions: + decision_dict = decision.to_dict() + decision_dict["action"] = decision.final_decision or decision.immediate_action + decision_dict["notes"] = decision.final_notes or decision_dict.get("final_notes") or "" + decision_dict["priority"] = decision.priority + payload.append(decision_dict) + self.memory_store.record_decisions(payload) + self._auto_execute_trades(decisions, submitted_trades) + + return AutoTradeResult( + focus_tickers=focus or focus_tickers, + decisions=decisions, + account_snapshot=snapshot, + raw_state=raw_state, + ) + + def _responses_call( + self, + conversation: List[Dict[str, Any]], + toolbox: TradingToolbox, + transcript: List[str], + *, + max_turns: Optional[int] = None, + allow_tools: bool = True, + submitted_trades: Optional[Set[Tuple[str, str]]] = None, + ): + model = ( + self.config.get("auto_trade", {}).get("responses_model") + or self.config.get("quick_think_llm") + ) + if not model: + raise RuntimeError("Missing responses model configuration.") + + reasoning_config = self.config.get("auto_trade", {}).get("responses_reasoning_effort", "") + reasoning_text = (reasoning_config or "").strip() + if not reasoning_text: + reasoning_text = "medium" + reasoning_enabled = reasoning_text and reasoning_text.lower() not in {"none", "off"} + remaining_turns = max_turns or int(self.config.get("auto_trade", {}).get("responses_max_turns") or 8) + + if remaining_turns <= 0: + raise RuntimeError("Responses conversation exceeded maximum turns without completion.") + + repeat_guard: Dict[str, int] = {} + narration_reminder_issued = False + + while remaining_turns > 0: + request_kwargs: Dict[str, Any] = { + "model": model, + "input": conversation, + "store": False, + } + if allow_tools: + request_kwargs["tools"] = toolbox.specs + if reasoning_enabled: + request_kwargs["reasoning"] = {"effort": reasoning_text} + + tool_call: Optional[Dict[str, Any]] = None + final_response: Any = None + + response = self.client.responses.create(**request_kwargs) + final_response = response + thinking_traces = self._extract_reasoning_traces(response) + for trace in thinking_traces: + if trace: + transcript.append(f"[Thinking] {trace}") + self._emit_narration(f"[Thinking] {trace}") + + assistant_message = self._response_text(response) + if assistant_message: + transcript.append(assistant_message) + self._emit_narration(assistant_message) + conversation.append({"role": "assistant", "content": assistant_message}) + narration_reminder_issued = False + + tool_calls = self._extract_tool_calls(response) + if tool_calls: + for tool_call in tool_calls: + args = self._safe_json(tool_call.get("arguments")) + name = tool_call.get("name") or "" + tool_error: Optional[str] = None + try: + result = toolbox.invoke(name, args) + except Exception as exc: # pragma: no cover - defensive + tool_error = f"{exc}" + result = {"error": tool_error} + self._emit_tool_event(name, args, result) + if ( + tool_error is None + and submitted_trades is not None + and name == "submit_trade_order" + ): + symbol = str(args.get("symbol") or "").upper() + action = str(args.get("action") or "").upper() + if symbol and action: + submitted_trades.add((symbol, action)) + conversation.append( + { + "role": "user", + "content": json.dumps( + { + "tool": name, + "tool_call_id": tool_call.get("id") + or tool_call.get("call_id") + or tool_call.get("item_id"), + "result": result, + }, + default=str, + ), + } + ) + guard_key = f"{name}:{json.dumps(args, sort_keys=True)}" + repeat_guard[guard_key] = repeat_guard.get(guard_key, 0) + 1 + if repeat_guard[guard_key] >= 2: + conversation.append( + { + "role": "user", + "content": ( + f"You have already called `{name}` with the same arguments " + f"{repeat_guard[guard_key]} times. Summarize the existing data and " + "move on to the next required tool or generate the decision summary instead of " + "repeating this call." + ), + } + ) + remaining_turns -= 1 + if not assistant_message and not narration_reminder_issued: + conversation.append( + { + "role": "user", + "content": ( + "Narrate what you are doing before issuing more tool calls so the CLI can show your " + "reasoning in real time." + ), + } + ) + narration_reminder_issued = True + continue + + if final_response is None: + raise RuntimeError("Streaming response did not complete.") + return final_response + + raise RuntimeError("Responses conversation exceeded maximum turns without completion.") + + def _decisions_from_summary(self, summary: Dict[str, Any]) -> Tuple[List[TickerDecision], List[str]]: + decisions_payload = summary.get("decisions") or [] + decisions: List[TickerDecision] = [] + focus: List[str] = [] + + for entry in decisions_payload: + ticker = str(entry.get("ticker") or "").upper() + if not ticker: + continue + focus.append(ticker) + priority = float(entry.get("priority") or entry.get("confidence") or 0) + action = str(entry.get("action") or entry.get("decision") or "monitor").upper() + plan_actions = entry.get("plan_actions") or entry.get("actions") or [] + if isinstance(plan_actions, str): + plan_actions = [plan_actions] + immediate = entry.get("immediate_action") or action.lower() + sequential_plan = SequentialPlan( + actions=[str(item).lower() for item in plan_actions], + next_decision=str(entry.get("next_decision") or immediate).lower(), + notes=str(entry.get("notes") or entry.get("rationale") or ""), + reasoning=entry.get("reasoning") or [], + ) + hypothesis = { + "ticker": ticker, + "rationale": entry.get("rationale") or entry.get("notes") or "", + "priority": priority, + "required_analysts": entry.get("required_analysts") or [], + "immediate_actions": immediate, + } + trade_notes = entry.get("execution_plan") or entry.get("notes") or "" + strategy = self._build_strategy(entry) + triggers = self._build_triggers(strategy, entry) + decision = TickerDecision( + ticker=ticker, + hypothesis=hypothesis, + sequential_plan=sequential_plan, + action_queue=triggers, + immediate_action=str(immediate), + priority=priority, + final_decision=action, + trader_plan=entry.get("trader_plan") or "", + final_notes=trade_notes, + strategy=strategy, + ) + decisions.append(decision) + + return decisions, focus + + def _build_strategy(self, entry: Dict[str, Any]) -> StrategyDirective: + overrides = entry.get("strategy") + if overrides and isinstance(overrides, dict): + overrides = {**overrides} # shallow copy so we can enrich with derived prices + strategy = resolve_strategy_directive(self.config, overrides) + entry["strategy"] = strategy.to_dict() + return strategy + + def _build_triggers(self, strategy: StrategyDirective, entry: Dict[str, Any]) -> List[str]: + base_triggers = [str(item).lower() for item in entry.get("action_queue", []) if str(item).strip()] + price = _extract_reference_price(entry) + if not price and hasattr(self, "_reference_prices"): + price = self._reference_prices.get(str(entry.get("ticker") or "").upper()) + derived: List[str] = [] + if price and strategy.target_pct: + success_price = price * (1 + strategy.target_pct) + strategy.success_price = success_price + derived.append(f"price >= {success_price:.2f}") + if price and strategy.stop_pct: + failure_price = price * (1 - strategy.stop_pct) + strategy.failure_price = failure_price + derived.append(f"price <= {failure_price:.2f}") + return base_triggers + derived + + def _extract_tool_calls(self, response: Any) -> List[Dict[str, Any]]: + calls: List[Dict[str, Any]] = [] + output_items = getattr(response, "output", []) or [] + for item in output_items: + if getattr(item, "type", None) != "function_call": + continue + call_id = getattr(item, "id", None) or getattr(item, "call_id", None) + arguments = getattr(item, "arguments", "") or "" + calls.append({"id": call_id, "name": getattr(item, "name", ""), "arguments": arguments}) + return calls + + def _extract_reasoning_traces(self, response: Any) -> List[str]: + traces: List[str] = [] + output_items = getattr(response, "output", []) or [] + for item in output_items: + if getattr(item, "type", None) != "reasoning": + continue + summary_bits: List[str] = [] + for summary in getattr(item, "summary", []) or []: + text = getattr(summary, "text", "") or "" + if text: + summary_bits.append(text.strip()) + detail_bits: List[str] = [] + for content in getattr(item, "content", []) or []: + text = getattr(content, "text", "") or "" + if text: + detail_bits.append(text.strip()) + summary_text = "; ".join(bit for bit in summary_bits if bit) + detail_text = " ".join(bit for bit in detail_bits if bit) + if detail_text and detail_text != summary_text: + snippet = f"{summary_text} — {detail_text}" if summary_text else detail_text + else: + snippet = summary_text or detail_text + if snippet: + traces.append(snippet) + return traces + + def _response_text(self, response: Any) -> str: + if not response: + return "" + if hasattr(response, "output_text") and response.output_text: + return response.output_text + pieces: List[str] = [] + for output in getattr(response, "output", []) or []: + if getattr(output, "type", None) == "message": + for content in getattr(output, "content", []) or []: + if getattr(content, "type", None) == "output_text": + pieces.append(getattr(content, "text", "") or "") + return "\n".join(pieces).strip() + + def _emit_narration(self, message: str) -> None: + snippet = message.strip() + if not snippet: + return + try: + print(f"[Responses Orchestrator] {snippet}") + except Exception: + pass + + def _emit_tool_event(self, name: str, args: Dict[str, Any], result: Dict[str, Any]) -> None: + try: + status = "OK" + if isinstance(result, dict) and result.get("error"): + status = "ERR" + args_str = _trimmed_json(args) + response_payload = result.get("report") if isinstance(result, dict) and isinstance(result.get("report"), str) else result + response_str = _trimmed_json(response_payload) + print(f"[Tool:{status}] {name}, Args:{args_str}") + print(f"[Tool:{status}] {name}, Response:{response_str}") + except Exception: + pass + + def _auto_execute_trades( + self, + decisions: List[TickerDecision], + submitted_trades: Set[Tuple[str, str]], + ) -> None: + exec_cfg = self.config.get("trade_execution", {}) or {} + if not exec_cfg.get("enabled"): + return + for decision in decisions: + action = (decision.final_decision or decision.immediate_action or "").upper() + if action not in {"BUY", "SELL"}: + continue + key = (decision.ticker.upper(), action) + if key in submitted_trades: + continue + result = self.graph.execute_trade_directive(decision.ticker, action) + try: + print(f"[Auto Execution] {decision.ticker} {action} -> {result.get('status')}") + except Exception: + pass + + def _build_system_prompt(self) -> str: + return ( + "You are the trading orchestrator for TradingAgents. Every run must begin by calling " + "`get_account_overview` exactly once (unless you explicitly refresh the Alpaca snapshot) and narrating the " + "current buying power, cash, open positions, and any recent orders. Reuse that overview for the remainder of " + "the run; do not call `get_account_overview` again until you intentionally refresh the snapshot.\n\n" + "With that snapshot, immediately synthesize or update a trading hypothesis for each focus ticker using the " + "account data, existing positions, buying power, cash, recent orders, and stored memory. Decide whether the " + "current hypothesis already justifies HOLD/BUY/SELL before touching high-latency tools. Only call " + "heavy-weight analysts or vendor data feeds when the hypothesis requires fresh evidence (e.g., preparing a " + "trade, validating a catalyst, or detecting a change since the last run). If the prior plan still applies, " + "log that decision and proceed without re-running every analyst.\n\n" + "When deeper context is required, call `get_ticker_memory`, then use `fetch_market_data`, " + "`fetch_indicators`, and `fetch_company_news` (plus `fetch_global_news` when macro context matters) before " + "invoking the specialist analysts (`run_market_analyst`, `run_news_analyst`, `run_fundamentals_analyst`). " + "For any ticker that lacks stored memory or an active position, you must at minimum gather the last 7 days of " + "market data and the latest company news before finalizing your hypothesis so you remain curious and well-grounded.\n\n" + "Every recommendation must map to a named strategy (e.g., day_trade, swing, position) taken from the provided presets. " + "For each ticker, specify a measurable `strategy` object containing `name`, `horizon_hours`, `target_pct`, `stop_pct`, " + "`success_metric`, `failure_metric`, `follow_up`, `urgency`, and an ISO8601 `deadline` that defines when the plan is reevaluated. " + "Customize the preset parameters only when the evidence demands it, and ensure the success/failure metrics describe the exact " + "conditions that complete or cancel the hypothesis so automation can act on them.\n\n" + "Narrate every step before you make the tool call so the CLI can display your thinking live, and summarize what " + "you learned from each tool. Be curious: when a ticker’s context is thin, proactively explore the smallest set " + "of tools needed to form a defendable hypothesis rather than defaulting to HOLD. Maintain an explicit plan tracker: list "+ + "each planned action (e.g., ‘Step 1 – Fetch TSLA market data’) along with its status (`pending`, `in_progress`, `done`), " + "and after every tool call, announce which step changed status and why. Consider market-open status before " + "trading, respect trade execution limits (but treat `day_trades_remaining` as informational—you may still " + "recommend buys/sells), and keep narration concise but informative.\n\n" + "Conclude with a JSON summary containing decisions for each ticker. The final assistant message must include a " + "JSON object with a `decisions` array where each entry specifies `ticker`, `action`, `priority`, " + "`plan_actions`, `next_decision`, `notes`, `plan_status` (a mapping of each plan action to its status), the `strategy` object described above, " + "and optional `action_queue` and `execution_plan` fields. After " + "producing the JSON, call `submit_trade_order` for every ticker whose action is BUY or SELL (subject to trade " + "execution settings)." + ) + + def _safe_json(self, raw: str) -> Dict[str, Any]: + if not raw: + return {} + try: + return json.loads(raw) + except json.JSONDecodeError: + return {"raw": raw} + + def _strategy_presets_brief(self) -> Dict[str, Any]: + cfg = self.config.get("trading_strategies", {}) or {} + presets = cfg.get("presets", {}) or {} + entries: List[Dict[str, Any]] = [] + for name, data in presets.items(): + entries.append( + { + "name": name, + "label": data.get("label"), + "horizon_hours": data.get("horizon_hours"), + "target_pct": data.get("target_pct"), + "stop_pct": data.get("stop_pct"), + "follow_up": data.get("follow_up"), + "urgency": data.get("urgency"), + } + ) + return {"default": cfg.get("default", "swing"), "presets": entries} +def _extract_reference_price(entry: Dict[str, Any]) -> Optional[float]: + candidates = [ + entry.get("reference_price"), + (entry.get("state") or {}).get("price"), + entry.get("last_price"), + ] + for value in candidates: + try: + price = float(value) + if price > 0: + return price + except (TypeError, ValueError): + continue + return None + + +def _snapshot_reference_prices(snapshot: AccountSnapshot) -> Dict[str, float]: + mapping: Dict[str, float] = {} + for position in snapshot.positions: + symbol = str(position.get("symbol") or position.get("symbol:") or "").upper() + if not symbol: + continue + price_fields = [ + position.get("current_price"), + position.get("price"), + position.get("market_value"), + ] + value = None + for field in price_fields: + try: + candidate = float(str(field).replace("$", "")) + if candidate > 0: + value = candidate + break + except (TypeError, ValueError, AttributeError): + continue + if value: + mapping[symbol] = value + return mapping