feat(memory): add trade history memory with outcome tracking - Fixes #19
This commit is contained in:
parent
d72c214d4d
commit
dbfcea3740
|
|
@ -0,0 +1,740 @@
|
||||||
|
"""Tests for Issue #19: Trade History Memory.
|
||||||
|
|
||||||
|
This module tests the trade history memory for tracking:
|
||||||
|
- Trade outcomes (profit/loss)
|
||||||
|
- Agent reasoning
|
||||||
|
- Market context
|
||||||
|
- Pattern finding
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from tradingagents.memory.trade_history import (
|
||||||
|
TradeHistoryMemory,
|
||||||
|
TradeRecord,
|
||||||
|
TradeOutcome,
|
||||||
|
TradeDirection,
|
||||||
|
SignalStrength,
|
||||||
|
AgentReasoning,
|
||||||
|
MarketContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Test Fixtures
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def memory():
|
||||||
|
"""Create a TradeHistoryMemory instance."""
|
||||||
|
return TradeHistoryMemory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_reasoning():
|
||||||
|
"""Create sample agent reasoning."""
|
||||||
|
return AgentReasoning(
|
||||||
|
fundamentals="Strong earnings growth, P/E below industry average",
|
||||||
|
technical="Breaking out above 50-day MA with volume",
|
||||||
|
news="Positive analyst upgrades",
|
||||||
|
sentiment="Bullish social media sentiment",
|
||||||
|
bull_case="Undervalued with strong growth trajectory",
|
||||||
|
bear_case="Macro headwinds could impact demand",
|
||||||
|
research_conclusion="Buy on fundamental strength",
|
||||||
|
final_signal="STRONG_BUY",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_market_context():
|
||||||
|
"""Create sample market context."""
|
||||||
|
return MarketContext(
|
||||||
|
vix=18.5,
|
||||||
|
spy_return_1d=0.005,
|
||||||
|
sector_performance={"XLK": 0.01, "XLF": -0.005},
|
||||||
|
economic_regime="EXPANSION",
|
||||||
|
yield_curve_state="NORMAL",
|
||||||
|
macro_indicators={"gdp_growth": 2.5, "unemployment": 4.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_trade(sample_reasoning, sample_market_context):
|
||||||
|
"""Create a sample trade record."""
|
||||||
|
return TradeRecord.create(
|
||||||
|
symbol="AAPL",
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=150.0,
|
||||||
|
quantity=100,
|
||||||
|
signal_strength=SignalStrength.STRONG_BUY,
|
||||||
|
confidence=0.85,
|
||||||
|
reasoning=sample_reasoning,
|
||||||
|
market_context=sample_market_context,
|
||||||
|
tags=["tech", "earnings"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def multiple_trades():
|
||||||
|
"""Create multiple trade records for testing."""
|
||||||
|
now = datetime.now()
|
||||||
|
trades = []
|
||||||
|
|
||||||
|
# Profitable AAPL trade
|
||||||
|
trade1 = TradeRecord(
|
||||||
|
id="trade-1",
|
||||||
|
symbol="AAPL",
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=150.0,
|
||||||
|
entry_time=now - timedelta(days=10),
|
||||||
|
quantity=100,
|
||||||
|
reasoning=AgentReasoning(research_conclusion="Buy on earnings"),
|
||||||
|
tags=["tech", "earnings"],
|
||||||
|
)
|
||||||
|
trade1.close(165.0)
|
||||||
|
|
||||||
|
# Loss GOOGL trade
|
||||||
|
trade2 = TradeRecord(
|
||||||
|
id="trade-2",
|
||||||
|
symbol="GOOGL",
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=140.0,
|
||||||
|
entry_time=now - timedelta(days=5),
|
||||||
|
quantity=50,
|
||||||
|
reasoning=AgentReasoning(research_conclusion="Momentum buy"),
|
||||||
|
tags=["tech"],
|
||||||
|
)
|
||||||
|
trade2.close(130.0)
|
||||||
|
|
||||||
|
# Break-even MSFT trade
|
||||||
|
trade3 = TradeRecord(
|
||||||
|
id="trade-3",
|
||||||
|
symbol="MSFT",
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=350.0,
|
||||||
|
entry_time=now - timedelta(days=2),
|
||||||
|
quantity=20,
|
||||||
|
reasoning=AgentReasoning(research_conclusion="Range trade"),
|
||||||
|
tags=["tech"],
|
||||||
|
)
|
||||||
|
trade3.close(351.0)
|
||||||
|
|
||||||
|
# Open NVDA trade
|
||||||
|
trade4 = TradeRecord(
|
||||||
|
id="trade-4",
|
||||||
|
symbol="NVDA",
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=500.0,
|
||||||
|
entry_time=now - timedelta(hours=2),
|
||||||
|
quantity=10,
|
||||||
|
reasoning=AgentReasoning(research_conclusion="AI momentum"),
|
||||||
|
tags=["tech", "ai"],
|
||||||
|
)
|
||||||
|
|
||||||
|
trades = [trade1, trade2, trade3, trade4]
|
||||||
|
return trades
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TradeDirection Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestTradeDirection:
|
||||||
|
"""Tests for TradeDirection enum."""
|
||||||
|
|
||||||
|
def test_long_value(self):
|
||||||
|
"""LONG should have correct value."""
|
||||||
|
assert TradeDirection.LONG.value == "long"
|
||||||
|
|
||||||
|
def test_short_value(self):
|
||||||
|
"""SHORT should have correct value."""
|
||||||
|
assert TradeDirection.SHORT.value == "short"
|
||||||
|
|
||||||
|
def test_hold_value(self):
|
||||||
|
"""HOLD should have correct value."""
|
||||||
|
assert TradeDirection.HOLD.value == "hold"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TradeOutcome Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestTradeOutcome:
|
||||||
|
"""Tests for TradeOutcome enum."""
|
||||||
|
|
||||||
|
def test_profitable(self):
|
||||||
|
"""PROFITABLE should have correct value."""
|
||||||
|
assert TradeOutcome.PROFITABLE.value == "profitable"
|
||||||
|
|
||||||
|
def test_loss(self):
|
||||||
|
"""LOSS should have correct value."""
|
||||||
|
assert TradeOutcome.LOSS.value == "loss"
|
||||||
|
|
||||||
|
def test_break_even(self):
|
||||||
|
"""BREAK_EVEN should have correct value."""
|
||||||
|
assert TradeOutcome.BREAK_EVEN.value == "break_even"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SignalStrength Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestSignalStrength:
|
||||||
|
"""Tests for SignalStrength enum."""
|
||||||
|
|
||||||
|
def test_signal_values(self):
|
||||||
|
"""All signal strengths should have correct values."""
|
||||||
|
assert SignalStrength.STRONG_BUY.value == "strong_buy"
|
||||||
|
assert SignalStrength.BUY.value == "buy"
|
||||||
|
assert SignalStrength.NEUTRAL.value == "neutral"
|
||||||
|
assert SignalStrength.SELL.value == "sell"
|
||||||
|
assert SignalStrength.STRONG_SELL.value == "strong_sell"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# AgentReasoning Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestAgentReasoning:
|
||||||
|
"""Tests for AgentReasoning class."""
|
||||||
|
|
||||||
|
def test_default_reasoning(self):
|
||||||
|
"""Default reasoning should have None values."""
|
||||||
|
reasoning = AgentReasoning()
|
||||||
|
assert reasoning.fundamentals is None
|
||||||
|
assert reasoning.technical is None
|
||||||
|
assert reasoning.research_conclusion is None
|
||||||
|
|
||||||
|
def test_reasoning_with_values(self, sample_reasoning):
|
||||||
|
"""Reasoning should store values correctly."""
|
||||||
|
assert sample_reasoning.fundamentals is not None
|
||||||
|
assert sample_reasoning.research_conclusion == "Buy on fundamental strength"
|
||||||
|
|
||||||
|
def test_to_dict(self, sample_reasoning):
|
||||||
|
"""To dict should serialize correctly."""
|
||||||
|
data = sample_reasoning.to_dict()
|
||||||
|
assert data["fundamentals"] == sample_reasoning.fundamentals
|
||||||
|
assert data["research_conclusion"] == sample_reasoning.research_conclusion
|
||||||
|
|
||||||
|
def test_from_dict(self, sample_reasoning):
|
||||||
|
"""From dict should deserialize correctly."""
|
||||||
|
data = sample_reasoning.to_dict()
|
||||||
|
restored = AgentReasoning.from_dict(data)
|
||||||
|
assert restored.fundamentals == sample_reasoning.fundamentals
|
||||||
|
assert restored.research_conclusion == sample_reasoning.research_conclusion
|
||||||
|
|
||||||
|
def test_summary(self, sample_reasoning):
|
||||||
|
"""Summary should generate text."""
|
||||||
|
summary = sample_reasoning.summary()
|
||||||
|
assert "Fundamentals" in summary
|
||||||
|
assert "Conclusion" in summary
|
||||||
|
|
||||||
|
def test_empty_summary(self):
|
||||||
|
"""Empty reasoning should have fallback summary."""
|
||||||
|
reasoning = AgentReasoning()
|
||||||
|
summary = reasoning.summary()
|
||||||
|
assert summary == "No reasoning recorded"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MarketContext Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestMarketContext:
|
||||||
|
"""Tests for MarketContext class."""
|
||||||
|
|
||||||
|
def test_default_context(self):
|
||||||
|
"""Default context should have None/empty values."""
|
||||||
|
context = MarketContext()
|
||||||
|
assert context.vix is None
|
||||||
|
assert context.sector_performance == {}
|
||||||
|
|
||||||
|
def test_context_with_values(self, sample_market_context):
|
||||||
|
"""Context should store values correctly."""
|
||||||
|
assert sample_market_context.vix == 18.5
|
||||||
|
assert sample_market_context.economic_regime == "EXPANSION"
|
||||||
|
assert "XLK" in sample_market_context.sector_performance
|
||||||
|
|
||||||
|
def test_to_dict(self, sample_market_context):
|
||||||
|
"""To dict should serialize correctly."""
|
||||||
|
data = sample_market_context.to_dict()
|
||||||
|
assert data["vix"] == 18.5
|
||||||
|
assert data["economic_regime"] == "EXPANSION"
|
||||||
|
|
||||||
|
def test_from_dict(self, sample_market_context):
|
||||||
|
"""From dict should deserialize correctly."""
|
||||||
|
data = sample_market_context.to_dict()
|
||||||
|
restored = MarketContext.from_dict(data)
|
||||||
|
assert restored.vix == sample_market_context.vix
|
||||||
|
assert restored.economic_regime == sample_market_context.economic_regime
|
||||||
|
|
||||||
|
def test_summary(self, sample_market_context):
|
||||||
|
"""Summary should generate text."""
|
||||||
|
summary = sample_market_context.summary()
|
||||||
|
assert "VIX" in summary
|
||||||
|
assert "Regime" in summary
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TradeRecord Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestTradeRecord:
|
||||||
|
"""Tests for TradeRecord class."""
|
||||||
|
|
||||||
|
def test_create_trade(self):
|
||||||
|
"""Create should generate a valid trade record."""
|
||||||
|
trade = TradeRecord.create(
|
||||||
|
symbol="AAPL",
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=150.0,
|
||||||
|
)
|
||||||
|
assert trade.symbol == "AAPL"
|
||||||
|
assert trade.direction == TradeDirection.LONG
|
||||||
|
assert trade.entry_price == 150.0
|
||||||
|
assert trade.id is not None
|
||||||
|
assert trade.entry_time is not None
|
||||||
|
assert trade.is_open()
|
||||||
|
|
||||||
|
def test_create_with_all_args(self, sample_reasoning, sample_market_context):
|
||||||
|
"""Create with all arguments should work."""
|
||||||
|
trade = TradeRecord.create(
|
||||||
|
symbol="AAPL",
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=150.0,
|
||||||
|
quantity=100,
|
||||||
|
signal_strength=SignalStrength.STRONG_BUY,
|
||||||
|
confidence=0.85,
|
||||||
|
reasoning=sample_reasoning,
|
||||||
|
market_context=sample_market_context,
|
||||||
|
tags=["tech"],
|
||||||
|
)
|
||||||
|
assert trade.quantity == 100
|
||||||
|
assert trade.signal_strength == SignalStrength.STRONG_BUY
|
||||||
|
assert trade.confidence == 0.85
|
||||||
|
assert trade.reasoning == sample_reasoning
|
||||||
|
assert trade.market_context == sample_market_context
|
||||||
|
|
||||||
|
def test_close_profitable(self, sample_trade):
|
||||||
|
"""Closing at higher price should be profitable."""
|
||||||
|
sample_trade.close(165.0)
|
||||||
|
|
||||||
|
assert sample_trade.exit_price == 165.0
|
||||||
|
assert sample_trade.exit_time is not None
|
||||||
|
assert sample_trade.returns is not None
|
||||||
|
assert sample_trade.returns > 0
|
||||||
|
assert sample_trade.outcome == TradeOutcome.PROFITABLE
|
||||||
|
assert not sample_trade.is_open()
|
||||||
|
|
||||||
|
def test_close_loss(self, sample_trade):
|
||||||
|
"""Closing at lower price should be a loss."""
|
||||||
|
sample_trade.close(140.0)
|
||||||
|
|
||||||
|
assert sample_trade.returns < 0
|
||||||
|
assert sample_trade.outcome == TradeOutcome.LOSS
|
||||||
|
|
||||||
|
def test_close_break_even(self, sample_trade):
|
||||||
|
"""Closing at same price should be break even."""
|
||||||
|
sample_trade.close(150.2) # Within 0.5% threshold
|
||||||
|
|
||||||
|
assert abs(sample_trade.returns) < 0.005
|
||||||
|
assert sample_trade.outcome == TradeOutcome.BREAK_EVEN
|
||||||
|
|
||||||
|
def test_close_short_profitable(self):
|
||||||
|
"""Short trade profitable when price goes down."""
|
||||||
|
trade = TradeRecord.create(
|
||||||
|
symbol="AAPL",
|
||||||
|
direction=TradeDirection.SHORT,
|
||||||
|
entry_price=150.0,
|
||||||
|
)
|
||||||
|
trade.close(140.0)
|
||||||
|
|
||||||
|
assert trade.returns > 0
|
||||||
|
assert trade.outcome == TradeOutcome.PROFITABLE
|
||||||
|
|
||||||
|
def test_close_short_loss(self):
|
||||||
|
"""Short trade loss when price goes up."""
|
||||||
|
trade = TradeRecord.create(
|
||||||
|
symbol="AAPL",
|
||||||
|
direction=TradeDirection.SHORT,
|
||||||
|
entry_price=150.0,
|
||||||
|
)
|
||||||
|
trade.close(160.0)
|
||||||
|
|
||||||
|
assert trade.returns < 0
|
||||||
|
assert trade.outcome == TradeOutcome.LOSS
|
||||||
|
|
||||||
|
def test_pnl_calculation(self, sample_trade):
|
||||||
|
"""PnL should be calculated correctly."""
|
||||||
|
sample_trade.close(165.0)
|
||||||
|
|
||||||
|
expected_return = (165 - 150) / 150
|
||||||
|
expected_pnl = expected_return * 150 * 100
|
||||||
|
|
||||||
|
assert abs(sample_trade.pnl - expected_pnl) < 0.01
|
||||||
|
|
||||||
|
def test_holding_period(self, sample_trade):
|
||||||
|
"""Holding period should be calculated correctly."""
|
||||||
|
assert sample_trade.holding_period_days() is None # Still open
|
||||||
|
|
||||||
|
sample_trade.close(165.0)
|
||||||
|
holding = sample_trade.holding_period_days()
|
||||||
|
assert holding is not None
|
||||||
|
assert holding >= 0
|
||||||
|
|
||||||
|
def test_to_memory_content(self, sample_trade):
|
||||||
|
"""Memory content should include key info."""
|
||||||
|
sample_trade.close(165.0)
|
||||||
|
content = sample_trade.to_memory_content()
|
||||||
|
|
||||||
|
assert "AAPL" in content
|
||||||
|
assert "long" in content
|
||||||
|
assert "profitable" in content
|
||||||
|
|
||||||
|
def test_to_dict(self, sample_trade):
|
||||||
|
"""To dict should serialize correctly."""
|
||||||
|
data = sample_trade.to_dict()
|
||||||
|
|
||||||
|
assert data["symbol"] == "AAPL"
|
||||||
|
assert data["direction"] == "long"
|
||||||
|
assert data["entry_price"] == 150.0
|
||||||
|
assert "reasoning" in data
|
||||||
|
assert "market_context" in data
|
||||||
|
|
||||||
|
def test_from_dict(self, sample_trade):
|
||||||
|
"""From dict should deserialize correctly."""
|
||||||
|
sample_trade.close(165.0)
|
||||||
|
data = sample_trade.to_dict()
|
||||||
|
restored = TradeRecord.from_dict(data)
|
||||||
|
|
||||||
|
assert restored.symbol == sample_trade.symbol
|
||||||
|
assert restored.direction == sample_trade.direction
|
||||||
|
assert restored.entry_price == sample_trade.entry_price
|
||||||
|
assert restored.exit_price == sample_trade.exit_price
|
||||||
|
assert restored.outcome == sample_trade.outcome
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TradeHistoryMemory Basic Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestTradeHistoryMemoryBasic:
|
||||||
|
"""Basic tests for TradeHistoryMemory."""
|
||||||
|
|
||||||
|
def test_create_empty(self, memory):
|
||||||
|
"""Empty memory should have zero trades."""
|
||||||
|
assert memory.count() == 0
|
||||||
|
|
||||||
|
def test_record_trade(self, memory, sample_trade):
|
||||||
|
"""Recording a trade should increase count."""
|
||||||
|
trade_id = memory.record_trade(sample_trade)
|
||||||
|
assert trade_id == sample_trade.id
|
||||||
|
assert memory.count() == 1
|
||||||
|
|
||||||
|
def test_get_trade(self, memory, sample_trade):
|
||||||
|
"""Get should return recorded trade."""
|
||||||
|
memory.record_trade(sample_trade)
|
||||||
|
retrieved = memory.get_trade(sample_trade.id)
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.symbol == sample_trade.symbol
|
||||||
|
|
||||||
|
def test_get_nonexistent(self, memory):
|
||||||
|
"""Get nonexistent trade should return None."""
|
||||||
|
result = memory.get_trade("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_close_trade(self, memory, sample_trade):
|
||||||
|
"""Closing a trade should update it."""
|
||||||
|
memory.record_trade(sample_trade)
|
||||||
|
closed = memory.close_trade(sample_trade.id, exit_price=165.0)
|
||||||
|
|
||||||
|
assert closed is not None
|
||||||
|
assert closed.exit_price == 165.0
|
||||||
|
assert closed.outcome is not None
|
||||||
|
|
||||||
|
def test_close_with_lessons(self, memory, sample_trade):
|
||||||
|
"""Closing with lessons should store them."""
|
||||||
|
memory.record_trade(sample_trade)
|
||||||
|
memory.close_trade(
|
||||||
|
sample_trade.id,
|
||||||
|
exit_price=165.0,
|
||||||
|
lessons_learned="Wait for confirmation before entry",
|
||||||
|
)
|
||||||
|
|
||||||
|
trade = memory.get_trade(sample_trade.id)
|
||||||
|
assert trade.lessons_learned == "Wait for confirmation before entry"
|
||||||
|
|
||||||
|
def test_clear(self, memory, multiple_trades):
|
||||||
|
"""Clear should remove all trades."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
count = memory.clear()
|
||||||
|
assert count == len(multiple_trades)
|
||||||
|
assert memory.count() == 0
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TradeHistoryMemory Query Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestTradeHistoryMemoryQueries:
|
||||||
|
"""Query tests for TradeHistoryMemory."""
|
||||||
|
|
||||||
|
def test_get_open_trades(self, memory, multiple_trades):
|
||||||
|
"""Get open trades should filter correctly."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
open_trades = memory.get_open_trades()
|
||||||
|
assert len(open_trades) == 1 # Only NVDA is open
|
||||||
|
assert open_trades[0].symbol == "NVDA"
|
||||||
|
|
||||||
|
def test_get_closed_trades(self, memory, multiple_trades):
|
||||||
|
"""Get closed trades should filter correctly."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
closed_trades = memory.get_closed_trades()
|
||||||
|
assert len(closed_trades) == 3 # AAPL, GOOGL, MSFT
|
||||||
|
|
||||||
|
def test_get_trades_by_symbol(self, memory, multiple_trades):
|
||||||
|
"""Get trades by symbol should filter correctly."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
aapl_trades = memory.get_trades_by_symbol("AAPL")
|
||||||
|
assert len(aapl_trades) == 1
|
||||||
|
assert aapl_trades[0].symbol == "AAPL"
|
||||||
|
|
||||||
|
def test_find_similar_trades(self, memory, multiple_trades):
|
||||||
|
"""Find similar trades should return relevant trades."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
similar = memory.find_similar_trades(
|
||||||
|
query="tech stock momentum earnings",
|
||||||
|
top_k=3,
|
||||||
|
)
|
||||||
|
assert len(similar) >= 1
|
||||||
|
|
||||||
|
def test_find_profitable_patterns(self, memory, multiple_trades):
|
||||||
|
"""Find profitable patterns should return winners."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
profitable = memory.find_profitable_patterns(
|
||||||
|
query="tech stock buy",
|
||||||
|
min_return=0.05,
|
||||||
|
top_k=5,
|
||||||
|
)
|
||||||
|
# AAPL had 10% return
|
||||||
|
assert len(profitable) >= 1
|
||||||
|
for trade in profitable:
|
||||||
|
assert trade.returns >= 0.05
|
||||||
|
|
||||||
|
def test_find_losing_patterns(self, memory, multiple_trades):
|
||||||
|
"""Find losing patterns should return losers."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
losers = memory.find_losing_patterns(
|
||||||
|
query="tech stock momentum",
|
||||||
|
max_return=-0.05,
|
||||||
|
top_k=5,
|
||||||
|
)
|
||||||
|
# GOOGL had -7% return
|
||||||
|
assert len(losers) >= 1
|
||||||
|
for trade in losers:
|
||||||
|
assert trade.returns <= -0.05
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TradeHistoryMemory Statistics Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestTradeHistoryMemoryStats:
|
||||||
|
"""Statistics tests for TradeHistoryMemory."""
|
||||||
|
|
||||||
|
def test_empty_statistics(self, memory):
|
||||||
|
"""Empty memory should have zero stats."""
|
||||||
|
stats = memory.get_statistics()
|
||||||
|
assert stats["total_trades"] == 0
|
||||||
|
assert stats["win_rate"] == 0.0
|
||||||
|
|
||||||
|
def test_statistics_with_trades(self, memory, multiple_trades):
|
||||||
|
"""Statistics should reflect trade data."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
stats = memory.get_statistics()
|
||||||
|
|
||||||
|
assert stats["total_trades"] == 4
|
||||||
|
assert stats["open_trades"] == 1
|
||||||
|
assert stats["closed_trades"] == 3
|
||||||
|
assert stats["win_rate"] > 0 # At least AAPL was profitable
|
||||||
|
|
||||||
|
def test_symbol_statistics(self, memory, multiple_trades):
|
||||||
|
"""Symbol statistics should be calculated correctly."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
stats = memory.get_symbol_statistics("AAPL")
|
||||||
|
|
||||||
|
assert stats["symbol"] == "AAPL"
|
||||||
|
assert stats["total_trades"] == 1
|
||||||
|
assert stats["closed_trades"] == 1
|
||||||
|
assert stats["win_rate"] == 1.0 # AAPL was profitable
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TradeHistoryMemory Serialization Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestTradeHistoryMemorySerialization:
|
||||||
|
"""Serialization tests for TradeHistoryMemory."""
|
||||||
|
|
||||||
|
def test_to_dict(self, memory, multiple_trades):
|
||||||
|
"""To dict should serialize correctly."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
data = memory.to_dict()
|
||||||
|
|
||||||
|
assert "trades" in data
|
||||||
|
assert "memory" in data
|
||||||
|
assert len(data["trades"]) == len(multiple_trades)
|
||||||
|
|
||||||
|
def test_from_dict(self, memory, multiple_trades):
|
||||||
|
"""From dict should deserialize correctly."""
|
||||||
|
for trade in multiple_trades:
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
data = memory.to_dict()
|
||||||
|
restored = TradeHistoryMemory.from_dict(data)
|
||||||
|
|
||||||
|
assert restored.count() == len(multiple_trades)
|
||||||
|
|
||||||
|
def test_roundtrip(self, memory, sample_trade):
|
||||||
|
"""Roundtrip serialization should preserve data."""
|
||||||
|
memory.record_trade(sample_trade)
|
||||||
|
memory.close_trade(sample_trade.id, exit_price=165.0)
|
||||||
|
|
||||||
|
data = memory.to_dict()
|
||||||
|
restored = TradeHistoryMemory.from_dict(data)
|
||||||
|
|
||||||
|
original = memory.get_trade(sample_trade.id)
|
||||||
|
restored_trade = restored.get_trade(sample_trade.id)
|
||||||
|
|
||||||
|
assert restored_trade is not None
|
||||||
|
assert restored_trade.symbol == original.symbol
|
||||||
|
assert restored_trade.exit_price == original.exit_price
|
||||||
|
assert restored_trade.outcome == original.outcome
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Integration Tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TestIntegration:
|
||||||
|
"""Integration tests for trade history workflow."""
|
||||||
|
|
||||||
|
def test_full_trade_lifecycle(self, memory):
|
||||||
|
"""Test complete trade lifecycle."""
|
||||||
|
# 1. Create reasoning
|
||||||
|
reasoning = AgentReasoning(
|
||||||
|
fundamentals="Strong quarterly earnings",
|
||||||
|
technical="Golden cross on daily chart",
|
||||||
|
research_conclusion="Buy for momentum continuation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Create market context
|
||||||
|
context = MarketContext(
|
||||||
|
vix=15.0,
|
||||||
|
economic_regime="EXPANSION",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Open trade
|
||||||
|
trade = TradeRecord.create(
|
||||||
|
symbol="MSFT",
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=400.0,
|
||||||
|
quantity=50,
|
||||||
|
signal_strength=SignalStrength.BUY,
|
||||||
|
confidence=0.75,
|
||||||
|
reasoning=reasoning,
|
||||||
|
market_context=context,
|
||||||
|
tags=["tech", "earnings"],
|
||||||
|
)
|
||||||
|
memory.record_trade(trade)
|
||||||
|
|
||||||
|
# 4. Verify open
|
||||||
|
assert memory.count() == 1
|
||||||
|
assert len(memory.get_open_trades()) == 1
|
||||||
|
|
||||||
|
# 5. Close with profit
|
||||||
|
memory.close_trade(
|
||||||
|
trade.id,
|
||||||
|
exit_price=440.0,
|
||||||
|
lessons_learned="Good timing on earnings play",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Verify closed
|
||||||
|
assert len(memory.get_closed_trades()) == 1
|
||||||
|
closed = memory.get_trade(trade.id)
|
||||||
|
assert closed.outcome == TradeOutcome.PROFITABLE
|
||||||
|
assert closed.returns == pytest.approx(0.10, rel=0.01)
|
||||||
|
|
||||||
|
# 7. Find similar trades
|
||||||
|
similar = memory.find_similar_trades(
|
||||||
|
query="tech earnings momentum",
|
||||||
|
top_k=1,
|
||||||
|
)
|
||||||
|
assert len(similar) == 1
|
||||||
|
assert similar[0].symbol == "MSFT"
|
||||||
|
|
||||||
|
# 8. Check statistics
|
||||||
|
stats = memory.get_statistics()
|
||||||
|
assert stats["win_rate"] == 1.0
|
||||||
|
assert stats["avg_return"] > 0
|
||||||
|
|
||||||
|
def test_learning_from_history(self, memory):
|
||||||
|
"""Test learning patterns from trade history."""
|
||||||
|
# Record several trades
|
||||||
|
trades = [
|
||||||
|
("AAPL", "earnings beat", 150.0, 165.0), # +10%
|
||||||
|
("GOOGL", "earnings beat", 140.0, 147.0), # +5%
|
||||||
|
("MSFT", "earnings miss", 350.0, 315.0), # -10%
|
||||||
|
("NVDA", "earnings beat", 500.0, 550.0), # +10%
|
||||||
|
]
|
||||||
|
|
||||||
|
for symbol, pattern, entry, exit_price in trades:
|
||||||
|
trade = TradeRecord.create(
|
||||||
|
symbol=symbol,
|
||||||
|
direction=TradeDirection.LONG,
|
||||||
|
entry_price=entry,
|
||||||
|
reasoning=AgentReasoning(research_conclusion=f"Trade on {pattern}"),
|
||||||
|
tags=["earnings"],
|
||||||
|
)
|
||||||
|
memory.record_trade(trade)
|
||||||
|
memory.close_trade(trade.id, exit_price=exit_price)
|
||||||
|
|
||||||
|
# Query for earnings beat patterns
|
||||||
|
profitable = memory.find_profitable_patterns(
|
||||||
|
query="earnings beat momentum",
|
||||||
|
min_return=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should find the profitable earnings beat trades
|
||||||
|
assert len(profitable) >= 2
|
||||||
|
|
||||||
|
# Query for losses to avoid
|
||||||
|
losers = memory.find_losing_patterns(
|
||||||
|
query="earnings trade",
|
||||||
|
max_return=-0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should find the earnings miss trade
|
||||||
|
assert len(losers) >= 1
|
||||||
|
|
@ -6,6 +6,7 @@ This module provides a layered memory system with three scoring dimensions:
|
||||||
- Importance: Significance weighting for impactful events
|
- Importance: Significance weighting for impactful events
|
||||||
|
|
||||||
Issue #18: Layered memory - recency, relevancy, importance scoring
|
Issue #18: Layered memory - recency, relevancy, importance scoring
|
||||||
|
Issue #19: Trade history memory - outcomes, agent reasoning
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .layered_memory import (
|
from .layered_memory import (
|
||||||
|
|
@ -17,11 +18,30 @@ from .layered_memory import (
|
||||||
ImportanceLevel,
|
ImportanceLevel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .trade_history import (
|
||||||
|
TradeHistoryMemory,
|
||||||
|
TradeRecord,
|
||||||
|
TradeOutcome,
|
||||||
|
TradeDirection,
|
||||||
|
SignalStrength,
|
||||||
|
AgentReasoning,
|
||||||
|
MarketContext,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Layered Memory (Issue #18)
|
||||||
"LayeredMemory",
|
"LayeredMemory",
|
||||||
"MemoryEntry",
|
"MemoryEntry",
|
||||||
"MemoryConfig",
|
"MemoryConfig",
|
||||||
"ScoringWeights",
|
"ScoringWeights",
|
||||||
"DecayFunction",
|
"DecayFunction",
|
||||||
"ImportanceLevel",
|
"ImportanceLevel",
|
||||||
|
# Trade History (Issue #19)
|
||||||
|
"TradeHistoryMemory",
|
||||||
|
"TradeRecord",
|
||||||
|
"TradeOutcome",
|
||||||
|
"TradeDirection",
|
||||||
|
"SignalStrength",
|
||||||
|
"AgentReasoning",
|
||||||
|
"MarketContext",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,820 @@
|
||||||
|
"""Trade History Memory for learning from past trade outcomes.
|
||||||
|
|
||||||
|
This module provides specialized memory for tracking and learning from trades:
|
||||||
|
- Trade outcomes (profit/loss, returns)
|
||||||
|
- Agent reasoning and signals
|
||||||
|
- Entry/exit conditions
|
||||||
|
- Market context at time of trade
|
||||||
|
|
||||||
|
Issue #19: [MEM-18] Trade history memory - outcomes, agent reasoning
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .layered_memory import (
|
||||||
|
LayeredMemory,
|
||||||
|
MemoryEntry,
|
||||||
|
MemoryConfig,
|
||||||
|
ScoringWeights,
|
||||||
|
ImportanceLevel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TradeOutcome(Enum):
|
||||||
|
"""Trade outcome categories."""
|
||||||
|
PROFITABLE = "profitable" # Positive return
|
||||||
|
BREAK_EVEN = "break_even" # ~0% return
|
||||||
|
LOSS = "loss" # Negative return
|
||||||
|
STOPPED_OUT = "stopped_out" # Hit stop loss
|
||||||
|
TARGET_HIT = "target_hit" # Hit profit target
|
||||||
|
|
||||||
|
|
||||||
|
class TradeDirection(Enum):
|
||||||
|
"""Trade direction."""
|
||||||
|
LONG = "long"
|
||||||
|
SHORT = "short"
|
||||||
|
HOLD = "hold"
|
||||||
|
|
||||||
|
|
||||||
|
class SignalStrength(Enum):
|
||||||
|
"""Signal strength levels."""
|
||||||
|
STRONG_BUY = "strong_buy"
|
||||||
|
BUY = "buy"
|
||||||
|
NEUTRAL = "neutral"
|
||||||
|
SELL = "sell"
|
||||||
|
STRONG_SELL = "strong_sell"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentReasoning:
|
||||||
|
"""Captures reasoning from each agent in the trading workflow.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
fundamentals: Fundamentals analyst reasoning
|
||||||
|
technical: Technical/market analyst reasoning
|
||||||
|
news: News analyst reasoning
|
||||||
|
sentiment: Social media sentiment reasoning
|
||||||
|
momentum: Momentum analyst reasoning (if enabled)
|
||||||
|
macro: Macro analyst reasoning (if enabled)
|
||||||
|
correlation: Correlation analyst reasoning (if enabled)
|
||||||
|
bull_case: Bull researcher arguments
|
||||||
|
bear_case: Bear researcher arguments
|
||||||
|
research_conclusion: Research manager decision
|
||||||
|
risk_assessment: Risk manager assessment
|
||||||
|
final_signal: Final trading signal
|
||||||
|
"""
|
||||||
|
fundamentals: Optional[str] = None
|
||||||
|
technical: Optional[str] = None
|
||||||
|
news: Optional[str] = None
|
||||||
|
sentiment: Optional[str] = None
|
||||||
|
momentum: Optional[str] = None
|
||||||
|
macro: Optional[str] = None
|
||||||
|
correlation: Optional[str] = None
|
||||||
|
bull_case: Optional[str] = None
|
||||||
|
bear_case: Optional[str] = None
|
||||||
|
research_conclusion: Optional[str] = None
|
||||||
|
risk_assessment: Optional[str] = None
|
||||||
|
final_signal: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Optional[str]]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"fundamentals": self.fundamentals,
|
||||||
|
"technical": self.technical,
|
||||||
|
"news": self.news,
|
||||||
|
"sentiment": self.sentiment,
|
||||||
|
"momentum": self.momentum,
|
||||||
|
"macro": self.macro,
|
||||||
|
"correlation": self.correlation,
|
||||||
|
"bull_case": self.bull_case,
|
||||||
|
"bear_case": self.bear_case,
|
||||||
|
"research_conclusion": self.research_conclusion,
|
||||||
|
"risk_assessment": self.risk_assessment,
|
||||||
|
"final_signal": self.final_signal,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "AgentReasoning":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
fundamentals=data.get("fundamentals"),
|
||||||
|
technical=data.get("technical"),
|
||||||
|
news=data.get("news"),
|
||||||
|
sentiment=data.get("sentiment"),
|
||||||
|
momentum=data.get("momentum"),
|
||||||
|
macro=data.get("macro"),
|
||||||
|
correlation=data.get("correlation"),
|
||||||
|
bull_case=data.get("bull_case"),
|
||||||
|
bear_case=data.get("bear_case"),
|
||||||
|
research_conclusion=data.get("research_conclusion"),
|
||||||
|
risk_assessment=data.get("risk_assessment"),
|
||||||
|
final_signal=data.get("final_signal"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""Generate a text summary of the reasoning."""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if self.fundamentals:
|
||||||
|
parts.append(f"Fundamentals: {self.fundamentals[:100]}...")
|
||||||
|
if self.technical:
|
||||||
|
parts.append(f"Technical: {self.technical[:100]}...")
|
||||||
|
if self.bull_case:
|
||||||
|
parts.append(f"Bull: {self.bull_case[:100]}...")
|
||||||
|
if self.bear_case:
|
||||||
|
parts.append(f"Bear: {self.bear_case[:100]}...")
|
||||||
|
if self.research_conclusion:
|
||||||
|
parts.append(f"Conclusion: {self.research_conclusion[:100]}...")
|
||||||
|
|
||||||
|
return " | ".join(parts) if parts else "No reasoning recorded"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketContext:
|
||||||
|
"""Market conditions at time of trade.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
vix: VIX volatility index level
|
||||||
|
spy_return_1d: SPY 1-day return
|
||||||
|
sector_performance: Dict of sector returns
|
||||||
|
economic_regime: Detected economic regime
|
||||||
|
yield_curve_state: Yield curve status
|
||||||
|
macro_indicators: Key macro indicators
|
||||||
|
"""
|
||||||
|
vix: Optional[float] = None
|
||||||
|
spy_return_1d: Optional[float] = None
|
||||||
|
sector_performance: Dict[str, float] = field(default_factory=dict)
|
||||||
|
economic_regime: Optional[str] = None
|
||||||
|
yield_curve_state: Optional[str] = None
|
||||||
|
macro_indicators: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"vix": self.vix,
|
||||||
|
"spy_return_1d": self.spy_return_1d,
|
||||||
|
"sector_performance": self.sector_performance,
|
||||||
|
"economic_regime": self.economic_regime,
|
||||||
|
"yield_curve_state": self.yield_curve_state,
|
||||||
|
"macro_indicators": self.macro_indicators,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "MarketContext":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
vix=data.get("vix"),
|
||||||
|
spy_return_1d=data.get("spy_return_1d"),
|
||||||
|
sector_performance=data.get("sector_performance", {}),
|
||||||
|
economic_regime=data.get("economic_regime"),
|
||||||
|
yield_curve_state=data.get("yield_curve_state"),
|
||||||
|
macro_indicators=data.get("macro_indicators", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""Generate a text summary of market context."""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if self.vix is not None:
|
||||||
|
parts.append(f"VIX: {self.vix:.1f}")
|
||||||
|
if self.economic_regime:
|
||||||
|
parts.append(f"Regime: {self.economic_regime}")
|
||||||
|
if self.yield_curve_state:
|
||||||
|
parts.append(f"Yield Curve: {self.yield_curve_state}")
|
||||||
|
|
||||||
|
return " | ".join(parts) if parts else "No market context"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TradeRecord:
|
||||||
|
"""Complete record of a trade including reasoning and outcome.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique trade ID
|
||||||
|
symbol: Trading symbol
|
||||||
|
direction: Trade direction (long/short/hold)
|
||||||
|
entry_price: Entry price
|
||||||
|
exit_price: Exit price (if closed)
|
||||||
|
entry_time: Entry timestamp
|
||||||
|
exit_time: Exit timestamp (if closed)
|
||||||
|
quantity: Number of shares/contracts
|
||||||
|
returns: Percentage return
|
||||||
|
pnl: Dollar profit/loss
|
||||||
|
outcome: Trade outcome category
|
||||||
|
signal_strength: Original signal strength
|
||||||
|
confidence: Confidence score (0-1)
|
||||||
|
reasoning: Agent reasoning captured
|
||||||
|
market_context: Market conditions at entry
|
||||||
|
lessons_learned: Post-trade lessons (added later)
|
||||||
|
tags: Trade tags for filtering
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
symbol: str
|
||||||
|
direction: TradeDirection
|
||||||
|
entry_price: float
|
||||||
|
entry_time: datetime
|
||||||
|
quantity: float = 1.0
|
||||||
|
exit_price: Optional[float] = None
|
||||||
|
exit_time: Optional[datetime] = None
|
||||||
|
returns: Optional[float] = None
|
||||||
|
pnl: Optional[float] = None
|
||||||
|
outcome: Optional[TradeOutcome] = None
|
||||||
|
signal_strength: SignalStrength = SignalStrength.NEUTRAL
|
||||||
|
confidence: float = 0.5
|
||||||
|
reasoning: AgentReasoning = field(default_factory=AgentReasoning)
|
||||||
|
market_context: MarketContext = field(default_factory=MarketContext)
|
||||||
|
lessons_learned: Optional[str] = None
|
||||||
|
tags: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
symbol: str,
|
||||||
|
direction: TradeDirection,
|
||||||
|
entry_price: float,
|
||||||
|
quantity: float = 1.0,
|
||||||
|
signal_strength: SignalStrength = SignalStrength.NEUTRAL,
|
||||||
|
confidence: float = 0.5,
|
||||||
|
reasoning: Optional[AgentReasoning] = None,
|
||||||
|
market_context: Optional[MarketContext] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
) -> "TradeRecord":
|
||||||
|
"""Create a new trade record."""
|
||||||
|
return cls(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
symbol=symbol,
|
||||||
|
direction=direction,
|
||||||
|
entry_price=entry_price,
|
||||||
|
entry_time=datetime.now(),
|
||||||
|
quantity=quantity,
|
||||||
|
signal_strength=signal_strength,
|
||||||
|
confidence=confidence,
|
||||||
|
reasoning=reasoning or AgentReasoning(),
|
||||||
|
market_context=market_context or MarketContext(),
|
||||||
|
tags=tags or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(
|
||||||
|
self,
|
||||||
|
exit_price: float,
|
||||||
|
exit_time: Optional[datetime] = None,
|
||||||
|
) -> "TradeRecord":
|
||||||
|
"""Close the trade and calculate returns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exit_price: Exit price
|
||||||
|
exit_time: Exit timestamp (default: now)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Self with updated exit info
|
||||||
|
"""
|
||||||
|
self.exit_price = exit_price
|
||||||
|
self.exit_time = exit_time or datetime.now()
|
||||||
|
|
||||||
|
# Calculate returns
|
||||||
|
if self.direction == TradeDirection.LONG:
|
||||||
|
self.returns = (exit_price - self.entry_price) / self.entry_price
|
||||||
|
elif self.direction == TradeDirection.SHORT:
|
||||||
|
self.returns = (self.entry_price - exit_price) / self.entry_price
|
||||||
|
else:
|
||||||
|
self.returns = 0.0
|
||||||
|
|
||||||
|
# Calculate PnL
|
||||||
|
self.pnl = self.returns * self.entry_price * self.quantity
|
||||||
|
|
||||||
|
# Determine outcome
|
||||||
|
if self.returns > 0.005: # > 0.5%
|
||||||
|
self.outcome = TradeOutcome.PROFITABLE
|
||||||
|
elif self.returns < -0.005: # < -0.5%
|
||||||
|
self.outcome = TradeOutcome.LOSS
|
||||||
|
else:
|
||||||
|
self.outcome = TradeOutcome.BREAK_EVEN
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def is_open(self) -> bool:
|
||||||
|
"""Check if trade is still open."""
|
||||||
|
return self.exit_time is None
|
||||||
|
|
||||||
|
def holding_period_days(self) -> Optional[float]:
|
||||||
|
"""Calculate holding period in days."""
|
||||||
|
if self.exit_time is None:
|
||||||
|
return None
|
||||||
|
delta = self.exit_time - self.entry_time
|
||||||
|
return delta.total_seconds() / 86400
|
||||||
|
|
||||||
|
def to_memory_content(self) -> str:
|
||||||
|
"""Generate memory content for this trade."""
|
||||||
|
parts = [
|
||||||
|
f"Trade {self.direction.value} {self.symbol} at ${self.entry_price:.2f}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.outcome:
|
||||||
|
parts.append(f"Outcome: {self.outcome.value}")
|
||||||
|
|
||||||
|
if self.returns is not None:
|
||||||
|
parts.append(f"Return: {self.returns * 100:.2f}%")
|
||||||
|
|
||||||
|
if self.reasoning.research_conclusion:
|
||||||
|
parts.append(f"Reasoning: {self.reasoning.research_conclusion}")
|
||||||
|
|
||||||
|
if self.market_context.economic_regime:
|
||||||
|
parts.append(f"Regime: {self.market_context.economic_regime}")
|
||||||
|
|
||||||
|
return " | ".join(parts)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"symbol": self.symbol,
|
||||||
|
"direction": self.direction.value,
|
||||||
|
"entry_price": self.entry_price,
|
||||||
|
"exit_price": self.exit_price,
|
||||||
|
"entry_time": self.entry_time.isoformat(),
|
||||||
|
"exit_time": self.exit_time.isoformat() if self.exit_time else None,
|
||||||
|
"quantity": self.quantity,
|
||||||
|
"returns": self.returns,
|
||||||
|
"pnl": self.pnl,
|
||||||
|
"outcome": self.outcome.value if self.outcome else None,
|
||||||
|
"signal_strength": self.signal_strength.value,
|
||||||
|
"confidence": self.confidence,
|
||||||
|
"reasoning": self.reasoning.to_dict(),
|
||||||
|
"market_context": self.market_context.to_dict(),
|
||||||
|
"lessons_learned": self.lessons_learned,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "TradeRecord":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
id=data["id"],
|
||||||
|
symbol=data["symbol"],
|
||||||
|
direction=TradeDirection(data["direction"]),
|
||||||
|
entry_price=data["entry_price"],
|
||||||
|
exit_price=data.get("exit_price"),
|
||||||
|
entry_time=datetime.fromisoformat(data["entry_time"]),
|
||||||
|
exit_time=datetime.fromisoformat(data["exit_time"]) if data.get("exit_time") else None,
|
||||||
|
quantity=data.get("quantity", 1.0),
|
||||||
|
returns=data.get("returns"),
|
||||||
|
pnl=data.get("pnl"),
|
||||||
|
outcome=TradeOutcome(data["outcome"]) if data.get("outcome") else None,
|
||||||
|
signal_strength=SignalStrength(data.get("signal_strength", "neutral")),
|
||||||
|
confidence=data.get("confidence", 0.5),
|
||||||
|
reasoning=AgentReasoning.from_dict(data.get("reasoning", {})),
|
||||||
|
market_context=MarketContext.from_dict(data.get("market_context", {})),
|
||||||
|
lessons_learned=data.get("lessons_learned"),
|
||||||
|
tags=data.get("tags", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TradeHistoryMemory:
|
||||||
|
"""Memory system specialized for trade history and learning.
|
||||||
|
|
||||||
|
This class combines trade record storage with the layered memory system
|
||||||
|
for intelligent retrieval of past trades based on similarity.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> memory = TradeHistoryMemory()
|
||||||
|
>>>
|
||||||
|
>>> # Record a trade
|
||||||
|
>>> trade = TradeRecord.create(
|
||||||
|
... symbol="AAPL",
|
||||||
|
... direction=TradeDirection.LONG,
|
||||||
|
... entry_price=150.0,
|
||||||
|
... reasoning=AgentReasoning(
|
||||||
|
... fundamentals="Strong earnings",
|
||||||
|
... research_conclusion="Buy on earnings momentum",
|
||||||
|
... ),
|
||||||
|
... )
|
||||||
|
>>> memory.record_trade(trade)
|
||||||
|
>>>
|
||||||
|
>>> # Close the trade
|
||||||
|
>>> memory.close_trade(trade.id, exit_price=165.0)
|
||||||
|
>>>
|
||||||
|
>>> # Find similar past trades
|
||||||
|
>>> similar = memory.find_similar_trades(
|
||||||
|
... query="Apple earnings momentum",
|
||||||
|
... top_k=5,
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[MemoryConfig] = None,
|
||||||
|
embedding_function=None,
|
||||||
|
):
|
||||||
|
"""Initialize trade history memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Memory configuration
|
||||||
|
embedding_function: Optional embedding function for similarity
|
||||||
|
"""
|
||||||
|
# Default config with weights tuned for trade history
|
||||||
|
if config is None:
|
||||||
|
config = MemoryConfig(
|
||||||
|
weights=ScoringWeights(
|
||||||
|
recency=0.25, # Recent trades somewhat important
|
||||||
|
relevancy=0.45, # Similarity most important
|
||||||
|
importance=0.30, # Outcome importance matters
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._layered_memory = LayeredMemory(
|
||||||
|
config=config,
|
||||||
|
embedding_function=embedding_function,
|
||||||
|
)
|
||||||
|
self._trades: Dict[str, TradeRecord] = {}
|
||||||
|
|
||||||
|
def record_trade(self, trade: TradeRecord) -> str:
|
||||||
|
"""Record a new trade.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade: The trade record to store
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trade ID
|
||||||
|
"""
|
||||||
|
self._trades[trade.id] = trade
|
||||||
|
|
||||||
|
# Create memory entry for the trade
|
||||||
|
importance = self._calculate_trade_importance(trade)
|
||||||
|
entry = MemoryEntry.create(
|
||||||
|
content=trade.to_memory_content(),
|
||||||
|
metadata={
|
||||||
|
"trade_id": trade.id,
|
||||||
|
"symbol": trade.symbol,
|
||||||
|
"direction": trade.direction.value,
|
||||||
|
"outcome": trade.outcome.value if trade.outcome else None,
|
||||||
|
"returns": trade.returns,
|
||||||
|
"reasoning_summary": trade.reasoning.summary(),
|
||||||
|
},
|
||||||
|
importance=importance,
|
||||||
|
tags=trade.tags + [trade.symbol, trade.direction.value],
|
||||||
|
timestamp=trade.entry_time,
|
||||||
|
)
|
||||||
|
entry.id = trade.id # Use trade ID as memory ID
|
||||||
|
|
||||||
|
self._layered_memory.add(entry)
|
||||||
|
return trade.id
|
||||||
|
|
||||||
|
def close_trade(
|
||||||
|
self,
|
||||||
|
trade_id: str,
|
||||||
|
exit_price: float,
|
||||||
|
lessons_learned: Optional[str] = None,
|
||||||
|
) -> Optional[TradeRecord]:
|
||||||
|
"""Close an open trade and update memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade_id: ID of the trade to close
|
||||||
|
exit_price: Exit price
|
||||||
|
lessons_learned: Optional lessons from the trade
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated trade record or None if not found
|
||||||
|
"""
|
||||||
|
trade = self._trades.get(trade_id)
|
||||||
|
if trade is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Close the trade
|
||||||
|
trade.close(exit_price)
|
||||||
|
|
||||||
|
if lessons_learned:
|
||||||
|
trade.lessons_learned = lessons_learned
|
||||||
|
|
||||||
|
# Update memory with outcome
|
||||||
|
importance = self._calculate_trade_importance(trade)
|
||||||
|
self._layered_memory.update_importance(trade_id, importance)
|
||||||
|
|
||||||
|
# Update memory entry content
|
||||||
|
entry = self._layered_memory.get(trade_id)
|
||||||
|
if entry:
|
||||||
|
entry.metadata["outcome"] = trade.outcome.value if trade.outcome else None
|
||||||
|
entry.metadata["returns"] = trade.returns
|
||||||
|
if lessons_learned:
|
||||||
|
entry.metadata["lessons_learned"] = lessons_learned
|
||||||
|
|
||||||
|
return trade
|
||||||
|
|
||||||
|
def get_trade(self, trade_id: str) -> Optional[TradeRecord]:
|
||||||
|
"""Get a trade by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade_id: Trade ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trade record or None
|
||||||
|
"""
|
||||||
|
return self._trades.get(trade_id)
|
||||||
|
|
||||||
|
def get_open_trades(self) -> List[TradeRecord]:
|
||||||
|
"""Get all open trades.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of open trade records
|
||||||
|
"""
|
||||||
|
return [t for t in self._trades.values() if t.is_open()]
|
||||||
|
|
||||||
|
def get_closed_trades(self) -> List[TradeRecord]:
|
||||||
|
"""Get all closed trades.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of closed trade records
|
||||||
|
"""
|
||||||
|
return [t for t in self._trades.values() if not t.is_open()]
|
||||||
|
|
||||||
|
def get_trades_by_symbol(self, symbol: str) -> List[TradeRecord]:
|
||||||
|
"""Get all trades for a symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of trade records for the symbol
|
||||||
|
"""
|
||||||
|
return [t for t in self._trades.values() if t.symbol == symbol]
|
||||||
|
|
||||||
|
def find_similar_trades(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
symbol: Optional[str] = None,
|
||||||
|
outcome: Optional[TradeOutcome] = None,
|
||||||
|
) -> List[TradeRecord]:
|
||||||
|
"""Find similar past trades.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query describing the current situation
|
||||||
|
top_k: Maximum number of results
|
||||||
|
symbol: Optional filter by symbol
|
||||||
|
outcome: Optional filter by outcome
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of similar trade records
|
||||||
|
"""
|
||||||
|
# Build tags filter
|
||||||
|
tags = []
|
||||||
|
if symbol:
|
||||||
|
tags.append(symbol)
|
||||||
|
|
||||||
|
# Retrieve from layered memory
|
||||||
|
results = self._layered_memory.retrieve(
|
||||||
|
query=query,
|
||||||
|
top_k=top_k * 2, # Get more to filter
|
||||||
|
tags=tags if tags else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to trade records and filter
|
||||||
|
trades = []
|
||||||
|
for scored in results:
|
||||||
|
trade_id = scored.entry.metadata.get("trade_id")
|
||||||
|
if trade_id and trade_id in self._trades:
|
||||||
|
trade = self._trades[trade_id]
|
||||||
|
|
||||||
|
# Apply outcome filter
|
||||||
|
if outcome and trade.outcome != outcome:
|
||||||
|
continue
|
||||||
|
|
||||||
|
trades.append(trade)
|
||||||
|
|
||||||
|
if len(trades) >= top_k:
|
||||||
|
break
|
||||||
|
|
||||||
|
return trades
|
||||||
|
|
||||||
|
def find_profitable_patterns(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
min_return: float = 0.0,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> List[TradeRecord]:
|
||||||
|
"""Find profitable trades similar to the query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query describing the current situation
|
||||||
|
min_return: Minimum return filter
|
||||||
|
top_k: Maximum number of results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of profitable trade records
|
||||||
|
"""
|
||||||
|
results = self._layered_memory.retrieve(query=query, top_k=top_k * 3)
|
||||||
|
|
||||||
|
trades = []
|
||||||
|
for scored in results:
|
||||||
|
trade_id = scored.entry.metadata.get("trade_id")
|
||||||
|
if trade_id and trade_id in self._trades:
|
||||||
|
trade = self._trades[trade_id]
|
||||||
|
|
||||||
|
if trade.returns is not None and trade.returns >= min_return:
|
||||||
|
trades.append(trade)
|
||||||
|
|
||||||
|
if len(trades) >= top_k:
|
||||||
|
break
|
||||||
|
|
||||||
|
return trades
|
||||||
|
|
||||||
|
def find_losing_patterns(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
max_return: float = 0.0,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> List[TradeRecord]:
|
||||||
|
"""Find losing trades similar to the query (for learning what to avoid).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query describing the current situation
|
||||||
|
max_return: Maximum return filter
|
||||||
|
top_k: Maximum number of results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of losing trade records
|
||||||
|
"""
|
||||||
|
results = self._layered_memory.retrieve(query=query, top_k=top_k * 3)
|
||||||
|
|
||||||
|
trades = []
|
||||||
|
for scored in results:
|
||||||
|
trade_id = scored.entry.metadata.get("trade_id")
|
||||||
|
if trade_id and trade_id in self._trades:
|
||||||
|
trade = self._trades[trade_id]
|
||||||
|
|
||||||
|
if trade.returns is not None and trade.returns <= max_return:
|
||||||
|
trades.append(trade)
|
||||||
|
|
||||||
|
if len(trades) >= top_k:
|
||||||
|
break
|
||||||
|
|
||||||
|
return trades
|
||||||
|
|
||||||
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""Get trade history statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with statistics
|
||||||
|
"""
|
||||||
|
closed_trades = self.get_closed_trades()
|
||||||
|
|
||||||
|
if not closed_trades:
|
||||||
|
return {
|
||||||
|
"total_trades": len(self._trades),
|
||||||
|
"open_trades": len(self.get_open_trades()),
|
||||||
|
"closed_trades": 0,
|
||||||
|
"win_rate": 0.0,
|
||||||
|
"avg_return": 0.0,
|
||||||
|
"total_pnl": 0.0,
|
||||||
|
"best_trade": None,
|
||||||
|
"worst_trade": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
returns = [t.returns for t in closed_trades if t.returns is not None]
|
||||||
|
pnls = [t.pnl for t in closed_trades if t.pnl is not None]
|
||||||
|
winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE]
|
||||||
|
|
||||||
|
# Find best and worst
|
||||||
|
best_trade = max(closed_trades, key=lambda t: t.returns or 0, default=None)
|
||||||
|
worst_trade = min(closed_trades, key=lambda t: t.returns or 0, default=None)
|
||||||
|
|
||||||
|
# Outcome distribution
|
||||||
|
outcome_dist = {}
|
||||||
|
for trade in closed_trades:
|
||||||
|
if trade.outcome:
|
||||||
|
outcome_dist[trade.outcome.value] = outcome_dist.get(trade.outcome.value, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_trades": len(self._trades),
|
||||||
|
"open_trades": len(self.get_open_trades()),
|
||||||
|
"closed_trades": len(closed_trades),
|
||||||
|
"win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0,
|
||||||
|
"avg_return": sum(returns) / len(returns) if returns else 0.0,
|
||||||
|
"total_pnl": sum(pnls) if pnls else 0.0,
|
||||||
|
"best_trade": best_trade.to_dict() if best_trade else None,
|
||||||
|
"worst_trade": worst_trade.to_dict() if worst_trade else None,
|
||||||
|
"outcome_distribution": outcome_dist,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_symbol_statistics(self, symbol: str) -> Dict[str, Any]:
|
||||||
|
"""Get statistics for a specific symbol.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with symbol-specific statistics
|
||||||
|
"""
|
||||||
|
symbol_trades = self.get_trades_by_symbol(symbol)
|
||||||
|
closed_trades = [t for t in symbol_trades if not t.is_open()]
|
||||||
|
|
||||||
|
if not closed_trades:
|
||||||
|
return {
|
||||||
|
"symbol": symbol,
|
||||||
|
"total_trades": len(symbol_trades),
|
||||||
|
"open_trades": len([t for t in symbol_trades if t.is_open()]),
|
||||||
|
"closed_trades": 0,
|
||||||
|
"win_rate": 0.0,
|
||||||
|
"avg_return": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
returns = [t.returns for t in closed_trades if t.returns is not None]
|
||||||
|
winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"symbol": symbol,
|
||||||
|
"total_trades": len(symbol_trades),
|
||||||
|
"open_trades": len([t for t in symbol_trades if t.is_open()]),
|
||||||
|
"closed_trades": len(closed_trades),
|
||||||
|
"win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0,
|
||||||
|
"avg_return": sum(returns) / len(returns) if returns else 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calculate_trade_importance(self, trade: TradeRecord) -> float:
|
||||||
|
"""Calculate importance score for a trade.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade: Trade record
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Importance score [0, 1]
|
||||||
|
"""
|
||||||
|
# Base on returns if closed
|
||||||
|
if trade.returns is not None:
|
||||||
|
abs_return = abs(trade.returns)
|
||||||
|
if abs_return >= 0.10: # 10%+
|
||||||
|
return ImportanceLevel.CRITICAL.value
|
||||||
|
elif abs_return >= 0.05: # 5%+
|
||||||
|
return ImportanceLevel.HIGH.value
|
||||||
|
elif abs_return >= 0.01: # 1%+
|
||||||
|
return ImportanceLevel.MEDIUM.value
|
||||||
|
else:
|
||||||
|
return ImportanceLevel.LOW.value
|
||||||
|
|
||||||
|
# For open trades, base on confidence
|
||||||
|
if trade.confidence >= 0.8:
|
||||||
|
return ImportanceLevel.HIGH.value
|
||||||
|
elif trade.confidence >= 0.5:
|
||||||
|
return ImportanceLevel.MEDIUM.value
|
||||||
|
else:
|
||||||
|
return ImportanceLevel.LOW.value
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Return total number of trades."""
|
||||||
|
return len(self._trades)
|
||||||
|
|
||||||
|
def clear(self) -> int:
|
||||||
|
"""Clear all trades.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of trades cleared
|
||||||
|
"""
|
||||||
|
count = len(self._trades)
|
||||||
|
self._trades.clear()
|
||||||
|
self._layered_memory.clear()
|
||||||
|
return count
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Serialize to dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"trades": [t.to_dict() for t in self._trades.values()],
|
||||||
|
"memory": self._layered_memory.to_dict(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(
|
||||||
|
cls,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
embedding_function=None,
|
||||||
|
) -> "TradeHistoryMemory":
|
||||||
|
"""Create from dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary representation
|
||||||
|
embedding_function: Optional embedding function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TradeHistoryMemory instance
|
||||||
|
"""
|
||||||
|
instance = cls(embedding_function=embedding_function)
|
||||||
|
|
||||||
|
# Restore trades
|
||||||
|
for trade_data in data.get("trades", []):
|
||||||
|
trade = TradeRecord.from_dict(trade_data)
|
||||||
|
instance._trades[trade.id] = trade
|
||||||
|
|
||||||
|
# Restore layered memory
|
||||||
|
if "memory" in data:
|
||||||
|
instance._layered_memory = LayeredMemory.from_dict(
|
||||||
|
data["memory"],
|
||||||
|
embedding_function=embedding_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
return instance
|
||||||
Loading…
Reference in New Issue