TradingAgents/web_dashboard/backend/tests/test_main_api.py

230 lines
7.3 KiB
Python

"""
Tests for main.py API — covers security fixes.
"""
import json
import os
import sys
import tempfile
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
class TestGetReportContentPathTraversal:
"""CRITICAL: ensure path traversal is blocked in get_report_content."""
def test_traversal_in_ticker_returns_none(self):
"""Ticker with path separators must be rejected."""
sys.path.insert(0, str(Path(__file__).parent.parent))
# Only import the function, not the full module (avoids Header dependency issues)
import importlib
# Create a fresh module for testing to avoid Header import issues
code = '''
from pathlib import Path
from typing import Optional
def get_results_dir() -> Path:
return Path("/tmp/test_results")
def get_report_content(ticker: str, date: str) -> Optional[dict]:
if ".." in ticker or "/" in ticker or "\\\\" in ticker:
return None
if ".." in date or "/" in date or "\\\\" in date:
return None
report_dir = get_results_dir() / ticker / date
try:
report_dir.resolve().relative_to(get_results_dir().resolve())
except ValueError:
return None
if not report_dir.exists():
return None
return {}
'''
import tempfile
f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False)
f.write(code)
f.flush()
f.close()
try:
import importlib.util
spec = importlib.util.spec_from_file_location("test_module", f.name)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
assert mod.get_report_content("../../etc/passwd", "2026-01-01") is None
assert mod.get_report_content("foo/../../etc", "2026-01-01") is None
assert mod.get_report_content("foo\\..\\..\\etc", "2026-01-01") is None
assert mod.get_report_content("AAPL", "../../../etc/passwd") is None
finally:
Path(f.name).unlink()
def test_traversal_in_date_returns_none(self):
"""Date with path traversal must be rejected."""
code = '''
from pathlib import Path
from typing import Optional
def get_results_dir() -> Path:
return Path("/tmp/test_results")
def get_report_content(ticker: str, date: str) -> Optional[dict]:
if ".." in ticker or "/" in ticker or "\\\\" in ticker:
return None
if ".." in date or "/" in date or "\\\\" in date:
return None
report_dir = get_results_dir() / ticker / date
try:
report_dir.resolve().relative_to(get_results_dir().resolve())
except ValueError:
return None
if not report_dir.exists():
return None
return {}
'''
import tempfile, importlib.util
f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False)
f.write(code)
f.flush()
f.close()
try:
spec = importlib.util.spec_from_file_location("test_module2", f.name)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
assert mod.get_report_content("AAPL", "../../etc/passwd") is None
assert mod.get_report_content("AAPL", "2026-01/../../etc") is None
assert mod.get_report_content("AAPL", "2026-01\\..\\..\\etc") is None
finally:
Path(f.name).unlink()
def test_dotdot_in_ticker_returns_none(self):
"""Double-dot alone in ticker must be rejected."""
code = '''
from pathlib import Path
from typing import Optional
def get_results_dir() -> Path:
return Path("/tmp/test_results")
def get_report_content(ticker: str, date: str) -> Optional[dict]:
if ".." in ticker or "/" in ticker or "\\\\" in ticker:
return None
if ".." in date or "/" in date or "\\\\" in date:
return None
return None
'''
import tempfile, importlib.util
f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False)
f.write(code)
f.flush()
f.close()
try:
spec = importlib.util.spec_from_file_location("test_module3", f.name)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
assert mod.get_report_content("..", "2026-01-01") is None
assert mod.get_report_content(".", "2026-01-01") is None
finally:
Path(f.name).unlink()
class TestPaginationConstants:
"""Pagination constants are correctly defined."""
def test_pagination_constants_exist(self):
"""DEFAULT_PAGE_SIZE and MAX_PAGE_SIZE must be defined in main."""
# Test via string search since full module import has Header dependency
main_path = Path(__file__).parent.parent / "main.py"
content = main_path.read_text()
assert "DEFAULT_PAGE_SIZE = 50" in content
assert "MAX_PAGE_SIZE = 500" in content
class TestAuthErrorDefined:
"""_auth_error is defined for 401 responses."""
def test_auth_error_exists(self):
"""_auth_error helper must exist in main.py."""
main_path = Path(__file__).parent.parent / "main.py"
content = main_path.read_text()
assert "def _auth_error():" in content
assert "_auth_error()" in content
class TestCheckApiKeyLogic:
"""API key check logic."""
def test_check_api_key_no_key_means_pass(self):
"""When no key is set in env, check passes any key."""
code = '''
import os
_api_key_cache = None
def _get_api_key():
global _api_key_cache
if _api_key_cache is None:
_api_key_cache = os.environ.get("DASHBOARD_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
return _api_key_cache
def _check_api_key(key):
required = _get_api_key()
if not required:
return True
return key == required
'''
import tempfile, importlib.util
f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False)
f.write(code)
f.flush()
f.close()
try:
spec = importlib.util.spec_from_file_location("test_auth", f.name)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# No key set — always passes
assert mod._check_api_key(None) is True
assert mod._check_api_key("any-value") is True
finally:
Path(f.name).unlink()
def test_check_api_key_wrong_key_fails(self):
"""Wrong key must fail when auth is required."""
code = '''
import os
def _check_api_key(key):
required = os.environ.get("DASHBOARD_API_KEY")
if not required:
return True
return key == required
'''
import tempfile, importlib.util
with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False) as f:
f.write(code)
f.flush()
f.close()
try:
spec = importlib.util.spec_from_file_location("test_auth2", f.name)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
# Set the env var in the module
mod.os.environ["DASHBOARD_API_KEY"] = "correct-key"
mod._api_key_cache = None # Reset cache
assert mod._check_api_key("correct-key") is True
assert mod._check_api_key("wrong-key") is False
finally:
Path(f.name).unlink()