TradingAgents/tests/discovery/test_sector_classifier.py

95 lines
3.4 KiB
Python

import pytest
from unittest.mock import patch, MagicMock
from tradingagents.dataflows.trending.sector_classifier import (
classify_sector,
TICKER_TO_SECTOR,
VALID_SECTORS,
_llm_classify_sector,
_sector_cache,
)
class TestStaticSectorMapping:
def test_static_sector_mapping_for_known_technology_tickers(self):
assert classify_sector("AAPL") == "technology"
assert classify_sector("MSFT") == "technology"
assert classify_sector("GOOGL") == "technology"
assert classify_sector("NVDA") == "technology"
def test_static_sector_mapping_for_known_healthcare_tickers(self):
assert classify_sector("JNJ") == "healthcare"
assert classify_sector("PFE") == "healthcare"
assert classify_sector("UNH") == "healthcare"
def test_static_sector_mapping_for_known_finance_tickers(self):
assert classify_sector("JPM") == "finance"
assert classify_sector("BAC") == "finance"
assert classify_sector("GS") == "finance"
def test_static_sector_mapping_for_known_energy_tickers(self):
assert classify_sector("XOM") == "energy"
assert classify_sector("CVX") == "energy"
assert classify_sector("COP") == "energy"
def test_static_sector_mapping_case_insensitive(self):
assert classify_sector("aapl") == "technology"
assert classify_sector("AAPL") == "technology"
assert classify_sector("Aapl") == "technology"
class TestLLMFallback:
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
def test_llm_fallback_for_unknown_tickers(self, mock_llm_classify):
mock_llm_classify.return_value = "technology"
_sector_cache.clear()
result = classify_sector("UNKNOWNTICKER123")
mock_llm_classify.assert_called_once_with("UNKNOWNTICKER123")
assert result == "technology"
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
def test_llm_fallback_caches_results(self, mock_llm_classify):
mock_llm_classify.return_value = "healthcare"
_sector_cache.clear()
result1 = classify_sector("NEWCO123")
result2 = classify_sector("NEWCO123")
assert mock_llm_classify.call_count == 1
assert result1 == "healthcare"
assert result2 == "healthcare"
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
def test_llm_fallback_returns_other_on_error(self, mock_llm_classify):
mock_llm_classify.side_effect = Exception("LLM error")
_sector_cache.clear()
result = classify_sector("ERRORCO")
assert result == "other"
class TestAllSectorCategories:
def test_all_sector_categories_in_valid_sectors(self):
expected_sectors = {
"technology",
"healthcare",
"finance",
"energy",
"consumer_goods",
"industrials",
"other",
}
assert VALID_SECTORS == expected_sectors
def test_static_mapping_covers_all_sector_categories(self):
sectors_in_mapping = set(TICKER_TO_SECTOR.values())
assert sectors_in_mapping.issubset(VALID_SECTORS)
def test_classify_sector_always_returns_valid_sector(self):
test_tickers = ["AAPL", "JPM", "XOM", "JNJ", "WMT", "CAT"]
for ticker in test_tickers:
result = classify_sector(ticker)
assert result in VALID_SECTORS