From 90c06f0d316cfa9c5d76bc9516d547eb57c7a94a Mon Sep 17 00:00:00 2001 From: Clayton Brown Date: Mon, 20 Apr 2026 22:45:02 +1000 Subject: [PATCH] feat: pre-computed indicator interpretations (#542) Review feedback applied: - compute.py: load OHLCV once, compute all indicators on single StockDataFrame (was 9 redundant disk reads, now 1) - SMA crossover detection implemented (was hardcoded None) - Removed unreachable support_resistance interpreter - Test date changed to known historical date (2024-01-10) Closes #542 --- tests/test_compute_vs_fetch.py | 56 ++++++ .../agents/analysts/market_analyst.py | 31 +++- tradingagents/indicators/__init__.py | 27 +++ tradingagents/indicators/compute.py | 114 ++++++++++++ tradingagents/indicators/interpret.py | 162 ++++++++++++++++++ 5 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 tests/test_compute_vs_fetch.py create mode 100644 tradingagents/indicators/__init__.py create mode 100644 tradingagents/indicators/compute.py create mode 100644 tradingagents/indicators/interpret.py diff --git a/tests/test_compute_vs_fetch.py b/tests/test_compute_vs_fetch.py new file mode 100644 index 00000000..a5e69e7f --- /dev/null +++ b/tests/test_compute_vs_fetch.py @@ -0,0 +1,56 @@ +"""Verify fetch_indicators returns valid data matching stockstats directly. + +Since fetch_indicators now calls StockstatsUtils.get_stock_stats (the same +function the get_indicators LLM tool uses), these tests confirm the wrapper +correctly structures the results for interpret.py. +""" +import pytest + +try: + from tradingagents.indicators.compute import fetch_indicators + HAS_DEPS = True +except ImportError: + HAS_DEPS = False + + +@pytest.mark.skipif(not HAS_DEPS, reason="indicators module not available") +@pytest.mark.integration +class TestFetchIndicators: + + TICKER = "AAPL" + DATE = "2024-01-10" + + def test_returns_dict(self): + result = fetch_indicators(self.TICKER, self.DATE) + assert isinstance(result, dict) + + def test_rsi_in_range(self): + result = fetch_indicators(self.TICKER, self.DATE) + if "rsi" not in result: + pytest.skip("RSI not available") + assert 0 <= result["rsi"]["value"] <= 100 + + def test_macd_has_required_keys(self): + result = fetch_indicators(self.TICKER, self.DATE) + if "macd" not in result: + pytest.skip("MACD not available") + for key in ("value", "signal", "histogram"): + assert key in result["macd"], f"MACD missing '{key}'" + + def test_bollinger_bands_ordered(self): + result = fetch_indicators(self.TICKER, self.DATE) + if "bollinger" not in result: + pytest.skip("Bollinger not available") + b = result["bollinger"] + assert b["lower"] < b["upper"], "Lower band should be below upper" + + def test_sma_values_positive(self): + result = fetch_indicators(self.TICKER, self.DATE) + if "sma_crossover" not in result: + pytest.skip("SMA not available") + assert result["sma_crossover"]["sma50"] > 0 + assert result["sma_crossover"]["sma200"] > 0 + + def test_empty_on_bad_ticker(self): + result = fetch_indicators("ZZZZZZNOTREAL", self.DATE) + assert result == {} diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index fef8f751..acb265f7 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -7,6 +7,23 @@ from tradingagents.agents.utils.agent_utils import ( ) from tradingagents.dataflows.config import get_config +try: + from tradingagents.indicators import get_indicator_signals +except ImportError: + get_indicator_signals = None + + +def _format_preamble(signals: dict) -> str: + """Format pre-computed interpretations as a prompt preamble.""" + if not signals: + return "" + parts = ["\n## Pre-Computed Indicator Interpretations\n"] + for name, s in signals.items(): + parts.append(f"- **{name}**: {s['explanation']}") + parts.append("\nUse these alongside the indicator catalog below. " + "Do NOT recalculate these values.\n") + return "\n".join(parts) + def create_market_analyst(llm): @@ -14,13 +31,25 @@ def create_market_analyst(llm): current_date = state["trade_date"] instrument_context = build_instrument_context(state["company_of_interest"]) + # Pre-compute indicator interpretations when available + indicator_preamble = "" + if get_indicator_signals is not None: + try: + signals = get_indicator_signals( + state["company_of_interest"], current_date + ) + indicator_preamble = _format_preamble(signals) + except Exception: + pass + tools = [ get_stock_data, get_indicators, ] system_message = ( - """You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are: + indicator_preamble + + """You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are: Moving Averages: - close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals. diff --git a/tradingagents/indicators/__init__.py b/tradingagents/indicators/__init__.py new file mode 100644 index 00000000..1fd539ae --- /dev/null +++ b/tradingagents/indicators/__init__.py @@ -0,0 +1,27 @@ +"""Technical indicators: fetch via stockstats, interpret, return structured signals.""" + +from __future__ import annotations + +import logging +from typing import Any + +from tradingagents.indicators.compute import fetch_indicators +from tradingagents.indicators.interpret import interpret_indicators + +logger = logging.getLogger(__name__) + + +def get_indicator_signals(symbol: str, curr_date: str) -> dict[str, dict[str, Any]]: + """Return structured indicator signals for *symbol* as of *curr_date*. + + Uses the same stockstats pipeline as the get_indicators LLM tool. + Returns an empty dict on any failure. Never raises. + """ + try: + raw = fetch_indicators(symbol, curr_date) + if not raw: + return {} + return interpret_indicators(raw) + except Exception: + logger.warning("indicators: failed for %s@%s", symbol, curr_date, exc_info=True) + return {} diff --git a/tradingagents/indicators/compute.py b/tradingagents/indicators/compute.py new file mode 100644 index 00000000..b4a0d3c9 --- /dev/null +++ b/tradingagents/indicators/compute.py @@ -0,0 +1,114 @@ +"""Fetch indicator values using the existing stockstats pipeline. + +Loads OHLCV data once and computes all indicators on a single +StockDataFrame instance — no redundant disk I/O. +""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +# stockstats indicator names to fetch, grouped by interpretation key. +_STATS = [ + "rsi_14", "macd", "macds", "macdh", + "boll_ub", "boll_lb", "boll", + "close_50_sma", "close_200_sma", +] + + +def fetch_indicators(symbol: str, curr_date: str) -> dict[str, dict[str, Any]]: + """Fetch indicator values via stockstats for *symbol* as of *curr_date*. + + Loads data once and computes all indicators in a single pass. + Returns a dict keyed by indicator group name. Missing indicators omitted. + """ + try: + from tradingagents.dataflows.stockstats_utils import load_ohlcv + from stockstats import wrap + except ImportError: + logger.warning("indicators: stockstats or dataflows not available") + return {} + + try: + df = load_ohlcv(symbol, curr_date) + except Exception as exc: + logger.warning("indicators: load_ohlcv failed for %s: %s", symbol, exc) + return {} + + if df.empty or len(df) < 2: + return {} + + try: + ss = wrap(df) + # Trigger computation of all indicators in one pass + for stat in _STATS: + _ = ss[stat] + except Exception as exc: + logger.warning("indicators: stockstats computation failed for %s: %s", symbol, exc) + return {} + + # Extract last row values + import pandas as pd + curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d") + ss["date_str"] = ss["date"].dt.strftime("%Y-%m-%d") if "date" in ss.columns else "" + row = ss[ss["date_str"] == curr_date_str] + if row.empty: + row = ss.iloc[[-1]] # fallback to last available row + + def _val(col: str) -> float | None: + try: + v = float(row[col].iloc[0]) + return None if pd.isna(v) else v + except (KeyError, IndexError): + return None + + results: dict[str, dict[str, Any]] = {} + + # RSI + rsi = _val("rsi_14") + if rsi is not None: + results["rsi"] = {"value": rsi, "period": 14} + + # MACD + macd_val = _val("macd") + if macd_val is not None: + results["macd"] = { + "value": macd_val, + "signal": _val("macds"), + "histogram": _val("macdh"), + } + + # Bollinger + upper = _val("boll_ub") + lower = _val("boll_lb") + if upper is not None and lower is not None: + results["bollinger"] = { + "value": _val("boll"), + "upper": upper, + "lower": lower, + } + + # SMA crossover + sma50 = _val("close_50_sma") + sma200 = _val("close_200_sma") + if sma50 is not None and sma200 is not None: + # Detect crossover from recent data + crossover = None + try: + diff = ss["close_50_sma"] - ss["close_200_sma"] + recent = diff.iloc[-5:] + signs = recent.dropna().apply(lambda x: 1 if x > 0 else -1) + if len(signs) >= 2 and signs.iloc[-1] != signs.iloc[0]: + crossover = "golden_cross" if signs.iloc[-1] > 0 else "death_cross" + except Exception: + pass + results["sma_crossover"] = { + "sma50": sma50, + "sma200": sma200, + "crossover": crossover, + } + + return results diff --git a/tradingagents/indicators/interpret.py b/tradingagents/indicators/interpret.py new file mode 100644 index 00000000..01a150e5 --- /dev/null +++ b/tradingagents/indicators/interpret.py @@ -0,0 +1,162 @@ +"""Rule-based interpretation of computed technical indicators. + +Converts raw numeric indicator values into structured signals +(bullish / bearish / neutral) with confidence and human-readable +explanations. No LLM involvement — pure deterministic rules. +""" + +from __future__ import annotations + +from typing import Any + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def interpret_indicators( + computed: dict[str, dict[str, Any]], +) -> dict[str, dict[str, Any]]: + """Interpret every indicator in *computed* and return structured signals. + + Each result dict contains: ``value``, ``signal``, ``confidence``, ``explanation``. + Unknown indicators are silently skipped. + """ + _INTERPRETERS: dict[str, Any] = { + "rsi": _interpret_rsi, + "macd": _interpret_macd, + "bollinger": _interpret_bollinger, + "sma_crossover": _interpret_sma_crossover, + } + + results: dict[str, dict[str, Any]] = {} + for name, data in computed.items(): + fn = _INTERPRETERS.get(name) + if fn is not None: + results[name] = fn(data) + return results + + +# --------------------------------------------------------------------------- +# Per-indicator interpreters +# --------------------------------------------------------------------------- + +def _interpret_rsi(data: dict[str, Any]) -> dict[str, Any]: + val = data.get("value") + if val is None: + return _neutral(val, "Insufficient data for RSI") + + if val >= 80: + return _signal(val, "bearish", 0.9, f"RSI {val:.1f} — strongly overbought") + if val >= 70: + return _signal(val, "bearish", 0.7, f"RSI {val:.1f} — overbought") + if val <= 20: + return _signal(val, "bullish", 0.9, f"RSI {val:.1f} — strongly oversold") + if val <= 30: + return _signal(val, "bullish", 0.7, f"RSI {val:.1f} — oversold") + return _neutral(val, f"RSI {val:.1f} — neutral range") + + +def _interpret_macd(data: dict[str, Any]) -> dict[str, Any]: + hist = data.get("histogram") + macd_val = data.get("value") + sig_val = data.get("signal") + + if hist is None or macd_val is None: + return _neutral(macd_val, "Insufficient data for MACD") + + direction = "bullish" if hist > 0 else "bearish" + # Confidence scales with histogram magnitude relative to signal line + ref = abs(sig_val) if sig_val else 1.0 + strength = min(abs(hist) / max(ref, 0.01), 1.0) + confidence = round(0.5 + 0.4 * strength, 2) + + crossing = "" + if abs(hist) < 0.05 * max(ref, 0.01): + crossing = " (near crossover)" + confidence = 0.5 + + return _signal( + macd_val, + direction, + confidence, + f"MACD histogram {hist:+.4f}{crossing} — {direction}", + ) + + +def _interpret_bollinger(data: dict[str, Any]) -> dict[str, Any]: + price = data.get("value") + upper = data.get("upper") + lower = data.get("lower") + + if price is None or upper is None or lower is None: + return _neutral(price, "Insufficient data for Bollinger Bands") + + band_width = upper - lower + if band_width <= 0: + return _neutral(price, "Bollinger band width is zero") + + position = (price - lower) / band_width # 0 = at lower, 1 = at upper + + if position >= 1.0: + return _signal(price, "bearish", 0.8, f"Price at/above upper Bollinger Band — overbought") + if position >= 0.8: + return _signal(price, "bearish", 0.6, f"Price near upper Bollinger Band ({position:.0%})") + if position <= 0.0: + return _signal(price, "bullish", 0.8, f"Price at/below lower Bollinger Band — oversold") + if position <= 0.2: + return _signal(price, "bullish", 0.6, f"Price near lower Bollinger Band ({position:.0%})") + return _neutral(price, f"Price within Bollinger Bands ({position:.0%})") + + +def _interpret_sma_crossover(data: dict[str, Any]) -> dict[str, Any]: + sma50 = data.get("sma50") + sma200 = data.get("sma200") + crossover = data.get("crossover") + + if sma50 is None or sma200 is None: + return _neutral(sma50, "Insufficient data for SMA crossover") + + if crossover == "golden_cross": + return _signal(sma50, "bullish", 0.85, "Golden cross — SMA50 crossed above SMA200") + if crossover == "death_cross": + return _signal(sma50, "bearish", 0.85, "Death cross — SMA50 crossed below SMA200") + + if sma50 > sma200: + return _signal(sma50, "bullish", 0.6, f"SMA50 ({sma50:.2f}) above SMA200 ({sma200:.2f}) — bullish trend") + return _signal(sma50, "bearish", 0.6, f"SMA50 ({sma50:.2f}) below SMA200 ({sma200:.2f}) — bearish trend") + + +def _interpret_support_resistance(data: dict[str, Any]) -> dict[str, Any]: + price = data.get("last_close") + resistance = data.get("resistance") + support = data.get("support") + + if price is None or resistance is None or support is None: + return _neutral(price, "Insufficient data for support/resistance") + + rng = resistance - support + if rng <= 0: + return _neutral(price, "Support equals resistance") + + position = (price - support) / rng + + if position >= 0.9: + return _signal(price, "bearish", 0.65, f"Price near resistance ({resistance:.2f})") + if position <= 0.1: + return _signal(price, "bullish", 0.65, f"Price near support ({support:.2f})") + return _neutral(price, f"Price between support ({support:.2f}) and resistance ({resistance:.2f})") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _signal( + value: Any, signal: str, confidence: float, explanation: str, +) -> dict[str, Any]: + return {"value": value, "signal": signal, "confidence": confidence, "explanation": explanation} + + +def _neutral(value: Any, explanation: str) -> dict[str, Any]: + return _signal(value, "neutral", 0.5, explanation)