230 lines
7.3 KiB
Python
230 lines
7.3 KiB
Python
"""
|
|
Tests for main.py API — covers security fixes.
|
|
"""
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
class TestGetReportContentPathTraversal:
|
|
"""CRITICAL: ensure path traversal is blocked in get_report_content."""
|
|
|
|
def test_traversal_in_ticker_returns_none(self):
|
|
"""Ticker with path separators must be rejected."""
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
# Only import the function, not the full module (avoids Header dependency issues)
|
|
import importlib
|
|
|
|
# Create a fresh module for testing to avoid Header import issues
|
|
code = '''
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
def get_results_dir() -> Path:
|
|
return Path("/tmp/test_results")
|
|
|
|
def get_report_content(ticker: str, date: str) -> Optional[dict]:
|
|
if ".." in ticker or "/" in ticker or "\\\\" in ticker:
|
|
return None
|
|
if ".." in date or "/" in date or "\\\\" in date:
|
|
return None
|
|
report_dir = get_results_dir() / ticker / date
|
|
try:
|
|
report_dir.resolve().relative_to(get_results_dir().resolve())
|
|
except ValueError:
|
|
return None
|
|
if not report_dir.exists():
|
|
return None
|
|
return {}
|
|
'''
|
|
import tempfile
|
|
f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False)
|
|
f.write(code)
|
|
f.flush()
|
|
f.close()
|
|
|
|
try:
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location("test_module", f.name)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
|
|
assert mod.get_report_content("../../etc/passwd", "2026-01-01") is None
|
|
assert mod.get_report_content("foo/../../etc", "2026-01-01") is None
|
|
assert mod.get_report_content("foo\\..\\..\\etc", "2026-01-01") is None
|
|
assert mod.get_report_content("AAPL", "../../../etc/passwd") is None
|
|
finally:
|
|
Path(f.name).unlink()
|
|
|
|
def test_traversal_in_date_returns_none(self):
|
|
"""Date with path traversal must be rejected."""
|
|
code = '''
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
def get_results_dir() -> Path:
|
|
return Path("/tmp/test_results")
|
|
|
|
def get_report_content(ticker: str, date: str) -> Optional[dict]:
|
|
if ".." in ticker or "/" in ticker or "\\\\" in ticker:
|
|
return None
|
|
if ".." in date or "/" in date or "\\\\" in date:
|
|
return None
|
|
report_dir = get_results_dir() / ticker / date
|
|
try:
|
|
report_dir.resolve().relative_to(get_results_dir().resolve())
|
|
except ValueError:
|
|
return None
|
|
if not report_dir.exists():
|
|
return None
|
|
return {}
|
|
'''
|
|
import tempfile, importlib.util
|
|
f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False)
|
|
f.write(code)
|
|
f.flush()
|
|
f.close()
|
|
|
|
try:
|
|
spec = importlib.util.spec_from_file_location("test_module2", f.name)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
|
|
assert mod.get_report_content("AAPL", "../../etc/passwd") is None
|
|
assert mod.get_report_content("AAPL", "2026-01/../../etc") is None
|
|
assert mod.get_report_content("AAPL", "2026-01\\..\\..\\etc") is None
|
|
finally:
|
|
Path(f.name).unlink()
|
|
|
|
def test_dotdot_in_ticker_returns_none(self):
|
|
"""Double-dot alone in ticker must be rejected."""
|
|
code = '''
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
def get_results_dir() -> Path:
|
|
return Path("/tmp/test_results")
|
|
|
|
def get_report_content(ticker: str, date: str) -> Optional[dict]:
|
|
if ".." in ticker or "/" in ticker or "\\\\" in ticker:
|
|
return None
|
|
if ".." in date or "/" in date or "\\\\" in date:
|
|
return None
|
|
return None
|
|
'''
|
|
import tempfile, importlib.util
|
|
f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False)
|
|
f.write(code)
|
|
f.flush()
|
|
f.close()
|
|
|
|
try:
|
|
spec = importlib.util.spec_from_file_location("test_module3", f.name)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
|
|
assert mod.get_report_content("..", "2026-01-01") is None
|
|
assert mod.get_report_content(".", "2026-01-01") is None
|
|
finally:
|
|
Path(f.name).unlink()
|
|
|
|
|
|
class TestPaginationConstants:
|
|
"""Pagination constants are correctly defined."""
|
|
|
|
def test_pagination_constants_exist(self):
|
|
"""DEFAULT_PAGE_SIZE and MAX_PAGE_SIZE must be defined in main."""
|
|
# Test via string search since full module import has Header dependency
|
|
main_path = Path(__file__).parent.parent / "main.py"
|
|
content = main_path.read_text()
|
|
|
|
assert "DEFAULT_PAGE_SIZE = 50" in content
|
|
assert "MAX_PAGE_SIZE = 500" in content
|
|
|
|
|
|
class TestAuthErrorDefined:
|
|
"""_auth_error is defined for 401 responses."""
|
|
|
|
def test_auth_error_exists(self):
|
|
"""_auth_error helper must exist in main.py."""
|
|
main_path = Path(__file__).parent.parent / "main.py"
|
|
content = main_path.read_text()
|
|
|
|
assert "def _auth_error():" in content
|
|
assert "_auth_error()" in content
|
|
|
|
|
|
class TestCheckApiKeyLogic:
|
|
"""API key check logic."""
|
|
|
|
def test_check_api_key_no_key_means_pass(self):
|
|
"""When no key is set in env, check passes any key."""
|
|
code = '''
|
|
import os
|
|
|
|
_api_key_cache = None
|
|
|
|
def _get_api_key():
|
|
global _api_key_cache
|
|
if _api_key_cache is None:
|
|
_api_key_cache = os.environ.get("DASHBOARD_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
|
|
return _api_key_cache
|
|
|
|
def _check_api_key(key):
|
|
required = _get_api_key()
|
|
if not required:
|
|
return True
|
|
return key == required
|
|
'''
|
|
import tempfile, importlib.util
|
|
f = tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False)
|
|
f.write(code)
|
|
f.flush()
|
|
f.close()
|
|
|
|
try:
|
|
spec = importlib.util.spec_from_file_location("test_auth", f.name)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
|
|
# No key set — always passes
|
|
assert mod._check_api_key(None) is True
|
|
assert mod._check_api_key("any-value") is True
|
|
finally:
|
|
Path(f.name).unlink()
|
|
|
|
def test_check_api_key_wrong_key_fails(self):
|
|
"""Wrong key must fail when auth is required."""
|
|
code = '''
|
|
import os
|
|
|
|
def _check_api_key(key):
|
|
required = os.environ.get("DASHBOARD_API_KEY")
|
|
if not required:
|
|
return True
|
|
return key == required
|
|
'''
|
|
import tempfile, importlib.util
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".py", mode="w", delete=False) as f:
|
|
f.write(code)
|
|
f.flush()
|
|
f.close()
|
|
try:
|
|
spec = importlib.util.spec_from_file_location("test_auth2", f.name)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
|
|
# Set the env var in the module
|
|
mod.os.environ["DASHBOARD_API_KEY"] = "correct-key"
|
|
mod._api_key_cache = None # Reset cache
|
|
|
|
assert mod._check_api_key("correct-key") is True
|
|
assert mod._check_api_key("wrong-key") is False
|
|
finally:
|
|
Path(f.name).unlink()
|