470 lines
16 KiB
Python
470 lines
16 KiB
Python
import pytest
|
|
import math
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import patch
|
|
from tradingagents.agents.discovery import NewsArticle, EventCategory, Sector
|
|
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
|
|
|
|
|
class TestFrequencyCalculation:
|
|
def test_frequency_calculation_unique_article_count(self):
|
|
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
|
|
|
now = datetime.now()
|
|
articles = [
|
|
NewsArticle(
|
|
title="Apple Q4 Earnings",
|
|
source="Reuters",
|
|
url="https://reuters.com/article1",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="Apple Inc reported strong earnings.",
|
|
ticker_mentions=["AAPL"],
|
|
),
|
|
NewsArticle(
|
|
title="Apple iPhone Sales",
|
|
source="Bloomberg",
|
|
url="https://bloomberg.com/article2",
|
|
published_at=now - timedelta(hours=2),
|
|
content_snippet="Apple saw record iPhone sales.",
|
|
ticker_mentions=["AAPL"],
|
|
),
|
|
NewsArticle(
|
|
title="Apple AI Features",
|
|
source="WSJ",
|
|
url="https://wsj.com/article3",
|
|
published_at=now - timedelta(hours=3),
|
|
content_snippet="Apple announced AI features.",
|
|
ticker_mentions=["AAPL"],
|
|
),
|
|
]
|
|
|
|
mentions = [
|
|
EntityMention(
|
|
company_name="Apple Inc",
|
|
confidence=0.95,
|
|
context_snippet="Apple Inc reported strong earnings",
|
|
article_id="article_0",
|
|
event_type=EventCategory.EARNINGS,
|
|
),
|
|
EntityMention(
|
|
company_name="Apple",
|
|
confidence=0.90,
|
|
context_snippet="Apple saw record iPhone sales",
|
|
article_id="article_1",
|
|
event_type=EventCategory.EARNINGS,
|
|
),
|
|
EntityMention(
|
|
company_name="Apple Inc.",
|
|
confidence=0.92,
|
|
context_snippet="Apple announced AI features",
|
|
article_id="article_2",
|
|
event_type=EventCategory.PRODUCT_LAUNCH,
|
|
),
|
|
]
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
|
) as mock_resolve:
|
|
mock_resolve.return_value = "AAPL"
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.classify_sector"
|
|
) as mock_sector:
|
|
mock_sector.return_value = "technology"
|
|
|
|
result = calculate_trending_scores(mentions, articles)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].ticker == "AAPL"
|
|
assert result[0].mention_count == 3
|
|
|
|
|
|
class TestSentimentIntensityFactor:
|
|
def test_sentiment_intensity_uses_absolute_value(self):
|
|
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
|
|
|
now = datetime.now()
|
|
articles = [
|
|
NewsArticle(
|
|
title="Stock drops sharply",
|
|
source="Reuters",
|
|
url="https://reuters.com/article1",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="Company faced major issues.",
|
|
ticker_mentions=["TSLA"],
|
|
),
|
|
NewsArticle(
|
|
title="More bad news",
|
|
source="Bloomberg",
|
|
url="https://bloomberg.com/article2",
|
|
published_at=now - timedelta(hours=2),
|
|
content_snippet="Further decline expected.",
|
|
ticker_mentions=["TSLA"],
|
|
),
|
|
]
|
|
|
|
mentions = [
|
|
EntityMention(
|
|
company_name="Tesla",
|
|
confidence=0.95,
|
|
context_snippet="Company faced major issues",
|
|
article_id="article_0",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=-0.8,
|
|
),
|
|
EntityMention(
|
|
company_name="Tesla Inc",
|
|
confidence=0.90,
|
|
context_snippet="Further decline expected",
|
|
article_id="article_1",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=-0.6,
|
|
),
|
|
]
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
|
) as mock_resolve:
|
|
mock_resolve.return_value = "TSLA"
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.classify_sector"
|
|
) as mock_sector:
|
|
mock_sector.return_value = "technology"
|
|
|
|
result = calculate_trending_scores(mentions, articles)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].sentiment < 0
|
|
expected_sentiment = (-0.8 * 0.95 + -0.6 * 0.90) / (0.95 + 0.90)
|
|
assert abs(result[0].sentiment - expected_sentiment) < 0.01
|
|
|
|
|
|
class TestRecencyWeightExponentialDecay:
|
|
def test_recency_weight_exponential_decay(self):
|
|
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
|
|
|
now = datetime.now()
|
|
articles = [
|
|
NewsArticle(
|
|
title="Recent news",
|
|
source="Reuters",
|
|
url="https://reuters.com/article1",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="Recent company news.",
|
|
ticker_mentions=["NVDA"],
|
|
),
|
|
NewsArticle(
|
|
title="Older news",
|
|
source="Bloomberg",
|
|
url="https://bloomberg.com/article2",
|
|
published_at=now - timedelta(hours=10),
|
|
content_snippet="Older company news.",
|
|
ticker_mentions=["NVDA"],
|
|
),
|
|
]
|
|
|
|
mentions = [
|
|
EntityMention(
|
|
company_name="Nvidia",
|
|
confidence=0.90,
|
|
context_snippet="Recent company news",
|
|
article_id="article_0",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=0.5,
|
|
),
|
|
EntityMention(
|
|
company_name="Nvidia",
|
|
confidence=0.90,
|
|
context_snippet="Older company news",
|
|
article_id="article_1",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=0.5,
|
|
),
|
|
]
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
|
) as mock_resolve:
|
|
mock_resolve.return_value = "NVDA"
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.classify_sector"
|
|
) as mock_sector:
|
|
mock_sector.return_value = "technology"
|
|
|
|
result = calculate_trending_scores(mentions, articles, decay_rate=0.1)
|
|
|
|
assert len(result) == 1
|
|
recent_weight = math.exp(-0.1 * 1)
|
|
older_weight = math.exp(-0.1 * 10)
|
|
avg_recency = (recent_weight + older_weight) / 2
|
|
assert result[0].score > 0
|
|
|
|
|
|
class TestMinimumThresholdFiltering:
|
|
def test_minimum_threshold_filtering_requires_two_articles(self):
|
|
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
|
|
|
now = datetime.now()
|
|
articles = [
|
|
NewsArticle(
|
|
title="Single mention stock",
|
|
source="Reuters",
|
|
url="https://reuters.com/article1",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="Some company news.",
|
|
ticker_mentions=["AMD"],
|
|
),
|
|
NewsArticle(
|
|
title="Multiple mention stock 1",
|
|
source="Bloomberg",
|
|
url="https://bloomberg.com/article2",
|
|
published_at=now - timedelta(hours=2),
|
|
content_snippet="Popular company news.",
|
|
ticker_mentions=["MSFT"],
|
|
),
|
|
NewsArticle(
|
|
title="Multiple mention stock 2",
|
|
source="WSJ",
|
|
url="https://wsj.com/article3",
|
|
published_at=now - timedelta(hours=3),
|
|
content_snippet="More popular company news.",
|
|
ticker_mentions=["MSFT"],
|
|
),
|
|
]
|
|
|
|
mentions = [
|
|
EntityMention(
|
|
company_name="AMD",
|
|
confidence=0.90,
|
|
context_snippet="Some company news",
|
|
article_id="article_0",
|
|
event_type=EventCategory.OTHER,
|
|
),
|
|
EntityMention(
|
|
company_name="Microsoft",
|
|
confidence=0.95,
|
|
context_snippet="Popular company news",
|
|
article_id="article_1",
|
|
event_type=EventCategory.OTHER,
|
|
),
|
|
EntityMention(
|
|
company_name="Microsoft Corp",
|
|
confidence=0.92,
|
|
context_snippet="More popular company news",
|
|
article_id="article_2",
|
|
event_type=EventCategory.OTHER,
|
|
),
|
|
]
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
|
) as mock_resolve:
|
|
|
|
def resolve_side_effect(name):
|
|
if "AMD" in name or name == "AMD":
|
|
return "AMD"
|
|
return "MSFT"
|
|
|
|
mock_resolve.side_effect = resolve_side_effect
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.classify_sector"
|
|
) as mock_sector:
|
|
mock_sector.return_value = "technology"
|
|
|
|
result = calculate_trending_scores(mentions, articles, min_mentions=2)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].ticker == "MSFT"
|
|
assert all(stock.mention_count >= 2 for stock in result)
|
|
|
|
|
|
class TestFinalScoreFormulaCorrectness:
|
|
def test_final_score_formula_correctness(self):
|
|
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
|
|
|
now = datetime.now()
|
|
hours_old = 2.0
|
|
articles = [
|
|
NewsArticle(
|
|
title="Test article 1",
|
|
source="Reuters",
|
|
url="https://reuters.com/article1",
|
|
published_at=now - timedelta(hours=hours_old),
|
|
content_snippet="Google announced results.",
|
|
ticker_mentions=["GOOGL"],
|
|
),
|
|
NewsArticle(
|
|
title="Test article 2",
|
|
source="Bloomberg",
|
|
url="https://bloomberg.com/article2",
|
|
published_at=now - timedelta(hours=hours_old),
|
|
content_snippet="Alphabet earnings beat.",
|
|
ticker_mentions=["GOOGL"],
|
|
),
|
|
]
|
|
|
|
sentiment_val = 0.6
|
|
confidence = 0.9
|
|
mentions = [
|
|
EntityMention(
|
|
company_name="Google",
|
|
confidence=confidence,
|
|
context_snippet="Google announced results",
|
|
article_id="article_0",
|
|
event_type=EventCategory.EARNINGS,
|
|
sentiment=sentiment_val,
|
|
),
|
|
EntityMention(
|
|
company_name="Alphabet",
|
|
confidence=confidence,
|
|
context_snippet="Alphabet earnings beat",
|
|
article_id="article_1",
|
|
event_type=EventCategory.EARNINGS,
|
|
sentiment=sentiment_val,
|
|
),
|
|
]
|
|
|
|
decay_rate = 0.1
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
|
) as mock_resolve:
|
|
mock_resolve.return_value = "GOOGL"
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.classify_sector"
|
|
) as mock_sector:
|
|
mock_sector.return_value = "technology"
|
|
|
|
result = calculate_trending_scores(
|
|
mentions, articles, decay_rate=decay_rate
|
|
)
|
|
|
|
assert len(result) == 1
|
|
stock = result[0]
|
|
|
|
frequency = 2
|
|
sentiment_factor = 1 + abs(sentiment_val)
|
|
recency_weight = math.exp(-decay_rate * hours_old)
|
|
expected_score = frequency * sentiment_factor * recency_weight
|
|
|
|
assert abs(stock.score - expected_score) < 0.01
|
|
|
|
|
|
class TestSortingByScoreDescending:
|
|
def test_results_sorted_by_score_descending(self):
|
|
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
|
|
|
now = datetime.now()
|
|
articles = [
|
|
NewsArticle(
|
|
title="High score stock 1",
|
|
source="Reuters",
|
|
url="https://reuters.com/article1",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="Apple news.",
|
|
ticker_mentions=["AAPL"],
|
|
),
|
|
NewsArticle(
|
|
title="High score stock 2",
|
|
source="Bloomberg",
|
|
url="https://bloomberg.com/article2",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="More Apple news.",
|
|
ticker_mentions=["AAPL"],
|
|
),
|
|
NewsArticle(
|
|
title="High score stock 3",
|
|
source="WSJ",
|
|
url="https://wsj.com/article3",
|
|
published_at=now - timedelta(hours=1),
|
|
content_snippet="Even more Apple news.",
|
|
ticker_mentions=["AAPL"],
|
|
),
|
|
NewsArticle(
|
|
title="Low score stock 1",
|
|
source="CNBC",
|
|
url="https://cnbc.com/article4",
|
|
published_at=now - timedelta(hours=10),
|
|
content_snippet="Tesla news.",
|
|
ticker_mentions=["TSLA"],
|
|
),
|
|
NewsArticle(
|
|
title="Low score stock 2",
|
|
source="FT",
|
|
url="https://ft.com/article5",
|
|
published_at=now - timedelta(hours=10),
|
|
content_snippet="More Tesla news.",
|
|
ticker_mentions=["TSLA"],
|
|
),
|
|
]
|
|
|
|
mentions = [
|
|
EntityMention(
|
|
company_name="Apple",
|
|
confidence=0.95,
|
|
context_snippet="Apple news",
|
|
article_id="article_0",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=0.8,
|
|
),
|
|
EntityMention(
|
|
company_name="Apple Inc",
|
|
confidence=0.93,
|
|
context_snippet="More Apple news",
|
|
article_id="article_1",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=0.8,
|
|
),
|
|
EntityMention(
|
|
company_name="Apple",
|
|
confidence=0.90,
|
|
context_snippet="Even more Apple news",
|
|
article_id="article_2",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=0.8,
|
|
),
|
|
EntityMention(
|
|
company_name="Tesla",
|
|
confidence=0.85,
|
|
context_snippet="Tesla news",
|
|
article_id="article_3",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=0.2,
|
|
),
|
|
EntityMention(
|
|
company_name="Tesla Inc",
|
|
confidence=0.85,
|
|
context_snippet="More Tesla news",
|
|
article_id="article_4",
|
|
event_type=EventCategory.OTHER,
|
|
sentiment=0.2,
|
|
),
|
|
]
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
|
) as mock_resolve:
|
|
|
|
def resolve_side_effect(name):
|
|
if "Apple" in name:
|
|
return "AAPL"
|
|
if "Tesla" in name:
|
|
return "TSLA"
|
|
return None
|
|
|
|
mock_resolve.side_effect = resolve_side_effect
|
|
|
|
with patch(
|
|
"tradingagents.agents.discovery.scorer.classify_sector"
|
|
) as mock_sector:
|
|
mock_sector.return_value = "technology"
|
|
|
|
result = calculate_trending_scores(mentions, articles, min_mentions=2)
|
|
|
|
assert len(result) == 2
|
|
for i in range(len(result) - 1):
|
|
assert result[i].score >= result[i + 1].score
|