fix(review): hmac.compare_digest for API key, ws/orchestrator auth, SignalMerger per-signal cap logic
This commit is contained in:
parent
28a95f34a7
commit
b50e5b4725
|
|
@ -75,9 +75,12 @@ class SignalMerger:
|
|||
)
|
||||
|
||||
# 两者都有:加权合并
|
||||
# Cap each signal's contribution before merging
|
||||
quant_conf = min(quant.confidence, self._config.quant_weight_cap)
|
||||
llm_conf = min(llm.confidence, self._config.llm_weight_cap)
|
||||
weighted_sum = (
|
||||
quant.direction * quant.confidence
|
||||
+ llm.direction * llm.confidence
|
||||
quant.direction * quant_conf
|
||||
+ llm.direction * llm_conf
|
||||
)
|
||||
final_direction = _sign(weighted_sum)
|
||||
if final_direction == 0:
|
||||
|
|
@ -85,10 +88,8 @@ class SignalMerger:
|
|||
"SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD",
|
||||
ticker,
|
||||
)
|
||||
total_conf = quant.confidence + llm.confidence
|
||||
raw_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0
|
||||
final_confidence = min(raw_confidence, self._config.quant_weight_cap,
|
||||
self._config.llm_weight_cap)
|
||||
total_conf = quant_conf + llm_conf
|
||||
final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0
|
||||
|
||||
return FinalSignal(
|
||||
ticker=ticker,
|
||||
|
|
|
|||
|
|
@ -78,11 +78,12 @@ def test_merge_both_same_direction(merger):
|
|||
l = _make_signal(direction=1, confidence=0.8, source="llm")
|
||||
result = merger.merge(q, l)
|
||||
assert result.direction == 1
|
||||
weighted_sum = 1 * 0.6 + 1 * 0.8 # 1.4
|
||||
total_conf = 0.6 + 0.8 # 1.4
|
||||
raw_conf = abs(weighted_sum) / total_conf # 1.0
|
||||
# actual code caps at min(raw, quant_weight_cap, llm_weight_cap)
|
||||
expected_conf = min(raw_conf, cfg.quant_weight_cap, cfg.llm_weight_cap)
|
||||
# caps applied per-signal before merging
|
||||
quant_conf = min(0.6, cfg.quant_weight_cap) # 0.6
|
||||
llm_conf = min(0.8, cfg.llm_weight_cap) # 0.8
|
||||
weighted_sum = 1 * quant_conf + 1 * llm_conf # 1.4
|
||||
total_conf = quant_conf + llm_conf # 1.4
|
||||
expected_conf = abs(weighted_sum) / total_conf # 1.0
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
|
||||
|
||||
|
|
@ -93,11 +94,13 @@ def test_merge_both_opposite_direction_quant_wins(merger):
|
|||
q = _make_signal(direction=1, confidence=0.9, source="quant")
|
||||
l = _make_signal(direction=-1, confidence=0.3, source="llm")
|
||||
result = merger.merge(q, l)
|
||||
weighted_sum = 1 * 0.9 + (-1) * 0.3 # 0.6
|
||||
assert result.direction == 1
|
||||
total_conf = 0.9 + 0.3
|
||||
raw_conf = abs(weighted_sum) / total_conf
|
||||
expected_conf = min(raw_conf, cfg.quant_weight_cap, cfg.llm_weight_cap)
|
||||
# caps applied per-signal before merging
|
||||
quant_conf = min(0.9, cfg.quant_weight_cap) # 0.8
|
||||
llm_conf = min(0.3, cfg.llm_weight_cap) # 0.3
|
||||
weighted_sum = 1 * quant_conf + (-1) * llm_conf # 0.5
|
||||
total_conf = quant_conf + llm_conf # 1.1
|
||||
expected_conf = abs(weighted_sum) / total_conf
|
||||
assert math.isclose(result.confidence, expected_conf)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ FastAPI REST API + WebSocket for real-time analysis progress
|
|||
"""
|
||||
import asyncio
|
||||
import fcntl
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
|
|
@ -105,7 +106,9 @@ def _check_api_key(api_key: Optional[str]) -> bool:
|
|||
required = _get_api_key()
|
||||
if not required:
|
||||
return True
|
||||
return api_key == required
|
||||
if not api_key:
|
||||
return False
|
||||
return hmac.compare_digest(api_key, required)
|
||||
|
||||
def _auth_error():
|
||||
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
||||
|
|
@ -1101,8 +1104,13 @@ async def root():
|
|||
|
||||
|
||||
@app.websocket("/ws/orchestrator")
|
||||
async def ws_orchestrator(websocket: WebSocket):
|
||||
async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None):
|
||||
"""WebSocket endpoint for orchestrator live signals."""
|
||||
# Auth check before accepting — reject unauthenticated connections
|
||||
if not _check_api_key(api_key):
|
||||
await websocket.close(code=4401)
|
||||
return
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
from orchestrator.config import OrchestratorConfig
|
||||
|
|
|
|||
Loading…
Reference in New Issue