modification on evaluation part to include the SFT-model included agentic system
This commit is contained in:
parent
7d3559665e
commit
eceb52e378
|
|
@ -10,3 +10,5 @@ eval_data/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
.env
|
.env
|
||||||
.history/
|
.history/
|
||||||
|
llama3_8b_dapt_transcripts_lora
|
||||||
|
dapt_sft_adapters_e4_60_20_20
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: Current File",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
@ -21,6 +21,17 @@ from evaluation_long_short.visualize import plot_cumulative_returns_from_results
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
|
def clear_chromadb_collections():
|
||||||
|
"""Clear any existing ChromaDB collections to avoid conflicts"""
|
||||||
|
try:
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
client = chromadb.Client(Settings(allow_reset=True))
|
||||||
|
client.reset()
|
||||||
|
print("[CLEANUP] ChromaDB collections cleared")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[CLEANUP] Warning: Could not clear ChromaDB: {e}")
|
||||||
|
|
||||||
def is_debugging() -> bool:
|
def is_debugging() -> bool:
|
||||||
try:
|
try:
|
||||||
import debugpy
|
import debugpy
|
||||||
|
|
@ -54,16 +65,29 @@ def save_strategy_actions_to_json(
|
||||||
# Build actions list with relevant daily info
|
# Build actions list with relevant daily info
|
||||||
actions = []
|
actions = []
|
||||||
for date, row in portfolio.iterrows():
|
for date, row in portfolio.iterrows():
|
||||||
|
# Handle both datetime and string dates
|
||||||
|
if isinstance(date, str):
|
||||||
|
date_str = date
|
||||||
|
else:
|
||||||
date_str = date.strftime("%Y-%m-%d")
|
date_str = date.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
# Handle different column names from different backtesting methods
|
||||||
|
# Baselines use: action, position, close
|
||||||
|
# TradingAgents use: action, shares, close_price
|
||||||
action_record = {
|
action_record = {
|
||||||
"date": date_str,
|
"date": date_str,
|
||||||
"action": int(row["action"]) if pd.notna(row["action"]) else 0, # 1=BUY, 0=HOLD, -1=SELL
|
"action": int(row["action"]) if "action" in row and pd.notna(row["action"]) else 0,
|
||||||
"position": int(row["position"]) if pd.notna(row["position"]) else 0, # 1=long, 0=flat
|
"position": int(row.get("position", 1 if row.get("shares", 0) > 0 else (-1 if row.get("shares", 0) < 0 else 0))),
|
||||||
"close_price": float(row["close"]) if pd.notna(row["close"]) else None,
|
"close_price": float(row.get("close_price") or row.get("close")) if ("close_price" in row or "close" in row) else None,
|
||||||
"portfolio_value": float(row["portfolio_value"]) if pd.notna(row["portfolio_value"]) else None,
|
"portfolio_value": float(row["portfolio_value"]) if pd.notna(row["portfolio_value"]) else None,
|
||||||
"strategy_return": float(row["strategy_return"]) if pd.notna(row["strategy_return"]) else 0.0,
|
"strategy_return": float(row["strategy_return"]) if pd.notna(row["strategy_return"]) else 0.0,
|
||||||
"cumulative_return": float(row["cumulative_return"]) if pd.notna(row["cumulative_return"]) else 1.0
|
"cumulative_return": float(row["cumulative_return"]) if pd.notna(row["cumulative_return"]) else 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add shares if available (TradingAgents specific)
|
||||||
|
if "shares" in row:
|
||||||
|
action_record["shares"] = float(row["shares"])
|
||||||
|
|
||||||
actions.append(action_record)
|
actions.append(action_record)
|
||||||
|
|
||||||
# Save to JSON
|
# Save to JSON
|
||||||
|
|
@ -87,11 +111,24 @@ def run_evaluation(
|
||||||
end_date: str,
|
end_date: str,
|
||||||
initial_capital: float = 100000,
|
initial_capital: float = 100000,
|
||||||
include_tradingagents: bool = True,
|
include_tradingagents: bool = True,
|
||||||
|
include_dapt: bool = True,
|
||||||
|
dapt_adapter_path: str = None,
|
||||||
output_dir: str = None,
|
output_dir: str = None,
|
||||||
config: dict = None
|
config: dict = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run complete evaluation: baselines + TradingAgents for a single ticker.
|
Run complete evaluation: baselines + TradingAgents (original + DAPT variant) for a single ticker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ticker: Stock ticker symbol
|
||||||
|
start_date: Start date for evaluation
|
||||||
|
end_date: End date for evaluation
|
||||||
|
initial_capital: Initial capital for backtesting
|
||||||
|
include_tradingagents: Whether to include original TradingAgents
|
||||||
|
include_dapt: Whether to include DAPT-enhanced TradingAgents
|
||||||
|
dapt_adapter_path: Path to DAPT adapter (required if include_dapt=True)
|
||||||
|
output_dir: Output directory for results
|
||||||
|
config: Base configuration dictionary
|
||||||
"""
|
"""
|
||||||
print(f"\n{'='*80}")
|
print(f"\n{'='*80}")
|
||||||
print(f"EVALUATION: {ticker} from {start_date} to {end_date}")
|
print(f"EVALUATION: {ticker} from {start_date} to {end_date}")
|
||||||
|
|
@ -130,12 +167,15 @@ def run_evaluation(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Failed: {e}")
|
print(f"✗ Failed: {e}")
|
||||||
|
|
||||||
# TradingAgents
|
# TradingAgents - Original
|
||||||
if include_tradingagents:
|
if include_tradingagents:
|
||||||
print("\n" + "="*80)
|
print("\n" + "="*80)
|
||||||
print("STEP 3: Running TradingAgents")
|
print("STEP 3: Running TradingAgents (Original)")
|
||||||
print("="*80)
|
print("="*80)
|
||||||
try:
|
try:
|
||||||
|
# Clear any existing ChromaDB collections
|
||||||
|
clear_chromadb_collections()
|
||||||
|
|
||||||
cfg = (config or DEFAULT_CONFIG).copy()
|
cfg = (config or DEFAULT_CONFIG).copy()
|
||||||
# Fast eval defaults (you can override from CLI)
|
# Fast eval defaults (you can override from CLI)
|
||||||
cfg["deep_think_llm"] = cfg.get("deep_think_llm", "o4-mini")
|
cfg["deep_think_llm"] = cfg.get("deep_think_llm", "o4-mini")
|
||||||
|
|
@ -144,15 +184,20 @@ def run_evaluation(
|
||||||
cfg["max_risk_discuss_rounds"] = cfg.get("max_risk_discuss_rounds", 1)
|
cfg["max_risk_discuss_rounds"] = cfg.get("max_risk_discuss_rounds", 1)
|
||||||
# Deterministic-ish decoding for reproducibility
|
# Deterministic-ish decoding for reproducibility
|
||||||
cfg.setdefault("llm_params", {}).update({"temperature": 0.7, "top_p": 1.0, "seed": 42})
|
cfg.setdefault("llm_params", {}).update({"temperature": 0.7, "top_p": 1.0, "seed": 42})
|
||||||
|
# Disable ALL fine-tuned models for original TradingAgents
|
||||||
|
cfg["use_dapt_sentiment"] = False
|
||||||
|
cfg["use_sft_sentiment"] = False
|
||||||
|
|
||||||
print(f"\nInitializing TradingAgents...")
|
print(f"\nInitializing TradingAgents (Original)...")
|
||||||
print(f" Deep Thinking LLM: {cfg['deep_think_llm']}")
|
print(f" Deep Thinking LLM: {cfg['deep_think_llm']}")
|
||||||
print(f" Quick Thinking LLM: {cfg['quick_think_llm']}")
|
print(f" Quick Thinking LLM: {cfg['quick_think_llm']}")
|
||||||
print(f" Debate Rounds: {cfg['max_debate_rounds']}")
|
print(f" Debate Rounds: {cfg['max_debate_rounds']}")
|
||||||
|
print(f" DAPT Sentiment: {cfg.get('use_dapt_sentiment', False)}")
|
||||||
|
print(f" SFT Sentiment: {cfg.get('use_sft_sentiment', False)}")
|
||||||
|
|
||||||
graph = TradingAgentsGraph(
|
graph = TradingAgentsGraph(
|
||||||
selected_analysts=["news"],
|
# selected_analysts=["news"],
|
||||||
# selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
debug=False,
|
debug=False,
|
||||||
config=cfg
|
config=cfg
|
||||||
)
|
)
|
||||||
|
|
@ -160,19 +205,78 @@ def run_evaluation(
|
||||||
ta_portfolio = ta_backtester.backtest(ticker, start_date, end_date, data)
|
ta_portfolio = ta_backtester.backtest(ticker, start_date, end_date, data)
|
||||||
|
|
||||||
engine.results["TradingAgents"] = ta_portfolio
|
engine.results["TradingAgents"] = ta_portfolio
|
||||||
print("\n✓ TradingAgents backtest complete")
|
print("\n✓ TradingAgents (Original) backtest complete")
|
||||||
|
|
||||||
# Save TradingAgents actions to JSON (in consistent format with baselines)
|
# Save TradingAgents actions to JSON (in consistent format with baselines)
|
||||||
save_strategy_actions_to_json(ta_portfolio, "TradingAgents", ticker, start_date, end_date, output_dir)
|
save_strategy_actions_to_json(ta_portfolio, "TradingAgents", ticker, start_date, end_date, output_dir)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n✗ TradingAgents failed: {e}")
|
print(f"\n✗ TradingAgents (Original) failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# TradingAgents - DAPT Enhanced
|
||||||
|
if include_dapt:
|
||||||
|
print("\n" + "="*80)
|
||||||
|
print("STEP 4: Running TradingAgents (DAPT-Enhanced)")
|
||||||
|
print("="*80)
|
||||||
|
try:
|
||||||
|
# Clear any existing ChromaDB collections
|
||||||
|
clear_chromadb_collections()
|
||||||
|
|
||||||
|
if dapt_adapter_path is None:
|
||||||
|
# Default to the path from test_dapt.py
|
||||||
|
dapt_adapter_path = "D:/Quanliang/PhD_courses/CS769-TradingAgents/llama3_8b_dapt_transcripts_lora"
|
||||||
|
print(f" Using default DAPT adapter path: {dapt_adapter_path}")
|
||||||
|
|
||||||
|
cfg_dapt = (config or DEFAULT_CONFIG).copy()
|
||||||
|
# Fast eval defaults (you can override from CLI)
|
||||||
|
cfg_dapt["deep_think_llm"] = cfg_dapt.get("deep_think_llm", "o4-mini")
|
||||||
|
cfg_dapt["quick_think_llm"] = cfg_dapt.get("quick_think_llm", "gpt-4o-mini")
|
||||||
|
cfg_dapt["max_debate_rounds"] = cfg_dapt.get("max_debate_rounds", 1)
|
||||||
|
cfg_dapt["max_risk_discuss_rounds"] = cfg_dapt.get("max_risk_discuss_rounds", 1)
|
||||||
|
# Deterministic-ish decoding for reproducibility
|
||||||
|
cfg_dapt.setdefault("llm_params", {}).update({"temperature": 0.7, "top_p": 1.0, "seed": 42})
|
||||||
|
|
||||||
|
# Enable BOTH DAPT and SFT for complete fine-tuned pipeline
|
||||||
|
cfg_dapt["use_dapt_sentiment"] = True
|
||||||
|
cfg_dapt["dapt_adapter_path"] = dapt_adapter_path
|
||||||
|
cfg_dapt["use_sft_sentiment"] = True # Enable SFT for news sentiment
|
||||||
|
cfg_dapt["sft_adapter_path"] = cfg_dapt.get("sft_adapter_path", "D:/Quanliang/PhD_courses/CS769-TradingAgents/dapt_sft_adapters_e4_60_20_20")
|
||||||
|
cfg_dapt["llm_provider"] = cfg_dapt.get("llm_provider", "openai") # provider for other agents
|
||||||
|
|
||||||
|
print(f"\nInitializing TradingAgents (DAPT-Enhanced)...")
|
||||||
|
print(f" Deep Thinking LLM: {cfg_dapt['deep_think_llm']}")
|
||||||
|
print(f" Quick Thinking LLM: {cfg_dapt['quick_think_llm']}")
|
||||||
|
print(f" Debate Rounds: {cfg_dapt['max_debate_rounds']}")
|
||||||
|
print(f" DAPT Sentiment: {cfg_dapt['use_dapt_sentiment']}")
|
||||||
|
print(f" DAPT Adapter Path: {cfg_dapt['dapt_adapter_path']}")
|
||||||
|
print(f" SFT Sentiment: {cfg_dapt['use_sft_sentiment']}")
|
||||||
|
print(f" SFT Adapter Path: {cfg_dapt['sft_adapter_path']}")
|
||||||
|
|
||||||
|
graph_dapt = TradingAgentsGraph(
|
||||||
|
# selected_analysts=["news"],
|
||||||
|
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||||
|
debug=False,
|
||||||
|
config=cfg_dapt
|
||||||
|
)
|
||||||
|
ta_dapt_backtester = TradingAgentsBacktester(graph_dapt, initial_capital, output_dir)
|
||||||
|
ta_dapt_portfolio = ta_dapt_backtester.backtest(ticker, start_date, end_date, data)
|
||||||
|
|
||||||
|
engine.results["TradingAgents_DAPT"] = ta_dapt_portfolio
|
||||||
|
print("\n✓ TradingAgents (DAPT-Enhanced) backtest complete")
|
||||||
|
|
||||||
|
# Save TradingAgents_DAPT actions to JSON
|
||||||
|
save_strategy_actions_to_json(ta_dapt_portfolio, "TradingAgents_DAPT", ticker, start_date, end_date, output_dir)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ TradingAgents (DAPT-Enhanced) failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
print("\n" + "="*80)
|
print("\n" + "="*80)
|
||||||
print("STEP 4: Calculating Performance Metrics")
|
print("STEP 5: Calculating Performance Metrics")
|
||||||
print("="*80)
|
print("="*80)
|
||||||
all_metrics = {}
|
all_metrics = {}
|
||||||
for name, portfolio in engine.results.items():
|
for name, portfolio in engine.results.items():
|
||||||
|
|
@ -182,7 +286,7 @@ def run_evaluation(
|
||||||
|
|
||||||
# Generate cumulative returns comparison plot
|
# Generate cumulative returns comparison plot
|
||||||
print("\n" + "="*80)
|
print("\n" + "="*80)
|
||||||
print("STEP 5: Generating Comparison Plot")
|
print("STEP 6: Generating Comparison Plot")
|
||||||
print("="*80)
|
print("="*80)
|
||||||
try:
|
try:
|
||||||
comparison_plot_path = str(out / ticker / "strategy_comparison.png")
|
comparison_plot_path = str(out / ticker / "strategy_comparison.png")
|
||||||
|
|
@ -223,7 +327,9 @@ def main():
|
||||||
parser.add_argument("--start-date", type=str, required=True, help="Start date (YYYY-MM-DD)")
|
parser.add_argument("--start-date", type=str, required=True, help="Start date (YYYY-MM-DD)")
|
||||||
parser.add_argument("--end-date", type=str, required=True, help="End date (YYYY-MM-DD)")
|
parser.add_argument("--end-date", type=str, required=True, help="End date (YYYY-MM-DD)")
|
||||||
parser.add_argument("--capital", type=float, default=100000, help="Initial capital (default: 100000)")
|
parser.add_argument("--capital", type=float, default=100000, help="Initial capital (default: 100000)")
|
||||||
parser.add_argument("--skip-tradingagents", action="store_true", help="Skip TradingAgents evaluation")
|
parser.add_argument("--skip-tradingagents", action="store_true", help="Skip original TradingAgents evaluation")
|
||||||
|
parser.add_argument("--skip-dapt", action="store_true", help="Skip DAPT-enhanced TradingAgents evaluation")
|
||||||
|
parser.add_argument("--dapt-adapter-path", type=str, default=None, help="Path to DAPT adapter (default: llama3_8b_dapt_transcripts_lora in workspace)")
|
||||||
parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results")
|
parser.add_argument("--output-dir", type=str, default=None, help="Output directory for results")
|
||||||
parser.add_argument("--deep-llm", type=str, default="o4-mini", help="Deep thinking LLM model")
|
parser.add_argument("--deep-llm", type=str, default="o4-mini", help="Deep thinking LLM model")
|
||||||
parser.add_argument("--quick-llm", type=str, default="gpt-4o-mini", help="Quick thinking LLM model")
|
parser.add_argument("--quick-llm", type=str, default="gpt-4o-mini", help="Quick thinking LLM model")
|
||||||
|
|
@ -246,6 +352,8 @@ def main():
|
||||||
end_date="2024-01-10",
|
end_date="2024-01-10",
|
||||||
initial_capital=1000,
|
initial_capital=1000,
|
||||||
include_tradingagents=True,
|
include_tradingagents=True,
|
||||||
|
include_dapt=True,
|
||||||
|
dapt_adapter_path="D:/Quanliang/PhD_courses/CS769-TradingAgents/llama3_8b_dapt_transcripts_lora",
|
||||||
output_dir="./evaluation_long_short/results",
|
output_dir="./evaluation_long_short/results",
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
@ -266,6 +374,8 @@ def main():
|
||||||
end_date=args.end_date,
|
end_date=args.end_date,
|
||||||
initial_capital=args.capital,
|
initial_capital=args.capital,
|
||||||
include_tradingagents=not args.skip_tradingagents,
|
include_tradingagents=not args.skip_tradingagents,
|
||||||
|
include_dapt=not args.skip_dapt,
|
||||||
|
dapt_adapter_path=args.dapt_adapter_path,
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
config=config
|
config=config
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["use_dapt_sentiment"] = True
|
config["use_dapt_sentiment"] = True
|
||||||
config["dapt_adapter_path"] = "/u/v/d/vdhanuka/llama3_8b_dapt_transcripts_lora" # <- set your absolute path
|
config["dapt_adapter_path"] = "" # <- set your absolute path
|
||||||
config["llm_provider"] = "openai" # provider for the other agents; DAPT is used for News
|
config["llm_provider"] = "openai" # provider for the other agents; DAPT is used for News
|
||||||
config["backend_url"] = "https://api.openai.com/v1" # unused if DAPT loads fine
|
config["backend_url"] = "https://api.openai.com/v1" # unused if DAPT loads fine
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import sys
|
||||||
from typing import List, Dict, Any, Tuple, Optional
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
|
|
||||||
# Add external utilities path for confidence/relevance and LoRA scoring
|
# Add external utilities path for confidence/relevance and LoRA scoring
|
||||||
CONF_UTILS_PATH = "/u/v/d/vdhanuka/CS769-TradingAgents"
|
CONF_UTILS_PATH = "D:/Quanliang/PhD_courses/CS769-TradingAgents"
|
||||||
if CONF_UTILS_PATH not in sys.path:
|
if CONF_UTILS_PATH not in sys.path:
|
||||||
sys.path.append(CONF_UTILS_PATH)
|
sys.path.append(CONF_UTILS_PATH)
|
||||||
|
|
||||||
|
|
@ -27,10 +27,20 @@ def create_news_analyst(llm):
|
||||||
lora_loaded: Dict[str, Any] = {"tokenizer": None, "model": None, "embedder": None}
|
lora_loaded: Dict[str, Any] = {"tokenizer": None, "model": None, "embedder": None}
|
||||||
|
|
||||||
def _ensure_models():
|
def _ensure_models():
|
||||||
|
"""Load SFT LoRA model and embedder only if use_sft_sentiment is enabled"""
|
||||||
|
cfg = get_config()
|
||||||
|
use_sft = cfg.get("use_sft_sentiment", False) # Default to False for original behavior
|
||||||
|
|
||||||
|
if not use_sft:
|
||||||
|
# Skip loading SFT models if disabled
|
||||||
|
print("[NEWS_ANALYST] SFT sentiment disabled - using fallback sentiment analysis")
|
||||||
|
return False
|
||||||
|
|
||||||
if conf is None:
|
if conf is None:
|
||||||
raise RuntimeError("confidence.py utilities not available on sys.path.")
|
raise RuntimeError("confidence.py utilities not available on sys.path.")
|
||||||
if lora_loaded["tokenizer"] is None or lora_loaded["model"] is None:
|
if lora_loaded["tokenizer"] is None or lora_loaded["model"] is None:
|
||||||
adapters_path = "/u/v/d/vdhanuka/defeatbeta-api-main/dapt_sft_adapters_e4_60_20_20"
|
# Use configured SFT adapter path
|
||||||
|
adapters_path = cfg.get("sft_adapter_path", "D:/Quanliang/PhD_courses/CS769-TradingAgents/dapt_sft_adapters_e4_60_20_20")
|
||||||
base_model_id = "meta-llama/Llama-3.1-8B"
|
base_model_id = "meta-llama/Llama-3.1-8B"
|
||||||
print(f"[NEWS_ANALYST] Loading SFT LoRA model from: {adapters_path}")
|
print(f"[NEWS_ANALYST] Loading SFT LoRA model from: {adapters_path}")
|
||||||
tok, mdl = conf.load_lora_causal_model(base_model_id, adapters_path)
|
tok, mdl = conf.load_lora_causal_model(base_model_id, adapters_path)
|
||||||
|
|
@ -43,6 +53,7 @@ def create_news_analyst(llm):
|
||||||
print("[NEWS_ANALYST] Loading sentence transformer embedder...")
|
print("[NEWS_ANALYST] Loading sentence transformer embedder...")
|
||||||
lora_loaded["embedder"] = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
lora_loaded["embedder"] = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
||||||
print("[NEWS_ANALYST] Embedder loaded successfully")
|
print("[NEWS_ANALYST] Embedder loaded successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
def _score_items(
|
def _score_items(
|
||||||
items: List[Dict[str, Any]],
|
items: List[Dict[str, Any]],
|
||||||
|
|
@ -55,11 +66,19 @@ def create_news_analyst(llm):
|
||||||
Score each item with sentiment (LoRA) + confidence and relevance, then compute
|
Score each item with sentiment (LoRA) + confidence and relevance, then compute
|
||||||
net sentiment as sum(w_i * S_i) / sum(w_i), where w_i = alpha*confidence + (1-alpha)*relevance.
|
net sentiment as sum(w_i * S_i) / sum(w_i), where w_i = alpha*confidence + (1-alpha)*relevance.
|
||||||
S_i in {-1, 0, 1}.
|
S_i in {-1, 0, 1}.
|
||||||
|
|
||||||
|
If SFT sentiment is disabled, returns empty scoring.
|
||||||
"""
|
"""
|
||||||
if not items:
|
if not items:
|
||||||
return [], 0.0, "Neutral"
|
return [], 0.0, "Neutral"
|
||||||
|
|
||||||
_ensure_models()
|
# Check if SFT models should be loaded
|
||||||
|
sft_enabled = _ensure_models()
|
||||||
|
if not sft_enabled:
|
||||||
|
# SFT disabled - return items without sentiment scoring
|
||||||
|
print("[NEWS_ANALYST] Returning items without SFT sentiment scores (disabled)")
|
||||||
|
return items, 0.0, "Neutral"
|
||||||
|
|
||||||
tokenizer = lora_loaded["tokenizer"]
|
tokenizer = lora_loaded["tokenizer"]
|
||||||
model = lora_loaded["model"]
|
model = lora_loaded["model"]
|
||||||
embedder = lora_loaded["embedder"]
|
embedder = lora_loaded["embedder"]
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,10 @@ DEFAULT_CONFIG = {
|
||||||
# Sentiment analysis model (DAPTed Llama 3.1 8B)
|
# Sentiment analysis model (DAPTed Llama 3.1 8B)
|
||||||
"use_dapt_sentiment": True, # Use DAPTed model for sentiment analysis (set False to use OpenAI backup)
|
"use_dapt_sentiment": True, # Use DAPTed model for sentiment analysis (set False to use OpenAI backup)
|
||||||
# Path to DAPT PEFT adapter (dynamically uses current username)
|
# Path to DAPT PEFT adapter (dynamically uses current username)
|
||||||
"dapt_adapter_path": f"/u/v/d/{os.getenv('USER', 'vdhanuka')}/llama3_8b_dapt_transcripts_lora",
|
"dapt_adapter_path": "D:/Quanliang/PhD_courses/CS769-TradingAgents/llama3_8b_dapt_transcripts_lora",
|
||||||
|
# Path to SFT adapter for news sentiment scoring
|
||||||
|
"use_sft_sentiment": True, # Use SFT fine-tuned model for news sentiment (set False for no fine-tuning)
|
||||||
|
"sft_adapter_path": "D:/Quanliang/PhD_courses/CS769-TradingAgents/dapt_sft_adapters_e4_60_20_20",
|
||||||
|
|
||||||
# Fallback: OpenAI model if DAPT is unavailable
|
# Fallback: OpenAI model if DAPT is unavailable
|
||||||
"sentiment_fallback_llm": "o4-mini", # OpenAI model for fallback
|
"sentiment_fallback_llm": "o4-mini", # OpenAI model for fallback
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue