TradingAgents/tradingagents/dataflows/trending/sector_classifier.py

268 lines
6.9 KiB
Python

import logging
from typing import Dict
logger = logging.getLogger(__name__)
VALID_SECTORS = {
"technology",
"healthcare",
"finance",
"energy",
"consumer_goods",
"industrials",
"other",
}
TICKER_TO_SECTOR: Dict[str, str] = {
"AAPL": "technology",
"MSFT": "technology",
"GOOGL": "technology",
"GOOG": "technology",
"AMZN": "technology",
"META": "technology",
"NVDA": "technology",
"TSLA": "technology",
"AMD": "technology",
"INTC": "technology",
"QCOM": "technology",
"AVGO": "technology",
"TXN": "technology",
"ADBE": "technology",
"CRM": "technology",
"CSCO": "technology",
"NFLX": "technology",
"ORCL": "technology",
"IBM": "technology",
"NOW": "technology",
"INTU": "technology",
"ADSK": "technology",
"SNPS": "technology",
"CDNS": "technology",
"PLTR": "technology",
"SNOW": "technology",
"DDOG": "technology",
"CRWD": "technology",
"OKTA": "technology",
"NET": "technology",
"MDB": "technology",
"TWLO": "technology",
"WDAY": "technology",
"SPLK": "technology",
"VMW": "technology",
"HPQ": "technology",
"DELL": "technology",
"FTNT": "technology",
"PANW": "technology",
"ZS": "technology",
"S": "technology",
"VEEV": "technology",
"ZM": "technology",
"DOCU": "technology",
"ASAN": "technology",
"MNDY": "technology",
"TEAM": "technology",
"ANSS": "technology",
"ROP": "technology",
"JPM": "finance",
"BAC": "finance",
"WFC": "finance",
"GS": "finance",
"MS": "finance",
"C": "finance",
"BLK": "finance",
"SCHW": "finance",
"AXP": "finance",
"V": "finance",
"MA": "finance",
"PYPL": "finance",
"SQ": "finance",
"COIN": "finance",
"HOOD": "finance",
"SOFI": "finance",
"AFRM": "finance",
"MQ": "finance",
"BRK-B": "finance",
"BRK-A": "finance",
"JNJ": "healthcare",
"UNH": "healthcare",
"PFE": "healthcare",
"ABBV": "healthcare",
"MRK": "healthcare",
"LLY": "healthcare",
"MRNA": "healthcare",
"BNTX": "healthcare",
"CVS": "healthcare",
"WBA": "healthcare",
"MCK": "healthcare",
"CAH": "healthcare",
"HUM": "healthcare",
"CI": "healthcare",
"ELV": "healthcare",
"XOM": "energy",
"CVX": "energy",
"COP": "energy",
"SLB": "energy",
"HAL": "energy",
"BKR": "energy",
"MPC": "energy",
"VLO": "energy",
"PSX": "energy",
"OXY": "energy",
"PXD": "energy",
"DVN": "energy",
"CEG": "energy",
"NEE": "energy",
"DUK": "energy",
"SO": "energy",
"D": "energy",
"SRE": "energy",
"WMT": "consumer_goods",
"COST": "consumer_goods",
"TGT": "consumer_goods",
"HD": "consumer_goods",
"LOW": "consumer_goods",
"PG": "consumer_goods",
"KO": "consumer_goods",
"PEP": "consumer_goods",
"NKE": "consumer_goods",
"SBUX": "consumer_goods",
"MCD": "consumer_goods",
"CMG": "consumer_goods",
"YUM": "consumer_goods",
"DPZ": "consumer_goods",
"DIS": "consumer_goods",
"CMCSA": "consumer_goods",
"VZ": "consumer_goods",
"T": "consumer_goods",
"TMUS": "consumer_goods",
"EL": "consumer_goods",
"CL": "consumer_goods",
"KMB": "consumer_goods",
"CLX": "consumer_goods",
"KHC": "consumer_goods",
"GIS": "consumer_goods",
"K": "consumer_goods",
"MDLZ": "consumer_goods",
"HSY": "consumer_goods",
"TSN": "consumer_goods",
"BYND": "consumer_goods",
"CAG": "consumer_goods",
"STZ": "consumer_goods",
"BUD": "consumer_goods",
"DEO": "consumer_goods",
"PM": "consumer_goods",
"MO": "consumer_goods",
"LULU": "consumer_goods",
"DG": "consumer_goods",
"DLTR": "consumer_goods",
"ROST": "consumer_goods",
"TJX": "consumer_goods",
"AZO": "consumer_goods",
"ORLY": "consumer_goods",
"KMX": "consumer_goods",
"ADDYY": "consumer_goods",
"UBER": "consumer_goods",
"LYFT": "consumer_goods",
"ABNB": "consumer_goods",
"DASH": "consumer_goods",
"SNAP": "consumer_goods",
"PINS": "consumer_goods",
"TWTR": "consumer_goods",
"SHOP": "consumer_goods",
"TOST": "consumer_goods",
"BA": "industrials",
"LMT": "industrials",
"RTX": "industrials",
"GD": "industrials",
"NOC": "industrials",
"GE": "industrials",
"HON": "industrials",
"MMM": "industrials",
"CAT": "industrials",
"DE": "industrials",
"UNP": "industrials",
"UPS": "industrials",
"FDX": "industrials",
"DAL": "industrials",
"UAL": "industrials",
"AAL": "industrials",
"LUV": "industrials",
"F": "industrials",
"GM": "industrials",
"TM": "industrials",
"HMC": "industrials",
"VWAGY": "industrials",
"RACE": "industrials",
"RIVN": "industrials",
"LCID": "industrials",
"NIO": "industrials",
"LNVGY": "industrials",
}
_sector_cache: Dict[str, str] = {}
def _llm_classify_sector(ticker: str) -> str:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from tradingagents.default_config import DEFAULT_CONFIG
llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini")
llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai")
backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1")
llm = ChatOpenAI(
model=llm_name,
base_url=backend_url,
temperature=0,
)
system_prompt = (
"You are a financial sector classifier. Given a stock ticker symbol, "
"classify it into exactly one of the following sectors: "
"technology, healthcare, finance, energy, consumer_goods, industrials, other. "
"Respond with only the sector name in lowercase, nothing else."
)
user_prompt = f"Classify the stock ticker: {ticker}"
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt),
]
response = llm.invoke(messages)
sector = response.content.strip().lower()
if sector not in VALID_SECTORS:
logger.warning(
"LLM returned invalid sector '%s' for ticker %s, defaulting to 'other'",
sector,
ticker,
)
return "other"
return sector
def classify_sector(ticker: str) -> str:
ticker_upper = ticker.upper()
if ticker_upper in TICKER_TO_SECTOR:
return TICKER_TO_SECTOR[ticker_upper]
if ticker_upper in _sector_cache:
return _sector_cache[ticker_upper]
logger.info("Using LLM fallback for sector classification of ticker: %s", ticker)
try:
sector = _llm_classify_sector(ticker_upper)
_sector_cache[ticker_upper] = sector
logger.info("Classified %s as %s via LLM", ticker, sector)
return sector
except Exception as e:
logger.error("LLM sector classification failed for %s: %s", ticker, str(e))
_sector_cache[ticker_upper] = "other"
return "other"