feat(memory): add trade history memory with outcome tracking - Fixes #19
This commit is contained in:
parent
d72c214d4d
commit
dbfcea3740
|
|
@ -0,0 +1,740 @@
|
|||
"""Tests for Issue #19: Trade History Memory.
|
||||
|
||||
This module tests the trade history memory for tracking:
|
||||
- Trade outcomes (profit/loss)
|
||||
- Agent reasoning
|
||||
- Market context
|
||||
- Pattern finding
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tradingagents.memory.trade_history import (
|
||||
TradeHistoryMemory,
|
||||
TradeRecord,
|
||||
TradeOutcome,
|
||||
TradeDirection,
|
||||
SignalStrength,
|
||||
AgentReasoning,
|
||||
MarketContext,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def memory():
|
||||
"""Create a TradeHistoryMemory instance."""
|
||||
return TradeHistoryMemory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_reasoning():
|
||||
"""Create sample agent reasoning."""
|
||||
return AgentReasoning(
|
||||
fundamentals="Strong earnings growth, P/E below industry average",
|
||||
technical="Breaking out above 50-day MA with volume",
|
||||
news="Positive analyst upgrades",
|
||||
sentiment="Bullish social media sentiment",
|
||||
bull_case="Undervalued with strong growth trajectory",
|
||||
bear_case="Macro headwinds could impact demand",
|
||||
research_conclusion="Buy on fundamental strength",
|
||||
final_signal="STRONG_BUY",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_market_context():
|
||||
"""Create sample market context."""
|
||||
return MarketContext(
|
||||
vix=18.5,
|
||||
spy_return_1d=0.005,
|
||||
sector_performance={"XLK": 0.01, "XLF": -0.005},
|
||||
economic_regime="EXPANSION",
|
||||
yield_curve_state="NORMAL",
|
||||
macro_indicators={"gdp_growth": 2.5, "unemployment": 4.0},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trade(sample_reasoning, sample_market_context):
|
||||
"""Create a sample trade record."""
|
||||
return TradeRecord.create(
|
||||
symbol="AAPL",
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=150.0,
|
||||
quantity=100,
|
||||
signal_strength=SignalStrength.STRONG_BUY,
|
||||
confidence=0.85,
|
||||
reasoning=sample_reasoning,
|
||||
market_context=sample_market_context,
|
||||
tags=["tech", "earnings"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_trades():
|
||||
"""Create multiple trade records for testing."""
|
||||
now = datetime.now()
|
||||
trades = []
|
||||
|
||||
# Profitable AAPL trade
|
||||
trade1 = TradeRecord(
|
||||
id="trade-1",
|
||||
symbol="AAPL",
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=150.0,
|
||||
entry_time=now - timedelta(days=10),
|
||||
quantity=100,
|
||||
reasoning=AgentReasoning(research_conclusion="Buy on earnings"),
|
||||
tags=["tech", "earnings"],
|
||||
)
|
||||
trade1.close(165.0)
|
||||
|
||||
# Loss GOOGL trade
|
||||
trade2 = TradeRecord(
|
||||
id="trade-2",
|
||||
symbol="GOOGL",
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=140.0,
|
||||
entry_time=now - timedelta(days=5),
|
||||
quantity=50,
|
||||
reasoning=AgentReasoning(research_conclusion="Momentum buy"),
|
||||
tags=["tech"],
|
||||
)
|
||||
trade2.close(130.0)
|
||||
|
||||
# Break-even MSFT trade
|
||||
trade3 = TradeRecord(
|
||||
id="trade-3",
|
||||
symbol="MSFT",
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=350.0,
|
||||
entry_time=now - timedelta(days=2),
|
||||
quantity=20,
|
||||
reasoning=AgentReasoning(research_conclusion="Range trade"),
|
||||
tags=["tech"],
|
||||
)
|
||||
trade3.close(351.0)
|
||||
|
||||
# Open NVDA trade
|
||||
trade4 = TradeRecord(
|
||||
id="trade-4",
|
||||
symbol="NVDA",
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=500.0,
|
||||
entry_time=now - timedelta(hours=2),
|
||||
quantity=10,
|
||||
reasoning=AgentReasoning(research_conclusion="AI momentum"),
|
||||
tags=["tech", "ai"],
|
||||
)
|
||||
|
||||
trades = [trade1, trade2, trade3, trade4]
|
||||
return trades
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TradeDirection Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTradeDirection:
|
||||
"""Tests for TradeDirection enum."""
|
||||
|
||||
def test_long_value(self):
|
||||
"""LONG should have correct value."""
|
||||
assert TradeDirection.LONG.value == "long"
|
||||
|
||||
def test_short_value(self):
|
||||
"""SHORT should have correct value."""
|
||||
assert TradeDirection.SHORT.value == "short"
|
||||
|
||||
def test_hold_value(self):
|
||||
"""HOLD should have correct value."""
|
||||
assert TradeDirection.HOLD.value == "hold"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TradeOutcome Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTradeOutcome:
|
||||
"""Tests for TradeOutcome enum."""
|
||||
|
||||
def test_profitable(self):
|
||||
"""PROFITABLE should have correct value."""
|
||||
assert TradeOutcome.PROFITABLE.value == "profitable"
|
||||
|
||||
def test_loss(self):
|
||||
"""LOSS should have correct value."""
|
||||
assert TradeOutcome.LOSS.value == "loss"
|
||||
|
||||
def test_break_even(self):
|
||||
"""BREAK_EVEN should have correct value."""
|
||||
assert TradeOutcome.BREAK_EVEN.value == "break_even"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SignalStrength Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestSignalStrength:
|
||||
"""Tests for SignalStrength enum."""
|
||||
|
||||
def test_signal_values(self):
|
||||
"""All signal strengths should have correct values."""
|
||||
assert SignalStrength.STRONG_BUY.value == "strong_buy"
|
||||
assert SignalStrength.BUY.value == "buy"
|
||||
assert SignalStrength.NEUTRAL.value == "neutral"
|
||||
assert SignalStrength.SELL.value == "sell"
|
||||
assert SignalStrength.STRONG_SELL.value == "strong_sell"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# AgentReasoning Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestAgentReasoning:
|
||||
"""Tests for AgentReasoning class."""
|
||||
|
||||
def test_default_reasoning(self):
|
||||
"""Default reasoning should have None values."""
|
||||
reasoning = AgentReasoning()
|
||||
assert reasoning.fundamentals is None
|
||||
assert reasoning.technical is None
|
||||
assert reasoning.research_conclusion is None
|
||||
|
||||
def test_reasoning_with_values(self, sample_reasoning):
|
||||
"""Reasoning should store values correctly."""
|
||||
assert sample_reasoning.fundamentals is not None
|
||||
assert sample_reasoning.research_conclusion == "Buy on fundamental strength"
|
||||
|
||||
def test_to_dict(self, sample_reasoning):
|
||||
"""To dict should serialize correctly."""
|
||||
data = sample_reasoning.to_dict()
|
||||
assert data["fundamentals"] == sample_reasoning.fundamentals
|
||||
assert data["research_conclusion"] == sample_reasoning.research_conclusion
|
||||
|
||||
def test_from_dict(self, sample_reasoning):
|
||||
"""From dict should deserialize correctly."""
|
||||
data = sample_reasoning.to_dict()
|
||||
restored = AgentReasoning.from_dict(data)
|
||||
assert restored.fundamentals == sample_reasoning.fundamentals
|
||||
assert restored.research_conclusion == sample_reasoning.research_conclusion
|
||||
|
||||
def test_summary(self, sample_reasoning):
|
||||
"""Summary should generate text."""
|
||||
summary = sample_reasoning.summary()
|
||||
assert "Fundamentals" in summary
|
||||
assert "Conclusion" in summary
|
||||
|
||||
def test_empty_summary(self):
|
||||
"""Empty reasoning should have fallback summary."""
|
||||
reasoning = AgentReasoning()
|
||||
summary = reasoning.summary()
|
||||
assert summary == "No reasoning recorded"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MarketContext Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestMarketContext:
|
||||
"""Tests for MarketContext class."""
|
||||
|
||||
def test_default_context(self):
|
||||
"""Default context should have None/empty values."""
|
||||
context = MarketContext()
|
||||
assert context.vix is None
|
||||
assert context.sector_performance == {}
|
||||
|
||||
def test_context_with_values(self, sample_market_context):
|
||||
"""Context should store values correctly."""
|
||||
assert sample_market_context.vix == 18.5
|
||||
assert sample_market_context.economic_regime == "EXPANSION"
|
||||
assert "XLK" in sample_market_context.sector_performance
|
||||
|
||||
def test_to_dict(self, sample_market_context):
|
||||
"""To dict should serialize correctly."""
|
||||
data = sample_market_context.to_dict()
|
||||
assert data["vix"] == 18.5
|
||||
assert data["economic_regime"] == "EXPANSION"
|
||||
|
||||
def test_from_dict(self, sample_market_context):
|
||||
"""From dict should deserialize correctly."""
|
||||
data = sample_market_context.to_dict()
|
||||
restored = MarketContext.from_dict(data)
|
||||
assert restored.vix == sample_market_context.vix
|
||||
assert restored.economic_regime == sample_market_context.economic_regime
|
||||
|
||||
def test_summary(self, sample_market_context):
|
||||
"""Summary should generate text."""
|
||||
summary = sample_market_context.summary()
|
||||
assert "VIX" in summary
|
||||
assert "Regime" in summary
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TradeRecord Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTradeRecord:
|
||||
"""Tests for TradeRecord class."""
|
||||
|
||||
def test_create_trade(self):
|
||||
"""Create should generate a valid trade record."""
|
||||
trade = TradeRecord.create(
|
||||
symbol="AAPL",
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=150.0,
|
||||
)
|
||||
assert trade.symbol == "AAPL"
|
||||
assert trade.direction == TradeDirection.LONG
|
||||
assert trade.entry_price == 150.0
|
||||
assert trade.id is not None
|
||||
assert trade.entry_time is not None
|
||||
assert trade.is_open()
|
||||
|
||||
def test_create_with_all_args(self, sample_reasoning, sample_market_context):
|
||||
"""Create with all arguments should work."""
|
||||
trade = TradeRecord.create(
|
||||
symbol="AAPL",
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=150.0,
|
||||
quantity=100,
|
||||
signal_strength=SignalStrength.STRONG_BUY,
|
||||
confidence=0.85,
|
||||
reasoning=sample_reasoning,
|
||||
market_context=sample_market_context,
|
||||
tags=["tech"],
|
||||
)
|
||||
assert trade.quantity == 100
|
||||
assert trade.signal_strength == SignalStrength.STRONG_BUY
|
||||
assert trade.confidence == 0.85
|
||||
assert trade.reasoning == sample_reasoning
|
||||
assert trade.market_context == sample_market_context
|
||||
|
||||
def test_close_profitable(self, sample_trade):
|
||||
"""Closing at higher price should be profitable."""
|
||||
sample_trade.close(165.0)
|
||||
|
||||
assert sample_trade.exit_price == 165.0
|
||||
assert sample_trade.exit_time is not None
|
||||
assert sample_trade.returns is not None
|
||||
assert sample_trade.returns > 0
|
||||
assert sample_trade.outcome == TradeOutcome.PROFITABLE
|
||||
assert not sample_trade.is_open()
|
||||
|
||||
def test_close_loss(self, sample_trade):
|
||||
"""Closing at lower price should be a loss."""
|
||||
sample_trade.close(140.0)
|
||||
|
||||
assert sample_trade.returns < 0
|
||||
assert sample_trade.outcome == TradeOutcome.LOSS
|
||||
|
||||
def test_close_break_even(self, sample_trade):
|
||||
"""Closing at same price should be break even."""
|
||||
sample_trade.close(150.2) # Within 0.5% threshold
|
||||
|
||||
assert abs(sample_trade.returns) < 0.005
|
||||
assert sample_trade.outcome == TradeOutcome.BREAK_EVEN
|
||||
|
||||
def test_close_short_profitable(self):
|
||||
"""Short trade profitable when price goes down."""
|
||||
trade = TradeRecord.create(
|
||||
symbol="AAPL",
|
||||
direction=TradeDirection.SHORT,
|
||||
entry_price=150.0,
|
||||
)
|
||||
trade.close(140.0)
|
||||
|
||||
assert trade.returns > 0
|
||||
assert trade.outcome == TradeOutcome.PROFITABLE
|
||||
|
||||
def test_close_short_loss(self):
|
||||
"""Short trade loss when price goes up."""
|
||||
trade = TradeRecord.create(
|
||||
symbol="AAPL",
|
||||
direction=TradeDirection.SHORT,
|
||||
entry_price=150.0,
|
||||
)
|
||||
trade.close(160.0)
|
||||
|
||||
assert trade.returns < 0
|
||||
assert trade.outcome == TradeOutcome.LOSS
|
||||
|
||||
def test_pnl_calculation(self, sample_trade):
|
||||
"""PnL should be calculated correctly."""
|
||||
sample_trade.close(165.0)
|
||||
|
||||
expected_return = (165 - 150) / 150
|
||||
expected_pnl = expected_return * 150 * 100
|
||||
|
||||
assert abs(sample_trade.pnl - expected_pnl) < 0.01
|
||||
|
||||
def test_holding_period(self, sample_trade):
|
||||
"""Holding period should be calculated correctly."""
|
||||
assert sample_trade.holding_period_days() is None # Still open
|
||||
|
||||
sample_trade.close(165.0)
|
||||
holding = sample_trade.holding_period_days()
|
||||
assert holding is not None
|
||||
assert holding >= 0
|
||||
|
||||
def test_to_memory_content(self, sample_trade):
|
||||
"""Memory content should include key info."""
|
||||
sample_trade.close(165.0)
|
||||
content = sample_trade.to_memory_content()
|
||||
|
||||
assert "AAPL" in content
|
||||
assert "long" in content
|
||||
assert "profitable" in content
|
||||
|
||||
def test_to_dict(self, sample_trade):
|
||||
"""To dict should serialize correctly."""
|
||||
data = sample_trade.to_dict()
|
||||
|
||||
assert data["symbol"] == "AAPL"
|
||||
assert data["direction"] == "long"
|
||||
assert data["entry_price"] == 150.0
|
||||
assert "reasoning" in data
|
||||
assert "market_context" in data
|
||||
|
||||
def test_from_dict(self, sample_trade):
|
||||
"""From dict should deserialize correctly."""
|
||||
sample_trade.close(165.0)
|
||||
data = sample_trade.to_dict()
|
||||
restored = TradeRecord.from_dict(data)
|
||||
|
||||
assert restored.symbol == sample_trade.symbol
|
||||
assert restored.direction == sample_trade.direction
|
||||
assert restored.entry_price == sample_trade.entry_price
|
||||
assert restored.exit_price == sample_trade.exit_price
|
||||
assert restored.outcome == sample_trade.outcome
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TradeHistoryMemory Basic Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTradeHistoryMemoryBasic:
|
||||
"""Basic tests for TradeHistoryMemory."""
|
||||
|
||||
def test_create_empty(self, memory):
|
||||
"""Empty memory should have zero trades."""
|
||||
assert memory.count() == 0
|
||||
|
||||
def test_record_trade(self, memory, sample_trade):
|
||||
"""Recording a trade should increase count."""
|
||||
trade_id = memory.record_trade(sample_trade)
|
||||
assert trade_id == sample_trade.id
|
||||
assert memory.count() == 1
|
||||
|
||||
def test_get_trade(self, memory, sample_trade):
|
||||
"""Get should return recorded trade."""
|
||||
memory.record_trade(sample_trade)
|
||||
retrieved = memory.get_trade(sample_trade.id)
|
||||
assert retrieved is not None
|
||||
assert retrieved.symbol == sample_trade.symbol
|
||||
|
||||
def test_get_nonexistent(self, memory):
|
||||
"""Get nonexistent trade should return None."""
|
||||
result = memory.get_trade("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_close_trade(self, memory, sample_trade):
|
||||
"""Closing a trade should update it."""
|
||||
memory.record_trade(sample_trade)
|
||||
closed = memory.close_trade(sample_trade.id, exit_price=165.0)
|
||||
|
||||
assert closed is not None
|
||||
assert closed.exit_price == 165.0
|
||||
assert closed.outcome is not None
|
||||
|
||||
def test_close_with_lessons(self, memory, sample_trade):
|
||||
"""Closing with lessons should store them."""
|
||||
memory.record_trade(sample_trade)
|
||||
memory.close_trade(
|
||||
sample_trade.id,
|
||||
exit_price=165.0,
|
||||
lessons_learned="Wait for confirmation before entry",
|
||||
)
|
||||
|
||||
trade = memory.get_trade(sample_trade.id)
|
||||
assert trade.lessons_learned == "Wait for confirmation before entry"
|
||||
|
||||
def test_clear(self, memory, multiple_trades):
|
||||
"""Clear should remove all trades."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
count = memory.clear()
|
||||
assert count == len(multiple_trades)
|
||||
assert memory.count() == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TradeHistoryMemory Query Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTradeHistoryMemoryQueries:
|
||||
"""Query tests for TradeHistoryMemory."""
|
||||
|
||||
def test_get_open_trades(self, memory, multiple_trades):
|
||||
"""Get open trades should filter correctly."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
open_trades = memory.get_open_trades()
|
||||
assert len(open_trades) == 1 # Only NVDA is open
|
||||
assert open_trades[0].symbol == "NVDA"
|
||||
|
||||
def test_get_closed_trades(self, memory, multiple_trades):
|
||||
"""Get closed trades should filter correctly."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
closed_trades = memory.get_closed_trades()
|
||||
assert len(closed_trades) == 3 # AAPL, GOOGL, MSFT
|
||||
|
||||
def test_get_trades_by_symbol(self, memory, multiple_trades):
|
||||
"""Get trades by symbol should filter correctly."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
aapl_trades = memory.get_trades_by_symbol("AAPL")
|
||||
assert len(aapl_trades) == 1
|
||||
assert aapl_trades[0].symbol == "AAPL"
|
||||
|
||||
def test_find_similar_trades(self, memory, multiple_trades):
|
||||
"""Find similar trades should return relevant trades."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
similar = memory.find_similar_trades(
|
||||
query="tech stock momentum earnings",
|
||||
top_k=3,
|
||||
)
|
||||
assert len(similar) >= 1
|
||||
|
||||
def test_find_profitable_patterns(self, memory, multiple_trades):
|
||||
"""Find profitable patterns should return winners."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
profitable = memory.find_profitable_patterns(
|
||||
query="tech stock buy",
|
||||
min_return=0.05,
|
||||
top_k=5,
|
||||
)
|
||||
# AAPL had 10% return
|
||||
assert len(profitable) >= 1
|
||||
for trade in profitable:
|
||||
assert trade.returns >= 0.05
|
||||
|
||||
def test_find_losing_patterns(self, memory, multiple_trades):
|
||||
"""Find losing patterns should return losers."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
losers = memory.find_losing_patterns(
|
||||
query="tech stock momentum",
|
||||
max_return=-0.05,
|
||||
top_k=5,
|
||||
)
|
||||
# GOOGL had -7% return
|
||||
assert len(losers) >= 1
|
||||
for trade in losers:
|
||||
assert trade.returns <= -0.05
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TradeHistoryMemory Statistics Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTradeHistoryMemoryStats:
|
||||
"""Statistics tests for TradeHistoryMemory."""
|
||||
|
||||
def test_empty_statistics(self, memory):
|
||||
"""Empty memory should have zero stats."""
|
||||
stats = memory.get_statistics()
|
||||
assert stats["total_trades"] == 0
|
||||
assert stats["win_rate"] == 0.0
|
||||
|
||||
def test_statistics_with_trades(self, memory, multiple_trades):
|
||||
"""Statistics should reflect trade data."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
stats = memory.get_statistics()
|
||||
|
||||
assert stats["total_trades"] == 4
|
||||
assert stats["open_trades"] == 1
|
||||
assert stats["closed_trades"] == 3
|
||||
assert stats["win_rate"] > 0 # At least AAPL was profitable
|
||||
|
||||
def test_symbol_statistics(self, memory, multiple_trades):
|
||||
"""Symbol statistics should be calculated correctly."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
stats = memory.get_symbol_statistics("AAPL")
|
||||
|
||||
assert stats["symbol"] == "AAPL"
|
||||
assert stats["total_trades"] == 1
|
||||
assert stats["closed_trades"] == 1
|
||||
assert stats["win_rate"] == 1.0 # AAPL was profitable
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TradeHistoryMemory Serialization Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTradeHistoryMemorySerialization:
|
||||
"""Serialization tests for TradeHistoryMemory."""
|
||||
|
||||
def test_to_dict(self, memory, multiple_trades):
|
||||
"""To dict should serialize correctly."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
data = memory.to_dict()
|
||||
|
||||
assert "trades" in data
|
||||
assert "memory" in data
|
||||
assert len(data["trades"]) == len(multiple_trades)
|
||||
|
||||
def test_from_dict(self, memory, multiple_trades):
|
||||
"""From dict should deserialize correctly."""
|
||||
for trade in multiple_trades:
|
||||
memory.record_trade(trade)
|
||||
|
||||
data = memory.to_dict()
|
||||
restored = TradeHistoryMemory.from_dict(data)
|
||||
|
||||
assert restored.count() == len(multiple_trades)
|
||||
|
||||
def test_roundtrip(self, memory, sample_trade):
|
||||
"""Roundtrip serialization should preserve data."""
|
||||
memory.record_trade(sample_trade)
|
||||
memory.close_trade(sample_trade.id, exit_price=165.0)
|
||||
|
||||
data = memory.to_dict()
|
||||
restored = TradeHistoryMemory.from_dict(data)
|
||||
|
||||
original = memory.get_trade(sample_trade.id)
|
||||
restored_trade = restored.get_trade(sample_trade.id)
|
||||
|
||||
assert restored_trade is not None
|
||||
assert restored_trade.symbol == original.symbol
|
||||
assert restored_trade.exit_price == original.exit_price
|
||||
assert restored_trade.outcome == original.outcome
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for trade history workflow."""
|
||||
|
||||
def test_full_trade_lifecycle(self, memory):
|
||||
"""Test complete trade lifecycle."""
|
||||
# 1. Create reasoning
|
||||
reasoning = AgentReasoning(
|
||||
fundamentals="Strong quarterly earnings",
|
||||
technical="Golden cross on daily chart",
|
||||
research_conclusion="Buy for momentum continuation",
|
||||
)
|
||||
|
||||
# 2. Create market context
|
||||
context = MarketContext(
|
||||
vix=15.0,
|
||||
economic_regime="EXPANSION",
|
||||
)
|
||||
|
||||
# 3. Open trade
|
||||
trade = TradeRecord.create(
|
||||
symbol="MSFT",
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=400.0,
|
||||
quantity=50,
|
||||
signal_strength=SignalStrength.BUY,
|
||||
confidence=0.75,
|
||||
reasoning=reasoning,
|
||||
market_context=context,
|
||||
tags=["tech", "earnings"],
|
||||
)
|
||||
memory.record_trade(trade)
|
||||
|
||||
# 4. Verify open
|
||||
assert memory.count() == 1
|
||||
assert len(memory.get_open_trades()) == 1
|
||||
|
||||
# 5. Close with profit
|
||||
memory.close_trade(
|
||||
trade.id,
|
||||
exit_price=440.0,
|
||||
lessons_learned="Good timing on earnings play",
|
||||
)
|
||||
|
||||
# 6. Verify closed
|
||||
assert len(memory.get_closed_trades()) == 1
|
||||
closed = memory.get_trade(trade.id)
|
||||
assert closed.outcome == TradeOutcome.PROFITABLE
|
||||
assert closed.returns == pytest.approx(0.10, rel=0.01)
|
||||
|
||||
# 7. Find similar trades
|
||||
similar = memory.find_similar_trades(
|
||||
query="tech earnings momentum",
|
||||
top_k=1,
|
||||
)
|
||||
assert len(similar) == 1
|
||||
assert similar[0].symbol == "MSFT"
|
||||
|
||||
# 8. Check statistics
|
||||
stats = memory.get_statistics()
|
||||
assert stats["win_rate"] == 1.0
|
||||
assert stats["avg_return"] > 0
|
||||
|
||||
def test_learning_from_history(self, memory):
|
||||
"""Test learning patterns from trade history."""
|
||||
# Record several trades
|
||||
trades = [
|
||||
("AAPL", "earnings beat", 150.0, 165.0), # +10%
|
||||
("GOOGL", "earnings beat", 140.0, 147.0), # +5%
|
||||
("MSFT", "earnings miss", 350.0, 315.0), # -10%
|
||||
("NVDA", "earnings beat", 500.0, 550.0), # +10%
|
||||
]
|
||||
|
||||
for symbol, pattern, entry, exit_price in trades:
|
||||
trade = TradeRecord.create(
|
||||
symbol=symbol,
|
||||
direction=TradeDirection.LONG,
|
||||
entry_price=entry,
|
||||
reasoning=AgentReasoning(research_conclusion=f"Trade on {pattern}"),
|
||||
tags=["earnings"],
|
||||
)
|
||||
memory.record_trade(trade)
|
||||
memory.close_trade(trade.id, exit_price=exit_price)
|
||||
|
||||
# Query for earnings beat patterns
|
||||
profitable = memory.find_profitable_patterns(
|
||||
query="earnings beat momentum",
|
||||
min_return=0.05,
|
||||
)
|
||||
|
||||
# Should find the profitable earnings beat trades
|
||||
assert len(profitable) >= 2
|
||||
|
||||
# Query for losses to avoid
|
||||
losers = memory.find_losing_patterns(
|
||||
query="earnings trade",
|
||||
max_return=-0.05,
|
||||
)
|
||||
|
||||
# Should find the earnings miss trade
|
||||
assert len(losers) >= 1
|
||||
|
|
@ -6,6 +6,7 @@ This module provides a layered memory system with three scoring dimensions:
|
|||
- Importance: Significance weighting for impactful events
|
||||
|
||||
Issue #18: Layered memory - recency, relevancy, importance scoring
|
||||
Issue #19: Trade history memory - outcomes, agent reasoning
|
||||
"""
|
||||
|
||||
from .layered_memory import (
|
||||
|
|
@ -17,11 +18,30 @@ from .layered_memory import (
|
|||
ImportanceLevel,
|
||||
)
|
||||
|
||||
from .trade_history import (
|
||||
TradeHistoryMemory,
|
||||
TradeRecord,
|
||||
TradeOutcome,
|
||||
TradeDirection,
|
||||
SignalStrength,
|
||||
AgentReasoning,
|
||||
MarketContext,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Layered Memory (Issue #18)
|
||||
"LayeredMemory",
|
||||
"MemoryEntry",
|
||||
"MemoryConfig",
|
||||
"ScoringWeights",
|
||||
"DecayFunction",
|
||||
"ImportanceLevel",
|
||||
# Trade History (Issue #19)
|
||||
"TradeHistoryMemory",
|
||||
"TradeRecord",
|
||||
"TradeOutcome",
|
||||
"TradeDirection",
|
||||
"SignalStrength",
|
||||
"AgentReasoning",
|
||||
"MarketContext",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,820 @@
|
|||
"""Trade History Memory for learning from past trade outcomes.
|
||||
|
||||
This module provides specialized memory for tracking and learning from trades:
|
||||
- Trade outcomes (profit/loss, returns)
|
||||
- Agent reasoning and signals
|
||||
- Entry/exit conditions
|
||||
- Market context at time of trade
|
||||
|
||||
Issue #19: [MEM-18] Trade history memory - outcomes, agent reasoning
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Any
|
||||
import uuid
|
||||
|
||||
from .layered_memory import (
|
||||
LayeredMemory,
|
||||
MemoryEntry,
|
||||
MemoryConfig,
|
||||
ScoringWeights,
|
||||
ImportanceLevel,
|
||||
)
|
||||
|
||||
|
||||
class TradeOutcome(Enum):
|
||||
"""Trade outcome categories."""
|
||||
PROFITABLE = "profitable" # Positive return
|
||||
BREAK_EVEN = "break_even" # ~0% return
|
||||
LOSS = "loss" # Negative return
|
||||
STOPPED_OUT = "stopped_out" # Hit stop loss
|
||||
TARGET_HIT = "target_hit" # Hit profit target
|
||||
|
||||
|
||||
class TradeDirection(Enum):
|
||||
"""Trade direction."""
|
||||
LONG = "long"
|
||||
SHORT = "short"
|
||||
HOLD = "hold"
|
||||
|
||||
|
||||
class SignalStrength(Enum):
|
||||
"""Signal strength levels."""
|
||||
STRONG_BUY = "strong_buy"
|
||||
BUY = "buy"
|
||||
NEUTRAL = "neutral"
|
||||
SELL = "sell"
|
||||
STRONG_SELL = "strong_sell"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentReasoning:
|
||||
"""Captures reasoning from each agent in the trading workflow.
|
||||
|
||||
Attributes:
|
||||
fundamentals: Fundamentals analyst reasoning
|
||||
technical: Technical/market analyst reasoning
|
||||
news: News analyst reasoning
|
||||
sentiment: Social media sentiment reasoning
|
||||
momentum: Momentum analyst reasoning (if enabled)
|
||||
macro: Macro analyst reasoning (if enabled)
|
||||
correlation: Correlation analyst reasoning (if enabled)
|
||||
bull_case: Bull researcher arguments
|
||||
bear_case: Bear researcher arguments
|
||||
research_conclusion: Research manager decision
|
||||
risk_assessment: Risk manager assessment
|
||||
final_signal: Final trading signal
|
||||
"""
|
||||
fundamentals: Optional[str] = None
|
||||
technical: Optional[str] = None
|
||||
news: Optional[str] = None
|
||||
sentiment: Optional[str] = None
|
||||
momentum: Optional[str] = None
|
||||
macro: Optional[str] = None
|
||||
correlation: Optional[str] = None
|
||||
bull_case: Optional[str] = None
|
||||
bear_case: Optional[str] = None
|
||||
research_conclusion: Optional[str] = None
|
||||
risk_assessment: Optional[str] = None
|
||||
final_signal: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Optional[str]]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"fundamentals": self.fundamentals,
|
||||
"technical": self.technical,
|
||||
"news": self.news,
|
||||
"sentiment": self.sentiment,
|
||||
"momentum": self.momentum,
|
||||
"macro": self.macro,
|
||||
"correlation": self.correlation,
|
||||
"bull_case": self.bull_case,
|
||||
"bear_case": self.bear_case,
|
||||
"research_conclusion": self.research_conclusion,
|
||||
"risk_assessment": self.risk_assessment,
|
||||
"final_signal": self.final_signal,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentReasoning":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
fundamentals=data.get("fundamentals"),
|
||||
technical=data.get("technical"),
|
||||
news=data.get("news"),
|
||||
sentiment=data.get("sentiment"),
|
||||
momentum=data.get("momentum"),
|
||||
macro=data.get("macro"),
|
||||
correlation=data.get("correlation"),
|
||||
bull_case=data.get("bull_case"),
|
||||
bear_case=data.get("bear_case"),
|
||||
research_conclusion=data.get("research_conclusion"),
|
||||
risk_assessment=data.get("risk_assessment"),
|
||||
final_signal=data.get("final_signal"),
|
||||
)
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Generate a text summary of the reasoning."""
|
||||
parts = []
|
||||
|
||||
if self.fundamentals:
|
||||
parts.append(f"Fundamentals: {self.fundamentals[:100]}...")
|
||||
if self.technical:
|
||||
parts.append(f"Technical: {self.technical[:100]}...")
|
||||
if self.bull_case:
|
||||
parts.append(f"Bull: {self.bull_case[:100]}...")
|
||||
if self.bear_case:
|
||||
parts.append(f"Bear: {self.bear_case[:100]}...")
|
||||
if self.research_conclusion:
|
||||
parts.append(f"Conclusion: {self.research_conclusion[:100]}...")
|
||||
|
||||
return " | ".join(parts) if parts else "No reasoning recorded"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketContext:
|
||||
"""Market conditions at time of trade.
|
||||
|
||||
Attributes:
|
||||
vix: VIX volatility index level
|
||||
spy_return_1d: SPY 1-day return
|
||||
sector_performance: Dict of sector returns
|
||||
economic_regime: Detected economic regime
|
||||
yield_curve_state: Yield curve status
|
||||
macro_indicators: Key macro indicators
|
||||
"""
|
||||
vix: Optional[float] = None
|
||||
spy_return_1d: Optional[float] = None
|
||||
sector_performance: Dict[str, float] = field(default_factory=dict)
|
||||
economic_regime: Optional[str] = None
|
||||
yield_curve_state: Optional[str] = None
|
||||
macro_indicators: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"vix": self.vix,
|
||||
"spy_return_1d": self.spy_return_1d,
|
||||
"sector_performance": self.sector_performance,
|
||||
"economic_regime": self.economic_regime,
|
||||
"yield_curve_state": self.yield_curve_state,
|
||||
"macro_indicators": self.macro_indicators,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MarketContext":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
vix=data.get("vix"),
|
||||
spy_return_1d=data.get("spy_return_1d"),
|
||||
sector_performance=data.get("sector_performance", {}),
|
||||
economic_regime=data.get("economic_regime"),
|
||||
yield_curve_state=data.get("yield_curve_state"),
|
||||
macro_indicators=data.get("macro_indicators", {}),
|
||||
)
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Generate a text summary of market context."""
|
||||
parts = []
|
||||
|
||||
if self.vix is not None:
|
||||
parts.append(f"VIX: {self.vix:.1f}")
|
||||
if self.economic_regime:
|
||||
parts.append(f"Regime: {self.economic_regime}")
|
||||
if self.yield_curve_state:
|
||||
parts.append(f"Yield Curve: {self.yield_curve_state}")
|
||||
|
||||
return " | ".join(parts) if parts else "No market context"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeRecord:
|
||||
"""Complete record of a trade including reasoning and outcome.
|
||||
|
||||
Attributes:
|
||||
id: Unique trade ID
|
||||
symbol: Trading symbol
|
||||
direction: Trade direction (long/short/hold)
|
||||
entry_price: Entry price
|
||||
exit_price: Exit price (if closed)
|
||||
entry_time: Entry timestamp
|
||||
exit_time: Exit timestamp (if closed)
|
||||
quantity: Number of shares/contracts
|
||||
returns: Percentage return
|
||||
pnl: Dollar profit/loss
|
||||
outcome: Trade outcome category
|
||||
signal_strength: Original signal strength
|
||||
confidence: Confidence score (0-1)
|
||||
reasoning: Agent reasoning captured
|
||||
market_context: Market conditions at entry
|
||||
lessons_learned: Post-trade lessons (added later)
|
||||
tags: Trade tags for filtering
|
||||
"""
|
||||
id: str
|
||||
symbol: str
|
||||
direction: TradeDirection
|
||||
entry_price: float
|
||||
entry_time: datetime
|
||||
quantity: float = 1.0
|
||||
exit_price: Optional[float] = None
|
||||
exit_time: Optional[datetime] = None
|
||||
returns: Optional[float] = None
|
||||
pnl: Optional[float] = None
|
||||
outcome: Optional[TradeOutcome] = None
|
||||
signal_strength: SignalStrength = SignalStrength.NEUTRAL
|
||||
confidence: float = 0.5
|
||||
reasoning: AgentReasoning = field(default_factory=AgentReasoning)
|
||||
market_context: MarketContext = field(default_factory=MarketContext)
|
||||
lessons_learned: Optional[str] = None
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
symbol: str,
|
||||
direction: TradeDirection,
|
||||
entry_price: float,
|
||||
quantity: float = 1.0,
|
||||
signal_strength: SignalStrength = SignalStrength.NEUTRAL,
|
||||
confidence: float = 0.5,
|
||||
reasoning: Optional[AgentReasoning] = None,
|
||||
market_context: Optional[MarketContext] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> "TradeRecord":
|
||||
"""Create a new trade record."""
|
||||
return cls(
|
||||
id=str(uuid.uuid4()),
|
||||
symbol=symbol,
|
||||
direction=direction,
|
||||
entry_price=entry_price,
|
||||
entry_time=datetime.now(),
|
||||
quantity=quantity,
|
||||
signal_strength=signal_strength,
|
||||
confidence=confidence,
|
||||
reasoning=reasoning or AgentReasoning(),
|
||||
market_context=market_context or MarketContext(),
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
def close(
|
||||
self,
|
||||
exit_price: float,
|
||||
exit_time: Optional[datetime] = None,
|
||||
) -> "TradeRecord":
|
||||
"""Close the trade and calculate returns.
|
||||
|
||||
Args:
|
||||
exit_price: Exit price
|
||||
exit_time: Exit timestamp (default: now)
|
||||
|
||||
Returns:
|
||||
Self with updated exit info
|
||||
"""
|
||||
self.exit_price = exit_price
|
||||
self.exit_time = exit_time or datetime.now()
|
||||
|
||||
# Calculate returns
|
||||
if self.direction == TradeDirection.LONG:
|
||||
self.returns = (exit_price - self.entry_price) / self.entry_price
|
||||
elif self.direction == TradeDirection.SHORT:
|
||||
self.returns = (self.entry_price - exit_price) / self.entry_price
|
||||
else:
|
||||
self.returns = 0.0
|
||||
|
||||
# Calculate PnL
|
||||
self.pnl = self.returns * self.entry_price * self.quantity
|
||||
|
||||
# Determine outcome
|
||||
if self.returns > 0.005: # > 0.5%
|
||||
self.outcome = TradeOutcome.PROFITABLE
|
||||
elif self.returns < -0.005: # < -0.5%
|
||||
self.outcome = TradeOutcome.LOSS
|
||||
else:
|
||||
self.outcome = TradeOutcome.BREAK_EVEN
|
||||
|
||||
return self
|
||||
|
||||
def is_open(self) -> bool:
|
||||
"""Check if trade is still open."""
|
||||
return self.exit_time is None
|
||||
|
||||
def holding_period_days(self) -> Optional[float]:
|
||||
"""Calculate holding period in days."""
|
||||
if self.exit_time is None:
|
||||
return None
|
||||
delta = self.exit_time - self.entry_time
|
||||
return delta.total_seconds() / 86400
|
||||
|
||||
def to_memory_content(self) -> str:
|
||||
"""Generate memory content for this trade."""
|
||||
parts = [
|
||||
f"Trade {self.direction.value} {self.symbol} at ${self.entry_price:.2f}",
|
||||
]
|
||||
|
||||
if self.outcome:
|
||||
parts.append(f"Outcome: {self.outcome.value}")
|
||||
|
||||
if self.returns is not None:
|
||||
parts.append(f"Return: {self.returns * 100:.2f}%")
|
||||
|
||||
if self.reasoning.research_conclusion:
|
||||
parts.append(f"Reasoning: {self.reasoning.research_conclusion}")
|
||||
|
||||
if self.market_context.economic_regime:
|
||||
parts.append(f"Regime: {self.market_context.economic_regime}")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"symbol": self.symbol,
|
||||
"direction": self.direction.value,
|
||||
"entry_price": self.entry_price,
|
||||
"exit_price": self.exit_price,
|
||||
"entry_time": self.entry_time.isoformat(),
|
||||
"exit_time": self.exit_time.isoformat() if self.exit_time else None,
|
||||
"quantity": self.quantity,
|
||||
"returns": self.returns,
|
||||
"pnl": self.pnl,
|
||||
"outcome": self.outcome.value if self.outcome else None,
|
||||
"signal_strength": self.signal_strength.value,
|
||||
"confidence": self.confidence,
|
||||
"reasoning": self.reasoning.to_dict(),
|
||||
"market_context": self.market_context.to_dict(),
|
||||
"lessons_learned": self.lessons_learned,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TradeRecord":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
symbol=data["symbol"],
|
||||
direction=TradeDirection(data["direction"]),
|
||||
entry_price=data["entry_price"],
|
||||
exit_price=data.get("exit_price"),
|
||||
entry_time=datetime.fromisoformat(data["entry_time"]),
|
||||
exit_time=datetime.fromisoformat(data["exit_time"]) if data.get("exit_time") else None,
|
||||
quantity=data.get("quantity", 1.0),
|
||||
returns=data.get("returns"),
|
||||
pnl=data.get("pnl"),
|
||||
outcome=TradeOutcome(data["outcome"]) if data.get("outcome") else None,
|
||||
signal_strength=SignalStrength(data.get("signal_strength", "neutral")),
|
||||
confidence=data.get("confidence", 0.5),
|
||||
reasoning=AgentReasoning.from_dict(data.get("reasoning", {})),
|
||||
market_context=MarketContext.from_dict(data.get("market_context", {})),
|
||||
lessons_learned=data.get("lessons_learned"),
|
||||
tags=data.get("tags", []),
|
||||
)
|
||||
|
||||
|
||||
class TradeHistoryMemory:
|
||||
"""Memory system specialized for trade history and learning.
|
||||
|
||||
This class combines trade record storage with the layered memory system
|
||||
for intelligent retrieval of past trades based on similarity.
|
||||
|
||||
Example:
|
||||
>>> memory = TradeHistoryMemory()
|
||||
>>>
|
||||
>>> # Record a trade
|
||||
>>> trade = TradeRecord.create(
|
||||
... symbol="AAPL",
|
||||
... direction=TradeDirection.LONG,
|
||||
... entry_price=150.0,
|
||||
... reasoning=AgentReasoning(
|
||||
... fundamentals="Strong earnings",
|
||||
... research_conclusion="Buy on earnings momentum",
|
||||
... ),
|
||||
... )
|
||||
>>> memory.record_trade(trade)
|
||||
>>>
|
||||
>>> # Close the trade
|
||||
>>> memory.close_trade(trade.id, exit_price=165.0)
|
||||
>>>
|
||||
>>> # Find similar past trades
|
||||
>>> similar = memory.find_similar_trades(
|
||||
... query="Apple earnings momentum",
|
||||
... top_k=5,
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[MemoryConfig] = None,
|
||||
embedding_function=None,
|
||||
):
|
||||
"""Initialize trade history memory.
|
||||
|
||||
Args:
|
||||
config: Memory configuration
|
||||
embedding_function: Optional embedding function for similarity
|
||||
"""
|
||||
# Default config with weights tuned for trade history
|
||||
if config is None:
|
||||
config = MemoryConfig(
|
||||
weights=ScoringWeights(
|
||||
recency=0.25, # Recent trades somewhat important
|
||||
relevancy=0.45, # Similarity most important
|
||||
importance=0.30, # Outcome importance matters
|
||||
),
|
||||
)
|
||||
|
||||
self._layered_memory = LayeredMemory(
|
||||
config=config,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
self._trades: Dict[str, TradeRecord] = {}
|
||||
|
||||
def record_trade(self, trade: TradeRecord) -> str:
|
||||
"""Record a new trade.
|
||||
|
||||
Args:
|
||||
trade: The trade record to store
|
||||
|
||||
Returns:
|
||||
Trade ID
|
||||
"""
|
||||
self._trades[trade.id] = trade
|
||||
|
||||
# Create memory entry for the trade
|
||||
importance = self._calculate_trade_importance(trade)
|
||||
entry = MemoryEntry.create(
|
||||
content=trade.to_memory_content(),
|
||||
metadata={
|
||||
"trade_id": trade.id,
|
||||
"symbol": trade.symbol,
|
||||
"direction": trade.direction.value,
|
||||
"outcome": trade.outcome.value if trade.outcome else None,
|
||||
"returns": trade.returns,
|
||||
"reasoning_summary": trade.reasoning.summary(),
|
||||
},
|
||||
importance=importance,
|
||||
tags=trade.tags + [trade.symbol, trade.direction.value],
|
||||
timestamp=trade.entry_time,
|
||||
)
|
||||
entry.id = trade.id # Use trade ID as memory ID
|
||||
|
||||
self._layered_memory.add(entry)
|
||||
return trade.id
|
||||
|
||||
def close_trade(
|
||||
self,
|
||||
trade_id: str,
|
||||
exit_price: float,
|
||||
lessons_learned: Optional[str] = None,
|
||||
) -> Optional[TradeRecord]:
|
||||
"""Close an open trade and update memory.
|
||||
|
||||
Args:
|
||||
trade_id: ID of the trade to close
|
||||
exit_price: Exit price
|
||||
lessons_learned: Optional lessons from the trade
|
||||
|
||||
Returns:
|
||||
Updated trade record or None if not found
|
||||
"""
|
||||
trade = self._trades.get(trade_id)
|
||||
if trade is None:
|
||||
return None
|
||||
|
||||
# Close the trade
|
||||
trade.close(exit_price)
|
||||
|
||||
if lessons_learned:
|
||||
trade.lessons_learned = lessons_learned
|
||||
|
||||
# Update memory with outcome
|
||||
importance = self._calculate_trade_importance(trade)
|
||||
self._layered_memory.update_importance(trade_id, importance)
|
||||
|
||||
# Update memory entry content
|
||||
entry = self._layered_memory.get(trade_id)
|
||||
if entry:
|
||||
entry.metadata["outcome"] = trade.outcome.value if trade.outcome else None
|
||||
entry.metadata["returns"] = trade.returns
|
||||
if lessons_learned:
|
||||
entry.metadata["lessons_learned"] = lessons_learned
|
||||
|
||||
return trade
|
||||
|
||||
def get_trade(self, trade_id: str) -> Optional[TradeRecord]:
|
||||
"""Get a trade by ID.
|
||||
|
||||
Args:
|
||||
trade_id: Trade ID
|
||||
|
||||
Returns:
|
||||
Trade record or None
|
||||
"""
|
||||
return self._trades.get(trade_id)
|
||||
|
||||
def get_open_trades(self) -> List[TradeRecord]:
|
||||
"""Get all open trades.
|
||||
|
||||
Returns:
|
||||
List of open trade records
|
||||
"""
|
||||
return [t for t in self._trades.values() if t.is_open()]
|
||||
|
||||
def get_closed_trades(self) -> List[TradeRecord]:
|
||||
"""Get all closed trades.
|
||||
|
||||
Returns:
|
||||
List of closed trade records
|
||||
"""
|
||||
return [t for t in self._trades.values() if not t.is_open()]
|
||||
|
||||
def get_trades_by_symbol(self, symbol: str) -> List[TradeRecord]:
|
||||
"""Get all trades for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
List of trade records for the symbol
|
||||
"""
|
||||
return [t for t in self._trades.values() if t.symbol == symbol]
|
||||
|
||||
def find_similar_trades(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
symbol: Optional[str] = None,
|
||||
outcome: Optional[TradeOutcome] = None,
|
||||
) -> List[TradeRecord]:
|
||||
"""Find similar past trades.
|
||||
|
||||
Args:
|
||||
query: Query describing the current situation
|
||||
top_k: Maximum number of results
|
||||
symbol: Optional filter by symbol
|
||||
outcome: Optional filter by outcome
|
||||
|
||||
Returns:
|
||||
List of similar trade records
|
||||
"""
|
||||
# Build tags filter
|
||||
tags = []
|
||||
if symbol:
|
||||
tags.append(symbol)
|
||||
|
||||
# Retrieve from layered memory
|
||||
results = self._layered_memory.retrieve(
|
||||
query=query,
|
||||
top_k=top_k * 2, # Get more to filter
|
||||
tags=tags if tags else None,
|
||||
)
|
||||
|
||||
# Convert to trade records and filter
|
||||
trades = []
|
||||
for scored in results:
|
||||
trade_id = scored.entry.metadata.get("trade_id")
|
||||
if trade_id and trade_id in self._trades:
|
||||
trade = self._trades[trade_id]
|
||||
|
||||
# Apply outcome filter
|
||||
if outcome and trade.outcome != outcome:
|
||||
continue
|
||||
|
||||
trades.append(trade)
|
||||
|
||||
if len(trades) >= top_k:
|
||||
break
|
||||
|
||||
return trades
|
||||
|
||||
def find_profitable_patterns(
|
||||
self,
|
||||
query: str,
|
||||
min_return: float = 0.0,
|
||||
top_k: int = 5,
|
||||
) -> List[TradeRecord]:
|
||||
"""Find profitable trades similar to the query.
|
||||
|
||||
Args:
|
||||
query: Query describing the current situation
|
||||
min_return: Minimum return filter
|
||||
top_k: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of profitable trade records
|
||||
"""
|
||||
results = self._layered_memory.retrieve(query=query, top_k=top_k * 3)
|
||||
|
||||
trades = []
|
||||
for scored in results:
|
||||
trade_id = scored.entry.metadata.get("trade_id")
|
||||
if trade_id and trade_id in self._trades:
|
||||
trade = self._trades[trade_id]
|
||||
|
||||
if trade.returns is not None and trade.returns >= min_return:
|
||||
trades.append(trade)
|
||||
|
||||
if len(trades) >= top_k:
|
||||
break
|
||||
|
||||
return trades
|
||||
|
||||
def find_losing_patterns(
|
||||
self,
|
||||
query: str,
|
||||
max_return: float = 0.0,
|
||||
top_k: int = 5,
|
||||
) -> List[TradeRecord]:
|
||||
"""Find losing trades similar to the query (for learning what to avoid).
|
||||
|
||||
Args:
|
||||
query: Query describing the current situation
|
||||
max_return: Maximum return filter
|
||||
top_k: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of losing trade records
|
||||
"""
|
||||
results = self._layered_memory.retrieve(query=query, top_k=top_k * 3)
|
||||
|
||||
trades = []
|
||||
for scored in results:
|
||||
trade_id = scored.entry.metadata.get("trade_id")
|
||||
if trade_id and trade_id in self._trades:
|
||||
trade = self._trades[trade_id]
|
||||
|
||||
if trade.returns is not None and trade.returns <= max_return:
|
||||
trades.append(trade)
|
||||
|
||||
if len(trades) >= top_k:
|
||||
break
|
||||
|
||||
return trades
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""Get trade history statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
closed_trades = self.get_closed_trades()
|
||||
|
||||
if not closed_trades:
|
||||
return {
|
||||
"total_trades": len(self._trades),
|
||||
"open_trades": len(self.get_open_trades()),
|
||||
"closed_trades": 0,
|
||||
"win_rate": 0.0,
|
||||
"avg_return": 0.0,
|
||||
"total_pnl": 0.0,
|
||||
"best_trade": None,
|
||||
"worst_trade": None,
|
||||
}
|
||||
|
||||
returns = [t.returns for t in closed_trades if t.returns is not None]
|
||||
pnls = [t.pnl for t in closed_trades if t.pnl is not None]
|
||||
winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE]
|
||||
|
||||
# Find best and worst
|
||||
best_trade = max(closed_trades, key=lambda t: t.returns or 0, default=None)
|
||||
worst_trade = min(closed_trades, key=lambda t: t.returns or 0, default=None)
|
||||
|
||||
# Outcome distribution
|
||||
outcome_dist = {}
|
||||
for trade in closed_trades:
|
||||
if trade.outcome:
|
||||
outcome_dist[trade.outcome.value] = outcome_dist.get(trade.outcome.value, 0) + 1
|
||||
|
||||
return {
|
||||
"total_trades": len(self._trades),
|
||||
"open_trades": len(self.get_open_trades()),
|
||||
"closed_trades": len(closed_trades),
|
||||
"win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0,
|
||||
"avg_return": sum(returns) / len(returns) if returns else 0.0,
|
||||
"total_pnl": sum(pnls) if pnls else 0.0,
|
||||
"best_trade": best_trade.to_dict() if best_trade else None,
|
||||
"worst_trade": worst_trade.to_dict() if worst_trade else None,
|
||||
"outcome_distribution": outcome_dist,
|
||||
}
|
||||
|
||||
def get_symbol_statistics(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Get statistics for a specific symbol.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dictionary with symbol-specific statistics
|
||||
"""
|
||||
symbol_trades = self.get_trades_by_symbol(symbol)
|
||||
closed_trades = [t for t in symbol_trades if not t.is_open()]
|
||||
|
||||
if not closed_trades:
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"total_trades": len(symbol_trades),
|
||||
"open_trades": len([t for t in symbol_trades if t.is_open()]),
|
||||
"closed_trades": 0,
|
||||
"win_rate": 0.0,
|
||||
"avg_return": 0.0,
|
||||
}
|
||||
|
||||
returns = [t.returns for t in closed_trades if t.returns is not None]
|
||||
winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE]
|
||||
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"total_trades": len(symbol_trades),
|
||||
"open_trades": len([t for t in symbol_trades if t.is_open()]),
|
||||
"closed_trades": len(closed_trades),
|
||||
"win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0,
|
||||
"avg_return": sum(returns) / len(returns) if returns else 0.0,
|
||||
}
|
||||
|
||||
def _calculate_trade_importance(self, trade: TradeRecord) -> float:
|
||||
"""Calculate importance score for a trade.
|
||||
|
||||
Args:
|
||||
trade: Trade record
|
||||
|
||||
Returns:
|
||||
Importance score [0, 1]
|
||||
"""
|
||||
# Base on returns if closed
|
||||
if trade.returns is not None:
|
||||
abs_return = abs(trade.returns)
|
||||
if abs_return >= 0.10: # 10%+
|
||||
return ImportanceLevel.CRITICAL.value
|
||||
elif abs_return >= 0.05: # 5%+
|
||||
return ImportanceLevel.HIGH.value
|
||||
elif abs_return >= 0.01: # 1%+
|
||||
return ImportanceLevel.MEDIUM.value
|
||||
else:
|
||||
return ImportanceLevel.LOW.value
|
||||
|
||||
# For open trades, base on confidence
|
||||
if trade.confidence >= 0.8:
|
||||
return ImportanceLevel.HIGH.value
|
||||
elif trade.confidence >= 0.5:
|
||||
return ImportanceLevel.MEDIUM.value
|
||||
else:
|
||||
return ImportanceLevel.LOW.value
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return total number of trades."""
|
||||
return len(self._trades)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""Clear all trades.
|
||||
|
||||
Returns:
|
||||
Number of trades cleared
|
||||
"""
|
||||
count = len(self._trades)
|
||||
self._trades.clear()
|
||||
self._layered_memory.clear()
|
||||
return count
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary representation
|
||||
"""
|
||||
return {
|
||||
"trades": [t.to_dict() for t in self._trades.values()],
|
||||
"memory": self._layered_memory.to_dict(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Dict[str, Any],
|
||||
embedding_function=None,
|
||||
) -> "TradeHistoryMemory":
|
||||
"""Create from dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary representation
|
||||
embedding_function: Optional embedding function
|
||||
|
||||
Returns:
|
||||
TradeHistoryMemory instance
|
||||
"""
|
||||
instance = cls(embedding_function=embedding_function)
|
||||
|
||||
# Restore trades
|
||||
for trade_data in data.get("trades", []):
|
||||
trade = TradeRecord.from_dict(trade_data)
|
||||
instance._trades[trade.id] = trade
|
||||
|
||||
# Restore layered memory
|
||||
if "memory" in data:
|
||||
instance._layered_memory = LayeredMemory.from_dict(
|
||||
data["memory"],
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
|
||||
return instance
|
||||
Loading…
Reference in New Issue