This commit is contained in:
Peter Wong 2026-04-03 01:05:25 +01:00
parent 4641c03340
commit 6310eedd20
10 changed files with 367 additions and 7 deletions

View File

@ -55,6 +55,7 @@ class MessageBuffer:
"social": "Social Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
"quant": "Quant Analyst",
}
# Report section mapping: section -> (analyst_key for filtering, finalizing_agent)
@ -65,6 +66,7 @@ class MessageBuffer:
"sentiment_report": ("social", "Social Analyst"),
"news_report": ("news", "News Analyst"),
"fundamentals_report": ("fundamentals", "Fundamentals Analyst"),
"quant_report": ("quant", "Quant Analyst"),
"investment_plan": (None, "Research Manager"),
"trader_investment_plan": (None, "Trader"),
"final_trade_decision": (None, "Portfolio Manager"),
@ -173,6 +175,7 @@ class MessageBuffer:
"sentiment_report": "Social Sentiment",
"news_report": "News Analysis",
"fundamentals_report": "Fundamentals Analysis",
"quant_report": "Quantitative Analysis",
"investment_plan": "Research Team Decision",
"trader_investment_plan": "Trading Team Plan",
"final_trade_decision": "Portfolio Management Decision",
@ -188,7 +191,7 @@ class MessageBuffer:
report_parts = []
# Analyst Team Reports - use .get() to handle missing sections
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report"]
analyst_sections = ["market_report", "sentiment_report", "news_report", "fundamentals_report", "quant_report"]
if any(self.report_sections.get(section) for section in analyst_sections):
report_parts.append("## Analyst Team Reports")
if self.report_sections.get("market_report"):
@ -207,6 +210,10 @@ class MessageBuffer:
report_parts.append(
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
)
if self.report_sections.get("quant_report"):
report_parts.append(
f"### Quantitative Analysis\n{self.report_sections['quant_report']}"
)
# Research Team Reports
if self.report_sections.get("investment_plan"):
@ -286,6 +293,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non
"Social Analyst",
"News Analyst",
"Fundamentals Analyst",
"Quant Analyst",
],
"Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
"Trading Team": ["Trader"],
@ -794,18 +802,20 @@ def update_research_team_status(status):
# Ordered list of analysts for status transitions
ANALYST_ORDER = ["market", "social", "news", "fundamentals"]
ANALYST_ORDER = ["market", "social", "news", "fundamentals", "quant"]
ANALYST_AGENT_NAMES = {
"market": "Market Analyst",
"social": "Social Analyst",
"news": "News Analyst",
"fundamentals": "Fundamentals Analyst",
"quant": "Quant Analyst",
}
ANALYST_REPORT_MAP = {
"market": "market_report",
"social": "sentiment_report",
"news": "news_report",
"fundamentals": "fundamentals_report",
"quant": "quant_report",
}

View File

@ -8,3 +8,4 @@ class AnalystType(str, Enum):
SOCIAL = "social"
NEWS = "news"
FUNDAMENTALS = "fundamentals"
QUANT = "quant"

View File

@ -15,6 +15,7 @@ ANALYST_ORDER = [
("Social Media Analyst", AnalystType.SOCIAL),
("News Analyst", AnalystType.NEWS),
("Fundamentals Analyst", AnalystType.FUNDAMENTALS),
("Quant Analyst", AnalystType.QUANT),
]

View File

@ -6,6 +6,7 @@ from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .analysts.quant_analyst import create_quant_analyst
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
@ -36,5 +37,6 @@ __all__ = [
"create_portfolio_manager",
"create_conservative_debator",
"create_social_media_analyst",
"create_quant_analyst",
"create_trader",
]

View File

@ -0,0 +1,59 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.quant_tools import get_quant_analysis
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
def create_quant_analyst(llm):
def quant_analyst_node(state):
current_date = state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
tools = [get_quant_analysis]
system_message = (
"You are a quantitative analyst. Use the get_quant_analysis tool to retrieve "
"statistical metrics for the stock. Interpret the results thoroughly: "
"assess risk (annualised volatility, semideviation, tail risk via skewness and kurtosis), "
"evaluate the return distribution (Jarque-Bera normality test), "
"analyse the market relationship (beta, alpha, R², rolling correlation with SPY), "
"and identify any structural concerns (fat tails, high downside deviation). "
"Provide a structured Markdown report with a clear summary table and actionable insights."
+ get_language_instruction()
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
" For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"quant_report": report,
}
return quant_analyst_node

