Merge 90c06f0d31 into fa4d01c23a
This commit is contained in:
commit
e9b62022e2
|
|
@ -0,0 +1,56 @@
|
||||||
|
"""Verify fetch_indicators returns valid data matching stockstats directly.
|
||||||
|
|
||||||
|
Since fetch_indicators now calls StockstatsUtils.get_stock_stats (the same
|
||||||
|
function the get_indicators LLM tool uses), these tests confirm the wrapper
|
||||||
|
correctly structures the results for interpret.py.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tradingagents.indicators.compute import fetch_indicators
|
||||||
|
HAS_DEPS = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_DEPS = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_DEPS, reason="indicators module not available")
|
||||||
|
@pytest.mark.integration
|
||||||
|
class TestFetchIndicators:
|
||||||
|
|
||||||
|
TICKER = "AAPL"
|
||||||
|
DATE = "2024-01-10"
|
||||||
|
|
||||||
|
def test_returns_dict(self):
|
||||||
|
result = fetch_indicators(self.TICKER, self.DATE)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
def test_rsi_in_range(self):
|
||||||
|
result = fetch_indicators(self.TICKER, self.DATE)
|
||||||
|
if "rsi" not in result:
|
||||||
|
pytest.skip("RSI not available")
|
||||||
|
assert 0 <= result["rsi"]["value"] <= 100
|
||||||
|
|
||||||
|
def test_macd_has_required_keys(self):
|
||||||
|
result = fetch_indicators(self.TICKER, self.DATE)
|
||||||
|
if "macd" not in result:
|
||||||
|
pytest.skip("MACD not available")
|
||||||
|
for key in ("value", "signal", "histogram"):
|
||||||
|
assert key in result["macd"], f"MACD missing '{key}'"
|
||||||
|
|
||||||
|
def test_bollinger_bands_ordered(self):
|
||||||
|
result = fetch_indicators(self.TICKER, self.DATE)
|
||||||
|
if "bollinger" not in result:
|
||||||
|
pytest.skip("Bollinger not available")
|
||||||
|
b = result["bollinger"]
|
||||||
|
assert b["lower"] < b["upper"], "Lower band should be below upper"
|
||||||
|
|
||||||
|
def test_sma_values_positive(self):
|
||||||
|
result = fetch_indicators(self.TICKER, self.DATE)
|
||||||
|
if "sma_crossover" not in result:
|
||||||
|
pytest.skip("SMA not available")
|
||||||
|
assert result["sma_crossover"]["sma50"] > 0
|
||||||
|
assert result["sma_crossover"]["sma200"] > 0
|
||||||
|
|
||||||
|
def test_empty_on_bad_ticker(self):
|
||||||
|
result = fetch_indicators("ZZZZZZNOTREAL", self.DATE)
|
||||||
|
assert result == {}
|
||||||
|
|
@ -7,6 +7,23 @@ from tradingagents.agents.utils.agent_utils import (
|
||||||
)
|
)
|
||||||
from tradingagents.dataflows.config import get_config
|
from tradingagents.dataflows.config import get_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tradingagents.indicators import get_indicator_signals
|
||||||
|
except ImportError:
|
||||||
|
get_indicator_signals = None
|
||||||
|
|
||||||
|
|
||||||
|
def _format_preamble(signals: dict) -> str:
|
||||||
|
"""Format pre-computed interpretations as a prompt preamble."""
|
||||||
|
if not signals:
|
||||||
|
return ""
|
||||||
|
parts = ["\n## Pre-Computed Indicator Interpretations\n"]
|
||||||
|
for name, s in signals.items():
|
||||||
|
parts.append(f"- **{name}**: {s['explanation']}")
|
||||||
|
parts.append("\nUse these alongside the indicator catalog below. "
|
||||||
|
"Do NOT recalculate these values.\n")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
def create_market_analyst(llm):
|
def create_market_analyst(llm):
|
||||||
|
|
||||||
|
|
@ -14,13 +31,25 @@ def create_market_analyst(llm):
|
||||||
current_date = state["trade_date"]
|
current_date = state["trade_date"]
|
||||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||||
|
|
||||||
|
# Pre-compute indicator interpretations when available
|
||||||
|
indicator_preamble = ""
|
||||||
|
if get_indicator_signals is not None:
|
||||||
|
try:
|
||||||
|
signals = get_indicator_signals(
|
||||||
|
state["company_of_interest"], current_date
|
||||||
|
)
|
||||||
|
indicator_preamble = _format_preamble(signals)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
tools = [
|
tools = [
|
||||||
get_stock_data,
|
get_stock_data,
|
||||||
get_indicators,
|
get_indicators,
|
||||||
]
|
]
|
||||||
|
|
||||||
system_message = (
|
system_message = (
|
||||||
"""You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
indicator_preamble
|
||||||
|
+ """You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||||
|
|
||||||
Moving Averages:
|
Moving Averages:
|
||||||
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""Technical indicators: fetch via stockstats, interpret, return structured signals."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from tradingagents.indicators.compute import fetch_indicators
|
||||||
|
from tradingagents.indicators.interpret import interpret_indicators
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_indicator_signals(symbol: str, curr_date: str) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Return structured indicator signals for *symbol* as of *curr_date*.
|
||||||
|
|
||||||
|
Uses the same stockstats pipeline as the get_indicators LLM tool.
|
||||||
|
Returns an empty dict on any failure. Never raises.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw = fetch_indicators(symbol, curr_date)
|
||||||
|
if not raw:
|
||||||
|
return {}
|
||||||
|
return interpret_indicators(raw)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("indicators: failed for %s@%s", symbol, curr_date, exc_info=True)
|
||||||
|
return {}
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""Fetch indicator values using the existing stockstats pipeline.
|
||||||
|
|
||||||
|
Loads OHLCV data once and computes all indicators on a single
|
||||||
|
StockDataFrame instance — no redundant disk I/O.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# stockstats indicator names to fetch, grouped by interpretation key.
|
||||||
|
_STATS = [
|
||||||
|
"rsi_14", "macd", "macds", "macdh",
|
||||||
|
"boll_ub", "boll_lb", "boll",
|
||||||
|
"close_50_sma", "close_200_sma",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_indicators(symbol: str, curr_date: str) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Fetch indicator values via stockstats for *symbol* as of *curr_date*.
|
||||||
|
|
||||||
|
Loads data once and computes all indicators in a single pass.
|
||||||
|
Returns a dict keyed by indicator group name. Missing indicators omitted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from tradingagents.dataflows.stockstats_utils import load_ohlcv
|
||||||
|
from stockstats import wrap
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("indicators: stockstats or dataflows not available")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = load_ohlcv(symbol, curr_date)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("indicators: load_ohlcv failed for %s: %s", symbol, exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if df.empty or len(df) < 2:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
ss = wrap(df)
|
||||||
|
# Trigger computation of all indicators in one pass
|
||||||
|
for stat in _STATS:
|
||||||
|
_ = ss[stat]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("indicators: stockstats computation failed for %s: %s", symbol, exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Extract last row values
|
||||||
|
import pandas as pd
|
||||||
|
curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d")
|
||||||
|
ss["date_str"] = ss["date"].dt.strftime("%Y-%m-%d") if "date" in ss.columns else ""
|
||||||
|
row = ss[ss["date_str"] == curr_date_str]
|
||||||
|
if row.empty:
|
||||||
|
row = ss.iloc[[-1]] # fallback to last available row
|
||||||
|
|
||||||
|
def _val(col: str) -> float | None:
|
||||||
|
try:
|
||||||
|
v = float(row[col].iloc[0])
|
||||||
|
return None if pd.isna(v) else v
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
results: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# RSI
|
||||||
|
rsi = _val("rsi_14")
|
||||||
|
if rsi is not None:
|
||||||
|
results["rsi"] = {"value": rsi, "period": 14}
|
||||||
|
|
||||||
|
# MACD
|
||||||
|
macd_val = _val("macd")
|
||||||
|
if macd_val is not None:
|
||||||
|
results["macd"] = {
|
||||||
|
"value": macd_val,
|
||||||
|
"signal": _val("macds"),
|
||||||
|
"histogram": _val("macdh"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Bollinger
|
||||||
|
upper = _val("boll_ub")
|
||||||
|
lower = _val("boll_lb")
|
||||||
|
if upper is not None and lower is not None:
|
||||||
|
results["bollinger"] = {
|
||||||
|
"value": _val("boll"),
|
||||||
|
"upper": upper,
|
||||||
|
"lower": lower,
|
||||||
|
}
|
||||||
|
|
||||||
|
# SMA crossover
|
||||||
|
sma50 = _val("close_50_sma")
|
||||||
|
sma200 = _val("close_200_sma")
|
||||||
|
if sma50 is not None and sma200 is not None:
|
||||||
|
# Detect crossover from recent data
|
||||||
|
crossover = None
|
||||||
|
try:
|
||||||
|
diff = ss["close_50_sma"] - ss["close_200_sma"]
|
||||||
|
recent = diff.iloc[-5:]
|
||||||
|
signs = recent.dropna().apply(lambda x: 1 if x > 0 else -1)
|
||||||
|
if len(signs) >= 2 and signs.iloc[-1] != signs.iloc[0]:
|
||||||
|
crossover = "golden_cross" if signs.iloc[-1] > 0 else "death_cross"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
results["sma_crossover"] = {
|
||||||
|
"sma50": sma50,
|
||||||
|
"sma200": sma200,
|
||||||
|
"crossover": crossover,
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
@ -0,0 +1,162 @@
|
||||||
|
"""Rule-based interpretation of computed technical indicators.
|
||||||
|
|
||||||
|
Converts raw numeric indicator values into structured signals
|
||||||
|
(bullish / bearish / neutral) with confidence and human-readable
|
||||||
|
explanations. No LLM involvement — pure deterministic rules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def interpret_indicators(
|
||||||
|
computed: dict[str, dict[str, Any]],
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Interpret every indicator in *computed* and return structured signals.
|
||||||
|
|
||||||
|
Each result dict contains: ``value``, ``signal``, ``confidence``, ``explanation``.
|
||||||
|
Unknown indicators are silently skipped.
|
||||||
|
"""
|
||||||
|
_INTERPRETERS: dict[str, Any] = {
|
||||||
|
"rsi": _interpret_rsi,
|
||||||
|
"macd": _interpret_macd,
|
||||||
|
"bollinger": _interpret_bollinger,
|
||||||
|
"sma_crossover": _interpret_sma_crossover,
|
||||||
|
}
|
||||||
|
|
||||||
|
results: dict[str, dict[str, Any]] = {}
|
||||||
|
for name, data in computed.items():
|
||||||
|
fn = _INTERPRETERS.get(name)
|
||||||
|
if fn is not None:
|
||||||
|
results[name] = fn(data)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-indicator interpreters
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _interpret_rsi(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
val = data.get("value")
|
||||||
|
if val is None:
|
||||||
|
return _neutral(val, "Insufficient data for RSI")
|
||||||
|
|
||||||
|
if val >= 80:
|
||||||
|
return _signal(val, "bearish", 0.9, f"RSI {val:.1f} — strongly overbought")
|
||||||
|
if val >= 70:
|
||||||
|
return _signal(val, "bearish", 0.7, f"RSI {val:.1f} — overbought")
|
||||||
|
if val <= 20:
|
||||||
|
return _signal(val, "bullish", 0.9, f"RSI {val:.1f} — strongly oversold")
|
||||||
|
if val <= 30:
|
||||||
|
return _signal(val, "bullish", 0.7, f"RSI {val:.1f} — oversold")
|
||||||
|
return _neutral(val, f"RSI {val:.1f} — neutral range")
|
||||||
|
|
||||||
|
|
||||||
|
def _interpret_macd(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
hist = data.get("histogram")
|
||||||
|
macd_val = data.get("value")
|
||||||
|
sig_val = data.get("signal")
|
||||||
|
|
||||||
|
if hist is None or macd_val is None:
|
||||||
|
return _neutral(macd_val, "Insufficient data for MACD")
|
||||||
|
|
||||||
|
direction = "bullish" if hist > 0 else "bearish"
|
||||||
|
# Confidence scales with histogram magnitude relative to signal line
|
||||||
|
ref = abs(sig_val) if sig_val else 1.0
|
||||||
|
strength = min(abs(hist) / max(ref, 0.01), 1.0)
|
||||||
|
confidence = round(0.5 + 0.4 * strength, 2)
|
||||||
|
|
||||||
|
crossing = ""
|
||||||
|
if abs(hist) < 0.05 * max(ref, 0.01):
|
||||||
|
crossing = " (near crossover)"
|
||||||
|
confidence = 0.5
|
||||||
|
|
||||||
|
return _signal(
|
||||||
|
macd_val,
|
||||||
|
direction,
|
||||||
|
confidence,
|
||||||
|
f"MACD histogram {hist:+.4f}{crossing} — {direction}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _interpret_bollinger(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
price = data.get("value")
|
||||||
|
upper = data.get("upper")
|
||||||
|
lower = data.get("lower")
|
||||||
|
|
||||||
|
if price is None or upper is None or lower is None:
|
||||||
|
return _neutral(price, "Insufficient data for Bollinger Bands")
|
||||||
|
|
||||||
|
band_width = upper - lower
|
||||||
|
if band_width <= 0:
|
||||||
|
return _neutral(price, "Bollinger band width is zero")
|
||||||
|
|
||||||
|
position = (price - lower) / band_width # 0 = at lower, 1 = at upper
|
||||||
|
|
||||||
|
if position >= 1.0:
|
||||||
|
return _signal(price, "bearish", 0.8, f"Price at/above upper Bollinger Band — overbought")
|
||||||
|
if position >= 0.8:
|
||||||
|
return _signal(price, "bearish", 0.6, f"Price near upper Bollinger Band ({position:.0%})")
|
||||||
|
if position <= 0.0:
|
||||||
|
return _signal(price, "bullish", 0.8, f"Price at/below lower Bollinger Band — oversold")
|
||||||
|
if position <= 0.2:
|
||||||
|
return _signal(price, "bullish", 0.6, f"Price near lower Bollinger Band ({position:.0%})")
|
||||||
|
return _neutral(price, f"Price within Bollinger Bands ({position:.0%})")
|
||||||
|
|
||||||
|
|
||||||
|
def _interpret_sma_crossover(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
sma50 = data.get("sma50")
|
||||||
|
sma200 = data.get("sma200")
|
||||||
|
crossover = data.get("crossover")
|
||||||
|
|
||||||
|
if sma50 is None or sma200 is None:
|
||||||
|
return _neutral(sma50, "Insufficient data for SMA crossover")
|
||||||
|
|
||||||
|
if crossover == "golden_cross":
|
||||||
|
return _signal(sma50, "bullish", 0.85, "Golden cross — SMA50 crossed above SMA200")
|
||||||
|
if crossover == "death_cross":
|
||||||
|
return _signal(sma50, "bearish", 0.85, "Death cross — SMA50 crossed below SMA200")
|
||||||
|
|
||||||
|
if sma50 > sma200:
|
||||||
|
return _signal(sma50, "bullish", 0.6, f"SMA50 ({sma50:.2f}) above SMA200 ({sma200:.2f}) — bullish trend")
|
||||||
|
return _signal(sma50, "bearish", 0.6, f"SMA50 ({sma50:.2f}) below SMA200 ({sma200:.2f}) — bearish trend")
|
||||||
|
|
||||||
|
|
||||||
|
def _interpret_support_resistance(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
price = data.get("last_close")
|
||||||
|
resistance = data.get("resistance")
|
||||||
|
support = data.get("support")
|
||||||
|
|
||||||
|
if price is None or resistance is None or support is None:
|
||||||
|
return _neutral(price, "Insufficient data for support/resistance")
|
||||||
|
|
||||||
|
rng = resistance - support
|
||||||
|
if rng <= 0:
|
||||||
|
return _neutral(price, "Support equals resistance")
|
||||||
|
|
||||||
|
position = (price - support) / rng
|
||||||
|
|
||||||
|
if position >= 0.9:
|
||||||
|
return _signal(price, "bearish", 0.65, f"Price near resistance ({resistance:.2f})")
|
||||||
|
if position <= 0.1:
|
||||||
|
return _signal(price, "bullish", 0.65, f"Price near support ({support:.2f})")
|
||||||
|
return _neutral(price, f"Price between support ({support:.2f}) and resistance ({resistance:.2f})")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _signal(
|
||||||
|
value: Any, signal: str, confidence: float, explanation: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {"value": value, "signal": signal, "confidence": confidence, "explanation": explanation}
|
||||||
|
|
||||||
|
|
||||||
|
def _neutral(value: Any, explanation: str) -> dict[str, Any]:
|
||||||
|
return _signal(value, "neutral", 0.5, explanation)
|
||||||
Loading…
Reference in New Issue