Add trending stock discovery feature

Implement a multi-stage pipeline to discover trending stocks from news:
- Entity extraction from news articles using LLM
- Stock ticker resolution via Yahoo Finance
- Sector classification and event categorization
- Scoring algorithm based on mentions, sentiment, and recency
- CLI integration with interactive stock selection and analysis flow
- Persistence layer for saving discovery results
- Comprehensive test suite for all discovery components

Update README with uv-based installation instructions and remove emojis.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Joseph O'Brien 2025-12-02 20:19:34 -05:00
parent 13b826a31d
commit 3f6b1e9f39
33 changed files with 5609 additions and 815 deletions

10
.gitignore vendored
View File

@ -9,3 +9,13 @@ eval_results/
eval_data/
*.egg-info/
.env
.claude/
.pytest_cache/
.specify/
specs/
agent-os/
*.local.md
build/
.mcp.json
*.zip
todos.md

View File

@ -27,7 +27,7 @@
# TradingAgents: Multi-Agents LLM Financial Trading Framework
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
> **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
>
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
@ -43,7 +43,7 @@
<div align="center">
🚀 [TradingAgents](#tradingagents-framework) | [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
[TradingAgents](#tradingagents-framework) | [Installation & CLI](#installation-and-cli) | [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | [Package Usage](#tradingagents-package) | [Contributing](#contributing) | [Citation](#citation)
</div>
@ -101,15 +101,10 @@ git clone https://github.com/TauricResearch/TradingAgents.git
cd TradingAgents
```
Create a virtual environment in any of your favorite environment managers:
Sync virtual environment:
```bash
conda create -n tradingagents python=3.13
conda activate tradingagents
```
Install dependencies:
```bash
pip install -r requirements.txt
uv sync
uv source .venv/bin/activate
```
### Required APIs
@ -133,7 +128,7 @@ cp .env.example .env
You can also try out the CLI directly by running:
```bash
python -m cli.main
uv run cli/main.py
```
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
@ -204,13 +199,9 @@ print(decision)
You can view the full list of configurations in `tradingagents/default_config.py`.
## Contributing
## Source
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
## Citation
Please reference our work if you find *TradingAgents* provides you with some help :)
Thanks to Yijia Xiao and Edward Sun and Di Luo and Wei Wang. Core agent implementation based on [TradingAgents: Multi-Agents LLM Financial Trading Framework](https://arxiv.org/abs/2412.20138)
```
@misc{xiao2025tradingagentsmultiagentsllmfinancial,

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,7 @@ load_dotenv()
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-4o-mini" # Use a different model
config["deep_think_llm"] = "gpt-5" # Use a different model
config["quick_think_llm"] = "gpt-4o-mini" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds

View File

200
tests/discovery/test_api.py Normal file
View File

@ -0,0 +1,200 @@
import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta
import signal
from tradingagents.agents.discovery import (
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
TrendingStock,
NewsArticle,
Sector,
EventCategory,
DiscoveryTimeoutError,
)
def create_mock_trending_stock(
ticker: str = "AAPL",
company_name: str = "Apple Inc.",
score: float = 10.0,
sector: Sector = Sector.TECHNOLOGY,
event_type: EventCategory = EventCategory.EARNINGS,
) -> TrendingStock:
return TrendingStock(
ticker=ticker,
company_name=company_name,
score=score,
mention_count=5,
sentiment=0.5,
sector=sector,
event_type=event_type,
news_summary="Test news summary",
source_articles=[],
)
def create_mock_news_article() -> NewsArticle:
return NewsArticle(
title="Test Article",
source="Test Source",
url="https://example.com/article",
published_at=datetime.now(),
content_snippet="Test content about Apple stock",
ticker_mentions=["AAPL"],
)
class TestDiscoverTrendingReturnsDiscoveryResult:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_discover_trending_returns_discovery_result(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = [create_mock_news_article()]
mock_extract.return_value = []
mock_scores.return_value = [create_mock_trending_stock()]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
result = graph.discover_trending()
assert isinstance(result, DiscoveryResult)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) > 0
class TestAnalyzeTrendingCallsPropagate:
def test_analyze_trending_calls_propagate(self):
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.propagate = Mock(return_value=({"final_state": "test"}, "BUY"))
trending_stock = create_mock_trending_stock()
result = graph.analyze_trending(trending_stock)
graph.propagate.assert_called_once()
call_args = graph.propagate.call_args
assert call_args[0][0] == "AAPL"
class TestSectorFilterParameter:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_sector_filter_filters_results(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = [create_mock_news_article()]
mock_extract.return_value = []
mock_scores.return_value = [
create_mock_trending_stock(ticker="AAPL", sector=Sector.TECHNOLOGY),
create_mock_trending_stock(ticker="JPM", sector=Sector.FINANCE),
create_mock_trending_stock(ticker="XOM", sector=Sector.ENERGY),
]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.TECHNOLOGY],
)
result = graph.discover_trending(request)
assert all(
stock.sector == Sector.TECHNOLOGY for stock in result.trending_stocks
)
class TestEventFilterParameter:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_event_filter_filters_results(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = [create_mock_news_article()]
mock_extract.return_value = []
mock_scores.return_value = [
create_mock_trending_stock(ticker="AAPL", event_type=EventCategory.EARNINGS),
create_mock_trending_stock(
ticker="MSFT", event_type=EventCategory.PRODUCT_LAUNCH
),
create_mock_trending_stock(
ticker="GOOGL", event_type=EventCategory.MERGER_ACQUISITION
),
]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(
lookback_period="24h",
event_filter=[EventCategory.EARNINGS],
)
result = graph.discover_trending(request)
assert all(
stock.event_type == EventCategory.EARNINGS
for stock in result.trending_stocks
)
class TestTimeoutHandling:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
def test_timeout_raises_discovery_timeout_error(self, mock_bulk_news):
def slow_fetch(*args, **kwargs):
import time
time.sleep(0.5)
return []
mock_bulk_news.side_effect = slow_fetch
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 0.1,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
with pytest.raises(DiscoveryTimeoutError):
graph.discover_trending()

View File

@ -0,0 +1,160 @@
import pytest
from datetime import datetime, timedelta
from unittest.mock import patch, MagicMock
from tradingagents.agents.discovery import NewsArticle
from tradingagents.dataflows.alpha_vantage_common import AlphaVantageRateLimitError
class TestGetBulkNewsReturnsNewsArticles:
def test_get_bulk_news_returns_list_of_news_article_objects(self):
mock_raw_news = [
{
"title": "Market Update: Tech stocks rally",
"source": "Reuters",
"url": "https://reuters.com/market-update",
"published_at": datetime.now().isoformat(),
"content_snippet": "Technology stocks led gains in early trading...",
},
{
"title": "Fed signals rate decision",
"source": "Bloomberg",
"url": "https://bloomberg.com/fed-rates",
"published_at": datetime.now().isoformat(),
"content_snippet": "Federal Reserve officials indicated...",
},
]
from tradingagents.dataflows.interface import (
_bulk_news_cache,
get_bulk_news,
)
_bulk_news_cache.clear()
with patch(
"tradingagents.dataflows.interface._fetch_bulk_news_from_vendor"
) as mock_fetch:
mock_fetch.return_value = mock_raw_news
result = get_bulk_news(lookback_period="24h")
assert isinstance(result, list)
assert len(result) == 2
for article in result:
assert isinstance(article, NewsArticle)
assert article.title is not None
assert article.source is not None
assert article.url is not None
class TestLookbackPeriodParsing:
@pytest.mark.parametrize(
"lookback,expected_hours",
[
("1h", 1),
("6h", 6),
("24h", 24),
("7d", 168),
],
)
def test_lookback_period_parsing(self, lookback, expected_hours):
from tradingagents.dataflows.interface import parse_lookback_period
hours = parse_lookback_period(lookback)
assert hours == expected_hours
def test_invalid_lookback_period_raises_error(self):
from tradingagents.dataflows.interface import parse_lookback_period
with pytest.raises(ValueError):
parse_lookback_period("invalid")
class TestVendorFallback:
def test_vendor_fallback_when_primary_rate_limited(self):
mock_openai_news = [
{
"title": "Fallback news from OpenAI",
"source": "Web Search",
"url": "https://example.com/fallback",
"published_at": datetime.now().isoformat(),
"content_snippet": "This is fallback content...",
},
]
from tradingagents.dataflows.interface import (
_bulk_news_cache,
)
_bulk_news_cache.clear()
with patch(
"tradingagents.dataflows.interface.VENDOR_METHODS",
{
"get_bulk_news": {
"alpha_vantage": MagicMock(side_effect=AlphaVantageRateLimitError("Rate limit")),
"openai": MagicMock(return_value=mock_openai_news),
"google": MagicMock(return_value=[]),
}
}
):
from tradingagents.dataflows.interface import _fetch_bulk_news_from_vendor
result = _fetch_bulk_news_from_vendor("24h")
assert isinstance(result, list)
assert len(result) == 1
assert result[0]["title"] == "Fallback news from OpenAI"
class TestBulkNewsCache:
def test_cache_returns_same_results_within_ttl(self):
from tradingagents.dataflows.interface import (
_bulk_news_cache,
_get_cached_bulk_news,
_set_cached_bulk_news,
)
_bulk_news_cache.clear()
test_articles = [
NewsArticle(
title="Cached article",
source="Test Source",
url="https://test.com/cached",
published_at=datetime.now(),
content_snippet="Cached content...",
ticker_mentions=[],
)
]
_set_cached_bulk_news("24h", test_articles)
cached_result = _get_cached_bulk_news("24h")
assert cached_result is not None
assert len(cached_result) == 1
assert cached_result[0].title == "Cached article"
cached_result_again = _get_cached_bulk_news("24h")
assert cached_result_again is not None
assert cached_result_again[0].title == cached_result[0].title
class TestEmptyResultHandling:
def test_empty_result_handling(self):
from tradingagents.dataflows.interface import (
_bulk_news_cache,
get_bulk_news,
)
_bulk_news_cache.clear()
with patch(
"tradingagents.dataflows.interface._fetch_bulk_news_from_vendor"
) as mock_fetch:
mock_fetch.return_value = []
result = get_bulk_news(lookback_period="1h")
assert isinstance(result, list)
assert len(result) == 0

127
tests/discovery/test_cli.py Normal file
View File

@ -0,0 +1,127 @@
import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime
from io import StringIO
from tradingagents.agents.discovery.models import (
DiscoveryResult,
DiscoveryRequest,
DiscoveryStatus,
TrendingStock,
NewsArticle,
Sector,
EventCategory,
)
@pytest.fixture
def sample_trending_stocks():
article = NewsArticle(
title="Apple announces new iPhone",
source="Reuters",
url="https://reuters.com/article",
published_at=datetime.now(),
content_snippet="Apple Inc. unveiled its latest iPhone model today...",
ticker_mentions=["AAPL"],
)
return [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=8.5,
mention_count=10,
sentiment=0.7,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="Apple announced new iPhone model with enhanced AI capabilities.",
source_articles=[article],
),
TrendingStock(
ticker="MSFT",
company_name="Microsoft Corporation",
score=7.2,
mention_count=8,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Microsoft reported strong quarterly earnings.",
source_articles=[article],
),
TrendingStock(
ticker="NVDA",
company_name="NVIDIA Corporation",
score=6.8,
mention_count=6,
sentiment=0.4,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="NVIDIA unveiled new AI chips.",
source_articles=[article],
),
]
@pytest.fixture
def sample_discovery_result(sample_trending_stocks):
request = DiscoveryRequest(
lookback_period="24h",
max_results=20,
)
return DiscoveryResult(
request=request,
trending_stocks=sample_trending_stocks,
status=DiscoveryStatus.COMPLETED,
started_at=datetime.now(),
completed_at=datetime.now(),
)
class TestDiscoveryMenuOption:
def test_discover_trending_flow_exists(self):
from cli.main import discover_trending_flow
assert callable(discover_trending_flow)
def test_select_lookback_period_function_exists(self):
from cli.main import select_lookback_period
assert callable(select_lookback_period)
class TestLookbackSelection:
@patch("cli.main.questionary.select")
def test_lookback_selection_returns_valid_period(self, mock_select):
mock_select.return_value.ask.return_value = "24h"
from cli.main import select_lookback_period
result = select_lookback_period()
assert result in ["1h", "6h", "24h", "7d"]
@patch("cli.main.questionary.select")
def test_lookback_selection_handles_all_options(self, mock_select):
from cli.main import select_lookback_period
for period in ["1h", "6h", "24h", "7d"]:
mock_select.return_value.ask.return_value = period
result = select_lookback_period()
assert result == period
class TestResultsTableDisplay:
def test_create_discovery_results_table(self, sample_trending_stocks):
from cli.main import create_discovery_results_table
table = create_discovery_results_table(sample_trending_stocks)
assert table is not None
assert table.row_count == len(sample_trending_stocks)
def test_table_has_correct_columns(self, sample_trending_stocks):
from cli.main import create_discovery_results_table
table = create_discovery_results_table(sample_trending_stocks)
column_names = [col.header for col in table.columns]
expected_columns = ["Rank", "Ticker", "Company", "Score", "Mentions", "Event Type"]
for expected in expected_columns:
assert expected in column_names
class TestDetailView:
def test_create_stock_detail_panel(self, sample_trending_stocks):
from cli.main import create_stock_detail_panel
stock = sample_trending_stocks[0]
panel = create_stock_detail_panel(stock, rank=1)
assert panel is not None

View File

@ -0,0 +1,278 @@
import pytest
from datetime import datetime
from unittest.mock import patch, MagicMock
from tradingagents.agents.discovery import NewsArticle, EventCategory
class TestExtractEntitiesReturnsCompanyMentions:
def test_extract_entities_returns_list_of_company_mentions(self):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
EntityMention,
)
articles = [
NewsArticle(
title="Apple announces new iPhone",
source="Reuters",
url="https://reuters.com/apple",
published_at=datetime.now(),
content_snippet="Apple Inc unveiled its latest iPhone model today with advanced AI features.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Apple Inc",
confidence=0.95,
context_snippet="Apple Inc unveiled its latest iPhone",
event_type="product_launch",
sentiment=0.7,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
assert isinstance(result, list)
assert len(result) > 0
assert all(isinstance(m, EntityMention) for m in result)
assert result[0].company_name == "Apple Inc"
class TestConfidenceScoreRange:
def test_confidence_score_in_valid_range(self):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
EntityMention,
)
articles = [
NewsArticle(
title="Tesla reports earnings",
source="Bloomberg",
url="https://bloomberg.com/tsla",
published_at=datetime.now(),
content_snippet="Tesla Inc reported strong quarterly earnings beating analyst expectations.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Tesla Inc",
confidence=0.88,
context_snippet="Tesla Inc reported strong quarterly earnings",
event_type="earnings",
sentiment=0.5,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
for mention in result:
assert 0.0 <= mention.confidence <= 1.0
class TestContextSnippetExtraction:
def test_context_snippet_extraction(self):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
EntityMention,
)
articles = [
NewsArticle(
title="Microsoft acquires gaming company",
source="WSJ",
url="https://wsj.com/msft",
published_at=datetime.now(),
content_snippet="Microsoft Corporation announced today it will acquire a major gaming studio for $10 billion.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Microsoft Corporation",
confidence=0.92,
context_snippet="Microsoft Corporation announced today it will acquire",
event_type="merger_acquisition",
sentiment=0.6,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
assert len(result) > 0
for mention in result:
assert mention.context_snippet is not None
assert len(mention.context_snippet) > 0
assert len(mention.context_snippet) <= 150
class TestBatchProcessing:
def test_batch_processing_of_multiple_articles(self):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
EntityMention,
BATCH_SIZE,
)
articles = [
NewsArticle(
title=f"News article {i}",
source="Reuters",
url=f"https://reuters.com/article{i}",
published_at=datetime.now(),
content_snippet=f"Company {i} announced major developments today.",
ticker_mentions=[],
)
for i in range(15)
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Test Company",
confidence=0.85,
context_snippet="Company announced major developments",
event_type="other",
sentiment=0.0,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
structured_llm = MagicMock()
structured_llm.invoke.return_value = mock_response
mock_llm.with_structured_output.return_value = structured_llm
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
expected_batches = (len(articles) + BATCH_SIZE - 1) // BATCH_SIZE
assert structured_llm.invoke.call_count == expected_batches
class TestNoCompanyMentions:
def test_handling_of_articles_with_no_company_mentions(self):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
EntityMention,
)
articles = [
NewsArticle(
title="Weather forecast for tomorrow",
source="Weather Channel",
url="https://weather.com/forecast",
published_at=datetime.now(),
content_snippet="Tomorrow will be sunny with temperatures reaching 75 degrees.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = []
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
assert isinstance(result, list)
assert len(result) == 0
class TestEventTypeClassification:
@pytest.mark.parametrize(
"event_type",
[
"earnings",
"merger_acquisition",
"regulatory",
"product_launch",
"executive_change",
"other",
],
)
def test_event_type_classification(self, event_type):
from tradingagents.agents.discovery.entity_extractor import (
extract_entities,
EntityMention,
)
articles = [
NewsArticle(
title="Company news",
source="Reuters",
url="https://reuters.com/news",
published_at=datetime.now(),
content_snippet="A company made an announcement today.",
ticker_mentions=[],
),
]
mock_response = MagicMock()
mock_response.entities = [
MagicMock(
company_name="Test Company",
confidence=0.90,
context_snippet="A company made an announcement",
event_type=event_type,
sentiment=0.0,
)
]
with patch(
"tradingagents.agents.discovery.entity_extractor._get_llm"
) as mock_get_llm:
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value.invoke.return_value = (
mock_response
)
mock_get_llm.return_value = mock_llm
result = extract_entities(articles)
assert len(result) > 0
assert result[0].event_type == EventCategory(event_type)

View File

@ -0,0 +1,489 @@
import pytest
import math
from datetime import datetime, timedelta
from unittest.mock import patch, MagicMock
from tradingagents.agents.discovery import (
TrendingStock,
NewsArticle,
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
Sector,
EventCategory,
DiscoveryTimeoutError,
NewsUnavailableError,
)
from tradingagents.agents.discovery.entity_extractor import EntityMention
class TestEndToEndDiscoveryFlow:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_full_discovery_flow_from_news_to_results(
self, mock_scores, mock_extract, mock_bulk_news
):
now = datetime.now()
mock_articles = [
NewsArticle(
title="Apple announces record earnings",
source="Reuters",
url="https://reuters.com/apple-earnings",
published_at=now - timedelta(hours=2),
content_snippet="Apple Inc reported record quarterly earnings...",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Apple stock surges on AI news",
source="Bloomberg",
url="https://bloomberg.com/apple-ai",
published_at=now - timedelta(hours=1),
content_snippet="Shares of Apple jumped after AI announcement...",
ticker_mentions=["AAPL"],
),
]
mock_bulk_news.return_value = mock_articles
mock_mentions = [
EntityMention(
company_name="Apple Inc",
confidence=0.95,
context_snippet="Apple Inc reported record quarterly earnings",
article_id="article_0",
event_type=EventCategory.EARNINGS,
sentiment=0.8,
),
EntityMention(
company_name="Apple",
confidence=0.92,
context_snippet="Shares of Apple jumped",
article_id="article_1",
event_type=EventCategory.PRODUCT_LAUNCH,
sentiment=0.7,
),
]
mock_extract.return_value = mock_mentions
mock_trending = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=8.5,
mention_count=2,
sentiment=0.75,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple reported record earnings and AI progress.",
source_articles=mock_articles,
),
]
mock_scores.return_value = mock_trending
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(lookback_period="24h")
result = graph.discover_trending(request)
assert isinstance(result, DiscoveryResult)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 1
assert result.trending_stocks[0].ticker == "AAPL"
assert result.trending_stocks[0].mention_count >= 2
mock_bulk_news.assert_called_once_with("24h")
mock_extract.assert_called_once()
mock_scores.assert_called_once()
class TestEntityExtractionToScoringPipeline:
def test_pipeline_from_extraction_to_scoring(self):
from tradingagents.agents.discovery.scorer import calculate_trending_scores
now = datetime.now()
articles = [
NewsArticle(
title="Microsoft cloud revenue grows",
source="WSJ",
url="https://wsj.com/article1",
published_at=now - timedelta(hours=2),
content_snippet="Microsoft Corporation reported strong cloud growth.",
ticker_mentions=["MSFT"],
),
NewsArticle(
title="Microsoft earnings beat estimates",
source="CNBC",
url="https://cnbc.com/article2",
published_at=now - timedelta(hours=3),
content_snippet="Microsoft earnings exceeded analyst expectations.",
ticker_mentions=["MSFT"],
),
NewsArticle(
title="Tech stocks rally",
source="Bloomberg",
url="https://bloomberg.com/article3",
published_at=now - timedelta(hours=1),
content_snippet="Technology companies led market gains.",
ticker_mentions=[],
),
]
mentions = [
EntityMention(
company_name="Microsoft Corporation",
confidence=0.95,
context_snippet="Microsoft Corporation reported strong cloud growth",
article_id="article_0",
event_type=EventCategory.EARNINGS,
sentiment=0.7,
),
EntityMention(
company_name="Microsoft",
confidence=0.92,
context_snippet="Microsoft earnings exceeded analyst expectations",
article_id="article_1",
event_type=EventCategory.EARNINGS,
sentiment=0.8,
),
]
with patch("tradingagents.agents.discovery.scorer.resolve_ticker") as mock_resolve:
mock_resolve.return_value = "MSFT"
with patch("tradingagents.agents.discovery.scorer.classify_sector") as mock_sector:
mock_sector.return_value = "technology"
result = calculate_trending_scores(mentions, articles, min_mentions=2)
assert len(result) == 1
assert result[0].ticker == "MSFT"
assert result[0].mention_count == 2
assert result[0].sentiment > 0
class TestNewsVendorFailureGracefulDegradation:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
def test_news_vendor_failure_with_graceful_degradation(self, mock_bulk_news):
mock_bulk_news.side_effect = NewsUnavailableError("All news vendors failed")
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
result = graph.discover_trending()
assert result.status == DiscoveryStatus.FAILED
assert result.error_message is not None
assert "news" in result.error_message.lower() or "vendor" in result.error_message.lower()
class TestTimeoutHandlingWithPartialResults:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
def test_timeout_handling_returns_error(self, mock_bulk_news):
def slow_fetch(*args, **kwargs):
import time
time.sleep(0.3)
return []
mock_bulk_news.side_effect = slow_fetch
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 0.1,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
with pytest.raises(DiscoveryTimeoutError):
graph.discover_trending()
class TestNoTrendingStocksFound:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_no_trending_stocks_found_returns_empty_list(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = [
NewsArticle(
title="General market update",
source="Reuters",
url="https://reuters.com/general",
published_at=datetime.now(),
content_snippet="Markets were quiet today with no major news.",
ticker_mentions=[],
),
]
mock_extract.return_value = []
mock_scores.return_value = []
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
result = graph.discover_trending()
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 0
class TestAllStocksFilteredOutBySectorFilter:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_all_stocks_filtered_out_by_sector_filter(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=10.0,
mention_count=5,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple earnings",
source_articles=[],
),
TrendingStock(
ticker="MSFT",
company_name="Microsoft",
score=9.0,
mention_count=4,
sentiment=0.4,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="Microsoft product",
source_articles=[],
),
]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.HEALTHCARE],
)
result = graph.discover_trending(request)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 0
class TestAllStocksFilteredOutByEventFilter:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_all_stocks_filtered_out_by_event_filter(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=10.0,
mention_count=5,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple earnings",
source_articles=[],
),
]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(
lookback_period="24h",
event_filter=[EventCategory.MERGER_ACQUISITION],
)
result = graph.discover_trending(request)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 0
class TestMultipleSectorsAndEventsFiltering:
@patch("tradingagents.graph.trading_graph.get_bulk_news")
@patch("tradingagents.graph.trading_graph.extract_entities")
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
def test_combined_sector_and_event_filtering(
self, mock_scores, mock_extract, mock_bulk_news
):
mock_bulk_news.return_value = []
mock_extract.return_value = []
mock_scores.return_value = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=10.0,
mention_count=5,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple earnings",
source_articles=[],
),
TrendingStock(
ticker="JPM",
company_name="JPMorgan Chase",
score=9.0,
mention_count=4,
sentiment=0.4,
sector=Sector.FINANCE,
event_type=EventCategory.EARNINGS,
news_summary="JPM earnings",
source_articles=[],
),
TrendingStock(
ticker="XOM",
company_name="Exxon Mobil",
score=8.0,
mention_count=3,
sentiment=0.3,
sector=Sector.ENERGY,
event_type=EventCategory.REGULATORY,
news_summary="XOM regulatory news",
source_articles=[],
),
]
from tradingagents.graph.trading_graph import TradingAgentsGraph
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
graph = TradingAgentsGraph()
graph.config = {
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.TECHNOLOGY, Sector.FINANCE],
event_filter=[EventCategory.EARNINGS],
)
result = graph.discover_trending(request)
assert result.status == DiscoveryStatus.COMPLETED
assert len(result.trending_stocks) == 2
tickers = [s.ticker for s in result.trending_stocks]
assert "AAPL" in tickers
assert "JPM" in tickers
assert "XOM" not in tickers
class TestDiscoveryResultPersistenceIntegration:
def test_discovery_result_can_be_serialized_and_saved(self):
from tradingagents.agents.discovery.persistence import (
save_discovery_result,
generate_markdown_summary,
)
import tempfile
import shutil
from pathlib import Path
article = NewsArticle(
title="Test article",
source="Test",
url="https://test.com",
published_at=datetime.now(),
content_snippet="Test content",
ticker_mentions=["TEST"],
)
stock = TrendingStock(
ticker="TEST",
company_name="Test Company",
score=5.0,
mention_count=2,
sentiment=0.5,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.OTHER,
news_summary="Test news summary",
source_articles=[article],
)
request = DiscoveryRequest(
lookback_period="24h",
created_at=datetime.now(),
)
result = DiscoveryResult(
request=request,
trending_stocks=[stock],
status=DiscoveryStatus.COMPLETED,
started_at=datetime.now(),
completed_at=datetime.now(),
)
temp_dir = tempfile.mkdtemp()
try:
path = save_discovery_result(result, base_path=Path(temp_dir))
assert path.exists()
assert (path / "discovery_result.json").exists()
assert (path / "discovery_summary.md").exists()
markdown = generate_markdown_summary(result)
assert "TEST" in markdown
assert "Test Company" in markdown
finally:
shutil.rmtree(temp_dir)

