feat(memory): add trade history memory with outcome tracking - Fixes #19

This commit is contained in:
Andrew Kaszubski 2025-12-26 20:22:03 +11:00
parent d72c214d4d
commit dbfcea3740
3 changed files with 1580 additions and 0 deletions

View File

@ -0,0 +1,740 @@
"""Tests for Issue #19: Trade History Memory.
This module tests the trade history memory for tracking:
- Trade outcomes (profit/loss)
- Agent reasoning
- Market context
- Pattern finding
"""
import pytest
from datetime import datetime, timedelta
from unittest.mock import MagicMock
from tradingagents.memory.trade_history import (
TradeHistoryMemory,
TradeRecord,
TradeOutcome,
TradeDirection,
SignalStrength,
AgentReasoning,
MarketContext,
)
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def memory():
"""Create a TradeHistoryMemory instance."""
return TradeHistoryMemory()
@pytest.fixture
def sample_reasoning():
"""Create sample agent reasoning."""
return AgentReasoning(
fundamentals="Strong earnings growth, P/E below industry average",
technical="Breaking out above 50-day MA with volume",
news="Positive analyst upgrades",
sentiment="Bullish social media sentiment",
bull_case="Undervalued with strong growth trajectory",
bear_case="Macro headwinds could impact demand",
research_conclusion="Buy on fundamental strength",
final_signal="STRONG_BUY",
)
@pytest.fixture
def sample_market_context():
"""Create sample market context."""
return MarketContext(
vix=18.5,
spy_return_1d=0.005,
sector_performance={"XLK": 0.01, "XLF": -0.005},
economic_regime="EXPANSION",
yield_curve_state="NORMAL",
macro_indicators={"gdp_growth": 2.5, "unemployment": 4.0},
)
@pytest.fixture
def sample_trade(sample_reasoning, sample_market_context):
"""Create a sample trade record."""
return TradeRecord.create(
symbol="AAPL",
direction=TradeDirection.LONG,
entry_price=150.0,
quantity=100,
signal_strength=SignalStrength.STRONG_BUY,
confidence=0.85,
reasoning=sample_reasoning,
market_context=sample_market_context,
tags=["tech", "earnings"],
)
@pytest.fixture
def multiple_trades():
"""Create multiple trade records for testing."""
now = datetime.now()
trades = []
# Profitable AAPL trade
trade1 = TradeRecord(
id="trade-1",
symbol="AAPL",
direction=TradeDirection.LONG,
entry_price=150.0,
entry_time=now - timedelta(days=10),
quantity=100,
reasoning=AgentReasoning(research_conclusion="Buy on earnings"),
tags=["tech", "earnings"],
)
trade1.close(165.0)
# Loss GOOGL trade
trade2 = TradeRecord(
id="trade-2",
symbol="GOOGL",
direction=TradeDirection.LONG,
entry_price=140.0,
entry_time=now - timedelta(days=5),
quantity=50,
reasoning=AgentReasoning(research_conclusion="Momentum buy"),
tags=["tech"],
)
trade2.close(130.0)
# Break-even MSFT trade
trade3 = TradeRecord(
id="trade-3",
symbol="MSFT",
direction=TradeDirection.LONG,
entry_price=350.0,
entry_time=now - timedelta(days=2),
quantity=20,
reasoning=AgentReasoning(research_conclusion="Range trade"),
tags=["tech"],
)
trade3.close(351.0)
# Open NVDA trade
trade4 = TradeRecord(
id="trade-4",
symbol="NVDA",
direction=TradeDirection.LONG,
entry_price=500.0,
entry_time=now - timedelta(hours=2),
quantity=10,
reasoning=AgentReasoning(research_conclusion="AI momentum"),
tags=["tech", "ai"],
)
trades = [trade1, trade2, trade3, trade4]
return trades
# =============================================================================
# TradeDirection Tests
# =============================================================================
class TestTradeDirection:
"""Tests for TradeDirection enum."""
def test_long_value(self):
"""LONG should have correct value."""
assert TradeDirection.LONG.value == "long"
def test_short_value(self):
"""SHORT should have correct value."""
assert TradeDirection.SHORT.value == "short"
def test_hold_value(self):
"""HOLD should have correct value."""
assert TradeDirection.HOLD.value == "hold"
# =============================================================================
# TradeOutcome Tests
# =============================================================================
class TestTradeOutcome:
"""Tests for TradeOutcome enum."""
def test_profitable(self):
"""PROFITABLE should have correct value."""
assert TradeOutcome.PROFITABLE.value == "profitable"
def test_loss(self):
"""LOSS should have correct value."""
assert TradeOutcome.LOSS.value == "loss"
def test_break_even(self):
"""BREAK_EVEN should have correct value."""
assert TradeOutcome.BREAK_EVEN.value == "break_even"
# =============================================================================
# SignalStrength Tests
# =============================================================================
class TestSignalStrength:
"""Tests for SignalStrength enum."""
def test_signal_values(self):
"""All signal strengths should have correct values."""
assert SignalStrength.STRONG_BUY.value == "strong_buy"
assert SignalStrength.BUY.value == "buy"
assert SignalStrength.NEUTRAL.value == "neutral"
assert SignalStrength.SELL.value == "sell"
assert SignalStrength.STRONG_SELL.value == "strong_sell"
# =============================================================================
# AgentReasoning Tests
# =============================================================================
class TestAgentReasoning:
"""Tests for AgentReasoning class."""
def test_default_reasoning(self):
"""Default reasoning should have None values."""
reasoning = AgentReasoning()
assert reasoning.fundamentals is None
assert reasoning.technical is None
assert reasoning.research_conclusion is None
def test_reasoning_with_values(self, sample_reasoning):
"""Reasoning should store values correctly."""
assert sample_reasoning.fundamentals is not None
assert sample_reasoning.research_conclusion == "Buy on fundamental strength"
def test_to_dict(self, sample_reasoning):
"""To dict should serialize correctly."""
data = sample_reasoning.to_dict()
assert data["fundamentals"] == sample_reasoning.fundamentals
assert data["research_conclusion"] == sample_reasoning.research_conclusion
def test_from_dict(self, sample_reasoning):
"""From dict should deserialize correctly."""
data = sample_reasoning.to_dict()
restored = AgentReasoning.from_dict(data)
assert restored.fundamentals == sample_reasoning.fundamentals
assert restored.research_conclusion == sample_reasoning.research_conclusion
def test_summary(self, sample_reasoning):
"""Summary should generate text."""
summary = sample_reasoning.summary()
assert "Fundamentals" in summary
assert "Conclusion" in summary
def test_empty_summary(self):
"""Empty reasoning should have fallback summary."""
reasoning = AgentReasoning()
summary = reasoning.summary()
assert summary == "No reasoning recorded"
# =============================================================================
# MarketContext Tests
# =============================================================================
class TestMarketContext:
"""Tests for MarketContext class."""
def test_default_context(self):
"""Default context should have None/empty values."""
context = MarketContext()
assert context.vix is None
assert context.sector_performance == {}
def test_context_with_values(self, sample_market_context):
"""Context should store values correctly."""
assert sample_market_context.vix == 18.5
assert sample_market_context.economic_regime == "EXPANSION"
assert "XLK" in sample_market_context.sector_performance
def test_to_dict(self, sample_market_context):
"""To dict should serialize correctly."""
data = sample_market_context.to_dict()
assert data["vix"] == 18.5
assert data["economic_regime"] == "EXPANSION"
def test_from_dict(self, sample_market_context):
"""From dict should deserialize correctly."""
data = sample_market_context.to_dict()
restored = MarketContext.from_dict(data)
assert restored.vix == sample_market_context.vix
assert restored.economic_regime == sample_market_context.economic_regime
def test_summary(self, sample_market_context):
"""Summary should generate text."""
summary = sample_market_context.summary()
assert "VIX" in summary
assert "Regime" in summary
# =============================================================================
# TradeRecord Tests
# =============================================================================
class TestTradeRecord:
"""Tests for TradeRecord class."""
def test_create_trade(self):
"""Create should generate a valid trade record."""
trade = TradeRecord.create(
symbol="AAPL",
direction=TradeDirection.LONG,
entry_price=150.0,
)
assert trade.symbol == "AAPL"
assert trade.direction == TradeDirection.LONG
assert trade.entry_price == 150.0
assert trade.id is not None
assert trade.entry_time is not None
assert trade.is_open()
def test_create_with_all_args(self, sample_reasoning, sample_market_context):
"""Create with all arguments should work."""
trade = TradeRecord.create(
symbol="AAPL",
direction=TradeDirection.LONG,
entry_price=150.0,
quantity=100,
signal_strength=SignalStrength.STRONG_BUY,
confidence=0.85,
reasoning=sample_reasoning,
market_context=sample_market_context,
tags=["tech"],
)
assert trade.quantity == 100
assert trade.signal_strength == SignalStrength.STRONG_BUY
assert trade.confidence == 0.85
assert trade.reasoning == sample_reasoning
assert trade.market_context == sample_market_context
def test_close_profitable(self, sample_trade):
"""Closing at higher price should be profitable."""
sample_trade.close(165.0)
assert sample_trade.exit_price == 165.0
assert sample_trade.exit_time is not None
assert sample_trade.returns is not None
assert sample_trade.returns > 0
assert sample_trade.outcome == TradeOutcome.PROFITABLE
assert not sample_trade.is_open()
def test_close_loss(self, sample_trade):
"""Closing at lower price should be a loss."""
sample_trade.close(140.0)
assert sample_trade.returns < 0
assert sample_trade.outcome == TradeOutcome.LOSS
def test_close_break_even(self, sample_trade):
"""Closing at same price should be break even."""
sample_trade.close(150.2) # Within 0.5% threshold
assert abs(sample_trade.returns) < 0.005
assert sample_trade.outcome == TradeOutcome.BREAK_EVEN
def test_close_short_profitable(self):
"""Short trade profitable when price goes down."""
trade = TradeRecord.create(
symbol="AAPL",
direction=TradeDirection.SHORT,
entry_price=150.0,
)
trade.close(140.0)
assert trade.returns > 0
assert trade.outcome == TradeOutcome.PROFITABLE
def test_close_short_loss(self):
"""Short trade loss when price goes up."""
trade = TradeRecord.create(
symbol="AAPL",
direction=TradeDirection.SHORT,
entry_price=150.0,
)
trade.close(160.0)
assert trade.returns < 0
assert trade.outcome == TradeOutcome.LOSS
def test_pnl_calculation(self, sample_trade):
"""PnL should be calculated correctly."""
sample_trade.close(165.0)
expected_return = (165 - 150) / 150
expected_pnl = expected_return * 150 * 100
assert abs(sample_trade.pnl - expected_pnl) < 0.01
def test_holding_period(self, sample_trade):
"""Holding period should be calculated correctly."""
assert sample_trade.holding_period_days() is None # Still open
sample_trade.close(165.0)
holding = sample_trade.holding_period_days()
assert holding is not None
assert holding >= 0
def test_to_memory_content(self, sample_trade):
"""Memory content should include key info."""
sample_trade.close(165.0)
content = sample_trade.to_memory_content()
assert "AAPL" in content
assert "long" in content
assert "profitable" in content
def test_to_dict(self, sample_trade):
"""To dict should serialize correctly."""
data = sample_trade.to_dict()
assert data["symbol"] == "AAPL"
assert data["direction"] == "long"
assert data["entry_price"] == 150.0
assert "reasoning" in data
assert "market_context" in data
def test_from_dict(self, sample_trade):
"""From dict should deserialize correctly."""
sample_trade.close(165.0)
data = sample_trade.to_dict()
restored = TradeRecord.from_dict(data)
assert restored.symbol == sample_trade.symbol
assert restored.direction == sample_trade.direction
assert restored.entry_price == sample_trade.entry_price
assert restored.exit_price == sample_trade.exit_price
assert restored.outcome == sample_trade.outcome
# =============================================================================
# TradeHistoryMemory Basic Tests
# =============================================================================
class TestTradeHistoryMemoryBasic:
"""Basic tests for TradeHistoryMemory."""
def test_create_empty(self, memory):
"""Empty memory should have zero trades."""
assert memory.count() == 0
def test_record_trade(self, memory, sample_trade):
"""Recording a trade should increase count."""
trade_id = memory.record_trade(sample_trade)
assert trade_id == sample_trade.id
assert memory.count() == 1
def test_get_trade(self, memory, sample_trade):
"""Get should return recorded trade."""
memory.record_trade(sample_trade)
retrieved = memory.get_trade(sample_trade.id)
assert retrieved is not None
assert retrieved.symbol == sample_trade.symbol
def test_get_nonexistent(self, memory):
"""Get nonexistent trade should return None."""
result = memory.get_trade("nonexistent")
assert result is None
def test_close_trade(self, memory, sample_trade):
"""Closing a trade should update it."""
memory.record_trade(sample_trade)
closed = memory.close_trade(sample_trade.id, exit_price=165.0)
assert closed is not None
assert closed.exit_price == 165.0
assert closed.outcome is not None
def test_close_with_lessons(self, memory, sample_trade):
"""Closing with lessons should store them."""
memory.record_trade(sample_trade)
memory.close_trade(
sample_trade.id,
exit_price=165.0,
lessons_learned="Wait for confirmation before entry",
)
trade = memory.get_trade(sample_trade.id)
assert trade.lessons_learned == "Wait for confirmation before entry"
def test_clear(self, memory, multiple_trades):
"""Clear should remove all trades."""
for trade in multiple_trades:
memory.record_trade(trade)
count = memory.clear()
assert count == len(multiple_trades)
assert memory.count() == 0
# =============================================================================
# TradeHistoryMemory Query Tests
# =============================================================================
class TestTradeHistoryMemoryQueries:
"""Query tests for TradeHistoryMemory."""
def test_get_open_trades(self, memory, multiple_trades):
"""Get open trades should filter correctly."""
for trade in multiple_trades:
memory.record_trade(trade)
open_trades = memory.get_open_trades()
assert len(open_trades) == 1 # Only NVDA is open
assert open_trades[0].symbol == "NVDA"
def test_get_closed_trades(self, memory, multiple_trades):
"""Get closed trades should filter correctly."""
for trade in multiple_trades:
memory.record_trade(trade)
closed_trades = memory.get_closed_trades()
assert len(closed_trades) == 3 # AAPL, GOOGL, MSFT
def test_get_trades_by_symbol(self, memory, multiple_trades):
"""Get trades by symbol should filter correctly."""
for trade in multiple_trades:
memory.record_trade(trade)
aapl_trades = memory.get_trades_by_symbol("AAPL")
assert len(aapl_trades) == 1
assert aapl_trades[0].symbol == "AAPL"
def test_find_similar_trades(self, memory, multiple_trades):
"""Find similar trades should return relevant trades."""
for trade in multiple_trades:
memory.record_trade(trade)
similar = memory.find_similar_trades(
query="tech stock momentum earnings",
top_k=3,
)
assert len(similar) >= 1
def test_find_profitable_patterns(self, memory, multiple_trades):
"""Find profitable patterns should return winners."""
for trade in multiple_trades:
memory.record_trade(trade)
profitable = memory.find_profitable_patterns(
query="tech stock buy",
min_return=0.05,
top_k=5,
)
# AAPL had 10% return
assert len(profitable) >= 1
for trade in profitable:
assert trade.returns >= 0.05
def test_find_losing_patterns(self, memory, multiple_trades):
"""Find losing patterns should return losers."""
for trade in multiple_trades:
memory.record_trade(trade)
losers = memory.find_losing_patterns(
query="tech stock momentum",
max_return=-0.05,
top_k=5,
)
# GOOGL had -7% return
assert len(losers) >= 1
for trade in losers:
assert trade.returns <= -0.05
# =============================================================================
# TradeHistoryMemory Statistics Tests
# =============================================================================
class TestTradeHistoryMemoryStats:
"""Statistics tests for TradeHistoryMemory."""
def test_empty_statistics(self, memory):
"""Empty memory should have zero stats."""
stats = memory.get_statistics()
assert stats["total_trades"] == 0
assert stats["win_rate"] == 0.0
def test_statistics_with_trades(self, memory, multiple_trades):
"""Statistics should reflect trade data."""
for trade in multiple_trades:
memory.record_trade(trade)
stats = memory.get_statistics()
assert stats["total_trades"] == 4
assert stats["open_trades"] == 1
assert stats["closed_trades"] == 3
assert stats["win_rate"] > 0 # At least AAPL was profitable
def test_symbol_statistics(self, memory, multiple_trades):
"""Symbol statistics should be calculated correctly."""
for trade in multiple_trades:
memory.record_trade(trade)
stats = memory.get_symbol_statistics("AAPL")
assert stats["symbol"] == "AAPL"
assert stats["total_trades"] == 1
assert stats["closed_trades"] == 1
assert stats["win_rate"] == 1.0 # AAPL was profitable
# =============================================================================
# TradeHistoryMemory Serialization Tests
# =============================================================================
class TestTradeHistoryMemorySerialization:
"""Serialization tests for TradeHistoryMemory."""
def test_to_dict(self, memory, multiple_trades):
"""To dict should serialize correctly."""
for trade in multiple_trades:
memory.record_trade(trade)
data = memory.to_dict()
assert "trades" in data
assert "memory" in data
assert len(data["trades"]) == len(multiple_trades)
def test_from_dict(self, memory, multiple_trades):
"""From dict should deserialize correctly."""
for trade in multiple_trades:
memory.record_trade(trade)
data = memory.to_dict()
restored = TradeHistoryMemory.from_dict(data)
assert restored.count() == len(multiple_trades)
def test_roundtrip(self, memory, sample_trade):
"""Roundtrip serialization should preserve data."""
memory.record_trade(sample_trade)
memory.close_trade(sample_trade.id, exit_price=165.0)
data = memory.to_dict()
restored = TradeHistoryMemory.from_dict(data)
original = memory.get_trade(sample_trade.id)
restored_trade = restored.get_trade(sample_trade.id)
assert restored_trade is not None
assert restored_trade.symbol == original.symbol
assert restored_trade.exit_price == original.exit_price
assert restored_trade.outcome == original.outcome
# =============================================================================
# Integration Tests
# =============================================================================
class TestIntegration:
"""Integration tests for trade history workflow."""
def test_full_trade_lifecycle(self, memory):
"""Test complete trade lifecycle."""
# 1. Create reasoning
reasoning = AgentReasoning(
fundamentals="Strong quarterly earnings",
technical="Golden cross on daily chart",
research_conclusion="Buy for momentum continuation",
)
# 2. Create market context
context = MarketContext(
vix=15.0,
economic_regime="EXPANSION",
)
# 3. Open trade
trade = TradeRecord.create(
symbol="MSFT",
direction=TradeDirection.LONG,
entry_price=400.0,
quantity=50,
signal_strength=SignalStrength.BUY,
confidence=0.75,
reasoning=reasoning,
market_context=context,
tags=["tech", "earnings"],
)
memory.record_trade(trade)
# 4. Verify open
assert memory.count() == 1
assert len(memory.get_open_trades()) == 1
# 5. Close with profit
memory.close_trade(
trade.id,
exit_price=440.0,
lessons_learned="Good timing on earnings play",
)
# 6. Verify closed
assert len(memory.get_closed_trades()) == 1
closed = memory.get_trade(trade.id)
assert closed.outcome == TradeOutcome.PROFITABLE
assert closed.returns == pytest.approx(0.10, rel=0.01)
# 7. Find similar trades
similar = memory.find_similar_trades(
query="tech earnings momentum",
top_k=1,
)
assert len(similar) == 1
assert similar[0].symbol == "MSFT"
# 8. Check statistics
stats = memory.get_statistics()
assert stats["win_rate"] == 1.0
assert stats["avg_return"] > 0
def test_learning_from_history(self, memory):
"""Test learning patterns from trade history."""
# Record several trades
trades = [
("AAPL", "earnings beat", 150.0, 165.0), # +10%
("GOOGL", "earnings beat", 140.0, 147.0), # +5%
("MSFT", "earnings miss", 350.0, 315.0), # -10%
("NVDA", "earnings beat", 500.0, 550.0), # +10%
]
for symbol, pattern, entry, exit_price in trades:
trade = TradeRecord.create(
symbol=symbol,
direction=TradeDirection.LONG,
entry_price=entry,
reasoning=AgentReasoning(research_conclusion=f"Trade on {pattern}"),
tags=["earnings"],
)
memory.record_trade(trade)
memory.close_trade(trade.id, exit_price=exit_price)
# Query for earnings beat patterns
profitable = memory.find_profitable_patterns(
query="earnings beat momentum",
min_return=0.05,
)
# Should find the profitable earnings beat trades
assert len(profitable) >= 2
# Query for losses to avoid
losers = memory.find_losing_patterns(
query="earnings trade",
max_return=-0.05,
)
# Should find the earnings miss trade
assert len(losers) >= 1

