diff --git a/tests/unit/memory/test_trade_history.py b/tests/unit/memory/test_trade_history.py new file mode 100644 index 00000000..be7eed98 --- /dev/null +++ b/tests/unit/memory/test_trade_history.py @@ -0,0 +1,740 @@ +"""Tests for Issue #19: Trade History Memory. + +This module tests the trade history memory for tracking: +- Trade outcomes (profit/loss) +- Agent reasoning +- Market context +- Pattern finding +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +from tradingagents.memory.trade_history import ( + TradeHistoryMemory, + TradeRecord, + TradeOutcome, + TradeDirection, + SignalStrength, + AgentReasoning, + MarketContext, +) + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + +@pytest.fixture +def memory(): + """Create a TradeHistoryMemory instance.""" + return TradeHistoryMemory() + + +@pytest.fixture +def sample_reasoning(): + """Create sample agent reasoning.""" + return AgentReasoning( + fundamentals="Strong earnings growth, P/E below industry average", + technical="Breaking out above 50-day MA with volume", + news="Positive analyst upgrades", + sentiment="Bullish social media sentiment", + bull_case="Undervalued with strong growth trajectory", + bear_case="Macro headwinds could impact demand", + research_conclusion="Buy on fundamental strength", + final_signal="STRONG_BUY", + ) + + +@pytest.fixture +def sample_market_context(): + """Create sample market context.""" + return MarketContext( + vix=18.5, + spy_return_1d=0.005, + sector_performance={"XLK": 0.01, "XLF": -0.005}, + economic_regime="EXPANSION", + yield_curve_state="NORMAL", + macro_indicators={"gdp_growth": 2.5, "unemployment": 4.0}, + ) + + +@pytest.fixture +def sample_trade(sample_reasoning, sample_market_context): + """Create a sample trade record.""" + return TradeRecord.create( + symbol="AAPL", + direction=TradeDirection.LONG, + entry_price=150.0, + quantity=100, + signal_strength=SignalStrength.STRONG_BUY, + confidence=0.85, + reasoning=sample_reasoning, + market_context=sample_market_context, + tags=["tech", "earnings"], + ) + + +@pytest.fixture +def multiple_trades(): + """Create multiple trade records for testing.""" + now = datetime.now() + trades = [] + + # Profitable AAPL trade + trade1 = TradeRecord( + id="trade-1", + symbol="AAPL", + direction=TradeDirection.LONG, + entry_price=150.0, + entry_time=now - timedelta(days=10), + quantity=100, + reasoning=AgentReasoning(research_conclusion="Buy on earnings"), + tags=["tech", "earnings"], + ) + trade1.close(165.0) + + # Loss GOOGL trade + trade2 = TradeRecord( + id="trade-2", + symbol="GOOGL", + direction=TradeDirection.LONG, + entry_price=140.0, + entry_time=now - timedelta(days=5), + quantity=50, + reasoning=AgentReasoning(research_conclusion="Momentum buy"), + tags=["tech"], + ) + trade2.close(130.0) + + # Break-even MSFT trade + trade3 = TradeRecord( + id="trade-3", + symbol="MSFT", + direction=TradeDirection.LONG, + entry_price=350.0, + entry_time=now - timedelta(days=2), + quantity=20, + reasoning=AgentReasoning(research_conclusion="Range trade"), + tags=["tech"], + ) + trade3.close(351.0) + + # Open NVDA trade + trade4 = TradeRecord( + id="trade-4", + symbol="NVDA", + direction=TradeDirection.LONG, + entry_price=500.0, + entry_time=now - timedelta(hours=2), + quantity=10, + reasoning=AgentReasoning(research_conclusion="AI momentum"), + tags=["tech", "ai"], + ) + + trades = [trade1, trade2, trade3, trade4] + return trades + + +# ============================================================================= +# TradeDirection Tests +# ============================================================================= + +class TestTradeDirection: + """Tests for TradeDirection enum.""" + + def test_long_value(self): + """LONG should have correct value.""" + assert TradeDirection.LONG.value == "long" + + def test_short_value(self): + """SHORT should have correct value.""" + assert TradeDirection.SHORT.value == "short" + + def test_hold_value(self): + """HOLD should have correct value.""" + assert TradeDirection.HOLD.value == "hold" + + +# ============================================================================= +# TradeOutcome Tests +# ============================================================================= + +class TestTradeOutcome: + """Tests for TradeOutcome enum.""" + + def test_profitable(self): + """PROFITABLE should have correct value.""" + assert TradeOutcome.PROFITABLE.value == "profitable" + + def test_loss(self): + """LOSS should have correct value.""" + assert TradeOutcome.LOSS.value == "loss" + + def test_break_even(self): + """BREAK_EVEN should have correct value.""" + assert TradeOutcome.BREAK_EVEN.value == "break_even" + + +# ============================================================================= +# SignalStrength Tests +# ============================================================================= + +class TestSignalStrength: + """Tests for SignalStrength enum.""" + + def test_signal_values(self): + """All signal strengths should have correct values.""" + assert SignalStrength.STRONG_BUY.value == "strong_buy" + assert SignalStrength.BUY.value == "buy" + assert SignalStrength.NEUTRAL.value == "neutral" + assert SignalStrength.SELL.value == "sell" + assert SignalStrength.STRONG_SELL.value == "strong_sell" + + +# ============================================================================= +# AgentReasoning Tests +# ============================================================================= + +class TestAgentReasoning: + """Tests for AgentReasoning class.""" + + def test_default_reasoning(self): + """Default reasoning should have None values.""" + reasoning = AgentReasoning() + assert reasoning.fundamentals is None + assert reasoning.technical is None + assert reasoning.research_conclusion is None + + def test_reasoning_with_values(self, sample_reasoning): + """Reasoning should store values correctly.""" + assert sample_reasoning.fundamentals is not None + assert sample_reasoning.research_conclusion == "Buy on fundamental strength" + + def test_to_dict(self, sample_reasoning): + """To dict should serialize correctly.""" + data = sample_reasoning.to_dict() + assert data["fundamentals"] == sample_reasoning.fundamentals + assert data["research_conclusion"] == sample_reasoning.research_conclusion + + def test_from_dict(self, sample_reasoning): + """From dict should deserialize correctly.""" + data = sample_reasoning.to_dict() + restored = AgentReasoning.from_dict(data) + assert restored.fundamentals == sample_reasoning.fundamentals + assert restored.research_conclusion == sample_reasoning.research_conclusion + + def test_summary(self, sample_reasoning): + """Summary should generate text.""" + summary = sample_reasoning.summary() + assert "Fundamentals" in summary + assert "Conclusion" in summary + + def test_empty_summary(self): + """Empty reasoning should have fallback summary.""" + reasoning = AgentReasoning() + summary = reasoning.summary() + assert summary == "No reasoning recorded" + + +# ============================================================================= +# MarketContext Tests +# ============================================================================= + +class TestMarketContext: + """Tests for MarketContext class.""" + + def test_default_context(self): + """Default context should have None/empty values.""" + context = MarketContext() + assert context.vix is None + assert context.sector_performance == {} + + def test_context_with_values(self, sample_market_context): + """Context should store values correctly.""" + assert sample_market_context.vix == 18.5 + assert sample_market_context.economic_regime == "EXPANSION" + assert "XLK" in sample_market_context.sector_performance + + def test_to_dict(self, sample_market_context): + """To dict should serialize correctly.""" + data = sample_market_context.to_dict() + assert data["vix"] == 18.5 + assert data["economic_regime"] == "EXPANSION" + + def test_from_dict(self, sample_market_context): + """From dict should deserialize correctly.""" + data = sample_market_context.to_dict() + restored = MarketContext.from_dict(data) + assert restored.vix == sample_market_context.vix + assert restored.economic_regime == sample_market_context.economic_regime + + def test_summary(self, sample_market_context): + """Summary should generate text.""" + summary = sample_market_context.summary() + assert "VIX" in summary + assert "Regime" in summary + + +# ============================================================================= +# TradeRecord Tests +# ============================================================================= + +class TestTradeRecord: + """Tests for TradeRecord class.""" + + def test_create_trade(self): + """Create should generate a valid trade record.""" + trade = TradeRecord.create( + symbol="AAPL", + direction=TradeDirection.LONG, + entry_price=150.0, + ) + assert trade.symbol == "AAPL" + assert trade.direction == TradeDirection.LONG + assert trade.entry_price == 150.0 + assert trade.id is not None + assert trade.entry_time is not None + assert trade.is_open() + + def test_create_with_all_args(self, sample_reasoning, sample_market_context): + """Create with all arguments should work.""" + trade = TradeRecord.create( + symbol="AAPL", + direction=TradeDirection.LONG, + entry_price=150.0, + quantity=100, + signal_strength=SignalStrength.STRONG_BUY, + confidence=0.85, + reasoning=sample_reasoning, + market_context=sample_market_context, + tags=["tech"], + ) + assert trade.quantity == 100 + assert trade.signal_strength == SignalStrength.STRONG_BUY + assert trade.confidence == 0.85 + assert trade.reasoning == sample_reasoning + assert trade.market_context == sample_market_context + + def test_close_profitable(self, sample_trade): + """Closing at higher price should be profitable.""" + sample_trade.close(165.0) + + assert sample_trade.exit_price == 165.0 + assert sample_trade.exit_time is not None + assert sample_trade.returns is not None + assert sample_trade.returns > 0 + assert sample_trade.outcome == TradeOutcome.PROFITABLE + assert not sample_trade.is_open() + + def test_close_loss(self, sample_trade): + """Closing at lower price should be a loss.""" + sample_trade.close(140.0) + + assert sample_trade.returns < 0 + assert sample_trade.outcome == TradeOutcome.LOSS + + def test_close_break_even(self, sample_trade): + """Closing at same price should be break even.""" + sample_trade.close(150.2) # Within 0.5% threshold + + assert abs(sample_trade.returns) < 0.005 + assert sample_trade.outcome == TradeOutcome.BREAK_EVEN + + def test_close_short_profitable(self): + """Short trade profitable when price goes down.""" + trade = TradeRecord.create( + symbol="AAPL", + direction=TradeDirection.SHORT, + entry_price=150.0, + ) + trade.close(140.0) + + assert trade.returns > 0 + assert trade.outcome == TradeOutcome.PROFITABLE + + def test_close_short_loss(self): + """Short trade loss when price goes up.""" + trade = TradeRecord.create( + symbol="AAPL", + direction=TradeDirection.SHORT, + entry_price=150.0, + ) + trade.close(160.0) + + assert trade.returns < 0 + assert trade.outcome == TradeOutcome.LOSS + + def test_pnl_calculation(self, sample_trade): + """PnL should be calculated correctly.""" + sample_trade.close(165.0) + + expected_return = (165 - 150) / 150 + expected_pnl = expected_return * 150 * 100 + + assert abs(sample_trade.pnl - expected_pnl) < 0.01 + + def test_holding_period(self, sample_trade): + """Holding period should be calculated correctly.""" + assert sample_trade.holding_period_days() is None # Still open + + sample_trade.close(165.0) + holding = sample_trade.holding_period_days() + assert holding is not None + assert holding >= 0 + + def test_to_memory_content(self, sample_trade): + """Memory content should include key info.""" + sample_trade.close(165.0) + content = sample_trade.to_memory_content() + + assert "AAPL" in content + assert "long" in content + assert "profitable" in content + + def test_to_dict(self, sample_trade): + """To dict should serialize correctly.""" + data = sample_trade.to_dict() + + assert data["symbol"] == "AAPL" + assert data["direction"] == "long" + assert data["entry_price"] == 150.0 + assert "reasoning" in data + assert "market_context" in data + + def test_from_dict(self, sample_trade): + """From dict should deserialize correctly.""" + sample_trade.close(165.0) + data = sample_trade.to_dict() + restored = TradeRecord.from_dict(data) + + assert restored.symbol == sample_trade.symbol + assert restored.direction == sample_trade.direction + assert restored.entry_price == sample_trade.entry_price + assert restored.exit_price == sample_trade.exit_price + assert restored.outcome == sample_trade.outcome + + +# ============================================================================= +# TradeHistoryMemory Basic Tests +# ============================================================================= + +class TestTradeHistoryMemoryBasic: + """Basic tests for TradeHistoryMemory.""" + + def test_create_empty(self, memory): + """Empty memory should have zero trades.""" + assert memory.count() == 0 + + def test_record_trade(self, memory, sample_trade): + """Recording a trade should increase count.""" + trade_id = memory.record_trade(sample_trade) + assert trade_id == sample_trade.id + assert memory.count() == 1 + + def test_get_trade(self, memory, sample_trade): + """Get should return recorded trade.""" + memory.record_trade(sample_trade) + retrieved = memory.get_trade(sample_trade.id) + assert retrieved is not None + assert retrieved.symbol == sample_trade.symbol + + def test_get_nonexistent(self, memory): + """Get nonexistent trade should return None.""" + result = memory.get_trade("nonexistent") + assert result is None + + def test_close_trade(self, memory, sample_trade): + """Closing a trade should update it.""" + memory.record_trade(sample_trade) + closed = memory.close_trade(sample_trade.id, exit_price=165.0) + + assert closed is not None + assert closed.exit_price == 165.0 + assert closed.outcome is not None + + def test_close_with_lessons(self, memory, sample_trade): + """Closing with lessons should store them.""" + memory.record_trade(sample_trade) + memory.close_trade( + sample_trade.id, + exit_price=165.0, + lessons_learned="Wait for confirmation before entry", + ) + + trade = memory.get_trade(sample_trade.id) + assert trade.lessons_learned == "Wait for confirmation before entry" + + def test_clear(self, memory, multiple_trades): + """Clear should remove all trades.""" + for trade in multiple_trades: + memory.record_trade(trade) + + count = memory.clear() + assert count == len(multiple_trades) + assert memory.count() == 0 + + +# ============================================================================= +# TradeHistoryMemory Query Tests +# ============================================================================= + +class TestTradeHistoryMemoryQueries: + """Query tests for TradeHistoryMemory.""" + + def test_get_open_trades(self, memory, multiple_trades): + """Get open trades should filter correctly.""" + for trade in multiple_trades: + memory.record_trade(trade) + + open_trades = memory.get_open_trades() + assert len(open_trades) == 1 # Only NVDA is open + assert open_trades[0].symbol == "NVDA" + + def test_get_closed_trades(self, memory, multiple_trades): + """Get closed trades should filter correctly.""" + for trade in multiple_trades: + memory.record_trade(trade) + + closed_trades = memory.get_closed_trades() + assert len(closed_trades) == 3 # AAPL, GOOGL, MSFT + + def test_get_trades_by_symbol(self, memory, multiple_trades): + """Get trades by symbol should filter correctly.""" + for trade in multiple_trades: + memory.record_trade(trade) + + aapl_trades = memory.get_trades_by_symbol("AAPL") + assert len(aapl_trades) == 1 + assert aapl_trades[0].symbol == "AAPL" + + def test_find_similar_trades(self, memory, multiple_trades): + """Find similar trades should return relevant trades.""" + for trade in multiple_trades: + memory.record_trade(trade) + + similar = memory.find_similar_trades( + query="tech stock momentum earnings", + top_k=3, + ) + assert len(similar) >= 1 + + def test_find_profitable_patterns(self, memory, multiple_trades): + """Find profitable patterns should return winners.""" + for trade in multiple_trades: + memory.record_trade(trade) + + profitable = memory.find_profitable_patterns( + query="tech stock buy", + min_return=0.05, + top_k=5, + ) + # AAPL had 10% return + assert len(profitable) >= 1 + for trade in profitable: + assert trade.returns >= 0.05 + + def test_find_losing_patterns(self, memory, multiple_trades): + """Find losing patterns should return losers.""" + for trade in multiple_trades: + memory.record_trade(trade) + + losers = memory.find_losing_patterns( + query="tech stock momentum", + max_return=-0.05, + top_k=5, + ) + # GOOGL had -7% return + assert len(losers) >= 1 + for trade in losers: + assert trade.returns <= -0.05 + + +# ============================================================================= +# TradeHistoryMemory Statistics Tests +# ============================================================================= + +class TestTradeHistoryMemoryStats: + """Statistics tests for TradeHistoryMemory.""" + + def test_empty_statistics(self, memory): + """Empty memory should have zero stats.""" + stats = memory.get_statistics() + assert stats["total_trades"] == 0 + assert stats["win_rate"] == 0.0 + + def test_statistics_with_trades(self, memory, multiple_trades): + """Statistics should reflect trade data.""" + for trade in multiple_trades: + memory.record_trade(trade) + + stats = memory.get_statistics() + + assert stats["total_trades"] == 4 + assert stats["open_trades"] == 1 + assert stats["closed_trades"] == 3 + assert stats["win_rate"] > 0 # At least AAPL was profitable + + def test_symbol_statistics(self, memory, multiple_trades): + """Symbol statistics should be calculated correctly.""" + for trade in multiple_trades: + memory.record_trade(trade) + + stats = memory.get_symbol_statistics("AAPL") + + assert stats["symbol"] == "AAPL" + assert stats["total_trades"] == 1 + assert stats["closed_trades"] == 1 + assert stats["win_rate"] == 1.0 # AAPL was profitable + + +# ============================================================================= +# TradeHistoryMemory Serialization Tests +# ============================================================================= + +class TestTradeHistoryMemorySerialization: + """Serialization tests for TradeHistoryMemory.""" + + def test_to_dict(self, memory, multiple_trades): + """To dict should serialize correctly.""" + for trade in multiple_trades: + memory.record_trade(trade) + + data = memory.to_dict() + + assert "trades" in data + assert "memory" in data + assert len(data["trades"]) == len(multiple_trades) + + def test_from_dict(self, memory, multiple_trades): + """From dict should deserialize correctly.""" + for trade in multiple_trades: + memory.record_trade(trade) + + data = memory.to_dict() + restored = TradeHistoryMemory.from_dict(data) + + assert restored.count() == len(multiple_trades) + + def test_roundtrip(self, memory, sample_trade): + """Roundtrip serialization should preserve data.""" + memory.record_trade(sample_trade) + memory.close_trade(sample_trade.id, exit_price=165.0) + + data = memory.to_dict() + restored = TradeHistoryMemory.from_dict(data) + + original = memory.get_trade(sample_trade.id) + restored_trade = restored.get_trade(sample_trade.id) + + assert restored_trade is not None + assert restored_trade.symbol == original.symbol + assert restored_trade.exit_price == original.exit_price + assert restored_trade.outcome == original.outcome + + +# ============================================================================= +# Integration Tests +# ============================================================================= + +class TestIntegration: + """Integration tests for trade history workflow.""" + + def test_full_trade_lifecycle(self, memory): + """Test complete trade lifecycle.""" + # 1. Create reasoning + reasoning = AgentReasoning( + fundamentals="Strong quarterly earnings", + technical="Golden cross on daily chart", + research_conclusion="Buy for momentum continuation", + ) + + # 2. Create market context + context = MarketContext( + vix=15.0, + economic_regime="EXPANSION", + ) + + # 3. Open trade + trade = TradeRecord.create( + symbol="MSFT", + direction=TradeDirection.LONG, + entry_price=400.0, + quantity=50, + signal_strength=SignalStrength.BUY, + confidence=0.75, + reasoning=reasoning, + market_context=context, + tags=["tech", "earnings"], + ) + memory.record_trade(trade) + + # 4. Verify open + assert memory.count() == 1 + assert len(memory.get_open_trades()) == 1 + + # 5. Close with profit + memory.close_trade( + trade.id, + exit_price=440.0, + lessons_learned="Good timing on earnings play", + ) + + # 6. Verify closed + assert len(memory.get_closed_trades()) == 1 + closed = memory.get_trade(trade.id) + assert closed.outcome == TradeOutcome.PROFITABLE + assert closed.returns == pytest.approx(0.10, rel=0.01) + + # 7. Find similar trades + similar = memory.find_similar_trades( + query="tech earnings momentum", + top_k=1, + ) + assert len(similar) == 1 + assert similar[0].symbol == "MSFT" + + # 8. Check statistics + stats = memory.get_statistics() + assert stats["win_rate"] == 1.0 + assert stats["avg_return"] > 0 + + def test_learning_from_history(self, memory): + """Test learning patterns from trade history.""" + # Record several trades + trades = [ + ("AAPL", "earnings beat", 150.0, 165.0), # +10% + ("GOOGL", "earnings beat", 140.0, 147.0), # +5% + ("MSFT", "earnings miss", 350.0, 315.0), # -10% + ("NVDA", "earnings beat", 500.0, 550.0), # +10% + ] + + for symbol, pattern, entry, exit_price in trades: + trade = TradeRecord.create( + symbol=symbol, + direction=TradeDirection.LONG, + entry_price=entry, + reasoning=AgentReasoning(research_conclusion=f"Trade on {pattern}"), + tags=["earnings"], + ) + memory.record_trade(trade) + memory.close_trade(trade.id, exit_price=exit_price) + + # Query for earnings beat patterns + profitable = memory.find_profitable_patterns( + query="earnings beat momentum", + min_return=0.05, + ) + + # Should find the profitable earnings beat trades + assert len(profitable) >= 2 + + # Query for losses to avoid + losers = memory.find_losing_patterns( + query="earnings trade", + max_return=-0.05, + ) + + # Should find the earnings miss trade + assert len(losers) >= 1 diff --git a/tradingagents/memory/__init__.py b/tradingagents/memory/__init__.py index 7bff2785..86243b2c 100644 --- a/tradingagents/memory/__init__.py +++ b/tradingagents/memory/__init__.py @@ -6,6 +6,7 @@ This module provides a layered memory system with three scoring dimensions: - Importance: Significance weighting for impactful events Issue #18: Layered memory - recency, relevancy, importance scoring +Issue #19: Trade history memory - outcomes, agent reasoning """ from .layered_memory import ( @@ -17,11 +18,30 @@ from .layered_memory import ( ImportanceLevel, ) +from .trade_history import ( + TradeHistoryMemory, + TradeRecord, + TradeOutcome, + TradeDirection, + SignalStrength, + AgentReasoning, + MarketContext, +) + __all__ = [ + # Layered Memory (Issue #18) "LayeredMemory", "MemoryEntry", "MemoryConfig", "ScoringWeights", "DecayFunction", "ImportanceLevel", + # Trade History (Issue #19) + "TradeHistoryMemory", + "TradeRecord", + "TradeOutcome", + "TradeDirection", + "SignalStrength", + "AgentReasoning", + "MarketContext", ] diff --git a/tradingagents/memory/trade_history.py b/tradingagents/memory/trade_history.py new file mode 100644 index 00000000..1d05b86d --- /dev/null +++ b/tradingagents/memory/trade_history.py @@ -0,0 +1,820 @@ +"""Trade History Memory for learning from past trade outcomes. + +This module provides specialized memory for tracking and learning from trades: +- Trade outcomes (profit/loss, returns) +- Agent reasoning and signals +- Entry/exit conditions +- Market context at time of trade + +Issue #19: [MEM-18] Trade history memory - outcomes, agent reasoning +""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Dict, List, Optional, Any +import uuid + +from .layered_memory import ( + LayeredMemory, + MemoryEntry, + MemoryConfig, + ScoringWeights, + ImportanceLevel, +) + + +class TradeOutcome(Enum): + """Trade outcome categories.""" + PROFITABLE = "profitable" # Positive return + BREAK_EVEN = "break_even" # ~0% return + LOSS = "loss" # Negative return + STOPPED_OUT = "stopped_out" # Hit stop loss + TARGET_HIT = "target_hit" # Hit profit target + + +class TradeDirection(Enum): + """Trade direction.""" + LONG = "long" + SHORT = "short" + HOLD = "hold" + + +class SignalStrength(Enum): + """Signal strength levels.""" + STRONG_BUY = "strong_buy" + BUY = "buy" + NEUTRAL = "neutral" + SELL = "sell" + STRONG_SELL = "strong_sell" + + +@dataclass +class AgentReasoning: + """Captures reasoning from each agent in the trading workflow. + + Attributes: + fundamentals: Fundamentals analyst reasoning + technical: Technical/market analyst reasoning + news: News analyst reasoning + sentiment: Social media sentiment reasoning + momentum: Momentum analyst reasoning (if enabled) + macro: Macro analyst reasoning (if enabled) + correlation: Correlation analyst reasoning (if enabled) + bull_case: Bull researcher arguments + bear_case: Bear researcher arguments + research_conclusion: Research manager decision + risk_assessment: Risk manager assessment + final_signal: Final trading signal + """ + fundamentals: Optional[str] = None + technical: Optional[str] = None + news: Optional[str] = None + sentiment: Optional[str] = None + momentum: Optional[str] = None + macro: Optional[str] = None + correlation: Optional[str] = None + bull_case: Optional[str] = None + bear_case: Optional[str] = None + research_conclusion: Optional[str] = None + risk_assessment: Optional[str] = None + final_signal: Optional[str] = None + + def to_dict(self) -> Dict[str, Optional[str]]: + """Convert to dictionary.""" + return { + "fundamentals": self.fundamentals, + "technical": self.technical, + "news": self.news, + "sentiment": self.sentiment, + "momentum": self.momentum, + "macro": self.macro, + "correlation": self.correlation, + "bull_case": self.bull_case, + "bear_case": self.bear_case, + "research_conclusion": self.research_conclusion, + "risk_assessment": self.risk_assessment, + "final_signal": self.final_signal, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AgentReasoning": + """Create from dictionary.""" + return cls( + fundamentals=data.get("fundamentals"), + technical=data.get("technical"), + news=data.get("news"), + sentiment=data.get("sentiment"), + momentum=data.get("momentum"), + macro=data.get("macro"), + correlation=data.get("correlation"), + bull_case=data.get("bull_case"), + bear_case=data.get("bear_case"), + research_conclusion=data.get("research_conclusion"), + risk_assessment=data.get("risk_assessment"), + final_signal=data.get("final_signal"), + ) + + def summary(self) -> str: + """Generate a text summary of the reasoning.""" + parts = [] + + if self.fundamentals: + parts.append(f"Fundamentals: {self.fundamentals[:100]}...") + if self.technical: + parts.append(f"Technical: {self.technical[:100]}...") + if self.bull_case: + parts.append(f"Bull: {self.bull_case[:100]}...") + if self.bear_case: + parts.append(f"Bear: {self.bear_case[:100]}...") + if self.research_conclusion: + parts.append(f"Conclusion: {self.research_conclusion[:100]}...") + + return " | ".join(parts) if parts else "No reasoning recorded" + + +@dataclass +class MarketContext: + """Market conditions at time of trade. + + Attributes: + vix: VIX volatility index level + spy_return_1d: SPY 1-day return + sector_performance: Dict of sector returns + economic_regime: Detected economic regime + yield_curve_state: Yield curve status + macro_indicators: Key macro indicators + """ + vix: Optional[float] = None + spy_return_1d: Optional[float] = None + sector_performance: Dict[str, float] = field(default_factory=dict) + economic_regime: Optional[str] = None + yield_curve_state: Optional[str] = None + macro_indicators: Dict[str, float] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "vix": self.vix, + "spy_return_1d": self.spy_return_1d, + "sector_performance": self.sector_performance, + "economic_regime": self.economic_regime, + "yield_curve_state": self.yield_curve_state, + "macro_indicators": self.macro_indicators, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MarketContext": + """Create from dictionary.""" + return cls( + vix=data.get("vix"), + spy_return_1d=data.get("spy_return_1d"), + sector_performance=data.get("sector_performance", {}), + economic_regime=data.get("economic_regime"), + yield_curve_state=data.get("yield_curve_state"), + macro_indicators=data.get("macro_indicators", {}), + ) + + def summary(self) -> str: + """Generate a text summary of market context.""" + parts = [] + + if self.vix is not None: + parts.append(f"VIX: {self.vix:.1f}") + if self.economic_regime: + parts.append(f"Regime: {self.economic_regime}") + if self.yield_curve_state: + parts.append(f"Yield Curve: {self.yield_curve_state}") + + return " | ".join(parts) if parts else "No market context" + + +@dataclass +class TradeRecord: + """Complete record of a trade including reasoning and outcome. + + Attributes: + id: Unique trade ID + symbol: Trading symbol + direction: Trade direction (long/short/hold) + entry_price: Entry price + exit_price: Exit price (if closed) + entry_time: Entry timestamp + exit_time: Exit timestamp (if closed) + quantity: Number of shares/contracts + returns: Percentage return + pnl: Dollar profit/loss + outcome: Trade outcome category + signal_strength: Original signal strength + confidence: Confidence score (0-1) + reasoning: Agent reasoning captured + market_context: Market conditions at entry + lessons_learned: Post-trade lessons (added later) + tags: Trade tags for filtering + """ + id: str + symbol: str + direction: TradeDirection + entry_price: float + entry_time: datetime + quantity: float = 1.0 + exit_price: Optional[float] = None + exit_time: Optional[datetime] = None + returns: Optional[float] = None + pnl: Optional[float] = None + outcome: Optional[TradeOutcome] = None + signal_strength: SignalStrength = SignalStrength.NEUTRAL + confidence: float = 0.5 + reasoning: AgentReasoning = field(default_factory=AgentReasoning) + market_context: MarketContext = field(default_factory=MarketContext) + lessons_learned: Optional[str] = None + tags: List[str] = field(default_factory=list) + + @classmethod + def create( + cls, + symbol: str, + direction: TradeDirection, + entry_price: float, + quantity: float = 1.0, + signal_strength: SignalStrength = SignalStrength.NEUTRAL, + confidence: float = 0.5, + reasoning: Optional[AgentReasoning] = None, + market_context: Optional[MarketContext] = None, + tags: Optional[List[str]] = None, + ) -> "TradeRecord": + """Create a new trade record.""" + return cls( + id=str(uuid.uuid4()), + symbol=symbol, + direction=direction, + entry_price=entry_price, + entry_time=datetime.now(), + quantity=quantity, + signal_strength=signal_strength, + confidence=confidence, + reasoning=reasoning or AgentReasoning(), + market_context=market_context or MarketContext(), + tags=tags or [], + ) + + def close( + self, + exit_price: float, + exit_time: Optional[datetime] = None, + ) -> "TradeRecord": + """Close the trade and calculate returns. + + Args: + exit_price: Exit price + exit_time: Exit timestamp (default: now) + + Returns: + Self with updated exit info + """ + self.exit_price = exit_price + self.exit_time = exit_time or datetime.now() + + # Calculate returns + if self.direction == TradeDirection.LONG: + self.returns = (exit_price - self.entry_price) / self.entry_price + elif self.direction == TradeDirection.SHORT: + self.returns = (self.entry_price - exit_price) / self.entry_price + else: + self.returns = 0.0 + + # Calculate PnL + self.pnl = self.returns * self.entry_price * self.quantity + + # Determine outcome + if self.returns > 0.005: # > 0.5% + self.outcome = TradeOutcome.PROFITABLE + elif self.returns < -0.005: # < -0.5% + self.outcome = TradeOutcome.LOSS + else: + self.outcome = TradeOutcome.BREAK_EVEN + + return self + + def is_open(self) -> bool: + """Check if trade is still open.""" + return self.exit_time is None + + def holding_period_days(self) -> Optional[float]: + """Calculate holding period in days.""" + if self.exit_time is None: + return None + delta = self.exit_time - self.entry_time + return delta.total_seconds() / 86400 + + def to_memory_content(self) -> str: + """Generate memory content for this trade.""" + parts = [ + f"Trade {self.direction.value} {self.symbol} at ${self.entry_price:.2f}", + ] + + if self.outcome: + parts.append(f"Outcome: {self.outcome.value}") + + if self.returns is not None: + parts.append(f"Return: {self.returns * 100:.2f}%") + + if self.reasoning.research_conclusion: + parts.append(f"Reasoning: {self.reasoning.research_conclusion}") + + if self.market_context.economic_regime: + parts.append(f"Regime: {self.market_context.economic_regime}") + + return " | ".join(parts) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "id": self.id, + "symbol": self.symbol, + "direction": self.direction.value, + "entry_price": self.entry_price, + "exit_price": self.exit_price, + "entry_time": self.entry_time.isoformat(), + "exit_time": self.exit_time.isoformat() if self.exit_time else None, + "quantity": self.quantity, + "returns": self.returns, + "pnl": self.pnl, + "outcome": self.outcome.value if self.outcome else None, + "signal_strength": self.signal_strength.value, + "confidence": self.confidence, + "reasoning": self.reasoning.to_dict(), + "market_context": self.market_context.to_dict(), + "lessons_learned": self.lessons_learned, + "tags": self.tags, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TradeRecord": + """Create from dictionary.""" + return cls( + id=data["id"], + symbol=data["symbol"], + direction=TradeDirection(data["direction"]), + entry_price=data["entry_price"], + exit_price=data.get("exit_price"), + entry_time=datetime.fromisoformat(data["entry_time"]), + exit_time=datetime.fromisoformat(data["exit_time"]) if data.get("exit_time") else None, + quantity=data.get("quantity", 1.0), + returns=data.get("returns"), + pnl=data.get("pnl"), + outcome=TradeOutcome(data["outcome"]) if data.get("outcome") else None, + signal_strength=SignalStrength(data.get("signal_strength", "neutral")), + confidence=data.get("confidence", 0.5), + reasoning=AgentReasoning.from_dict(data.get("reasoning", {})), + market_context=MarketContext.from_dict(data.get("market_context", {})), + lessons_learned=data.get("lessons_learned"), + tags=data.get("tags", []), + ) + + +class TradeHistoryMemory: + """Memory system specialized for trade history and learning. + + This class combines trade record storage with the layered memory system + for intelligent retrieval of past trades based on similarity. + + Example: + >>> memory = TradeHistoryMemory() + >>> + >>> # Record a trade + >>> trade = TradeRecord.create( + ... symbol="AAPL", + ... direction=TradeDirection.LONG, + ... entry_price=150.0, + ... reasoning=AgentReasoning( + ... fundamentals="Strong earnings", + ... research_conclusion="Buy on earnings momentum", + ... ), + ... ) + >>> memory.record_trade(trade) + >>> + >>> # Close the trade + >>> memory.close_trade(trade.id, exit_price=165.0) + >>> + >>> # Find similar past trades + >>> similar = memory.find_similar_trades( + ... query="Apple earnings momentum", + ... top_k=5, + ... ) + """ + + def __init__( + self, + config: Optional[MemoryConfig] = None, + embedding_function=None, + ): + """Initialize trade history memory. + + Args: + config: Memory configuration + embedding_function: Optional embedding function for similarity + """ + # Default config with weights tuned for trade history + if config is None: + config = MemoryConfig( + weights=ScoringWeights( + recency=0.25, # Recent trades somewhat important + relevancy=0.45, # Similarity most important + importance=0.30, # Outcome importance matters + ), + ) + + self._layered_memory = LayeredMemory( + config=config, + embedding_function=embedding_function, + ) + self._trades: Dict[str, TradeRecord] = {} + + def record_trade(self, trade: TradeRecord) -> str: + """Record a new trade. + + Args: + trade: The trade record to store + + Returns: + Trade ID + """ + self._trades[trade.id] = trade + + # Create memory entry for the trade + importance = self._calculate_trade_importance(trade) + entry = MemoryEntry.create( + content=trade.to_memory_content(), + metadata={ + "trade_id": trade.id, + "symbol": trade.symbol, + "direction": trade.direction.value, + "outcome": trade.outcome.value if trade.outcome else None, + "returns": trade.returns, + "reasoning_summary": trade.reasoning.summary(), + }, + importance=importance, + tags=trade.tags + [trade.symbol, trade.direction.value], + timestamp=trade.entry_time, + ) + entry.id = trade.id # Use trade ID as memory ID + + self._layered_memory.add(entry) + return trade.id + + def close_trade( + self, + trade_id: str, + exit_price: float, + lessons_learned: Optional[str] = None, + ) -> Optional[TradeRecord]: + """Close an open trade and update memory. + + Args: + trade_id: ID of the trade to close + exit_price: Exit price + lessons_learned: Optional lessons from the trade + + Returns: + Updated trade record or None if not found + """ + trade = self._trades.get(trade_id) + if trade is None: + return None + + # Close the trade + trade.close(exit_price) + + if lessons_learned: + trade.lessons_learned = lessons_learned + + # Update memory with outcome + importance = self._calculate_trade_importance(trade) + self._layered_memory.update_importance(trade_id, importance) + + # Update memory entry content + entry = self._layered_memory.get(trade_id) + if entry: + entry.metadata["outcome"] = trade.outcome.value if trade.outcome else None + entry.metadata["returns"] = trade.returns + if lessons_learned: + entry.metadata["lessons_learned"] = lessons_learned + + return trade + + def get_trade(self, trade_id: str) -> Optional[TradeRecord]: + """Get a trade by ID. + + Args: + trade_id: Trade ID + + Returns: + Trade record or None + """ + return self._trades.get(trade_id) + + def get_open_trades(self) -> List[TradeRecord]: + """Get all open trades. + + Returns: + List of open trade records + """ + return [t for t in self._trades.values() if t.is_open()] + + def get_closed_trades(self) -> List[TradeRecord]: + """Get all closed trades. + + Returns: + List of closed trade records + """ + return [t for t in self._trades.values() if not t.is_open()] + + def get_trades_by_symbol(self, symbol: str) -> List[TradeRecord]: + """Get all trades for a symbol. + + Args: + symbol: Trading symbol + + Returns: + List of trade records for the symbol + """ + return [t for t in self._trades.values() if t.symbol == symbol] + + def find_similar_trades( + self, + query: str, + top_k: int = 5, + symbol: Optional[str] = None, + outcome: Optional[TradeOutcome] = None, + ) -> List[TradeRecord]: + """Find similar past trades. + + Args: + query: Query describing the current situation + top_k: Maximum number of results + symbol: Optional filter by symbol + outcome: Optional filter by outcome + + Returns: + List of similar trade records + """ + # Build tags filter + tags = [] + if symbol: + tags.append(symbol) + + # Retrieve from layered memory + results = self._layered_memory.retrieve( + query=query, + top_k=top_k * 2, # Get more to filter + tags=tags if tags else None, + ) + + # Convert to trade records and filter + trades = [] + for scored in results: + trade_id = scored.entry.metadata.get("trade_id") + if trade_id and trade_id in self._trades: + trade = self._trades[trade_id] + + # Apply outcome filter + if outcome and trade.outcome != outcome: + continue + + trades.append(trade) + + if len(trades) >= top_k: + break + + return trades + + def find_profitable_patterns( + self, + query: str, + min_return: float = 0.0, + top_k: int = 5, + ) -> List[TradeRecord]: + """Find profitable trades similar to the query. + + Args: + query: Query describing the current situation + min_return: Minimum return filter + top_k: Maximum number of results + + Returns: + List of profitable trade records + """ + results = self._layered_memory.retrieve(query=query, top_k=top_k * 3) + + trades = [] + for scored in results: + trade_id = scored.entry.metadata.get("trade_id") + if trade_id and trade_id in self._trades: + trade = self._trades[trade_id] + + if trade.returns is not None and trade.returns >= min_return: + trades.append(trade) + + if len(trades) >= top_k: + break + + return trades + + def find_losing_patterns( + self, + query: str, + max_return: float = 0.0, + top_k: int = 5, + ) -> List[TradeRecord]: + """Find losing trades similar to the query (for learning what to avoid). + + Args: + query: Query describing the current situation + max_return: Maximum return filter + top_k: Maximum number of results + + Returns: + List of losing trade records + """ + results = self._layered_memory.retrieve(query=query, top_k=top_k * 3) + + trades = [] + for scored in results: + trade_id = scored.entry.metadata.get("trade_id") + if trade_id and trade_id in self._trades: + trade = self._trades[trade_id] + + if trade.returns is not None and trade.returns <= max_return: + trades.append(trade) + + if len(trades) >= top_k: + break + + return trades + + def get_statistics(self) -> Dict[str, Any]: + """Get trade history statistics. + + Returns: + Dictionary with statistics + """ + closed_trades = self.get_closed_trades() + + if not closed_trades: + return { + "total_trades": len(self._trades), + "open_trades": len(self.get_open_trades()), + "closed_trades": 0, + "win_rate": 0.0, + "avg_return": 0.0, + "total_pnl": 0.0, + "best_trade": None, + "worst_trade": None, + } + + returns = [t.returns for t in closed_trades if t.returns is not None] + pnls = [t.pnl for t in closed_trades if t.pnl is not None] + winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE] + + # Find best and worst + best_trade = max(closed_trades, key=lambda t: t.returns or 0, default=None) + worst_trade = min(closed_trades, key=lambda t: t.returns or 0, default=None) + + # Outcome distribution + outcome_dist = {} + for trade in closed_trades: + if trade.outcome: + outcome_dist[trade.outcome.value] = outcome_dist.get(trade.outcome.value, 0) + 1 + + return { + "total_trades": len(self._trades), + "open_trades": len(self.get_open_trades()), + "closed_trades": len(closed_trades), + "win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0, + "avg_return": sum(returns) / len(returns) if returns else 0.0, + "total_pnl": sum(pnls) if pnls else 0.0, + "best_trade": best_trade.to_dict() if best_trade else None, + "worst_trade": worst_trade.to_dict() if worst_trade else None, + "outcome_distribution": outcome_dist, + } + + def get_symbol_statistics(self, symbol: str) -> Dict[str, Any]: + """Get statistics for a specific symbol. + + Args: + symbol: Trading symbol + + Returns: + Dictionary with symbol-specific statistics + """ + symbol_trades = self.get_trades_by_symbol(symbol) + closed_trades = [t for t in symbol_trades if not t.is_open()] + + if not closed_trades: + return { + "symbol": symbol, + "total_trades": len(symbol_trades), + "open_trades": len([t for t in symbol_trades if t.is_open()]), + "closed_trades": 0, + "win_rate": 0.0, + "avg_return": 0.0, + } + + returns = [t.returns for t in closed_trades if t.returns is not None] + winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE] + + return { + "symbol": symbol, + "total_trades": len(symbol_trades), + "open_trades": len([t for t in symbol_trades if t.is_open()]), + "closed_trades": len(closed_trades), + "win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0, + "avg_return": sum(returns) / len(returns) if returns else 0.0, + } + + def _calculate_trade_importance(self, trade: TradeRecord) -> float: + """Calculate importance score for a trade. + + Args: + trade: Trade record + + Returns: + Importance score [0, 1] + """ + # Base on returns if closed + if trade.returns is not None: + abs_return = abs(trade.returns) + if abs_return >= 0.10: # 10%+ + return ImportanceLevel.CRITICAL.value + elif abs_return >= 0.05: # 5%+ + return ImportanceLevel.HIGH.value + elif abs_return >= 0.01: # 1%+ + return ImportanceLevel.MEDIUM.value + else: + return ImportanceLevel.LOW.value + + # For open trades, base on confidence + if trade.confidence >= 0.8: + return ImportanceLevel.HIGH.value + elif trade.confidence >= 0.5: + return ImportanceLevel.MEDIUM.value + else: + return ImportanceLevel.LOW.value + + def count(self) -> int: + """Return total number of trades.""" + return len(self._trades) + + def clear(self) -> int: + """Clear all trades. + + Returns: + Number of trades cleared + """ + count = len(self._trades) + self._trades.clear() + self._layered_memory.clear() + return count + + def to_dict(self) -> Dict[str, Any]: + """Serialize to dictionary. + + Returns: + Dictionary representation + """ + return { + "trades": [t.to_dict() for t in self._trades.values()], + "memory": self._layered_memory.to_dict(), + } + + @classmethod + def from_dict( + cls, + data: Dict[str, Any], + embedding_function=None, + ) -> "TradeHistoryMemory": + """Create from dictionary. + + Args: + data: Dictionary representation + embedding_function: Optional embedding function + + Returns: + TradeHistoryMemory instance + """ + instance = cls(embedding_function=embedding_function) + + # Restore trades + for trade_data in data.get("trades", []): + trade = TradeRecord.from_dict(trade_data) + instance._trades[trade.id] = trade + + # Restore layered memory + if "memory" in data: + instance._layered_memory = LayeredMemory.from_dict( + data["memory"], + embedding_function=embedding_function, + ) + + return instance