feat: add macro analyst

This commit is contained in:
Garrick 2026-03-24 16:10:51 -07:00
parent cc5a322135
commit 5be6ca954a
9 changed files with 657 additions and 6 deletions

182
tests/test_macro_analyst.py Normal file
View File

@ -0,0 +1,182 @@
from tradingagents.graph.setup import GraphSetup
from tradingagents.graph.trading_graph import TradingAgentsGraph
class DummyStateGraph:
def __init__(self, _state_type):
self.nodes = {}
self.conditional_edges = {}
def add_node(self, name, node):
self.nodes[name] = node
def add_edge(self, *_args, **_kwargs):
return None
def add_conditional_edges(self, source, condition, destinations):
self.conditional_edges[source] = {
"condition": condition,
"destinations": destinations,
}
def compile(self):
return {
"nodes": self.nodes,
"conditional_edges": self.conditional_edges,
}
class DummyToolNode:
def __init__(self, tools):
self.tools = tools
def test_macro_tools_route_to_vendor(monkeypatch):
import tradingagents.dataflows.interface as interface
from tradingagents.agents.utils.macro_data_tools import (
get_economic_indicators,
get_fed_calendar,
get_yield_curve,
)
calls = []
def fake_route_to_vendor(method, *args, **kwargs):
calls.append((method, args, kwargs))
return f"{method}-result"
monkeypatch.setattr(interface, "route_to_vendor", fake_route_to_vendor)
assert (
get_economic_indicators.invoke(
{"curr_date": "2026-03-24", "lookback_days": 30}
)
== "get_economic_indicators-result"
)
assert (
get_yield_curve.invoke({"curr_date": "2026-03-24"})
== "get_yield_curve-result"
)
assert (
get_fed_calendar.invoke({"curr_date": "2026-03-24"})
== "get_fed_calendar-result"
)
assert calls == [
(
"get_economic_indicators",
(),
{"curr_date": "2026-03-24", "lookback_days": 30},
),
("get_yield_curve", (), {"curr_date": "2026-03-24"}),
("get_fed_calendar", (), {"curr_date": "2026-03-24"}),
]
def test_graph_setup_wires_macro_analyst_and_macro_tools(monkeypatch):
recorded_llms = {}
monkeypatch.setattr("tradingagents.graph.setup.StateGraph", DummyStateGraph)
monkeypatch.setattr("tradingagents.graph.setup.create_msg_delete", lambda: "delete")
def make_factory(node_name):
def factory(llm, *_args):
recorded_llms[node_name] = llm
return node_name
return factory
monkeypatch.setattr(
"tradingagents.graph.setup.create_market_analyst",
make_factory("Market Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_macro_analyst",
make_factory("Macro Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_social_media_analyst",
make_factory("Social Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_news_analyst",
make_factory("News Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_fundamentals_analyst",
make_factory("Fundamentals Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_bull_researcher",
make_factory("Bull Researcher"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_bear_researcher",
make_factory("Bear Researcher"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_research_manager",
make_factory("Research Manager"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_trader",
make_factory("Trader"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_aggressive_debator",
make_factory("Aggressive Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_neutral_debator",
make_factory("Neutral Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_conservative_debator",
make_factory("Conservative Analyst"),
)
monkeypatch.setattr(
"tradingagents.graph.setup.create_portfolio_manager",
make_factory("Portfolio Manager"),
)
class PartialConditionalLogic:
def should_continue_market(self, _state):
return "Msg Clear Market"
def should_continue_debate(self, _state):
return "Research Manager"
def should_continue_risk_analysis(self, _state):
return "Portfolio Manager"
setup = GraphSetup(
quick_thinking_llm="quick-llm",
deep_thinking_llm="deep-llm",
tool_nodes={"market": "market-tools", "macro": "macro-tools"},
bull_memory=object(),
bear_memory=object(),
trader_memory=object(),
invest_judge_memory=object(),
portfolio_manager_memory=object(),
conditional_logic=PartialConditionalLogic(),
role_llms={"macro": "macro-llm"},
)
graph = setup.setup_graph(selected_analysts=["market", "macro"])
assert recorded_llms["Macro Analyst"] == "macro-llm"
assert graph["nodes"]["Macro Analyst"] == "Macro Analyst"
assert graph["nodes"]["tools_macro"] == "macro-tools"
assert "Macro Analyst" in graph["conditional_edges"]
def test_trading_graph_creates_macro_tool_node(monkeypatch):
monkeypatch.setattr("tradingagents.graph.trading_graph.ToolNode", DummyToolNode)
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
tool_nodes = TradingAgentsGraph._create_tool_nodes(graph)
assert [tool.name for tool in tool_nodes["macro"].tools] == [
"get_economic_indicators",
"get_yield_curve",
"get_fed_calendar",
]

View File

@ -3,6 +3,7 @@ from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.macro_analyst import create_macro_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
@ -29,6 +30,7 @@ __all__ = [
"create_bull_researcher",
"create_research_manager",
"create_fundamentals_analyst",
"create_macro_analyst",
"create_market_analyst",
"create_neutral_debator",
"create_news_analyst",

View File

@ -0,0 +1,77 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from tradingagents.agents.utils.agent_utils import (
build_instrument_context,
get_economic_indicators,
get_fed_calendar,
get_yield_curve,
)
def _merge_with_news_report(existing_report: str, macro_report: str) -> str:
if not macro_report:
return existing_report
if not existing_report:
return macro_report
return f"{existing_report.rstrip()}\n\n## Macro Economic Overlay\n\n{macro_report}"
def create_macro_analyst(llm):
def macro_analyst_node(state):
current_date = state["trade_date"]
instrument_context = build_instrument_context(state["company_of_interest"])
tools = [
get_economic_indicators,
get_yield_curve,
get_fed_calendar,
]
system_message = (
"You are a macroeconomic analyst responsible for turning Federal Reserve "
"data, inflation data, labor data, and the Treasury curve into a trading "
"usable macro view. Use `get_economic_indicators` to establish the growth, "
"inflation, and labor backdrop, `get_yield_curve` to explain the rates "
"curve and recession signal, and `get_fed_calendar` to summarize the policy "
"path. Focus on regime identification, likely policy direction, cross-asset "
"implications, and concrete risks that other analysts should incorporate."
" Append a Markdown table that summarizes the major macro signals and their "
"market implications."
)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK; another assistant with different tools"
" will help where you left off. Execute what you can to make progress."
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}"
"For your reference, the current date is {current_date}. {instrument_context}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join(tool.name for tool in tools))
prompt = prompt.partial(current_date=current_date)
prompt = prompt.partial(instrument_context=instrument_context)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"macro_report": report,
"news_report": _merge_with_news_report(state.get("news_report", ""), report),
}
return macro_analyst_node

View File

@ -18,6 +18,29 @@ from tradingagents.agents.utils.news_data_tools import (
get_insider_transactions,
get_global_news
)
from tradingagents.agents.utils.macro_data_tools import (
get_economic_indicators,
get_fed_calendar,
get_yield_curve,
)
__all__ = [
"build_instrument_context",
"create_msg_delete",
"get_balance_sheet",
"get_cashflow",
"get_economic_indicators",
"get_fed_calendar",
"get_fundamentals",
"get_global_news",
"get_income_statement",
"get_indicators",
"get_insider_transactions",
"get_news",
"get_stock_data",
"get_yield_curve",
]
def build_instrument_context(ticker: str) -> str:
@ -28,6 +51,7 @@ def build_instrument_context(ticker: str) -> str:
"preserving any exchange suffix (e.g. `.TO`, `.L`, `.HK`, `.T`)."
)
def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
@ -42,6 +66,3 @@ def create_msg_delete():
return {"messages": removal_operations + [placeholder]}
return delete_messages

View File

@ -0,0 +1,38 @@
from typing import Annotated
from langchain_core.tools import tool
@tool
def get_economic_indicators(
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
lookback_days: Annotated[int, "how many days to look back for data"] = 90,
) -> str:
"""Retrieve a macro indicators report backed by the configured macro data vendor."""
from tradingagents.dataflows.interface import route_to_vendor
return route_to_vendor(
"get_economic_indicators",
curr_date=curr_date,
lookback_days=lookback_days,
)
@tool
def get_yield_curve(
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
) -> str:
"""Retrieve the US Treasury yield curve and spread analysis."""
from tradingagents.dataflows.interface import route_to_vendor
return route_to_vendor("get_yield_curve", curr_date=curr_date)
@tool
def get_fed_calendar(
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
) -> str:
"""Retrieve the recent Federal Reserve policy path summary."""
from tradingagents.dataflows.interface import route_to_vendor
return route_to_vendor("get_fed_calendar", curr_date=curr_date)

View File

@ -23,6 +23,11 @@ from .alpha_vantage import (
get_global_news as get_alpha_vantage_global_news,
)
from .alpha_vantage_common import AlphaVantageRateLimitError
from .macro_utils import (
get_economic_indicators_report,
get_fed_calendar_and_minutes,
get_treasury_yield_curve,
)
# Configuration and routing logic
from .config import get_config
@ -57,12 +62,21 @@ TOOLS_CATEGORIES = {
"get_global_news",
"get_insider_transactions",
]
},
"macro_data": {
"description": "Macroeconomic indicators and Federal Reserve data",
"tools": [
"get_economic_indicators",
"get_yield_curve",
"get_fed_calendar",
],
}
}
VENDOR_LIST = [
"yfinance",
"alpha_vantage",
"fred",
]
# Mapping of methods to their vendor-specific implementations
@ -107,6 +121,16 @@ VENDOR_METHODS = {
"alpha_vantage": get_alpha_vantage_insider_transactions,
"yfinance": get_yfinance_insider_transactions,
},
# macro_data
"get_economic_indicators": {
"fred": get_economic_indicators_report,
},
"get_yield_curve": {
"fred": get_treasury_yield_curve,
},
"get_fed_calendar": {
"fred": get_fed_calendar_and_minutes,
},
}
def get_category_for_method(method: str) -> str:
@ -159,4 +183,4 @@ def route_to_vendor(method: str, *args, **kwargs):
except AlphaVantageRateLimitError:
continue # Only rate limits trigger fallback
raise RuntimeError(f"No available vendor for '{method}'")
raise RuntimeError(f"No available vendor for '{method}'")

View File

@ -0,0 +1,269 @@
import os
from datetime import datetime, timedelta
import requests
FRED_OBSERVATIONS_URL = "https://api.stlouisfed.org/fred/series/observations"
def _get_fred_api_key() -> str | None:
return os.getenv("FRED_API_KEY")
def _get_fred_observations(
series_id: str,
start_date: str,
end_date: str,
*,
limit: int = 100,
):
api_key = _get_fred_api_key()
if not api_key:
return {
"error": (
"FRED API key not configured. Set the FRED_API_KEY environment "
"variable to enable macro data."
)
}
params = {
"series_id": series_id,
"api_key": api_key,
"file_type": "json",
"observation_start": start_date,
"observation_end": end_date,
"sort_order": "desc",
"limit": limit,
}
try:
response = requests.get(FRED_OBSERVATIONS_URL, params=params, timeout=30)
response.raise_for_status()
return response.json()
except requests.RequestException as exc:
return {"error": f"Failed to fetch FRED data for {series_id}: {exc}"}
except ValueError as exc:
return {"error": f"FRED returned invalid JSON for {series_id}: {exc}"}
def _valid_observations(payload):
observations = payload.get("observations", [])
return [obs for obs in observations if obs.get("value") not in (None, ".")]
def _window_start(curr_date: str, lookback_days: int) -> str:
return (
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=lookback_days)
).strftime("%Y-%m-%d")
def get_treasury_yield_curve(curr_date: str) -> str:
start_date = _window_start(curr_date, 30)
yield_series = [
("1 Month", "DGS1MO"),
("3 Month", "DGS3MO"),
("6 Month", "DGS6MO"),
("1 Year", "DGS1"),
("2 Year", "DGS2"),
("3 Year", "DGS3"),
("5 Year", "DGS5"),
("7 Year", "DGS7"),
("10 Year", "DGS10"),
("20 Year", "DGS20"),
("30 Year", "DGS30"),
]
rows = []
for maturity, series_id in yield_series:
payload = _get_fred_observations(series_id, start_date, curr_date)
if "error" in payload:
continue
observations = _valid_observations(payload)
if not observations:
continue
latest = observations[0]
rows.append((maturity, float(latest["value"]), latest["date"]))
if not rows:
return (
f"## Treasury Yield Curve as of {curr_date}\n\n"
"No Treasury yield data was available for the requested window."
)
lines = [
f"## Treasury Yield Curve as of {curr_date}",
"",
"| Maturity | Yield (%) | Observation Date |",
"| --- | ---: | --- |",
]
for maturity, rate, observation_date in rows:
lines.append(f"| {maturity} | {rate:.2f} | {observation_date} |")
two_year = next((rate for maturity, rate, _ in rows if maturity == "2 Year"), None)
ten_year = next((rate for maturity, rate, _ in rows if maturity == "10 Year"), None)
if two_year is not None and ten_year is not None:
spread = ten_year - two_year
lines.extend(
[
"",
"### Yield Curve Readout",
f"- 2Y-10Y spread: {spread:.2f} percentage points.",
]
)
if spread < 0:
lines.append("- Interpretation: the curve is inverted, a classic recession warning.")
elif spread < 0.5:
lines.append("- Interpretation: the curve is flat, pointing to tighter growth expectations.")
else:
lines.append("- Interpretation: the curve is upward sloping, consistent with normal growth expectations.")
return "\n".join(lines)
def get_economic_indicators_report(curr_date: str, lookback_days: int = 90) -> str:
start_date = _window_start(curr_date, lookback_days)
indicators = {
"Federal Funds Rate": {
"series": "FEDFUNDS",
"description": "Federal Reserve policy rate",
"unit": "%",
},
"Consumer Price Index": {
"series": "CPIAUCSL",
"description": "Headline consumer inflation index",
"unit": "index",
"year_over_year": True,
},
"Producer Price Index": {
"series": "PPIACO",
"description": "Producer-level inflation index",
"unit": "index",
"year_over_year": True,
},
"Unemployment Rate": {
"series": "UNRATE",
"description": "Share of the labor force that is unemployed",
"unit": "%",
},
"Nonfarm Payrolls": {
"series": "PAYEMS",
"description": "Total nonfarm payroll employment",
"unit": "thousands",
},
"GDP": {
"series": "GDP",
"description": "Gross domestic product, nominal level",
"unit": "billions",
},
"ISM Manufacturing PMI": {
"series": "NAPM",
"description": "Manufacturing activity diffusion index",
"unit": "index",
},
"Consumer Confidence": {
"series": "CSCICP03USM665S",
"description": "OECD consumer confidence measure for the US",
"unit": "index",
},
"VIX": {
"series": "VIXCLS",
"description": "CBOE market volatility index",
"unit": "index",
},
}
lines = [f"## Economic Indicators Report ({start_date} to {curr_date})", ""]
for name, metadata in indicators.items():
payload = _get_fred_observations(metadata["series"], start_date, curr_date)
lines.append(f"### {name}")
if "error" in payload:
lines.append(f"- Error: {payload['error']}")
lines.append("")
continue
observations = _valid_observations(payload)
if not observations:
lines.append("- No data available in the requested window.")
lines.append("")
continue
latest = observations[0]
latest_value = float(latest["value"])
lines.append(
f"- Latest value: {latest_value:.2f} {metadata['unit']} ({latest['date']})"
)
lines.append(f"- Description: {metadata['description']}")
if len(observations) >= 2:
previous = observations[1]
previous_value = float(previous["value"])
change = latest_value - previous_value
change_pct = 0.0 if previous_value == 0 else (change / previous_value) * 100
lines.append(
f"- Sequential change: {change:+.2f} {metadata['unit']} ({change_pct:+.2f}%)"
)
if metadata.get("year_over_year") and len(observations) >= 12:
year_ago = observations[11]
year_ago_value = float(year_ago["value"])
if year_ago_value != 0:
yoy_change = ((latest_value - year_ago_value) / year_ago_value) * 100
lines.append(f"- Year-over-year change: {yoy_change:+.2f}%")
lines.append("")
return "\n".join(lines).rstrip()
def get_fed_calendar_and_minutes(curr_date: str) -> str:
start_date = _window_start(curr_date, 365)
payload = _get_fred_observations("FEDFUNDS", start_date, curr_date)
lines = [
f"## Federal Reserve Policy Snapshot as of {curr_date}",
"",
"FRED does not provide the FOMC meeting calendar directly. This summary uses the recent policy-rate path as a proxy for the Fed backdrop.",
"",
]
if "error" in payload:
lines.append(f"- Error: {payload['error']}")
return "\n".join(lines)
observations = _valid_observations(payload)
if not observations:
lines.append("- No recent Federal Funds observations were available.")
return "\n".join(lines)
lines.extend(
[
"| Date | Fed Funds Rate (%) | Change vs Prior |",
"| --- | ---: | --- |",
]
)
recent_observations = observations[:6]
for index, observation in enumerate(recent_observations):
rate = float(observation["value"])
change_text = "-"
if index + 1 < len(observations):
prior_rate = float(observations[index + 1]["value"])
delta = rate - prior_rate
change_text = "unchanged" if delta == 0 else f"{delta:+.2f}"
lines.append(f"| {observation['date']} | {rate:.2f} | {change_text} |")
latest_rate = float(recent_observations[0]["value"])
lines.extend(
[
"",
"### Policy Readout",
f"- Latest effective Fed Funds rate in the series: {latest_rate:.2f}%.",
]
)
if latest_rate >= 4.0:
lines.append("- Interpretation: policy remains restrictive relative to the post-2008 norm.")
elif latest_rate <= 2.0:
lines.append("- Interpretation: policy is accommodative by recent historical standards.")
else:
lines.append("- Interpretation: policy is near a neutral zone by recent historical standards.")
return "\n".join(lines)

View File

@ -44,6 +44,7 @@ class GraphSetup:
self.fundamentals_analyst_llm = self._get_role_llm(
"fundamentals", self.quick_thinking_llm
)
self.macro_analyst_llm = self._get_role_llm("macro", self.quick_thinking_llm)
self.bull_researcher_llm = self._get_role_llm(
"bull_researcher", self.quick_thinking_llm
)
@ -70,6 +71,24 @@ class GraphSetup:
def _get_role_llm(self, role: str, fallback_llm: ChatOpenAI):
return self.role_llms.get(role, fallback_llm)
def _get_continue_handler(self, analyst_type: str):
specific_handler = getattr(
self.conditional_logic,
f"should_continue_{analyst_type}",
None,
)
if specific_handler is not None:
return specific_handler
def default_handler(state: AgentState):
messages = state["messages"]
last_message = messages[-1]
if getattr(last_message, "tool_calls", None):
return f"tools_{analyst_type}"
return f"Msg Clear {analyst_type.capitalize()}"
return default_handler
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
):
@ -81,6 +100,7 @@ class GraphSetup:
- "social": Social media analyst
- "news": News analyst
- "fundamentals": Fundamentals analyst
- "macro": Macro analyst
"""
if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
@ -118,6 +138,11 @@ class GraphSetup:
delete_nodes["fundamentals"] = create_msg_delete()
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
if "macro" in selected_analysts:
analyst_nodes["macro"] = create_macro_analyst(self.macro_analyst_llm)
delete_nodes["macro"] = create_msg_delete()
tool_nodes["macro"] = self.tool_nodes["macro"]
# Create researcher and manager nodes
bull_researcher_node = create_bull_researcher(
self.bull_researcher_llm, self.bull_memory
@ -175,7 +200,7 @@ class GraphSetup:
# Add conditional edges for current analyst
workflow.add_conditional_edges(
current_analyst,
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
self._get_continue_handler(analyst_type),
[current_tools, current_clear],
)
workflow.add_edge(current_tools, current_analyst)

View File

@ -28,9 +28,12 @@ from tradingagents.agents.utils.agent_utils import (
get_balance_sheet,
get_cashflow,
get_income_statement,
get_economic_indicators,
get_fed_calendar,
get_news,
get_insider_transactions,
get_global_news
get_global_news,
get_yield_curve,
)
from .conditional_logic import ConditionalLogic
@ -58,6 +61,7 @@ class TradingAgentsGraph:
"social",
"news",
"fundamentals",
"macro",
"bull_researcher",
"bear_researcher",
"trader",
@ -299,6 +303,14 @@ class TradingAgentsGraph:
get_income_statement,
]
),
"macro": ToolNode(
[
# Macroeconomic analysis tools
get_economic_indicators,
get_yield_curve,
get_fed_calendar,
]
),
}
def propagate(self, company_name, trade_date):
@ -345,6 +357,7 @@ class TradingAgentsGraph:
"sentiment_report": final_state["sentiment_report"],
"news_report": final_state["news_report"],
"fundamentals_report": final_state["fundamentals_report"],
"macro_report": final_state.get("macro_report", ""),
"investment_debate_state": {
"bull_history": final_state["investment_debate_state"]["bull_history"],
"bear_history": final_state["investment_debate_state"]["bear_history"],