diff --git a/tests/test_market_gate.py b/tests/test_market_gate.py new file mode 100644 index 00000000..4555b6c5 --- /dev/null +++ b/tests/test_market_gate.py @@ -0,0 +1,153 @@ +"""Tests for the pre-trade market state verification gate. + +Imports the market_gate module directly to avoid pulling in the full +tradingagents dependency chain (yfinance, etc.) in CI environments +that only need to validate the gate logic. +""" + +import json +import sys +from pathlib import Path +from unittest.mock import patch, MagicMock +from io import BytesIO + +# Allow direct import without the full tradingagents package chain +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "tradingagents" / "agents" / "risk_mgmt")) +from market_gate import ( + _ticker_to_mic, + check_market_state, + create_market_gate, +) + + +class TestTickerToMic: + """Test ticker suffix to MIC resolution.""" + + def test_plain_us_ticker(self): + assert _ticker_to_mic("AAPL") == "XNYS" + + def test_london_suffix(self): + assert _ticker_to_mic("VOD.L") == "XLON" + + def test_tokyo_suffix(self): + assert _ticker_to_mic("7203.T") == "XJPX" + + def test_hong_kong_suffix(self): + assert _ticker_to_mic("0700.HK") == "XHKG" + + def test_case_insensitive(self): + assert _ticker_to_mic("vod.l") == "XLON" + + def test_unknown_suffix_defaults_to_xnys(self): + assert _ticker_to_mic("UNKNOWN.ZZ") == "XNYS" + + +class TestCheckMarketState: + """Test the oracle HTTP call and fail-closed behavior.""" + + @patch("market_gate.urlopen") + def test_open_market(self, mock_urlopen): + response = BytesIO(json.dumps({"status": "OPEN", "mic": "XNYS"}).encode()) + mock_urlopen.return_value.__enter__ = lambda s: response + mock_urlopen.return_value.__exit__ = MagicMock(return_value=False) + + result = check_market_state("AAPL") + assert result["status"] == "OPEN" + assert result["blocked"] is False + assert result["reason"] == "" + + @patch("market_gate.urlopen") + def test_closed_market(self, mock_urlopen): + response = BytesIO(json.dumps({"status": "CLOSED", "mic": "XNYS"}).encode()) + mock_urlopen.return_value.__enter__ = lambda s: response + mock_urlopen.return_value.__exit__ = MagicMock(return_value=False) + + result = check_market_state("AAPL") + assert result["status"] == "CLOSED" + assert result["blocked"] is True + assert "BLOCK TRADE" in result["reason"] + + @patch("market_gate.urlopen") + def test_halted_market(self, mock_urlopen): + response = BytesIO(json.dumps({"status": "HALTED", "mic": "XNYS"}).encode()) + mock_urlopen.return_value.__enter__ = lambda s: response + mock_urlopen.return_value.__exit__ = MagicMock(return_value=False) + + result = check_market_state("AAPL") + assert result["status"] == "HALTED" + assert result["blocked"] is True + + @patch("market_gate.urlopen") + def test_network_failure_defaults_to_unknown(self, mock_urlopen): + from urllib.error import URLError + mock_urlopen.side_effect = URLError("connection refused") + + result = check_market_state("AAPL") + assert result["status"] == "UNKNOWN" + assert result["blocked"] is True + assert "BLOCK TRADE" in result["reason"] + + +class TestMarketGateNode: + """Test the LangGraph node integration.""" + + @patch("market_gate.check_market_state") + def test_open_market_adds_safe_advisory(self, mock_check): + mock_check.return_value = { + "status": "OPEN", "mic": "XNYS", "blocked": False, "reason": "" + } + + node = create_market_gate() + state = { + "company_of_interest": "AAPL", + "risk_debate_state": { + "history": "Prior debate...", + "aggressive_history": "", + "conservative_history": "", + "neutral_history": "", + "latest_speaker": "Neutral", + "current_aggressive_response": "", + "current_conservative_response": "", + "current_neutral_response": "", + "judge_decision": "", + "count": 3, + }, + } + + result = node(state) + history = result["risk_debate_state"]["history"] + assert "[MARKET GATE]" in history + assert "OPEN" in history + assert "safe to proceed" in history + + @patch("market_gate.check_market_state") + def test_closed_market_adds_block_advisory(self, mock_check): + mock_check.return_value = { + "status": "CLOSED", + "mic": "XNYS", + "blocked": True, + "reason": "BLOCK TRADE — market XNYS is CLOSED", + } + + node = create_market_gate() + state = { + "company_of_interest": "AAPL", + "risk_debate_state": { + "history": "Prior debate...", + "aggressive_history": "", + "conservative_history": "", + "neutral_history": "", + "latest_speaker": "Neutral", + "current_aggressive_response": "", + "current_conservative_response": "", + "current_neutral_response": "", + "judge_decision": "", + "count": 3, + }, + } + + result = node(state) + history = result["risk_debate_state"]["history"] + assert "[MARKET GATE]" in history + assert "BLOCK TRADE" in history + assert "Do NOT approve execution" in history diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 1f03642c..30ee8d20 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -13,6 +13,7 @@ from .researchers.bull_researcher import create_bull_researcher from .risk_mgmt.aggressive_debator import create_aggressive_debator from .risk_mgmt.conservative_debator import create_conservative_debator from .risk_mgmt.neutral_debator import create_neutral_debator +from .risk_mgmt.market_gate import create_market_gate from .managers.research_manager import create_research_manager from .managers.portfolio_manager import create_portfolio_manager @@ -37,4 +38,5 @@ __all__ = [ "create_conservative_debator", "create_social_media_analyst", "create_trader", + "create_market_gate", ] diff --git a/tradingagents/agents/risk_mgmt/market_gate.py b/tradingagents/agents/risk_mgmt/market_gate.py new file mode 100644 index 00000000..499b80e8 --- /dev/null +++ b/tradingagents/agents/risk_mgmt/market_gate.py @@ -0,0 +1,125 @@ +"""Pre-trade market state verification gate. + +Checks whether the target exchange is open before the Portfolio Manager +approves a trade. Uses the Headless Oracle free demo endpoint — no API +key or account required. + +Resolves: https://github.com/TauricResearch/TradingAgents/issues/514 +""" + +import json +import logging +from urllib.request import urlopen, Request +from urllib.error import URLError + +logger = logging.getLogger(__name__) + +# Map common ticker suffixes to ISO 10383 Market Identifier Codes. +# Tickers without a suffix are assumed to be US equities (XNYS). +SUFFIX_TO_MIC = { + "": "XNYS", # US equities (default) + ".TO": "XNYS", # TMX — route through NYSE hours as proxy + ".L": "XLON", # London Stock Exchange + ".HK": "XHKG", # Hong Kong + ".T": "XJPX", # Tokyo + ".PA": "XPAR", # Euronext Paris + ".SI": "XSES", # Singapore + ".AX": "XASX", # Australia + ".BO": "XBOM", # BSE India + ".NS": "XNSE", # NSE India + ".SS": "XSHG", # Shanghai + ".SZ": "XSHE", # Shenzhen + ".KS": "XKRX", # Korea + ".JO": "XJSE", # Johannesburg + ".SA": "XBSP", # B3 Brazil + ".SW": "XSWX", # SIX Swiss + ".MI": "XMIL", # Borsa Italiana + ".IS": "XIST", # Borsa Istanbul + ".SR": "XSAU", # Saudi Exchange + ".NZ": "XNZE", # New Zealand + ".HE": "XHEL", # Nasdaq Helsinki + ".ST": "XSTO", # Nasdaq Stockholm +} + +ORACLE_URL = "https://headlessoracle.com/v5/demo" + + +def _ticker_to_mic(ticker: str) -> str: + """Derive the exchange MIC from a ticker's suffix.""" + upper = ticker.upper() + for suffix, mic in SUFFIX_TO_MIC.items(): + if suffix and upper.endswith(suffix): + return mic + # No suffix -> US equity + return "XNYS" + + +def check_market_state(ticker: str, timeout: int = 10) -> dict: + """Fetch a signed market-state receipt for the ticker's exchange. + + Returns a dict with at least: + status - "OPEN", "CLOSED", "HALTED", or "UNKNOWN" + mic - the exchange MIC that was checked + blocked - True if the trade should not proceed + reason - human-readable explanation (empty string when OPEN) + + On network failure the status defaults to "UNKNOWN" (fail-closed). + """ + mic = _ticker_to_mic(ticker) + url = f"{ORACLE_URL}?mic={mic}" + + try: + req = Request(url, headers={"User-Agent": "TradingAgents/1.0"}) + with urlopen(req, timeout=timeout) as resp: + data = json.load(resp) + except (URLError, OSError, json.JSONDecodeError) as exc: + logger.warning("Market gate: oracle unreachable (%s), defaulting to UNKNOWN", exc) + data = {"status": "UNKNOWN", "mic": mic} + + status = data.get("status", "UNKNOWN") + blocked = status != "OPEN" + reason = "" if not blocked else f"BLOCK TRADE — market {mic} is {status}" + + return { + "status": status, + "mic": mic, + "blocked": blocked, + "reason": reason, + } + + +def create_market_gate(): + """Create a graph node that gates trade execution on market state. + + When the market is not OPEN, the node injects a blocking advisory into + the risk debate history so the Portfolio Manager sees it before deciding. + """ + + def market_gate_node(state) -> dict: + ticker = state["company_of_interest"] + result = check_market_state(ticker) + + risk_debate_state = state["risk_debate_state"] + history = risk_debate_state.get("history", "") + + if result["blocked"]: + advisory = ( + f"\n\n[MARKET GATE] {result['reason']}. " + f"Exchange {result['mic']} status is {result['status']}. " + "Do NOT approve execution — the market is not open for trading." + ) + logger.info("Market gate blocked trade: %s", result["reason"]) + else: + advisory = ( + f"\n\n[MARKET GATE] Exchange {result['mic']} is OPEN. " + "Market state verified — safe to proceed with execution." + ) + + new_risk_debate_state = { + **risk_debate_state, + "history": history + advisory, + } + + return {"risk_debate_state": new_risk_debate_state} + + return market_gate_node diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 26a4e4d2..622de552 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -23,6 +23,13 @@ DEFAULT_CONFIG = { "max_debate_rounds": 1, "max_risk_discuss_rounds": 1, "max_recur_limit": 100, + # Market state verification gate (pre-trade safety check) + # When enabled, verifies the target exchange is open before the + # Portfolio Manager approves execution. Uses the free Headless Oracle + # demo endpoint (no API key required). Covers 28 global exchanges. + # See: https://github.com/TauricResearch/TradingAgents/issues/514 + "use_market_gate": True, + # Data vendor configuration # Category-level configuration (default for all tools in category) "data_vendors": { diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index ae90489c..13c09c16 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -5,6 +5,7 @@ from langgraph.graph import END, START, StateGraph from langgraph.prebuilt import ToolNode from tradingagents.agents import * +from tradingagents.agents.risk_mgmt.market_gate import create_market_gate from tradingagents.agents.utils.agent_states import AgentState from .conditional_logic import ConditionalLogic @@ -24,8 +25,10 @@ class GraphSetup: invest_judge_memory, portfolio_manager_memory, conditional_logic: ConditionalLogic, + config: dict = None, ): """Initialize with required components.""" + self.config = config or {} self.quick_thinking_llm = quick_thinking_llm self.deep_thinking_llm = deep_thinking_llm self.tool_nodes = tool_nodes @@ -169,13 +172,24 @@ class GraphSetup: }, ) workflow.add_edge("Research Manager", "Trader") + # Market state verification gate (optional, enabled by default) + # When enabled, checks if the target exchange is open before the + # Portfolio Manager approves execution. + if self.config.get("use_market_gate", True): + market_gate_node = create_market_gate() + workflow.add_node("Market Gate", market_gate_node) + workflow.add_edge("Market Gate", "Portfolio Manager") + portfolio_entry = "Market Gate" + else: + portfolio_entry = "Portfolio Manager" + workflow.add_edge("Trader", "Aggressive Analyst") workflow.add_conditional_edges( "Aggressive Analyst", self.conditional_logic.should_continue_risk_analysis, { "Conservative Analyst": "Conservative Analyst", - "Portfolio Manager": "Portfolio Manager", + "Portfolio Manager": portfolio_entry, }, ) workflow.add_conditional_edges( @@ -183,7 +197,7 @@ class GraphSetup: self.conditional_logic.should_continue_risk_analysis, { "Neutral Analyst": "Neutral Analyst", - "Portfolio Manager": "Portfolio Manager", + "Portfolio Manager": portfolio_entry, }, ) workflow.add_conditional_edges( @@ -191,7 +205,7 @@ class GraphSetup: self.conditional_logic.should_continue_risk_analysis, { "Aggressive Analyst": "Aggressive Analyst", - "Portfolio Manager": "Portfolio Manager", + "Portfolio Manager": portfolio_entry, }, ) diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 8e18f9c4..0f2bafe7 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -119,6 +119,7 @@ class TradingAgentsGraph: self.invest_judge_memory, self.portfolio_manager_memory, self.conditional_logic, + config=self.config, ) self.propagator = Propagator()