View File

@ -0,0 +1,196 @@
import pytest
from datetime import datetime
from tradingagents.agents.discovery import (
TrendingStock,
NewsArticle,
DiscoveryRequest,
DiscoveryResult,
Sector,
EventCategory,
)
from tradingagents.agents.discovery.models import DiscoveryStatus
class TestTrendingStock:
def test_trending_stock_creation_and_validation(self):
article = NewsArticle(
title="Apple announces new iPhone",
source="Reuters",
url="https://reuters.com/article1",
published_at=datetime(2024, 1, 15, 10, 30, 0),
content_snippet="Apple Inc announced its latest iPhone model today...",
ticker_mentions=["AAPL"],
)
stock = TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=85.5,
mention_count=10,
sentiment=0.75,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="Apple announced new iPhone with advanced AI features.",
source_articles=[article],
)
assert stock.ticker == "AAPL"
assert stock.company_name == "Apple Inc."
assert stock.score == 85.5
assert stock.mention_count == 10
assert stock.sentiment == 0.75
assert stock.sector == Sector.TECHNOLOGY
assert stock.event_type == EventCategory.PRODUCT_LAUNCH
assert len(stock.source_articles) == 1
class TestNewsArticle:
def test_news_article_with_required_fields(self):
published = datetime(2024, 1, 15, 14, 0, 0)
article = NewsArticle(
title="Tesla Q4 Earnings Beat Expectations",
source="Bloomberg",
url="https://bloomberg.com/news/tsla-earnings",
published_at=published,
content_snippet="Tesla Inc. reported fourth quarter earnings that exceeded analyst expectations...",
ticker_mentions=["TSLA", "F"],
)
assert article.title == "Tesla Q4 Earnings Beat Expectations"
assert article.source == "Bloomberg"
assert article.url == "https://bloomberg.com/news/tsla-earnings"
assert article.published_at == published
assert article.content_snippet.startswith("Tesla Inc.")
assert "TSLA" in article.ticker_mentions
assert "F" in article.ticker_mentions
class TestDiscoveryRequest:
def test_discovery_request_with_lookback_period_validation(self):
created = datetime(2024, 1, 15, 12, 0, 0)
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.TECHNOLOGY, Sector.HEALTHCARE],
event_filter=[EventCategory.EARNINGS],
max_results=20,
created_at=created,
)
assert request.lookback_period == "24h"
assert Sector.TECHNOLOGY in request.sector_filter
assert Sector.HEALTHCARE in request.sector_filter
assert EventCategory.EARNINGS in request.event_filter
assert request.max_results == 20
assert request.created_at == created
def test_discovery_request_with_defaults(self):
request = DiscoveryRequest(
lookback_period="1h",
)
assert request.lookback_period == "1h"
assert request.sector_filter is None
assert request.event_filter is None
assert request.max_results == 20
assert request.created_at is not None
class TestDiscoveryResult:
def test_discovery_result_state_transitions(self):
request = DiscoveryRequest(lookback_period="6h")
started = datetime(2024, 1, 15, 12, 0, 0)
result = DiscoveryResult(
request=request,
trending_stocks=[],
status=DiscoveryStatus.CREATED,
started_at=started,
)
assert result.status == DiscoveryStatus.CREATED
result.status = DiscoveryStatus.PROCESSING
assert result.status == DiscoveryStatus.PROCESSING
result.status = DiscoveryStatus.COMPLETED
result.completed_at = datetime(2024, 1, 15, 12, 1, 0)
assert result.status == DiscoveryStatus.COMPLETED
assert result.completed_at is not None
def test_discovery_result_failed_state(self):
request = DiscoveryRequest(lookback_period="7d")
result = DiscoveryResult(
request=request,
trending_stocks=[],
status=DiscoveryStatus.FAILED,
started_at=datetime(2024, 1, 15, 12, 0, 0),
error_message="News API rate limit exceeded",
)
assert result.status == DiscoveryStatus.FAILED
assert result.error_message == "News API rate limit exceeded"
class TestSerializationRoundtrip:
def test_to_dict_and_from_dict_serialization_roundtrip(self):
article = NewsArticle(
title="Microsoft acquires AI startup",
source="WSJ",
url="https://wsj.com/msft-acquisition",
published_at=datetime(2024, 1, 15, 9, 0, 0),
content_snippet="Microsoft Corp announced the acquisition of an AI startup...",
ticker_mentions=["MSFT"],
)
stock = TrendingStock(
ticker="MSFT",
company_name="Microsoft Corporation",
score=92.3,
mention_count=15,
sentiment=0.65,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.MERGER_ACQUISITION,
news_summary="Microsoft announces major AI acquisition.",
source_articles=[article],
)
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.TECHNOLOGY],
event_filter=[EventCategory.MERGER_ACQUISITION],
max_results=10,
created_at=datetime(2024, 1, 15, 8, 0, 0),
)
result = DiscoveryResult(
request=request,
trending_stocks=[stock],
status=DiscoveryStatus.COMPLETED,
started_at=datetime(2024, 1, 15, 8, 0, 0),
completed_at=datetime(2024, 1, 15, 8, 1, 30),
)
result_dict = result.to_dict()
restored_result = DiscoveryResult.from_dict(result_dict)
assert restored_result.status == result.status
assert restored_result.request.lookback_period == request.lookback_period
assert len(restored_result.trending_stocks) == 1
restored_stock = restored_result.trending_stocks[0]
assert restored_stock.ticker == stock.ticker
assert restored_stock.company_name == stock.company_name
assert restored_stock.score == stock.score
assert restored_stock.mention_count == stock.mention_count
assert restored_stock.sentiment == stock.sentiment
assert restored_stock.sector == stock.sector
assert restored_stock.event_type == stock.event_type
assert len(restored_stock.source_articles) == 1
restored_article = restored_stock.source_articles[0]
assert restored_article.title == article.title
assert restored_article.source == article.source
assert restored_article.url == article.url

View File

@ -0,0 +1,228 @@
import pytest
import json
from datetime import datetime
from pathlib import Path
import tempfile
import shutil
from tradingagents.agents.discovery import (
TrendingStock,
NewsArticle,
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
Sector,
EventCategory,
)
from tradingagents.agents.discovery.persistence import (
save_discovery_result,
generate_markdown_summary,
)
@pytest.fixture
def sample_discovery_result():
articles = [
NewsArticle(
title="Apple announces new iPhone with AI features",
source="Reuters",
url="https://reuters.com/apple-iphone-ai",
published_at=datetime(2024, 1, 15, 10, 30, 0),
content_snippet="Apple Inc announced its latest iPhone model with advanced AI...",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Apple stock surges on earnings beat",
source="Bloomberg",
url="https://bloomberg.com/apple-earnings",
published_at=datetime(2024, 1, 15, 11, 0, 0),
content_snippet="Shares of Apple Inc surged after the company reported...",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Microsoft cloud revenue grows 25%",
source="WSJ",
url="https://wsj.com/msft-cloud",
published_at=datetime(2024, 1, 15, 9, 0, 0),
content_snippet="Microsoft Corp reported strong cloud revenue growth...",
ticker_mentions=["MSFT"],
),
]
stocks = [
TrendingStock(
ticker="AAPL",
company_name="Apple Inc.",
score=8.54,
mention_count=12,
sentiment=0.72,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.EARNINGS,
news_summary="Apple reported strong earnings and announced new AI features.",
source_articles=[articles[0], articles[1]],
),
TrendingStock(
ticker="MSFT",
company_name="Microsoft Corporation",
score=7.23,
mention_count=9,
sentiment=0.65,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.PRODUCT_LAUNCH,
news_summary="Microsoft cloud business continues strong growth.",
source_articles=[articles[2]],
),
TrendingStock(
ticker="GOOGL",
company_name="Alphabet Inc.",
score=6.15,
mention_count=7,
sentiment=0.58,
sector=Sector.TECHNOLOGY,
event_type=EventCategory.REGULATORY,
news_summary="Google faces regulatory scrutiny in multiple markets.",
source_articles=[],
),
]
request = DiscoveryRequest(
lookback_period="24h",
sector_filter=[Sector.TECHNOLOGY],
event_filter=[EventCategory.EARNINGS],
max_results=20,
created_at=datetime(2024, 1, 15, 14, 30, 45),
)
return DiscoveryResult(
request=request,
trending_stocks=stocks,
status=DiscoveryStatus.COMPLETED,
started_at=datetime(2024, 1, 15, 14, 30, 45),
completed_at=datetime(2024, 1, 15, 14, 31, 30),
)
@pytest.fixture
def temp_results_dir():
temp_dir = tempfile.mkdtemp()
yield Path(temp_dir)
shutil.rmtree(temp_dir)
class TestDirectoryStructureCreation:
def test_creates_correct_directory_structure(self, sample_discovery_result, temp_results_dir):
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
assert result_path.exists()
assert result_path.is_dir()
path_parts = result_path.parts
assert "discovery" in path_parts
date_part = path_parts[-2]
time_part = path_parts[-1]
assert len(date_part.split("-")) == 3
assert len(time_part.split("-")) == 3
class TestDiscoveryResultJson:
def test_discovery_result_json_contains_all_fields(self, sample_discovery_result, temp_results_dir):
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
json_path = result_path / "discovery_result.json"
assert json_path.exists()
with open(json_path, "r") as f:
saved_data = json.load(f)
assert "request" in saved_data
assert "trending_stocks" in saved_data
assert "status" in saved_data
assert "started_at" in saved_data
assert "completed_at" in saved_data
assert saved_data["request"]["lookback_period"] == "24h"
assert saved_data["status"] == "completed"
assert len(saved_data["trending_stocks"]) == 3
first_stock = saved_data["trending_stocks"][0]
assert first_stock["ticker"] == "AAPL"
assert first_stock["company_name"] == "Apple Inc."
assert first_stock["score"] == 8.54
assert first_stock["mention_count"] == 12
assert first_stock["sentiment"] == 0.72
assert first_stock["sector"] == "technology"
assert first_stock["event_type"] == "earnings"
assert "news_summary" in first_stock
assert "source_articles" in first_stock
class TestDiscoverySummaryMarkdown:
def test_discovery_summary_md_is_human_readable(self, sample_discovery_result, temp_results_dir):
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
md_path = result_path / "discovery_summary.md"
assert md_path.exists()
with open(md_path, "r") as f:
markdown_content = f.read()
assert "# Discovery Results" in markdown_content
assert "Timestamp:" in markdown_content
assert "Lookback Period:" in markdown_content
assert "24h" in markdown_content
assert "Total Stocks Found:" in markdown_content
assert "## Trending Stocks" in markdown_content
assert "| Rank |" in markdown_content
assert "| Ticker |" in markdown_content
assert "| Company |" in markdown_content
assert "| Score |" in markdown_content
assert "| Mentions |" in markdown_content
assert "| Event |" in markdown_content
assert "AAPL" in markdown_content
assert "Apple Inc." in markdown_content
assert "8.54" in markdown_content
assert "12" in markdown_content
assert "earnings" in markdown_content
assert "MSFT" in markdown_content
assert "Microsoft Corporation" in markdown_content
assert "## Top 3 Detailed Analysis" in markdown_content
assert "### 1. AAPL - Apple Inc." in markdown_content
assert "**Score:**" in markdown_content
assert "**Sentiment:**" in markdown_content
assert "**Sector:**" in markdown_content
assert "**Event Type:**" in markdown_content
assert "**Mentions:**" in markdown_content
assert "**News Summary:**" in markdown_content
class TestMarkdownGeneration:
def test_generate_markdown_with_filters(self, sample_discovery_result):
markdown = generate_markdown_summary(sample_discovery_result)
assert "sector=technology" in markdown.lower()
assert "event=earnings" in markdown.lower()
def test_generate_markdown_without_filters(self):
request = DiscoveryRequest(
lookback_period="6h",
created_at=datetime(2024, 1, 15, 10, 0, 0),
)
result = DiscoveryResult(
request=request,
trending_stocks=[],
status=DiscoveryStatus.COMPLETED,
started_at=datetime(2024, 1, 15, 10, 0, 0),
completed_at=datetime(2024, 1, 15, 10, 1, 0),
)
markdown = generate_markdown_summary(result)
assert "Filters:" in markdown
assert "None" in markdown

View File

@ -0,0 +1,469 @@
import pytest
import math
from datetime import datetime, timedelta
from unittest.mock import patch
from tradingagents.agents.discovery import NewsArticle, EventCategory, Sector
from tradingagents.agents.discovery.entity_extractor import EntityMention
class TestFrequencyCalculation:
def test_frequency_calculation_unique_article_count(self):
from tradingagents.agents.discovery.scorer import calculate_trending_scores
now = datetime.now()
articles = [
NewsArticle(
title="Apple Q4 Earnings",
source="Reuters",
url="https://reuters.com/article1",
published_at=now - timedelta(hours=1),
content_snippet="Apple Inc reported strong earnings.",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Apple iPhone Sales",
source="Bloomberg",
url="https://bloomberg.com/article2",
published_at=now - timedelta(hours=2),
content_snippet="Apple saw record iPhone sales.",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Apple AI Features",
source="WSJ",
url="https://wsj.com/article3",
published_at=now - timedelta(hours=3),
content_snippet="Apple announced AI features.",
ticker_mentions=["AAPL"],
),
]
mentions = [
EntityMention(
company_name="Apple Inc",
confidence=0.95,
context_snippet="Apple Inc reported strong earnings",
article_id="article_0",
event_type=EventCategory.EARNINGS,
),
EntityMention(
company_name="Apple",
confidence=0.90,
context_snippet="Apple saw record iPhone sales",
article_id="article_1",
event_type=EventCategory.EARNINGS,
),
EntityMention(
company_name="Apple Inc.",
confidence=0.92,
context_snippet="Apple announced AI features",
article_id="article_2",
event_type=EventCategory.PRODUCT_LAUNCH,
),
]
with patch(
"tradingagents.agents.discovery.scorer.resolve_ticker"
) as mock_resolve:
mock_resolve.return_value = "AAPL"
with patch(
"tradingagents.agents.discovery.scorer.classify_sector"
) as mock_sector:
mock_sector.return_value = "technology"
result = calculate_trending_scores(mentions, articles)
assert len(result) == 1
assert result[0].ticker == "AAPL"
assert result[0].mention_count == 3
class TestSentimentIntensityFactor:
def test_sentiment_intensity_uses_absolute_value(self):
from tradingagents.agents.discovery.scorer import calculate_trending_scores
now = datetime.now()
articles = [
NewsArticle(
title="Stock drops sharply",
source="Reuters",
url="https://reuters.com/article1",
published_at=now - timedelta(hours=1),
content_snippet="Company faced major issues.",
ticker_mentions=["TSLA"],
),
NewsArticle(
title="More bad news",
source="Bloomberg",
url="https://bloomberg.com/article2",
published_at=now - timedelta(hours=2),
content_snippet="Further decline expected.",
ticker_mentions=["TSLA"],
),
]
mentions = [
EntityMention(
company_name="Tesla",
confidence=0.95,
context_snippet="Company faced major issues",
article_id="article_0",
event_type=EventCategory.OTHER,
sentiment=-0.8,
),
EntityMention(
company_name="Tesla Inc",
confidence=0.90,
context_snippet="Further decline expected",
article_id="article_1",
event_type=EventCategory.OTHER,
sentiment=-0.6,
),
]
with patch(
"tradingagents.agents.discovery.scorer.resolve_ticker"
) as mock_resolve:
mock_resolve.return_value = "TSLA"
with patch(
"tradingagents.agents.discovery.scorer.classify_sector"
) as mock_sector:
mock_sector.return_value = "technology"
result = calculate_trending_scores(mentions, articles)
assert len(result) == 1
assert result[0].sentiment < 0
expected_sentiment = (-0.8 * 0.95 + -0.6 * 0.90) / (0.95 + 0.90)
assert abs(result[0].sentiment - expected_sentiment) < 0.01
class TestRecencyWeightExponentialDecay:
def test_recency_weight_exponential_decay(self):
from tradingagents.agents.discovery.scorer import calculate_trending_scores
now = datetime.now()
articles = [
NewsArticle(
title="Recent news",
source="Reuters",
url="https://reuters.com/article1",
published_at=now - timedelta(hours=1),
content_snippet="Recent company news.",
ticker_mentions=["NVDA"],
),
NewsArticle(
title="Older news",
source="Bloomberg",
url="https://bloomberg.com/article2",
published_at=now - timedelta(hours=10),
content_snippet="Older company news.",
ticker_mentions=["NVDA"],
),
]
mentions = [
EntityMention(
company_name="Nvidia",
confidence=0.90,
context_snippet="Recent company news",
article_id="article_0",
event_type=EventCategory.OTHER,
sentiment=0.5,
),
EntityMention(
company_name="Nvidia",
confidence=0.90,
context_snippet="Older company news",
article_id="article_1",
event_type=EventCategory.OTHER,
sentiment=0.5,
),
]
with patch(
"tradingagents.agents.discovery.scorer.resolve_ticker"
) as mock_resolve:
mock_resolve.return_value = "NVDA"
with patch(
"tradingagents.agents.discovery.scorer.classify_sector"
) as mock_sector:
mock_sector.return_value = "technology"
result = calculate_trending_scores(mentions, articles, decay_rate=0.1)
assert len(result) == 1
recent_weight = math.exp(-0.1 * 1)
older_weight = math.exp(-0.1 * 10)
avg_recency = (recent_weight + older_weight) / 2
assert result[0].score > 0
class TestMinimumThresholdFiltering:
def test_minimum_threshold_filtering_requires_two_articles(self):
from tradingagents.agents.discovery.scorer import calculate_trending_scores
now = datetime.now()
articles = [
NewsArticle(
title="Single mention stock",
source="Reuters",
url="https://reuters.com/article1",
published_at=now - timedelta(hours=1),
content_snippet="Some company news.",
ticker_mentions=["AMD"],
),
NewsArticle(
title="Multiple mention stock 1",
source="Bloomberg",
url="https://bloomberg.com/article2",
published_at=now - timedelta(hours=2),
content_snippet="Popular company news.",
ticker_mentions=["MSFT"],
),
NewsArticle(
title="Multiple mention stock 2",
source="WSJ",
url="https://wsj.com/article3",
published_at=now - timedelta(hours=3),
content_snippet="More popular company news.",
ticker_mentions=["MSFT"],
),
]
mentions = [
EntityMention(
company_name="AMD",
confidence=0.90,
context_snippet="Some company news",
article_id="article_0",
event_type=EventCategory.OTHER,
),
EntityMention(
company_name="Microsoft",
confidence=0.95,
context_snippet="Popular company news",
article_id="article_1",
event_type=EventCategory.OTHER,
),
EntityMention(
company_name="Microsoft Corp",
confidence=0.92,
context_snippet="More popular company news",
article_id="article_2",
event_type=EventCategory.OTHER,
),
]
with patch(
"tradingagents.agents.discovery.scorer.resolve_ticker"
) as mock_resolve:
def resolve_side_effect(name):
if "AMD" in name or name == "AMD":
return "AMD"
return "MSFT"
mock_resolve.side_effect = resolve_side_effect
with patch(
"tradingagents.agents.discovery.scorer.classify_sector"
) as mock_sector:
mock_sector.return_value = "technology"
result = calculate_trending_scores(mentions, articles, min_mentions=2)
assert len(result) == 1
assert result[0].ticker == "MSFT"
assert all(stock.mention_count >= 2 for stock in result)
class TestFinalScoreFormulaCorrectness:
def test_final_score_formula_correctness(self):
from tradingagents.agents.discovery.scorer import calculate_trending_scores
now = datetime.now()
hours_old = 2.0
articles = [
NewsArticle(
title="Test article 1",
source="Reuters",
url="https://reuters.com/article1",
published_at=now - timedelta(hours=hours_old),
content_snippet="Google announced results.",
ticker_mentions=["GOOGL"],
),
NewsArticle(
title="Test article 2",
source="Bloomberg",
url="https://bloomberg.com/article2",
published_at=now - timedelta(hours=hours_old),
content_snippet="Alphabet earnings beat.",
ticker_mentions=["GOOGL"],
),
]
sentiment_val = 0.6
confidence = 0.9
mentions = [
EntityMention(
company_name="Google",
confidence=confidence,
context_snippet="Google announced results",
article_id="article_0",
event_type=EventCategory.EARNINGS,
sentiment=sentiment_val,
),
EntityMention(
company_name="Alphabet",
confidence=confidence,
context_snippet="Alphabet earnings beat",
article_id="article_1",
event_type=EventCategory.EARNINGS,
sentiment=sentiment_val,
),
]
decay_rate = 0.1
with patch(
"tradingagents.agents.discovery.scorer.resolve_ticker"
) as mock_resolve:
mock_resolve.return_value = "GOOGL"
with patch(
"tradingagents.agents.discovery.scorer.classify_sector"
) as mock_sector:
mock_sector.return_value = "technology"
result = calculate_trending_scores(
mentions, articles, decay_rate=decay_rate
)
assert len(result) == 1
stock = result[0]
frequency = 2
sentiment_factor = 1 + abs(sentiment_val)
recency_weight = math.exp(-decay_rate * hours_old)
expected_score = frequency * sentiment_factor * recency_weight
assert abs(stock.score - expected_score) < 0.01
class TestSortingByScoreDescending:
def test_results_sorted_by_score_descending(self):
from tradingagents.agents.discovery.scorer import calculate_trending_scores
now = datetime.now()
articles = [
NewsArticle(
title="High score stock 1",
source="Reuters",
url="https://reuters.com/article1",
published_at=now - timedelta(hours=1),
content_snippet="Apple news.",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="High score stock 2",
source="Bloomberg",
url="https://bloomberg.com/article2",
published_at=now - timedelta(hours=1),
content_snippet="More Apple news.",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="High score stock 3",
source="WSJ",
url="https://wsj.com/article3",
published_at=now - timedelta(hours=1),
content_snippet="Even more Apple news.",
ticker_mentions=["AAPL"],
),
NewsArticle(
title="Low score stock 1",
source="CNBC",
url="https://cnbc.com/article4",
published_at=now - timedelta(hours=10),
content_snippet="Tesla news.",
ticker_mentions=["TSLA"],
),
NewsArticle(
title="Low score stock 2",
source="FT",
url="https://ft.com/article5",
published_at=now - timedelta(hours=10),
content_snippet="More Tesla news.",
ticker_mentions=["TSLA"],
),
]
mentions = [
EntityMention(
company_name="Apple",
confidence=0.95,
context_snippet="Apple news",
article_id="article_0",
event_type=EventCategory.OTHER,
sentiment=0.8,
),
EntityMention(
company_name="Apple Inc",
confidence=0.93,
context_snippet="More Apple news",
article_id="article_1",
event_type=EventCategory.OTHER,
sentiment=0.8,
),
EntityMention(
company_name="Apple",
confidence=0.90,
context_snippet="Even more Apple news",
article_id="article_2",
event_type=EventCategory.OTHER,
sentiment=0.8,
),
EntityMention(
company_name="Tesla",
confidence=0.85,
context_snippet="Tesla news",
article_id="article_3",
event_type=EventCategory.OTHER,
sentiment=0.2,
),
EntityMention(
company_name="Tesla Inc",
confidence=0.85,
context_snippet="More Tesla news",
article_id="article_4",
event_type=EventCategory.OTHER,
sentiment=0.2,
),
]
with patch(
"tradingagents.agents.discovery.scorer.resolve_ticker"
) as mock_resolve:
def resolve_side_effect(name):
if "Apple" in name:
return "AAPL"
if "Tesla" in name:
return "TSLA"
return None
mock_resolve.side_effect = resolve_side_effect
with patch(
"tradingagents.agents.discovery.scorer.classify_sector"
) as mock_sector:
mock_sector.return_value = "technology"
result = calculate_trending_scores(mentions, articles, min_mentions=2)
assert len(result) == 2
for i in range(len(result) - 1):
assert result[i].score >= result[i + 1].score

View File

@ -0,0 +1,94 @@
import pytest
from unittest.mock import patch, MagicMock
from tradingagents.dataflows.trending.sector_classifier import (
classify_sector,
TICKER_TO_SECTOR,
VALID_SECTORS,
_llm_classify_sector,
_sector_cache,
)
class TestStaticSectorMapping:
def test_static_sector_mapping_for_known_technology_tickers(self):
assert classify_sector("AAPL") == "technology"
assert classify_sector("MSFT") == "technology"
assert classify_sector("GOOGL") == "technology"
assert classify_sector("NVDA") == "technology"
def test_static_sector_mapping_for_known_healthcare_tickers(self):
assert classify_sector("JNJ") == "healthcare"
assert classify_sector("PFE") == "healthcare"
assert classify_sector("UNH") == "healthcare"
def test_static_sector_mapping_for_known_finance_tickers(self):
assert classify_sector("JPM") == "finance"
assert classify_sector("BAC") == "finance"
assert classify_sector("GS") == "finance"
def test_static_sector_mapping_for_known_energy_tickers(self):
assert classify_sector("XOM") == "energy"
assert classify_sector("CVX") == "energy"
assert classify_sector("COP") == "energy"
def test_static_sector_mapping_case_insensitive(self):
assert classify_sector("aapl") == "technology"
assert classify_sector("AAPL") == "technology"
assert classify_sector("Aapl") == "technology"
class TestLLMFallback:
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
def test_llm_fallback_for_unknown_tickers(self, mock_llm_classify):
mock_llm_classify.return_value = "technology"
_sector_cache.clear()
result = classify_sector("UNKNOWNTICKER123")
mock_llm_classify.assert_called_once_with("UNKNOWNTICKER123")
assert result == "technology"
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
def test_llm_fallback_caches_results(self, mock_llm_classify):
mock_llm_classify.return_value = "healthcare"
_sector_cache.clear()
result1 = classify_sector("NEWCO123")
result2 = classify_sector("NEWCO123")
assert mock_llm_classify.call_count == 1
assert result1 == "healthcare"
assert result2 == "healthcare"
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
def test_llm_fallback_returns_other_on_error(self, mock_llm_classify):
mock_llm_classify.side_effect = Exception("LLM error")
_sector_cache.clear()
result = classify_sector("ERRORCO")
assert result == "other"
class TestAllSectorCategories:
def test_all_sector_categories_in_valid_sectors(self):
expected_sectors = {
"technology",
"healthcare",
"finance",
"energy",
"consumer_goods",
"industrials",
"other",
}
assert VALID_SECTORS == expected_sectors
def test_static_mapping_covers_all_sector_categories(self):
sectors_in_mapping = set(TICKER_TO_SECTOR.values())
assert sectors_in_mapping.issubset(VALID_SECTORS)
def test_classify_sector_always_returns_valid_sector(self):
test_tickers = ["AAPL", "JPM", "XOM", "JNJ", "WMT", "CAT"]
for ticker in test_tickers:
result = classify_sector(ticker)
assert result in VALID_SECTORS

View File

@ -0,0 +1,135 @@
import pytest
import logging
from unittest.mock import patch, MagicMock
from tradingagents.dataflows.trending.stock_resolver import (
resolve_ticker,
validate_us_ticker,
_normalize_company_name,
_search_yfinance_ticker,
)
class TestStaticLookup:
def test_static_lookup_for_known_companies(self):
assert resolve_ticker("Apple") == "AAPL"
assert resolve_ticker("Microsoft") == "MSFT"
assert resolve_ticker("Google") == "GOOGL"
assert resolve_ticker("Amazon") == "AMZN"
assert resolve_ticker("Tesla") == "TSLA"
assert resolve_ticker("Nvidia") == "NVDA"
def test_static_lookup_case_insensitive(self):
assert resolve_ticker("APPLE") == "AAPL"
assert resolve_ticker("apple") == "AAPL"
assert resolve_ticker("ApPlE") == "AAPL"
assert resolve_ticker("microsoft") == "MSFT"
assert resolve_ticker("MICROSOFT") == "MSFT"
class TestNameVariationHandling:
def test_name_variation_handling_with_suffixes(self):
assert resolve_ticker("Apple Inc.") == "AAPL"
assert resolve_ticker("Apple Inc") == "AAPL"
assert resolve_ticker("Apple Corporation") == "AAPL"
assert resolve_ticker("Microsoft Corp.") == "MSFT"
assert resolve_ticker("Microsoft Corp") == "MSFT"
assert resolve_ticker("Tesla Inc") == "TSLA"
def test_name_variation_handling_informal_names(self):
assert resolve_ticker("the iPhone maker") == "AAPL"
assert resolve_ticker("iPhone maker") == "AAPL"
assert resolve_ticker("the search giant") == "GOOGL"
assert resolve_ticker("the e-commerce giant") == "AMZN"
assert resolve_ticker("EV maker Tesla") == "TSLA"
def test_name_variation_handling_alternate_names(self):
assert resolve_ticker("Alphabet") == "GOOGL"
assert resolve_ticker("Meta") == "META"
assert resolve_ticker("Facebook") == "META"
assert resolve_ticker("Meta Platforms") == "META"
class TestYfinanceFallback:
@patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
@patch("tradingagents.dataflows.trending.stock_resolver.validate_us_ticker")
def test_yfinance_fallback_for_unknown_company(self, mock_validate, mock_search):
mock_search.return_value = "PLTR"
mock_validate.return_value = True
result = resolve_ticker("UnknownTechStartupXYZ")
mock_search.assert_called_once()
assert result == "PLTR"
@patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
def test_yfinance_fallback_returns_none_when_not_found(self, mock_search):
mock_search.return_value = None
result = resolve_ticker("NonexistentCompanyXYZ123")
assert result is None
class TestUSExchangeValidation:
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
def test_validate_us_ticker_accepts_nyse(self, mock_ticker):
mock_info = {"exchange": "NYQ"}
mock_ticker.return_value.info = mock_info
assert validate_us_ticker("IBM") is True
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
def test_validate_us_ticker_accepts_nasdaq(self, mock_ticker):
mock_info = {"exchange": "NMS"}
mock_ticker.return_value.info = mock_info
assert validate_us_ticker("AAPL") is True
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
def test_validate_us_ticker_accepts_amex(self, mock_ticker):
mock_info = {"exchange": "ASE"}
mock_ticker.return_value.info = mock_info
assert validate_us_ticker("SPY") is True
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
def test_validate_us_ticker_rejects_international(self, mock_ticker):
mock_info = {"exchange": "LSE"}
mock_ticker.return_value.info = mock_info
assert validate_us_ticker("VOD.L") is False
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
def test_validate_us_ticker_rejects_otc(self, mock_ticker):
mock_info = {"exchange": "PNK"}
mock_ticker.return_value.info = mock_info
assert validate_us_ticker("OTCPK") is False
class TestAmbiguousResolutionLogging:
def test_ambiguous_resolution_logs_multiple_matches(self, caplog):
with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"):
pass
@patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
@patch("tradingagents.dataflows.trending.stock_resolver.validate_us_ticker")
def test_yfinance_fallback_is_logged(self, mock_validate, mock_search, caplog):
mock_search.return_value = "RBLX"
mock_validate.return_value = True
with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"):
result = resolve_ticker("SomeRandomCompanyNotInMapping")
assert any("fallback" in record.message.lower() or "yfinance" in record.message.lower()
for record in caplog.records)
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
def test_validation_failure_is_logged(self, mock_ticker, caplog):
mock_info = {"exchange": "LSE"}
mock_ticker.return_value.info = mock_info
with caplog.at_level(logging.WARNING, logger="tradingagents.dataflows.trending.stock_resolver"):
result = validate_us_ticker("VOD.L")
assert result is False

View File

@ -0,0 +1,53 @@
from .models import (
NewsArticle,
TrendingStock,
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
Sector,
EventCategory,
)
from .exceptions import (
DiscoveryError,
NewsUnavailableError,
DiscoveryTimeoutError,
TickerResolutionError,
)
from .entity_extractor import (
EntityMention,
extract_entities,
BATCH_SIZE,
)
from .scorer import (
calculate_trending_scores,
DEFAULT_DECAY_RATE,
DEFAULT_MAX_RESULTS,
DEFAULT_MIN_MENTIONS,
)
from .persistence import (
save_discovery_result,
generate_markdown_summary,
)
__all__ = [
"NewsArticle",
"TrendingStock",
"DiscoveryRequest",
"DiscoveryResult",
"DiscoveryStatus",
"Sector",
"EventCategory",
"DiscoveryError",
"NewsUnavailableError",
"DiscoveryTimeoutError",
"TickerResolutionError",
"EntityMention",
"extract_entities",
"BATCH_SIZE",
"calculate_trending_scores",
"DEFAULT_DECAY_RATE",
"DEFAULT_MAX_RESULTS",
"DEFAULT_MIN_MENTIONS",
"save_discovery_result",
"generate_markdown_summary",
]

View File

@ -0,0 +1,159 @@
from dataclasses import dataclass, field
from typing import List, Optional
from pydantic import BaseModel, Field as PydanticField
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.discovery.models import NewsArticle, EventCategory
BATCH_SIZE = 10
@dataclass
class EntityMention:
company_name: str
confidence: float
context_snippet: str
article_id: str
event_type: EventCategory
sentiment: float = field(default=0.0)
class ExtractedEntity(BaseModel):
company_name: str = PydanticField(description="The name of the publicly traded company mentioned")
confidence: float = PydanticField(description="Confidence score from 0.0 to 1.0 based on mention clarity")
context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
class ExtractionResponse(BaseModel):
entities: List[ExtractedEntity] = PydanticField(default_factory=list, description="List of extracted company entities")
def _get_llm(config: Optional[dict] = None):
cfg = config or DEFAULT_CONFIG
provider = cfg.get("llm_provider", "openai").lower()
model = cfg.get("quick_think_llm", "gpt-4o-mini")
backend_url = cfg.get("backend_url", "https://api.openai.com/v1")
if provider in ("openai", "ollama", "openrouter"):
return ChatOpenAI(model=model, base_url=backend_url)
elif provider == "anthropic":
return ChatAnthropic(model=model, base_url=backend_url)
elif provider == "google":
return ChatGoogleGenerativeAI(model=model)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
EXTRACTION_PROMPT = """You are an expert at identifying publicly traded companies mentioned in news articles.
For each article provided, extract all mentions of publicly traded companies. For each company mention:
1. Extract the company name as it appears (e.g., "Apple Inc.", "Apple", "AAPL", "the iPhone maker")
2. Assign a confidence score from 0.0 to 1.0 based on how clearly the company is mentioned:
- 0.9-1.0: Direct company name or ticker symbol
- 0.7-0.9: Clear reference with context (e.g., "the Cupertino tech giant")
- 0.5-0.7: Indirect reference requiring inference
- Below 0.5: Uncertain or ambiguous reference
3. Extract 50-100 characters of surrounding context
4. Classify the event type:
- earnings: Quarterly/annual earnings reports, revenue announcements
- merger_acquisition: Mergers, acquisitions, buyouts, takeovers
- regulatory: SEC filings, government investigations, compliance issues
- product_launch: New products, services, or features
- executive_change: CEO/CFO changes, board appointments, departures
- other: Any other business news
5. Assign a sentiment score from -1.0 to 1.0:
- -1.0: Very negative news (lawsuits, crashes, major failures)
- -0.5: Moderately negative news
- 0.0: Neutral news
- 0.5: Moderately positive news
- 1.0: Very positive news (breakthroughs, record earnings)
Only extract companies that are publicly traded on major stock exchanges.
Handle name variations by providing the most complete company name found.
Articles to analyze:
{articles_text}
Extract all company mentions from the articles above."""
def _format_articles_for_prompt(articles: List[NewsArticle], start_idx: int) -> str:
formatted = []
for i, article in enumerate(articles):
article_id = f"article_{start_idx + i}"
formatted.append(
f"[{article_id}]\n"
f"Title: {article.title}\n"
f"Source: {article.source}\n"
f"Content: {article.content_snippet}\n"
)
return "\n---\n".join(formatted)
def _extract_batch(
articles: List[NewsArticle],
start_idx: int,
llm,
) -> List[EntityMention]:
if not articles:
return []
articles_text = _format_articles_for_prompt(articles, start_idx)
prompt = EXTRACTION_PROMPT.format(articles_text=articles_text)
structured_llm = llm.with_structured_output(ExtractionResponse)
response = structured_llm.invoke(prompt)
mentions = []
for entity in response.entities:
event_type_str = entity.event_type.lower().strip()
valid_event_types = {e.value for e in EventCategory}
if event_type_str not in valid_event_types:
event_type_str = "other"
confidence = max(0.0, min(1.0, entity.confidence))
sentiment = max(-1.0, min(1.0, entity.sentiment))
context = entity.context_snippet
if len(context) > 150:
context = context[:147] + "..."
mention = EntityMention(
company_name=entity.company_name,
confidence=confidence,
context_snippet=context,
article_id=f"article_{start_idx}",
event_type=EventCategory(event_type_str),
sentiment=sentiment,
)
mentions.append(mention)
return mentions
def extract_entities(
articles: List[NewsArticle],
config: Optional[dict] = None,
) -> List[EntityMention]:
if not articles:
return []
llm = _get_llm(config)
all_mentions: List[EntityMention] = []
for batch_start in range(0, len(articles), BATCH_SIZE):
batch_end = min(batch_start + BATCH_SIZE, len(articles))
batch = articles[batch_start:batch_end]
batch_mentions = _extract_batch(batch, batch_start, llm)
all_mentions.extend(batch_mentions)
return all_mentions