View File

@ -6,6 +6,7 @@ This module provides a layered memory system with three scoring dimensions:
- Importance: Significance weighting for impactful events
Issue #18: Layered memory - recency, relevancy, importance scoring
Issue #19: Trade history memory - outcomes, agent reasoning
"""
from .layered_memory import (
@ -17,11 +18,30 @@ from .layered_memory import (
ImportanceLevel,
)
from .trade_history import (
TradeHistoryMemory,
TradeRecord,
TradeOutcome,
TradeDirection,
SignalStrength,
AgentReasoning,
MarketContext,
)
__all__ = [
# Layered Memory (Issue #18)
"LayeredMemory",
"MemoryEntry",
"MemoryConfig",
"ScoringWeights",
"DecayFunction",
"ImportanceLevel",
# Trade History (Issue #19)
"TradeHistoryMemory",
"TradeRecord",
"TradeOutcome",
"TradeDirection",
"SignalStrength",
"AgentReasoning",
"MarketContext",
]

View File

@ -0,0 +1,820 @@
"""Trade History Memory for learning from past trade outcomes.
This module provides specialized memory for tracking and learning from trades:
- Trade outcomes (profit/loss, returns)
- Agent reasoning and signals
- Entry/exit conditions
- Market context at time of trade
Issue #19: [MEM-18] Trade history memory - outcomes, agent reasoning
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Any
import uuid
from .layered_memory import (
LayeredMemory,
MemoryEntry,
MemoryConfig,
ScoringWeights,
ImportanceLevel,
)
class TradeOutcome(Enum):
"""Trade outcome categories."""
PROFITABLE = "profitable" # Positive return
BREAK_EVEN = "break_even" # ~0% return
LOSS = "loss" # Negative return
STOPPED_OUT = "stopped_out" # Hit stop loss
TARGET_HIT = "target_hit" # Hit profit target
class TradeDirection(Enum):
"""Trade direction."""
LONG = "long"
SHORT = "short"
HOLD = "hold"
class SignalStrength(Enum):
"""Signal strength levels."""
STRONG_BUY = "strong_buy"
BUY = "buy"
NEUTRAL = "neutral"
SELL = "sell"
STRONG_SELL = "strong_sell"
@dataclass
class AgentReasoning:
"""Captures reasoning from each agent in the trading workflow.
Attributes:
fundamentals: Fundamentals analyst reasoning
technical: Technical/market analyst reasoning
news: News analyst reasoning
sentiment: Social media sentiment reasoning
momentum: Momentum analyst reasoning (if enabled)
macro: Macro analyst reasoning (if enabled)
correlation: Correlation analyst reasoning (if enabled)
bull_case: Bull researcher arguments
bear_case: Bear researcher arguments
research_conclusion: Research manager decision
risk_assessment: Risk manager assessment
final_signal: Final trading signal
"""
fundamentals: Optional[str] = None
technical: Optional[str] = None
news: Optional[str] = None
sentiment: Optional[str] = None
momentum: Optional[str] = None
macro: Optional[str] = None
correlation: Optional[str] = None
bull_case: Optional[str] = None
bear_case: Optional[str] = None
research_conclusion: Optional[str] = None
risk_assessment: Optional[str] = None
final_signal: Optional[str] = None
def to_dict(self) -> Dict[str, Optional[str]]:
"""Convert to dictionary."""
return {
"fundamentals": self.fundamentals,
"technical": self.technical,
"news": self.news,
"sentiment": self.sentiment,
"momentum": self.momentum,
"macro": self.macro,
"correlation": self.correlation,
"bull_case": self.bull_case,
"bear_case": self.bear_case,
"research_conclusion": self.research_conclusion,
"risk_assessment": self.risk_assessment,
"final_signal": self.final_signal,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentReasoning":
"""Create from dictionary."""
return cls(
fundamentals=data.get("fundamentals"),
technical=data.get("technical"),
news=data.get("news"),
sentiment=data.get("sentiment"),
momentum=data.get("momentum"),
macro=data.get("macro"),
correlation=data.get("correlation"),
bull_case=data.get("bull_case"),
bear_case=data.get("bear_case"),
research_conclusion=data.get("research_conclusion"),
risk_assessment=data.get("risk_assessment"),
final_signal=data.get("final_signal"),
)
def summary(self) -> str:
"""Generate a text summary of the reasoning."""
parts = []
if self.fundamentals:
parts.append(f"Fundamentals: {self.fundamentals[:100]}...")
if self.technical:
parts.append(f"Technical: {self.technical[:100]}...")
if self.bull_case:
parts.append(f"Bull: {self.bull_case[:100]}...")
if self.bear_case:
parts.append(f"Bear: {self.bear_case[:100]}...")
if self.research_conclusion:
parts.append(f"Conclusion: {self.research_conclusion[:100]}...")
return " | ".join(parts) if parts else "No reasoning recorded"
@dataclass
class MarketContext:
"""Market conditions at time of trade.
Attributes:
vix: VIX volatility index level
spy_return_1d: SPY 1-day return
sector_performance: Dict of sector returns
economic_regime: Detected economic regime
yield_curve_state: Yield curve status
macro_indicators: Key macro indicators
"""
vix: Optional[float] = None
spy_return_1d: Optional[float] = None
sector_performance: Dict[str, float] = field(default_factory=dict)
economic_regime: Optional[str] = None
yield_curve_state: Optional[str] = None
macro_indicators: Dict[str, float] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"vix": self.vix,
"spy_return_1d": self.spy_return_1d,
"sector_performance": self.sector_performance,
"economic_regime": self.economic_regime,
"yield_curve_state": self.yield_curve_state,
"macro_indicators": self.macro_indicators,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MarketContext":
"""Create from dictionary."""
return cls(
vix=data.get("vix"),
spy_return_1d=data.get("spy_return_1d"),
sector_performance=data.get("sector_performance", {}),
economic_regime=data.get("economic_regime"),
yield_curve_state=data.get("yield_curve_state"),
macro_indicators=data.get("macro_indicators", {}),
)
def summary(self) -> str:
"""Generate a text summary of market context."""
parts = []
if self.vix is not None:
parts.append(f"VIX: {self.vix:.1f}")
if self.economic_regime:
parts.append(f"Regime: {self.economic_regime}")
if self.yield_curve_state:
parts.append(f"Yield Curve: {self.yield_curve_state}")
return " | ".join(parts) if parts else "No market context"
@dataclass
class TradeRecord:
"""Complete record of a trade including reasoning and outcome.
Attributes:
id: Unique trade ID
symbol: Trading symbol
direction: Trade direction (long/short/hold)
entry_price: Entry price
exit_price: Exit price (if closed)
entry_time: Entry timestamp
exit_time: Exit timestamp (if closed)
quantity: Number of shares/contracts
returns: Percentage return
pnl: Dollar profit/loss
outcome: Trade outcome category
signal_strength: Original signal strength
confidence: Confidence score (0-1)
reasoning: Agent reasoning captured
market_context: Market conditions at entry
lessons_learned: Post-trade lessons (added later)
tags: Trade tags for filtering
"""
id: str
symbol: str
direction: TradeDirection
entry_price: float
entry_time: datetime
quantity: float = 1.0
exit_price: Optional[float] = None
exit_time: Optional[datetime] = None
returns: Optional[float] = None
pnl: Optional[float] = None
outcome: Optional[TradeOutcome] = None
signal_strength: SignalStrength = SignalStrength.NEUTRAL
confidence: float = 0.5
reasoning: AgentReasoning = field(default_factory=AgentReasoning)
market_context: MarketContext = field(default_factory=MarketContext)
lessons_learned: Optional[str] = None
tags: List[str] = field(default_factory=list)
@classmethod
def create(
cls,
symbol: str,
direction: TradeDirection,
entry_price: float,
quantity: float = 1.0,
signal_strength: SignalStrength = SignalStrength.NEUTRAL,
confidence: float = 0.5,
reasoning: Optional[AgentReasoning] = None,
market_context: Optional[MarketContext] = None,
tags: Optional[List[str]] = None,
) -> "TradeRecord":
"""Create a new trade record."""
return cls(
id=str(uuid.uuid4()),
symbol=symbol,
direction=direction,
entry_price=entry_price,
entry_time=datetime.now(),
quantity=quantity,
signal_strength=signal_strength,
confidence=confidence,
reasoning=reasoning or AgentReasoning(),
market_context=market_context or MarketContext(),
tags=tags or [],
)
def close(
self,
exit_price: float,
exit_time: Optional[datetime] = None,
) -> "TradeRecord":
"""Close the trade and calculate returns.
Args:
exit_price: Exit price
exit_time: Exit timestamp (default: now)
Returns:
Self with updated exit info
"""
self.exit_price = exit_price
self.exit_time = exit_time or datetime.now()
# Calculate returns
if self.direction == TradeDirection.LONG:
self.returns = (exit_price - self.entry_price) / self.entry_price
elif self.direction == TradeDirection.SHORT:
self.returns = (self.entry_price - exit_price) / self.entry_price
else:
self.returns = 0.0
# Calculate PnL
self.pnl = self.returns * self.entry_price * self.quantity
# Determine outcome
if self.returns > 0.005: # > 0.5%
self.outcome = TradeOutcome.PROFITABLE
elif self.returns < -0.005: # < -0.5%
self.outcome = TradeOutcome.LOSS
else:
self.outcome = TradeOutcome.BREAK_EVEN
return self
def is_open(self) -> bool:
"""Check if trade is still open."""
return self.exit_time is None
def holding_period_days(self) -> Optional[float]:
"""Calculate holding period in days."""
if self.exit_time is None:
return None
delta = self.exit_time - self.entry_time
return delta.total_seconds() / 86400
def to_memory_content(self) -> str:
"""Generate memory content for this trade."""
parts = [
f"Trade {self.direction.value} {self.symbol} at ${self.entry_price:.2f}",
]
if self.outcome:
parts.append(f"Outcome: {self.outcome.value}")
if self.returns is not None:
parts.append(f"Return: {self.returns * 100:.2f}%")
if self.reasoning.research_conclusion:
parts.append(f"Reasoning: {self.reasoning.research_conclusion}")
if self.market_context.economic_regime:
parts.append(f"Regime: {self.market_context.economic_regime}")
return " | ".join(parts)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"id": self.id,
"symbol": self.symbol,
"direction": self.direction.value,
"entry_price": self.entry_price,
"exit_price": self.exit_price,
"entry_time": self.entry_time.isoformat(),
"exit_time": self.exit_time.isoformat() if self.exit_time else None,
"quantity": self.quantity,
"returns": self.returns,
"pnl": self.pnl,
"outcome": self.outcome.value if self.outcome else None,
"signal_strength": self.signal_strength.value,
"confidence": self.confidence,
"reasoning": self.reasoning.to_dict(),
"market_context": self.market_context.to_dict(),
"lessons_learned": self.lessons_learned,
"tags": self.tags,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TradeRecord":
"""Create from dictionary."""
return cls(
id=data["id"],
symbol=data["symbol"],
direction=TradeDirection(data["direction"]),
entry_price=data["entry_price"],
exit_price=data.get("exit_price"),
entry_time=datetime.fromisoformat(data["entry_time"]),
exit_time=datetime.fromisoformat(data["exit_time"]) if data.get("exit_time") else None,
quantity=data.get("quantity", 1.0),
returns=data.get("returns"),
pnl=data.get("pnl"),
outcome=TradeOutcome(data["outcome"]) if data.get("outcome") else None,
signal_strength=SignalStrength(data.get("signal_strength", "neutral")),
confidence=data.get("confidence", 0.5),
reasoning=AgentReasoning.from_dict(data.get("reasoning", {})),
market_context=MarketContext.from_dict(data.get("market_context", {})),
lessons_learned=data.get("lessons_learned"),
tags=data.get("tags", []),
)
class TradeHistoryMemory:
"""Memory system specialized for trade history and learning.
This class combines trade record storage with the layered memory system
for intelligent retrieval of past trades based on similarity.
Example:
>>> memory = TradeHistoryMemory()
>>>
>>> # Record a trade
>>> trade = TradeRecord.create(
... symbol="AAPL",
... direction=TradeDirection.LONG,
... entry_price=150.0,
... reasoning=AgentReasoning(
... fundamentals="Strong earnings",
... research_conclusion="Buy on earnings momentum",
... ),
... )
>>> memory.record_trade(trade)
>>>
>>> # Close the trade
>>> memory.close_trade(trade.id, exit_price=165.0)
>>>
>>> # Find similar past trades
>>> similar = memory.find_similar_trades(
... query="Apple earnings momentum",
... top_k=5,
... )
"""
def __init__(
self,
config: Optional[MemoryConfig] = None,
embedding_function=None,
):
"""Initialize trade history memory.
Args:
config: Memory configuration
embedding_function: Optional embedding function for similarity
"""
# Default config with weights tuned for trade history
if config is None:
config = MemoryConfig(
weights=ScoringWeights(
recency=0.25, # Recent trades somewhat important
relevancy=0.45, # Similarity most important
importance=0.30, # Outcome importance matters
),
)
self._layered_memory = LayeredMemory(
config=config,
embedding_function=embedding_function,
)
self._trades: Dict[str, TradeRecord] = {}
def record_trade(self, trade: TradeRecord) -> str:
"""Record a new trade.
Args:
trade: The trade record to store
Returns:
Trade ID
"""
self._trades[trade.id] = trade
# Create memory entry for the trade
importance = self._calculate_trade_importance(trade)
entry = MemoryEntry.create(
content=trade.to_memory_content(),
metadata={
"trade_id": trade.id,
"symbol": trade.symbol,
"direction": trade.direction.value,
"outcome": trade.outcome.value if trade.outcome else None,
"returns": trade.returns,
"reasoning_summary": trade.reasoning.summary(),
},
importance=importance,
tags=trade.tags + [trade.symbol, trade.direction.value],
timestamp=trade.entry_time,
)
entry.id = trade.id # Use trade ID as memory ID
self._layered_memory.add(entry)
return trade.id
def close_trade(
self,
trade_id: str,
exit_price: float,
lessons_learned: Optional[str] = None,
) -> Optional[TradeRecord]:
"""Close an open trade and update memory.
Args:
trade_id: ID of the trade to close
exit_price: Exit price
lessons_learned: Optional lessons from the trade
Returns:
Updated trade record or None if not found
"""
trade = self._trades.get(trade_id)
if trade is None:
return None
# Close the trade
trade.close(exit_price)
if lessons_learned:
trade.lessons_learned = lessons_learned
# Update memory with outcome
importance = self._calculate_trade_importance(trade)
self._layered_memory.update_importance(trade_id, importance)
# Update memory entry content
entry = self._layered_memory.get(trade_id)
if entry:
entry.metadata["outcome"] = trade.outcome.value if trade.outcome else None
entry.metadata["returns"] = trade.returns
if lessons_learned:
entry.metadata["lessons_learned"] = lessons_learned
return trade
def get_trade(self, trade_id: str) -> Optional[TradeRecord]:
"""Get a trade by ID.
Args:
trade_id: Trade ID
Returns:
Trade record or None
"""
return self._trades.get(trade_id)
def get_open_trades(self) -> List[TradeRecord]:
"""Get all open trades.
Returns:
List of open trade records
"""
return [t for t in self._trades.values() if t.is_open()]
def get_closed_trades(self) -> List[TradeRecord]:
"""Get all closed trades.
Returns:
List of closed trade records
"""
return [t for t in self._trades.values() if not t.is_open()]
def get_trades_by_symbol(self, symbol: str) -> List[TradeRecord]:
"""Get all trades for a symbol.
Args:
symbol: Trading symbol
Returns:
List of trade records for the symbol
"""
return [t for t in self._trades.values() if t.symbol == symbol]
def find_similar_trades(
self,
query: str,
top_k: int = 5,
symbol: Optional[str] = None,
outcome: Optional[TradeOutcome] = None,
) -> List[TradeRecord]:
"""Find similar past trades.
Args:
query: Query describing the current situation
top_k: Maximum number of results
symbol: Optional filter by symbol
outcome: Optional filter by outcome
Returns:
List of similar trade records
"""
# Build tags filter
tags = []
if symbol:
tags.append(symbol)
# Retrieve from layered memory
results = self._layered_memory.retrieve(
query=query,
top_k=top_k * 2, # Get more to filter
tags=tags if tags else None,
)
# Convert to trade records and filter
trades = []
for scored in results:
trade_id = scored.entry.metadata.get("trade_id")
if trade_id and trade_id in self._trades:
trade = self._trades[trade_id]
# Apply outcome filter
if outcome and trade.outcome != outcome:
continue
trades.append(trade)
if len(trades) >= top_k:
break
return trades
def find_profitable_patterns(
self,
query: str,
min_return: float = 0.0,
top_k: int = 5,
) -> List[TradeRecord]:
"""Find profitable trades similar to the query.
Args:
query: Query describing the current situation
min_return: Minimum return filter
top_k: Maximum number of results
Returns:
List of profitable trade records
"""
results = self._layered_memory.retrieve(query=query, top_k=top_k * 3)
trades = []
for scored in results:
trade_id = scored.entry.metadata.get("trade_id")
if trade_id and trade_id in self._trades:
trade = self._trades[trade_id]
if trade.returns is not None and trade.returns >= min_return:
trades.append(trade)
if len(trades) >= top_k:
break
return trades
def find_losing_patterns(
self,
query: str,
max_return: float = 0.0,
top_k: int = 5,
) -> List[TradeRecord]:
"""Find losing trades similar to the query (for learning what to avoid).
Args:
query: Query describing the current situation
max_return: Maximum return filter
top_k: Maximum number of results
Returns:
List of losing trade records
"""
results = self._layered_memory.retrieve(query=query, top_k=top_k * 3)
trades = []
for scored in results:
trade_id = scored.entry.metadata.get("trade_id")
if trade_id and trade_id in self._trades:
trade = self._trades[trade_id]
if trade.returns is not None and trade.returns <= max_return:
trades.append(trade)
if len(trades) >= top_k:
break
return trades
def get_statistics(self) -> Dict[str, Any]:
"""Get trade history statistics.
Returns:
Dictionary with statistics
"""
closed_trades = self.get_closed_trades()
if not closed_trades:
return {
"total_trades": len(self._trades),
"open_trades": len(self.get_open_trades()),
"closed_trades": 0,
"win_rate": 0.0,
"avg_return": 0.0,
"total_pnl": 0.0,
"best_trade": None,
"worst_trade": None,
}
returns = [t.returns for t in closed_trades if t.returns is not None]
pnls = [t.pnl for t in closed_trades if t.pnl is not None]
winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE]
# Find best and worst
best_trade = max(closed_trades, key=lambda t: t.returns or 0, default=None)
worst_trade = min(closed_trades, key=lambda t: t.returns or 0, default=None)
# Outcome distribution
outcome_dist = {}
for trade in closed_trades:
if trade.outcome:
outcome_dist[trade.outcome.value] = outcome_dist.get(trade.outcome.value, 0) + 1
return {
"total_trades": len(self._trades),
"open_trades": len(self.get_open_trades()),
"closed_trades": len(closed_trades),
"win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0,
"avg_return": sum(returns) / len(returns) if returns else 0.0,
"total_pnl": sum(pnls) if pnls else 0.0,
"best_trade": best_trade.to_dict() if best_trade else None,
"worst_trade": worst_trade.to_dict() if worst_trade else None,
"outcome_distribution": outcome_dist,
}
def get_symbol_statistics(self, symbol: str) -> Dict[str, Any]:
"""Get statistics for a specific symbol.
Args:
symbol: Trading symbol
Returns:
Dictionary with symbol-specific statistics
"""
symbol_trades = self.get_trades_by_symbol(symbol)
closed_trades = [t for t in symbol_trades if not t.is_open()]
if not closed_trades:
return {
"symbol": symbol,
"total_trades": len(symbol_trades),
"open_trades": len([t for t in symbol_trades if t.is_open()]),
"closed_trades": 0,
"win_rate": 0.0,
"avg_return": 0.0,
}
returns = [t.returns for t in closed_trades if t.returns is not None]
winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE]
return {
"symbol": symbol,
"total_trades": len(symbol_trades),
"open_trades": len([t for t in symbol_trades if t.is_open()]),
"closed_trades": len(closed_trades),
"win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0,
"avg_return": sum(returns) / len(returns) if returns else 0.0,
}
def _calculate_trade_importance(self, trade: TradeRecord) -> float:
"""Calculate importance score for a trade.
Args:
trade: Trade record
Returns:
Importance score [0, 1]
"""
# Base on returns if closed
if trade.returns is not None:
abs_return = abs(trade.returns)
if abs_return >= 0.10: # 10%+
return ImportanceLevel.CRITICAL.value
elif abs_return >= 0.05: # 5%+
return ImportanceLevel.HIGH.value
elif abs_return >= 0.01: # 1%+
return ImportanceLevel.MEDIUM.value
else:
return ImportanceLevel.LOW.value
# For open trades, base on confidence
if trade.confidence >= 0.8:
return ImportanceLevel.HIGH.value
elif trade.confidence >= 0.5:
return ImportanceLevel.MEDIUM.value
else:
return ImportanceLevel.LOW.value
def count(self) -> int:
"""Return total number of trades."""
return len(self._trades)
def clear(self) -> int:
"""Clear all trades.
Returns:
Number of trades cleared
"""
count = len(self._trades)
self._trades.clear()
self._layered_memory.clear()
return count
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dictionary.
Returns:
Dictionary representation
"""
return {
"trades": [t.to_dict() for t in self._trades.values()],
"memory": self._layered_memory.to_dict(),
}
@classmethod
def from_dict(
cls,
data: Dict[str, Any],
embedding_function=None,
) -> "TradeHistoryMemory":
"""Create from dictionary.
Args:
data: Dictionary representation
embedding_function: Optional embedding function
Returns:
TradeHistoryMemory instance
"""
instance = cls(embedding_function=embedding_function)
# Restore trades
for trade_data in data.get("trades", []):
trade = TradeRecord.from_dict(trade_data)
instance._trades[trade.id] = trade
# Restore layered memory
if "memory" in data:
instance._layered_memory = LayeredMemory.from_dict(
data["memory"],
embedding_function=embedding_function,
)
return instance