New updates
This commit is contained in:
parent
948ad66343
commit
6ad4e0243d
|
|
@ -0,0 +1,524 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
|
||||||
|
def load_sft_model(model_name_or_path: str):
|
||||||
|
"""
|
||||||
|
Load the fine-tuned (SFT) sequence classification model and tokenizer.
|
||||||
|
"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
|
||||||
|
model.eval()
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def classify_with_confidence(
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
model: AutoModelForSequenceClassification,
|
||||||
|
texts: List[str],
|
||||||
|
) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Run sentiment classification and return (label, confidence) for each text.
|
||||||
|
Confidence is defined as max softmax(logits).
|
||||||
|
"""
|
||||||
|
results: List[Tuple[str, float]] = []
|
||||||
|
|
||||||
|
# Batch to speed up a bit
|
||||||
|
batch_size = 16
|
||||||
|
id2label = getattr(model.config, "id2label", None)
|
||||||
|
if not id2label:
|
||||||
|
# Align with finetune_dapt.py label set: Negative, Neutral, Positive
|
||||||
|
id2label = {0: "Negative", 1: "Neutral", 2: "Positive"}
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
for start in range(0, len(texts), batch_size):
|
||||||
|
chunk = texts[start : start + batch_size]
|
||||||
|
enc = tokenizer(
|
||||||
|
chunk,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=256,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
enc = {k: v.to(device) for k, v in enc.items()}
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(**enc)
|
||||||
|
logits = out.logits # [batch, num_labels]
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
confidences, indices = torch.max(probs, dim=-1)
|
||||||
|
for idx in range(len(chunk)):
|
||||||
|
label_idx = indices[idx].item()
|
||||||
|
label = id2label.get(label_idx, str(label_idx))
|
||||||
|
# normalize label casing (positive/negative/neutral)
|
||||||
|
label_norm = label.lower()
|
||||||
|
results.append((label_norm, float(confidences[idx].item())))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def build_ticker_context(company: str, ticker: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a short textual context for the ticker to be used for embeddings.
|
||||||
|
"""
|
||||||
|
# Very lightweight template; can be extended with sector/description if available
|
||||||
|
return f"{company}, {ticker}, company, stock, shares"
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Simple alphanumeric tokenization, lowercased.
|
||||||
|
"""
|
||||||
|
return re.findall(r"[A-Za-z0-9]+", text.lower())
|
||||||
|
|
||||||
|
|
||||||
|
def keyword_boost(title: str, ticker_context: str, company: Optional[str] = None, ticker: Optional[str] = None) -> float:
|
||||||
|
"""
|
||||||
|
Simple, interpretable keyword/meta boost combining:
|
||||||
|
- +0.4 if title explicitly mentions the company name or ticker
|
||||||
|
- +0.2 if title mentions competitor/sector keywords
|
||||||
|
- Base overlap from Jaccard(title_tokens, context_tokens)
|
||||||
|
- Reduce base if the title is macro-level (economy/markets-wide)
|
||||||
|
"""
|
||||||
|
title_tokens = set(tokenize(title))
|
||||||
|
context_tokens = set(tokenize(ticker_context))
|
||||||
|
|
||||||
|
# Add a small set of generic market keywords to context to better capture overlap
|
||||||
|
generic_keywords = {
|
||||||
|
"stock", "stocks", "share", "shares", "price", "profit", "profits", "loss",
|
||||||
|
"results", "earnings", "revenue", "deal", "merger", "acquisition", "jobs",
|
||||||
|
"cut", "cuts", "dividend", "rises", "falls", "up", "down", "guidance",
|
||||||
|
"forecast", "outlook", "sponsor", "sponsorship", "board", "turmoil",
|
||||||
|
}
|
||||||
|
context_tokens |= generic_keywords
|
||||||
|
|
||||||
|
# Base overlap via Jaccard
|
||||||
|
union = title_tokens | context_tokens
|
||||||
|
inter = title_tokens & context_tokens
|
||||||
|
base_overlap = float(len(inter) / len(union)) if union else 0.0
|
||||||
|
|
||||||
|
# Company/ticker explicit mention (+0.4)
|
||||||
|
title_lower = title.lower()
|
||||||
|
company_mention = False
|
||||||
|
if company:
|
||||||
|
if company.lower() in title_lower:
|
||||||
|
company_mention = True
|
||||||
|
if ticker:
|
||||||
|
# substring check to avoid tokenizer punctuation issues (e.g., BRK.B)
|
||||||
|
if ticker.lower() in title_lower:
|
||||||
|
company_mention = True
|
||||||
|
|
||||||
|
# Competitor/sector keywords (+0.2) — keep set small and generic
|
||||||
|
competitor_words = {
|
||||||
|
"competitor", "competitors", "rival", "rivals", "peer", "peers", "competition",
|
||||||
|
}
|
||||||
|
sector_words = {
|
||||||
|
"technology", "tech", "semiconductor", "chip", "software", "hardware",
|
||||||
|
"bank", "banks", "finance", "financials", "insurance",
|
||||||
|
"energy", "oil", "gas", "utilities",
|
||||||
|
"retail", "consumer", "automotive", "auto",
|
||||||
|
"healthcare", "pharma", "biotech",
|
||||||
|
"telecom", "communications", "media",
|
||||||
|
"aerospace", "defense", "industrial",
|
||||||
|
"mining", "metals",
|
||||||
|
"travel", "airline", "hospitality",
|
||||||
|
"ecommerce", "cloud", "ai", "artificial", "intelligence",
|
||||||
|
}
|
||||||
|
competitor_or_sector = bool(title_tokens & (competitor_words | sector_words))
|
||||||
|
|
||||||
|
# Macro-level hints → dampen base overlap
|
||||||
|
macro_words = {
|
||||||
|
"market", "markets", "economy", "economic", "macro", "inflation", "rates",
|
||||||
|
"interest", "fed", "federal", "policy", "geopolitical", "tariff", "trade",
|
||||||
|
"sector-wide", "industry-wide", "stocks", "equities",
|
||||||
|
}
|
||||||
|
is_macro = bool(title_tokens & macro_words)
|
||||||
|
|
||||||
|
kb = base_overlap
|
||||||
|
if is_macro:
|
||||||
|
kb *= 0.6 # dampen base if distant/macro
|
||||||
|
if company_mention:
|
||||||
|
kb += 0.4
|
||||||
|
if competitor_or_sector:
|
||||||
|
kb += 0.2
|
||||||
|
|
||||||
|
return float(np.clip(kb, 0.0, 1.0))
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
"""
|
||||||
|
Cosine similarity for L2-normalized vectors is their dot product.
|
||||||
|
Ensure inputs are 1D arrays.
|
||||||
|
"""
|
||||||
|
a = a.reshape(-1)
|
||||||
|
b = b.reshape(-1)
|
||||||
|
den = (np.linalg.norm(a) * np.linalg.norm(b))
|
||||||
|
if den == 0.0:
|
||||||
|
return 0.0
|
||||||
|
return float(np.dot(a, b) / den)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_relevance(
|
||||||
|
embedder: SentenceTransformer,
|
||||||
|
title: str,
|
||||||
|
company: str,
|
||||||
|
ticker: str,
|
||||||
|
beta: float = 0.7,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
By default:
|
||||||
|
relevance = 0.7 * cosine_sim(e_news, e_ticker) + 0.3 * keyword_boost
|
||||||
|
where e_* are sentence embeddings and keyword_boost includes simple meta/keyword rules.
|
||||||
|
"""
|
||||||
|
beta = float(np.clip(beta, 0.0, 1.0))
|
||||||
|
ticker_ctx = build_ticker_context(company, ticker)
|
||||||
|
|
||||||
|
embs = embedder.encode([title, ticker_ctx], normalize_embeddings=True)
|
||||||
|
e_news = embs[0]
|
||||||
|
e_ticker = embs[1]
|
||||||
|
|
||||||
|
cos_sim = cosine_similarity(np.asarray(e_news), np.asarray(e_ticker))
|
||||||
|
kb = keyword_boost(title, ticker_ctx, company=company, ticker=ticker)
|
||||||
|
relevance = beta * cos_sim + (1.0 - beta) * kb
|
||||||
|
# Clip to [0, 1] for interpretability
|
||||||
|
return float(np.clip(relevance, 0.0, 1.0))
|
||||||
|
|
||||||
|
|
||||||
|
def default_ticker_for_company(company: str) -> str:
|
||||||
|
"""
|
||||||
|
Approximate mapping from company names in the sample dataset to tickers.
|
||||||
|
Falls back to an uppercase abbreviation if unknown.
|
||||||
|
"""
|
||||||
|
mapping: Dict[str, str] = {
|
||||||
|
"Tesco": "TSCO",
|
||||||
|
"CRH": "CRH",
|
||||||
|
"Holcim Lafarge": "LHN",
|
||||||
|
"Reed Elsevier": "RELX",
|
||||||
|
"Kingfisher": "KGF",
|
||||||
|
"Mr Bricolage": "MRB",
|
||||||
|
"Glencore": "GLEN",
|
||||||
|
"Diageo": "DGE",
|
||||||
|
"Shell": "SHEL",
|
||||||
|
"Shire": "SHP",
|
||||||
|
"Baxalta": "BXLT",
|
||||||
|
"BP": "BP",
|
||||||
|
"HSBC": "HSBA",
|
||||||
|
"Standard Chartered": "STAN",
|
||||||
|
}
|
||||||
|
if company in mapping:
|
||||||
|
return mapping[company]
|
||||||
|
# Fallback: uppercase initials (e.g., "Reed Elsevier" -> "RE")
|
||||||
|
initials = "".join([w[0] for w in company.split() if w])
|
||||||
|
return initials.upper() or company.upper()
|
||||||
|
|
||||||
|
|
||||||
|
def round_float(value: float, ndigits: int = 2) -> float:
|
||||||
|
"""
|
||||||
|
Round float safely; ensures standard Python rounding and float type.
|
||||||
|
"""
|
||||||
|
return float(round(value, ndigits))
|
||||||
|
|
||||||
|
|
||||||
|
def label_to_numeric(label: str) -> int:
|
||||||
|
"""
|
||||||
|
Map textual sentiment to numeric scheme: Negative=-1, Neutral=0, Positive=1.
|
||||||
|
"""
|
||||||
|
mapping = {"negative": -1, "neutral": 0, "positive": 1}
|
||||||
|
return int(mapping.get(label.lower(), 0))
|
||||||
|
|
||||||
|
|
||||||
|
def build_instruction_prompt(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Match the finetune_dapt.py instruction template for consistent scoring.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"### Instruction:\n"
|
||||||
|
"Classify the sentiment of the following financial text.\n\n"
|
||||||
|
f"### Text:\n{text}\n\n"
|
||||||
|
"### Response:\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora_causal_model(base_model_id: str, adapters_path: str, hf_token: str = None):
|
||||||
|
"""
|
||||||
|
Load base causal LM and attach LoRA adapters for SFT scoring via prompting.
|
||||||
|
"""
|
||||||
|
# Keep simple, no quantization by default here
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model_id,
|
||||||
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
token=hf_token,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True, token=hf_token)
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
try:
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
model = PeftModel.from_pretrained(model, adapters_path)
|
||||||
|
model.eval()
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def score_labels_with_lora(
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
model: AutoModelForCausalLM,
|
||||||
|
prompts: List[str],
|
||||||
|
label_texts: List[str],
|
||||||
|
) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Compute sentiment label and confidence using LoRA causal LM by scoring
|
||||||
|
log-likelihood of label strings conditioned on the prompt.
|
||||||
|
Returns (label_str_lowercase, confidence_softmax_over_labels).
|
||||||
|
"""
|
||||||
|
results: List[Tuple[str, float]] = []
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
# Pre-tokenize label targets
|
||||||
|
label_ids_list = [tokenizer.encode(lbl, add_special_tokens=False) for lbl in label_texts]
|
||||||
|
|
||||||
|
for start in range(0, len(prompts), batch_size):
|
||||||
|
chunk = prompts[start : start + batch_size]
|
||||||
|
enc = tokenizer(
|
||||||
|
chunk,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=512,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
# Determine embedding device similar to evaluation_sft.py to avoid full model move
|
||||||
|
try:
|
||||||
|
embed_device = model.base_model.get_input_embeddings().weight.device # type: ignore
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
embed_device = model.get_input_embeddings().weight.device # type: ignore
|
||||||
|
except Exception:
|
||||||
|
embed_device = next(model.parameters()).device
|
||||||
|
input_ids = enc["input_ids"].to(embed_device)
|
||||||
|
attention_mask = enc["attention_mask"].to(embed_device)
|
||||||
|
|
||||||
|
# For each sample in batch, score each label by teacher-forcing the label tokens
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(input_ids.size(0)):
|
||||||
|
prompt_ids = input_ids[i]
|
||||||
|
prompt_len = int(attention_mask[i].sum().item())
|
||||||
|
# Store log-likelihood per label
|
||||||
|
label_logps = []
|
||||||
|
for label_ids in label_ids_list:
|
||||||
|
# Concatenate prompt + label
|
||||||
|
target_ids = torch.tensor(label_ids, dtype=torch.long, device=embed_device)
|
||||||
|
concat_ids = torch.cat([prompt_ids[:prompt_len], target_ids], dim=0).unsqueeze(0)
|
||||||
|
concat_mask = torch.ones_like(concat_ids, device=embed_device)
|
||||||
|
out = model(input_ids=concat_ids, attention_mask=concat_mask)
|
||||||
|
logits = out.logits # [1, seq_len, vocab]
|
||||||
|
log_probs = torch.log_softmax(logits, dim=-1)
|
||||||
|
# Sum log-probs of each label token conditioned on preceding tokens
|
||||||
|
lp_sum = 0.0
|
||||||
|
for k, tok in enumerate(target_ids):
|
||||||
|
# Position of token is prompt_len + k; use logits at previous position
|
||||||
|
pos = prompt_len + k
|
||||||
|
prev_pos = pos - 1
|
||||||
|
if prev_pos < 0:
|
||||||
|
continue
|
||||||
|
lp = log_probs[0, prev_pos, tok.item()].item()
|
||||||
|
lp_sum += lp
|
||||||
|
label_logps.append(lp_sum)
|
||||||
|
# Softmax over label log-likelihoods to get confidence
|
||||||
|
logps_np = np.array(label_logps, dtype=np.float64)
|
||||||
|
# numerical stability
|
||||||
|
m = np.max(logps_np)
|
||||||
|
exp = np.exp(logps_np - m)
|
||||||
|
probs = exp / np.sum(exp)
|
||||||
|
best_idx = int(np.argmax(probs))
|
||||||
|
best_label = label_texts[best_idx].lower()
|
||||||
|
best_conf = float(probs[best_idx])
|
||||||
|
results.append((best_label, best_conf))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def lora_diagnostics(model: AutoModelForCausalLM) -> Dict[str, object]:
|
||||||
|
"""
|
||||||
|
Return basic diagnostics about LoRA adapter loading.
|
||||||
|
"""
|
||||||
|
diag: Dict[str, object] = {}
|
||||||
|
try:
|
||||||
|
adapter_names = getattr(model, "active_adapters", None)
|
||||||
|
if adapter_names is None:
|
||||||
|
# newer peft exposes 'peft_config' dict and 'active_adapter'
|
||||||
|
peft_cfg = getattr(model, "peft_config", None)
|
||||||
|
if isinstance(peft_cfg, dict):
|
||||||
|
adapter_names = list(peft_cfg.keys())
|
||||||
|
diag["adapter_names"] = adapter_names
|
||||||
|
except Exception:
|
||||||
|
diag["adapter_names"] = None
|
||||||
|
|
||||||
|
# Count trainable LoRA parameters
|
||||||
|
total_params = 0
|
||||||
|
lora_trainable = 0
|
||||||
|
lora_total = 0
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
num = p.numel()
|
||||||
|
total_params += num
|
||||||
|
if "lora_" in name:
|
||||||
|
lora_total += num
|
||||||
|
if p.requires_grad:
|
||||||
|
lora_trainable += num
|
||||||
|
diag["total_params"] = total_params
|
||||||
|
diag["lora_total_params"] = lora_total
|
||||||
|
diag["lora_trainable_params"] = lora_trainable
|
||||||
|
diag["lora_trainable_pct"] = (float(lora_trainable) / float(total_params)) if total_params else 0.0
|
||||||
|
return diag
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Compute sentiment confidence and relevance for headlines.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="/u/v/d/vdhanuka/defeatbeta-api-main/Headline_Trialdata.json",
|
||||||
|
help="Path to Headline_Trialdata.json",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
default="/u/v/d/vdhanuka/defeatbeta-api-main/headline_results1.json",
|
||||||
|
help="Where to write the results JSON.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sft_model",
|
||||||
|
type=str,
|
||||||
|
default=os.environ.get("SFT_MODEL_NAME", "distilbert-base-uncased-finetuned-sst-2-english"),
|
||||||
|
help="Hugging Face model path/name for SFT classifier.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_lora_sft",
|
||||||
|
action="store_true",
|
||||||
|
help="Use LoRA SFT adapters on a causal LM (meta-llama) for scoring instead of a classifier.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diagnose_lora",
|
||||||
|
action="store_true",
|
||||||
|
help="Print diagnostics about loaded LoRA adapters and run a quick probe.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--diagnose_only",
|
||||||
|
action="store_true",
|
||||||
|
help="If set with --use_lora_sft, run diagnostics/probe and exit without processing dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_model_id",
|
||||||
|
type=str,
|
||||||
|
default=os.environ.get("BASE_MODEL_ID", "meta-llama/Llama-3.1-8B"),
|
||||||
|
help="Base model ID for LoRA SFT mode.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adapters_path",
|
||||||
|
type=str,
|
||||||
|
default=os.environ.get("ADAPTERS_PATH", "/u/v/d/vdhanuka/defeatbeta-api-main/dapt_sft_adapters_e4_60_20_20"),
|
||||||
|
help="Path to LoRA adapters for LoRA SFT mode.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_model",
|
||||||
|
type=str,
|
||||||
|
default=os.environ.get("EMBEDDING_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2"),
|
||||||
|
help="Sentence-Transformers model for embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta",
|
||||||
|
type=float,
|
||||||
|
default=float(os.environ.get("RELEVANCE_BETA", 0.8)),
|
||||||
|
help="Weight for semantic similarity in relevance calculation (0.7 - 0.9 recommended).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_items",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="If > 0, limit processing to first N items (useful for quick checks).",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Load dataset
|
||||||
|
with open(args.dataset, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, list):
|
||||||
|
raise ValueError("Dataset must be a JSON list of headline objects.")
|
||||||
|
|
||||||
|
if args.max_items and args.max_items > 0:
|
||||||
|
data = data[: args.max_items]
|
||||||
|
|
||||||
|
# Prepare models
|
||||||
|
embedder = SentenceTransformer(args.embedding_model)
|
||||||
|
|
||||||
|
# Sentiment path: classifier or LoRA SFT
|
||||||
|
if args.use_lora_sft:
|
||||||
|
hf_token = (
|
||||||
|
os.getenv("HUGGING_FACE_HUB_TOKEN")
|
||||||
|
or os.getenv("HF_TOKEN")
|
||||||
|
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
||||||
|
)
|
||||||
|
causal_tokenizer, causal_model = load_lora_causal_model(args.base_model_id, args.adapters_path, hf_token)
|
||||||
|
if args.diagnose_lora:
|
||||||
|
diag = lora_diagnostics(causal_model)
|
||||||
|
print("[LoRA] Diagnostics:", json.dumps(diag, indent=2))
|
||||||
|
# Quick probe
|
||||||
|
probe_prompt = build_instruction_prompt("Stocks rose after strong earnings.")
|
||||||
|
label_texts = ["Positive", "Neutral", "Negative"]
|
||||||
|
probe = score_labels_with_lora(causal_tokenizer, causal_model, [probe_prompt], label_texts)
|
||||||
|
if probe:
|
||||||
|
lbl, conf = probe[0]
|
||||||
|
print(f"[LoRA] Probe prediction: {lbl} (confidence={conf:.3f})")
|
||||||
|
if args.diagnose_only:
|
||||||
|
return
|
||||||
|
prompts = [build_instruction_prompt(item.get("title", "")) for item in data]
|
||||||
|
label_texts = ["Positive", "Neutral", "Negative"]
|
||||||
|
sent_conf = score_labels_with_lora(causal_tokenizer, causal_model, prompts, label_texts)
|
||||||
|
else:
|
||||||
|
tokenizer, model = load_sft_model(args.sft_model)
|
||||||
|
# Collect texts for batch classification
|
||||||
|
texts = [item.get("title", "") for item in data]
|
||||||
|
sent_conf = classify_with_confidence(tokenizer, model, texts)
|
||||||
|
|
||||||
|
# Normalize mapping when using LoRA path (already lowercase strings returned)
|
||||||
|
def to_numeric(lbl: str) -> int:
|
||||||
|
return label_to_numeric(lbl)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for item, (label, conf) in zip(data, sent_conf):
|
||||||
|
title = item.get("title", "")
|
||||||
|
company = item.get("company", "")
|
||||||
|
ticker = default_ticker_for_company(company)
|
||||||
|
relevance = compute_relevance(embedder, title, company, ticker, beta=args.beta)
|
||||||
|
results.append({
|
||||||
|
"id": item.get("id"),
|
||||||
|
"title": title,
|
||||||
|
"company": company,
|
||||||
|
"sentiment": label,
|
||||||
|
"sentiment_score": to_numeric(label),
|
||||||
|
"confidence": round_float(conf, 2),
|
||||||
|
"relevance": round_float(relevance, 2),
|
||||||
|
"ticker": ticker,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Write output
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"Wrote {len(results)} results to: {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -9,6 +9,6 @@ config["dapt_adapter_path"] = "/u/v/d/vdhanuka/llama3_8b_dapt_transcripts_lora"
|
||||||
config["llm_provider"] = "openai" # provider for the other agents; DAPT is used for News
|
config["llm_provider"] = "openai" # provider for the other agents; DAPT is used for News
|
||||||
config["backend_url"] = "https://api.openai.com/v1" # unused if DAPT loads fine
|
config["backend_url"] = "https://api.openai.com/v1" # unused if DAPT loads fine
|
||||||
|
|
||||||
graph = TradingAgentsGraph(selected_analysts=["news","market","social"], config=config, debug=True)
|
graph = TradingAgentsGraph(selected_analysts=["news","fundamentals"], config=config, debug=True)
|
||||||
_, decision = graph.propagate(company_name="AAPL", trade_date="2024-01-07")
|
_, decision = graph.propagate(company_name="AAPL", trade_date="2024-01-02")
|
||||||
print(decision)
|
print(decision)
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from tradingagents.dataflows.openai import (
|
||||||
|
get_stock_news_openai,
|
||||||
|
get_global_news_openai,
|
||||||
|
)
|
||||||
|
from tradingagents.dataflows.news_parsers import (
|
||||||
|
parse_global_news,
|
||||||
|
parse_stock_news,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_text_from_global_item(item: Dict[str, Any]) -> str:
|
||||||
|
parts: List[str] = []
|
||||||
|
if item.get("date"):
|
||||||
|
parts.append(f"Date: {item['date']}")
|
||||||
|
if item.get("headline"):
|
||||||
|
parts.append(f"Headline: {item['headline']}")
|
||||||
|
if item.get("relevance"):
|
||||||
|
parts.append(f"Relevance: {item['relevance']}")
|
||||||
|
if item.get("sources"):
|
||||||
|
parts.append("Sources: " + ", ".join(item["sources"][:3]))
|
||||||
|
return "\n".join(parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def build_text_from_stock_item(item: Dict[str, Any]) -> str:
|
||||||
|
parts: List[str] = []
|
||||||
|
if item.get("title"):
|
||||||
|
parts.append(f"Title: {item['title']}")
|
||||||
|
if item.get("summary"):
|
||||||
|
parts.append(item["summary"])
|
||||||
|
if item.get("sources"):
|
||||||
|
parts.append("Sources: " + ", ".join(item["sources"][:3]))
|
||||||
|
return "\n".join(parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def write_json(path: str, rows: List[Dict[str, Any]]) -> None:
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(rows, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def run_company(ticker: str, start_date: str, end_date: str, out_path: str) -> None:
|
||||||
|
raw = get_stock_news_openai(ticker, start_date, end_date)
|
||||||
|
items = parse_stock_news(raw)
|
||||||
|
rows: List[Dict[str, Any]] = []
|
||||||
|
for it in items:
|
||||||
|
text = build_text_from_stock_item(it)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"text": text,
|
||||||
|
"ticker": ticker,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"sources": it.get("sources", []),
|
||||||
|
# Leave label absent; add later if you want supervised SFT
|
||||||
|
}
|
||||||
|
)
|
||||||
|
write_json(out_path, rows)
|
||||||
|
print(f"Wrote {len(rows)} company news items to: {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_global(curr_date: str, look_back_days: int, limit: int, out_path: str) -> None:
|
||||||
|
raw = get_global_news_openai(curr_date, look_back_days=look_back_days, limit=limit)
|
||||||
|
items = parse_global_news(raw)
|
||||||
|
rows: List[Dict[str, Any]] = []
|
||||||
|
for it in items:
|
||||||
|
text = build_text_from_global_item(it)
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"text": text,
|
||||||
|
"curr_date": curr_date,
|
||||||
|
"look_back_days": look_back_days,
|
||||||
|
"sources": it.get("sources", []),
|
||||||
|
# Leave label absent; add later if you want supervised SFT
|
||||||
|
}
|
||||||
|
)
|
||||||
|
write_json(out_path, rows)
|
||||||
|
print(f"Wrote {len(rows)} global news items to: {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Fetch, split, and save news as JSON dataset.")
|
||||||
|
mode = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
mode.add_argument("--company", action="store_true", help="Fetch company-specific news")
|
||||||
|
mode.add_argument("--global-news", action="store_true", help="Fetch global/macro news")
|
||||||
|
|
||||||
|
parser.add_argument("--ticker", type=str, help="Ticker symbol for company news")
|
||||||
|
parser.add_argument("--start-date", type=str, help="Start date YYYY-MM-DD for company news")
|
||||||
|
parser.add_argument("--end-date", type=str, help="End date YYYY-MM-DD for company news")
|
||||||
|
|
||||||
|
parser.add_argument("--curr-date", type=str, help="Reference date YYYY-MM-DD for global news")
|
||||||
|
parser.add_argument("--look-back-days", type=int, default=7, help="Look-back window for global news")
|
||||||
|
parser.add_argument("--limit", type=int, default=5, help="Max number of global items")
|
||||||
|
|
||||||
|
parser.add_argument("--output", type=str, required=True, help="Output JSON path")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.company:
|
||||||
|
if not (args.ticker and args.start_date and args.end_date):
|
||||||
|
raise SystemExit("For --company, provide --ticker, --start-date, --end-date")
|
||||||
|
run_company(args.ticker, args.start_date, args.end_date, args.output)
|
||||||
|
else:
|
||||||
|
if not args.curr_date:
|
||||||
|
raise SystemExit("For --global-news, provide --curr-date")
|
||||||
|
run_global(args.curr_date, args.look_back_days, args.limit, args.output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue