268 lines
6.9 KiB
Python
268 lines
6.9 KiB
Python
import logging
|
|
from typing import Dict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
VALID_SECTORS = {
|
|
"technology",
|
|
"healthcare",
|
|
"finance",
|
|
"energy",
|
|
"consumer_goods",
|
|
"industrials",
|
|
"other",
|
|
}
|
|
|
|
TICKER_TO_SECTOR: Dict[str, str] = {
|
|
"AAPL": "technology",
|
|
"MSFT": "technology",
|
|
"GOOGL": "technology",
|
|
"GOOG": "technology",
|
|
"AMZN": "technology",
|
|
"META": "technology",
|
|
"NVDA": "technology",
|
|
"TSLA": "technology",
|
|
"AMD": "technology",
|
|
"INTC": "technology",
|
|
"QCOM": "technology",
|
|
"AVGO": "technology",
|
|
"TXN": "technology",
|
|
"ADBE": "technology",
|
|
"CRM": "technology",
|
|
"CSCO": "technology",
|
|
"NFLX": "technology",
|
|
"ORCL": "technology",
|
|
"IBM": "technology",
|
|
"NOW": "technology",
|
|
"INTU": "technology",
|
|
"ADSK": "technology",
|
|
"SNPS": "technology",
|
|
"CDNS": "technology",
|
|
"PLTR": "technology",
|
|
"SNOW": "technology",
|
|
"DDOG": "technology",
|
|
"CRWD": "technology",
|
|
"OKTA": "technology",
|
|
"NET": "technology",
|
|
"MDB": "technology",
|
|
"TWLO": "technology",
|
|
"WDAY": "technology",
|
|
"SPLK": "technology",
|
|
"VMW": "technology",
|
|
"HPQ": "technology",
|
|
"DELL": "technology",
|
|
"FTNT": "technology",
|
|
"PANW": "technology",
|
|
"ZS": "technology",
|
|
"S": "technology",
|
|
"VEEV": "technology",
|
|
"ZM": "technology",
|
|
"DOCU": "technology",
|
|
"ASAN": "technology",
|
|
"MNDY": "technology",
|
|
"TEAM": "technology",
|
|
"ANSS": "technology",
|
|
"ROP": "technology",
|
|
"JPM": "finance",
|
|
"BAC": "finance",
|
|
"WFC": "finance",
|
|
"GS": "finance",
|
|
"MS": "finance",
|
|
"C": "finance",
|
|
"BLK": "finance",
|
|
"SCHW": "finance",
|
|
"AXP": "finance",
|
|
"V": "finance",
|
|
"MA": "finance",
|
|
"PYPL": "finance",
|
|
"SQ": "finance",
|
|
"COIN": "finance",
|
|
"HOOD": "finance",
|
|
"SOFI": "finance",
|
|
"AFRM": "finance",
|
|
"MQ": "finance",
|
|
"BRK-B": "finance",
|
|
"BRK-A": "finance",
|
|
"JNJ": "healthcare",
|
|
"UNH": "healthcare",
|
|
"PFE": "healthcare",
|
|
"ABBV": "healthcare",
|
|
"MRK": "healthcare",
|
|
"LLY": "healthcare",
|
|
"MRNA": "healthcare",
|
|
"BNTX": "healthcare",
|
|
"CVS": "healthcare",
|
|
"WBA": "healthcare",
|
|
"MCK": "healthcare",
|
|
"CAH": "healthcare",
|
|
"HUM": "healthcare",
|
|
"CI": "healthcare",
|
|
"ELV": "healthcare",
|
|
"XOM": "energy",
|
|
"CVX": "energy",
|
|
"COP": "energy",
|
|
"SLB": "energy",
|
|
"HAL": "energy",
|
|
"BKR": "energy",
|
|
"MPC": "energy",
|
|
"VLO": "energy",
|
|
"PSX": "energy",
|
|
"OXY": "energy",
|
|
"PXD": "energy",
|
|
"DVN": "energy",
|
|
"CEG": "energy",
|
|
"NEE": "energy",
|
|
"DUK": "energy",
|
|
"SO": "energy",
|
|
"D": "energy",
|
|
"SRE": "energy",
|
|
"WMT": "consumer_goods",
|
|
"COST": "consumer_goods",
|
|
"TGT": "consumer_goods",
|
|
"HD": "consumer_goods",
|
|
"LOW": "consumer_goods",
|
|
"PG": "consumer_goods",
|
|
"KO": "consumer_goods",
|
|
"PEP": "consumer_goods",
|
|
"NKE": "consumer_goods",
|
|
"SBUX": "consumer_goods",
|
|
"MCD": "consumer_goods",
|
|
"CMG": "consumer_goods",
|
|
"YUM": "consumer_goods",
|
|
"DPZ": "consumer_goods",
|
|
"DIS": "consumer_goods",
|
|
"CMCSA": "consumer_goods",
|
|
"VZ": "consumer_goods",
|
|
"T": "consumer_goods",
|
|
"TMUS": "consumer_goods",
|
|
"EL": "consumer_goods",
|
|
"CL": "consumer_goods",
|
|
"KMB": "consumer_goods",
|
|
"CLX": "consumer_goods",
|
|
"KHC": "consumer_goods",
|
|
"GIS": "consumer_goods",
|
|
"K": "consumer_goods",
|
|
"MDLZ": "consumer_goods",
|
|
"HSY": "consumer_goods",
|
|
"TSN": "consumer_goods",
|
|
"BYND": "consumer_goods",
|
|
"CAG": "consumer_goods",
|
|
"STZ": "consumer_goods",
|
|
"BUD": "consumer_goods",
|
|
"DEO": "consumer_goods",
|
|
"PM": "consumer_goods",
|
|
"MO": "consumer_goods",
|
|
"LULU": "consumer_goods",
|
|
"DG": "consumer_goods",
|
|
"DLTR": "consumer_goods",
|
|
"ROST": "consumer_goods",
|
|
"TJX": "consumer_goods",
|
|
"AZO": "consumer_goods",
|
|
"ORLY": "consumer_goods",
|
|
"KMX": "consumer_goods",
|
|
"ADDYY": "consumer_goods",
|
|
"UBER": "consumer_goods",
|
|
"LYFT": "consumer_goods",
|
|
"ABNB": "consumer_goods",
|
|
"DASH": "consumer_goods",
|
|
"SNAP": "consumer_goods",
|
|
"PINS": "consumer_goods",
|
|
"TWTR": "consumer_goods",
|
|
"SHOP": "consumer_goods",
|
|
"TOST": "consumer_goods",
|
|
"BA": "industrials",
|
|
"LMT": "industrials",
|
|
"RTX": "industrials",
|
|
"GD": "industrials",
|
|
"NOC": "industrials",
|
|
"GE": "industrials",
|
|
"HON": "industrials",
|
|
"MMM": "industrials",
|
|
"CAT": "industrials",
|
|
"DE": "industrials",
|
|
"UNP": "industrials",
|
|
"UPS": "industrials",
|
|
"FDX": "industrials",
|
|
"DAL": "industrials",
|
|
"UAL": "industrials",
|
|
"AAL": "industrials",
|
|
"LUV": "industrials",
|
|
"F": "industrials",
|
|
"GM": "industrials",
|
|
"TM": "industrials",
|
|
"HMC": "industrials",
|
|
"VWAGY": "industrials",
|
|
"RACE": "industrials",
|
|
"RIVN": "industrials",
|
|
"LCID": "industrials",
|
|
"NIO": "industrials",
|
|
"LNVGY": "industrials",
|
|
}
|
|
|
|
_sector_cache: Dict[str, str] = {}
|
|
|
|
|
|
def _llm_classify_sector(ticker: str) -> str:
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.messages import HumanMessage, SystemMessage
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
|
|
llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini")
|
|
llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai")
|
|
backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1")
|
|
|
|
llm = ChatOpenAI(
|
|
model=llm_name,
|
|
base_url=backend_url,
|
|
temperature=0,
|
|
)
|
|
|
|
system_prompt = (
|
|
"You are a financial sector classifier. Given a stock ticker symbol, "
|
|
"classify it into exactly one of the following sectors: "
|
|
"technology, healthcare, finance, energy, consumer_goods, industrials, other. "
|
|
"Respond with only the sector name in lowercase, nothing else."
|
|
)
|
|
|
|
user_prompt = f"Classify the stock ticker: {ticker}"
|
|
|
|
messages = [
|
|
SystemMessage(content=system_prompt),
|
|
HumanMessage(content=user_prompt),
|
|
]
|
|
|
|
response = llm.invoke(messages)
|
|
sector = response.content.strip().lower()
|
|
|
|
if sector not in VALID_SECTORS:
|
|
logger.warning(
|
|
"LLM returned invalid sector '%s' for ticker %s, defaulting to 'other'",
|
|
sector,
|
|
ticker,
|
|
)
|
|
return "other"
|
|
|
|
return sector
|
|
|
|
|
|
def classify_sector(ticker: str) -> str:
|
|
ticker_upper = ticker.upper()
|
|
|
|
if ticker_upper in TICKER_TO_SECTOR:
|
|
return TICKER_TO_SECTOR[ticker_upper]
|
|
|
|
if ticker_upper in _sector_cache:
|
|
return _sector_cache[ticker_upper]
|
|
|
|
logger.info("Using LLM fallback for sector classification of ticker: %s", ticker)
|
|
|
|
try:
|
|
sector = _llm_classify_sector(ticker_upper)
|
|
_sector_cache[ticker_upper] = sector
|
|
logger.info("Classified %s as %s via LLM", ticker, sector)
|
|
return sector
|
|
except Exception as e:
|
|
logger.error("LLM sector classification failed for %s: %s", ticker, str(e))
|
|
_sector_cache[ticker_upper] = "other"
|
|
return "other"
|