95 lines
3.4 KiB
Python
95 lines
3.4 KiB
Python
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
from tradingagents.dataflows.trending.sector_classifier import (
|
|
classify_sector,
|
|
TICKER_TO_SECTOR,
|
|
VALID_SECTORS,
|
|
_llm_classify_sector,
|
|
_sector_cache,
|
|
)
|
|
|
|
|
|
class TestStaticSectorMapping:
|
|
def test_static_sector_mapping_for_known_technology_tickers(self):
|
|
assert classify_sector("AAPL") == "technology"
|
|
assert classify_sector("MSFT") == "technology"
|
|
assert classify_sector("GOOGL") == "technology"
|
|
assert classify_sector("NVDA") == "technology"
|
|
|
|
def test_static_sector_mapping_for_known_healthcare_tickers(self):
|
|
assert classify_sector("JNJ") == "healthcare"
|
|
assert classify_sector("PFE") == "healthcare"
|
|
assert classify_sector("UNH") == "healthcare"
|
|
|
|
def test_static_sector_mapping_for_known_finance_tickers(self):
|
|
assert classify_sector("JPM") == "finance"
|
|
assert classify_sector("BAC") == "finance"
|
|
assert classify_sector("GS") == "finance"
|
|
|
|
def test_static_sector_mapping_for_known_energy_tickers(self):
|
|
assert classify_sector("XOM") == "energy"
|
|
assert classify_sector("CVX") == "energy"
|
|
assert classify_sector("COP") == "energy"
|
|
|
|
def test_static_sector_mapping_case_insensitive(self):
|
|
assert classify_sector("aapl") == "technology"
|
|
assert classify_sector("AAPL") == "technology"
|
|
assert classify_sector("Aapl") == "technology"
|
|
|
|
|
|
class TestLLMFallback:
|
|
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
|
|
def test_llm_fallback_for_unknown_tickers(self, mock_llm_classify):
|
|
mock_llm_classify.return_value = "technology"
|
|
_sector_cache.clear()
|
|
|
|
result = classify_sector("UNKNOWNTICKER123")
|
|
|
|
mock_llm_classify.assert_called_once_with("UNKNOWNTICKER123")
|
|
assert result == "technology"
|
|
|
|
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
|
|
def test_llm_fallback_caches_results(self, mock_llm_classify):
|
|
mock_llm_classify.return_value = "healthcare"
|
|
_sector_cache.clear()
|
|
|
|
result1 = classify_sector("NEWCO123")
|
|
result2 = classify_sector("NEWCO123")
|
|
|
|
assert mock_llm_classify.call_count == 1
|
|
assert result1 == "healthcare"
|
|
assert result2 == "healthcare"
|
|
|
|
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
|
|
def test_llm_fallback_returns_other_on_error(self, mock_llm_classify):
|
|
mock_llm_classify.side_effect = Exception("LLM error")
|
|
_sector_cache.clear()
|
|
|
|
result = classify_sector("ERRORCO")
|
|
|
|
assert result == "other"
|
|
|
|
|
|
class TestAllSectorCategories:
|
|
def test_all_sector_categories_in_valid_sectors(self):
|
|
expected_sectors = {
|
|
"technology",
|
|
"healthcare",
|
|
"finance",
|
|
"energy",
|
|
"consumer_goods",
|
|
"industrials",
|
|
"other",
|
|
}
|
|
assert VALID_SECTORS == expected_sectors
|
|
|
|
def test_static_mapping_covers_all_sector_categories(self):
|
|
sectors_in_mapping = set(TICKER_TO_SECTOR.values())
|
|
assert sectors_in_mapping.issubset(VALID_SECTORS)
|
|
|
|
def test_classify_sector_always_returns_valid_sector(self):
|
|
test_tickers = ["AAPL", "JPM", "XOM", "JNJ", "WMT", "CAT"]
|
|
for ticker in test_tickers:
|
|
result = classify_sector(ticker)
|
|
assert result in VALID_SECTORS
|