feat: pre-computed indicator interpretations (#542)

Review feedback applied:
- compute.py: load OHLCV once, compute all indicators on single StockDataFrame
  (was 9 redundant disk reads, now 1)
- SMA crossover detection implemented (was hardcoded None)
- Removed unreachable support_resistance interpreter
- Test date changed to known historical date (2024-01-10)

Closes #542
This commit is contained in:
Clayton Brown 2026-04-20 22:45:02 +10:00
parent fa4d01c23a
commit 90c06f0d31
5 changed files with 389 additions and 1 deletions

View File

@ -0,0 +1,56 @@
"""Verify fetch_indicators returns valid data matching stockstats directly.
Since fetch_indicators now calls StockstatsUtils.get_stock_stats (the same
function the get_indicators LLM tool uses), these tests confirm the wrapper
correctly structures the results for interpret.py.
"""
import pytest
try:
from tradingagents.indicators.compute import fetch_indicators
HAS_DEPS = True
except ImportError:
HAS_DEPS = False
@pytest.mark.skipif(not HAS_DEPS, reason="indicators module not available")
@pytest.mark.integration
class TestFetchIndicators:
TICKER = "AAPL"
DATE = "2024-01-10"
def test_returns_dict(self):
result = fetch_indicators(self.TICKER, self.DATE)
assert isinstance(result, dict)
def test_rsi_in_range(self):
result = fetch_indicators(self.TICKER, self.DATE)
if "rsi" not in result:
pytest.skip("RSI not available")
assert 0 <= result["rsi"]["value"] <= 100
def test_macd_has_required_keys(self):
result = fetch_indicators(self.TICKER, self.DATE)
if "macd" not in result:
pytest.skip("MACD not available")
for key in ("value", "signal", "histogram"):
assert key in result["macd"], f"MACD missing '{key}'"
def test_bollinger_bands_ordered(self):
result = fetch_indicators(self.TICKER, self.DATE)
if "bollinger" not in result:
pytest.skip("Bollinger not available")
b = result["bollinger"]
assert b["lower"] < b["upper"], "Lower band should be below upper"
def test_sma_values_positive(self):
result = fetch_indicators(self.TICKER, self.DATE)
if "sma_crossover" not in result:
pytest.skip("SMA not available")
assert result["sma_crossover"]["sma50"] > 0
assert result["sma_crossover"]["sma200"] > 0
def test_empty_on_bad_ticker(self):
result = fetch_indicators("ZZZZZZNOTREAL", self.DATE)
assert result == {}

View File

@ -7,6 +7,23 @@ from tradingagents.agents.utils.agent_utils import (
)
from tradingagents.dataflows.config import get_config
try:
from tradingagents.indicators import get_indicator_signals
except ImportError:
get_indicator_signals = None
def _format_preamble(signals: dict) -> str:
"""Format pre-computed interpretations as a prompt preamble."""
if not signals:
return ""
parts = ["\n## Pre-Computed Indicator Interpretations\n"]
for name, s in signals.items():
parts.append(f"- **{name}**: {s['explanation']}")
parts.append("\nUse these alongside the indicator catalog below. "
"Do NOT recalculate these values.\n")
return "\n".join(parts)
def create_market_analyst(llm):
@ -14,13 +31,25 @@ def create_market_analyst(llm):
current_date = state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
# Pre-compute indicator interpretations when available
indicator_preamble = ""
if get_indicator_signals is not None:
try:
signals = get_indicator_signals(
state["company_of_interest"], current_date
)
indicator_preamble = _format_preamble(signals)
except Exception:
pass
tools = [
get_stock_data,
get_indicators,
]
system_message = (
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
indicator_preamble
+ """You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
Moving Averages:
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.

View File

@ -0,0 +1,27 @@
"""Technical indicators: fetch via stockstats, interpret, return structured signals."""
from __future__ import annotations
import logging
from typing import Any
from tradingagents.indicators.compute import fetch_indicators
from tradingagents.indicators.interpret import interpret_indicators
logger = logging.getLogger(__name__)
def get_indicator_signals(symbol: str, curr_date: str) -> dict[str, dict[str, Any]]:
"""Return structured indicator signals for *symbol* as of *curr_date*.
Uses the same stockstats pipeline as the get_indicators LLM tool.
Returns an empty dict on any failure. Never raises.
"""
try:
raw = fetch_indicators(symbol, curr_date)
if not raw:
return {}
return interpret_indicators(raw)
except Exception:
logger.warning("indicators: failed for %s@%s", symbol, curr_date, exc_info=True)
return {}

View File

@ -0,0 +1,114 @@
"""Fetch indicator values using the existing stockstats pipeline.
Loads OHLCV data once and computes all indicators on a single
StockDataFrame instance no redundant disk I/O.
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
# stockstats indicator names to fetch, grouped by interpretation key.
_STATS = [
"rsi_14", "macd", "macds", "macdh",
"boll_ub", "boll_lb", "boll",
"close_50_sma", "close_200_sma",
]
def fetch_indicators(symbol: str, curr_date: str) -> dict[str, dict[str, Any]]:
"""Fetch indicator values via stockstats for *symbol* as of *curr_date*.
Loads data once and computes all indicators in a single pass.
Returns a dict keyed by indicator group name. Missing indicators omitted.
"""
try:
from tradingagents.dataflows.stockstats_utils import load_ohlcv
from stockstats import wrap
except ImportError:
logger.warning("indicators: stockstats or dataflows not available")
return {}
try:
df = load_ohlcv(symbol, curr_date)
except Exception as exc:
logger.warning("indicators: load_ohlcv failed for %s: %s", symbol, exc)
return {}
if df.empty or len(df) < 2:
return {}
try:
ss = wrap(df)
# Trigger computation of all indicators in one pass
for stat in _STATS:
_ = ss[stat]
except Exception as exc:
logger.warning("indicators: stockstats computation failed for %s: %s", symbol, exc)
return {}
# Extract last row values
import pandas as pd
curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d")
ss["date_str"] = ss["date"].dt.strftime("%Y-%m-%d") if "date" in ss.columns else ""
row = ss[ss["date_str"] == curr_date_str]
if row.empty:
row = ss.iloc[[-1]] # fallback to last available row
def _val(col: str) -> float | None:
try:
v = float(row[col].iloc[0])
return None if pd.isna(v) else v
except (KeyError, IndexError):
return None
results: dict[str, dict[str, Any]] = {}
# RSI
rsi = _val("rsi_14")
if rsi is not None:
results["rsi"] = {"value": rsi, "period": 14}
# MACD
macd_val = _val("macd")
if macd_val is not None:
results["macd"] = {
"value": macd_val,
"signal": _val("macds"),
"histogram": _val("macdh"),
}
# Bollinger
upper = _val("boll_ub")
lower = _val("boll_lb")
if upper is not None and lower is not None:
results["bollinger"] = {
"value": _val("boll"),
"upper": upper,
"lower": lower,
}
# SMA crossover
sma50 = _val("close_50_sma")
sma200 = _val("close_200_sma")
if sma50 is not None and sma200 is not None:
# Detect crossover from recent data
crossover = None
try:
diff = ss["close_50_sma"] - ss["close_200_sma"]
recent = diff.iloc[-5:]
signs = recent.dropna().apply(lambda x: 1 if x > 0 else -1)
if len(signs) >= 2 and signs.iloc[-1] != signs.iloc[0]:
crossover = "golden_cross" if signs.iloc[-1] > 0 else "death_cross"
except Exception:
pass
results["sma_crossover"] = {
"sma50": sma50,
"sma200": sma200,
"crossover": crossover,
}
return results

View File

@ -0,0 +1,162 @@
"""Rule-based interpretation of computed technical indicators.
Converts raw numeric indicator values into structured signals
(bullish / bearish / neutral) with confidence and human-readable
explanations. No LLM involvement pure deterministic rules.
"""
from __future__ import annotations
from typing import Any
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def interpret_indicators(
computed: dict[str, dict[str, Any]],
) -> dict[str, dict[str, Any]]:
"""Interpret every indicator in *computed* and return structured signals.
Each result dict contains: ``value``, ``signal``, ``confidence``, ``explanation``.
Unknown indicators are silently skipped.
"""
_INTERPRETERS: dict[str, Any] = {
"rsi": _interpret_rsi,
"macd": _interpret_macd,
"bollinger": _interpret_bollinger,
"sma_crossover": _interpret_sma_crossover,
}
results: dict[str, dict[str, Any]] = {}
for name, data in computed.items():
fn = _INTERPRETERS.get(name)
if fn is not None:
results[name] = fn(data)
return results
# ---------------------------------------------------------------------------
# Per-indicator interpreters
# ---------------------------------------------------------------------------
def _interpret_rsi(data: dict[str, Any]) -> dict[str, Any]:
val = data.get("value")
if val is None:
return _neutral(val, "Insufficient data for RSI")
if val >= 80:
return _signal(val, "bearish", 0.9, f"RSI {val:.1f} — strongly overbought")
if val >= 70:
return _signal(val, "bearish", 0.7, f"RSI {val:.1f} — overbought")
if val <= 20:
return _signal(val, "bullish", 0.9, f"RSI {val:.1f} — strongly oversold")
if val <= 30:
return _signal(val, "bullish", 0.7, f"RSI {val:.1f} — oversold")
return _neutral(val, f"RSI {val:.1f} — neutral range")
def _interpret_macd(data: dict[str, Any]) -> dict[str, Any]:
hist = data.get("histogram")
macd_val = data.get("value")
sig_val = data.get("signal")
if hist is None or macd_val is None:
return _neutral(macd_val, "Insufficient data for MACD")
direction = "bullish" if hist > 0 else "bearish"
# Confidence scales with histogram magnitude relative to signal line
ref = abs(sig_val) if sig_val else 1.0
strength = min(abs(hist) / max(ref, 0.01), 1.0)
confidence = round(0.5 + 0.4 * strength, 2)
crossing = ""
if abs(hist) < 0.05 * max(ref, 0.01):
crossing = " (near crossover)"
confidence = 0.5
return _signal(
macd_val,
direction,
confidence,
f"MACD histogram {hist:+.4f}{crossing}{direction}",
)
def _interpret_bollinger(data: dict[str, Any]) -> dict[str, Any]:
price = data.get("value")
upper = data.get("upper")
lower = data.get("lower")
if price is None or upper is None or lower is None:
return _neutral(price, "Insufficient data for Bollinger Bands")
band_width = upper - lower
if band_width <= 0:
return _neutral(price, "Bollinger band width is zero")
position = (price - lower) / band_width # 0 = at lower, 1 = at upper
if position >= 1.0:
return _signal(price, "bearish", 0.8, f"Price at/above upper Bollinger Band — overbought")
if position >= 0.8:
return _signal(price, "bearish", 0.6, f"Price near upper Bollinger Band ({position:.0%})")
if position <= 0.0:
return _signal(price, "bullish", 0.8, f"Price at/below lower Bollinger Band — oversold")
if position <= 0.2:
return _signal(price, "bullish", 0.6, f"Price near lower Bollinger Band ({position:.0%})")
return _neutral(price, f"Price within Bollinger Bands ({position:.0%})")
def _interpret_sma_crossover(data: dict[str, Any]) -> dict[str, Any]:
sma50 = data.get("sma50")
sma200 = data.get("sma200")
crossover = data.get("crossover")
if sma50 is None or sma200 is None:
return _neutral(sma50, "Insufficient data for SMA crossover")
if crossover == "golden_cross":
return _signal(sma50, "bullish", 0.85, "Golden cross — SMA50 crossed above SMA200")
if crossover == "death_cross":
return _signal(sma50, "bearish", 0.85, "Death cross — SMA50 crossed below SMA200")
if sma50 > sma200:
return _signal(sma50, "bullish", 0.6, f"SMA50 ({sma50:.2f}) above SMA200 ({sma200:.2f}) — bullish trend")
return _signal(sma50, "bearish", 0.6, f"SMA50 ({sma50:.2f}) below SMA200 ({sma200:.2f}) — bearish trend")
def _interpret_support_resistance(data: dict[str, Any]) -> dict[str, Any]:
price = data.get("last_close")
resistance = data.get("resistance")
support = data.get("support")
if price is None or resistance is None or support is None:
return _neutral(price, "Insufficient data for support/resistance")
rng = resistance - support
if rng <= 0:
return _neutral(price, "Support equals resistance")
position = (price - support) / rng
if position >= 0.9:
return _signal(price, "bearish", 0.65, f"Price near resistance ({resistance:.2f})")
if position <= 0.1:
return _signal(price, "bullish", 0.65, f"Price near support ({support:.2f})")
return _neutral(price, f"Price between support ({support:.2f}) and resistance ({resistance:.2f})")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _signal(
value: Any, signal: str, confidence: float, explanation: str,
) -> dict[str, Any]:
return {"value": value, "signal": signal, "confidence": confidence, "explanation": explanation}
def _neutral(value: Any, explanation: str) -> dict[str, Any]:
return _signal(value, "neutral", 0.5, explanation)