TradingAgents/tests/unit/test_pr106_review_fixes.py

468 lines
19 KiB
Python

"""Tests for PR#106 review fixes (ADR 016).
Covers:
- Fix 1: save_holding_review per-ticker iteration in run_portfolio
- Fix 2: contextvars-based RunLogger isolation
- Fix 3: list_pm_decisions excludes _id (ObjectId)
- Fix 4: ReflexionMemory created_at is native datetime for MongoDB
- Fix 5: write/read_latest_pointer respects base_dir parameter
- Fix 6: RunLogger callback wired into astream_events config
- Fix 7: ensure_indexes called in MongoReportStore.__init__
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
import tempfile
import unittest
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock, patch
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if _project_root not in sys.path:
sys.path.insert(0, _project_root)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _collect(agen):
"""Collect all events from an async generator into a list."""
events = []
async for evt in agen:
events.append(evt)
return events
def _root_chain_end_event(output: dict) -> dict:
"""Build a synthetic root on_chain_end LangGraph v2 event."""
return {
"event": "on_chain_end",
"name": "LangGraph",
"parent_ids": [],
"metadata": {},
"data": {"output": output},
"run_id": "test-run-id",
"tags": [],
}
# ---------------------------------------------------------------------------
# Fix 1: save_holding_review per-ticker iteration
# ---------------------------------------------------------------------------
class TestSaveHoldingReviewIteration(unittest.TestCase):
"""Verify save_holding_review is called per-ticker, not once with portfolio_id."""
_FINAL_STATE = {
"holding_reviews": json.dumps({
"AAPL": {"rating": "hold", "reason": "stable"},
"MSFT": {"rating": "buy", "reason": "growth"},
}),
"risk_metrics": "",
"pm_decision": "",
"execution_result": "",
}
def _make_mock_portfolio_graph(self, final_state=None):
if final_state is None:
final_state = self._FINAL_STATE
async def mock_astream(*args, **kwargs):
yield _root_chain_end_event(final_state)
mock_graph = MagicMock()
mock_graph.astream_events = mock_astream
mock_pg = MagicMock()
mock_pg.graph = mock_graph
return mock_pg
def test_holding_reviews_saved_per_ticker(self):
"""run_portfolio should call save_holding_review once per ticker key."""
from agent_os.backend.services.langgraph_engine import LangGraphEngine
mock_pg = self._make_mock_portfolio_graph()
engine = LangGraphEngine()
mock_store = MagicMock()
mock_store.load_scan.return_value = {}
mock_store.load_analysis.return_value = None
with patch("agent_os.backend.services.langgraph_engine.PortfolioGraph", return_value=mock_pg), \
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
patch("agent_os.backend.services.langgraph_engine.get_daily_dir") as mock_gdd, \
patch("agent_os.backend.services.langgraph_engine.append_to_digest"):
fake_daily = MagicMock(spec=Path)
fake_daily.exists.return_value = False
fake_daily.__truediv__ = MagicMock(return_value=MagicMock(spec=Path, exists=MagicMock(return_value=False)))
mock_gdd.return_value = fake_daily
asyncio.run(_collect(engine.run_portfolio("run1", {
"date": "2026-03-20",
"portfolio_id": "pid-123",
})))
# save_holding_review should be called once per ticker
calls = mock_store.save_holding_review.call_args_list
tickers_saved = {c.args[1] for c in calls} # (date, ticker, data)
self.assertEqual(tickers_saved, {"AAPL", "MSFT"})
self.assertEqual(len(calls), 2)
def test_non_dict_reviews_logs_warning(self):
"""When holding_reviews is not a dict, it should log a warning, not crash."""
from agent_os.backend.services.langgraph_engine import LangGraphEngine
state = dict(self._FINAL_STATE)
state["holding_reviews"] = json.dumps(["not", "a", "dict"])
mock_pg = self._make_mock_portfolio_graph(state)
engine = LangGraphEngine()
mock_store = MagicMock()
mock_store.load_scan.return_value = {}
mock_store.load_analysis.return_value = None
with patch("agent_os.backend.services.langgraph_engine.PortfolioGraph", return_value=mock_pg), \
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
patch("agent_os.backend.services.langgraph_engine.get_daily_dir") as mock_gdd, \
patch("agent_os.backend.services.langgraph_engine.append_to_digest"):
fake_daily = MagicMock(spec=Path)
fake_daily.exists.return_value = False
fake_daily.__truediv__ = MagicMock(return_value=MagicMock(spec=Path, exists=MagicMock(return_value=False)))
mock_gdd.return_value = fake_daily
events = asyncio.run(_collect(engine.run_portfolio("run1", {
"date": "2026-03-20",
"portfolio_id": "pid-123",
})))
# save_holding_review should NOT be called
mock_store.save_holding_review.assert_not_called()
# ---------------------------------------------------------------------------
# Fix 2: contextvars-based RunLogger isolation
# ---------------------------------------------------------------------------
class TestContextVarRunLogger(unittest.TestCase):
"""Verify RunLogger uses contextvars (isolated per asyncio task)."""
def test_set_get_returns_correct_logger(self):
from tradingagents.observability import (
RunLogger,
get_run_logger,
set_run_logger,
)
rl = RunLogger()
set_run_logger(rl)
self.assertIs(get_run_logger(), rl)
set_run_logger(None)
self.assertIsNone(get_run_logger())
def test_context_isolation_across_async_tasks(self):
"""Each asyncio task should have its own RunLogger."""
from tradingagents.observability import (
RunLogger,
get_run_logger,
set_run_logger,
)
results = {}
async def task(name: str):
rl = RunLogger()
set_run_logger(rl)
await asyncio.sleep(0.01)
results[name] = get_run_logger()
return rl
async def run_concurrent():
rl_a, rl_b = await asyncio.gather(task("A"), task("B"))
return rl_a, rl_b
rl_a, rl_b = asyncio.run(run_concurrent())
# Each task should get back its own logger, not the other's
self.assertIs(results["A"], rl_a)
self.assertIs(results["B"], rl_b)
# They should be different instances
self.assertIsNot(rl_a, rl_b)
# ---------------------------------------------------------------------------
# Fix 3: list_pm_decisions excludes _id
# ---------------------------------------------------------------------------
class TestListPmDecisionsExcludesId(unittest.TestCase):
"""Verify list_pm_decisions uses {_id: 0} projection."""
def test_projection_excludes_object_id(self):
with patch("tradingagents.portfolio.mongo_report_store.MongoClient") as mock_client_cls:
mock_col = MagicMock()
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_col)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_client_cls.return_value = mock_client
from tradingagents.portfolio.mongo_report_store import MongoReportStore
store = MongoReportStore("mongodb://localhost:27017", run_id="test")
store._col = mock_col
mock_col.find.return_value = []
store.list_pm_decisions("pid-123")
# Verify the projection argument includes _id: 0
find_call = mock_col.find.call_args
projection = find_call[0][1] if len(find_call[0]) > 1 else find_call[1].get("projection")
self.assertEqual(projection, {"_id": 0})
# ---------------------------------------------------------------------------
# Fix 4: ReflexionMemory created_at is native datetime for MongoDB
# ---------------------------------------------------------------------------
class TestReflexionCreatedAtType(unittest.TestCase):
"""Verify created_at is native datetime for MongoDB, ISO string for local."""
def test_mongodb_path_stores_native_datetime(self):
"""When writing to MongoDB, created_at should be a datetime object."""
with patch("tradingagents.memory.reflexion.MongoClient", create=True) as mock_client_cls:
mock_col = MagicMock()
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_col)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_client_cls.return_value = mock_client
from tradingagents.memory.reflexion import ReflexionMemory
mem = ReflexionMemory.__new__(ReflexionMemory)
mem._col = mock_col
mem._fallback_path = Path("/tmp/test_reflexion.json")
mem.record_decision("AAPL", "2026-03-20", "BUY", "test", "high")
doc = mock_col.insert_one.call_args[0][0]
self.assertIsInstance(doc["created_at"], datetime)
def test_local_path_stores_iso_string(self):
"""When writing to local JSON, created_at should be an ISO string."""
import tempfile
from tradingagents.memory.reflexion import ReflexionMemory
with tempfile.TemporaryDirectory() as tmpdir:
fb_path = Path(tmpdir) / "test_reflexion.json"
mem = ReflexionMemory(fallback_path=fb_path)
mem.record_decision("AAPL", "2026-03-20", "BUY", "test", "high")
data = json.loads(fb_path.read_text())
self.assertIsInstance(data[0]["created_at"], str)
# Should be parseable as ISO datetime
datetime.fromisoformat(data[0]["created_at"])
# ---------------------------------------------------------------------------
# Fix 5: write/read_latest_pointer respects base_dir parameter
# ---------------------------------------------------------------------------
class TestLatestPointerBaseDir(unittest.TestCase):
"""Verify write_latest_pointer/read_latest_pointer use base_dir."""
def test_pointer_uses_custom_base_dir(self):
from tradingagents.report_paths import read_latest_pointer, write_latest_pointer
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir) / "custom_reports"
write_latest_pointer("2026-03-20", "run123", base_dir=base)
# Should be written under the custom base, not REPORTS_ROOT
pointer = base / "daily" / "2026-03-20" / "latest.json"
self.assertTrue(pointer.exists())
data = json.loads(pointer.read_text())
self.assertEqual(data["run_id"], "run123")
# read_latest_pointer should use the same base
result = read_latest_pointer("2026-03-20", base_dir=base)
self.assertEqual(result, "run123")
def test_read_returns_none_with_wrong_base(self):
from tradingagents.report_paths import read_latest_pointer, write_latest_pointer
with tempfile.TemporaryDirectory() as tmpdir:
base_a = Path(tmpdir) / "a"
base_b = Path(tmpdir) / "b"
write_latest_pointer("2026-03-20", "run_a", base_dir=base_a)
# Reading from a different base should not find it
result = read_latest_pointer("2026-03-20", base_dir=base_b)
self.assertIsNone(result)
def test_report_store_passes_base_dir(self):
"""ReportStore should pass its _base_dir to pointer functions."""
from tradingagents.portfolio.report_store import ReportStore
with tempfile.TemporaryDirectory() as tmpdir:
base = Path(tmpdir) / "custom"
store = ReportStore(base_dir=base, run_id="abc123")
# Trigger a save which calls _update_latest
store.save_scan("2026-03-20", {"test": True})
# Pointer should be under the custom base
pointer = base / "daily" / "2026-03-20" / "latest.json"
self.assertTrue(pointer.exists())
data = json.loads(pointer.read_text())
self.assertEqual(data["run_id"], "abc123")
# ---------------------------------------------------------------------------
# Fix 6: RunLogger callback wired into astream_events config
# ---------------------------------------------------------------------------
class TestRunLoggerCallbackWiring(unittest.TestCase):
"""Verify astream_events receives the RunLogger callback in config."""
def _make_mock_graph(self, final_state):
"""Create a mock graph that captures the config passed to astream_events."""
captured_config = {}
async def mock_astream(*args, **kwargs):
captured_config.update(kwargs.get("config", {}))
yield _root_chain_end_event(final_state)
mock_graph = MagicMock()
mock_graph.astream_events = mock_astream
return mock_graph, captured_config
def test_run_scan_wires_callback(self):
from agent_os.backend.services.langgraph_engine import LangGraphEngine
mock_graph, captured = self._make_mock_graph({
"geopolitical_report": "", "market_movers_report": "",
"sector_performance_report": "", "industry_deep_dive_report": "",
"macro_scan_summary": "",
})
mock_scanner = MagicMock()
mock_scanner.graph = mock_graph
engine = LangGraphEngine()
mock_store = MagicMock()
with patch("agent_os.backend.services.langgraph_engine.ScannerGraph", return_value=mock_scanner), \
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
patch("agent_os.backend.services.langgraph_engine.get_market_dir") as mock_gmd, \
patch("agent_os.backend.services.langgraph_engine.append_to_digest"), \
patch("agent_os.backend.services.langgraph_engine.extract_json", return_value={}):
fake_dir = MagicMock(spec=Path)
fake_dir.__truediv__ = MagicMock(return_value=MagicMock(spec=Path))
fake_dir.mkdir = MagicMock()
mock_gmd.return_value = fake_dir
asyncio.run(_collect(engine.run_scan("run1", {"date": "2026-01-01"})))
self.assertIn("callbacks", captured)
self.assertEqual(len(captured["callbacks"]), 1)
def test_run_pipeline_wires_callback(self):
from agent_os.backend.services.langgraph_engine import LangGraphEngine
mock_graph, captured = self._make_mock_graph({"final_trade_decision": "BUY"})
mock_propagator = MagicMock()
mock_propagator.max_recur_limit = 100
mock_propagator.create_initial_state.return_value = {"ticker": "AAPL"}
mock_wrapper = MagicMock()
mock_wrapper.graph = mock_graph
mock_wrapper.propagator = mock_propagator
engine = LangGraphEngine()
mock_store = MagicMock()
with patch("agent_os.backend.services.langgraph_engine.TradingAgentsGraph", return_value=mock_wrapper), \
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
patch("agent_os.backend.services.langgraph_engine.get_ticker_dir") as mock_gtd, \
patch("agent_os.backend.services.langgraph_engine.append_to_digest"):
fake_dir = MagicMock(spec=Path)
fake_dir.__truediv__ = MagicMock(return_value=MagicMock(spec=Path))
fake_dir.mkdir = MagicMock()
mock_gtd.return_value = fake_dir
asyncio.run(_collect(engine.run_pipeline("run1", {
"ticker": "AAPL", "date": "2026-01-01",
})))
self.assertIn("callbacks", captured)
self.assertEqual(len(captured["callbacks"]), 1)
# Also verify recursion_limit is still set
self.assertEqual(captured["recursion_limit"], 100)
def test_run_portfolio_wires_callback(self):
from agent_os.backend.services.langgraph_engine import LangGraphEngine
mock_graph, captured = self._make_mock_graph({
"holding_reviews": "", "risk_metrics": "",
"pm_decision": "", "execution_result": "",
})
mock_pg = MagicMock()
mock_pg.graph = mock_graph
engine = LangGraphEngine()
mock_store = MagicMock()
mock_store.load_scan.return_value = {}
mock_store.load_analysis.return_value = None
with patch("agent_os.backend.services.langgraph_engine.PortfolioGraph", return_value=mock_pg), \
patch("agent_os.backend.services.langgraph_engine.create_report_store", return_value=mock_store), \
patch("agent_os.backend.services.langgraph_engine.get_daily_dir") as mock_gdd, \
patch("agent_os.backend.services.langgraph_engine.append_to_digest"):
fake_daily = MagicMock(spec=Path)
fake_daily.exists.return_value = False
fake_daily.__truediv__ = MagicMock(return_value=MagicMock(spec=Path, exists=MagicMock(return_value=False)))
mock_gdd.return_value = fake_daily
asyncio.run(_collect(engine.run_portfolio("run1", {
"date": "2026-01-01", "portfolio_id": "pid-123",
})))
self.assertIn("callbacks", captured)
self.assertEqual(len(captured["callbacks"]), 1)
# ---------------------------------------------------------------------------
# Fix 7: ensure_indexes called in MongoReportStore.__init__
# ---------------------------------------------------------------------------
class TestEnsureIndexesInInit(unittest.TestCase):
"""Verify ensure_indexes is called during __init__, not just via factory."""
def test_init_calls_ensure_indexes(self):
with patch("tradingagents.portfolio.mongo_report_store.MongoClient") as mock_client_cls:
mock_col = MagicMock()
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_col)
mock_client = MagicMock()
mock_client.__getitem__ = MagicMock(return_value=mock_db)
mock_client_cls.return_value = mock_client
from tradingagents.portfolio.mongo_report_store import MongoReportStore
store = MongoReportStore("mongodb://localhost:27017", run_id="test")
# Indexes are now created lazily, not in __init__.
# Explicitly call ensure_indexes() to test index creation logic.
store.ensure_indexes()
# create_index should have been called at least 4 times
# (the indexes from ensure_indexes)
self.assertGreaterEqual(mock_col.create_index.call_count, 4)
if __name__ == "__main__":
unittest.main()