From 25c31d5f5dcb712abad5cf57775e38be6d3bf845 Mon Sep 17 00:00:00 2001 From: Andrew Kaszubski Date: Fri, 26 Dec 2025 20:30:21 +1100 Subject: [PATCH] feat(memory): add risk profiles memory for user preferences - Fixes #20 --- tests/unit/memory/test_risk_profiles.py | 966 ++++++++++++++++++++++++ tradingagents/memory/__init__.py | 17 + tradingagents/memory/risk_profiles.py | 817 ++++++++++++++++++++ 3 files changed, 1800 insertions(+) create mode 100644 tests/unit/memory/test_risk_profiles.py create mode 100644 tradingagents/memory/risk_profiles.py diff --git a/tests/unit/memory/test_risk_profiles.py b/tests/unit/memory/test_risk_profiles.py new file mode 100644 index 00000000..7fdc3b36 --- /dev/null +++ b/tests/unit/memory/test_risk_profiles.py @@ -0,0 +1,966 @@ +"""Tests for Risk Profiles Memory module. + +Issue #20: [MEM-19] Risk profiles memory - user preferences over time +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch + +from tradingagents.memory.risk_profiles import ( + RiskProfileMemory, + RiskProfile, + RiskDecision, + RiskTolerance, + MarketRegime, + RiskCategory, +) + + +# ============================================================================= +# RiskTolerance Enum Tests +# ============================================================================= + + +class TestRiskTolerance: + """Tests for RiskTolerance enum.""" + + def test_from_score_conservative(self): + """Test conservative threshold (< 0.25).""" + assert RiskTolerance.from_score(0.0) == RiskTolerance.CONSERVATIVE + assert RiskTolerance.from_score(0.1) == RiskTolerance.CONSERVATIVE + assert RiskTolerance.from_score(0.24) == RiskTolerance.CONSERVATIVE + + def test_from_score_moderate(self): + """Test moderate threshold (0.25 - 0.50).""" + assert RiskTolerance.from_score(0.25) == RiskTolerance.MODERATE + assert RiskTolerance.from_score(0.37) == RiskTolerance.MODERATE + assert RiskTolerance.from_score(0.49) == RiskTolerance.MODERATE + + def test_from_score_aggressive(self): + """Test aggressive threshold (0.50 - 0.75).""" + assert RiskTolerance.from_score(0.50) == RiskTolerance.AGGRESSIVE + assert RiskTolerance.from_score(0.62) == RiskTolerance.AGGRESSIVE + assert RiskTolerance.from_score(0.74) == RiskTolerance.AGGRESSIVE + + def test_from_score_very_aggressive(self): + """Test very aggressive threshold (>= 0.75).""" + assert RiskTolerance.from_score(0.75) == RiskTolerance.VERY_AGGRESSIVE + assert RiskTolerance.from_score(0.9) == RiskTolerance.VERY_AGGRESSIVE + assert RiskTolerance.from_score(1.0) == RiskTolerance.VERY_AGGRESSIVE + + def test_to_score(self): + """Test conversion to numeric score.""" + assert RiskTolerance.CONSERVATIVE.to_score() == 0.125 + assert RiskTolerance.MODERATE.to_score() == 0.375 + assert RiskTolerance.AGGRESSIVE.to_score() == 0.625 + assert RiskTolerance.VERY_AGGRESSIVE.to_score() == 0.875 + + def test_roundtrip_approximate(self): + """Test from_score(to_score()) is consistent.""" + for tolerance in RiskTolerance: + score = tolerance.to_score() + recovered = RiskTolerance.from_score(score) + assert recovered == tolerance + + +# ============================================================================= +# MarketRegime Enum Tests +# ============================================================================= + + +class TestMarketRegime: + """Tests for MarketRegime enum.""" + + def test_all_regimes_defined(self): + """Test all expected regimes exist.""" + expected = ["bull", "bear", "sideways", "high_volatility", "low_volatility", "crisis"] + for regime_value in expected: + assert MarketRegime(regime_value) is not None + + def test_regime_values(self): + """Test regime string values.""" + assert MarketRegime.BULL.value == "bull" + assert MarketRegime.BEAR.value == "bear" + assert MarketRegime.CRISIS.value == "crisis" + + +# ============================================================================= +# RiskCategory Enum Tests +# ============================================================================= + + +class TestRiskCategory: + """Tests for RiskCategory enum.""" + + def test_all_categories_defined(self): + """Test all expected categories exist.""" + expected = [ + "position_size", "leverage", "diversification", + "hedging", "stop_loss", "sector_exposure", "asset_class" + ] + for cat_value in expected: + assert RiskCategory(cat_value) is not None + + +# ============================================================================= +# RiskDecision Tests +# ============================================================================= + + +class TestRiskDecision: + """Tests for RiskDecision dataclass.""" + + def test_create_decision(self): + """Test creating a risk decision.""" + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.6, + market_regime=MarketRegime.BULL, + context="Strong momentum in tech", + ) + + assert decision.id is not None + assert decision.category == RiskCategory.POSITION_SIZE + assert decision.risk_level == 0.6 + assert decision.market_regime == MarketRegime.BULL + assert decision.context == "Strong momentum in tech" + assert decision.outcome is None + assert decision.was_appropriate is None + + def test_create_with_vix(self): + """Test creating decision with VIX level.""" + decision = RiskDecision.create( + category=RiskCategory.LEVERAGE, + risk_level=0.3, + market_regime=MarketRegime.HIGH_VOLATILITY, + context="Market stress", + vix_level=32.5, + ) + + assert decision.vix_level == 32.5 + + def test_create_with_notes(self): + """Test creating decision with notes.""" + decision = RiskDecision.create( + category=RiskCategory.HEDGING, + risk_level=0.4, + market_regime=MarketRegime.BEAR, + context="Protective puts", + notes="Weekly expiration", + ) + + assert decision.notes == "Weekly expiration" + + def test_create_validates_risk_level_low(self): + """Test risk level validation - too low.""" + with pytest.raises(ValueError, match="Risk level must be between 0 and 1"): + RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=-0.1, + market_regime=MarketRegime.BULL, + context="Test", + ) + + def test_create_validates_risk_level_high(self): + """Test risk level validation - too high.""" + with pytest.raises(ValueError, match="Risk level must be between 0 and 1"): + RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=1.5, + market_regime=MarketRegime.BULL, + context="Test", + ) + + def test_create_boundary_values(self): + """Test boundary values for risk level.""" + decision_low = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.0, + market_regime=MarketRegime.BULL, + context="Minimum risk", + ) + assert decision_low.risk_level == 0.0 + + decision_high = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=1.0, + market_regime=MarketRegime.BULL, + context="Maximum risk", + ) + assert decision_high.risk_level == 1.0 + + def test_evaluate_decision(self): + """Test evaluating a decision with outcome.""" + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.6, + market_regime=MarketRegime.BULL, + context="Strong momentum", + ) + + result = decision.evaluate( + outcome="Profitable trade", + outcome_score=0.8, + was_appropriate=True, + ) + + assert result is decision + assert decision.outcome == "Profitable trade" + assert decision.outcome_score == 0.8 + assert decision.was_appropriate is True + + def test_evaluate_clamps_outcome_score(self): + """Test outcome score is clamped to [-1, 1].""" + decision = RiskDecision.create( + category=RiskCategory.LEVERAGE, + risk_level=0.9, + market_regime=MarketRegime.BULL, + context="High leverage", + ) + + decision.evaluate("Loss", -2.0, False) + assert decision.outcome_score == -1.0 + + decision.evaluate("Huge win", 5.0, True) + assert decision.outcome_score == 1.0 + + def test_to_dict(self): + """Test serialization to dictionary.""" + decision = RiskDecision.create( + category=RiskCategory.STOP_LOSS, + risk_level=0.3, + market_regime=MarketRegime.SIDEWAYS, + context="Range-bound market", + vix_level=18.0, + ) + decision.evaluate("Hit stop", -0.5, True) + + data = decision.to_dict() + + assert data["id"] == decision.id + assert data["category"] == "stop_loss" + assert data["risk_level"] == 0.3 + assert data["market_regime"] == "sideways" + assert data["context"] == "Range-bound market" + assert data["vix_level"] == 18.0 + assert data["outcome"] == "Hit stop" + assert data["outcome_score"] == -0.5 + assert data["was_appropriate"] is True + + def test_from_dict(self): + """Test deserialization from dictionary.""" + original = RiskDecision.create( + category=RiskCategory.DIVERSIFICATION, + risk_level=0.5, + market_regime=MarketRegime.LOW_VOLATILITY, + context="Adding sectors", + ) + original.evaluate("Good diversification", 0.3, True) + + data = original.to_dict() + restored = RiskDecision.from_dict(data) + + assert restored.id == original.id + assert restored.category == original.category + assert restored.risk_level == original.risk_level + assert restored.market_regime == original.market_regime + assert restored.outcome == original.outcome + assert restored.was_appropriate == original.was_appropriate + + +# ============================================================================= +# RiskProfile Tests +# ============================================================================= + + +class TestRiskProfile: + """Tests for RiskProfile dataclass.""" + + def test_create_default_profile(self): + """Test creating profile with defaults.""" + profile = RiskProfile(user_id="user1") + + assert profile.user_id == "user1" + assert profile.base_tolerance == RiskTolerance.MODERATE + assert profile.max_drawdown_tolerance == 0.20 + assert profile.volatility_preference == 0.15 + + def test_default_regime_adjustments(self): + """Test default regime adjustments are set.""" + profile = RiskProfile(user_id="user1") + + assert profile.regime_adjustments[MarketRegime.BULL.value] == 0.1 + assert profile.regime_adjustments[MarketRegime.BEAR.value] == -0.2 + assert profile.regime_adjustments[MarketRegime.CRISIS.value] == -0.5 + + def test_get_adjusted_risk_score_bull(self): + """Test adjusted score in bull market.""" + profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.MODERATE) + + score = profile.get_adjusted_risk_score(MarketRegime.BULL) + + # MODERATE = 0.375, BULL adjustment = 0.1 + expected = 0.375 + 0.1 + assert abs(score - expected) < 0.001 + + def test_get_adjusted_risk_score_crisis(self): + """Test adjusted score in crisis.""" + profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE) + + score = profile.get_adjusted_risk_score(MarketRegime.CRISIS) + + # AGGRESSIVE = 0.625, CRISIS adjustment = -0.5 + expected = 0.625 - 0.5 + assert abs(score - expected) < 0.001 + + def test_get_adjusted_risk_score_clamped_low(self): + """Test adjusted score is clamped at 0.""" + profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.CONSERVATIVE) + + score = profile.get_adjusted_risk_score(MarketRegime.CRISIS) + + # CONSERVATIVE = 0.125, CRISIS = -0.5, would be negative + assert score == 0.0 + + def test_get_adjusted_risk_score_clamped_high(self): + """Test adjusted score is clamped at 1.""" + profile = RiskProfile( + user_id="user1", + base_tolerance=RiskTolerance.VERY_AGGRESSIVE, + regime_adjustments={MarketRegime.BULL.value: 0.5}, + ) + + score = profile.get_adjusted_risk_score(MarketRegime.BULL) + + # VERY_AGGRESSIVE = 0.875, BULL = 0.5, would exceed 1.0 + assert score == 1.0 + + def test_get_adjusted_tolerance(self): + """Test getting adjusted tolerance enum.""" + profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE) + + # In crisis, aggressive becomes moderate + tolerance = profile.get_adjusted_tolerance(MarketRegime.CRISIS) + assert tolerance == RiskTolerance.CONSERVATIVE + + def test_update_regime_adjustment(self): + """Test updating regime adjustment.""" + profile = RiskProfile(user_id="user1") + original_updated = profile.updated_at + + profile.update_regime_adjustment(MarketRegime.BEAR, -0.4) + + assert profile.regime_adjustments[MarketRegime.BEAR.value] == -0.4 + assert profile.updated_at >= original_updated + + def test_update_regime_adjustment_clamped(self): + """Test adjustment is clamped to [-1, 1].""" + profile = RiskProfile(user_id="user1") + + profile.update_regime_adjustment(MarketRegime.BULL, 2.0) + assert profile.regime_adjustments[MarketRegime.BULL.value] == 1.0 + + profile.update_regime_adjustment(MarketRegime.CRISIS, -2.0) + assert profile.regime_adjustments[MarketRegime.CRISIS.value] == -1.0 + + def test_to_dict(self): + """Test serialization to dictionary.""" + profile = RiskProfile( + user_id="user1", + base_tolerance=RiskTolerance.AGGRESSIVE, + max_drawdown_tolerance=0.15, + ) + + data = profile.to_dict() + + assert data["user_id"] == "user1" + assert data["base_tolerance"] == "aggressive" + assert data["max_drawdown_tolerance"] == 0.15 + assert "regime_adjustments" in data + assert "created_at" in data + + def test_from_dict(self): + """Test deserialization from dictionary.""" + original = RiskProfile( + user_id="user1", + base_tolerance=RiskTolerance.CONSERVATIVE, + max_drawdown_tolerance=0.10, + ) + + data = original.to_dict() + restored = RiskProfile.from_dict(data) + + assert restored.user_id == original.user_id + assert restored.base_tolerance == original.base_tolerance + assert restored.max_drawdown_tolerance == original.max_drawdown_tolerance + + +# ============================================================================= +# RiskProfileMemory Tests +# ============================================================================= + + +class TestRiskProfileMemory: + """Tests for RiskProfileMemory class.""" + + def test_create_memory(self): + """Test creating memory instance.""" + memory = RiskProfileMemory() + assert memory.count() == 0 + + def test_set_and_get_profile(self): + """Test setting and getting a profile.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE) + + memory.set_profile(profile) + retrieved = memory.get_profile("user1") + + assert retrieved is not None + assert retrieved.user_id == "user1" + assert retrieved.base_tolerance == RiskTolerance.AGGRESSIVE + + def test_get_profile_default_user(self): + """Test getting default user profile.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="default") + memory.set_profile(profile) + + retrieved = memory.get_profile() + assert retrieved.user_id == "default" + + def test_get_nonexistent_profile(self): + """Test getting nonexistent profile returns None.""" + memory = RiskProfileMemory() + assert memory.get_profile("unknown") is None + + def test_get_or_create_profile_new(self): + """Test get_or_create creates new profile.""" + memory = RiskProfileMemory() + + profile = memory.get_or_create_profile("new_user", RiskTolerance.CONSERVATIVE) + + assert profile.user_id == "new_user" + assert profile.base_tolerance == RiskTolerance.CONSERVATIVE + + def test_get_or_create_profile_existing(self): + """Test get_or_create returns existing profile.""" + memory = RiskProfileMemory() + original = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE) + memory.set_profile(original) + + profile = memory.get_or_create_profile("user1", RiskTolerance.CONSERVATIVE) + + assert profile.base_tolerance == RiskTolerance.AGGRESSIVE + + def test_record_decision(self): + """Test recording a decision.""" + memory = RiskProfileMemory() + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.5, + market_regime=MarketRegime.BULL, + context="Tech momentum", + ) + + decision_id = memory.record_decision(decision) + + assert decision_id == decision.id + assert memory.count() == 1 + + def test_get_decision(self): + """Test retrieving a decision by ID.""" + memory = RiskProfileMemory() + decision = RiskDecision.create( + category=RiskCategory.LEVERAGE, + risk_level=0.7, + market_regime=MarketRegime.LOW_VOLATILITY, + context="Low vol environment", + ) + memory.record_decision(decision) + + retrieved = memory.get_decision(decision.id) + + assert retrieved is not None + assert retrieved.risk_level == 0.7 + + def test_get_nonexistent_decision(self): + """Test getting nonexistent decision returns None.""" + memory = RiskProfileMemory() + assert memory.get_decision("unknown") is None + + def test_evaluate_decision(self): + """Test evaluating a recorded decision.""" + memory = RiskProfileMemory() + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.6, + market_regime=MarketRegime.BULL, + context="Strong momentum", + ) + memory.record_decision(decision) + + result = memory.evaluate_decision( + decision_id=decision.id, + outcome="Profitable", + outcome_score=0.5, + was_appropriate=True, + ) + + assert result is not None + assert result.outcome == "Profitable" + assert result.was_appropriate is True + + def test_evaluate_nonexistent_decision(self): + """Test evaluating nonexistent decision.""" + memory = RiskProfileMemory() + result = memory.evaluate_decision("unknown", "Test", 0.5, True) + assert result is None + + def test_find_similar_decisions(self): + """Test finding similar decisions.""" + memory = RiskProfileMemory() + + # Record several decisions + for i in range(5): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.5 + i * 0.05, + market_regime=MarketRegime.BULL, + context=f"Tech momentum scenario {i}", + ) + memory.record_decision(decision) + + similar = memory.find_similar_decisions( + context="Tech momentum similar", + category=RiskCategory.POSITION_SIZE, + market_regime=MarketRegime.BULL, + top_k=3, + ) + + assert len(similar) <= 3 + for decision in similar: + assert decision.category == RiskCategory.POSITION_SIZE + + def test_find_similar_decisions_with_filters(self): + """Test finding similar decisions with regime filter.""" + memory = RiskProfileMemory() + + # Record decisions in different regimes + for regime in [MarketRegime.BULL, MarketRegime.BEAR]: + for i in range(3): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.5, + market_regime=regime, + context=f"Scenario {i}", + ) + memory.record_decision(decision) + + # Filter by BULL regime only + similar = memory.find_similar_decisions( + context="Find scenario", + market_regime=MarketRegime.BULL, + top_k=10, + ) + + for decision in similar: + assert decision.market_regime == MarketRegime.BULL + + def test_recommend_risk_level_no_history(self): + """Test recommendation without history uses profile only.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE) + memory.set_profile(profile) + + risk_level, explanation = memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=MarketRegime.BULL, + context="New situation", + use_history=False, + ) + + # MODERATE (0.375) + BULL adjustment (0.1) + expected = 0.375 + 0.1 + assert abs(risk_level - expected) < 0.01 + assert "Base risk from profile" in explanation + + def test_recommend_risk_level_with_history(self): + """Test recommendation with historical decisions.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE) + memory.set_profile(profile) + + # Record successful decisions + for i in range(3): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.7, + market_regime=MarketRegime.BULL, + context=f"Momentum play {i}", + ) + decision.evaluate("Profitable", 0.6, True) + memory.record_decision(decision) + + risk_level, explanation = memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=MarketRegime.BULL, + context="Similar momentum play", + use_history=True, + ) + + # Should be influenced by successful 0.7 risk decisions + assert risk_level > 0.5 # Should be higher than base moderate + assert "successful similar decisions" in explanation + + def test_recommend_warns_about_unsuccessful(self): + """Test recommendation when only unsuccessful decisions exist.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE) + memory.set_profile(profile) + + # Record unsuccessful decisions with matching context words + for i in range(3): + decision = RiskDecision.create( + category=RiskCategory.LEVERAGE, + risk_level=0.8, + market_regime=MarketRegime.HIGH_VOLATILITY, + context="High leverage situation volatility trade", + ) + decision.evaluate("Loss", -0.5, False) + memory.record_decision(decision) + + risk_level, explanation = memory.recommend_risk_level( + category=RiskCategory.LEVERAGE, + market_regime=MarketRegime.HIGH_VOLATILITY, + context="High leverage situation volatility trade", + use_history=True, + ) + + # Should either use base risk (no successful similar) or warn about unsuccessful + # The explanation should NOT include "successful similar decisions" since there are none + assert "successful similar decisions" not in explanation or "WARNING" in explanation + + def test_get_regime_statistics_empty(self): + """Test regime statistics with no decisions.""" + memory = RiskProfileMemory() + stats = memory.get_regime_statistics() + + for regime in MarketRegime: + assert stats[regime.value]["count"] == 0 + assert stats[regime.value]["avg_risk_level"] is None + + def test_get_regime_statistics(self): + """Test regime statistics with decisions.""" + memory = RiskProfileMemory() + + # Record decisions in bull market + for i in range(5): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.6 + i * 0.02, + market_regime=MarketRegime.BULL, + context=f"Bull scenario {i}", + ) + if i < 3: + decision.evaluate("Win", 0.5, True) + else: + decision.evaluate("Loss", -0.3, False) + memory.record_decision(decision) + + stats = memory.get_regime_statistics() + + assert stats[MarketRegime.BULL.value]["count"] == 5 + assert 0.6 <= stats[MarketRegime.BULL.value]["avg_risk_level"] <= 0.7 + assert stats[MarketRegime.BULL.value]["success_rate"] == 0.6 # 3/5 + + def test_get_category_statistics(self): + """Test category statistics.""" + memory = RiskProfileMemory() + + for i in range(4): + decision = RiskDecision.create( + category=RiskCategory.STOP_LOSS, + risk_level=0.3 + i * 0.05, + market_regime=MarketRegime.SIDEWAYS, + context=f"Stop loss scenario {i}", + ) + decision.evaluate("Hit stop", -0.2, i % 2 == 0) + memory.record_decision(decision) + + stats = memory.get_category_statistics() + + assert stats[RiskCategory.STOP_LOSS.value]["count"] == 4 + assert stats[RiskCategory.STOP_LOSS.value]["success_rate"] == 0.5 + + def test_learn_regime_adjustments_insufficient_data(self): + """Test learning with insufficient data.""" + memory = RiskProfileMemory() + + # Only 2 decisions (below min_decisions=5) + for i in range(2): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.5, + market_regime=MarketRegime.BULL, + context=f"Scenario {i}", + ) + decision.evaluate("Win", 0.5, True) + memory.record_decision(decision) + + suggestions = memory.learn_regime_adjustments(min_decisions=5) + + # Should not have suggestions for BULL (only 2 decisions) + assert MarketRegime.BULL.value not in suggestions + + def test_learn_regime_adjustments_successful(self): + """Test learning from successful decisions.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE) + memory.set_profile(profile) + + # Record successful high-risk decisions in bull + for i in range(6): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.7, # Higher than base (0.375) + market_regime=MarketRegime.BULL, + context=f"Successful bull trade {i}", + ) + decision.evaluate("Profit", 0.6, True) + memory.record_decision(decision) + + suggestions = memory.learn_regime_adjustments(min_decisions=5) + + # Should suggest positive adjustment for bull (0.7 > 0.375) + assert MarketRegime.BULL.value in suggestions + assert suggestions[MarketRegime.BULL.value] > 0 + + def test_learn_regime_adjustments_unsuccessful(self): + """Test learning from unsuccessful decisions.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.AGGRESSIVE) + memory.set_profile(profile) + + # Record all unsuccessful decisions in crisis + for i in range(6): + decision = RiskDecision.create( + category=RiskCategory.LEVERAGE, + risk_level=0.8, + market_regime=MarketRegime.CRISIS, + context=f"Crisis loss {i}", + ) + decision.evaluate("Loss", -0.7, False) + memory.record_decision(decision) + + suggestions = memory.learn_regime_adjustments(min_decisions=5) + + # Should suggest negative adjustment (lower risk in crisis) + assert MarketRegime.CRISIS.value in suggestions + assert suggestions[MarketRegime.CRISIS.value] < 0 + + def test_count(self): + """Test decision count.""" + memory = RiskProfileMemory() + assert memory.count() == 0 + + for i in range(5): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.5, + market_regime=MarketRegime.BULL, + context=f"Scenario {i}", + ) + memory.record_decision(decision) + + assert memory.count() == 5 + + def test_clear(self): + """Test clearing decisions (preserving profiles).""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="user1") + memory.set_profile(profile) + + for i in range(3): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.5, + market_regime=MarketRegime.BULL, + context=f"Scenario {i}", + ) + memory.record_decision(decision) + + assert memory.count() == 3 + cleared = memory.clear() + + assert cleared == 3 + assert memory.count() == 0 + assert memory.get_profile("user1") is not None # Profile preserved + + def test_to_dict(self): + """Test serialization to dictionary.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.AGGRESSIVE) + memory.set_profile(profile) + + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.6, + market_regime=MarketRegime.BULL, + context="Test scenario", + ) + memory.record_decision(decision) + + data = memory.to_dict() + + assert "profiles" in data + assert "user1" in data["profiles"] + assert "decisions" in data + assert len(data["decisions"]) == 1 + assert "memory" in data + + def test_from_dict(self): + """Test deserialization from dictionary.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.CONSERVATIVE) + memory.set_profile(profile) + + decision = RiskDecision.create( + category=RiskCategory.HEDGING, + risk_level=0.4, + market_regime=MarketRegime.BEAR, + context="Protective strategy", + ) + decision.evaluate("Good protection", 0.5, True) + memory.record_decision(decision) + + data = memory.to_dict() + restored = RiskProfileMemory.from_dict(data) + + assert restored.get_profile("user1") is not None + assert restored.get_profile("user1").base_tolerance == RiskTolerance.CONSERVATIVE + assert restored.count() == 1 + assert restored.get_decision(decision.id) is not None + + def test_multiple_users(self): + """Test handling multiple user profiles.""" + memory = RiskProfileMemory() + + users = ["alice", "bob", "charlie"] + tolerances = [RiskTolerance.CONSERVATIVE, RiskTolerance.MODERATE, RiskTolerance.AGGRESSIVE] + + for user, tolerance in zip(users, tolerances): + profile = RiskProfile(user_id=user, base_tolerance=tolerance) + memory.set_profile(profile) + + for user, expected_tolerance in zip(users, tolerances): + profile = memory.get_profile(user) + assert profile.base_tolerance == expected_tolerance + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestRiskProfileMemoryIntegration: + """Integration tests for RiskProfileMemory.""" + + def test_full_workflow(self): + """Test complete workflow: profile -> decisions -> learning.""" + memory = RiskProfileMemory() + + # 1. Create profile + profile = RiskProfile( + user_id="trader1", + base_tolerance=RiskTolerance.MODERATE, + max_drawdown_tolerance=0.15, + ) + memory.set_profile(profile) + + # 2. Record decisions across regimes + decisions_data = [ + (MarketRegime.BULL, 0.6, True), + (MarketRegime.BULL, 0.7, True), + (MarketRegime.BULL, 0.65, True), + (MarketRegime.BEAR, 0.3, True), + (MarketRegime.BEAR, 0.5, False), + (MarketRegime.CRISIS, 0.2, True), + ] + + for regime, risk, appropriate in decisions_data: + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=risk, + market_regime=regime, + context=f"Trading in {regime.value}", + ) + outcome_score = 0.5 if appropriate else -0.5 + decision.evaluate("Result", outcome_score, appropriate) + memory.record_decision(decision, user_id="trader1") + + # 3. Get recommendations + bull_risk, _ = memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=MarketRegime.BULL, + context="Bull market opportunity", + user_id="trader1", + ) + + bear_risk, _ = memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=MarketRegime.BEAR, + context="Bear market caution", + user_id="trader1", + ) + + # Bull should recommend higher risk than bear + assert bull_risk > bear_risk + + # 4. Get statistics + stats = memory.get_regime_statistics() + assert stats[MarketRegime.BULL.value]["count"] == 3 + assert stats[MarketRegime.BULL.value]["success_rate"] == 1.0 + + # 5. Serialize and restore + data = memory.to_dict() + restored = RiskProfileMemory.from_dict(data) + + assert restored.get_profile("trader1") is not None + assert restored.count() == 6 + + def test_recommendation_adapts_over_time(self): + """Test that recommendations adapt as more data is collected.""" + memory = RiskProfileMemory() + profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE) + memory.set_profile(profile) + + # Initial recommendation (no history) + initial_risk, _ = memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=MarketRegime.BULL, + context="Starting out", + ) + + # Add successful high-risk decisions + for i in range(5): + decision = RiskDecision.create( + category=RiskCategory.POSITION_SIZE, + risk_level=0.8, + market_regime=MarketRegime.BULL, + context=f"Successful trade {i}", + ) + decision.evaluate("Win", 0.7, True) + memory.record_decision(decision) + + # Later recommendation should be influenced by history + later_risk, explanation = memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=MarketRegime.BULL, + context="Similar opportunity", + ) + + # Should be higher after seeing successful high-risk trades + assert later_risk > initial_risk + assert "successful similar decisions" in explanation diff --git a/tradingagents/memory/__init__.py b/tradingagents/memory/__init__.py index 86243b2c..f4155d83 100644 --- a/tradingagents/memory/__init__.py +++ b/tradingagents/memory/__init__.py @@ -7,6 +7,7 @@ This module provides a layered memory system with three scoring dimensions: Issue #18: Layered memory - recency, relevancy, importance scoring Issue #19: Trade history memory - outcomes, agent reasoning +Issue #20: Risk profiles memory - user preferences over time """ from .layered_memory import ( @@ -28,6 +29,15 @@ from .trade_history import ( MarketContext, ) +from .risk_profiles import ( + RiskProfileMemory, + RiskProfile, + RiskDecision, + RiskTolerance, + MarketRegime, + RiskCategory, +) + __all__ = [ # Layered Memory (Issue #18) "LayeredMemory", @@ -44,4 +54,11 @@ __all__ = [ "SignalStrength", "AgentReasoning", "MarketContext", + # Risk Profiles (Issue #20) + "RiskProfileMemory", + "RiskProfile", + "RiskDecision", + "RiskTolerance", + "MarketRegime", + "RiskCategory", ] diff --git a/tradingagents/memory/risk_profiles.py b/tradingagents/memory/risk_profiles.py new file mode 100644 index 00000000..726a212a --- /dev/null +++ b/tradingagents/memory/risk_profiles.py @@ -0,0 +1,817 @@ +"""Risk Profiles Memory for tracking user risk preferences over time. + +This module provides memory for tracking and learning from risk preferences: +- Risk tolerance levels across different market conditions +- Risk preference evolution over time +- Market regime-specific risk adjustments +- Historical risk decisions and outcomes + +Issue #20: [MEM-19] Risk profiles memory - user preferences over time +""" + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from enum import Enum +from typing import Dict, List, Optional, Any, Tuple +import statistics +import uuid + +from .layered_memory import ( + LayeredMemory, + MemoryEntry, + MemoryConfig, + ScoringWeights, + ImportanceLevel, +) + + +class RiskTolerance(Enum): + """User risk tolerance levels.""" + CONSERVATIVE = "conservative" # Low risk, capital preservation + MODERATE = "moderate" # Balanced risk/reward + AGGRESSIVE = "aggressive" # High risk, growth focused + VERY_AGGRESSIVE = "very_aggressive" # Maximum risk tolerance + + @classmethod + def from_score(cls, score: float) -> "RiskTolerance": + """Convert a risk score (0-1) to RiskTolerance. + + Args: + score: Risk score between 0 (conservative) and 1 (very aggressive) + + Returns: + Corresponding RiskTolerance + """ + if score < 0.25: + return cls.CONSERVATIVE + elif score < 0.50: + return cls.MODERATE + elif score < 0.75: + return cls.AGGRESSIVE + else: + return cls.VERY_AGGRESSIVE + + def to_score(self) -> float: + """Convert RiskTolerance to a numeric score. + + Returns: + Score between 0 and 1 + """ + mapping = { + RiskTolerance.CONSERVATIVE: 0.125, + RiskTolerance.MODERATE: 0.375, + RiskTolerance.AGGRESSIVE: 0.625, + RiskTolerance.VERY_AGGRESSIVE: 0.875, + } + return mapping[self] + + +class MarketRegime(Enum): + """Market regime classifications.""" + BULL = "bull" # Strong uptrend + BEAR = "bear" # Strong downtrend + SIDEWAYS = "sideways" # Range-bound + HIGH_VOLATILITY = "high_volatility" # VIX > 25 + LOW_VOLATILITY = "low_volatility" # VIX < 15 + CRISIS = "crisis" # Market stress/crash + + +class RiskCategory(Enum): + """Categories of risk decisions.""" + POSITION_SIZE = "position_size" # How much to invest + LEVERAGE = "leverage" # Use of leverage + DIVERSIFICATION = "diversification" # Portfolio spread + HEDGING = "hedging" # Protective positions + STOP_LOSS = "stop_loss" # Exit thresholds + SECTOR_EXPOSURE = "sector_exposure" # Sector concentration + ASSET_CLASS = "asset_class" # Asset allocation + + +@dataclass +class RiskDecision: + """A recorded risk decision with context and outcome. + + Attributes: + id: Unique decision ID + timestamp: When decision was made + category: Type of risk decision + risk_level: Risk level chosen (0-1 scale) + market_regime: Market conditions at decision time + context: Situation description + vix_level: VIX at decision time + outcome: Outcome description (added later) + outcome_score: Quantified outcome (-1 to 1) + was_appropriate: Whether decision was appropriate in hindsight + notes: Additional notes + """ + id: str + timestamp: datetime + category: RiskCategory + risk_level: float # 0 (min risk) to 1 (max risk) + market_regime: MarketRegime + context: str + vix_level: Optional[float] = None + outcome: Optional[str] = None + outcome_score: Optional[float] = None # -1 (bad) to 1 (good) + was_appropriate: Optional[bool] = None + notes: Optional[str] = None + + @classmethod + def create( + cls, + category: RiskCategory, + risk_level: float, + market_regime: MarketRegime, + context: str, + vix_level: Optional[float] = None, + notes: Optional[str] = None, + ) -> "RiskDecision": + """Create a new risk decision record. + + Args: + category: Type of risk decision + risk_level: Risk level chosen (0-1) + market_regime: Current market regime + context: Situation description + vix_level: Current VIX level + notes: Additional notes + + Returns: + New RiskDecision instance + """ + if not 0.0 <= risk_level <= 1.0: + raise ValueError(f"Risk level must be between 0 and 1, got {risk_level}") + + return cls( + id=str(uuid.uuid4()), + timestamp=datetime.now(), + category=category, + risk_level=risk_level, + market_regime=market_regime, + context=context, + vix_level=vix_level, + notes=notes, + ) + + def evaluate( + self, + outcome: str, + outcome_score: float, + was_appropriate: bool, + ) -> "RiskDecision": + """Evaluate the decision after the outcome is known. + + Args: + outcome: Description of what happened + outcome_score: Quantified outcome (-1 to 1) + was_appropriate: Whether the risk level was appropriate + + Returns: + Self with updated evaluation + """ + self.outcome = outcome + self.outcome_score = max(-1.0, min(1.0, outcome_score)) + self.was_appropriate = was_appropriate + return self + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "id": self.id, + "timestamp": self.timestamp.isoformat(), + "category": self.category.value, + "risk_level": self.risk_level, + "market_regime": self.market_regime.value, + "context": self.context, + "vix_level": self.vix_level, + "outcome": self.outcome, + "outcome_score": self.outcome_score, + "was_appropriate": self.was_appropriate, + "notes": self.notes, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RiskDecision": + """Create from dictionary.""" + return cls( + id=data["id"], + timestamp=datetime.fromisoformat(data["timestamp"]), + category=RiskCategory(data["category"]), + risk_level=data["risk_level"], + market_regime=MarketRegime(data["market_regime"]), + context=data["context"], + vix_level=data.get("vix_level"), + outcome=data.get("outcome"), + outcome_score=data.get("outcome_score"), + was_appropriate=data.get("was_appropriate"), + notes=data.get("notes"), + ) + + +@dataclass +class RiskProfile: + """User's risk profile with preferences and history. + + Attributes: + user_id: User identifier + base_tolerance: Baseline risk tolerance + regime_adjustments: Adjustments by market regime + category_preferences: Preferences by risk category + max_drawdown_tolerance: Maximum acceptable drawdown + volatility_preference: Preferred portfolio volatility + created_at: Profile creation time + updated_at: Last update time + """ + user_id: str + base_tolerance: RiskTolerance = RiskTolerance.MODERATE + regime_adjustments: Dict[str, float] = field(default_factory=dict) + category_preferences: Dict[str, float] = field(default_factory=dict) + max_drawdown_tolerance: float = 0.20 # 20% max drawdown + volatility_preference: float = 0.15 # 15% annual volatility + created_at: datetime = field(default_factory=datetime.now) + updated_at: datetime = field(default_factory=datetime.now) + + def __post_init__(self): + """Initialize default regime adjustments if empty.""" + if not self.regime_adjustments: + self.regime_adjustments = { + MarketRegime.BULL.value: 0.1, # Slightly more risk + MarketRegime.BEAR.value: -0.2, # Reduce risk + MarketRegime.SIDEWAYS.value: 0.0, # No change + MarketRegime.HIGH_VOLATILITY.value: -0.3, # Reduce significantly + MarketRegime.LOW_VOLATILITY.value: 0.1, # Slightly more risk + MarketRegime.CRISIS.value: -0.5, # Maximum reduction + } + + def get_adjusted_risk_score(self, market_regime: MarketRegime) -> float: + """Get risk score adjusted for current market regime. + + Args: + market_regime: Current market regime + + Returns: + Adjusted risk score (0-1) + """ + base_score = self.base_tolerance.to_score() + adjustment = self.regime_adjustments.get(market_regime.value, 0.0) + adjusted = base_score + adjustment + return max(0.0, min(1.0, adjusted)) + + def get_adjusted_tolerance(self, market_regime: MarketRegime) -> RiskTolerance: + """Get risk tolerance adjusted for market regime. + + Args: + market_regime: Current market regime + + Returns: + Adjusted RiskTolerance + """ + score = self.get_adjusted_risk_score(market_regime) + return RiskTolerance.from_score(score) + + def update_regime_adjustment( + self, + regime: MarketRegime, + adjustment: float, + ) -> None: + """Update the adjustment for a specific regime. + + Args: + regime: Market regime to update + adjustment: New adjustment value (-1 to 1) + """ + self.regime_adjustments[regime.value] = max(-1.0, min(1.0, adjustment)) + self.updated_at = datetime.now() + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "user_id": self.user_id, + "base_tolerance": self.base_tolerance.value, + "regime_adjustments": self.regime_adjustments, + "category_preferences": self.category_preferences, + "max_drawdown_tolerance": self.max_drawdown_tolerance, + "volatility_preference": self.volatility_preference, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RiskProfile": + """Create from dictionary.""" + return cls( + user_id=data["user_id"], + base_tolerance=RiskTolerance(data["base_tolerance"]), + regime_adjustments=data.get("regime_adjustments", {}), + category_preferences=data.get("category_preferences", {}), + max_drawdown_tolerance=data.get("max_drawdown_tolerance", 0.20), + volatility_preference=data.get("volatility_preference", 0.15), + created_at=datetime.fromisoformat(data["created_at"]), + updated_at=datetime.fromisoformat(data["updated_at"]), + ) + + +class RiskProfileMemory: + """Memory system for tracking risk profiles and decisions. + + This class provides storage and retrieval for risk profiles and + historical risk decisions, enabling learning from past decisions. + + Example: + >>> memory = RiskProfileMemory() + >>> + >>> # Create a risk profile + >>> profile = RiskProfile(user_id="user1", base_tolerance=RiskTolerance.MODERATE) + >>> memory.set_profile(profile) + >>> + >>> # Record a risk decision + >>> decision = RiskDecision.create( + ... category=RiskCategory.POSITION_SIZE, + ... risk_level=0.6, + ... market_regime=MarketRegime.BULL, + ... context="Strong momentum in tech sector", + ... ) + >>> memory.record_decision(decision) + >>> + >>> # Get recommended risk level for similar situation + >>> recommended = memory.recommend_risk_level( + ... category=RiskCategory.POSITION_SIZE, + ... market_regime=MarketRegime.BULL, + ... context="Tech sector showing strength", + ... ) + """ + + def __init__( + self, + config: Optional[MemoryConfig] = None, + embedding_function=None, + ): + """Initialize risk profile memory. + + Args: + config: Memory configuration + embedding_function: Optional embedding function + """ + if config is None: + config = MemoryConfig( + weights=ScoringWeights( + recency=0.35, # Recent decisions more relevant + relevancy=0.40, # Similar situations important + importance=0.25, # Outcome importance + ), + ) + + self._layered_memory = LayeredMemory( + config=config, + embedding_function=embedding_function, + ) + self._profiles: Dict[str, RiskProfile] = {} + self._decisions: Dict[str, RiskDecision] = {} + self._default_user_id = "default" + + def set_profile(self, profile: RiskProfile) -> None: + """Set or update a user's risk profile. + + Args: + profile: Risk profile to store + """ + self._profiles[profile.user_id] = profile + + def get_profile(self, user_id: Optional[str] = None) -> Optional[RiskProfile]: + """Get a user's risk profile. + + Args: + user_id: User ID (default: "default") + + Returns: + RiskProfile or None + """ + user_id = user_id or self._default_user_id + return self._profiles.get(user_id) + + def get_or_create_profile( + self, + user_id: Optional[str] = None, + base_tolerance: RiskTolerance = RiskTolerance.MODERATE, + ) -> RiskProfile: + """Get existing profile or create a new one. + + Args: + user_id: User ID + base_tolerance: Default tolerance if creating new + + Returns: + RiskProfile + """ + user_id = user_id or self._default_user_id + profile = self._profiles.get(user_id) + + if profile is None: + profile = RiskProfile(user_id=user_id, base_tolerance=base_tolerance) + self._profiles[user_id] = profile + + return profile + + def record_decision( + self, + decision: RiskDecision, + user_id: Optional[str] = None, + ) -> str: + """Record a risk decision. + + Args: + decision: The risk decision to record + user_id: User ID (default: "default") + + Returns: + Decision ID + """ + user_id = user_id or self._default_user_id + self._decisions[decision.id] = decision + + # Calculate importance based on outcome if available + importance = 0.5 + if decision.outcome_score is not None: + importance = 0.5 + (abs(decision.outcome_score) * 0.5) + + # Create memory entry + content = ( + f"Risk decision: {decision.category.value} with risk level " + f"{decision.risk_level:.2f} in {decision.market_regime.value} market. " + f"Context: {decision.context}" + ) + + entry = MemoryEntry.create( + content=content, + metadata={ + "decision_id": decision.id, + "user_id": user_id, + "category": decision.category.value, + "risk_level": decision.risk_level, + "market_regime": decision.market_regime.value, + "vix_level": decision.vix_level, + "outcome": decision.outcome, + "outcome_score": decision.outcome_score, + "was_appropriate": decision.was_appropriate, + }, + importance=importance, + tags=[ + user_id, + decision.category.value, + decision.market_regime.value, + ], + timestamp=decision.timestamp, + ) + entry.id = decision.id + + self._layered_memory.add(entry) + return decision.id + + def evaluate_decision( + self, + decision_id: str, + outcome: str, + outcome_score: float, + was_appropriate: bool, + ) -> Optional[RiskDecision]: + """Evaluate a past decision with hindsight. + + Args: + decision_id: ID of the decision + outcome: What happened + outcome_score: Quantified outcome (-1 to 1) + was_appropriate: Whether decision was appropriate + + Returns: + Updated decision or None + """ + decision = self._decisions.get(decision_id) + if decision is None: + return None + + decision.evaluate(outcome, outcome_score, was_appropriate) + + # Update memory importance + importance = 0.5 + (abs(outcome_score) * 0.5) + self._layered_memory.update_importance(decision_id, importance) + + return decision + + def get_decision(self, decision_id: str) -> Optional[RiskDecision]: + """Get a decision by ID. + + Args: + decision_id: Decision ID + + Returns: + RiskDecision or None + """ + return self._decisions.get(decision_id) + + def find_similar_decisions( + self, + context: str, + category: Optional[RiskCategory] = None, + market_regime: Optional[MarketRegime] = None, + top_k: int = 5, + ) -> List[RiskDecision]: + """Find similar past decisions. + + Args: + context: Current situation context + category: Optional filter by category + market_regime: Optional filter by regime + top_k: Maximum results + + Returns: + List of similar decisions + """ + tags = [] + if category: + tags.append(category.value) + if market_regime: + tags.append(market_regime.value) + + results = self._layered_memory.retrieve( + query=context, + top_k=top_k * 2, + tags=tags if tags else None, + ) + + decisions = [] + for scored in results: + decision_id = scored.entry.metadata.get("decision_id") + if decision_id and decision_id in self._decisions: + decisions.append(self._decisions[decision_id]) + if len(decisions) >= top_k: + break + + return decisions + + def recommend_risk_level( + self, + category: RiskCategory, + market_regime: MarketRegime, + context: str, + user_id: Optional[str] = None, + use_history: bool = True, + ) -> Tuple[float, str]: + """Recommend a risk level based on profile and history. + + Args: + category: Risk category + market_regime: Current market regime + context: Current situation + user_id: User ID + use_history: Whether to consider past decisions + + Returns: + Tuple of (risk_level, explanation) + """ + user_id = user_id or self._default_user_id + profile = self.get_or_create_profile(user_id) + + # Start with profile-based recommendation + base_risk = profile.get_adjusted_risk_score(market_regime) + explanation_parts = [ + f"Base risk from profile: {base_risk:.2f} " + f"({profile.base_tolerance.value} adjusted for {market_regime.value})" + ] + + if not use_history: + return base_risk, " | ".join(explanation_parts) + + # Find similar past decisions + similar = self.find_similar_decisions( + context=context, + category=category, + market_regime=market_regime, + top_k=5, + ) + + if not similar: + explanation_parts.append("No similar past decisions found") + return base_risk, " | ".join(explanation_parts) + + # Analyze outcomes of similar decisions + successful_decisions = [ + d for d in similar + if d.was_appropriate is True + ] + unsuccessful_decisions = [ + d for d in similar + if d.was_appropriate is False + ] + + # Calculate weighted average of successful decisions + if successful_decisions: + successful_avg = statistics.mean([d.risk_level for d in successful_decisions]) + explanation_parts.append( + f"Avg risk level from {len(successful_decisions)} successful similar decisions: " + f"{successful_avg:.2f}" + ) + + # Blend with base risk (weight toward successful history) + adjusted_risk = (base_risk * 0.4) + (successful_avg * 0.6) + else: + adjusted_risk = base_risk + + # Warn about unsuccessful patterns + if unsuccessful_decisions: + unsuccessful_avg = statistics.mean([d.risk_level for d in unsuccessful_decisions]) + if abs(adjusted_risk - unsuccessful_avg) < 0.1: + explanation_parts.append( + f"WARNING: Similar risk level ({unsuccessful_avg:.2f}) was unsuccessful before" + ) + # Adjust away from unsuccessful pattern + if unsuccessful_avg > base_risk: + adjusted_risk = max(0.0, adjusted_risk - 0.1) + else: + adjusted_risk = min(1.0, adjusted_risk + 0.1) + + return adjusted_risk, " | ".join(explanation_parts) + + def get_regime_statistics( + self, + user_id: Optional[str] = None, + ) -> Dict[str, Dict[str, Any]]: + """Get statistics of risk decisions by market regime. + + Args: + user_id: Optional filter by user + + Returns: + Dictionary of statistics by regime + """ + stats: Dict[str, Dict[str, Any]] = {} + + for regime in MarketRegime: + regime_decisions = [ + d for d in self._decisions.values() + if d.market_regime == regime + ] + + if not regime_decisions: + stats[regime.value] = { + "count": 0, + "avg_risk_level": None, + "success_rate": None, + } + continue + + evaluated = [d for d in regime_decisions if d.was_appropriate is not None] + successful = [d for d in evaluated if d.was_appropriate is True] + + stats[regime.value] = { + "count": len(regime_decisions), + "avg_risk_level": statistics.mean([d.risk_level for d in regime_decisions]), + "success_rate": len(successful) / len(evaluated) if evaluated else None, + } + + return stats + + def get_category_statistics( + self, + user_id: Optional[str] = None, + ) -> Dict[str, Dict[str, Any]]: + """Get statistics of risk decisions by category. + + Args: + user_id: Optional filter by user + + Returns: + Dictionary of statistics by category + """ + stats: Dict[str, Dict[str, Any]] = {} + + for category in RiskCategory: + category_decisions = [ + d for d in self._decisions.values() + if d.category == category + ] + + if not category_decisions: + stats[category.value] = { + "count": 0, + "avg_risk_level": None, + "success_rate": None, + } + continue + + evaluated = [d for d in category_decisions if d.was_appropriate is not None] + successful = [d for d in evaluated if d.was_appropriate is True] + + stats[category.value] = { + "count": len(category_decisions), + "avg_risk_level": statistics.mean([d.risk_level for d in category_decisions]), + "success_rate": len(successful) / len(evaluated) if evaluated else None, + } + + return stats + + def learn_regime_adjustments( + self, + user_id: Optional[str] = None, + min_decisions: int = 5, + ) -> Dict[str, float]: + """Learn regime adjustments from historical decisions. + + Analyzes past decisions to suggest optimal regime adjustments. + + Args: + user_id: User ID + min_decisions: Minimum decisions per regime to learn from + + Returns: + Suggested regime adjustments + """ + user_id = user_id or self._default_user_id + profile = self.get_or_create_profile(user_id) + suggestions: Dict[str, float] = {} + + for regime in MarketRegime: + regime_decisions = [ + d for d in self._decisions.values() + if d.market_regime == regime + and d.was_appropriate is not None + ] + + if len(regime_decisions) < min_decisions: + continue + + # Find the risk level with best outcomes + successful = [d for d in regime_decisions if d.was_appropriate] + unsuccessful = [d for d in regime_decisions if not d.was_appropriate] + + if not successful: + # All decisions were unsuccessful - suggest lower risk + avg_failed_risk = statistics.mean([d.risk_level for d in unsuccessful]) + suggested_adjustment = -0.2 # Lower risk + elif not unsuccessful: + # All decisions were successful - keep similar + avg_success_risk = statistics.mean([d.risk_level for d in successful]) + base_score = profile.base_tolerance.to_score() + suggested_adjustment = avg_success_risk - base_score + else: + # Mixed results - prefer successful pattern + avg_success_risk = statistics.mean([d.risk_level for d in successful]) + base_score = profile.base_tolerance.to_score() + suggested_adjustment = avg_success_risk - base_score + + suggestions[regime.value] = max(-0.5, min(0.5, suggested_adjustment)) + + return suggestions + + def count(self) -> int: + """Return total number of decisions.""" + return len(self._decisions) + + def clear(self) -> int: + """Clear all decisions (preserves profiles). + + Returns: + Number of decisions cleared + """ + count = len(self._decisions) + self._decisions.clear() + self._layered_memory.clear() + return count + + def to_dict(self) -> Dict[str, Any]: + """Serialize to dictionary.""" + return { + "profiles": { + uid: p.to_dict() + for uid, p in self._profiles.items() + }, + "decisions": [d.to_dict() for d in self._decisions.values()], + "memory": self._layered_memory.to_dict(), + } + + @classmethod + def from_dict( + cls, + data: Dict[str, Any], + embedding_function=None, + ) -> "RiskProfileMemory": + """Create from dictionary.""" + instance = cls(embedding_function=embedding_function) + + # Restore profiles + for uid, profile_data in data.get("profiles", {}).items(): + profile = RiskProfile.from_dict(profile_data) + instance._profiles[uid] = profile + + # Restore decisions + for decision_data in data.get("decisions", []): + decision = RiskDecision.from_dict(decision_data) + instance._decisions[decision.id] = decision + + # Restore layered memory + if "memory" in data: + instance._layered_memory = LayeredMemory.from_dict( + data["memory"], + embedding_function=embedding_function, + ) + + return instance