"""Tests for Issue #19: Trade History Memory. This module tests the trade history memory for tracking: - Trade outcomes (profit/loss) - Agent reasoning - Market context - Pattern finding """ import pytest from datetime import datetime, timedelta from unittest.mock import MagicMock from tradingagents.memory.trade_history import ( TradeHistoryMemory, TradeRecord, TradeOutcome, TradeDirection, SignalStrength, AgentReasoning, MarketContext, ) # ============================================================================= # Test Fixtures # ============================================================================= @pytest.fixture def memory(): """Create a TradeHistoryMemory instance.""" return TradeHistoryMemory() @pytest.fixture def sample_reasoning(): """Create sample agent reasoning.""" return AgentReasoning( fundamentals="Strong earnings growth, P/E below industry average", technical="Breaking out above 50-day MA with volume", news="Positive analyst upgrades", sentiment="Bullish social media sentiment", bull_case="Undervalued with strong growth trajectory", bear_case="Macro headwinds could impact demand", research_conclusion="Buy on fundamental strength", final_signal="STRONG_BUY", ) @pytest.fixture def sample_market_context(): """Create sample market context.""" return MarketContext( vix=18.5, spy_return_1d=0.005, sector_performance={"XLK": 0.01, "XLF": -0.005}, economic_regime="EXPANSION", yield_curve_state="NORMAL", macro_indicators={"gdp_growth": 2.5, "unemployment": 4.0}, ) @pytest.fixture def sample_trade(sample_reasoning, sample_market_context): """Create a sample trade record.""" return TradeRecord.create( symbol="AAPL", direction=TradeDirection.LONG, entry_price=150.0, quantity=100, signal_strength=SignalStrength.STRONG_BUY, confidence=0.85, reasoning=sample_reasoning, market_context=sample_market_context, tags=["tech", "earnings"], ) @pytest.fixture def multiple_trades(): """Create multiple trade records for testing.""" now = datetime.now() trades = [] # Profitable AAPL trade trade1 = TradeRecord( id="trade-1", symbol="AAPL", direction=TradeDirection.LONG, entry_price=150.0, entry_time=now - timedelta(days=10), quantity=100, reasoning=AgentReasoning(research_conclusion="Buy on earnings"), tags=["tech", "earnings"], ) trade1.close(165.0) # Loss GOOGL trade trade2 = TradeRecord( id="trade-2", symbol="GOOGL", direction=TradeDirection.LONG, entry_price=140.0, entry_time=now - timedelta(days=5), quantity=50, reasoning=AgentReasoning(research_conclusion="Momentum buy"), tags=["tech"], ) trade2.close(130.0) # Break-even MSFT trade trade3 = TradeRecord( id="trade-3", symbol="MSFT", direction=TradeDirection.LONG, entry_price=350.0, entry_time=now - timedelta(days=2), quantity=20, reasoning=AgentReasoning(research_conclusion="Range trade"), tags=["tech"], ) trade3.close(351.0) # Open NVDA trade trade4 = TradeRecord( id="trade-4", symbol="NVDA", direction=TradeDirection.LONG, entry_price=500.0, entry_time=now - timedelta(hours=2), quantity=10, reasoning=AgentReasoning(research_conclusion="AI momentum"), tags=["tech", "ai"], ) trades = [trade1, trade2, trade3, trade4] return trades # ============================================================================= # TradeDirection Tests # ============================================================================= class TestTradeDirection: """Tests for TradeDirection enum.""" def test_long_value(self): """LONG should have correct value.""" assert TradeDirection.LONG.value == "long" def test_short_value(self): """SHORT should have correct value.""" assert TradeDirection.SHORT.value == "short" def test_hold_value(self): """HOLD should have correct value.""" assert TradeDirection.HOLD.value == "hold" # ============================================================================= # TradeOutcome Tests # ============================================================================= class TestTradeOutcome: """Tests for TradeOutcome enum.""" def test_profitable(self): """PROFITABLE should have correct value.""" assert TradeOutcome.PROFITABLE.value == "profitable" def test_loss(self): """LOSS should have correct value.""" assert TradeOutcome.LOSS.value == "loss" def test_break_even(self): """BREAK_EVEN should have correct value.""" assert TradeOutcome.BREAK_EVEN.value == "break_even" # ============================================================================= # SignalStrength Tests # ============================================================================= class TestSignalStrength: """Tests for SignalStrength enum.""" def test_signal_values(self): """All signal strengths should have correct values.""" assert SignalStrength.STRONG_BUY.value == "strong_buy" assert SignalStrength.BUY.value == "buy" assert SignalStrength.NEUTRAL.value == "neutral" assert SignalStrength.SELL.value == "sell" assert SignalStrength.STRONG_SELL.value == "strong_sell" # ============================================================================= # AgentReasoning Tests # ============================================================================= class TestAgentReasoning: """Tests for AgentReasoning class.""" def test_default_reasoning(self): """Default reasoning should have None values.""" reasoning = AgentReasoning() assert reasoning.fundamentals is None assert reasoning.technical is None assert reasoning.research_conclusion is None def test_reasoning_with_values(self, sample_reasoning): """Reasoning should store values correctly.""" assert sample_reasoning.fundamentals is not None assert sample_reasoning.research_conclusion == "Buy on fundamental strength" def test_to_dict(self, sample_reasoning): """To dict should serialize correctly.""" data = sample_reasoning.to_dict() assert data["fundamentals"] == sample_reasoning.fundamentals assert data["research_conclusion"] == sample_reasoning.research_conclusion def test_from_dict(self, sample_reasoning): """From dict should deserialize correctly.""" data = sample_reasoning.to_dict() restored = AgentReasoning.from_dict(data) assert restored.fundamentals == sample_reasoning.fundamentals assert restored.research_conclusion == sample_reasoning.research_conclusion def test_summary(self, sample_reasoning): """Summary should generate text.""" summary = sample_reasoning.summary() assert "Fundamentals" in summary assert "Conclusion" in summary def test_empty_summary(self): """Empty reasoning should have fallback summary.""" reasoning = AgentReasoning() summary = reasoning.summary() assert summary == "No reasoning recorded" # ============================================================================= # MarketContext Tests # ============================================================================= class TestMarketContext: """Tests for MarketContext class.""" def test_default_context(self): """Default context should have None/empty values.""" context = MarketContext() assert context.vix is None assert context.sector_performance == {} def test_context_with_values(self, sample_market_context): """Context should store values correctly.""" assert sample_market_context.vix == 18.5 assert sample_market_context.economic_regime == "EXPANSION" assert "XLK" in sample_market_context.sector_performance def test_to_dict(self, sample_market_context): """To dict should serialize correctly.""" data = sample_market_context.to_dict() assert data["vix"] == 18.5 assert data["economic_regime"] == "EXPANSION" def test_from_dict(self, sample_market_context): """From dict should deserialize correctly.""" data = sample_market_context.to_dict() restored = MarketContext.from_dict(data) assert restored.vix == sample_market_context.vix assert restored.economic_regime == sample_market_context.economic_regime def test_summary(self, sample_market_context): """Summary should generate text.""" summary = sample_market_context.summary() assert "VIX" in summary assert "Regime" in summary # ============================================================================= # TradeRecord Tests # ============================================================================= class TestTradeRecord: """Tests for TradeRecord class.""" def test_create_trade(self): """Create should generate a valid trade record.""" trade = TradeRecord.create( symbol="AAPL", direction=TradeDirection.LONG, entry_price=150.0, ) assert trade.symbol == "AAPL" assert trade.direction == TradeDirection.LONG assert trade.entry_price == 150.0 assert trade.id is not None assert trade.entry_time is not None assert trade.is_open() def test_create_with_all_args(self, sample_reasoning, sample_market_context): """Create with all arguments should work.""" trade = TradeRecord.create( symbol="AAPL", direction=TradeDirection.LONG, entry_price=150.0, quantity=100, signal_strength=SignalStrength.STRONG_BUY, confidence=0.85, reasoning=sample_reasoning, market_context=sample_market_context, tags=["tech"], ) assert trade.quantity == 100 assert trade.signal_strength == SignalStrength.STRONG_BUY assert trade.confidence == 0.85 assert trade.reasoning == sample_reasoning assert trade.market_context == sample_market_context def test_close_profitable(self, sample_trade): """Closing at higher price should be profitable.""" sample_trade.close(165.0) assert sample_trade.exit_price == 165.0 assert sample_trade.exit_time is not None assert sample_trade.returns is not None assert sample_trade.returns > 0 assert sample_trade.outcome == TradeOutcome.PROFITABLE assert not sample_trade.is_open() def test_close_loss(self, sample_trade): """Closing at lower price should be a loss.""" sample_trade.close(140.0) assert sample_trade.returns < 0 assert sample_trade.outcome == TradeOutcome.LOSS def test_close_break_even(self, sample_trade): """Closing at same price should be break even.""" sample_trade.close(150.2) # Within 0.5% threshold assert abs(sample_trade.returns) < 0.005 assert sample_trade.outcome == TradeOutcome.BREAK_EVEN def test_close_short_profitable(self): """Short trade profitable when price goes down.""" trade = TradeRecord.create( symbol="AAPL", direction=TradeDirection.SHORT, entry_price=150.0, ) trade.close(140.0) assert trade.returns > 0 assert trade.outcome == TradeOutcome.PROFITABLE def test_close_short_loss(self): """Short trade loss when price goes up.""" trade = TradeRecord.create( symbol="AAPL", direction=TradeDirection.SHORT, entry_price=150.0, ) trade.close(160.0) assert trade.returns < 0 assert trade.outcome == TradeOutcome.LOSS def test_pnl_calculation(self, sample_trade): """PnL should be calculated correctly.""" sample_trade.close(165.0) expected_return = (165 - 150) / 150 expected_pnl = expected_return * 150 * 100 assert abs(sample_trade.pnl - expected_pnl) < 0.01 def test_holding_period(self, sample_trade): """Holding period should be calculated correctly.""" assert sample_trade.holding_period_days() is None # Still open sample_trade.close(165.0) holding = sample_trade.holding_period_days() assert holding is not None assert holding >= 0 def test_to_memory_content(self, sample_trade): """Memory content should include key info.""" sample_trade.close(165.0) content = sample_trade.to_memory_content() assert "AAPL" in content assert "long" in content assert "profitable" in content def test_to_dict(self, sample_trade): """To dict should serialize correctly.""" data = sample_trade.to_dict() assert data["symbol"] == "AAPL" assert data["direction"] == "long" assert data["entry_price"] == 150.0 assert "reasoning" in data assert "market_context" in data def test_from_dict(self, sample_trade): """From dict should deserialize correctly.""" sample_trade.close(165.0) data = sample_trade.to_dict() restored = TradeRecord.from_dict(data) assert restored.symbol == sample_trade.symbol assert restored.direction == sample_trade.direction assert restored.entry_price == sample_trade.entry_price assert restored.exit_price == sample_trade.exit_price assert restored.outcome == sample_trade.outcome # ============================================================================= # TradeHistoryMemory Basic Tests # ============================================================================= class TestTradeHistoryMemoryBasic: """Basic tests for TradeHistoryMemory.""" def test_create_empty(self, memory): """Empty memory should have zero trades.""" assert memory.count() == 0 def test_record_trade(self, memory, sample_trade): """Recording a trade should increase count.""" trade_id = memory.record_trade(sample_trade) assert trade_id == sample_trade.id assert memory.count() == 1 def test_get_trade(self, memory, sample_trade): """Get should return recorded trade.""" memory.record_trade(sample_trade) retrieved = memory.get_trade(sample_trade.id) assert retrieved is not None assert retrieved.symbol == sample_trade.symbol def test_get_nonexistent(self, memory): """Get nonexistent trade should return None.""" result = memory.get_trade("nonexistent") assert result is None def test_close_trade(self, memory, sample_trade): """Closing a trade should update it.""" memory.record_trade(sample_trade) closed = memory.close_trade(sample_trade.id, exit_price=165.0) assert closed is not None assert closed.exit_price == 165.0 assert closed.outcome is not None def test_close_with_lessons(self, memory, sample_trade): """Closing with lessons should store them.""" memory.record_trade(sample_trade) memory.close_trade( sample_trade.id, exit_price=165.0, lessons_learned="Wait for confirmation before entry", ) trade = memory.get_trade(sample_trade.id) assert trade.lessons_learned == "Wait for confirmation before entry" def test_clear(self, memory, multiple_trades): """Clear should remove all trades.""" for trade in multiple_trades: memory.record_trade(trade) count = memory.clear() assert count == len(multiple_trades) assert memory.count() == 0 # ============================================================================= # TradeHistoryMemory Query Tests # ============================================================================= class TestTradeHistoryMemoryQueries: """Query tests for TradeHistoryMemory.""" def test_get_open_trades(self, memory, multiple_trades): """Get open trades should filter correctly.""" for trade in multiple_trades: memory.record_trade(trade) open_trades = memory.get_open_trades() assert len(open_trades) == 1 # Only NVDA is open assert open_trades[0].symbol == "NVDA" def test_get_closed_trades(self, memory, multiple_trades): """Get closed trades should filter correctly.""" for trade in multiple_trades: memory.record_trade(trade) closed_trades = memory.get_closed_trades() assert len(closed_trades) == 3 # AAPL, GOOGL, MSFT def test_get_trades_by_symbol(self, memory, multiple_trades): """Get trades by symbol should filter correctly.""" for trade in multiple_trades: memory.record_trade(trade) aapl_trades = memory.get_trades_by_symbol("AAPL") assert len(aapl_trades) == 1 assert aapl_trades[0].symbol == "AAPL" def test_find_similar_trades(self, memory, multiple_trades): """Find similar trades should return relevant trades.""" for trade in multiple_trades: memory.record_trade(trade) similar = memory.find_similar_trades( query="tech stock momentum earnings", top_k=3, ) assert len(similar) >= 1 def test_find_profitable_patterns(self, memory, multiple_trades): """Find profitable patterns should return winners.""" for trade in multiple_trades: memory.record_trade(trade) profitable = memory.find_profitable_patterns( query="tech stock buy", min_return=0.05, top_k=5, ) # AAPL had 10% return assert len(profitable) >= 1 for trade in profitable: assert trade.returns >= 0.05 def test_find_losing_patterns(self, memory, multiple_trades): """Find losing patterns should return losers.""" for trade in multiple_trades: memory.record_trade(trade) losers = memory.find_losing_patterns( query="tech stock momentum", max_return=-0.05, top_k=5, ) # GOOGL had -7% return assert len(losers) >= 1 for trade in losers: assert trade.returns <= -0.05 # ============================================================================= # TradeHistoryMemory Statistics Tests # ============================================================================= class TestTradeHistoryMemoryStats: """Statistics tests for TradeHistoryMemory.""" def test_empty_statistics(self, memory): """Empty memory should have zero stats.""" stats = memory.get_statistics() assert stats["total_trades"] == 0 assert stats["win_rate"] == 0.0 def test_statistics_with_trades(self, memory, multiple_trades): """Statistics should reflect trade data.""" for trade in multiple_trades: memory.record_trade(trade) stats = memory.get_statistics() assert stats["total_trades"] == 4 assert stats["open_trades"] == 1 assert stats["closed_trades"] == 3 assert stats["win_rate"] > 0 # At least AAPL was profitable def test_symbol_statistics(self, memory, multiple_trades): """Symbol statistics should be calculated correctly.""" for trade in multiple_trades: memory.record_trade(trade) stats = memory.get_symbol_statistics("AAPL") assert stats["symbol"] == "AAPL" assert stats["total_trades"] == 1 assert stats["closed_trades"] == 1 assert stats["win_rate"] == 1.0 # AAPL was profitable # ============================================================================= # TradeHistoryMemory Serialization Tests # ============================================================================= class TestTradeHistoryMemorySerialization: """Serialization tests for TradeHistoryMemory.""" def test_to_dict(self, memory, multiple_trades): """To dict should serialize correctly.""" for trade in multiple_trades: memory.record_trade(trade) data = memory.to_dict() assert "trades" in data assert "memory" in data assert len(data["trades"]) == len(multiple_trades) def test_from_dict(self, memory, multiple_trades): """From dict should deserialize correctly.""" for trade in multiple_trades: memory.record_trade(trade) data = memory.to_dict() restored = TradeHistoryMemory.from_dict(data) assert restored.count() == len(multiple_trades) def test_roundtrip(self, memory, sample_trade): """Roundtrip serialization should preserve data.""" memory.record_trade(sample_trade) memory.close_trade(sample_trade.id, exit_price=165.0) data = memory.to_dict() restored = TradeHistoryMemory.from_dict(data) original = memory.get_trade(sample_trade.id) restored_trade = restored.get_trade(sample_trade.id) assert restored_trade is not None assert restored_trade.symbol == original.symbol assert restored_trade.exit_price == original.exit_price assert restored_trade.outcome == original.outcome # ============================================================================= # Integration Tests # ============================================================================= class TestIntegration: """Integration tests for trade history workflow.""" def test_full_trade_lifecycle(self, memory): """Test complete trade lifecycle.""" # 1. Create reasoning reasoning = AgentReasoning( fundamentals="Strong quarterly earnings", technical="Golden cross on daily chart", research_conclusion="Buy for momentum continuation", ) # 2. Create market context context = MarketContext( vix=15.0, economic_regime="EXPANSION", ) # 3. Open trade trade = TradeRecord.create( symbol="MSFT", direction=TradeDirection.LONG, entry_price=400.0, quantity=50, signal_strength=SignalStrength.BUY, confidence=0.75, reasoning=reasoning, market_context=context, tags=["tech", "earnings"], ) memory.record_trade(trade) # 4. Verify open assert memory.count() == 1 assert len(memory.get_open_trades()) == 1 # 5. Close with profit memory.close_trade( trade.id, exit_price=440.0, lessons_learned="Good timing on earnings play", ) # 6. Verify closed assert len(memory.get_closed_trades()) == 1 closed = memory.get_trade(trade.id) assert closed.outcome == TradeOutcome.PROFITABLE assert closed.returns == pytest.approx(0.10, rel=0.01) # 7. Find similar trades similar = memory.find_similar_trades( query="tech earnings momentum", top_k=1, ) assert len(similar) == 1 assert similar[0].symbol == "MSFT" # 8. Check statistics stats = memory.get_statistics() assert stats["win_rate"] == 1.0 assert stats["avg_return"] > 0 def test_learning_from_history(self, memory): """Test learning patterns from trade history.""" # Record several trades trades = [ ("AAPL", "earnings beat", 150.0, 165.0), # +10% ("GOOGL", "earnings beat", 140.0, 147.0), # +5% ("MSFT", "earnings miss", 350.0, 315.0), # -10% ("NVDA", "earnings beat", 500.0, 550.0), # +10% ] for symbol, pattern, entry, exit_price in trades: trade = TradeRecord.create( symbol=symbol, direction=TradeDirection.LONG, entry_price=entry, reasoning=AgentReasoning(research_conclusion=f"Trade on {pattern}"), tags=["earnings"], ) memory.record_trade(trade) memory.close_trade(trade.id, exit_price=exit_price) # Query for earnings beat patterns profitable = memory.find_profitable_patterns( query="earnings beat momentum", min_return=0.05, ) # Should find the profitable earnings beat trades assert len(profitable) >= 2 # Query for losses to avoid losers = memory.find_losing_patterns( query="earnings trade", max_return=-0.05, ) # Should find the earnings miss trade assert len(losers) >= 1