feat: pre-computed indicator interpretations (#542)
Review feedback applied: - compute.py: load OHLCV once, compute all indicators on single StockDataFrame (was 9 redundant disk reads, now 1) - SMA crossover detection implemented (was hardcoded None) - Removed unreachable support_resistance interpreter - Test date changed to known historical date (2024-01-10) Closes #542
This commit is contained in:
parent
fa4d01c23a
commit
90c06f0d31
|
|
@ -0,0 +1,56 @@
|
|||
"""Verify fetch_indicators returns valid data matching stockstats directly.
|
||||
|
||||
Since fetch_indicators now calls StockstatsUtils.get_stock_stats (the same
|
||||
function the get_indicators LLM tool uses), these tests confirm the wrapper
|
||||
correctly structures the results for interpret.py.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from tradingagents.indicators.compute import fetch_indicators
|
||||
HAS_DEPS = True
|
||||
except ImportError:
|
||||
HAS_DEPS = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_DEPS, reason="indicators module not available")
|
||||
@pytest.mark.integration
|
||||
class TestFetchIndicators:
|
||||
|
||||
TICKER = "AAPL"
|
||||
DATE = "2024-01-10"
|
||||
|
||||
def test_returns_dict(self):
|
||||
result = fetch_indicators(self.TICKER, self.DATE)
|
||||
assert isinstance(result, dict)
|
||||
|
||||
def test_rsi_in_range(self):
|
||||
result = fetch_indicators(self.TICKER, self.DATE)
|
||||
if "rsi" not in result:
|
||||
pytest.skip("RSI not available")
|
||||
assert 0 <= result["rsi"]["value"] <= 100
|
||||
|
||||
def test_macd_has_required_keys(self):
|
||||
result = fetch_indicators(self.TICKER, self.DATE)
|
||||
if "macd" not in result:
|
||||
pytest.skip("MACD not available")
|
||||
for key in ("value", "signal", "histogram"):
|
||||
assert key in result["macd"], f"MACD missing '{key}'"
|
||||
|
||||
def test_bollinger_bands_ordered(self):
|
||||
result = fetch_indicators(self.TICKER, self.DATE)
|
||||
if "bollinger" not in result:
|
||||
pytest.skip("Bollinger not available")
|
||||
b = result["bollinger"]
|
||||
assert b["lower"] < b["upper"], "Lower band should be below upper"
|
||||
|
||||
def test_sma_values_positive(self):
|
||||
result = fetch_indicators(self.TICKER, self.DATE)
|
||||
if "sma_crossover" not in result:
|
||||
pytest.skip("SMA not available")
|
||||
assert result["sma_crossover"]["sma50"] > 0
|
||||
assert result["sma_crossover"]["sma200"] > 0
|
||||
|
||||
def test_empty_on_bad_ticker(self):
|
||||
result = fetch_indicators("ZZZZZZNOTREAL", self.DATE)
|
||||
assert result == {}
|
||||
|
|
@ -7,6 +7,23 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
try:
|
||||
from tradingagents.indicators import get_indicator_signals
|
||||
except ImportError:
|
||||
get_indicator_signals = None
|
||||
|
||||
|
||||
def _format_preamble(signals: dict) -> str:
|
||||
"""Format pre-computed interpretations as a prompt preamble."""
|
||||
if not signals:
|
||||
return ""
|
||||
parts = ["\n## Pre-Computed Indicator Interpretations\n"]
|
||||
for name, s in signals.items():
|
||||
parts.append(f"- **{name}**: {s['explanation']}")
|
||||
parts.append("\nUse these alongside the indicator catalog below. "
|
||||
"Do NOT recalculate these values.\n")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def create_market_analyst(llm):
|
||||
|
||||
|
|
@ -14,13 +31,25 @@ def create_market_analyst(llm):
|
|||
current_date = state["trade_date"]
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
|
||||
# Pre-compute indicator interpretations when available
|
||||
indicator_preamble = ""
|
||||
if get_indicator_signals is not None:
|
||||
try:
|
||||
signals = get_indicator_signals(
|
||||
state["company_of_interest"], current_date
|
||||
)
|
||||
indicator_preamble = _format_preamble(signals)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tools = [
|
||||
get_stock_data,
|
||||
get_indicators,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
indicator_preamble
|
||||
+ """You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
|
||||
Moving Averages:
|
||||
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
"""Technical indicators: fetch via stockstats, interpret, return structured signals."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.indicators.compute import fetch_indicators
|
||||
from tradingagents.indicators.interpret import interpret_indicators
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_indicator_signals(symbol: str, curr_date: str) -> dict[str, dict[str, Any]]:
|
||||
"""Return structured indicator signals for *symbol* as of *curr_date*.
|
||||
|
||||
Uses the same stockstats pipeline as the get_indicators LLM tool.
|
||||
Returns an empty dict on any failure. Never raises.
|
||||
"""
|
||||
try:
|
||||
raw = fetch_indicators(symbol, curr_date)
|
||||
if not raw:
|
||||
return {}
|
||||
return interpret_indicators(raw)
|
||||
except Exception:
|
||||
logger.warning("indicators: failed for %s@%s", symbol, curr_date, exc_info=True)
|
||||
return {}
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
"""Fetch indicator values using the existing stockstats pipeline.
|
||||
|
||||
Loads OHLCV data once and computes all indicators on a single
|
||||
StockDataFrame instance — no redundant disk I/O.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# stockstats indicator names to fetch, grouped by interpretation key.
|
||||
_STATS = [
|
||||
"rsi_14", "macd", "macds", "macdh",
|
||||
"boll_ub", "boll_lb", "boll",
|
||||
"close_50_sma", "close_200_sma",
|
||||
]
|
||||
|
||||
|
||||
def fetch_indicators(symbol: str, curr_date: str) -> dict[str, dict[str, Any]]:
|
||||
"""Fetch indicator values via stockstats for *symbol* as of *curr_date*.
|
||||
|
||||
Loads data once and computes all indicators in a single pass.
|
||||
Returns a dict keyed by indicator group name. Missing indicators omitted.
|
||||
"""
|
||||
try:
|
||||
from tradingagents.dataflows.stockstats_utils import load_ohlcv
|
||||
from stockstats import wrap
|
||||
except ImportError:
|
||||
logger.warning("indicators: stockstats or dataflows not available")
|
||||
return {}
|
||||
|
||||
try:
|
||||
df = load_ohlcv(symbol, curr_date)
|
||||
except Exception as exc:
|
||||
logger.warning("indicators: load_ohlcv failed for %s: %s", symbol, exc)
|
||||
return {}
|
||||
|
||||
if df.empty or len(df) < 2:
|
||||
return {}
|
||||
|
||||
try:
|
||||
ss = wrap(df)
|
||||
# Trigger computation of all indicators in one pass
|
||||
for stat in _STATS:
|
||||
_ = ss[stat]
|
||||
except Exception as exc:
|
||||
logger.warning("indicators: stockstats computation failed for %s: %s", symbol, exc)
|
||||
return {}
|
||||
|
||||
# Extract last row values
|
||||
import pandas as pd
|
||||
curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d")
|
||||
ss["date_str"] = ss["date"].dt.strftime("%Y-%m-%d") if "date" in ss.columns else ""
|
||||
row = ss[ss["date_str"] == curr_date_str]
|
||||
if row.empty:
|
||||
row = ss.iloc[[-1]] # fallback to last available row
|
||||
|
||||
def _val(col: str) -> float | None:
|
||||
try:
|
||||
v = float(row[col].iloc[0])
|
||||
return None if pd.isna(v) else v
|
||||
except (KeyError, IndexError):
|
||||
return None
|
||||
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# RSI
|
||||
rsi = _val("rsi_14")
|
||||
if rsi is not None:
|
||||
results["rsi"] = {"value": rsi, "period": 14}
|
||||
|
||||
# MACD
|
||||
macd_val = _val("macd")
|
||||
if macd_val is not None:
|
||||
results["macd"] = {
|
||||
"value": macd_val,
|
||||
"signal": _val("macds"),
|
||||
"histogram": _val("macdh"),
|
||||
}
|
||||
|
||||
# Bollinger
|
||||
upper = _val("boll_ub")
|
||||
lower = _val("boll_lb")
|
||||
if upper is not None and lower is not None:
|
||||
results["bollinger"] = {
|
||||
"value": _val("boll"),
|
||||
"upper": upper,
|
||||
"lower": lower,
|
||||
}
|
||||
|
||||
# SMA crossover
|
||||
sma50 = _val("close_50_sma")
|
||||
sma200 = _val("close_200_sma")
|
||||
if sma50 is not None and sma200 is not None:
|
||||
# Detect crossover from recent data
|
||||
crossover = None
|
||||
try:
|
||||
diff = ss["close_50_sma"] - ss["close_200_sma"]
|
||||
recent = diff.iloc[-5:]
|
||||
signs = recent.dropna().apply(lambda x: 1 if x > 0 else -1)
|
||||
if len(signs) >= 2 and signs.iloc[-1] != signs.iloc[0]:
|
||||
crossover = "golden_cross" if signs.iloc[-1] > 0 else "death_cross"
|
||||
except Exception:
|
||||
pass
|
||||
results["sma_crossover"] = {
|
||||
"sma50": sma50,
|
||||
"sma200": sma200,
|
||||
"crossover": crossover,
|
||||
}
|
||||
|
||||
return results
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
"""Rule-based interpretation of computed technical indicators.
|
||||
|
||||
Converts raw numeric indicator values into structured signals
|
||||
(bullish / bearish / neutral) with confidence and human-readable
|
||||
explanations. No LLM involvement — pure deterministic rules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def interpret_indicators(
|
||||
computed: dict[str, dict[str, Any]],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Interpret every indicator in *computed* and return structured signals.
|
||||
|
||||
Each result dict contains: ``value``, ``signal``, ``confidence``, ``explanation``.
|
||||
Unknown indicators are silently skipped.
|
||||
"""
|
||||
_INTERPRETERS: dict[str, Any] = {
|
||||
"rsi": _interpret_rsi,
|
||||
"macd": _interpret_macd,
|
||||
"bollinger": _interpret_bollinger,
|
||||
"sma_crossover": _interpret_sma_crossover,
|
||||
}
|
||||
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
for name, data in computed.items():
|
||||
fn = _INTERPRETERS.get(name)
|
||||
if fn is not None:
|
||||
results[name] = fn(data)
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-indicator interpreters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _interpret_rsi(data: dict[str, Any]) -> dict[str, Any]:
|
||||
val = data.get("value")
|
||||
if val is None:
|
||||
return _neutral(val, "Insufficient data for RSI")
|
||||
|
||||
if val >= 80:
|
||||
return _signal(val, "bearish", 0.9, f"RSI {val:.1f} — strongly overbought")
|
||||
if val >= 70:
|
||||
return _signal(val, "bearish", 0.7, f"RSI {val:.1f} — overbought")
|
||||
if val <= 20:
|
||||
return _signal(val, "bullish", 0.9, f"RSI {val:.1f} — strongly oversold")
|
||||
if val <= 30:
|
||||
return _signal(val, "bullish", 0.7, f"RSI {val:.1f} — oversold")
|
||||
return _neutral(val, f"RSI {val:.1f} — neutral range")
|
||||
|
||||
|
||||
def _interpret_macd(data: dict[str, Any]) -> dict[str, Any]:
|
||||
hist = data.get("histogram")
|
||||
macd_val = data.get("value")
|
||||
sig_val = data.get("signal")
|
||||
|
||||
if hist is None or macd_val is None:
|
||||
return _neutral(macd_val, "Insufficient data for MACD")
|
||||
|
||||
direction = "bullish" if hist > 0 else "bearish"
|
||||
# Confidence scales with histogram magnitude relative to signal line
|
||||
ref = abs(sig_val) if sig_val else 1.0
|
||||
strength = min(abs(hist) / max(ref, 0.01), 1.0)
|
||||
confidence = round(0.5 + 0.4 * strength, 2)
|
||||
|
||||
crossing = ""
|
||||
if abs(hist) < 0.05 * max(ref, 0.01):
|
||||
crossing = " (near crossover)"
|
||||
confidence = 0.5
|
||||
|
||||
return _signal(
|
||||
macd_val,
|
||||
direction,
|
||||
confidence,
|
||||
f"MACD histogram {hist:+.4f}{crossing} — {direction}",
|
||||
)
|
||||
|
||||
|
||||
def _interpret_bollinger(data: dict[str, Any]) -> dict[str, Any]:
|
||||
price = data.get("value")
|
||||
upper = data.get("upper")
|
||||
lower = data.get("lower")
|
||||
|
||||
if price is None or upper is None or lower is None:
|
||||
return _neutral(price, "Insufficient data for Bollinger Bands")
|
||||
|
||||
band_width = upper - lower
|
||||
if band_width <= 0:
|
||||
return _neutral(price, "Bollinger band width is zero")
|
||||
|
||||
position = (price - lower) / band_width # 0 = at lower, 1 = at upper
|
||||
|
||||
if position >= 1.0:
|
||||
return _signal(price, "bearish", 0.8, f"Price at/above upper Bollinger Band — overbought")
|
||||
if position >= 0.8:
|
||||
return _signal(price, "bearish", 0.6, f"Price near upper Bollinger Band ({position:.0%})")
|
||||
if position <= 0.0:
|
||||
return _signal(price, "bullish", 0.8, f"Price at/below lower Bollinger Band — oversold")
|
||||
if position <= 0.2:
|
||||
return _signal(price, "bullish", 0.6, f"Price near lower Bollinger Band ({position:.0%})")
|
||||
return _neutral(price, f"Price within Bollinger Bands ({position:.0%})")
|
||||
|
||||
|
||||
def _interpret_sma_crossover(data: dict[str, Any]) -> dict[str, Any]:
|
||||
sma50 = data.get("sma50")
|
||||
sma200 = data.get("sma200")
|
||||
crossover = data.get("crossover")
|
||||
|
||||
if sma50 is None or sma200 is None:
|
||||
return _neutral(sma50, "Insufficient data for SMA crossover")
|
||||
|
||||
if crossover == "golden_cross":
|
||||
return _signal(sma50, "bullish", 0.85, "Golden cross — SMA50 crossed above SMA200")
|
||||
if crossover == "death_cross":
|
||||
return _signal(sma50, "bearish", 0.85, "Death cross — SMA50 crossed below SMA200")
|
||||
|
||||
if sma50 > sma200:
|
||||
return _signal(sma50, "bullish", 0.6, f"SMA50 ({sma50:.2f}) above SMA200 ({sma200:.2f}) — bullish trend")
|
||||
return _signal(sma50, "bearish", 0.6, f"SMA50 ({sma50:.2f}) below SMA200 ({sma200:.2f}) — bearish trend")
|
||||
|
||||
|
||||
def _interpret_support_resistance(data: dict[str, Any]) -> dict[str, Any]:
|
||||
price = data.get("last_close")
|
||||
resistance = data.get("resistance")
|
||||
support = data.get("support")
|
||||
|
||||
if price is None or resistance is None or support is None:
|
||||
return _neutral(price, "Insufficient data for support/resistance")
|
||||
|
||||
rng = resistance - support
|
||||
if rng <= 0:
|
||||
return _neutral(price, "Support equals resistance")
|
||||
|
||||
position = (price - support) / rng
|
||||
|
||||
if position >= 0.9:
|
||||
return _signal(price, "bearish", 0.65, f"Price near resistance ({resistance:.2f})")
|
||||
if position <= 0.1:
|
||||
return _signal(price, "bullish", 0.65, f"Price near support ({support:.2f})")
|
||||
return _neutral(price, f"Price between support ({support:.2f}) and resistance ({resistance:.2f})")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _signal(
|
||||
value: Any, signal: str, confidence: float, explanation: str,
|
||||
) -> dict[str, Any]:
|
||||
return {"value": value, "signal": signal, "confidence": confidence, "explanation": explanation}
|
||||
|
||||
|
||||
def _neutral(value: Any, explanation: str) -> dict[str, Any]:
|
||||
return _signal(value, "neutral", 0.5, explanation)
|
||||
Loading…
Reference in New Issue