fix(review): hmac.compare_digest for API key, ws/orchestrator auth, SignalMerger per-signal cap logic
This commit is contained in:
parent
28a95f34a7
commit
b50e5b4725
|
|
@ -75,9 +75,12 @@ class SignalMerger:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 两者都有:加权合并
|
# 两者都有:加权合并
|
||||||
|
# Cap each signal's contribution before merging
|
||||||
|
quant_conf = min(quant.confidence, self._config.quant_weight_cap)
|
||||||
|
llm_conf = min(llm.confidence, self._config.llm_weight_cap)
|
||||||
weighted_sum = (
|
weighted_sum = (
|
||||||
quant.direction * quant.confidence
|
quant.direction * quant_conf
|
||||||
+ llm.direction * llm.confidence
|
+ llm.direction * llm_conf
|
||||||
)
|
)
|
||||||
final_direction = _sign(weighted_sum)
|
final_direction = _sign(weighted_sum)
|
||||||
if final_direction == 0:
|
if final_direction == 0:
|
||||||
|
|
@ -85,10 +88,8 @@ class SignalMerger:
|
||||||
"SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD",
|
"SignalMerger: weighted_sum=0 for %s — signals cancel out, HOLD",
|
||||||
ticker,
|
ticker,
|
||||||
)
|
)
|
||||||
total_conf = quant.confidence + llm.confidence
|
total_conf = quant_conf + llm_conf
|
||||||
raw_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0
|
final_confidence = abs(weighted_sum) / total_conf if total_conf > 0 else 0.0
|
||||||
final_confidence = min(raw_confidence, self._config.quant_weight_cap,
|
|
||||||
self._config.llm_weight_cap)
|
|
||||||
|
|
||||||
return FinalSignal(
|
return FinalSignal(
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
|
|
|
||||||
|
|
@ -78,11 +78,12 @@ def test_merge_both_same_direction(merger):
|
||||||
l = _make_signal(direction=1, confidence=0.8, source="llm")
|
l = _make_signal(direction=1, confidence=0.8, source="llm")
|
||||||
result = merger.merge(q, l)
|
result = merger.merge(q, l)
|
||||||
assert result.direction == 1
|
assert result.direction == 1
|
||||||
weighted_sum = 1 * 0.6 + 1 * 0.8 # 1.4
|
# caps applied per-signal before merging
|
||||||
total_conf = 0.6 + 0.8 # 1.4
|
quant_conf = min(0.6, cfg.quant_weight_cap) # 0.6
|
||||||
raw_conf = abs(weighted_sum) / total_conf # 1.0
|
llm_conf = min(0.8, cfg.llm_weight_cap) # 0.8
|
||||||
# actual code caps at min(raw, quant_weight_cap, llm_weight_cap)
|
weighted_sum = 1 * quant_conf + 1 * llm_conf # 1.4
|
||||||
expected_conf = min(raw_conf, cfg.quant_weight_cap, cfg.llm_weight_cap)
|
total_conf = quant_conf + llm_conf # 1.4
|
||||||
|
expected_conf = abs(weighted_sum) / total_conf # 1.0
|
||||||
assert math.isclose(result.confidence, expected_conf)
|
assert math.isclose(result.confidence, expected_conf)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -93,11 +94,13 @@ def test_merge_both_opposite_direction_quant_wins(merger):
|
||||||
q = _make_signal(direction=1, confidence=0.9, source="quant")
|
q = _make_signal(direction=1, confidence=0.9, source="quant")
|
||||||
l = _make_signal(direction=-1, confidence=0.3, source="llm")
|
l = _make_signal(direction=-1, confidence=0.3, source="llm")
|
||||||
result = merger.merge(q, l)
|
result = merger.merge(q, l)
|
||||||
weighted_sum = 1 * 0.9 + (-1) * 0.3 # 0.6
|
|
||||||
assert result.direction == 1
|
assert result.direction == 1
|
||||||
total_conf = 0.9 + 0.3
|
# caps applied per-signal before merging
|
||||||
raw_conf = abs(weighted_sum) / total_conf
|
quant_conf = min(0.9, cfg.quant_weight_cap) # 0.8
|
||||||
expected_conf = min(raw_conf, cfg.quant_weight_cap, cfg.llm_weight_cap)
|
llm_conf = min(0.3, cfg.llm_weight_cap) # 0.3
|
||||||
|
weighted_sum = 1 * quant_conf + (-1) * llm_conf # 0.5
|
||||||
|
total_conf = quant_conf + llm_conf # 1.1
|
||||||
|
expected_conf = abs(weighted_sum) / total_conf
|
||||||
assert math.isclose(result.confidence, expected_conf)
|
assert math.isclose(result.confidence, expected_conf)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ FastAPI REST API + WebSocket for real-time analysis progress
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import fcntl
|
import fcntl
|
||||||
|
import hmac
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
@ -105,7 +106,9 @@ def _check_api_key(api_key: Optional[str]) -> bool:
|
||||||
required = _get_api_key()
|
required = _get_api_key()
|
||||||
if not required:
|
if not required:
|
||||||
return True
|
return True
|
||||||
return api_key == required
|
if not api_key:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(api_key, required)
|
||||||
|
|
||||||
def _auth_error():
|
def _auth_error():
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
raise HTTPException(status_code=401, detail="Unauthorized: valid X-API-Key header required")
|
||||||
|
|
@ -1101,8 +1104,13 @@ async def root():
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws/orchestrator")
|
@app.websocket("/ws/orchestrator")
|
||||||
async def ws_orchestrator(websocket: WebSocket):
|
async def ws_orchestrator(websocket: WebSocket, api_key: Optional[str] = None):
|
||||||
"""WebSocket endpoint for orchestrator live signals."""
|
"""WebSocket endpoint for orchestrator live signals."""
|
||||||
|
# Auth check before accepting — reject unauthenticated connections
|
||||||
|
if not _check_api_key(api_key):
|
||||||
|
await websocket.close(code=4401)
|
||||||
|
return
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, str(REPO_ROOT))
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
from orchestrator.config import OrchestratorConfig
|
from orchestrator.config import OrchestratorConfig
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue