TradingAgents/tests/discovery/test_entity_extractor.py

276 lines
8.8 KiB
Python

from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from tradingagents.agents.discovery import EventCategory, NewsArticle
class TestExtractEntitiesReturnsCompanyMentions:
def test_extract_entities_returns_list_of_company_mentions(self):
from tradingagents.agents.discovery.entity_extractor import (
EntityMention,
extract_entities,
)
articles = [
NewsArticle(
title="Apple announces new iPhone",
source="Reuters",
url="https://reuters.com/apple",
published_at=datetime.now(),
content_snippet="Apple Inc unveiled its latest iPhone model today with advanced AI features.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Apple Inc",
confidence=0.95,
context_snippet="Apple Inc unveiled its latest iPhone",
event_type="product_launch",
sentiment=0.7,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
assert isinstance(result, list)
assert len(result) > 0
assert all(isinstance(m, EntityMention) for m in result)
assert result[0].company_name == "Apple Inc"
class TestConfidenceScoreRange:
def test_confidence_score_in_valid_range(self):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
)
articles = [
NewsArticle(
title="Tesla reports earnings",
source="Bloomberg",
url="https://bloomberg.com/tsla",
published_at=datetime.now(),
content_snippet="Tesla Inc reported strong quarterly earnings beating analyst expectations.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Tesla Inc",
confidence=0.88,
context_snippet="Tesla Inc reported strong quarterly earnings",
event_type="earnings",
sentiment=0.5,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
for mention in result:
assert 0.0 <= mention.confidence <= 1.0
class TestContextSnippetExtraction:
def test_context_snippet_extraction(self):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
)
articles = [
NewsArticle(
title="Microsoft acquires gaming company",
source="WSJ",
url="https://wsj.com/msft",
published_at=datetime.now(),
content_snippet="Microsoft Corporation announced today it will acquire a major gaming studio for $10 billion.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Microsoft Corporation",
confidence=0.92,
context_snippet="Microsoft Corporation announced today it will acquire",
event_type="merger_acquisition",
sentiment=0.6,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
assert len(result) > 0
for mention in result:
assert mention.context_snippet is not None
assert len(mention.context_snippet) > 0
assert len(mention.context_snippet) <= 150
class TestBatchProcessing:
def test_batch_processing_of_multiple_articles(self):
from tradingagents.agents.discovery.entity_extractor import (
BATCH_SIZE,
extract_entities,
)
articles = [
NewsArticle(
title=f"News article {i}",
source="Reuters",
url=f"https://reuters.com/article{i}",
published_at=datetime.now(),
content_snippet=f"Company {i} announced major developments today.",
ticker_mentions=[],
)
for i in range(15)
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Test Company",
confidence=0.85,
context_snippet="Company announced major developments",
event_type="other",
sentiment=0.0,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
structured_llm = MagicMock()
structured_llm.invoke.return_value = mock_response
mock_llm.with_structured_output.return_value = structured_llm
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
expected_batches = (len(articles) + BATCH_SIZE - 1) // BATCH_SIZE
assert structured_llm.invoke.call_count == expected_batches
class TestNoCompanyMentions:
def test_handling_of_articles_with_no_company_mentions(self):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
)
articles = [
NewsArticle(
title="Weather forecast for tomorrow",
source="Weather Channel",
url="https://weather.com/forecast",
published_at=datetime.now(),
content_snippet="Tomorrow will be sunny with temperatures reaching 75 degrees.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = []
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
assert isinstance(result, list)
assert len(result) == 0
class TestEventTypeClassification:
@pytest.mark.parametrize(
"event_type",
[
"earnings",
"merger_acquisition",
"regulatory",
"product_launch",
"executive_change",
"other",
],
)
def test_event_type_classification(self, event_type):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
)
articles = [
NewsArticle(
title="Company news",
source="Reuters",
url="https://reuters.com/news",
published_at=datetime.now(),
content_snippet="A company made an announcement today.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Test Company",
confidence=0.90,
context_snippet="A company made an announcement",
event_type=event_type,
sentiment=0.0,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
assert len(result) > 0
assert result[0].event_type == EventCategory(event_type)