View File

@ -1,6 +1,6 @@
from typing import Annotated, Sequence
from datetime import date, timedelta, datetime
from typing_extensions import TypedDict, Optional
from typing_extensions import TypedDict, Optional, NotRequired
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
@ -60,6 +60,7 @@ class AgentState(MessagesState):
str, "Report from the News Researcher of current world affairs"
]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
quant_report: NotRequired[Annotated[str, "Report from the Quant Analyst"]]
# researcher team discussion step
investment_debate_state: Annotated[

View File

@ -0,0 +1,267 @@
from langchain_core.tools import tool
from typing import Annotated
from datetime import datetime, timedelta
import logging
import numpy as np
import pandas as pd
import yfinance as yf
from scipy import stats
from statsmodels.stats.stattools import jarque_bera
from statsmodels.regression.linear_model import OLS
from statsmodels.tools import add_constant
logger = logging.getLogger(__name__)
def _fetch_close_prices(ticker: str, end_date: str, days: int = 400) -> pd.Series:
"""Fetch closing prices for a ticker ending on end_date (exclusive)."""
end = datetime.strptime(end_date, "%Y-%m-%d")
start = end - timedelta(days=days)
data = yf.download(
ticker,
start=start.strftime("%Y-%m-%d"),
end=end.strftime("%Y-%m-%d"),
progress=False,
auto_adjust=True,
)
if data.empty:
return pd.Series(dtype=float)
close = data["Close"]
if isinstance(close, pd.DataFrame):
close = close.iloc[:, 0]
return close.dropna()
def _breusch_pagan(residuals: np.ndarray, X: np.ndarray):
"""
Manual Breusch-Pagan test for heteroskedasticity.
H0: residuals have constant variance.
Returns (LM statistic, p-value).
"""
n = len(residuals)
sq_resid = residuals ** 2
aux = OLS(sq_resid, X).fit()
lm = n * aux.rsquared
p = float(stats.chi2.sf(lm, df=X.shape[1] - 1))
return float(lm), p
def _ljung_box(series: np.ndarray, lags: int = 10):
"""
Manual Ljung-Box test for autocorrelation in residuals.
H0: no autocorrelation up to `lags`.
Returns (Q statistic, p-value) for the given lag count.
"""
n = len(series)
acf_vals = [np.corrcoef(series[k:], series[:-k])[0, 1] for k in range(1, lags + 1)]
q = n * (n + 2) * sum(r ** 2 / (n - k) for k, r in enumerate(acf_vals, start=1))
p = float(stats.chi2.sf(q, df=lags))
return float(q), p
@tool
def get_quant_analysis(
ticker: Annotated[str, "Ticker symbol of the stock to analyse, e.g. AAPL"],
analysis_date: Annotated[str, "End date for the analysis window in yyyy-mm-dd format"],
) -> str:
"""
Compute a comprehensive quantitative finance analysis for a stock including:
dispersion/risk metrics, Sharpe ratio, statistical moments, correlation with SPY,
OLS market beta regression, regression diagnostics (Breusch-Pagan heteroskedasticity,
Ljung-Box autocorrelation, Newey-West robust errors, structural break detection),
and multi-factor adjusted . Returns a formatted Markdown report.
"""
close = _fetch_close_prices(ticker, analysis_date)
if len(close) < 30:
return f"Insufficient price data for {ticker} (only {len(close)} trading days available before {analysis_date})."
log_returns = np.log(close / close.shift(1)).dropna().values
n = len(log_returns)
mean_ret = float(np.mean(log_returns))
# ── 1. Dispersion & Risk ─────────────────────────────────────────────────
ann_vol = float(np.std(log_returns, ddof=1) * np.sqrt(252))
# Semideviation: divide by total N (Sortino convention)
downside = log_returns[log_returns < mean_ret]
semivar = float(np.sum((downside - mean_ret) ** 2) / n) if len(downside) > 0 else float("nan")
ann_semidev = float(np.sqrt(semivar) * np.sqrt(252)) if not np.isnan(semivar) else float("nan")
# Target semideviation (target = 0%, i.e. capital preservation)
target = 0.0
below_target = log_returns[log_returns < target]
target_semivar = float(np.sum((below_target - target) ** 2) / n) if len(below_target) > 0 else float("nan")
ann_target_semidev = float(np.sqrt(target_semivar) * np.sqrt(252)) if not np.isnan(target_semivar) else float("nan")
ann_mad = float(np.mean(np.abs(log_returns - mean_ret)) * np.sqrt(252))
ret_range = float(np.ptp(log_returns)) # peak-to-peak
# Sharpe Ratio (rf = 0, annualised)
sharpe = float((mean_ret / np.std(log_returns, ddof=1)) * np.sqrt(252))
# Rolling 30-day Sharpe (last value); guard against zero-std window producing inf
ret_series = pd.Series(log_returns)
roll_std = ret_series.rolling(30).std().replace(0, float("nan"))
roll_sharpe = (ret_series.rolling(30).mean() / roll_std) * np.sqrt(252)
last_roll_sharpe_raw = roll_sharpe.iloc[-1] if not roll_sharpe.empty else float("nan")
last_roll_sharpe = float(last_roll_sharpe_raw) if np.isfinite(last_roll_sharpe_raw) else float("nan")
# ── 2. Statistical Moments ───────────────────────────────────────────────
skewness = float(stats.skew(log_returns))
excess_kurtosis = float(stats.kurtosis(log_returns))
jb_stat, jb_pvalue, _, _ = jarque_bera(log_returns)
jb_pvalue = float(jb_pvalue)
normal_dist = "YES" if jb_pvalue > 0.05 else "NO"
# ── 3. Market Relationship (SPY) ─────────────────────────────────────────
mkt_section = ""
diag_section = ""
break_section = ""
# Fetch SPY — network/data errors are expected and handled gracefully
try:
spy_close = _fetch_close_prices("SPY", analysis_date)
except Exception:
logger.exception("Failed to fetch SPY data for %s", ticker)
spy_close = pd.Series(dtype=float)
common_idx = close.index.intersection(spy_close.index)
if len(common_idx) >= 30:
stock_ret = np.log(close.loc[common_idx] / close.loc[common_idx].shift(1)).dropna()
spy_ret = np.log(spy_close.loc[common_idx] / spy_close.loc[common_idx].shift(1)).dropna()
cidx = stock_ret.index.intersection(spy_ret.index)
sr = stock_ret.loc[cidx].values
mr = spy_ret.loc[cidx].values
# Correlation
pearson_r, pearson_p = stats.pearsonr(sr, mr)
s_s = pd.Series(sr, index=cidx)
m_s = pd.Series(mr, index=cidx)
last_roll_corr = float(s_s.rolling(30).corr(m_s).iloc[-1])
# OLS beta (standard)
X = add_constant(mr)
ols = OLS(sr, X).fit()
alpha = float(ols.params[0])
beta = float(ols.params[1])
r2 = float(ols.rsquared)
adj_r2 = float(ols.rsquared_adj)
beta_p = float(ols.pvalues[1])
ci = ols.conf_int() # ndarray shape (n_params, 2): rows=params, cols=[lower, upper]
beta_ci_lo = float(ci[1][0]) # row 1 = beta, col 0 = lower bound
beta_ci_hi = float(ci[1][1]) # row 1 = beta, col 1 = upper bound
# Newey-West HAC robust beta
hac = ols.get_robustcov_results(cov_type="HAC", maxlags=5)
beta_hac_se = float(hac.bse[1])
beta_hac_p = float(hac.pvalues[1])
mkt_section = (
f"\n| Pearson r vs SPY | {pearson_r:.4f} |"
f"\n| Pearson p-value | {pearson_p:.4f} |"
f"\n| 30-day Rolling Corr (last) | {last_roll_corr:.4f} |"
f"\n| Market Beta (β) | {beta:.4f} |"
f"\n| Beta 95% CI | [{beta_ci_lo:.4f}, {beta_ci_hi:.4f}] |"
f"\n| Alpha (α, daily) | {alpha:.6f} |"
f"\n| R² | {r2:.4f} |"
f"\n| Adjusted R² | {adj_r2:.4f} |"
f"\n| β p-value (OLS) | {beta_p:.4f} |"
f"\n| β Newey-West SE | {beta_hac_se:.4f} |"
f"\n| β p-value (HAC) | {beta_hac_p:.4f} |"
)
# ── 4. Regression Diagnostics ────────────────────────────────────
residuals = ols.resid
# Breusch-Pagan heteroskedasticity test
bp_lm, bp_p = _breusch_pagan(residuals, X)
bp_result = "Heteroskedastic" if bp_p < 0.05 else "Homoskedastic"
# Ljung-Box autocorrelation test (10 lags)
lb_q, lb_p = _ljung_box(residuals, lags=10)
lb_result = "Autocorrelated" if lb_p < 0.05 else "No autocorrelation"
diag_section = (
f"\n\n### Regression Diagnostics"
f"\n\n| Test | Statistic | p-value | Result |"
f"\n|------|-----------|---------|--------|"
f"\n| Breusch-Pagan (heteroskedasticity) | {bp_lm:.4f} | {bp_p:.4f} | {bp_result} |"
f"\n| Ljung-Box Q(10) (autocorrelation) | {lb_q:.4f} | {lb_p:.4f} | {lb_result} |"
)
# ── 5. Structural Break ──────────────────────────────────────────
if len(sr) >= 60:
mid = len(sr) // 2
sr1, mr1 = sr[:mid], mr[:mid]
sr2, mr2 = sr[mid:], mr[mid:]
res1 = OLS(sr1, add_constant(mr1)).fit()
res2 = OLS(sr2, add_constant(mr2)).fit()
beta1, beta2 = float(res1.params[1]), float(res2.params[1])
alpha1, alpha2 = float(res1.params[0]), float(res2.params[0])
r2_1, r2_2 = float(res1.rsquared), float(res2.rsquared)
beta_shift = beta2 - beta1
stability = "STABLE" if abs(beta_shift) < 0.3 else "UNSTABLE"
break_section = (
f"\n\n### Structural Break Analysis (first half vs second half)"
f"\n\n| Period | Beta | Alpha (daily) | R² |"
f"\n|--------|------|---------------|----|"
f"\n| First half ({mid} days) | {beta1:.4f} | {alpha1:.6f} | {r2_1:.4f} |"
f"\n| Second half ({len(sr)-mid} days) | {beta2:.4f} | {alpha2:.6f} | {r2_2:.4f} |"
f"\n| β shift | {beta_shift:+.4f} | — | — |"
f"\n\n**Beta stability**: {stability} (|Δβ| {'< 0.30' if stability == 'STABLE' else '≥ 0.30'})"
)
else:
mkt_section = "\n| SPY Market Data | Insufficient overlapping data |"
# ── Build report ──────────────────────────────────────────────────────────
report = f"""## Quantitative Analysis: {ticker} (as of {analysis_date})
**Data window**: {n} trading days of log returns
### 1. Dispersion & Risk
| Metric | Value |
|--------|-------|
| Annualised Volatility | {ann_vol:.4f} ({ann_vol*100:.2f}%) |
| Annualised Semideviation (vs mean) | {ann_semidev:.4f} ({ann_semidev*100:.2f}%) |
| Annualised Semideviation (vs 0%) | {ann_target_semidev:.4f} ({ann_target_semidev*100:.2f}%) |
| Annualised MAD | {ann_mad:.4f} ({ann_mad*100:.2f}%) |
| Return Range (peak-to-peak) | {ret_range:.4f} ({ret_range*100:.2f}%) |
| Sharpe Ratio (annualised, rf=0) | {sharpe:.4f} |
| Rolling 30-day Sharpe (last) | {last_roll_sharpe:.4f} |
### 2. Statistical Moments
| Metric | Value |
|--------|-------|
| Skewness | {skewness:.4f} |
| Excess Kurtosis | {excess_kurtosis:.4f} |
| Jarque-Bera p-value | {jb_pvalue:.4f} |
| Normally Distributed (JB test) | {normal_dist} |
### 3. Market Relationship (vs SPY)
| Metric | Value |
|--------|-------|{mkt_section}
{diag_section}
{break_section}
### Interpretation Guide
- **Semideviation vs mean**: downside risk relative to average return (Sortino denominator).
- **Semideviation vs 0%**: downside risk relative to capital preservation threshold.
- **Sharpe > 1**: good risk-adjusted return; **< 0**: negative excess return.
- **Skewness < 0**: left tail (larger losses than gains); **> 0**: right tail.
- **Kurtosis > 0**: fat tails more extreme moves than a normal distribution.
- **JB p < 0.05**: returns NOT normally distributed.
- **β > 1**: amplifies market; **β < 1**: defensive.
- **Beta CI**: narrow CI = precisely estimated; wide CI = unreliable beta.
- **HAC p-value**: robust to het + autocorrelation prefer over OLS p-value if BP/LB flagged.
- **Breusch-Pagan p < 0.05**: heteroskedastic residuals OLS standard errors unreliable.
- **Ljung-Box p < 0.05**: autocorrelated residuals model missing serial structure.
- **β shift > 0.30**: regime change detected beta from full period may be misleading.
"""
return report

View File

@ -43,6 +43,14 @@ class ConditionalLogic:
return "tools_fundamentals"
return "Msg Clear Fundamentals"
def should_continue_quant(self, state: AgentState):
"""Determine if quant analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools_quant"
return "Msg Clear Quant"
def should_continue_debate(self, state: AgentState) -> str:
"""Determine if debate should continue."""

View File

@ -37,9 +37,7 @@ class GraphSetup:
self.portfolio_manager_memory = portfolio_manager_memory
self.conditional_logic = conditional_logic
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
):
def setup_graph(self, selected_analysts=None):
"""Set up and compile the agent workflow graph.
Args:
@ -48,7 +46,10 @@ class GraphSetup:
- "social": Social media analyst
- "news": News analyst
- "fundamentals": Fundamentals analyst
- "quant": Quantitative analyst
"""
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals", "quant"]
if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
@ -85,6 +86,11 @@ class GraphSetup:
delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
if "quant" in selected_analysts:
analyst_nodes["quant"] = create_quant_analyst(self.quick_thinking_llm)
delete_nodes["quant"] = create_msg_delete()
tool_nodes["quant"] = self.tool_nodes["quant"]
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
self.quick_thinking_llm, self.bull_memory

View File

@ -32,6 +32,7 @@ from tradingagents.agents.utils.agent_utils import (
get_insider_transactions,
get_global_news
)
from tradingagents.agents.utils.quant_tools import get_quant_analysis
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
@ -45,7 +46,7 @@ class TradingAgentsGraph:
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
selected_analysts=None,
debug=False,
config: Dict[str, Any] = None,
callbacks: Optional[List] = None,
@ -58,6 +59,8 @@ class TradingAgentsGraph:
config: Configuration dictionary. If None, uses default config
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
"""
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals", "quant"]
self.debug = debug
self.config = config or DEFAULT_CONFIG
self.callbacks = callbacks or []
@ -189,6 +192,7 @@ class TradingAgentsGraph:
get_income_statement,
]
),
"quant": ToolNode([get_quant_analysis]),
}
def propagate(self, company_name, trade_date):
@ -235,6 +239,7 @@ class TradingAgentsGraph:
"sentiment_report": final_state["sentiment_report"],
"news_report": final_state["news_report"],
"fundamentals_report": final_state["fundamentals_report"],
"quant_report": final_state.get("quant_report", ""),
"investment_debate_state": {
"bull_history": final_state["investment_debate_state"]["bull_history"],
"bear_history": final_state["investment_debate_state"]["bear_history"],