464 lines
19 KiB
Python
464 lines
19 KiB
Python
"""Tests for PR#106 review fixes (ADR 016).
|
|
|
|
Covers:
|
|
- Fix 1: save_holding_review per-ticker iteration in run_portfolio
|
|
- Fix 2: contextvars-based RunLogger isolation
|
|
- Fix 3: list_pm_decisions excludes _id (ObjectId)
|
|
- Fix 4: ReflexionMemory created_at is native datetime for MongoDB
|
|
- Fix 5: write/read_latest_pointer respects base_dir parameter
|
|
- Fix 6: RunLogger callback wired into astream_events config
|
|
- Fix 7: ensure_indexes called in MongoReportStore.__init__
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
if _project_root not in sys.path:
|
|
sys.path.insert(0, _project_root)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _collect(agen):
|
|
"""Collect all events from an async generator into a list."""
|
|
events = []
|
|
async for evt in agen:
|
|
events.append(evt)
|
|
return events
|
|
|
|
|
|
def _root_chain_end_event(output: dict) -> dict:
|
|
"""Build a synthetic root on_chain_end LangGraph v2 event."""
|
|
return {
|
|
"event": "on_chain_end",
|
|
"name": "LangGraph",
|
|
"parent_ids": [],
|
|
"metadata": {},
|
|
"data": {"output": output},
|
|
"run_id": "test-run-id",
|
|
"tags": [],
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fix 1: save_holding_review per-ticker iteration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSaveHoldingReviewIteration(unittest.TestCase):
|
|
"""Verify save_holding_review is called per-ticker, not once with portfolio_id."""
|
|
|
|
_FINAL_STATE = {
|
|
"holding_reviews": json.dumps({
|
|
"AAPL": {"rating": "hold", "reason": "stable"},
|
|
"MSFT": {"rating": "buy", "reason": "growth"},
|
|
}),
|
|
"risk_metrics": "",
|
|
"pm_decision": "",
|
|
"execution_result": "",
|
|
}
|
|
|
|
def _make_mock_portfolio_graph(self, final_state=None):
|
|
if final_state is None:
|
|
final_state = self._FINAL_STATE
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield _root_chain_end_event(final_state)
|
|
|
|
mock_graph = MagicMock()
|
|
mock_graph.astream_events = mock_astream
|
|
mock_pg = MagicMock()
|
|
mock_pg.graph = mock_graph
|
|
return mock_pg
|
|
|
|
def test_holding_reviews_saved_per_ticker(self):
|
|
"""run_portfolio should call save_holding_review once per ticker key."""
|
|
from agent_os.backend.services.langgraph_engine import LangGraphEngine
|
|
|
|
mock_pg = self._make_mock_portfolio_graph()
|
|
engine = LangGraphEngine()
|
|
mock_store = MagicMock()
|
|
mock_store.load_scan.return_value = {}
|
|
mock_store.load_analysis.return_value = None
|
|
|
|
with patch("agent_os.backend.services.langgraph_engine.PortfolioGraph", return_value=mock_pg), \
|
|
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
|
|
patch("agent_os.backend.services.langgraph_engine.get_daily_dir") as mock_gdd, \
|
|
patch("agent_os.backend.services.langgraph_engine.append_to_digest"):
|
|
fake_daily = MagicMock(spec=Path)
|
|
fake_daily.exists.return_value = False
|
|
fake_daily.__truediv__ = MagicMock(return_value=MagicMock(spec=Path, exists=MagicMock(return_value=False)))
|
|
mock_gdd.return_value = fake_daily
|
|
|
|
asyncio.run(_collect(engine.run_portfolio("run1", {
|
|
"date": "2026-03-20",
|
|
"portfolio_id": "pid-123",
|
|
})))
|
|
|
|
# save_holding_review should be called once per ticker
|
|
calls = mock_store.save_holding_review.call_args_list
|
|
tickers_saved = {c.args[1] for c in calls} # (date, ticker, data)
|
|
self.assertEqual(tickers_saved, {"AAPL", "MSFT"})
|
|
self.assertEqual(len(calls), 2)
|
|
|
|
def test_non_dict_reviews_logs_warning(self):
|
|
"""When holding_reviews is not a dict, it should log a warning, not crash."""
|
|
from agent_os.backend.services.langgraph_engine import LangGraphEngine
|
|
|
|
state = dict(self._FINAL_STATE)
|
|
state["holding_reviews"] = json.dumps(["not", "a", "dict"])
|
|
|
|
mock_pg = self._make_mock_portfolio_graph(state)
|
|
engine = LangGraphEngine()
|
|
mock_store = MagicMock()
|
|
mock_store.load_scan.return_value = {}
|
|
mock_store.load_analysis.return_value = None
|
|
|
|
with patch("agent_os.backend.services.langgraph_engine.PortfolioGraph", return_value=mock_pg), \
|
|
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
|
|
patch("agent_os.backend.services.langgraph_engine.get_daily_dir") as mock_gdd, \
|
|
patch("agent_os.backend.services.langgraph_engine.append_to_digest"):
|
|
fake_daily = MagicMock(spec=Path)
|
|
fake_daily.exists.return_value = False
|
|
fake_daily.__truediv__ = MagicMock(return_value=MagicMock(spec=Path, exists=MagicMock(return_value=False)))
|
|
mock_gdd.return_value = fake_daily
|
|
|
|
events = asyncio.run(_collect(engine.run_portfolio("run1", {
|
|
"date": "2026-03-20",
|
|
"portfolio_id": "pid-123",
|
|
})))
|
|
|
|
# save_holding_review should NOT be called
|
|
mock_store.save_holding_review.assert_not_called()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fix 2: contextvars-based RunLogger isolation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestContextVarRunLogger(unittest.TestCase):
|
|
"""Verify RunLogger uses contextvars (isolated per asyncio task)."""
|
|
|
|
def test_set_get_returns_correct_logger(self):
|
|
from tradingagents.observability import (
|
|
RunLogger,
|
|
get_run_logger,
|
|
set_run_logger,
|
|
)
|
|
|
|
rl = RunLogger()
|
|
set_run_logger(rl)
|
|
self.assertIs(get_run_logger(), rl)
|
|
set_run_logger(None)
|
|
self.assertIsNone(get_run_logger())
|
|
|
|
def test_context_isolation_across_async_tasks(self):
|
|
"""Each asyncio task should have its own RunLogger."""
|
|
from tradingagents.observability import (
|
|
RunLogger,
|
|
get_run_logger,
|
|
set_run_logger,
|
|
)
|
|
|
|
results = {}
|
|
|
|
async def task(name: str):
|
|
rl = RunLogger()
|
|
set_run_logger(rl)
|
|
await asyncio.sleep(0.01)
|
|
results[name] = get_run_logger()
|
|
return rl
|
|
|
|
async def run_concurrent():
|
|
rl_a, rl_b = await asyncio.gather(task("A"), task("B"))
|
|
return rl_a, rl_b
|
|
|
|
rl_a, rl_b = asyncio.run(run_concurrent())
|
|
|
|
# Each task should get back its own logger, not the other's
|
|
self.assertIs(results["A"], rl_a)
|
|
self.assertIs(results["B"], rl_b)
|
|
# They should be different instances
|
|
self.assertIsNot(rl_a, rl_b)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fix 3: list_pm_decisions excludes _id
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestListPmDecisionsExcludesId(unittest.TestCase):
|
|
"""Verify list_pm_decisions uses {_id: 0} projection."""
|
|
|
|
def test_projection_excludes_object_id(self):
|
|
with patch("tradingagents.portfolio.mongo_report_store.MongoClient") as mock_client_cls:
|
|
mock_col = MagicMock()
|
|
mock_db = MagicMock()
|
|
mock_db.__getitem__ = MagicMock(return_value=mock_col)
|
|
mock_client = MagicMock()
|
|
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
|
mock_client_cls.return_value = mock_client
|
|
|
|
from tradingagents.portfolio.mongo_report_store import MongoReportStore
|
|
|
|
store = MongoReportStore("mongodb://localhost:27017", run_id="test")
|
|
store._col = mock_col
|
|
|
|
mock_col.find.return_value = []
|
|
store.list_pm_decisions("pid-123")
|
|
|
|
# Verify the projection argument includes _id: 0
|
|
find_call = mock_col.find.call_args
|
|
projection = find_call[0][1] if len(find_call[0]) > 1 else find_call[1].get("projection")
|
|
self.assertEqual(projection, {"_id": 0})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fix 4: ReflexionMemory created_at is native datetime for MongoDB
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestReflexionCreatedAtType(unittest.TestCase):
|
|
"""Verify created_at is native datetime for MongoDB, ISO string for local."""
|
|
|
|
def test_mongodb_path_stores_native_datetime(self):
|
|
"""When writing to MongoDB, created_at should be a datetime object."""
|
|
with patch("tradingagents.memory.reflexion.MongoClient", create=True) as mock_client_cls:
|
|
mock_col = MagicMock()
|
|
mock_db = MagicMock()
|
|
mock_db.__getitem__ = MagicMock(return_value=mock_col)
|
|
mock_client = MagicMock()
|
|
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
|
mock_client_cls.return_value = mock_client
|
|
|
|
from tradingagents.memory.reflexion import ReflexionMemory
|
|
|
|
mem = ReflexionMemory.__new__(ReflexionMemory)
|
|
mem._col = mock_col
|
|
mem._fallback_path = Path("/tmp/test_reflexion.json")
|
|
|
|
mem.record_decision("AAPL", "2026-03-20", "BUY", "test", "high")
|
|
|
|
doc = mock_col.insert_one.call_args[0][0]
|
|
self.assertIsInstance(doc["created_at"], datetime)
|
|
|
|
def test_local_path_stores_iso_string(self):
|
|
"""When writing to local JSON, created_at should be an ISO string."""
|
|
import tempfile
|
|
from tradingagents.memory.reflexion import ReflexionMemory
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
fb_path = Path(tmpdir) / "test_reflexion.json"
|
|
mem = ReflexionMemory(fallback_path=fb_path)
|
|
|
|
mem.record_decision("AAPL", "2026-03-20", "BUY", "test", "high")
|
|
|
|
data = json.loads(fb_path.read_text())
|
|
self.assertIsInstance(data[0]["created_at"], str)
|
|
# Should be parseable as ISO datetime
|
|
datetime.fromisoformat(data[0]["created_at"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fix 5: write/read_latest_pointer respects base_dir parameter
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestLatestPointerBaseDir(unittest.TestCase):
|
|
"""Verify write_latest_pointer/read_latest_pointer use base_dir."""
|
|
|
|
def test_pointer_uses_custom_base_dir(self):
|
|
from tradingagents.report_paths import read_latest_pointer, write_latest_pointer
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
base = Path(tmpdir) / "custom_reports"
|
|
write_latest_pointer("2026-03-20", "run123", base_dir=base)
|
|
|
|
# Should be written under the custom base, not REPORTS_ROOT
|
|
pointer = base / "daily" / "2026-03-20" / "latest.json"
|
|
self.assertTrue(pointer.exists())
|
|
data = json.loads(pointer.read_text())
|
|
self.assertEqual(data["run_id"], "run123")
|
|
|
|
# read_latest_pointer should use the same base
|
|
result = read_latest_pointer("2026-03-20", base_dir=base)
|
|
self.assertEqual(result, "run123")
|
|
|
|
def test_read_returns_none_with_wrong_base(self):
|
|
from tradingagents.report_paths import read_latest_pointer, write_latest_pointer
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
base_a = Path(tmpdir) / "a"
|
|
base_b = Path(tmpdir) / "b"
|
|
write_latest_pointer("2026-03-20", "run_a", base_dir=base_a)
|
|
|
|
# Reading from a different base should not find it
|
|
result = read_latest_pointer("2026-03-20", base_dir=base_b)
|
|
self.assertIsNone(result)
|
|
|
|
def test_report_store_passes_base_dir(self):
|
|
"""ReportStore should pass its _base_dir to pointer functions."""
|
|
from tradingagents.portfolio.report_store import ReportStore
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
base = Path(tmpdir) / "custom"
|
|
store = ReportStore(base_dir=base, run_id="abc123")
|
|
|
|
# Trigger a save which calls _update_latest
|
|
store.save_scan("2026-03-20", {"test": True})
|
|
|
|
# Pointer should be under the custom base
|
|
pointer = base / "daily" / "2026-03-20" / "latest.json"
|
|
self.assertTrue(pointer.exists())
|
|
data = json.loads(pointer.read_text())
|
|
self.assertEqual(data["run_id"], "abc123")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fix 6: RunLogger callback wired into astream_events config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRunLoggerCallbackWiring(unittest.TestCase):
|
|
"""Verify astream_events receives the RunLogger callback in config."""
|
|
|
|
def _make_mock_graph(self, final_state):
|
|
"""Create a mock graph that captures the config passed to astream_events."""
|
|
captured_config = {}
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
captured_config.update(kwargs.get("config", {}))
|
|
yield _root_chain_end_event(final_state)
|
|
|
|
mock_graph = MagicMock()
|
|
mock_graph.astream_events = mock_astream
|
|
return mock_graph, captured_config
|
|
|
|
def test_run_scan_wires_callback(self):
|
|
from agent_os.backend.services.langgraph_engine import LangGraphEngine
|
|
|
|
mock_graph, captured = self._make_mock_graph({
|
|
"geopolitical_report": "", "market_movers_report": "",
|
|
"sector_performance_report": "", "industry_deep_dive_report": "",
|
|
"macro_scan_summary": "",
|
|
})
|
|
mock_scanner = MagicMock()
|
|
mock_scanner.graph = mock_graph
|
|
|
|
engine = LangGraphEngine()
|
|
mock_store = MagicMock()
|
|
|
|
with patch("agent_os.backend.services.langgraph_engine.ScannerGraph", return_value=mock_scanner), \
|
|
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
|
|
patch("agent_os.backend.services.langgraph_engine.get_market_dir") as mock_gmd, \
|
|
patch("agent_os.backend.services.langgraph_engine.append_to_digest"), \
|
|
patch("agent_os.backend.services.langgraph_engine.extract_json", return_value={}):
|
|
fake_dir = MagicMock(spec=Path)
|
|
fake_dir.__truediv__ = MagicMock(return_value=MagicMock(spec=Path))
|
|
fake_dir.mkdir = MagicMock()
|
|
mock_gmd.return_value = fake_dir
|
|
|
|
asyncio.run(_collect(engine.run_scan("run1", {"date": "2026-01-01"})))
|
|
|
|
self.assertIn("callbacks", captured)
|
|
self.assertEqual(len(captured["callbacks"]), 1)
|
|
|
|
def test_run_pipeline_wires_callback(self):
|
|
from agent_os.backend.services.langgraph_engine import LangGraphEngine
|
|
|
|
mock_graph, captured = self._make_mock_graph({"final_trade_decision": "BUY"})
|
|
mock_propagator = MagicMock()
|
|
mock_propagator.max_recur_limit = 100
|
|
mock_propagator.create_initial_state.return_value = {"ticker": "AAPL"}
|
|
mock_wrapper = MagicMock()
|
|
mock_wrapper.graph = mock_graph
|
|
mock_wrapper.propagator = mock_propagator
|
|
|
|
engine = LangGraphEngine()
|
|
mock_store = MagicMock()
|
|
|
|
with patch("agent_os.backend.services.langgraph_engine.TradingAgentsGraph", return_value=mock_wrapper), \
|
|
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
|
|
patch("agent_os.backend.services.langgraph_engine.get_ticker_dir") as mock_gtd, \
|
|
patch("agent_os.backend.services.langgraph_engine.append_to_digest"):
|
|
fake_dir = MagicMock(spec=Path)
|
|
fake_dir.__truediv__ = MagicMock(return_value=MagicMock(spec=Path))
|
|
fake_dir.mkdir = MagicMock()
|
|
mock_gtd.return_value = fake_dir
|
|
|
|
asyncio.run(_collect(engine.run_pipeline("run1", {
|
|
"ticker": "AAPL", "date": "2026-01-01",
|
|
})))
|
|
|
|
self.assertIn("callbacks", captured)
|
|
self.assertEqual(len(captured["callbacks"]), 1)
|
|
# Also verify recursion_limit is still set
|
|
self.assertEqual(captured["recursion_limit"], 100)
|
|
|
|
def test_run_portfolio_wires_callback(self):
|
|
from agent_os.backend.services.langgraph_engine import LangGraphEngine
|
|
|
|
mock_graph, captured = self._make_mock_graph({
|
|
"holding_reviews": "", "risk_metrics": "",
|
|
"pm_decision": "", "execution_result": "",
|
|
})
|
|
mock_pg = MagicMock()
|
|
mock_pg.graph = mock_graph
|
|
|
|
engine = LangGraphEngine()
|
|
mock_store = MagicMock()
|
|
mock_store.load_scan.return_value = {}
|
|
mock_store.load_analysis.return_value = None
|
|
|
|
with patch("agent_os.backend.services.langgraph_engine.PortfolioGraph", return_value=mock_pg), \
|
|
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
|
|
patch("agent_os.backend.services.langgraph_engine.get_daily_dir") as mock_gdd, \
|
|
patch("agent_os.backend.services.langgraph_engine.append_to_digest"):
|
|
fake_daily = MagicMock(spec=Path)
|
|
fake_daily.exists.return_value = False
|
|
fake_daily.__truediv__ = MagicMock(return_value=MagicMock(spec=Path, exists=MagicMock(return_value=False)))
|
|
mock_gdd.return_value = fake_daily
|
|
|
|
asyncio.run(_collect(engine.run_portfolio("run1", {
|
|
"date": "2026-01-01", "portfolio_id": "pid-123",
|
|
})))
|
|
|
|
self.assertIn("callbacks", captured)
|
|
self.assertEqual(len(captured["callbacks"]), 1)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fix 7: ensure_indexes called in MongoReportStore.__init__
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestEnsureIndexesInInit(unittest.TestCase):
|
|
"""Verify ensure_indexes is called during __init__, not just via factory."""
|
|
|
|
def test_init_calls_ensure_indexes(self):
|
|
with patch("tradingagents.portfolio.mongo_report_store.MongoClient") as mock_client_cls:
|
|
mock_col = MagicMock()
|
|
mock_db = MagicMock()
|
|
mock_db.__getitem__ = MagicMock(return_value=mock_col)
|
|
mock_client = MagicMock()
|
|
mock_client.__getitem__ = MagicMock(return_value=mock_db)
|
|
mock_client_cls.return_value = mock_client
|
|
|
|
from tradingagents.portfolio.mongo_report_store import MongoReportStore
|
|
|
|
store = MongoReportStore("mongodb://localhost:27017", run_id="test")
|
|
|
|
# create_index should have been called at least 4 times
|
|
# (the indexes from ensure_indexes)
|
|
self.assertGreaterEqual(mock_col.create_index.call_count, 4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|