203 lines
7.5 KiB
Python
203 lines
7.5 KiB
Python
"""Tests for observability integration in LangGraphEngine.
|
|
|
|
Covers:
|
|
- RunLogger lifecycle (_start_run_logger / _finish_run_logger)
|
|
- Enriched tool events (service, status, error fields)
|
|
- Run log JSONL persistence
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
if _project_root not in sys.path:
|
|
sys.path.insert(0, _project_root)
|
|
|
|
from agent_os.backend.services.langgraph_engine import (
|
|
LangGraphEngine,
|
|
_TOOL_SERVICE_MAP,
|
|
)
|
|
from tradingagents.observability import RunLogger, get_run_logger, set_run_logger
|
|
|
|
|
|
class TestToolServiceMap(unittest.TestCase):
|
|
"""Verify the static tool→service mapping is populated."""
|
|
|
|
def test_known_tools_have_services(self):
|
|
self.assertEqual(_TOOL_SERVICE_MAP["get_stock_data"], "yfinance")
|
|
self.assertEqual(_TOOL_SERVICE_MAP["get_insider_transactions"], "finnhub")
|
|
self.assertEqual(_TOOL_SERVICE_MAP["get_insider_buying_stocks"], "finviz")
|
|
self.assertEqual(_TOOL_SERVICE_MAP["get_enriched_holdings"], "local")
|
|
|
|
def test_map_is_non_empty(self):
|
|
self.assertGreater(len(_TOOL_SERVICE_MAP), 20)
|
|
|
|
|
|
class TestRunLoggerLifecycle(unittest.TestCase):
|
|
"""Test _start_run_logger and _finish_run_logger."""
|
|
|
|
def setUp(self):
|
|
self.engine = LangGraphEngine()
|
|
# Clean up any leftover thread-local state
|
|
set_run_logger(None)
|
|
|
|
def tearDown(self):
|
|
set_run_logger(None)
|
|
|
|
def test_start_creates_logger_and_sets_thread_local(self):
|
|
rl = self.engine._start_run_logger("test-run-1")
|
|
self.assertIsInstance(rl, RunLogger)
|
|
self.assertIs(self.engine._run_loggers.get("test-run-1"), rl)
|
|
self.assertIs(get_run_logger(), rl)
|
|
|
|
def test_finish_writes_log_and_cleans_up(self):
|
|
rl = self.engine._start_run_logger("test-run-2")
|
|
# Add a synthetic event
|
|
rl.log_tool_call("get_stock_data", "AAPL", True, 123.4)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
log_dir = Path(tmpdir) / "sub"
|
|
self.engine._finish_run_logger("test-run-2", log_dir)
|
|
|
|
# Logger removed from tracking
|
|
self.assertNotIn("test-run-2", self.engine._run_loggers)
|
|
# Thread-local cleared
|
|
self.assertIsNone(get_run_logger())
|
|
|
|
# JSONL file written
|
|
log_file = log_dir / "run_log.jsonl"
|
|
self.assertTrue(log_file.exists())
|
|
lines = log_file.read_text().strip().split("\n")
|
|
self.assertGreaterEqual(len(lines), 2) # event + summary
|
|
|
|
# Verify first line is the tool event
|
|
evt = json.loads(lines[0])
|
|
self.assertEqual(evt["kind"], "tool")
|
|
self.assertEqual(evt["tool"], "get_stock_data")
|
|
|
|
# Last line should be summary
|
|
summary = json.loads(lines[-1])
|
|
self.assertEqual(summary["kind"], "summary")
|
|
self.assertEqual(summary["tool_calls"], 1)
|
|
|
|
def test_finish_noop_for_unknown_run(self):
|
|
"""_finish_run_logger should silently do nothing for unknown run IDs."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
self.engine._finish_run_logger("nonexistent", Path(tmpdir))
|
|
# No file written, no crash
|
|
self.assertEqual(list(Path(tmpdir).iterdir()), [])
|
|
|
|
|
|
class TestToolEventMapping(unittest.TestCase):
|
|
"""Test enriched tool events in _map_langgraph_event."""
|
|
|
|
def setUp(self):
|
|
self.engine = LangGraphEngine()
|
|
self.run_id = "test-tool-run"
|
|
self.engine._node_start_times[self.run_id] = {}
|
|
self.engine._run_identifiers[self.run_id] = "AAPL"
|
|
self.engine._node_prompts[self.run_id] = {}
|
|
|
|
def tearDown(self):
|
|
self.engine._node_start_times.pop(self.run_id, None)
|
|
self.engine._run_identifiers.pop(self.run_id, None)
|
|
self.engine._node_prompts.pop(self.run_id, None)
|
|
|
|
def test_tool_start_includes_service(self):
|
|
event = {
|
|
"event": "on_tool_start",
|
|
"name": "get_stock_data",
|
|
"data": {"input": {"ticker": "AAPL"}},
|
|
"run_id": "abc123",
|
|
"metadata": {"langgraph_node": "market_analyst"},
|
|
}
|
|
result = self.engine._map_langgraph_event(self.run_id, event)
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result["type"], "tool")
|
|
self.assertEqual(result["service"], "yfinance")
|
|
self.assertEqual(result["status"], "running")
|
|
|
|
def test_tool_start_unknown_tool_has_empty_service(self):
|
|
event = {
|
|
"event": "on_tool_start",
|
|
"name": "some_custom_tool",
|
|
"data": {"input": "test"},
|
|
"run_id": "abc123",
|
|
"metadata": {"langgraph_node": "custom_node"},
|
|
}
|
|
result = self.engine._map_langgraph_event(self.run_id, event)
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result["service"], "")
|
|
|
|
def test_tool_end_success(self):
|
|
event = {
|
|
"event": "on_tool_end",
|
|
"name": "get_fundamentals",
|
|
"data": {"output": MagicMock(content="PE ratio: 25.3, Revenue: $100B")},
|
|
"run_id": "abc123",
|
|
"metadata": {"langgraph_node": "fundamentals_analyst"},
|
|
}
|
|
result = self.engine._map_langgraph_event(self.run_id, event)
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result["type"], "tool_result")
|
|
self.assertEqual(result["status"], "success")
|
|
self.assertEqual(result["service"], "yfinance")
|
|
self.assertIsNone(result["error"])
|
|
self.assertIn("✓", result["message"])
|
|
|
|
def test_tool_end_error_detected(self):
|
|
mock_output = MagicMock()
|
|
mock_output.content = "Error calling get_stock_data: ConnectionError: timeout"
|
|
event = {
|
|
"event": "on_tool_end",
|
|
"name": "get_stock_data",
|
|
"data": {"output": mock_output},
|
|
"run_id": "abc123",
|
|
"metadata": {"langgraph_node": "market_analyst"},
|
|
}
|
|
result = self.engine._map_langgraph_event(self.run_id, event)
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result["status"], "error")
|
|
self.assertIn("Error", result["error"])
|
|
self.assertIn("✗", result["message"])
|
|
|
|
def test_tool_end_graceful_skip(self):
|
|
mock_output = MagicMock()
|
|
mock_output.content = "Data gracefully skipped due to rate limit"
|
|
event = {
|
|
"event": "on_tool_end",
|
|
"name": "get_insider_transactions",
|
|
"data": {"output": mock_output},
|
|
"run_id": "abc123",
|
|
"metadata": {"langgraph_node": "news_analyst"},
|
|
}
|
|
result = self.engine._map_langgraph_event(self.run_id, event)
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result["status"], "graceful_skip")
|
|
self.assertEqual(result["service"], "finnhub")
|
|
self.assertIn("⚠", result["message"])
|
|
|
|
def test_tool_end_event_status_error(self):
|
|
"""When the event itself has status='error', detect it."""
|
|
event = {
|
|
"event": "on_tool_end",
|
|
"name": "get_earnings_calendar",
|
|
"data": {"output": MagicMock(content=""), "status": "error"},
|
|
"run_id": "abc123",
|
|
"metadata": {"langgraph_node": "sector_scanner"},
|
|
}
|
|
result = self.engine._map_langgraph_event(self.run_id, event)
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result["status"], "error")
|
|
self.assertEqual(result["service"], "finnhub")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|