View File

@ -0,0 +1,14 @@
class DiscoveryError(Exception):
pass
class NewsUnavailableError(DiscoveryError):
pass
class DiscoveryTimeoutError(DiscoveryError):
pass
class TickerResolutionError(DiscoveryError):
pass

View File

@ -0,0 +1,180 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import List, Optional, Dict, Any
class DiscoveryStatus(Enum):
CREATED = "created"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class Sector(Enum):
TECHNOLOGY = "technology"
HEALTHCARE = "healthcare"
FINANCE = "finance"
ENERGY = "energy"
CONSUMER_GOODS = "consumer_goods"
INDUSTRIALS = "industrials"
OTHER = "other"
class EventCategory(Enum):
EARNINGS = "earnings"
MERGER_ACQUISITION = "merger_acquisition"
REGULATORY = "regulatory"
PRODUCT_LAUNCH = "product_launch"
EXECUTIVE_CHANGE = "executive_change"
OTHER = "other"
@dataclass
class NewsArticle:
title: str
source: str
url: str
published_at: datetime
content_snippet: str
ticker_mentions: List[str]
def to_dict(self) -> Dict[str, Any]:
return {
"title": self.title,
"source": self.source,
"url": self.url,
"published_at": self.published_at.isoformat(),
"content_snippet": self.content_snippet,
"ticker_mentions": self.ticker_mentions,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NewsArticle":
return cls(
title=data["title"],
source=data["source"],
url=data["url"],
published_at=datetime.fromisoformat(data["published_at"]),
content_snippet=data["content_snippet"],
ticker_mentions=data["ticker_mentions"],
)
@dataclass
class TrendingStock:
ticker: str
company_name: str
score: float
mention_count: int
sentiment: float
sector: Sector
event_type: EventCategory
news_summary: str
source_articles: List[NewsArticle]
def to_dict(self) -> Dict[str, Any]:
return {
"ticker": self.ticker,
"company_name": self.company_name,
"score": self.score,
"mention_count": self.mention_count,
"sentiment": self.sentiment,
"sector": self.sector.value,
"event_type": self.event_type.value,
"news_summary": self.news_summary,
"source_articles": [article.to_dict() for article in self.source_articles],
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TrendingStock":
return cls(
ticker=data["ticker"],
company_name=data["company_name"],
score=data["score"],
mention_count=data["mention_count"],
sentiment=data["sentiment"],
sector=Sector(data["sector"]),
event_type=EventCategory(data["event_type"]),
news_summary=data["news_summary"],
source_articles=[
NewsArticle.from_dict(article) for article in data["source_articles"]
],
)
@dataclass
class DiscoveryRequest:
lookback_period: str
sector_filter: Optional[List[Sector]] = None
event_filter: Optional[List[EventCategory]] = None
max_results: int = 20
created_at: datetime = field(default_factory=datetime.now)
def to_dict(self) -> Dict[str, Any]:
return {
"lookback_period": self.lookback_period,
"sector_filter": (
[s.value for s in self.sector_filter] if self.sector_filter else None
),
"event_filter": (
[e.value for e in self.event_filter] if self.event_filter else None
),
"max_results": self.max_results,
"created_at": self.created_at.isoformat(),
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryRequest":
return cls(
lookback_period=data["lookback_period"],
sector_filter=(
[Sector(s) for s in data["sector_filter"]]
if data.get("sector_filter")
else None
),
event_filter=(
[EventCategory(e) for e in data["event_filter"]]
if data.get("event_filter")
else None
),
max_results=data.get("max_results", 20),
created_at=datetime.fromisoformat(data["created_at"]),
)
@dataclass
class DiscoveryResult:
request: DiscoveryRequest
trending_stocks: List[TrendingStock]
status: DiscoveryStatus
started_at: datetime
completed_at: Optional[datetime] = None
error_message: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
return {
"request": self.request.to_dict(),
"trending_stocks": [stock.to_dict() for stock in self.trending_stocks],
"status": self.status.value,
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"error_message": self.error_message,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryResult":
return cls(
request=DiscoveryRequest.from_dict(data["request"]),
trending_stocks=[
TrendingStock.from_dict(stock) for stock in data["trending_stocks"]
],
status=DiscoveryStatus(data["status"]),
started_at=datetime.fromisoformat(data["started_at"]),
completed_at=(
datetime.fromisoformat(data["completed_at"])
if data.get("completed_at")
else None
),
error_message=data.get("error_message"),
)

View File

@ -0,0 +1,120 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
from .models import DiscoveryResult, TrendingStock
def save_discovery_result(
result: DiscoveryResult,
base_path: Optional[Path] = None,
) -> Path:
if base_path is None:
base_path = Path("results")
timestamp = result.completed_at or result.started_at
date_str = timestamp.strftime("%Y-%m-%d")
time_str = timestamp.strftime("%H-%M-%S")
result_dir = base_path / "discovery" / date_str / time_str
result_dir.mkdir(parents=True, exist_ok=True)
json_path = result_dir / "discovery_result.json"
with open(json_path, "w") as f:
json.dump(result.to_dict(), f, indent=2)
md_path = result_dir / "discovery_summary.md"
markdown_content = generate_markdown_summary(result)
with open(md_path, "w") as f:
f.write(markdown_content)
return result_dir
def generate_markdown_summary(result: DiscoveryResult) -> str:
lines = []
lines.append("# Discovery Results")
lines.append("")
timestamp = result.completed_at or result.started_at
lines.append(f"**Timestamp:** {timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
lines.append(f"**Lookback Period:** {result.request.lookback_period}")
filters = _format_filters(result)
lines.append(f"**Filters:** {filters}")
lines.append(f"**Total Stocks Found:** {len(result.trending_stocks)}")
lines.append("")
lines.append("## Trending Stocks")
lines.append("")
lines.append("| Rank | Ticker | Company | Score | Mentions | Event |")
lines.append("|------|--------|---------|-------|----------|-------|")
for rank, stock in enumerate(result.trending_stocks, 1):
lines.append(
f"| {rank} | {stock.ticker} | {stock.company_name} | "
f"{stock.score:.2f} | {stock.mention_count} | {stock.event_type.value} |"
)
lines.append("")
lines.append("## Top 3 Detailed Analysis")
lines.append("")
top_stocks = result.trending_stocks[:3]
for rank, stock in enumerate(top_stocks, 1):
lines.extend(_format_stock_detail(rank, stock))
return "\n".join(lines)
def _format_filters(result: DiscoveryResult) -> str:
filter_parts = []
if result.request.sector_filter:
sector_values = [s.value for s in result.request.sector_filter]
filter_parts.append(f"sector={','.join(sector_values)}")
if result.request.event_filter:
event_values = [e.value for e in result.request.event_filter]
filter_parts.append(f"event={','.join(event_values)}")
if filter_parts:
return " ".join(filter_parts)
return "None"
def _format_stock_detail(rank: int, stock: TrendingStock) -> list:
lines = []
lines.append(f"### {rank}. {stock.ticker} - {stock.company_name}")
lines.append(f"- **Score:** {stock.score:.2f}")
sentiment_label = _get_sentiment_label(stock.sentiment)
lines.append(f"- **Sentiment:** {stock.sentiment:.2f} ({sentiment_label})")
lines.append(f"- **Sector:** {stock.sector.value}")
lines.append(f"- **Event Type:** {stock.event_type.value}")
lines.append(f"- **Mentions:** {stock.mention_count}")
lines.append("")
lines.append("**News Summary:**")
lines.append(stock.news_summary)
lines.append("")
if stock.source_articles:
lines.append("**Top Sources:**")
for article in stock.source_articles[:3]:
lines.append(f"- [{article.title}] - {article.source}")
lines.append("")
return lines
def _get_sentiment_label(sentiment: float) -> str:
if sentiment > 0.3:
return "positive"
elif sentiment < -0.3:
return "negative"
return "neutral"

View File

@ -0,0 +1,153 @@
import math
from collections import defaultdict
from datetime import datetime
from typing import List, Dict, Optional
from tradingagents.agents.discovery.models import (
TrendingStock,
NewsArticle,
Sector,
EventCategory,
)
from tradingagents.agents.discovery.entity_extractor import EntityMention
from tradingagents.dataflows.trending.stock_resolver import resolve_ticker
from tradingagents.dataflows.trending.sector_classifier import classify_sector
DEFAULT_DECAY_RATE = 0.1
DEFAULT_MAX_RESULTS = 20
DEFAULT_MIN_MENTIONS = 2
def _aggregate_sentiment(mentions: List[EntityMention]) -> float:
if not mentions:
return 0.0
total_weighted_sentiment = 0.0
total_confidence = 0.0
for mention in mentions:
total_weighted_sentiment += mention.sentiment * mention.confidence
total_confidence += mention.confidence
if total_confidence == 0:
return 0.0
return total_weighted_sentiment / total_confidence
def _calculate_recency_weight(
articles: List[NewsArticle],
article_ids: set,
decay_rate: float,
) -> float:
if not articles:
return 1.0
now = datetime.now()
weights = []
for i, article in enumerate(articles):
article_id = f"article_{i}"
if article_id in article_ids:
hours_old = (now - article.published_at).total_seconds() / 3600.0
weight = math.exp(-decay_rate * hours_old)
weights.append(weight)
if not weights:
return 1.0
return sum(weights) / len(weights)
def _get_most_common_event_type(mentions: List[EntityMention]) -> EventCategory:
if not mentions:
return EventCategory.OTHER
event_counts: Dict[EventCategory, int] = defaultdict(int)
for mention in mentions:
event_counts[mention.event_type] += 1
return max(event_counts.keys(), key=lambda e: event_counts[e])
def _build_news_summary(mentions: List[EntityMention]) -> str:
if not mentions:
return ""
snippets = [m.context_snippet for m in mentions[:3]]
return " ".join(snippets)
def calculate_trending_scores(
mentions: List[EntityMention],
articles: List[NewsArticle],
decay_rate: float = DEFAULT_DECAY_RATE,
max_results: int = DEFAULT_MAX_RESULTS,
min_mentions: int = DEFAULT_MIN_MENTIONS,
) -> List[TrendingStock]:
if not mentions:
return []
ticker_mentions: Dict[str, List[EntityMention]] = defaultdict(list)
ticker_company_names: Dict[str, str] = {}
for mention in mentions:
ticker = resolve_ticker(mention.company_name)
if ticker:
ticker_mentions[ticker].append(mention)
if ticker not in ticker_company_names:
ticker_company_names[ticker] = mention.company_name
article_index: Dict[str, int] = {}
for i, article in enumerate(articles):
article_index[f"article_{i}"] = i
trending_stocks: List[TrendingStock] = []
for ticker, ticker_mention_list in ticker_mentions.items():
article_ids = {m.article_id for m in ticker_mention_list}
frequency = len(article_ids)
if frequency < min_mentions:
continue
sentiment = _aggregate_sentiment(ticker_mention_list)
sentiment_factor = 1 + abs(sentiment)
recency_weight = _calculate_recency_weight(articles, article_ids, decay_rate)
score = frequency * sentiment_factor * recency_weight
sector_str = classify_sector(ticker)
try:
sector = Sector(sector_str)
except ValueError:
sector = Sector.OTHER
event_type = _get_most_common_event_type(ticker_mention_list)
source_article_list: List[NewsArticle] = []
for article_id in article_ids:
idx = article_index.get(article_id)
if idx is not None and idx < len(articles):
source_article_list.append(articles[idx])
news_summary = _build_news_summary(ticker_mention_list)
trending_stock = TrendingStock(
ticker=ticker,
company_name=ticker_company_names.get(ticker, ticker),
score=score,
mention_count=frequency,
sentiment=sentiment,
sector=sector,
event_type=event_type,
news_summary=news_summary,
source_articles=source_article_list,
)
trending_stocks.append(trending_stock)
trending_stocks.sort(key=lambda s: s.score, reverse=True)
return trending_stocks[:max_results]

View File

@ -7,44 +7,44 @@ from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
# Researcher team state
class InvestDebateState(TypedDict):
"""Researcher team state"""
bull_history: Annotated[
str, "Bullish Conversation history"
] # Bullish Conversation history
]
bear_history: Annotated[
str, "Bearish Conversation history"
] # Bullish Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
current_response: Annotated[str, "Latest response"] # Last response
judge_decision: Annotated[str, "Final judge decision"] # Last response
count: Annotated[int, "Length of the current conversation"] # Conversation length
]
history: Annotated[str, "Conversation history"]
current_response: Annotated[str, "Latest response"]
judge_decision: Annotated[str, "Final judge decision"]
count: Annotated[int, "Length of the current conversation"]
# Risk management team state
class RiskDebateState(TypedDict):
"""Risk management team state"""
risky_history: Annotated[
str, "Risky Agent's Conversation history"
] # Conversation history
]
safe_history: Annotated[
str, "Safe Agent's Conversation history"
] # Conversation history
]
neutral_history: Annotated[
str, "Neutral Agent's Conversation history"
] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
]
history: Annotated[str, "Conversation history"]
latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[
str, "Latest response by the risky analyst"
] # Last response
]
current_safe_response: Annotated[
str, "Latest response by the safe analyst"
] # Last response
]
current_neutral_response: Annotated[
str, "Latest response by the neutral analyst"
] # Last response
]
judge_decision: Annotated[str, "Judge's decision"]
count: Annotated[int, "Length of the current conversation"] # Conversation length
count: Annotated[int, "Length of the current conversation"]
class AgentState(MessagesState):
@ -53,7 +53,7 @@ class AgentState(MessagesState):
sender: Annotated[str, "Agent that sent this message"]
# research step
# research
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[
@ -61,7 +61,7 @@ class AgentState(MessagesState):
]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
# researcher team discussion step
# research
investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not"
]
@ -69,7 +69,7 @@ class AgentState(MessagesState):
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
# risk management team discussion step
# risk mgmt
risk_debate_state: Annotated[
RiskDebateState, "Current state of the debate on evaluating risk"
]

