TradingAgents/tests/models/test_market_data.py

217 lines
6.4 KiB
Python

from datetime import datetime, date
from decimal import Decimal
import pytest
from tradingagents.models.market_data import (
OHLCVBar,
OHLCV,
TechnicalIndicators,
MarketSnapshot,
HistoricalDataRequest,
HistoricalDataResponse,
)
class TestOHLCVBar:
def test_valid_bar(self):
bar = OHLCVBar(
timestamp=datetime(2024, 1, 15, 9, 30),
open=Decimal("100.00"),
high=Decimal("105.00"),
low=Decimal("99.00"),
close=Decimal("103.50"),
volume=1000000,
)
assert bar.open == Decimal("100.00")
assert bar.high == Decimal("105.00")
assert bar.volume == 1000000
def test_bar_with_adjusted_close(self):
bar = OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
adjusted_close=Decimal("102.50"),
)
assert bar.adjusted_close == Decimal("102.50")
def test_invalid_negative_price(self):
with pytest.raises(ValueError):
OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("-100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
)
def test_invalid_negative_volume(self):
with pytest.raises(ValueError):
OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=-1000,
)
class TestOHLCV:
@pytest.fixture
def sample_bars(self):
return [
OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
),
OHLCVBar(
timestamp=datetime(2024, 1, 16),
open=Decimal("103"),
high=Decimal("108"),
low=Decimal("102"),
close=Decimal("107"),
volume=1200000,
),
OHLCVBar(
timestamp=datetime(2024, 1, 17),
open=Decimal("107"),
high=Decimal("110"),
low=Decimal("105"),
close=Decimal("109"),
volume=900000,
),
]
def test_ohlcv_creation(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
assert ohlcv.ticker == "AAPL"
assert len(ohlcv.bars) == 3
assert ohlcv.interval == "1d"
assert ohlcv.currency == "USD"
def test_start_end_dates(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
assert ohlcv.start_date == datetime(2024, 1, 15)
assert ohlcv.end_date == datetime(2024, 1, 17)
def test_empty_ohlcv(self):
ohlcv = OHLCV(ticker="AAPL", bars=[])
assert ohlcv.start_date is None
assert ohlcv.end_date is None
def test_get_bar(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
bar = ohlcv.get_bar(datetime(2024, 1, 16))
assert bar is not None
assert bar.close == Decimal("107")
def test_get_bar_not_found(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
bar = ohlcv.get_bar(datetime(2024, 1, 20))
assert bar is None
def test_slice(self, sample_bars):
ohlcv = OHLCV(ticker="AAPL", bars=sample_bars)
sliced = ohlcv.slice(datetime(2024, 1, 15), datetime(2024, 1, 16))
assert len(sliced.bars) == 2
assert sliced.ticker == "AAPL"
def test_invalid_ticker(self):
with pytest.raises(ValueError):
OHLCV(ticker="", bars=[])
class TestTechnicalIndicators:
def test_valid_indicators(self):
indicators = TechnicalIndicators(
timestamp=datetime(2024, 1, 15),
ticker="AAPL",
sma_50=Decimal("150.00"),
rsi_14=Decimal("65.5"),
macd=Decimal("2.5"),
)
assert indicators.sma_50 == Decimal("150.00")
assert indicators.rsi_14 == Decimal("65.5")
def test_rsi_bounds(self):
with pytest.raises(ValueError):
TechnicalIndicators(
timestamp=datetime(2024, 1, 15),
ticker="AAPL",
rsi_14=Decimal("150"),
)
def test_mfi_bounds(self):
with pytest.raises(ValueError):
TechnicalIndicators(
timestamp=datetime(2024, 1, 15),
ticker="AAPL",
mfi_14=Decimal("-10"),
)
class TestMarketSnapshot:
def test_snapshot_change_calculation(self):
bar = OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
)
snapshot = MarketSnapshot(
ticker="AAPL",
timestamp=datetime(2024, 1, 15),
bar=bar,
prev_close=Decimal("100"),
)
assert snapshot.change == Decimal("3")
assert snapshot.change_percent == Decimal("3")
def test_snapshot_no_prev_close(self):
bar = OHLCVBar(
timestamp=datetime(2024, 1, 15),
open=Decimal("100"),
high=Decimal("105"),
low=Decimal("99"),
close=Decimal("103"),
volume=1000000,
)
snapshot = MarketSnapshot(
ticker="AAPL",
timestamp=datetime(2024, 1, 15),
bar=bar,
)
assert snapshot.change is None
assert snapshot.change_percent is None
class TestHistoricalDataRequest:
def test_valid_request(self):
request = HistoricalDataRequest(
ticker="AAPL",
start_date=date(2024, 1, 1),
end_date=date(2024, 6, 30),
)
assert request.ticker == "AAPL"
assert request.include_indicators is True
def test_invalid_date_range(self):
with pytest.raises(ValueError):
HistoricalDataRequest(
ticker="AAPL",
start_date=date(2024, 6, 30),
end_date=date(2024, 1, 1),
)