feat(memory): add risk profiles memory for user preferences - Fixes #20
This commit is contained in:
parent
dbfcea3740
commit
25c31d5f5d
|
|
@ -0,0 +1,966 @@
|
|||
"""Tests for Risk Profiles Memory module.
|
||||
|
||||
Issue #20: [MEM-19] Risk profiles memory - user preferences over time
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from tradingagents.memory.risk_profiles import (
|
||||
RiskProfileMemory,
|
||||
RiskProfile,
|
||||
RiskDecision,
|
||||
RiskTolerance,
|
||||
MarketRegime,
|
||||
RiskCategory,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RiskTolerance Enum Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRiskTolerance:
|
||||
"""Tests for RiskTolerance enum."""
|
||||
|
||||
def test_from_score_conservative(self):
|
||||
"""Test conservative threshold (< 0.25)."""
|
||||
assert RiskTolerance.from_score(0.0) == RiskTolerance.CONSERVATIVE
|
||||
assert RiskTolerance.from_score(0.1) == RiskTolerance.CONSERVATIVE
|
||||
assert RiskTolerance.from_score(0.24) == RiskTolerance.CONSERVATIVE
|
||||
|
||||
def test_from_score_moderate(self):
|
||||
"""Test moderate threshold (0.25 - 0.50)."""
|
||||
assert RiskTolerance.from_score(0.25) == RiskTolerance.MODERATE
|
||||
assert RiskTolerance.from_score(0.37) == RiskTolerance.MODERATE
|
||||
assert RiskTolerance.from_score(0.49) == RiskTolerance.MODERATE
|
||||
|
||||
def test_from_score_aggressive(self):
|
||||
"""Test aggressive threshold (0.50 - 0.75)."""
|
||||
assert RiskTolerance.from_score(0.50) == RiskTolerance.AGGRESSIVE
|
||||
assert RiskTolerance.from_score(0.62) == RiskTolerance.AGGRESSIVE
|
||||
assert RiskTolerance.from_score(0.74) == RiskTolerance.AGGRESSIVE
|
||||
|
||||
def test_from_score_very_aggressive(self):
|
||||
"""Test very aggressive threshold (>= 0.75)."""
|
||||
assert RiskTolerance.from_score(0.75) == RiskTolerance.VERY_AGGRESSIVE
|
||||
assert RiskTolerance.from_score(0.9) == RiskTolerance.VERY_AGGRESSIVE
|
||||
assert RiskTolerance.from_score(1.0) == RiskTolerance.VERY_AGGRESSIVE
|
||||
|
||||
def test_to_score(self):
|
||||
"""Test conversion to numeric score."""
|
||||
assert RiskTolerance.CONSERVATIVE.to_score() == 0.125
|
||||
assert RiskTolerance.MODERATE.to_score() == 0.375
|
||||
assert RiskTolerance.AGGRESSIVE.to_score() == 0.625
|
||||
assert RiskTolerance.VERY_AGGRESSIVE.to_score() == 0.875
|
||||
|
||||
def test_roundtrip_approximate(self):
|
||||
"""Test from_score(to_score()) is consistent."""
|
||||
for tolerance in RiskTolerance:
|
||||
score = tolerance.to_score()
|
||||
recovered = RiskTolerance.from_score(score)
|
||||
assert recovered == tolerance
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MarketRegime Enum Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMarketRegime:
|
||||
"""Tests for MarketRegime enum."""
|
||||
|
||||
def test_all_regimes_defined(self):
|
||||
"""Test all expected regimes exist."""
|
||||
expected = ["bull", "bear", "sideways", "high_volatility", "low_volatility", "crisis"]
|
||||
for regime_value in expected:
|
||||
assert MarketRegime(regime_value) is not None
|
||||
|
||||
def test_regime_values(self):
|
||||
"""Test regime string values."""
|
||||
assert MarketRegime.BULL.value == "bull"
|
||||
assert MarketRegime.BEAR.value == "bear"
|
||||
assert MarketRegime.CRISIS.value == "crisis"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RiskCategory Enum Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRiskCategory:
|
||||
"""Tests for RiskCategory enum."""
|
||||
|
||||
def test_all_categories_defined(self):
|
||||
"""Test all expected categories exist."""
|
||||
expected = [
|
||||
"position_size", "leverage", "diversification",
|
||||
"hedging", "stop_loss", "sector_exposure", "asset_class"
|
||||
]
|
||||
for cat_value in expected:
|
||||
assert RiskCategory(cat_value) is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RiskDecision Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRiskDecision:
|
||||
"""Tests for RiskDecision dataclass."""
|
||||
|
||||
def test_create_decision(self):
|
||||
"""Test creating a risk decision."""
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.6,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Strong momentum in tech",
|
||||
)
|
||||
|
||||
assert decision.id is not None
|
||||
assert decision.category == RiskCategory.POSITION_SIZE
|
||||
assert decision.risk_level == 0.6
|
||||
assert decision.market_regime == MarketRegime.BULL
|
||||
assert decision.context == "Strong momentum in tech"
|
||||
assert decision.outcome is None
|
||||
assert decision.was_appropriate is None
|
||||
|
||||
def test_create_with_vix(self):
|
||||
"""Test creating decision with VIX level."""
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.LEVERAGE,
|
||||
risk_level=0.3,
|
||||
market_regime=MarketRegime.HIGH_VOLATILITY,
|
||||
context="Market stress",
|
||||
vix_level=32.5,
|
||||
)
|
||||
|
||||
assert decision.vix_level == 32.5
|
||||
|
||||
def test_create_with_notes(self):
|
||||
"""Test creating decision with notes."""
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.HEDGING,
|
||||
risk_level=0.4,
|
||||
market_regime=MarketRegime.BEAR,
|
||||
context="Protective puts",
|
||||
notes="Weekly expiration",
|
||||
)
|
||||
|
||||
assert decision.notes == "Weekly expiration"
|
||||
|
||||
def test_create_validates_risk_level_low(self):
|
||||
"""Test risk level validation - too low."""
|
||||
with pytest.raises(ValueError, match="Risk level must be between 0 and 1"):
|
||||
RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=-0.1,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Test",
|
||||
)
|
||||
|
||||
def test_create_validates_risk_level_high(self):
|
||||
"""Test risk level validation - too high."""
|
||||
with pytest.raises(ValueError, match="Risk level must be between 0 and 1"):
|
||||
RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=1.5,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Test",
|
||||
)
|
||||
|
||||
def test_create_boundary_values(self):
|
||||
"""Test boundary values for risk level."""
|
||||
decision_low = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.0,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Minimum risk",
|
||||
)
|
||||
assert decision_low.risk_level == 0.0
|
||||
|
||||
decision_high = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=1.0,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Maximum risk",
|
||||
)
|
||||
assert decision_high.risk_level == 1.0
|
||||
|
||||
def test_evaluate_decision(self):
|
||||
"""Test evaluating a decision with outcome."""
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.6,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Strong momentum",
|
||||
)
|
||||
|
||||
result = decision.evaluate(
|
||||
outcome="Profitable trade",
|
||||
outcome_score=0.8,
|
||||
was_appropriate=True,
|
||||
)
|
||||
|
||||
assert result is decision
|
||||
assert decision.outcome == "Profitable trade"
|
||||
assert decision.outcome_score == 0.8
|
||||
assert decision.was_appropriate is True
|
||||
|
||||
def test_evaluate_clamps_outcome_score(self):
|
||||
"""Test outcome score is clamped to [-1, 1]."""
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.LEVERAGE,
|
||||
risk_level=0.9,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="High leverage",
|
||||
)
|
||||
|
||||
decision.evaluate("Loss", -2.0, False)
|
||||
assert decision.outcome_score == -1.0
|
||||
|
||||
decision.evaluate("Huge win", 5.0, True)
|
||||
assert decision.outcome_score == 1.0
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization to dictionary."""
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.STOP_LOSS,
|
||||
risk_level=0.3,
|
||||
market_regime=MarketRegime.SIDEWAYS,
|
||||
context="Range-bound market",
|
||||
vix_level=18.0,
|
||||
)
|
||||
decision.evaluate("Hit stop", -0.5, True)
|
||||
|
||||
data = decision.to_dict()
|
||||
|
||||
assert data["id"] == decision.id
|
||||
assert data["category"] == "stop_loss"
|
||||
assert data["risk_level"] == 0.3
|
||||
assert data["market_regime"] == "sideways"
|
||||
assert data["context"] == "Range-bound market"
|
||||
assert data["vix_level"] == 18.0
|
||||
assert data["outcome"] == "Hit stop"
|
||||
assert data["outcome_score"] == -0.5
|
||||
assert data["was_appropriate"] is True
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test deserialization from dictionary."""
|
||||
original = RiskDecision.create(
|
||||
category=RiskCategory.DIVERSIFICATION,
|
||||
risk_level=0.5,
|
||||
market_regime=MarketRegime.LOW_VOLATILITY,
|
||||
context="Adding sectors",
|
||||
)
|
||||
original.evaluate("Good diversification", 0.3, True)
|
||||
|
||||
data = original.to_dict()
|
||||
restored = RiskDecision.from_dict(data)
|
||||
|
||||
assert restored.id == original.id
|
||||
assert restored.category == original.category
|
||||
assert restored.risk_level == original.risk_level
|
||||
assert restored.market_regime == original.market_regime
|
||||
assert restored.outcome == original.outcome
|
||||
assert restored.was_appropriate == original.was_appropriate
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RiskProfile Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRiskProfile:
|
||||
"""Tests for RiskProfile dataclass."""
|
||||
|
||||
def test_create_default_profile(self):
|
||||
"""Test creating profile with defaults."""
|
||||
profile = RiskProfile(user_id="user1")
|
||||
|
||||
assert profile.user_id == "user1"
|
||||
assert profile.base_tolerance == RiskTolerance.MODERATE
|
||||
assert profile.max_drawdown_tolerance == 0.20
|
||||
assert profile.volatility_preference == 0.15
|
||||
|
||||
def test_default_regime_adjustments(self):
|
||||
"""Test default regime adjustments are set."""
|
||||
profile = RiskProfile(user_id="user1")
|
||||
|
||||
assert profile.regime_adjustments[MarketRegime.BULL.value] == 0.1
|
||||
assert profile.regime_adjustments[MarketRegime.BEAR.value] == -0.2
|
||||
assert profile.regime_adjustments[MarketRegime.CRISIS.value] == -0.5
|
||||
|
||||
def test_get_adjusted_risk_score_bull(self):
|
||||
"""Test adjusted score in bull market."""
|
||||
profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.MODERATE)
|
||||
|
||||
score = profile.get_adjusted_risk_score(MarketRegime.BULL)
|
||||
|
||||
# MODERATE = 0.375, BULL adjustment = 0.1
|
||||
expected = 0.375 + 0.1
|
||||
assert abs(score - expected) < 0.001
|
||||
|
||||
def test_get_adjusted_risk_score_crisis(self):
|
||||
"""Test adjusted score in crisis."""
|
||||
profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE)
|
||||
|
||||
score = profile.get_adjusted_risk_score(MarketRegime.CRISIS)
|
||||
|
||||
# AGGRESSIVE = 0.625, CRISIS adjustment = -0.5
|
||||
expected = 0.625 - 0.5
|
||||
assert abs(score - expected) < 0.001
|
||||
|
||||
def test_get_adjusted_risk_score_clamped_low(self):
|
||||
"""Test adjusted score is clamped at 0."""
|
||||
profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.CONSERVATIVE)
|
||||
|
||||
score = profile.get_adjusted_risk_score(MarketRegime.CRISIS)
|
||||
|
||||
# CONSERVATIVE = 0.125, CRISIS = -0.5, would be negative
|
||||
assert score == 0.0
|
||||
|
||||
def test_get_adjusted_risk_score_clamped_high(self):
|
||||
"""Test adjusted score is clamped at 1."""
|
||||
profile = RiskProfile(
|
||||
user_id="user1",
|
||||
base_tolerance=RiskTolerance.VERY_AGGRESSIVE,
|
||||
regime_adjustments={MarketRegime.BULL.value: 0.5},
|
||||
)
|
||||
|
||||
score = profile.get_adjusted_risk_score(MarketRegime.BULL)
|
||||
|
||||
# VERY_AGGRESSIVE = 0.875, BULL = 0.5, would exceed 1.0
|
||||
assert score == 1.0
|
||||
|
||||
def test_get_adjusted_tolerance(self):
|
||||
"""Test getting adjusted tolerance enum."""
|
||||
profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE)
|
||||
|
||||
# In crisis, aggressive becomes moderate
|
||||
tolerance = profile.get_adjusted_tolerance(MarketRegime.CRISIS)
|
||||
assert tolerance == RiskTolerance.CONSERVATIVE
|
||||
|
||||
def test_update_regime_adjustment(self):
|
||||
"""Test updating regime adjustment."""
|
||||
profile = RiskProfile(user_id="user1")
|
||||
original_updated = profile.updated_at
|
||||
|
||||
profile.update_regime_adjustment(MarketRegime.BEAR, -0.4)
|
||||
|
||||
assert profile.regime_adjustments[MarketRegime.BEAR.value] == -0.4
|
||||
assert profile.updated_at >= original_updated
|
||||
|
||||
def test_update_regime_adjustment_clamped(self):
|
||||
"""Test adjustment is clamped to [-1, 1]."""
|
||||
profile = RiskProfile(user_id="user1")
|
||||
|
||||
profile.update_regime_adjustment(MarketRegime.BULL, 2.0)
|
||||
assert profile.regime_adjustments[MarketRegime.BULL.value] == 1.0
|
||||
|
||||
profile.update_regime_adjustment(MarketRegime.CRISIS, -2.0)
|
||||
assert profile.regime_adjustments[MarketRegime.CRISIS.value] == -1.0
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization to dictionary."""
|
||||
profile = RiskProfile(
|
||||
user_id="user1",
|
||||
base_tolerance=RiskTolerance.AGGRESSIVE,
|
||||
max_drawdown_tolerance=0.15,
|
||||
)
|
||||
|
||||
data = profile.to_dict()
|
||||
|
||||
assert data["user_id"] == "user1"
|
||||
assert data["base_tolerance"] == "aggressive"
|
||||
assert data["max_drawdown_tolerance"] == 0.15
|
||||
assert "regime_adjustments" in data
|
||||
assert "created_at" in data
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test deserialization from dictionary."""
|
||||
original = RiskProfile(
|
||||
user_id="user1",
|
||||
base_tolerance=RiskTolerance.CONSERVATIVE,
|
||||
max_drawdown_tolerance=0.10,
|
||||
)
|
||||
|
||||
data = original.to_dict()
|
||||
restored = RiskProfile.from_dict(data)
|
||||
|
||||
assert restored.user_id == original.user_id
|
||||
assert restored.base_tolerance == original.base_tolerance
|
||||
assert restored.max_drawdown_tolerance == original.max_drawdown_tolerance
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RiskProfileMemory Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRiskProfileMemory:
|
||||
"""Tests for RiskProfileMemory class."""
|
||||
|
||||
def test_create_memory(self):
|
||||
"""Test creating memory instance."""
|
||||
memory = RiskProfileMemory()
|
||||
assert memory.count() == 0
|
||||
|
||||
def test_set_and_get_profile(self):
|
||||
"""Test setting and getting a profile."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE)
|
||||
|
||||
memory.set_profile(profile)
|
||||
retrieved = memory.get_profile("user1")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.user_id == "user1"
|
||||
assert retrieved.base_tolerance == RiskTolerance.AGGRESSIVE
|
||||
|
||||
def test_get_profile_default_user(self):
|
||||
"""Test getting default user profile."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="default")
|
||||
memory.set_profile(profile)
|
||||
|
||||
retrieved = memory.get_profile()
|
||||
assert retrieved.user_id == "default"
|
||||
|
||||
def test_get_nonexistent_profile(self):
|
||||
"""Test getting nonexistent profile returns None."""
|
||||
memory = RiskProfileMemory()
|
||||
assert memory.get_profile("unknown") is None
|
||||
|
||||
def test_get_or_create_profile_new(self):
|
||||
"""Test get_or_create creates new profile."""
|
||||
memory = RiskProfileMemory()
|
||||
|
||||
profile = memory.get_or_create_profile("new_user", RiskTolerance.CONSERVATIVE)
|
||||
|
||||
assert profile.user_id == "new_user"
|
||||
assert profile.base_tolerance == RiskTolerance.CONSERVATIVE
|
||||
|
||||
def test_get_or_create_profile_existing(self):
|
||||
"""Test get_or_create returns existing profile."""
|
||||
memory = RiskProfileMemory()
|
||||
original = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE)
|
||||
memory.set_profile(original)
|
||||
|
||||
profile = memory.get_or_create_profile("user1", RiskTolerance.CONSERVATIVE)
|
||||
|
||||
assert profile.base_tolerance == RiskTolerance.AGGRESSIVE
|
||||
|
||||
def test_record_decision(self):
|
||||
"""Test recording a decision."""
|
||||
memory = RiskProfileMemory()
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.5,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Tech momentum",
|
||||
)
|
||||
|
||||
decision_id = memory.record_decision(decision)
|
||||
|
||||
assert decision_id == decision.id
|
||||
assert memory.count() == 1
|
||||
|
||||
def test_get_decision(self):
|
||||
"""Test retrieving a decision by ID."""
|
||||
memory = RiskProfileMemory()
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.LEVERAGE,
|
||||
risk_level=0.7,
|
||||
market_regime=MarketRegime.LOW_VOLATILITY,
|
||||
context="Low vol environment",
|
||||
)
|
||||
memory.record_decision(decision)
|
||||
|
||||
retrieved = memory.get_decision(decision.id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.risk_level == 0.7
|
||||
|
||||
def test_get_nonexistent_decision(self):
|
||||
"""Test getting nonexistent decision returns None."""
|
||||
memory = RiskProfileMemory()
|
||||
assert memory.get_decision("unknown") is None
|
||||
|
||||
def test_evaluate_decision(self):
|
||||
"""Test evaluating a recorded decision."""
|
||||
memory = RiskProfileMemory()
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.6,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Strong momentum",
|
||||
)
|
||||
memory.record_decision(decision)
|
||||
|
||||
result = memory.evaluate_decision(
|
||||
decision_id=decision.id,
|
||||
outcome="Profitable",
|
||||
outcome_score=0.5,
|
||||
was_appropriate=True,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.outcome == "Profitable"
|
||||
assert result.was_appropriate is True
|
||||
|
||||
def test_evaluate_nonexistent_decision(self):
|
||||
"""Test evaluating nonexistent decision."""
|
||||
memory = RiskProfileMemory()
|
||||
result = memory.evaluate_decision("unknown", "Test", 0.5, True)
|
||||
assert result is None
|
||||
|
||||
def test_find_similar_decisions(self):
|
||||
"""Test finding similar decisions."""
|
||||
memory = RiskProfileMemory()
|
||||
|
||||
# Record several decisions
|
||||
for i in range(5):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.5 + i * 0.05,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context=f"Tech momentum scenario {i}",
|
||||
)
|
||||
memory.record_decision(decision)
|
||||
|
||||
similar = memory.find_similar_decisions(
|
||||
context="Tech momentum similar",
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
market_regime=MarketRegime.BULL,
|
||||
top_k=3,
|
||||
)
|
||||
|
||||
assert len(similar) <= 3
|
||||
for decision in similar:
|
||||
assert decision.category == RiskCategory.POSITION_SIZE
|
||||
|
||||
def test_find_similar_decisions_with_filters(self):
|
||||
"""Test finding similar decisions with regime filter."""
|
||||
memory = RiskProfileMemory()
|
||||
|
||||
# Record decisions in different regimes
|
||||
for regime in [MarketRegime.BULL, MarketRegime.BEAR]:
|
||||
for i in range(3):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.5,
|
||||
market_regime=regime,
|
||||
context=f"Scenario {i}",
|
||||
)
|
||||
memory.record_decision(decision)
|
||||
|
||||
# Filter by BULL regime only
|
||||
similar = memory.find_similar_decisions(
|
||||
context="Find scenario",
|
||||
market_regime=MarketRegime.BULL,
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
for decision in similar:
|
||||
assert decision.market_regime == MarketRegime.BULL
|
||||
|
||||
def test_recommend_risk_level_no_history(self):
|
||||
"""Test recommendation without history uses profile only."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
|
||||
memory.set_profile(profile)
|
||||
|
||||
risk_level, explanation = memory.recommend_risk_level(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="New situation",
|
||||
use_history=False,
|
||||
)
|
||||
|
||||
# MODERATE (0.375) + BULL adjustment (0.1)
|
||||
expected = 0.375 + 0.1
|
||||
assert abs(risk_level - expected) < 0.01
|
||||
assert "Base risk from profile" in explanation
|
||||
|
||||
def test_recommend_risk_level_with_history(self):
|
||||
"""Test recommendation with historical decisions."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
|
||||
memory.set_profile(profile)
|
||||
|
||||
# Record successful decisions
|
||||
for i in range(3):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.7,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context=f"Momentum play {i}",
|
||||
)
|
||||
decision.evaluate("Profitable", 0.6, True)
|
||||
memory.record_decision(decision)
|
||||
|
||||
risk_level, explanation = memory.recommend_risk_level(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Similar momentum play",
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
# Should be influenced by successful 0.7 risk decisions
|
||||
assert risk_level > 0.5 # Should be higher than base moderate
|
||||
assert "successful similar decisions" in explanation
|
||||
|
||||
def test_recommend_warns_about_unsuccessful(self):
|
||||
"""Test recommendation when only unsuccessful decisions exist."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
|
||||
memory.set_profile(profile)
|
||||
|
||||
# Record unsuccessful decisions with matching context words
|
||||
for i in range(3):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.LEVERAGE,
|
||||
risk_level=0.8,
|
||||
market_regime=MarketRegime.HIGH_VOLATILITY,
|
||||
context="High leverage situation volatility trade",
|
||||
)
|
||||
decision.evaluate("Loss", -0.5, False)
|
||||
memory.record_decision(decision)
|
||||
|
||||
risk_level, explanation = memory.recommend_risk_level(
|
||||
category=RiskCategory.LEVERAGE,
|
||||
market_regime=MarketRegime.HIGH_VOLATILITY,
|
||||
context="High leverage situation volatility trade",
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
# Should either use base risk (no successful similar) or warn about unsuccessful
|
||||
# The explanation should NOT include "successful similar decisions" since there are none
|
||||
assert "successful similar decisions" not in explanation or "WARNING" in explanation
|
||||
|
||||
def test_get_regime_statistics_empty(self):
|
||||
"""Test regime statistics with no decisions."""
|
||||
memory = RiskProfileMemory()
|
||||
stats = memory.get_regime_statistics()
|
||||
|
||||
for regime in MarketRegime:
|
||||
assert stats[regime.value]["count"] == 0
|
||||
assert stats[regime.value]["avg_risk_level"] is None
|
||||
|
||||
def test_get_regime_statistics(self):
|
||||
"""Test regime statistics with decisions."""
|
||||
memory = RiskProfileMemory()
|
||||
|
||||
# Record decisions in bull market
|
||||
for i in range(5):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.6 + i * 0.02,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context=f"Bull scenario {i}",
|
||||
)
|
||||
if i < 3:
|
||||
decision.evaluate("Win", 0.5, True)
|
||||
else:
|
||||
decision.evaluate("Loss", -0.3, False)
|
||||
memory.record_decision(decision)
|
||||
|
||||
stats = memory.get_regime_statistics()
|
||||
|
||||
assert stats[MarketRegime.BULL.value]["count"] == 5
|
||||
assert 0.6 <= stats[MarketRegime.BULL.value]["avg_risk_level"] <= 0.7
|
||||
assert stats[MarketRegime.BULL.value]["success_rate"] == 0.6 # 3/5
|
||||
|
||||
def test_get_category_statistics(self):
|
||||
"""Test category statistics."""
|
||||
memory = RiskProfileMemory()
|
||||
|
||||
for i in range(4):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.STOP_LOSS,
|
||||
risk_level=0.3 + i * 0.05,
|
||||
market_regime=MarketRegime.SIDEWAYS,
|
||||
context=f"Stop loss scenario {i}",
|
||||
)
|
||||
decision.evaluate("Hit stop", -0.2, i % 2 == 0)
|
||||
memory.record_decision(decision)
|
||||
|
||||
stats = memory.get_category_statistics()
|
||||
|
||||
assert stats[RiskCategory.STOP_LOSS.value]["count"] == 4
|
||||
assert stats[RiskCategory.STOP_LOSS.value]["success_rate"] == 0.5
|
||||
|
||||
def test_learn_regime_adjustments_insufficient_data(self):
|
||||
"""Test learning with insufficient data."""
|
||||
memory = RiskProfileMemory()
|
||||
|
||||
# Only 2 decisions (below min_decisions=5)
|
||||
for i in range(2):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.5,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context=f"Scenario {i}",
|
||||
)
|
||||
decision.evaluate("Win", 0.5, True)
|
||||
memory.record_decision(decision)
|
||||
|
||||
suggestions = memory.learn_regime_adjustments(min_decisions=5)
|
||||
|
||||
# Should not have suggestions for BULL (only 2 decisions)
|
||||
assert MarketRegime.BULL.value not in suggestions
|
||||
|
||||
def test_learn_regime_adjustments_successful(self):
|
||||
"""Test learning from successful decisions."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
|
||||
memory.set_profile(profile)
|
||||
|
||||
# Record successful high-risk decisions in bull
|
||||
for i in range(6):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.7, # Higher than base (0.375)
|
||||
market_regime=MarketRegime.BULL,
|
||||
context=f"Successful bull trade {i}",
|
||||
)
|
||||
decision.evaluate("Profit", 0.6, True)
|
||||
memory.record_decision(decision)
|
||||
|
||||
suggestions = memory.learn_regime_adjustments(min_decisions=5)
|
||||
|
||||
# Should suggest positive adjustment for bull (0.7 > 0.375)
|
||||
assert MarketRegime.BULL.value in suggestions
|
||||
assert suggestions[MarketRegime.BULL.value] > 0
|
||||
|
||||
def test_learn_regime_adjustments_unsuccessful(self):
|
||||
"""Test learning from unsuccessful decisions."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.AGGRESSIVE)
|
||||
memory.set_profile(profile)
|
||||
|
||||
# Record all unsuccessful decisions in crisis
|
||||
for i in range(6):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.LEVERAGE,
|
||||
risk_level=0.8,
|
||||
market_regime=MarketRegime.CRISIS,
|
||||
context=f"Crisis loss {i}",
|
||||
)
|
||||
decision.evaluate("Loss", -0.7, False)
|
||||
memory.record_decision(decision)
|
||||
|
||||
suggestions = memory.learn_regime_adjustments(min_decisions=5)
|
||||
|
||||
# Should suggest negative adjustment (lower risk in crisis)
|
||||
assert MarketRegime.CRISIS.value in suggestions
|
||||
assert suggestions[MarketRegime.CRISIS.value] < 0
|
||||
|
||||
def test_count(self):
|
||||
"""Test decision count."""
|
||||
memory = RiskProfileMemory()
|
||||
assert memory.count() == 0
|
||||
|
||||
for i in range(5):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.5,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context=f"Scenario {i}",
|
||||
)
|
||||
memory.record_decision(decision)
|
||||
|
||||
assert memory.count() == 5
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing decisions (preserving profiles)."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="user1")
|
||||
memory.set_profile(profile)
|
||||
|
||||
for i in range(3):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.5,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context=f"Scenario {i}",
|
||||
)
|
||||
memory.record_decision(decision)
|
||||
|
||||
assert memory.count() == 3
|
||||
cleared = memory.clear()
|
||||
|
||||
assert cleared == 3
|
||||
assert memory.count() == 0
|
||||
assert memory.get_profile("user1") is not None # Profile preserved
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization to dictionary."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE)
|
||||
memory.set_profile(profile)
|
||||
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.6,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Test scenario",
|
||||
)
|
||||
memory.record_decision(decision)
|
||||
|
||||
data = memory.to_dict()
|
||||
|
||||
assert "profiles" in data
|
||||
assert "user1" in data["profiles"]
|
||||
assert "decisions" in data
|
||||
assert len(data["decisions"]) == 1
|
||||
assert "memory" in data
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test deserialization from dictionary."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.CONSERVATIVE)
|
||||
memory.set_profile(profile)
|
||||
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.HEDGING,
|
||||
risk_level=0.4,
|
||||
market_regime=MarketRegime.BEAR,
|
||||
context="Protective strategy",
|
||||
)
|
||||
decision.evaluate("Good protection", 0.5, True)
|
||||
memory.record_decision(decision)
|
||||
|
||||
data = memory.to_dict()
|
||||
restored = RiskProfileMemory.from_dict(data)
|
||||
|
||||
assert restored.get_profile("user1") is not None
|
||||
assert restored.get_profile("user1").base_tolerance == RiskTolerance.CONSERVATIVE
|
||||
assert restored.count() == 1
|
||||
assert restored.get_decision(decision.id) is not None
|
||||
|
||||
def test_multiple_users(self):
|
||||
"""Test handling multiple user profiles."""
|
||||
memory = RiskProfileMemory()
|
||||
|
||||
users = ["alice", "bob", "charlie"]
|
||||
tolerances = [RiskTolerance.CONSERVATIVE, RiskTolerance.MODERATE, RiskTolerance.AGGRESSIVE]
|
||||
|
||||
for user, tolerance in zip(users, tolerances):
|
||||
profile = RiskProfile(user_id=user, base_tolerance=tolerance)
|
||||
memory.set_profile(profile)
|
||||
|
||||
for user, expected_tolerance in zip(users, tolerances):
|
||||
profile = memory.get_profile(user)
|
||||
assert profile.base_tolerance == expected_tolerance
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRiskProfileMemoryIntegration:
|
||||
"""Integration tests for RiskProfileMemory."""
|
||||
|
||||
def test_full_workflow(self):
|
||||
"""Test complete workflow: profile -> decisions -> learning."""
|
||||
memory = RiskProfileMemory()
|
||||
|
||||
# 1. Create profile
|
||||
profile = RiskProfile(
|
||||
user_id="trader1",
|
||||
base_tolerance=RiskTolerance.MODERATE,
|
||||
max_drawdown_tolerance=0.15,
|
||||
)
|
||||
memory.set_profile(profile)
|
||||
|
||||
# 2. Record decisions across regimes
|
||||
decisions_data = [
|
||||
(MarketRegime.BULL, 0.6, True),
|
||||
(MarketRegime.BULL, 0.7, True),
|
||||
(MarketRegime.BULL, 0.65, True),
|
||||
(MarketRegime.BEAR, 0.3, True),
|
||||
(MarketRegime.BEAR, 0.5, False),
|
||||
(MarketRegime.CRISIS, 0.2, True),
|
||||
]
|
||||
|
||||
for regime, risk, appropriate in decisions_data:
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=risk,
|
||||
market_regime=regime,
|
||||
context=f"Trading in {regime.value}",
|
||||
)
|
||||
outcome_score = 0.5 if appropriate else -0.5
|
||||
decision.evaluate("Result", outcome_score, appropriate)
|
||||
memory.record_decision(decision, user_id="trader1")
|
||||
|
||||
# 3. Get recommendations
|
||||
bull_risk, _ = memory.recommend_risk_level(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Bull market opportunity",
|
||||
user_id="trader1",
|
||||
)
|
||||
|
||||
bear_risk, _ = memory.recommend_risk_level(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
market_regime=MarketRegime.BEAR,
|
||||
context="Bear market caution",
|
||||
user_id="trader1",
|
||||
)
|
||||
|
||||
# Bull should recommend higher risk than bear
|
||||
assert bull_risk > bear_risk
|
||||
|
||||
# 4. Get statistics
|
||||
stats = memory.get_regime_statistics()
|
||||
assert stats[MarketRegime.BULL.value]["count"] == 3
|
||||
assert stats[MarketRegime.BULL.value]["success_rate"] == 1.0
|
||||
|
||||
# 5. Serialize and restore
|
||||
data = memory.to_dict()
|
||||
restored = RiskProfileMemory.from_dict(data)
|
||||
|
||||
assert restored.get_profile("trader1") is not None
|
||||
assert restored.count() == 6
|
||||
|
||||
def test_recommendation_adapts_over_time(self):
|
||||
"""Test that recommendations adapt as more data is collected."""
|
||||
memory = RiskProfileMemory()
|
||||
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
|
||||
memory.set_profile(profile)
|
||||
|
||||
# Initial recommendation (no history)
|
||||
initial_risk, _ = memory.recommend_risk_level(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Starting out",
|
||||
)
|
||||
|
||||
# Add successful high-risk decisions
|
||||
for i in range(5):
|
||||
decision = RiskDecision.create(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
risk_level=0.8,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context=f"Successful trade {i}",
|
||||
)
|
||||
decision.evaluate("Win", 0.7, True)
|
||||
memory.record_decision(decision)
|
||||
|
||||
# Later recommendation should be influenced by history
|
||||
later_risk, explanation = memory.recommend_risk_level(
|
||||
category=RiskCategory.POSITION_SIZE,
|
||||
market_regime=MarketRegime.BULL,
|
||||
context="Similar opportunity",
|
||||
)
|
||||
|
||||
# Should be higher after seeing successful high-risk trades
|
||||
assert later_risk > initial_risk
|
||||
assert "successful similar decisions" in explanation
|
||||
|
|
@ -7,6 +7,7 @@ This module provides a layered memory system with three scoring dimensions:
|
|||
|
||||
Issue #18: Layered memory - recency, relevancy, importance scoring
|
||||
Issue #19: Trade history memory - outcomes, agent reasoning
|
||||
Issue #20: Risk profiles memory - user preferences over time
|
||||
"""
|
||||
|
||||
from .layered_memory import (
|
||||
|
|
@ -28,6 +29,15 @@ from .trade_history import (
|
|||
MarketContext,
|
||||
)
|
||||
|
||||
from .risk_profiles import (
|
||||
RiskProfileMemory,
|
||||
RiskProfile,
|
||||
RiskDecision,
|
||||
RiskTolerance,
|
||||
MarketRegime,
|
||||
RiskCategory,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Layered Memory (Issue #18)
|
||||
"LayeredMemory",
|
||||
|
|
@ -44,4 +54,11 @@ __all__ = [
|
|||
"SignalStrength",
|
||||
"AgentReasoning",
|
||||
"MarketContext",
|
||||
# Risk Profiles (Issue #20)
|
||||
"RiskProfileMemory",
|
||||
"RiskProfile",
|
||||
"RiskDecision",
|
||||
"RiskTolerance",
|
||||
"MarketRegime",
|
||||
"RiskCategory",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,817 @@
|
|||
"""Risk Profiles Memory for tracking user risk preferences over time.
|
||||
|
||||
This module provides memory for tracking and learning from risk preferences:
|
||||
- Risk tolerance levels across different market conditions
|
||||
- Risk preference evolution over time
|
||||
- Market regime-specific risk adjustments
|
||||
- Historical risk decisions and outcomes
|
||||
|
||||
Issue #20: [MEM-19] Risk profiles memory - user preferences over time
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import statistics
|
||||
import uuid
|
||||
|
||||
from .layered_memory import (
|
||||
LayeredMemory,
|
||||
MemoryEntry,
|
||||
MemoryConfig,
|
||||
ScoringWeights,
|
||||
ImportanceLevel,
|
||||
)
|
||||
|
||||
|
||||
class RiskTolerance(Enum):
|
||||
"""User risk tolerance levels."""
|
||||
CONSERVATIVE = "conservative" # Low risk, capital preservation
|
||||
MODERATE = "moderate" # Balanced risk/reward
|
||||
AGGRESSIVE = "aggressive" # High risk, growth focused
|
||||
VERY_AGGRESSIVE = "very_aggressive" # Maximum risk tolerance
|
||||
|
||||
@classmethod
|
||||
def from_score(cls, score: float) -> "RiskTolerance":
|
||||
"""Convert a risk score (0-1) to RiskTolerance.
|
||||
|
||||
Args:
|
||||
score: Risk score between 0 (conservative) and 1 (very aggressive)
|
||||
|
||||
Returns:
|
||||
Corresponding RiskTolerance
|
||||
"""
|
||||
if score < 0.25:
|
||||
return cls.CONSERVATIVE
|
||||
elif score < 0.50:
|
||||
return cls.MODERATE
|
||||
elif score < 0.75:
|
||||
return cls.AGGRESSIVE
|
||||
else:
|
||||
return cls.VERY_AGGRESSIVE
|
||||
|
||||
def to_score(self) -> float:
|
||||
"""Convert RiskTolerance to a numeric score.
|
||||
|
||||
Returns:
|
||||
Score between 0 and 1
|
||||
"""
|
||||
mapping = {
|
||||
RiskTolerance.CONSERVATIVE: 0.125,
|
||||
RiskTolerance.MODERATE: 0.375,
|
||||
RiskTolerance.AGGRESSIVE: 0.625,
|
||||
RiskTolerance.VERY_AGGRESSIVE: 0.875,
|
||||
}
|
||||
return mapping[self]
|
||||
|
||||
|
||||
class MarketRegime(Enum):
|
||||
"""Market regime classifications."""
|
||||
BULL = "bull" # Strong uptrend
|
||||
BEAR = "bear" # Strong downtrend
|
||||
SIDEWAYS = "sideways" # Range-bound
|
||||
HIGH_VOLATILITY = "high_volatility" # VIX > 25
|
||||
LOW_VOLATILITY = "low_volatility" # VIX < 15
|
||||
CRISIS = "crisis" # Market stress/crash
|
||||
|
||||
|
||||
class RiskCategory(Enum):
|
||||
"""Categories of risk decisions."""
|
||||
POSITION_SIZE = "position_size" # How much to invest
|
||||
LEVERAGE = "leverage" # Use of leverage
|
||||
DIVERSIFICATION = "diversification" # Portfolio spread
|
||||
HEDGING = "hedging" # Protective positions
|
||||
STOP_LOSS = "stop_loss" # Exit thresholds
|
||||
SECTOR_EXPOSURE = "sector_exposure" # Sector concentration
|
||||
ASSET_CLASS = "asset_class" # Asset allocation
|
||||
|
||||
|
||||
@dataclass
|
||||
class RiskDecision:
|
||||
"""A recorded risk decision with context and outcome.
|
||||
|
||||
Attributes:
|
||||
id: Unique decision ID
|
||||
timestamp: When decision was made
|
||||
category: Type of risk decision
|
||||
risk_level: Risk level chosen (0-1 scale)
|
||||
market_regime: Market conditions at decision time
|
||||
context: Situation description
|
||||
vix_level: VIX at decision time
|
||||
outcome: Outcome description (added later)
|
||||
outcome_score: Quantified outcome (-1 to 1)
|
||||
was_appropriate: Whether decision was appropriate in hindsight
|
||||
notes: Additional notes
|
||||
"""
|
||||
id: str
|
||||
timestamp: datetime
|
||||
category: RiskCategory
|
||||
risk_level: float # 0 (min risk) to 1 (max risk)
|
||||
market_regime: MarketRegime
|
||||
context: str
|
||||
vix_level: Optional[float] = None
|
||||
outcome: Optional[str] = None
|
||||
outcome_score: Optional[float] = None # -1 (bad) to 1 (good)
|
||||
was_appropriate: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
category: RiskCategory,
|
||||
risk_level: float,
|
||||
market_regime: MarketRegime,
|
||||
context: str,
|
||||
vix_level: Optional[float] = None,
|
||||
notes: Optional[str] = None,
|
||||
) -> "RiskDecision":
|
||||
"""Create a new risk decision record.
|
||||
|
||||
Args:
|
||||
category: Type of risk decision
|
||||
risk_level: Risk level chosen (0-1)
|
||||
market_regime: Current market regime
|
||||
context: Situation description
|
||||
vix_level: Current VIX level
|
||||
notes: Additional notes
|
||||
|
||||
Returns:
|
||||
New RiskDecision instance
|
||||
"""
|
||||
if not 0.0 <= risk_level <= 1.0:
|
||||
raise ValueError(f"Risk level must be between 0 and 1, got {risk_level}")
|
||||
|
||||
return cls(
|
||||
id=str(uuid.uuid4()),
|
||||
timestamp=datetime.now(),
|
||||
category=category,
|
||||
risk_level=risk_level,
|
||||
market_regime=market_regime,
|
||||
context=context,
|
||||
vix_level=vix_level,
|
||||
notes=notes,
|
||||
)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
outcome: str,
|
||||
outcome_score: float,
|
||||
was_appropriate: bool,
|
||||
) -> "RiskDecision":
|
||||
"""Evaluate the decision after the outcome is known.
|
||||
|
||||
Args:
|
||||
outcome: Description of what happened
|
||||
outcome_score: Quantified outcome (-1 to 1)
|
||||
was_appropriate: Whether the risk level was appropriate
|
||||
|
||||
Returns:
|
||||
Self with updated evaluation
|
||||
"""
|
||||
self.outcome = outcome
|
||||
self.outcome_score = max(-1.0, min(1.0, outcome_score))
|
||||
self.was_appropriate = was_appropriate
|
||||
return self
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"category": self.category.value,
|
||||
"risk_level": self.risk_level,
|
||||
"market_regime": self.market_regime.value,
|
||||
"context": self.context,
|
||||
"vix_level": self.vix_level,
|
||||
"outcome": self.outcome,
|
||||
"outcome_score": self.outcome_score,
|
||||
"was_appropriate": self.was_appropriate,
|
||||
"notes": self.notes,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "RiskDecision":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"]),
|
||||
category=RiskCategory(data["category"]),
|
||||
risk_level=data["risk_level"],
|
||||
market_regime=MarketRegime(data["market_regime"]),
|
||||
context=data["context"],
|
||||
vix_level=data.get("vix_level"),
|
||||
outcome=data.get("outcome"),
|
||||
outcome_score=data.get("outcome_score"),
|
||||
was_appropriate=data.get("was_appropriate"),
|
||||
notes=data.get("notes"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RiskProfile:
|
||||
"""User's risk profile with preferences and history.
|
||||
|
||||
Attributes:
|
||||
user_id: User identifier
|
||||
base_tolerance: Baseline risk tolerance
|
||||
regime_adjustments: Adjustments by market regime
|
||||
category_preferences: Preferences by risk category
|
||||
max_drawdown_tolerance: Maximum acceptable drawdown
|
||||
volatility_preference: Preferred portfolio volatility
|
||||
created_at: Profile creation time
|
||||
updated_at: Last update time
|
||||
"""
|
||||
user_id: str
|
||||
base_tolerance: RiskTolerance = RiskTolerance.MODERATE
|
||||
regime_adjustments: Dict[str, float] = field(default_factory=dict)
|
||||
category_preferences: Dict[str, float] = field(default_factory=dict)
|
||||
max_drawdown_tolerance: float = 0.20 # 20% max drawdown
|
||||
volatility_preference: float = 0.15 # 15% annual volatility
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize default regime adjustments if empty."""
|
||||
if not self.regime_adjustments:
|
||||
self.regime_adjustments = {
|
||||
MarketRegime.BULL.value: 0.1, # Slightly more risk
|
||||
MarketRegime.BEAR.value: -0.2, # Reduce risk
|
||||
MarketRegime.SIDEWAYS.value: 0.0, # No change
|
||||
MarketRegime.HIGH_VOLATILITY.value: -0.3, # Reduce significantly
|
||||
MarketRegime.LOW_VOLATILITY.value: 0.1, # Slightly more risk
|
||||
MarketRegime.CRISIS.value: -0.5, # Maximum reduction
|
||||
}
|
||||
|
||||
def get_adjusted_risk_score(self, market_regime: MarketRegime) -> float:
|
||||
"""Get risk score adjusted for current market regime.
|
||||
|
||||
Args:
|
||||
market_regime: Current market regime
|
||||
|
||||
Returns:
|
||||
Adjusted risk score (0-1)
|
||||
"""
|
||||
base_score = self.base_tolerance.to_score()
|
||||
adjustment = self.regime_adjustments.get(market_regime.value, 0.0)
|
||||
adjusted = base_score + adjustment
|
||||
return max(0.0, min(1.0, adjusted))
|
||||
|
||||
def get_adjusted_tolerance(self, market_regime: MarketRegime) -> RiskTolerance:
|
||||
"""Get risk tolerance adjusted for market regime.
|
||||
|
||||
Args:
|
||||
market_regime: Current market regime
|
||||
|
||||
Returns:
|
||||
Adjusted RiskTolerance
|
||||
"""
|
||||
score = self.get_adjusted_risk_score(market_regime)
|
||||
return RiskTolerance.from_score(score)
|
||||
|
||||
def update_regime_adjustment(
|
||||
self,
|
||||
regime: MarketRegime,
|
||||
adjustment: float,
|
||||
) -> None:
|
||||
"""Update the adjustment for a specific regime.
|
||||
|
||||
Args:
|
||||
regime: Market regime to update
|
||||
adjustment: New adjustment value (-1 to 1)
|
||||
"""
|
||||
self.regime_adjustments[regime.value] = max(-1.0, min(1.0, adjustment))
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"base_tolerance": self.base_tolerance.value,
|
||||
"regime_adjustments": self.regime_adjustments,
|
||||
"category_preferences": self.category_preferences,
|
||||
"max_drawdown_tolerance": self.max_drawdown_tolerance,
|
||||
"volatility_preference": self.volatility_preference,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "RiskProfile":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
user_id=data["user_id"],
|
||||
base_tolerance=RiskTolerance(data["base_tolerance"]),
|
||||
regime_adjustments=data.get("regime_adjustments", {}),
|
||||
category_preferences=data.get("category_preferences", {}),
|
||||
max_drawdown_tolerance=data.get("max_drawdown_tolerance", 0.20),
|
||||
volatility_preference=data.get("volatility_preference", 0.15),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]),
|
||||
)
|
||||
|
||||
|
||||
class RiskProfileMemory:
|
||||
"""Memory system for tracking risk profiles and decisions.
|
||||
|
||||
This class provides storage and retrieval for risk profiles and
|
||||
historical risk decisions, enabling learning from past decisions.
|
||||
|
||||
Example:
|
||||
>>> memory = RiskProfileMemory()
|
||||
>>>
|
||||
>>> # Create a risk profile
|
||||
>>> profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.MODERATE)
|
||||
>>> memory.set_profile(profile)
|
||||
>>>
|
||||
>>> # Record a risk decision
|
||||
>>> decision = RiskDecision.create(
|
||||
... category=RiskCategory.POSITION_SIZE,
|
||||
... risk_level=0.6,
|
||||
... market_regime=MarketRegime.BULL,
|
||||
... context="Strong momentum in tech sector",
|
||||
... )
|
||||
>>> memory.record_decision(decision)
|
||||
>>>
|
||||
>>> # Get recommended risk level for similar situation
|
||||
>>> recommended = memory.recommend_risk_level(
|
||||
... category=RiskCategory.POSITION_SIZE,
|
||||
... market_regime=MarketRegime.BULL,
|
||||
... context="Tech sector showing strength",
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[MemoryConfig] = None,
|
||||
embedding_function=None,
|
||||
):
|
||||
"""Initialize risk profile memory.
|
||||
|
||||
Args:
|
||||
config: Memory configuration
|
||||
embedding_function: Optional embedding function
|
||||
"""
|
||||
if config is None:
|
||||
config = MemoryConfig(
|
||||
weights=ScoringWeights(
|
||||
recency=0.35, # Recent decisions more relevant
|
||||
relevancy=0.40, # Similar situations important
|
||||
importance=0.25, # Outcome importance
|
||||
),
|
||||
)
|
||||
|
||||
self._layered_memory = LayeredMemory(
|
||||
config=config,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
self._profiles: Dict[str, RiskProfile] = {}
|
||||
self._decisions: Dict[str, RiskDecision] = {}
|
||||
self._default_user_id = "default"
|
||||
|
||||
def set_profile(self, profile: RiskProfile) -> None:
|
||||
"""Set or update a user's risk profile.
|
||||
|
||||
Args:
|
||||
profile: Risk profile to store
|
||||
"""
|
||||
self._profiles[profile.user_id] = profile
|
||||
|
||||
def get_profile(self, user_id: Optional[str] = None) -> Optional[RiskProfile]:
|
||||
"""Get a user's risk profile.
|
||||
|
||||
Args:
|
||||
user_id: User ID (default: "default")
|
||||
|
||||
Returns:
|
||||
RiskProfile or None
|
||||
"""
|
||||
user_id = user_id or self._default_user_id
|
||||
return self._profiles.get(user_id)
|
||||
|
||||
def get_or_create_profile(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
base_tolerance: RiskTolerance = RiskTolerance.MODERATE,
|
||||
) -> RiskProfile:
|
||||
"""Get existing profile or create a new one.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
base_tolerance: Default tolerance if creating new
|
||||
|
||||
Returns:
|
||||
RiskProfile
|
||||
"""
|
||||
user_id = user_id or self._default_user_id
|
||||
profile = self._profiles.get(user_id)
|
||||
|
||||
if profile is None:
|
||||
profile = RiskProfile(user_id=user_id, base_tolerance=base_tolerance)
|
||||
self._profiles[user_id] = profile
|
||||
|
||||
return profile
|
||||
|
||||
def record_decision(
|
||||
self,
|
||||
decision: RiskDecision,
|
||||
user_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Record a risk decision.
|
||||
|
||||
Args:
|
||||
decision: The risk decision to record
|
||||
user_id: User ID (default: "default")
|
||||
|
||||
Returns:
|
||||
Decision ID
|
||||
"""
|
||||
user_id = user_id or self._default_user_id
|
||||
self._decisions[decision.id] = decision
|
||||
|
||||
# Calculate importance based on outcome if available
|
||||
importance = 0.5
|
||||
if decision.outcome_score is not None:
|
||||
importance = 0.5 + (abs(decision.outcome_score) * 0.5)
|
||||
|
||||
# Create memory entry
|
||||
content = (
|
||||
f"Risk decision: {decision.category.value} with risk level "
|
||||
f"{decision.risk_level:.2f} in {decision.market_regime.value} market. "
|
||||
f"Context: {decision.context}"
|
||||
)
|
||||
|
||||
entry = MemoryEntry.create(
|
||||
content=content,
|
||||
metadata={
|
||||
"decision_id": decision.id,
|
||||
"user_id": user_id,
|
||||
"category": decision.category.value,
|
||||
"risk_level": decision.risk_level,
|
||||
"market_regime": decision.market_regime.value,
|
||||
"vix_level": decision.vix_level,
|
||||
"outcome": decision.outcome,
|
||||
"outcome_score": decision.outcome_score,
|
||||
"was_appropriate": decision.was_appropriate,
|
||||
},
|
||||
importance=importance,
|
||||
tags=[
|
||||
user_id,
|
||||
decision.category.value,
|
||||
decision.market_regime.value,
|
||||
],
|
||||
timestamp=decision.timestamp,
|
||||
)
|
||||
entry.id = decision.id
|
||||
|
||||
self._layered_memory.add(entry)
|
||||
return decision.id
|
||||
|
||||
def evaluate_decision(
|
||||
self,
|
||||
decision_id: str,
|
||||
outcome: str,
|
||||
outcome_score: float,
|
||||
was_appropriate: bool,
|
||||
) -> Optional[RiskDecision]:
|
||||
"""Evaluate a past decision with hindsight.
|
||||
|
||||
Args:
|
||||
decision_id: ID of the decision
|
||||
outcome: What happened
|
||||
outcome_score: Quantified outcome (-1 to 1)
|
||||
was_appropriate: Whether decision was appropriate
|
||||
|
||||
Returns:
|
||||
Updated decision or None
|
||||
"""
|
||||
decision = self._decisions.get(decision_id)
|
||||
if decision is None:
|
||||
return None
|
||||
|
||||
decision.evaluate(outcome, outcome_score, was_appropriate)
|
||||
|
||||
# Update memory importance
|
||||
importance = 0.5 + (abs(outcome_score) * 0.5)
|
||||
self._layered_memory.update_importance(decision_id, importance)
|
||||
|
||||
return decision
|
||||
|
||||
def get_decision(self, decision_id: str) -> Optional[RiskDecision]:
|
||||
"""Get a decision by ID.
|
||||
|
||||
Args:
|
||||
decision_id: Decision ID
|
||||
|
||||
Returns:
|
||||
RiskDecision or None
|
||||
"""
|
||||
return self._decisions.get(decision_id)
|
||||
|
||||
def find_similar_decisions(
|
||||
self,
|
||||
context: str,
|
||||
category: Optional[RiskCategory] = None,
|
||||
market_regime: Optional[MarketRegime] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[RiskDecision]:
|
||||
"""Find similar past decisions.
|
||||
|
||||
Args:
|
||||
context: Current situation context
|
||||
category: Optional filter by category
|
||||
market_regime: Optional filter by regime
|
||||
top_k: Maximum results
|
||||
|
||||
Returns:
|
||||
List of similar decisions
|
||||
"""
|
||||
tags = []
|
||||
if category:
|
||||
tags.append(category.value)
|
||||
if market_regime:
|
||||
tags.append(market_regime.value)
|
||||
|
||||
results = self._layered_memory.retrieve(
|
||||
query=context,
|
||||
top_k=top_k * 2,
|
||||
tags=tags if tags else None,
|
||||
)
|
||||
|
||||
decisions = []
|
||||
for scored in results:
|
||||
decision_id = scored.entry.metadata.get("decision_id")
|
||||
if decision_id and decision_id in self._decisions:
|
||||
decisions.append(self._decisions[decision_id])
|
||||
if len(decisions) >= top_k:
|
||||
break
|
||||
|
||||
return decisions
|
||||
|
||||
def recommend_risk_level(
|
||||
self,
|
||||
category: RiskCategory,
|
||||
market_regime: MarketRegime,
|
||||
context: str,
|
||||
user_id: Optional[str] = None,
|
||||
use_history: bool = True,
|
||||
) -> Tuple[float, str]:
|
||||
"""Recommend a risk level based on profile and history.
|
||||
|
||||
Args:
|
||||
category: Risk category
|
||||
market_regime: Current market regime
|
||||
context: Current situation
|
||||
user_id: User ID
|
||||
use_history: Whether to consider past decisions
|
||||
|
||||
Returns:
|
||||
Tuple of (risk_level, explanation)
|
||||
"""
|
||||
user_id = user_id or self._default_user_id
|
||||
profile = self.get_or_create_profile(user_id)
|
||||
|
||||
# Start with profile-based recommendation
|
||||
base_risk = profile.get_adjusted_risk_score(market_regime)
|
||||
explanation_parts = [
|
||||
f"Base risk from profile: {base_risk:.2f} "
|
||||
f"({profile.base_tolerance.value} adjusted for {market_regime.value})"
|
||||
]
|
||||
|
||||
if not use_history:
|
||||
return base_risk, " | ".join(explanation_parts)
|
||||
|
||||
# Find similar past decisions
|
||||
similar = self.find_similar_decisions(
|
||||
context=context,
|
||||
category=category,
|
||||
market_regime=market_regime,
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
if not similar:
|
||||
explanation_parts.append("No similar past decisions found")
|
||||
return base_risk, " | ".join(explanation_parts)
|
||||
|
||||
# Analyze outcomes of similar decisions
|
||||
successful_decisions = [
|
||||
d for d in similar
|
||||
if d.was_appropriate is True
|
||||
]
|
||||
unsuccessful_decisions = [
|
||||
d for d in similar
|
||||
if d.was_appropriate is False
|
||||
]
|
||||
|
||||
# Calculate weighted average of successful decisions
|
||||
if successful_decisions:
|
||||
successful_avg = statistics.mean([d.risk_level for d in successful_decisions])
|
||||
explanation_parts.append(
|
||||
f"Avg risk level from {len(successful_decisions)} successful similar decisions: "
|
||||
f"{successful_avg:.2f}"
|
||||
)
|
||||
|
||||
# Blend with base risk (weight toward successful history)
|
||||
adjusted_risk = (base_risk * 0.4) + (successful_avg * 0.6)
|
||||
else:
|
||||
adjusted_risk = base_risk
|
||||
|
||||
# Warn about unsuccessful patterns
|
||||
if unsuccessful_decisions:
|
||||
unsuccessful_avg = statistics.mean([d.risk_level for d in unsuccessful_decisions])
|
||||
if abs(adjusted_risk - unsuccessful_avg) < 0.1:
|
||||
explanation_parts.append(
|
||||
f"WARNING: Similar risk level ({unsuccessful_avg:.2f}) was unsuccessful before"
|
||||
)
|
||||
# Adjust away from unsuccessful pattern
|
||||
if unsuccessful_avg > base_risk:
|
||||
adjusted_risk = max(0.0, adjusted_risk - 0.1)
|
||||
else:
|
||||
adjusted_risk = min(1.0, adjusted_risk + 0.1)
|
||||
|
||||
return adjusted_risk, " | ".join(explanation_parts)
|
||||
|
||||
def get_regime_statistics(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get statistics of risk decisions by market regime.
|
||||
|
||||
Args:
|
||||
user_id: Optional filter by user
|
||||
|
||||
Returns:
|
||||
Dictionary of statistics by regime
|
||||
"""
|
||||
stats: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for regime in MarketRegime:
|
||||
regime_decisions = [
|
||||
d for d in self._decisions.values()
|
||||
if d.market_regime == regime
|
||||
]
|
||||
|
||||
if not regime_decisions:
|
||||
stats[regime.value] = {
|
||||
"count": 0,
|
||||
"avg_risk_level": None,
|
||||
"success_rate": None,
|
||||
}
|
||||
continue
|
||||
|
||||
evaluated = [d for d in regime_decisions if d.was_appropriate is not None]
|
||||
successful = [d for d in evaluated if d.was_appropriate is True]
|
||||
|
||||
stats[regime.value] = {
|
||||
"count": len(regime_decisions),
|
||||
"avg_risk_level": statistics.mean([d.risk_level for d in regime_decisions]),
|
||||
"success_rate": len(successful) / len(evaluated) if evaluated else None,
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def get_category_statistics(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get statistics of risk decisions by category.
|
||||
|
||||
Args:
|
||||
user_id: Optional filter by user
|
||||
|
||||
Returns:
|
||||
Dictionary of statistics by category
|
||||
"""
|
||||
stats: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for category in RiskCategory:
|
||||
category_decisions = [
|
||||
d for d in self._decisions.values()
|
||||
if d.category == category
|
||||
]
|
||||
|
||||
if not category_decisions:
|
||||
stats[category.value] = {
|
||||
"count": 0,
|
||||
"avg_risk_level": None,
|
||||
"success_rate": None,
|
||||
}
|
||||
continue
|
||||
|
||||
evaluated = [d for d in category_decisions if d.was_appropriate is not None]
|
||||
successful = [d for d in evaluated if d.was_appropriate is True]
|
||||
|
||||
stats[category.value] = {
|
||||
"count": len(category_decisions),
|
||||
"avg_risk_level": statistics.mean([d.risk_level for d in category_decisions]),
|
||||
"success_rate": len(successful) / len(evaluated) if evaluated else None,
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def learn_regime_adjustments(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
min_decisions: int = 5,
|
||||
) -> Dict[str, float]:
|
||||
"""Learn regime adjustments from historical decisions.
|
||||
|
||||
Analyzes past decisions to suggest optimal regime adjustments.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
min_decisions: Minimum decisions per regime to learn from
|
||||
|
||||
Returns:
|
||||
Suggested regime adjustments
|
||||
"""
|
||||
user_id = user_id or self._default_user_id
|
||||
profile = self.get_or_create_profile(user_id)
|
||||
suggestions: Dict[str, float] = {}
|
||||
|
||||
for regime in MarketRegime:
|
||||
regime_decisions = [
|
||||
d for d in self._decisions.values()
|
||||
if d.market_regime == regime
|
||||
and d.was_appropriate is not None
|
||||
]
|
||||
|
||||
if len(regime_decisions) < min_decisions:
|
||||
continue
|
||||
|
||||
# Find the risk level with best outcomes
|
||||
successful = [d for d in regime_decisions if d.was_appropriate]
|
||||
unsuccessful = [d for d in regime_decisions if not d.was_appropriate]
|
||||
|
||||
if not successful:
|
||||
# All decisions were unsuccessful - suggest lower risk
|
||||
avg_failed_risk = statistics.mean([d.risk_level for d in unsuccessful])
|
||||
suggested_adjustment = -0.2 # Lower risk
|
||||
elif not unsuccessful:
|
||||
# All decisions were successful - keep similar
|
||||
avg_success_risk = statistics.mean([d.risk_level for d in successful])
|
||||
base_score = profile.base_tolerance.to_score()
|
||||
suggested_adjustment = avg_success_risk - base_score
|
||||
else:
|
||||
# Mixed results - prefer successful pattern
|
||||
avg_success_risk = statistics.mean([d.risk_level for d in successful])
|
||||
base_score = profile.base_tolerance.to_score()
|
||||
suggested_adjustment = avg_success_risk - base_score
|
||||
|
||||
suggestions[regime.value] = max(-0.5, min(0.5, suggested_adjustment))
|
||||
|
||||
return suggestions
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return total number of decisions."""
|
||||
return len(self._decisions)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""Clear all decisions (preserves profiles).
|
||||
|
||||
Returns:
|
||||
Number of decisions cleared
|
||||
"""
|
||||
count = len(self._decisions)
|
||||
self._decisions.clear()
|
||||
self._layered_memory.clear()
|
||||
return count
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dictionary."""
|
||||
return {
|
||||
"profiles": {
|
||||
uid: p.to_dict()
|
||||
for uid, p in self._profiles.items()
|
||||
},
|
||||
"decisions": [d.to_dict() for d in self._decisions.values()],
|
||||
"memory": self._layered_memory.to_dict(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Dict[str, Any],
|
||||
embedding_function=None,
|
||||
) -> "RiskProfileMemory":
|
||||
"""Create from dictionary."""
|
||||
instance = cls(embedding_function=embedding_function)
|
||||
|
||||
# Restore profiles
|
||||
for uid, profile_data in data.get("profiles", {}).items():
|
||||
profile = RiskProfile.from_dict(profile_data)
|
||||
instance._profiles[uid] = profile
|
||||
|
||||
# Restore decisions
|
||||
for decision_data in data.get("decisions", []):
|
||||
decision = RiskDecision.from_dict(decision_data)
|
||||
instance._decisions[decision.id] = decision
|
||||
|
||||
# Restore layered memory
|
||||
if "memory" in data:
|
||||
instance._layered_memory = LayeredMemory.from_dict(
|
||||
data["memory"],
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
|
||||
return instance
|
||||
Loading…
Reference in New Issue