View File

@ -1,6 +1,5 @@
from langchain_core.messages import HumanMessage, RemoveMessage
# Import tools from separate utility files
from tradingagents.agents.utils.core_stock_tools import (
get_stock_data
)
@ -24,16 +23,7 @@ def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages

View File

@ -67,47 +67,45 @@ class FinancialSituationMemory:
return matched_results
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory()
# if __name__ == "__main__":
# # Example usage
# matcher = FinancialSituationMemory()
# example_data = [
# (
# "High inflation rate with rising interest rates and declining consumer spending",
# "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
# ),
# (
# "Tech sector showing high volatility with increasing institutional selling pressure",
# "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
# ),
# (
# "Strong dollar affecting emerging markets with increasing forex volatility",
# "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
# ),
# (
# "Market showing signs of sector rotation with rising yields",
# "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
# ),
# ]
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
# # Add the example situations and recommendations
# matcher.add_situations(example_data)
# Add the example situations and recommendations
matcher.add_situations(example_data)
# # Example query
# current_situation = """
# Market showing increased volatility in tech sector, with institutional investors
# reducing positions and rising interest rates affecting growth stock valuations
# """
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
# try:
# recommendations = matcher.get_memories(current_situation, n_matches=2)
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
# for i, rec in enumerate(recommendations, 1):
# print(f"\nMatch {i}:")
# print(f"Similarity Score: {rec['similarity_score']:.2f}")
# print(f"Matched Situation: {rec['matched_situation']}")
# print(f"Recommendation: {rec['recommendation']}")
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
except Exception as e:
print(f"Error during recommendation: {str(e)}")
# except Exception as e:
# print(f"Error during recommendation: {str(e)}")

View File

@ -1,19 +1,9 @@
import json
from datetime import datetime, timedelta
from typing import List, Dict, Any
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
Covers stocks, cryptocurrencies, forex, and topics like fiscal policy, mergers & acquisitions, IPOs.
Args:
ticker: Stock symbol for news articles.
start_date: Start date for news search.
end_date: End date for news search.
Returns:
Dictionary containing news sentiment data or JSON string.
"""
params = {
"tickers": ticker,
"time_from": format_datetime_for_api(start_date),
@ -21,23 +11,63 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"sort": "LATEST",
"limit": "50",
}
return _make_api_request("NEWS_SENTIMENT", params)
def get_insider_transactions(symbol: str) -> dict[str, str] | str:
"""Returns latest and historical insider transactions by key stakeholders.
Covers transactions by founders, executives, board members, etc.
Args:
symbol: Ticker symbol. Example: "IBM".
Returns:
Dictionary containing insider transaction data or JSON string.
"""
params = {
"symbol": symbol,
}
return _make_api_request("INSIDER_TRANSACTIONS", params)
return _make_api_request("INSIDER_TRANSACTIONS", params)
def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
end_date = datetime.now()
start_date = end_date - timedelta(hours=lookback_hours)
params = {
"time_from": format_datetime_for_api(start_date),
"time_to": format_datetime_for_api(end_date),
"sort": "LATEST",
"limit": "200",
"topics": "financial_markets,earnings,economy_fiscal,economy_monetary,mergers_and_acquisitions",
}
response = _make_api_request("NEWS_SENTIMENT", params)
if isinstance(response, str):
try:
response = json.loads(response)
except json.JSONDecodeError:
return []
if not isinstance(response, dict):
return []
feed = response.get("feed", [])
articles = []
for item in feed:
try:
time_published = item.get("time_published", "")
if time_published:
try:
published_at = datetime.strptime(time_published, "%Y%m%dT%H%M%S")
except ValueError:
published_at = datetime.now()
else:
published_at = datetime.now()
article = {
"title": item.get("title", ""),
"source": item.get("source", ""),
"url": item.get("url", ""),
"published_at": published_at.isoformat(),
"content_snippet": item.get("summary", "")[:500],
}
articles.append(article)
except Exception:
continue
return articles

View File

@ -1,5 +1,5 @@
from typing import Annotated
from datetime import datetime
from typing import Annotated, List, Dict, Any
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from .googlenews_utils import getNewsData
@ -27,4 +27,53 @@ def get_google_news(
if len(news_results) == 0:
return ""
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
def get_bulk_news_google(lookback_hours: int) -> List[Dict[str, Any]]:
end_date = datetime.now()
start_date = end_date - timedelta(hours=lookback_hours)
start_str = start_date.strftime("%Y-%m-%d")
end_str = end_date.strftime("%Y-%m-%d")
queries = [
"stock market",
"trading news",
"earnings report",
]
all_articles = []
seen_titles = set()
for query in queries:
try:
news_results = getNewsData(query.replace(" ", "+"), start_str, end_str)
for news in news_results:
title = news.get("title", "")
if title and title not in seen_titles:
seen_titles.add(title)
date_str = news.get("date", "")
try:
if date_str:
published_at = datetime.now()
else:
published_at = datetime.now()
except ValueError:
published_at = datetime.now()
article = {
"title": title,
"source": news.get("source", "Google News"),
"url": news.get("link", ""),
"published_at": published_at.isoformat(),
"content_snippet": news.get("snippet", "")[:500],
}
all_articles.append(article)
except Exception:
continue
return all_articles

View File

@ -1,10 +1,10 @@
from typing import Annotated
from typing import Annotated, List, Dict, Any, Optional
from datetime import datetime, timedelta
# Import from vendor-specific modules
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
from .google import get_google_news
from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai
from .google import get_google_news, get_bulk_news_google
from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai, get_bulk_news_openai
from .alpha_vantage import (
get_stock as get_alpha_vantage_stock,
get_indicator as get_alpha_vantage_indicator,
@ -15,12 +15,13 @@ from .alpha_vantage import (
get_insider_transactions as get_alpha_vantage_insider_transactions,
get_news as get_alpha_vantage_news
)
from .alpha_vantage_news import get_bulk_news_alpha_vantage
from .alpha_vantage_common import AlphaVantageRateLimitError
# Configuration and routing logic
from .config import get_config
# Tools organized by category
from tradingagents.agents.discovery import NewsArticle
TOOLS_CATEGORIES = {
"core_stock_apis": {
"description": "OHLCV stock price data",
@ -50,6 +51,7 @@ TOOLS_CATEGORIES = {
"get_global_news",
"get_insider_sentiment",
"get_insider_transactions",
"get_bulk_news",
]
}
}
@ -61,21 +63,17 @@ VENDOR_LIST = [
"google"
]
# Mapping of methods to their vendor-specific implementations
VENDOR_METHODS = {
# core_stock_apis
"get_stock_data": {
"alpha_vantage": get_alpha_vantage_stock,
"yfinance": get_YFin_data_online,
"local": get_YFin_data,
},
# technical_indicators
"get_indicators": {
"alpha_vantage": get_alpha_vantage_indicator,
"yfinance": get_stock_stats_indicators_window,
"local": get_stock_stats_indicators_window
},
# fundamental_data
"get_fundamentals": {
"alpha_vantage": get_alpha_vantage_fundamentals,
"openai": get_fundamentals_openai,
@ -95,7 +93,6 @@ VENDOR_METHODS = {
"yfinance": get_yfinance_income_statement,
"local": get_simfin_income_statements,
},
# news_data
"get_news": {
"alpha_vantage": get_alpha_vantage_news,
"openai": get_stock_news_openai,
@ -114,56 +111,159 @@ VENDOR_METHODS = {
"yfinance": get_yfinance_insider_transactions,
"local": get_finnhub_company_insider_transactions,
},
"get_bulk_news": {
"alpha_vantage": get_bulk_news_alpha_vantage,
"openai": get_bulk_news_openai,
"google": get_bulk_news_google,
},
}
CACHE_TTL_SECONDS = 300
_bulk_news_cache: Dict[str, Dict[str, Any]] = {}
def parse_lookback_period(lookback: str) -> int:
lookback = lookback.lower().strip()
if lookback == "1h":
return 1
elif lookback == "6h":
return 6
elif lookback == "24h":
return 24
elif lookback == "7d":
return 168
else:
raise ValueError(f"Invalid lookback period: {lookback}. Valid values: 1h, 6h, 24h, 7d")
def _get_cached_bulk_news(lookback_period: str) -> Optional[List[NewsArticle]]:
cache_key = lookback_period
if cache_key in _bulk_news_cache:
cached = _bulk_news_cache[cache_key]
cached_time = cached.get("timestamp")
if cached_time and (datetime.now() - cached_time).total_seconds() < CACHE_TTL_SECONDS:
return cached.get("articles")
return None
def _set_cached_bulk_news(lookback_period: str, articles: List[NewsArticle]) -> None:
cache_key = lookback_period
_bulk_news_cache[cache_key] = {
"timestamp": datetime.now(),
"articles": articles,
}
def _convert_to_news_articles(raw_articles: List[Dict[str, Any]]) -> List[NewsArticle]:
articles = []
for item in raw_articles:
try:
published_at_str = item.get("published_at", "")
if isinstance(published_at_str, str):
try:
published_at = datetime.fromisoformat(published_at_str.replace("Z", "+00:00"))
except ValueError:
published_at = datetime.now()
elif isinstance(published_at_str, datetime):
published_at = published_at_str
else:
published_at = datetime.now()
article = NewsArticle(
title=item.get("title", ""),
source=item.get("source", ""),
url=item.get("url", ""),
published_at=published_at,
content_snippet=item.get("content_snippet", ""),
ticker_mentions=[],
)
articles.append(article)
except Exception:
continue
return articles
def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]:
lookback_hours = parse_lookback_period(lookback_period)
vendor_order = ["alpha_vantage", "openai", "google"]
for vendor in vendor_order:
if vendor not in VENDOR_METHODS["get_bulk_news"]:
continue
vendor_func = VENDOR_METHODS["get_bulk_news"][vendor]
try:
print(f"DEBUG: Attempting bulk news from vendor '{vendor}'...")
result = vendor_func(lookback_hours)
if result:
print(f"SUCCESS: Got {len(result)} articles from vendor '{vendor}'")
return result
print(f"DEBUG: Vendor '{vendor}' returned empty results, trying next...")
except AlphaVantageRateLimitError as e:
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded: {e}")
continue
except Exception as e:
print(f"FAILED: Vendor '{vendor}' failed: {e}")
continue
return []
def get_bulk_news(lookback_period: str = "24h") -> List[NewsArticle]:
cached = _get_cached_bulk_news(lookback_period)
if cached is not None:
print(f"DEBUG: Returning cached bulk news for period '{lookback_period}'")
return cached
raw_articles = _fetch_bulk_news_from_vendor(lookback_period)
articles = _convert_to_news_articles(raw_articles)
_set_cached_bulk_news(lookback_period, articles)
return articles
def get_category_for_method(method: str) -> str:
"""Get the category that contains the specified method."""
for category, info in TOOLS_CATEGORIES.items():
if method in info["tools"]:
return category
raise ValueError(f"Method '{method}' not found in any category")
def get_vendor(category: str, method: str = None) -> str:
"""Get the configured vendor for a data category or specific tool method.
Tool-level configuration takes precedence over category-level.
"""
config = get_config()
# Check tool-level configuration first (if method provided)
if method:
tool_vendors = config.get("tool_vendors", {})
if method in tool_vendors:
return tool_vendors[method]
# Fall back to category-level configuration
return config.get("data_vendors", {}).get(category, "default")
def route_to_vendor(method: str, *args, **kwargs):
"""Route method calls to appropriate vendor implementation with fallback support."""
category = get_category_for_method(method)
vendor_config = get_vendor(category, method)
# Handle comma-separated vendors
primary_vendors = [v.strip() for v in vendor_config.split(',')]
if method not in VENDOR_METHODS:
raise ValueError(f"Method '{method}' not supported")
# Get all available vendors for this method for fallback
all_available_vendors = list(VENDOR_METHODS[method].keys())
# Create fallback vendor list: primary vendors first, then remaining vendors as fallbacks
fallback_vendors = primary_vendors.copy()
for vendor in all_available_vendors:
if vendor not in fallback_vendors:
fallback_vendors.append(vendor)
# Debug: Print fallback ordering
primary_str = "".join(primary_vendors)
fallback_str = "".join(fallback_vendors)
primary_str = " -> ".join(primary_vendors)
fallback_str = " -> ".join(fallback_vendors)
print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]")
# Track results and execution state
results = []
vendor_attempt_count = 0
any_primary_vendor_attempted = False
@ -179,22 +279,18 @@ def route_to_vendor(method: str, *args, **kwargs):
is_primary_vendor = vendor in primary_vendors
vendor_attempt_count += 1
# Track if we attempted any primary vendor
if is_primary_vendor:
any_primary_vendor_attempted = True
# Debug: Print current attempt
vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK"
print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})")
# Handle list of methods for a vendor
if isinstance(vendor_impl, list):
vendor_methods = [(impl, vendor) for impl in vendor_impl]
print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions")
else:
vendor_methods = [(vendor_impl, vendor)]
# Run methods for this vendor
vendor_results = []
for impl_func, vendor_name in vendor_methods:
try:
@ -202,43 +298,35 @@ def route_to_vendor(method: str, *args, **kwargs):
result = impl_func(*args, **kwargs)
vendor_results.append(result)
print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully")
except AlphaVantageRateLimitError as e:
if vendor == "alpha_vantage":
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor")
print(f"DEBUG: Rate limit details: {e}")
# Continue to next vendor for fallback
continue
except Exception as e:
# Log error but continue with other implementations
print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}")
continue
# Add this vendor's results
if vendor_results:
results.extend(vendor_results)
successful_vendor = vendor
result_summary = f"Got {len(vendor_results)} result(s)"
print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}")
# Stopping logic: Stop after first successful vendor for single-vendor configs
# Multiple vendor configs (comma-separated) may want to collect from multiple sources
if len(primary_vendors) == 1:
print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)")
break
else:
print(f"FAILED: Vendor '{vendor}' produced no results")
# Final result summary
if not results:
print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'")
raise RuntimeError(f"All vendor implementations failed for method '{method}'")
else:
print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)")
# Return single result if only one, otherwise concatenate as string
if len(results) == 1:
return results[0]
else:
# Convert all results to strings and concatenate
return '\n'.join(str(result) for result in results)
return '\n'.join(str(result) for result in results)

View File

@ -1,3 +1,7 @@
import json
import re
from datetime import datetime, timedelta
from typing import List, Dict, Any
from openai import OpenAI
from .config import get_config
@ -104,4 +108,91 @@ def get_fundamentals_openai(ticker, curr_date):
store=True,
)
return response.output[1].content[0].text
return response.output[1].content[0].text
def get_bulk_news_openai(lookback_hours: int) -> List[Dict[str, Any]]:
config = get_config()
client = OpenAI(base_url=config["backend_url"])
end_date = datetime.now()
start_date = end_date - timedelta(hours=lookback_hours)
start_str = start_date.strftime("%Y-%m-%d %H:%M")
end_str = end_date.strftime("%Y-%m-%d %H:%M")
prompt = f"""Search for recent stock market news, trading news, and earnings announcements from {start_str} to {end_str}.
Return the results as a JSON array with the following structure:
[
{{
"title": "Article title",
"source": "Source name",
"url": "https://...",
"published_at": "YYYY-MM-DDTHH:MM:SS",
"content_snippet": "Brief summary of the article..."
}}
]
Focus on:
- Stock market movements and trends
- Company earnings reports
- Mergers and acquisitions
- Significant trading activity
- Economic news affecting markets
Return ONLY the JSON array, no additional text."""
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": prompt,
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "medium",
}
],
temperature=0.5,
max_output_tokens=8192,
top_p=1,
store=True,
)
try:
response_text = response.output[1].content[0].text
json_match = re.search(r'\[[\s\S]*\]', response_text)
if json_match:
articles = json.loads(json_match.group())
else:
articles = json.loads(response_text)
result = []
for item in articles:
if isinstance(item, dict):
article = {
"title": item.get("title", ""),
"source": item.get("source", "Web Search"),
"url": item.get("url", ""),
"published_at": item.get("published_at", datetime.now().isoformat()),
"content_snippet": item.get("content_snippet", "")[:500],
}
result.append(article)
return result
except (json.JSONDecodeError, IndexError, AttributeError):
return []

View File

@ -0,0 +1,21 @@
from .stock_resolver import (
resolve_ticker,
validate_tradeable,
validate_us_ticker,
COMPANY_TO_TICKER,
)
from .sector_classifier import (
classify_sector,
TICKER_TO_SECTOR,
VALID_SECTORS,
)
__all__ = [
"resolve_ticker",
"validate_tradeable",
"validate_us_ticker",
"COMPANY_TO_TICKER",
"classify_sector",
"TICKER_TO_SECTOR",
"VALID_SECTORS",
]

View File

@ -0,0 +1,267 @@
import logging
from typing import Dict
logger = logging.getLogger(__name__)
VALID_SECTORS = {
"technology",
"healthcare",
"finance",
"energy",
"consumer_goods",
"industrials",
"other",
}
TICKER_TO_SECTOR: Dict[str, str] = {
"AAPL": "technology",
"MSFT": "technology",
"GOOGL": "technology",
"GOOG": "technology",
"AMZN": "technology",
"META": "technology",
"NVDA": "technology",
"TSLA": "technology",
"AMD": "technology",
"INTC": "technology",
"QCOM": "technology",
"AVGO": "technology",
"TXN": "technology",
"ADBE": "technology",
"CRM": "technology",
"CSCO": "technology",
"NFLX": "technology",
"ORCL": "technology",
"IBM": "technology",
"NOW": "technology",
"INTU": "technology",
"ADSK": "technology",
"SNPS": "technology",
"CDNS": "technology",
"PLTR": "technology",
"SNOW": "technology",
"DDOG": "technology",
"CRWD": "technology",
"OKTA": "technology",
"NET": "technology",
"MDB": "technology",
"TWLO": "technology",
"WDAY": "technology",
"SPLK": "technology",
"VMW": "technology",
"HPQ": "technology",
"DELL": "technology",
"FTNT": "technology",
"PANW": "technology",
"ZS": "technology",
"S": "technology",
"VEEV": "technology",
"ZM": "technology",
"DOCU": "technology",
"ASAN": "technology",
"MNDY": "technology",
"TEAM": "technology",
"ANSS": "technology",
"ROP": "technology",
"JPM": "finance",
"BAC": "finance",
"WFC": "finance",
"GS": "finance",
"MS": "finance",
"C": "finance",
"BLK": "finance",
"SCHW": "finance",
"AXP": "finance",
"V": "finance",
"MA": "finance",
"PYPL": "finance",
"SQ": "finance",
"COIN": "finance",
"HOOD": "finance",
"SOFI": "finance",
"AFRM": "finance",
"MQ": "finance",
"BRK-B": "finance",
"BRK-A": "finance",
"JNJ": "healthcare",
"UNH": "healthcare",
"PFE": "healthcare",
"ABBV": "healthcare",
"MRK": "healthcare",
"LLY": "healthcare",
"MRNA": "healthcare",
"BNTX": "healthcare",
"CVS": "healthcare",
"WBA": "healthcare",
"MCK": "healthcare",
"CAH": "healthcare",
"HUM": "healthcare",
"CI": "healthcare",
"ELV": "healthcare",
"XOM": "energy",
"CVX": "energy",
"COP": "energy",
"SLB": "energy",
"HAL": "energy",
"BKR": "energy",
"MPC": "energy",
"VLO": "energy",
"PSX": "energy",
"OXY": "energy",
"PXD": "energy",
"DVN": "energy",
"CEG": "energy",
"NEE": "energy",
"DUK": "energy",
"SO": "energy",
"D": "energy",
"SRE": "energy",
"WMT": "consumer_goods",
"COST": "consumer_goods",
"TGT": "consumer_goods",
"HD": "consumer_goods",
"LOW": "consumer_goods",
"PG": "consumer_goods",
"KO": "consumer_goods",
"PEP": "consumer_goods",
"NKE": "consumer_goods",
"SBUX": "consumer_goods",
"MCD": "consumer_goods",
"CMG": "consumer_goods",
"YUM": "consumer_goods",
"DPZ": "consumer_goods",
"DIS": "consumer_goods",
"CMCSA": "consumer_goods",
"VZ": "consumer_goods",
"T": "consumer_goods",
"TMUS": "consumer_goods",
"EL": "consumer_goods",
"CL": "consumer_goods",
"KMB": "consumer_goods",
"CLX": "consumer_goods",
"KHC": "consumer_goods",
"GIS": "consumer_goods",
"K": "consumer_goods",
"MDLZ": "consumer_goods",
"HSY": "consumer_goods",
"TSN": "consumer_goods",
"BYND": "consumer_goods",
"CAG": "consumer_goods",
"STZ": "consumer_goods",
"BUD": "consumer_goods",
"DEO": "consumer_goods",
"PM": "consumer_goods",
"MO": "consumer_goods",
"LULU": "consumer_goods",
"DG": "consumer_goods",
"DLTR": "consumer_goods",
"ROST": "consumer_goods",
"TJX": "consumer_goods",
"AZO": "consumer_goods",
"ORLY": "consumer_goods",
"KMX": "consumer_goods",
"ADDYY": "consumer_goods",
"UBER": "consumer_goods",
"LYFT": "consumer_goods",
"ABNB": "consumer_goods",
"DASH": "consumer_goods",
"SNAP": "consumer_goods",
"PINS": "consumer_goods",
"TWTR": "consumer_goods",
"SHOP": "consumer_goods",
"TOST": "consumer_goods",
"BA": "industrials",
"LMT": "industrials",
"RTX": "industrials",
"GD": "industrials",
"NOC": "industrials",
"GE": "industrials",
"HON": "industrials",
"MMM": "industrials",
"CAT": "industrials",
"DE": "industrials",
"UNP": "industrials",
"UPS": "industrials",
"FDX": "industrials",
"DAL": "industrials",
"UAL": "industrials",
"AAL": "industrials",
"LUV": "industrials",
"F": "industrials",
"GM": "industrials",
"TM": "industrials",
"HMC": "industrials",
"VWAGY": "industrials",
"RACE": "industrials",
"RIVN": "industrials",
"LCID": "industrials",
"NIO": "industrials",
"LNVGY": "industrials",
}
_sector_cache: Dict[str, str] = {}
def _llm_classify_sector(ticker: str) -> str:
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from tradingagents.default_config import DEFAULT_CONFIG
llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini")
llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai")
backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1")
llm = ChatOpenAI(
model=llm_name,
base_url=backend_url,
temperature=0,
)
system_prompt = (
"You are a financial sector classifier. Given a stock ticker symbol, "
"classify it into exactly one of the following sectors: "
"technology, healthcare, finance, energy, consumer_goods, industrials, other. "
"Respond with only the sector name in lowercase, nothing else."
)
user_prompt = f"Classify the stock ticker: {ticker}"
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_prompt),
]
response = llm.invoke(messages)
sector = response.content.strip().lower()
if sector not in VALID_SECTORS:
logger.warning(
"LLM returned invalid sector '%s' for ticker %s, defaulting to 'other'",
sector,
ticker,
)
return "other"
return sector
def classify_sector(ticker: str) -> str:
ticker_upper = ticker.upper()
if ticker_upper in TICKER_TO_SECTOR:
return TICKER_TO_SECTOR[ticker_upper]
if ticker_upper in _sector_cache:
return _sector_cache[ticker_upper]
logger.info("Using LLM fallback for sector classification of ticker: %s", ticker)
try:
sector = _llm_classify_sector(ticker_upper)
_sector_cache[ticker_upper] = sector
logger.info("Classified %s as %s via LLM", ticker, sector)
return sector
except Exception as e:
logger.error("LLM sector classification failed for %s: %s", ticker, str(e))
_sector_cache[ticker_upper] = "other"
return "other"

View File

@ -0,0 +1,538 @@
import logging
import re
from typing import Optional
import yfinance as yf
logger = logging.getLogger(__name__)
COMPANY_TO_TICKER = {
"apple": "AAPL",
"apple inc": "AAPL",
"apple inc.": "AAPL",
"apple corporation": "AAPL",
"the iphone maker": "AAPL",
"iphone maker": "AAPL",
"microsoft": "MSFT",
"microsoft inc": "MSFT",
"microsoft inc.": "MSFT",
"microsoft corp": "MSFT",
"microsoft corp.": "MSFT",
"microsoft corporation": "MSFT",
"google": "GOOGL",
"alphabet": "GOOGL",
"alphabet inc": "GOOGL",
"alphabet inc.": "GOOGL",
"the search giant": "GOOGL",
"amazon": "AMZN",
"amazon inc": "AMZN",
"amazon inc.": "AMZN",
"amazon.com": "AMZN",
"amazon.com inc": "AMZN",
"the e-commerce giant": "AMZN",
"e-commerce giant": "AMZN",
"meta": "META",
"meta platforms": "META",
"meta platforms inc": "META",
"meta platforms inc.": "META",
"facebook": "META",
"facebook inc": "META",
"facebook inc.": "META",
"tesla": "TSLA",
"tesla inc": "TSLA",
"tesla inc.": "TSLA",
"tesla motors": "TSLA",
"ev maker tesla": "TSLA",
"nvidia": "NVDA",
"nvidia corp": "NVDA",
"nvidia corp.": "NVDA",
"nvidia corporation": "NVDA",
"berkshire hathaway": "BRK-B",
"berkshire": "BRK-B",
"jpmorgan": "JPM",
"jpmorgan chase": "JPM",
"jp morgan": "JPM",
"jp morgan chase": "JPM",
"johnson & johnson": "JNJ",
"johnson and johnson": "JNJ",
"j&j": "JNJ",
"unitedhealth": "UNH",
"unitedhealth group": "UNH",
"visa": "V",
"visa inc": "V",
"visa inc.": "V",
"procter & gamble": "PG",
"procter and gamble": "PG",
"p&g": "PG",
"mastercard": "MA",
"mastercard inc": "MA",
"mastercard inc.": "MA",
"home depot": "HD",
"the home depot": "HD",
"chevron": "CVX",
"chevron corp": "CVX",
"chevron corporation": "CVX",
"exxon": "XOM",
"exxon mobil": "XOM",
"exxonmobil": "XOM",
"pfizer": "PFE",
"pfizer inc": "PFE",
"pfizer inc.": "PFE",
"abbvie": "ABBV",
"abbvie inc": "ABBV",
"abbvie inc.": "ABBV",
"coca-cola": "KO",
"coca cola": "KO",
"coke": "KO",
"the coca-cola company": "KO",
"pepsico": "PEP",
"pepsi": "PEP",
"pepsi co": "PEP",
"costco": "COST",
"costco wholesale": "COST",
"walmart": "WMT",
"wal-mart": "WMT",
"walmart inc": "WMT",
"bank of america": "BAC",
"bofa": "BAC",
"merck": "MRK",
"merck & co": "MRK",
"merck and co": "MRK",
"eli lilly": "LLY",
"lilly": "LLY",
"eli lilly and company": "LLY",
"adobe": "ADBE",
"adobe inc": "ADBE",
"adobe inc.": "ADBE",
"adobe systems": "ADBE",
"salesforce": "CRM",
"salesforce inc": "CRM",
"salesforce.com": "CRM",
"cisco": "CSCO",
"cisco systems": "CSCO",
"cisco systems inc": "CSCO",
"netflix": "NFLX",
"netflix inc": "NFLX",
"netflix inc.": "NFLX",
"oracle": "ORCL",
"oracle corp": "ORCL",
"oracle corporation": "ORCL",
"intel": "INTC",
"intel corp": "INTC",
"intel corporation": "INTC",
"amd": "AMD",
"advanced micro devices": "AMD",
"qualcomm": "QCOM",
"qualcomm inc": "QCOM",
"qualcomm inc.": "QCOM",
"broadcom": "AVGO",
"broadcom inc": "AVGO",
"broadcom inc.": "AVGO",
"texas instruments": "TXN",
"ti": "TXN",
"disney": "DIS",
"walt disney": "DIS",
"the walt disney company": "DIS",
"walt disney company": "DIS",
"comcast": "CMCSA",
"comcast corp": "CMCSA",
"comcast corporation": "CMCSA",
"verizon": "VZ",
"verizon communications": "VZ",
"at&t": "T",
"att": "T",
"t-mobile": "TMUS",
"tmobile": "TMUS",
"t-mobile us": "TMUS",
"american express": "AXP",
"amex": "AXP",
"goldman sachs": "GS",
"goldman": "GS",
"morgan stanley": "MS",
"wells fargo": "WFC",
"wells": "WFC",
"citigroup": "C",
"citi": "C",
"citibank": "C",
"charles schwab": "SCHW",
"schwab": "SCHW",
"blackrock": "BLK",
"blackrock inc": "BLK",
"paypal": "PYPL",
"paypal holdings": "PYPL",
"paypal inc": "PYPL",
"square": "SQ",
"block": "SQ",
"block inc": "SQ",
"shopify": "SHOP",
"shopify inc": "SHOP",
"uber": "UBER",
"uber technologies": "UBER",
"lyft": "LYFT",
"lyft inc": "LYFT",
"airbnb": "ABNB",
"airbnb inc": "ABNB",
"doordash": "DASH",
"doordash inc": "DASH",
"snap": "SNAP",
"snap inc": "SNAP",
"snapchat": "SNAP",
"pinterest": "PINS",
"pinterest inc": "PINS",
"twitter": "TWTR",
"twitter inc": "TWTR",
"linkedin": "MSFT",
"zoom": "ZM",
"zoom video": "ZM",
"zoom video communications": "ZM",
"slack": "CRM",
"slack technologies": "CRM",
"palantir": "PLTR",
"palantir technologies": "PLTR",
"snowflake": "SNOW",
"snowflake inc": "SNOW",
"datadog": "DDOG",
"datadog inc": "DDOG",
"crowdstrike": "CRWD",
"crowdstrike holdings": "CRWD",
"okta": "OKTA",
"okta inc": "OKTA",
"cloudflare": "NET",
"cloudflare inc": "NET",
"mongodb": "MDB",
"mongodb inc": "MDB",
"twilio": "TWLO",
"twilio inc": "TWLO",
"servicenow": "NOW",
"servicenow inc": "NOW",
"workday": "WDAY",
"workday inc": "WDAY",
"splunk": "SPLK",
"splunk inc": "SPLK",
"vmware": "VMW",
"vmware inc": "VMW",
"ibm": "IBM",
"international business machines": "IBM",
"hp": "HPQ",
"hewlett-packard": "HPQ",
"hewlett packard": "HPQ",
"dell": "DELL",
"dell technologies": "DELL",
"lenovo": "LNVGY",
"boeing": "BA",
"boeing company": "BA",
"the boeing company": "BA",
"lockheed martin": "LMT",
"lockheed": "LMT",
"raytheon": "RTX",
"rtx": "RTX",
"general dynamics": "GD",
"northrop grumman": "NOC",
"northrop": "NOC",
"general electric": "GE",
"ge": "GE",
"honeywell": "HON",
"honeywell international": "HON",
"3m": "MMM",
"3m company": "MMM",
"caterpillar": "CAT",
"caterpillar inc": "CAT",
"deere": "DE",
"john deere": "DE",
"deere & company": "DE",
"union pacific": "UNP",
"ups": "UPS",
"united parcel service": "UPS",
"fedex": "FDX",
"federal express": "FDX",
"delta": "DAL",
"delta air lines": "DAL",
"delta airlines": "DAL",
"united airlines": "UAL",
"united": "UAL",
"american airlines": "AAL",
"southwest": "LUV",
"southwest airlines": "LUV",
"ford": "F",
"ford motor": "F",
"ford motor company": "F",
"general motors": "GM",
"gm": "GM",
"toyota": "TM",
"toyota motor": "TM",
"honda": "HMC",
"honda motor": "HMC",
"volkswagen": "VWAGY",
"vw": "VWAGY",
"ferrari": "RACE",
"rivian": "RIVN",
"rivian automotive": "RIVN",
"lucid": "LCID",
"lucid motors": "LCID",
"lucid group": "LCID",
"nio": "NIO",
"nio inc": "NIO",
"moderna": "MRNA",
"moderna inc": "MRNA",
"biontech": "BNTX",
"cvs": "CVS",
"cvs health": "CVS",
"walgreens": "WBA",
"walgreens boots alliance": "WBA",
"mckesson": "MCK",
"mckesson corp": "MCK",
"cardinal health": "CAH",
"humana": "HUM",
"humana inc": "HUM",
"cigna": "CI",
"cigna group": "CI",
"anthem": "ELV",
"elevance health": "ELV",
"starbucks": "SBUX",
"starbucks corp": "SBUX",
"starbucks corporation": "SBUX",
"mcdonalds": "MCD",
"mcdonald's": "MCD",
"chipotle": "CMG",
"chipotle mexican grill": "CMG",
"yum brands": "YUM",
"yum": "YUM",
"dominos": "DPZ",
"domino's": "DPZ",
"domino's pizza": "DPZ",
"nike": "NKE",
"nike inc": "NKE",
"adidas": "ADDYY",
"lululemon": "LULU",
"lululemon athletica": "LULU",
"target": "TGT",
"target corp": "TGT",
"target corporation": "TGT",
"dollar general": "DG",
"dollar tree": "DLTR",
"ross stores": "ROST",
"ross": "ROST",
"tjx": "TJX",
"tjx companies": "TJX",
"tj maxx": "TJX",
"lowes": "LOW",
"lowe's": "LOW",
"lowe's companies": "LOW",
"autozone": "AZO",
"o'reilly": "ORLY",
"o'reilly automotive": "ORLY",
"carmax": "KMX",
"estee lauder": "EL",
"colgate": "CL",
"colgate-palmolive": "CL",
"colgate palmolive": "CL",
"kimberly-clark": "KMB",
"kimberly clark": "KMB",
"clorox": "CLX",
"clorox company": "CLX",
"kraft heinz": "KHC",
"kraft": "KHC",
"heinz": "KHC",
"general mills": "GIS",
"kellogg": "K",
"kellogg's": "K",
"mondelez": "MDLZ",
"mondelez international": "MDLZ",
"hershey": "HSY",
"the hershey company": "HSY",
"tyson": "TSN",
"tyson foods": "TSN",
"beyond meat": "BYND",
"conagra": "CAG",
"conagra brands": "CAG",
"constellation brands": "STZ",
"anheuser-busch": "BUD",
"anheuser busch": "BUD",
"ab inbev": "BUD",
"diageo": "DEO",
"philip morris": "PM",
"philip morris international": "PM",
"altria": "MO",
"altria group": "MO",
"constellation energy": "CEG",
"nextera": "NEE",
"nextera energy": "NEE",
"duke energy": "DUK",
"southern company": "SO",
"dominion": "D",
"dominion energy": "D",
"sempra": "SRE",
"sempra energy": "SRE",
"conocophillips": "COP",
"conoco": "COP",
"schlumberger": "SLB",
"halliburton": "HAL",
"baker hughes": "BKR",
"marathon": "MPC",
"marathon petroleum": "MPC",
"valero": "VLO",
"valero energy": "VLO",
"phillips 66": "PSX",
"occidental": "OXY",
"occidental petroleum": "OXY",
"pioneer": "PXD",
"pioneer natural resources": "PXD",
"devon energy": "DVN",
"devon": "DVN",
"coinbase": "COIN",
"coinbase global": "COIN",
"robinhood": "HOOD",
"robinhood markets": "HOOD",
"sofi": "SOFI",
"sofi technologies": "SOFI",
"affirm": "AFRM",
"affirm holdings": "AFRM",
"marqeta": "MQ",
"toast": "TOST",
"toast inc": "TOST",
"docusign": "DOCU",
"docusign inc": "DOCU",
"asana": "ASAN",
"monday.com": "MNDY",
"monday": "MNDY",
"atlassian": "TEAM",
"atlassian corp": "TEAM",
"intuit": "INTU",
"intuit inc": "INTU",
"autodesk": "ADSK",
"autodesk inc": "ADSK",
"synopsys": "SNPS",
"cadence": "CDNS",
"cadence design": "CDNS",
"ansys": "ANSS",
"roper": "ROP",
"roper technologies": "ROP",
"fortinet": "FTNT",
"palo alto": "PANW",
"palo alto networks": "PANW",
"zscaler": "ZS",
"sentinelone": "S",
"veeva": "VEEV",
"veeva systems": "VEEV",
}
US_EXCHANGE_CODES = {
"NYQ",
"NMS",
"NGM",
"NCM",
"ASE",
"PCX",
"BTS",
"NYSE",
"NASDAQ",
"AMEX",
"NYS",
"NAS",
"NIM",
"NAQ",
}
SUFFIX_PATTERNS = [
r"\s+inc\.?$",
r"\s+corp\.?$",
r"\s+corporation$",
r"\s+co\.?$",
r"\s+company$",
r"\s+llc$",
r"\s+ltd\.?$",
r"\s+limited$",
r"\s+plc$",
r"\s+holdings?$",
r"\s+group$",
r"\s+technologies$",
r"\s+enterprises?$",
]
def _normalize_company_name(name: str) -> str:
normalized = name.lower().strip()
for pattern in SUFFIX_PATTERNS:
normalized = re.sub(pattern, "", normalized, flags=re.IGNORECASE)
normalized = normalized.strip()
return normalized
def _search_yfinance_ticker(company_name: str) -> Optional[str]:
try:
search_result = yf.Ticker(company_name)
info = search_result.info
if info and "symbol" in info:
return info["symbol"]
except Exception as e:
logger.debug("yfinance search failed for %s: %s", company_name, str(e))
try:
search = yf.Search(company_name, max_results=5)
if hasattr(search, "quotes") and search.quotes:
for quote in search.quotes:
if "symbol" in quote:
return quote["symbol"]
except Exception as e:
logger.debug("yfinance Search failed for %s: %s", company_name, str(e))
return None
def validate_us_ticker(ticker: str) -> bool:
try:
ticker_obj = yf.Ticker(ticker.upper())
info = ticker_obj.info
if not info:
logger.warning("Validation failed for %s: no info available", ticker)
return False
exchange = info.get("exchange", "")
if exchange in US_EXCHANGE_CODES:
return True
exchange_lower = exchange.lower()
if any(us_ex.lower() in exchange_lower for us_ex in ["nyse", "nasdaq", "amex", "nys", "nms", "ngm"]):
return True
logger.warning("Validation failed for %s: exchange %s is not a US exchange", ticker, exchange)
return False
except Exception as e:
logger.warning("Validation failed for %s: %s", ticker, str(e))
return False
def resolve_ticker(company_name: str) -> Optional[str]:
if not company_name or not company_name.strip():
return None
normalized = company_name.lower().strip()
if normalized in COMPANY_TO_TICKER:
return COMPANY_TO_TICKER[normalized]
normalized_stripped = _normalize_company_name(company_name)
if normalized_stripped in COMPANY_TO_TICKER:
return COMPANY_TO_TICKER[normalized_stripped]
if company_name.upper() in [v for v in COMPANY_TO_TICKER.values()]:
if validate_us_ticker(company_name.upper()):
return company_name.upper()
logger.info("Using yfinance fallback for company: %s", company_name)
yf_ticker = _search_yfinance_ticker(company_name)
if yf_ticker:
if validate_us_ticker(yf_ticker):
logger.info("Resolved %s to %s via yfinance", company_name, yf_ticker)
return yf_ticker
else:
logger.warning("Ticker %s for %s failed US exchange validation", yf_ticker, company_name)
return None
logger.warning("Could not resolve ticker for company: %s", company_name)
return None
def validate_tradeable(ticker: str) -> bool:
return validate_us_ticker(ticker)

View File

@ -8,26 +8,24 @@ DEFAULT_CONFIG = {
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-5",
"quick_think_llm": "gpt-5-mini",
"backend_url": "https://api.openai.com/v1",
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
"max_debate_rounds": 2,
"max_risk_discuss_rounds": 2,
"max_recur_limit": 100,
# Data vendor configuration
# Category-level configuration (default for all tools in category)
"data_vendors": {
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
"core_stock_apis": "yfinance",
"technical_indicators": "yfinance",
"fundamental_data": "alpha_vantage",
"news_data": "alpha_vantage",
},
# Tool-level configuration (takes precedence over category-level)
"tool_vendors": {
# Example: "get_stock_data": "alpha_vantage", # Override category default
# Example: "get_news": "openai", # Override category default
},
"discovery_timeout": 60,
"discovery_hard_timeout": 120,
"discovery_cache_ttl": 300,
"discovery_max_results": 20,
"discovery_min_mentions": 2,
}

View File

@ -1,9 +1,9 @@
# TradingAgents/graph/trading_graph.py
import os
import signal
import threading
from pathlib import Path
import json
from datetime import date
from datetime import date, datetime
from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
@ -22,7 +22,6 @@ from tradingagents.agents.utils.agent_states import (
)
from tradingagents.dataflows.config import set_config
# Import the new abstract tool methods from agent_utils
from tradingagents.agents.utils.agent_utils import (
get_stock_data,
get_indicators,
@ -36,6 +35,19 @@ from tradingagents.agents.utils.agent_utils import (
get_global_news
)
from tradingagents.agents.discovery import (
DiscoveryRequest,
DiscoveryResult,
DiscoveryStatus,
TrendingStock,
Sector,
EventCategory,
DiscoveryTimeoutError,
extract_entities,
calculate_trending_scores,
)
from tradingagents.dataflows.interface import get_bulk_news
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
@ -43,8 +55,15 @@ from .reflection import Reflector
from .signal_processing import SignalProcessor
class DiscoveryTimeoutException(Exception):
pass
def _timeout_handler(signum, frame):
raise DiscoveryTimeoutException("Discovery operation timed out")
class TradingAgentsGraph:
"""Main class that orchestrates the trading agents framework."""
def __init__(
self,
@ -52,26 +71,16 @@ class TradingAgentsGraph:
debug=False,
config: Dict[str, Any] = None,
):
"""Initialize the trading agents graph and components.
Args:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
# Update the interface's config
set_config(self.config)
# Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
@ -83,18 +92,13 @@ class TradingAgentsGraph:
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
# Create tool nodes
self.tool_nodes = self._create_tool_nodes()
# Initialize components
self.conditional_logic = ConditionalLogic()
self.graph_setup = GraphSetup(
self.quick_thinking_llm,
@ -111,35 +115,26 @@ class TradingAgentsGraph:
self.propagator = Propagator()
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
# State tracking
self.curr_state = None
self.ticker = None
self.log_states_dict = {} # date to full state dict
# Set up the graph
self.log_states_dict = {}
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
"""Create tool nodes for different data sources using abstract methods."""
return {
"market": ToolNode(
[
# Core stock data tools
get_stock_data,
# Technical indicators
get_indicators,
]
),
"social": ToolNode(
[
# News tools for social media analysis
get_news,
]
),
"news": ToolNode(
[
# News and insider information
get_news,
get_global_news,
get_insider_sentiment,
@ -148,7 +143,6 @@ class TradingAgentsGraph:
),
"fundamentals": ToolNode(
[
# Fundamental analysis tools
get_fundamentals,
get_balance_sheet,
get_cashflow,
@ -158,18 +152,13 @@ class TradingAgentsGraph:
}
def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date."""
self.ticker = company_name
# Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
args = self.propagator.get_graph_args()
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
@ -180,20 +169,14 @@ class TradingAgentsGraph:
final_state = trace[-1]
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
# Store current state for reflection
self.curr_state = final_state
# Log state
self._log_state(trade_date, final_state)
# Return decision and processed signal
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state):
"""Log the final state to a JSON file."""
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
@ -224,7 +207,6 @@ class TradingAgentsGraph:
"final_trade_decision": final_state["final_trade_decision"],
}
# Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
directory.mkdir(parents=True, exist_ok=True)
@ -235,7 +217,6 @@ class TradingAgentsGraph:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses):
"""Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
)
@ -253,5 +234,95 @@ class TradingAgentsGraph:
)
def process_signal(self, full_signal):
"""Process a signal to extract the core decision."""
return self.signal_processor.process_signal(full_signal)
def discover_trending(
self,
request: Optional[DiscoveryRequest] = None,
) -> DiscoveryResult:
if request is None:
request = DiscoveryRequest(
lookback_period="24h",
max_results=self.config.get("discovery_max_results", 20),
)
started_at = datetime.now()
result = DiscoveryResult(
request=request,
trending_stocks=[],
status=DiscoveryStatus.PROCESSING,
started_at=started_at,
)
hard_timeout = self.config.get("discovery_hard_timeout", 120)
discovery_result = {"stocks": [], "error": None}
def run_discovery():
try:
articles = get_bulk_news(request.lookback_period)
mentions = extract_entities(articles, self.config)
min_mentions = self.config.get("discovery_min_mentions", 2)
max_results = request.max_results or self.config.get("discovery_max_results", 20)
trending_stocks = calculate_trending_scores(
mentions,
articles,
max_results=max_results,
min_mentions=min_mentions,
)
discovery_result["stocks"] = trending_stocks
except Exception as e:
discovery_result["error"] = str(e)
discovery_thread = threading.Thread(target=run_discovery)
discovery_thread.start()
discovery_thread.join(timeout=hard_timeout)
if discovery_thread.is_alive():
raise DiscoveryTimeoutError(
f"Discovery operation exceeded {hard_timeout} second timeout"
)
if discovery_result["error"]:
result.status = DiscoveryStatus.FAILED
result.error_message = discovery_result["error"]
result.completed_at = datetime.now()
return result
trending_stocks = discovery_result["stocks"]
if request.sector_filter:
sector_values = {s.value if isinstance(s, Sector) else s for s in request.sector_filter}
trending_stocks = [
stock for stock in trending_stocks
if stock.sector.value in sector_values or stock.sector in request.sector_filter
]
if request.event_filter:
event_values = {e.value if isinstance(e, EventCategory) else e for e in request.event_filter}
trending_stocks = [
stock for stock in trending_stocks
if stock.event_type.value in event_values or stock.event_type in request.event_filter
]
result.trending_stocks = trending_stocks
result.status = DiscoveryStatus.COMPLETED
result.completed_at = datetime.now()
return result
def analyze_trending(
self,
trending_stock: TrendingStock,
trade_date: Optional[date] = None,
) -> Tuple[Dict[str, Any], str]:
ticker = trending_stock.ticker
if trade_date is None:
trade_date = date.today()
return self.propagate(ticker, trade_date)