Add trending stock discovery feature
Implement a multi-stage pipeline to discover trending stocks from news: - Entity extraction from news articles using LLM - Stock ticker resolution via Yahoo Finance - Sector classification and event categorization - Scoring algorithm based on mentions, sentiment, and recency - CLI integration with interactive stock selection and analysis flow - Persistence layer for saving discovery results - Comprehensive test suite for all discovery components Update README with uv-based installation instructions and remove emojis. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
13b826a31d
commit
3f6b1e9f39
|
|
@ -9,3 +9,13 @@ eval_results/
|
|||
eval_data/
|
||||
*.egg-info/
|
||||
.env
|
||||
.claude/
|
||||
.pytest_cache/
|
||||
.specify/
|
||||
specs/
|
||||
agent-os/
|
||||
*.local.md
|
||||
build/
|
||||
.mcp.json
|
||||
*.zip
|
||||
todos.md
|
||||
|
|
|
|||
25
README.md
25
README.md
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
# TradingAgents: Multi-Agents LLM Financial Trading Framework
|
||||
|
||||
> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
|
||||
> **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
|
||||
>
|
||||
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
|
||||
|
||||
|
|
@ -43,7 +43,7 @@
|
|||
|
||||
<div align="center">
|
||||
|
||||
🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
|
||||
[TradingAgents](#tradingagents-framework) | [Installation & CLI](#installation-and-cli) | [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | [Package Usage](#tradingagents-package) | [Contributing](#contributing) | [Citation](#citation)
|
||||
|
||||
</div>
|
||||
|
||||
|
|
@ -101,15 +101,10 @@ git clone https://github.com/TauricResearch/TradingAgents.git
|
|||
cd TradingAgents
|
||||
```
|
||||
|
||||
Create a virtual environment in any of your favorite environment managers:
|
||||
Sync virtual environment:
|
||||
```bash
|
||||
conda create -n tradingagents python=3.13
|
||||
conda activate tradingagents
|
||||
```
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
uv sync
|
||||
uv source .venv/bin/activate
|
||||
```
|
||||
|
||||
### Required APIs
|
||||
|
|
@ -133,7 +128,7 @@ cp .env.example .env
|
|||
|
||||
You can also try out the CLI directly by running:
|
||||
```bash
|
||||
python -m cli.main
|
||||
uv run cli/main.py
|
||||
```
|
||||
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
|
||||
|
||||
|
|
@ -204,13 +199,9 @@ print(decision)
|
|||
|
||||
You can view the full list of configurations in `tradingagents/default_config.py`.
|
||||
|
||||
## Contributing
|
||||
## Source
|
||||
|
||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||
|
||||
## Citation
|
||||
|
||||
Please reference our work if you find *TradingAgents* provides you with some help :)
|
||||
Thanks to Yijia Xiao and Edward Sun and Di Luo and Wei Wang. Core agent implementation based on [TradingAgents: Multi-Agents LLM Financial Trading Framework](https://arxiv.org/abs/2412.20138)
|
||||
|
||||
```
|
||||
@misc{xiao2025tradingagentsmultiagentsllmfinancial,
|
||||
|
|
|
|||
1791
cli/main.py
1791
cli/main.py
File diff suppressed because it is too large
Load Diff
2
main.py
2
main.py
|
|
@ -8,7 +8,7 @@ load_dotenv()
|
|||
|
||||
# Create a custom config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["deep_think_llm"] = "gpt-4o-mini" # Use a different model
|
||||
config["deep_think_llm"] = "gpt-5" # Use a different model
|
||||
config["quick_think_llm"] = "gpt-4o-mini" # Use a different model
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,200 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
import signal
|
||||
|
||||
from tradingagents.agents.discovery import (
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
EventCategory,
|
||||
DiscoveryTimeoutError,
|
||||
)
|
||||
|
||||
|
||||
def create_mock_trending_stock(
|
||||
ticker: str = "AAPL",
|
||||
company_name: str = "Apple Inc.",
|
||||
score: float = 10.0,
|
||||
sector: Sector = Sector.TECHNOLOGY,
|
||||
event_type: EventCategory = EventCategory.EARNINGS,
|
||||
) -> TrendingStock:
|
||||
return TrendingStock(
|
||||
ticker=ticker,
|
||||
company_name=company_name,
|
||||
score=score,
|
||||
mention_count=5,
|
||||
sentiment=0.5,
|
||||
sector=sector,
|
||||
event_type=event_type,
|
||||
news_summary="Test news summary",
|
||||
source_articles=[],
|
||||
)
|
||||
|
||||
|
||||
def create_mock_news_article() -> NewsArticle:
|
||||
return NewsArticle(
|
||||
title="Test Article",
|
||||
source="Test Source",
|
||||
url="https://example.com/article",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Test content about Apple stock",
|
||||
ticker_mentions=["AAPL"],
|
||||
)
|
||||
|
||||
|
||||
class TestDiscoverTrendingReturnsDiscoveryResult:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
def test_discover_trending_returns_discovery_result(
|
||||
self, mock_scores, mock_extract, mock_bulk_news
|
||||
):
|
||||
mock_bulk_news.return_value = [create_mock_news_article()]
|
||||
mock_extract.return_value = []
|
||||
mock_scores.return_value = [create_mock_trending_stock()]
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
result = graph.discover_trending()
|
||||
|
||||
assert isinstance(result, DiscoveryResult)
|
||||
assert result.status == DiscoveryStatus.COMPLETED
|
||||
assert len(result.trending_stocks) > 0
|
||||
|
||||
|
||||
class TestAnalyzeTrendingCallsPropagate:
|
||||
def test_analyze_trending_calls_propagate(self):
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.propagate = Mock(return_value=({"final_state": "test"}, "BUY"))
|
||||
|
||||
trending_stock = create_mock_trending_stock()
|
||||
|
||||
result = graph.analyze_trending(trending_stock)
|
||||
|
||||
graph.propagate.assert_called_once()
|
||||
call_args = graph.propagate.call_args
|
||||
assert call_args[0][0] == "AAPL"
|
||||
|
||||
|
||||
class TestSectorFilterParameter:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
def test_sector_filter_filters_results(
|
||||
self, mock_scores, mock_extract, mock_bulk_news
|
||||
):
|
||||
mock_bulk_news.return_value = [create_mock_news_article()]
|
||||
mock_extract.return_value = []
|
||||
mock_scores.return_value = [
|
||||
create_mock_trending_stock(ticker="AAPL", sector=Sector.TECHNOLOGY),
|
||||
create_mock_trending_stock(ticker="JPM", sector=Sector.FINANCE),
|
||||
create_mock_trending_stock(ticker="XOM", sector=Sector.ENERGY),
|
||||
]
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
sector_filter=[Sector.TECHNOLOGY],
|
||||
)
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
assert all(
|
||||
stock.sector == Sector.TECHNOLOGY for stock in result.trending_stocks
|
||||
)
|
||||
|
||||
|
||||
class TestEventFilterParameter:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
def test_event_filter_filters_results(
|
||||
self, mock_scores, mock_extract, mock_bulk_news
|
||||
):
|
||||
mock_bulk_news.return_value = [create_mock_news_article()]
|
||||
mock_extract.return_value = []
|
||||
mock_scores.return_value = [
|
||||
create_mock_trending_stock(ticker="AAPL", event_type=EventCategory.EARNINGS),
|
||||
create_mock_trending_stock(
|
||||
ticker="MSFT", event_type=EventCategory.PRODUCT_LAUNCH
|
||||
),
|
||||
create_mock_trending_stock(
|
||||
ticker="GOOGL", event_type=EventCategory.MERGER_ACQUISITION
|
||||
),
|
||||
]
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
event_filter=[EventCategory.EARNINGS],
|
||||
)
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
assert all(
|
||||
stock.event_type == EventCategory.EARNINGS
|
||||
for stock in result.trending_stocks
|
||||
)
|
||||
|
||||
|
||||
class TestTimeoutHandling:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
def test_timeout_raises_discovery_timeout_error(self, mock_bulk_news):
|
||||
def slow_fetch(*args, **kwargs):
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
return []
|
||||
|
||||
mock_bulk_news.side_effect = slow_fetch
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 0.1,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
with pytest.raises(DiscoveryTimeoutError):
|
||||
graph.discover_trending()
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tradingagents.agents.discovery import NewsArticle
|
||||
from tradingagents.dataflows.alpha_vantage_common import AlphaVantageRateLimitError
|
||||
|
||||
|
||||
class TestGetBulkNewsReturnsNewsArticles:
|
||||
def test_get_bulk_news_returns_list_of_news_article_objects(self):
|
||||
mock_raw_news = [
|
||||
{
|
||||
"title": "Market Update: Tech stocks rally",
|
||||
"source": "Reuters",
|
||||
"url": "https://reuters.com/market-update",
|
||||
"published_at": datetime.now().isoformat(),
|
||||
"content_snippet": "Technology stocks led gains in early trading...",
|
||||
},
|
||||
{
|
||||
"title": "Fed signals rate decision",
|
||||
"source": "Bloomberg",
|
||||
"url": "https://bloomberg.com/fed-rates",
|
||||
"published_at": datetime.now().isoformat(),
|
||||
"content_snippet": "Federal Reserve officials indicated...",
|
||||
},
|
||||
]
|
||||
|
||||
from tradingagents.dataflows.interface import (
|
||||
_bulk_news_cache,
|
||||
get_bulk_news,
|
||||
)
|
||||
|
||||
_bulk_news_cache.clear()
|
||||
|
||||
with patch(
|
||||
"tradingagents.dataflows.interface._fetch_bulk_news_from_vendor"
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = mock_raw_news
|
||||
|
||||
result = get_bulk_news(lookback_period="24h")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
for article in result:
|
||||
assert isinstance(article, NewsArticle)
|
||||
assert article.title is not None
|
||||
assert article.source is not None
|
||||
assert article.url is not None
|
||||
|
||||
|
||||
class TestLookbackPeriodParsing:
|
||||
@pytest.mark.parametrize(
|
||||
"lookback,expected_hours",
|
||||
[
|
||||
("1h", 1),
|
||||
("6h", 6),
|
||||
("24h", 24),
|
||||
("7d", 168),
|
||||
],
|
||||
)
|
||||
def test_lookback_period_parsing(self, lookback, expected_hours):
|
||||
from tradingagents.dataflows.interface import parse_lookback_period
|
||||
|
||||
hours = parse_lookback_period(lookback)
|
||||
assert hours == expected_hours
|
||||
|
||||
def test_invalid_lookback_period_raises_error(self):
|
||||
from tradingagents.dataflows.interface import parse_lookback_period
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parse_lookback_period("invalid")
|
||||
|
||||
|
||||
class TestVendorFallback:
|
||||
def test_vendor_fallback_when_primary_rate_limited(self):
|
||||
mock_openai_news = [
|
||||
{
|
||||
"title": "Fallback news from OpenAI",
|
||||
"source": "Web Search",
|
||||
"url": "https://example.com/fallback",
|
||||
"published_at": datetime.now().isoformat(),
|
||||
"content_snippet": "This is fallback content...",
|
||||
},
|
||||
]
|
||||
|
||||
from tradingagents.dataflows.interface import (
|
||||
_bulk_news_cache,
|
||||
)
|
||||
|
||||
_bulk_news_cache.clear()
|
||||
|
||||
with patch(
|
||||
"tradingagents.dataflows.interface.VENDOR_METHODS",
|
||||
{
|
||||
"get_bulk_news": {
|
||||
"alpha_vantage": MagicMock(side_effect=AlphaVantageRateLimitError("Rate limit")),
|
||||
"openai": MagicMock(return_value=mock_openai_news),
|
||||
"google": MagicMock(return_value=[]),
|
||||
}
|
||||
}
|
||||
):
|
||||
from tradingagents.dataflows.interface import _fetch_bulk_news_from_vendor
|
||||
|
||||
result = _fetch_bulk_news_from_vendor("24h")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Fallback news from OpenAI"
|
||||
|
||||
|
||||
class TestBulkNewsCache:
|
||||
def test_cache_returns_same_results_within_ttl(self):
|
||||
from tradingagents.dataflows.interface import (
|
||||
_bulk_news_cache,
|
||||
_get_cached_bulk_news,
|
||||
_set_cached_bulk_news,
|
||||
)
|
||||
|
||||
_bulk_news_cache.clear()
|
||||
|
||||
test_articles = [
|
||||
NewsArticle(
|
||||
title="Cached article",
|
||||
source="Test Source",
|
||||
url="https://test.com/cached",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Cached content...",
|
||||
ticker_mentions=[],
|
||||
)
|
||||
]
|
||||
|
||||
_set_cached_bulk_news("24h", test_articles)
|
||||
|
||||
cached_result = _get_cached_bulk_news("24h")
|
||||
assert cached_result is not None
|
||||
assert len(cached_result) == 1
|
||||
assert cached_result[0].title == "Cached article"
|
||||
|
||||
cached_result_again = _get_cached_bulk_news("24h")
|
||||
assert cached_result_again is not None
|
||||
assert cached_result_again[0].title == cached_result[0].title
|
||||
|
||||
|
||||
class TestEmptyResultHandling:
|
||||
def test_empty_result_handling(self):
|
||||
from tradingagents.dataflows.interface import (
|
||||
_bulk_news_cache,
|
||||
get_bulk_news,
|
||||
)
|
||||
|
||||
_bulk_news_cache.clear()
|
||||
|
||||
with patch(
|
||||
"tradingagents.dataflows.interface._fetch_bulk_news_from_vendor"
|
||||
) as mock_fetch:
|
||||
mock_fetch.return_value = []
|
||||
|
||||
result = get_bulk_news(lookback_period="1h")
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 0
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
|
||||
from tradingagents.agents.discovery.models import (
|
||||
DiscoveryResult,
|
||||
DiscoveryRequest,
|
||||
DiscoveryStatus,
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
EventCategory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trending_stocks():
|
||||
article = NewsArticle(
|
||||
title="Apple announces new iPhone",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/article",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Apple Inc. unveiled its latest iPhone model today...",
|
||||
ticker_mentions=["AAPL"],
|
||||
)
|
||||
return [
|
||||
TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple Inc.",
|
||||
score=8.5,
|
||||
mention_count=10,
|
||||
sentiment=0.7,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.PRODUCT_LAUNCH,
|
||||
news_summary="Apple announced new iPhone model with enhanced AI capabilities.",
|
||||
source_articles=[article],
|
||||
),
|
||||
TrendingStock(
|
||||
ticker="MSFT",
|
||||
company_name="Microsoft Corporation",
|
||||
score=7.2,
|
||||
mention_count=8,
|
||||
sentiment=0.5,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.EARNINGS,
|
||||
news_summary="Microsoft reported strong quarterly earnings.",
|
||||
source_articles=[article],
|
||||
),
|
||||
TrendingStock(
|
||||
ticker="NVDA",
|
||||
company_name="NVIDIA Corporation",
|
||||
score=6.8,
|
||||
mention_count=6,
|
||||
sentiment=0.4,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.PRODUCT_LAUNCH,
|
||||
news_summary="NVIDIA unveiled new AI chips.",
|
||||
source_articles=[article],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discovery_result(sample_trending_stocks):
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
max_results=20,
|
||||
)
|
||||
return DiscoveryResult(
|
||||
request=request,
|
||||
trending_stocks=sample_trending_stocks,
|
||||
status=DiscoveryStatus.COMPLETED,
|
||||
started_at=datetime.now(),
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
class TestDiscoveryMenuOption:
|
||||
def test_discover_trending_flow_exists(self):
|
||||
from cli.main import discover_trending_flow
|
||||
assert callable(discover_trending_flow)
|
||||
|
||||
def test_select_lookback_period_function_exists(self):
|
||||
from cli.main import select_lookback_period
|
||||
assert callable(select_lookback_period)
|
||||
|
||||
|
||||
class TestLookbackSelection:
|
||||
@patch("cli.main.questionary.select")
|
||||
def test_lookback_selection_returns_valid_period(self, mock_select):
|
||||
mock_select.return_value.ask.return_value = "24h"
|
||||
from cli.main import select_lookback_period
|
||||
result = select_lookback_period()
|
||||
assert result in ["1h", "6h", "24h", "7d"]
|
||||
|
||||
@patch("cli.main.questionary.select")
|
||||
def test_lookback_selection_handles_all_options(self, mock_select):
|
||||
from cli.main import select_lookback_period
|
||||
for period in ["1h", "6h", "24h", "7d"]:
|
||||
mock_select.return_value.ask.return_value = period
|
||||
result = select_lookback_period()
|
||||
assert result == period
|
||||
|
||||
|
||||
class TestResultsTableDisplay:
|
||||
def test_create_discovery_results_table(self, sample_trending_stocks):
|
||||
from cli.main import create_discovery_results_table
|
||||
table = create_discovery_results_table(sample_trending_stocks)
|
||||
assert table is not None
|
||||
assert table.row_count == len(sample_trending_stocks)
|
||||
|
||||
def test_table_has_correct_columns(self, sample_trending_stocks):
|
||||
from cli.main import create_discovery_results_table
|
||||
table = create_discovery_results_table(sample_trending_stocks)
|
||||
column_names = [col.header for col in table.columns]
|
||||
expected_columns = ["Rank", "Ticker", "Company", "Score", "Mentions", "Event Type"]
|
||||
for expected in expected_columns:
|
||||
assert expected in column_names
|
||||
|
||||
|
||||
class TestDetailView:
|
||||
def test_create_stock_detail_panel(self, sample_trending_stocks):
|
||||
from cli.main import create_stock_detail_panel
|
||||
stock = sample_trending_stocks[0]
|
||||
panel = create_stock_detail_panel(stock, rank=1)
|
||||
assert panel is not None
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tradingagents.agents.discovery import NewsArticle, EventCategory
|
||||
|
||||
|
||||
class TestExtractEntitiesReturnsCompanyMentions:
|
||||
def test_extract_entities_returns_list_of_company_mentions(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Apple announces new iPhone",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/apple",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Apple Inc unveiled its latest iPhone model today with advanced AI features.",
|
||||
ticker_mentions=[],
|
||||
),
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.entities = [
|
||||
MagicMock(
|
||||
company_name="Apple Inc",
|
||||
confidence=0.95,
|
||||
context_snippet="Apple Inc unveiled its latest iPhone",
|
||||
event_type="product_launch",
|
||||
sentiment=0.7,
|
||||
)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.entity_extractor._get_llm"
|
||||
) as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = extract_entities(articles)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
assert all(isinstance(m, EntityMention) for m in result)
|
||||
assert result[0].company_name == "Apple Inc"
|
||||
|
||||
|
||||
class TestConfidenceScoreRange:
|
||||
def test_confidence_score_in_valid_range(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Tesla reports earnings",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/tsla",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Tesla Inc reported strong quarterly earnings beating analyst expectations.",
|
||||
ticker_mentions=[],
|
||||
),
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.entities = [
|
||||
MagicMock(
|
||||
company_name="Tesla Inc",
|
||||
confidence=0.88,
|
||||
context_snippet="Tesla Inc reported strong quarterly earnings",
|
||||
event_type="earnings",
|
||||
sentiment=0.5,
|
||||
)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.entity_extractor._get_llm"
|
||||
) as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = extract_entities(articles)
|
||||
|
||||
for mention in result:
|
||||
assert 0.0 <= mention.confidence <= 1.0
|
||||
|
||||
|
||||
class TestContextSnippetExtraction:
|
||||
def test_context_snippet_extraction(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Microsoft acquires gaming company",
|
||||
source="WSJ",
|
||||
url="https://wsj.com/msft",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Microsoft Corporation announced today it will acquire a major gaming studio for $10 billion.",
|
||||
ticker_mentions=[],
|
||||
),
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.entities = [
|
||||
MagicMock(
|
||||
company_name="Microsoft Corporation",
|
||||
confidence=0.92,
|
||||
context_snippet="Microsoft Corporation announced today it will acquire",
|
||||
event_type="merger_acquisition",
|
||||
sentiment=0.6,
|
||||
)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.entity_extractor._get_llm"
|
||||
) as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = extract_entities(articles)
|
||||
|
||||
assert len(result) > 0
|
||||
for mention in result:
|
||||
assert mention.context_snippet is not None
|
||||
assert len(mention.context_snippet) > 0
|
||||
assert len(mention.context_snippet) <= 150
|
||||
|
||||
|
||||
class TestBatchProcessing:
|
||||
def test_batch_processing_of_multiple_articles(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
BATCH_SIZE,
|
||||
)
|
||||
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title=f"News article {i}",
|
||||
source="Reuters",
|
||||
url=f"https://reuters.com/article{i}",
|
||||
published_at=datetime.now(),
|
||||
content_snippet=f"Company {i} announced major developments today.",
|
||||
ticker_mentions=[],
|
||||
)
|
||||
for i in range(15)
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.entities = [
|
||||
MagicMock(
|
||||
company_name="Test Company",
|
||||
confidence=0.85,
|
||||
context_snippet="Company announced major developments",
|
||||
event_type="other",
|
||||
sentiment=0.0,
|
||||
)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.entity_extractor._get_llm"
|
||||
) as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
structured_llm = MagicMock()
|
||||
structured_llm.invoke.return_value = mock_response
|
||||
mock_llm.with_structured_output.return_value = structured_llm
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = extract_entities(articles)
|
||||
|
||||
expected_batches = (len(articles) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||
assert structured_llm.invoke.call_count == expected_batches
|
||||
|
||||
|
||||
class TestNoCompanyMentions:
|
||||
def test_handling_of_articles_with_no_company_mentions(self):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Weather forecast for tomorrow",
|
||||
source="Weather Channel",
|
||||
url="https://weather.com/forecast",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Tomorrow will be sunny with temperatures reaching 75 degrees.",
|
||||
ticker_mentions=[],
|
||||
),
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.entities = []
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.entity_extractor._get_llm"
|
||||
) as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = extract_entities(articles)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestEventTypeClassification:
|
||||
@pytest.mark.parametrize(
|
||||
"event_type",
|
||||
[
|
||||
"earnings",
|
||||
"merger_acquisition",
|
||||
"regulatory",
|
||||
"product_launch",
|
||||
"executive_change",
|
||||
"other",
|
||||
],
|
||||
)
|
||||
def test_event_type_classification(self, event_type):
|
||||
from tradingagents.agents.discovery.entity_extractor import (
|
||||
extract_entities,
|
||||
EntityMention,
|
||||
)
|
||||
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Company news",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/news",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="A company made an announcement today.",
|
||||
ticker_mentions=[],
|
||||
),
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.entities = [
|
||||
MagicMock(
|
||||
company_name="Test Company",
|
||||
confidence=0.90,
|
||||
context_snippet="A company made an announcement",
|
||||
event_type=event_type,
|
||||
sentiment=0.0,
|
||||
)
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.entity_extractor._get_llm"
|
||||
) as mock_get_llm:
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.with_structured_output.return_value.invoke.return_value = (
|
||||
mock_response
|
||||
)
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
result = extract_entities(articles)
|
||||
|
||||
assert len(result) > 0
|
||||
assert result[0].event_type == EventCategory(event_type)
|
||||
|
|
@ -0,0 +1,489 @@
|
|||
import pytest
|
||||
import math
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tradingagents.agents.discovery import (
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
Sector,
|
||||
EventCategory,
|
||||
DiscoveryTimeoutError,
|
||||
NewsUnavailableError,
|
||||
)
|
||||
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
||||
|
||||
|
||||
class TestEndToEndDiscoveryFlow:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
def test_full_discovery_flow_from_news_to_results(
|
||||
self, mock_scores, mock_extract, mock_bulk_news
|
||||
):
|
||||
now = datetime.now()
|
||||
mock_articles = [
|
||||
NewsArticle(
|
||||
title="Apple announces record earnings",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/apple-earnings",
|
||||
published_at=now - timedelta(hours=2),
|
||||
content_snippet="Apple Inc reported record quarterly earnings...",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Apple stock surges on AI news",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/apple-ai",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="Shares of Apple jumped after AI announcement...",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
]
|
||||
mock_bulk_news.return_value = mock_articles
|
||||
|
||||
mock_mentions = [
|
||||
EntityMention(
|
||||
company_name="Apple Inc",
|
||||
confidence=0.95,
|
||||
context_snippet="Apple Inc reported record quarterly earnings",
|
||||
article_id="article_0",
|
||||
event_type=EventCategory.EARNINGS,
|
||||
sentiment=0.8,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Apple",
|
||||
confidence=0.92,
|
||||
context_snippet="Shares of Apple jumped",
|
||||
article_id="article_1",
|
||||
event_type=EventCategory.PRODUCT_LAUNCH,
|
||||
sentiment=0.7,
|
||||
),
|
||||
]
|
||||
mock_extract.return_value = mock_mentions
|
||||
|
||||
mock_trending = [
|
||||
TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple Inc.",
|
||||
score=8.5,
|
||||
mention_count=2,
|
||||
sentiment=0.75,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.EARNINGS,
|
||||
news_summary="Apple reported record earnings and AI progress.",
|
||||
source_articles=mock_articles,
|
||||
),
|
||||
]
|
||||
mock_scores.return_value = mock_trending
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
request = DiscoveryRequest(lookback_period="24h")
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
assert isinstance(result, DiscoveryResult)
|
||||
assert result.status == DiscoveryStatus.COMPLETED
|
||||
assert len(result.trending_stocks) == 1
|
||||
assert result.trending_stocks[0].ticker == "AAPL"
|
||||
assert result.trending_stocks[0].mention_count >= 2
|
||||
|
||||
mock_bulk_news.assert_called_once_with("24h")
|
||||
mock_extract.assert_called_once()
|
||||
mock_scores.assert_called_once()
|
||||
|
||||
|
||||
class TestEntityExtractionToScoringPipeline:
|
||||
def test_pipeline_from_extraction_to_scoring(self):
|
||||
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
||||
|
||||
now = datetime.now()
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Microsoft cloud revenue grows",
|
||||
source="WSJ",
|
||||
url="https://wsj.com/article1",
|
||||
published_at=now - timedelta(hours=2),
|
||||
content_snippet="Microsoft Corporation reported strong cloud growth.",
|
||||
ticker_mentions=["MSFT"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Microsoft earnings beat estimates",
|
||||
source="CNBC",
|
||||
url="https://cnbc.com/article2",
|
||||
published_at=now - timedelta(hours=3),
|
||||
content_snippet="Microsoft earnings exceeded analyst expectations.",
|
||||
ticker_mentions=["MSFT"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Tech stocks rally",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/article3",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="Technology companies led market gains.",
|
||||
ticker_mentions=[],
|
||||
),
|
||||
]
|
||||
|
||||
mentions = [
|
||||
EntityMention(
|
||||
company_name="Microsoft Corporation",
|
||||
confidence=0.95,
|
||||
context_snippet="Microsoft Corporation reported strong cloud growth",
|
||||
article_id="article_0",
|
||||
event_type=EventCategory.EARNINGS,
|
||||
sentiment=0.7,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Microsoft",
|
||||
confidence=0.92,
|
||||
context_snippet="Microsoft earnings exceeded analyst expectations",
|
||||
article_id="article_1",
|
||||
event_type=EventCategory.EARNINGS,
|
||||
sentiment=0.8,
|
||||
),
|
||||
]
|
||||
|
||||
with patch("tradingagents.agents.discovery.scorer.resolve_ticker") as mock_resolve:
|
||||
mock_resolve.return_value = "MSFT"
|
||||
|
||||
with patch("tradingagents.agents.discovery.scorer.classify_sector") as mock_sector:
|
||||
mock_sector.return_value = "technology"
|
||||
|
||||
result = calculate_trending_scores(mentions, articles, min_mentions=2)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].ticker == "MSFT"
|
||||
assert result[0].mention_count == 2
|
||||
assert result[0].sentiment > 0
|
||||
|
||||
|
||||
class TestNewsVendorFailureGracefulDegradation:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
def test_news_vendor_failure_with_graceful_degradation(self, mock_bulk_news):
|
||||
mock_bulk_news.side_effect = NewsUnavailableError("All news vendors failed")
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
result = graph.discover_trending()
|
||||
|
||||
assert result.status == DiscoveryStatus.FAILED
|
||||
assert result.error_message is not None
|
||||
assert "news" in result.error_message.lower() or "vendor" in result.error_message.lower()
|
||||
|
||||
|
||||
class TestTimeoutHandlingWithPartialResults:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
def test_timeout_handling_returns_error(self, mock_bulk_news):
|
||||
def slow_fetch(*args, **kwargs):
|
||||
import time
|
||||
time.sleep(0.3)
|
||||
return []
|
||||
|
||||
mock_bulk_news.side_effect = slow_fetch
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 0.1,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
with pytest.raises(DiscoveryTimeoutError):
|
||||
graph.discover_trending()
|
||||
|
||||
|
||||
class TestNoTrendingStocksFound:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
def test_no_trending_stocks_found_returns_empty_list(
|
||||
self, mock_scores, mock_extract, mock_bulk_news
|
||||
):
|
||||
mock_bulk_news.return_value = [
|
||||
NewsArticle(
|
||||
title="General market update",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/general",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Markets were quiet today with no major news.",
|
||||
ticker_mentions=[],
|
||||
),
|
||||
]
|
||||
mock_extract.return_value = []
|
||||
mock_scores.return_value = []
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
result = graph.discover_trending()
|
||||
|
||||
assert result.status == DiscoveryStatus.COMPLETED
|
||||
assert len(result.trending_stocks) == 0
|
||||
|
||||
|
||||
class TestAllStocksFilteredOutBySectorFilter:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
def test_all_stocks_filtered_out_by_sector_filter(
|
||||
self, mock_scores, mock_extract, mock_bulk_news
|
||||
):
|
||||
mock_bulk_news.return_value = []
|
||||
mock_extract.return_value = []
|
||||
mock_scores.return_value = [
|
||||
TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple Inc.",
|
||||
score=10.0,
|
||||
mention_count=5,
|
||||
sentiment=0.5,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.EARNINGS,
|
||||
news_summary="Apple earnings",
|
||||
source_articles=[],
|
||||
),
|
||||
TrendingStock(
|
||||
ticker="MSFT",
|
||||
company_name="Microsoft",
|
||||
score=9.0,
|
||||
mention_count=4,
|
||||
sentiment=0.4,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.PRODUCT_LAUNCH,
|
||||
news_summary="Microsoft product",
|
||||
source_articles=[],
|
||||
),
|
||||
]
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
sector_filter=[Sector.HEALTHCARE],
|
||||
)
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
assert result.status == DiscoveryStatus.COMPLETED
|
||||
assert len(result.trending_stocks) == 0
|
||||
|
||||
|
||||
class TestAllStocksFilteredOutByEventFilter:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
def test_all_stocks_filtered_out_by_event_filter(
|
||||
self, mock_scores, mock_extract, mock_bulk_news
|
||||
):
|
||||
mock_bulk_news.return_value = []
|
||||
mock_extract.return_value = []
|
||||
mock_scores.return_value = [
|
||||
TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple Inc.",
|
||||
score=10.0,
|
||||
mention_count=5,
|
||||
sentiment=0.5,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.EARNINGS,
|
||||
news_summary="Apple earnings",
|
||||
source_articles=[],
|
||||
),
|
||||
]
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
event_filter=[EventCategory.MERGER_ACQUISITION],
|
||||
)
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
assert result.status == DiscoveryStatus.COMPLETED
|
||||
assert len(result.trending_stocks) == 0
|
||||
|
||||
|
||||
class TestMultipleSectorsAndEventsFiltering:
|
||||
@patch("tradingagents.graph.trading_graph.get_bulk_news")
|
||||
@patch("tradingagents.graph.trading_graph.extract_entities")
|
||||
@patch("tradingagents.graph.trading_graph.calculate_trending_scores")
|
||||
def test_combined_sector_and_event_filtering(
|
||||
self, mock_scores, mock_extract, mock_bulk_news
|
||||
):
|
||||
mock_bulk_news.return_value = []
|
||||
mock_extract.return_value = []
|
||||
mock_scores.return_value = [
|
||||
TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple Inc.",
|
||||
score=10.0,
|
||||
mention_count=5,
|
||||
sentiment=0.5,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.EARNINGS,
|
||||
news_summary="Apple earnings",
|
||||
source_articles=[],
|
||||
),
|
||||
TrendingStock(
|
||||
ticker="JPM",
|
||||
company_name="JPMorgan Chase",
|
||||
score=9.0,
|
||||
mention_count=4,
|
||||
sentiment=0.4,
|
||||
sector=Sector.FINANCE,
|
||||
event_type=EventCategory.EARNINGS,
|
||||
news_summary="JPM earnings",
|
||||
source_articles=[],
|
||||
),
|
||||
TrendingStock(
|
||||
ticker="XOM",
|
||||
company_name="Exxon Mobil",
|
||||
score=8.0,
|
||||
mention_count=3,
|
||||
sentiment=0.3,
|
||||
sector=Sector.ENERGY,
|
||||
event_type=EventCategory.REGULATORY,
|
||||
news_summary="XOM regulatory news",
|
||||
source_articles=[],
|
||||
),
|
||||
]
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
|
||||
graph = TradingAgentsGraph()
|
||||
graph.config = {
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
sector_filter=[Sector.TECHNOLOGY, Sector.FINANCE],
|
||||
event_filter=[EventCategory.EARNINGS],
|
||||
)
|
||||
result = graph.discover_trending(request)
|
||||
|
||||
assert result.status == DiscoveryStatus.COMPLETED
|
||||
assert len(result.trending_stocks) == 2
|
||||
tickers = [s.ticker for s in result.trending_stocks]
|
||||
assert "AAPL" in tickers
|
||||
assert "JPM" in tickers
|
||||
assert "XOM" not in tickers
|
||||
|
||||
|
||||
class TestDiscoveryResultPersistenceIntegration:
|
||||
def test_discovery_result_can_be_serialized_and_saved(self):
|
||||
from tradingagents.agents.discovery.persistence import (
|
||||
save_discovery_result,
|
||||
generate_markdown_summary,
|
||||
)
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
article = NewsArticle(
|
||||
title="Test article",
|
||||
source="Test",
|
||||
url="https://test.com",
|
||||
published_at=datetime.now(),
|
||||
content_snippet="Test content",
|
||||
ticker_mentions=["TEST"],
|
||||
)
|
||||
|
||||
stock = TrendingStock(
|
||||
ticker="TEST",
|
||||
company_name="Test Company",
|
||||
score=5.0,
|
||||
mention_count=2,
|
||||
sentiment=0.5,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.OTHER,
|
||||
news_summary="Test news summary",
|
||||
source_articles=[article],
|
||||
)
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
result = DiscoveryResult(
|
||||
request=request,
|
||||
trending_stocks=[stock],
|
||||
status=DiscoveryStatus.COMPLETED,
|
||||
started_at=datetime.now(),
|
||||
completed_at=datetime.now(),
|
||||
)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
path = save_discovery_result(result, base_path=Path(temp_dir))
|
||||
assert path.exists()
|
||||
assert (path / "discovery_result.json").exists()
|
||||
assert (path / "discovery_summary.md").exists()
|
||||
|
||||
markdown = generate_markdown_summary(result)
|
||||
assert "TEST" in markdown
|
||||
assert "Test Company" in markdown
|
||||
finally:
|
||||
shutil.rmtree(temp_dir)
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
import pytest
|
||||
from datetime import datetime
|
||||
from tradingagents.agents.discovery import (
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
Sector,
|
||||
EventCategory,
|
||||
)
|
||||
from tradingagents.agents.discovery.models import DiscoveryStatus
|
||||
|
||||
|
||||
class TestTrendingStock:
|
||||
def test_trending_stock_creation_and_validation(self):
|
||||
article = NewsArticle(
|
||||
title="Apple announces new iPhone",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/article1",
|
||||
published_at=datetime(2024, 1, 15, 10, 30, 0),
|
||||
content_snippet="Apple Inc announced its latest iPhone model today...",
|
||||
ticker_mentions=["AAPL"],
|
||||
)
|
||||
|
||||
stock = TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple Inc.",
|
||||
score=85.5,
|
||||
mention_count=10,
|
||||
sentiment=0.75,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.PRODUCT_LAUNCH,
|
||||
news_summary="Apple announced new iPhone with advanced AI features.",
|
||||
source_articles=[article],
|
||||
)
|
||||
|
||||
assert stock.ticker == "AAPL"
|
||||
assert stock.company_name == "Apple Inc."
|
||||
assert stock.score == 85.5
|
||||
assert stock.mention_count == 10
|
||||
assert stock.sentiment == 0.75
|
||||
assert stock.sector == Sector.TECHNOLOGY
|
||||
assert stock.event_type == EventCategory.PRODUCT_LAUNCH
|
||||
assert len(stock.source_articles) == 1
|
||||
|
||||
|
||||
class TestNewsArticle:
|
||||
def test_news_article_with_required_fields(self):
|
||||
published = datetime(2024, 1, 15, 14, 0, 0)
|
||||
|
||||
article = NewsArticle(
|
||||
title="Tesla Q4 Earnings Beat Expectations",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/news/tsla-earnings",
|
||||
published_at=published,
|
||||
content_snippet="Tesla Inc. reported fourth quarter earnings that exceeded analyst expectations...",
|
||||
ticker_mentions=["TSLA", "F"],
|
||||
)
|
||||
|
||||
assert article.title == "Tesla Q4 Earnings Beat Expectations"
|
||||
assert article.source == "Bloomberg"
|
||||
assert article.url == "https://bloomberg.com/news/tsla-earnings"
|
||||
assert article.published_at == published
|
||||
assert article.content_snippet.startswith("Tesla Inc.")
|
||||
assert "TSLA" in article.ticker_mentions
|
||||
assert "F" in article.ticker_mentions
|
||||
|
||||
|
||||
class TestDiscoveryRequest:
|
||||
def test_discovery_request_with_lookback_period_validation(self):
|
||||
created = datetime(2024, 1, 15, 12, 0, 0)
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
sector_filter=[Sector.TECHNOLOGY, Sector.HEALTHCARE],
|
||||
event_filter=[EventCategory.EARNINGS],
|
||||
max_results=20,
|
||||
created_at=created,
|
||||
)
|
||||
|
||||
assert request.lookback_period == "24h"
|
||||
assert Sector.TECHNOLOGY in request.sector_filter
|
||||
assert Sector.HEALTHCARE in request.sector_filter
|
||||
assert EventCategory.EARNINGS in request.event_filter
|
||||
assert request.max_results == 20
|
||||
assert request.created_at == created
|
||||
|
||||
def test_discovery_request_with_defaults(self):
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="1h",
|
||||
)
|
||||
|
||||
assert request.lookback_period == "1h"
|
||||
assert request.sector_filter is None
|
||||
assert request.event_filter is None
|
||||
assert request.max_results == 20
|
||||
assert request.created_at is not None
|
||||
|
||||
|
||||
class TestDiscoveryResult:
|
||||
def test_discovery_result_state_transitions(self):
|
||||
request = DiscoveryRequest(lookback_period="6h")
|
||||
started = datetime(2024, 1, 15, 12, 0, 0)
|
||||
|
||||
result = DiscoveryResult(
|
||||
request=request,
|
||||
trending_stocks=[],
|
||||
status=DiscoveryStatus.CREATED,
|
||||
started_at=started,
|
||||
)
|
||||
|
||||
assert result.status == DiscoveryStatus.CREATED
|
||||
|
||||
result.status = DiscoveryStatus.PROCESSING
|
||||
assert result.status == DiscoveryStatus.PROCESSING
|
||||
|
||||
result.status = DiscoveryStatus.COMPLETED
|
||||
result.completed_at = datetime(2024, 1, 15, 12, 1, 0)
|
||||
assert result.status == DiscoveryStatus.COMPLETED
|
||||
assert result.completed_at is not None
|
||||
|
||||
def test_discovery_result_failed_state(self):
|
||||
request = DiscoveryRequest(lookback_period="7d")
|
||||
|
||||
result = DiscoveryResult(
|
||||
request=request,
|
||||
trending_stocks=[],
|
||||
status=DiscoveryStatus.FAILED,
|
||||
started_at=datetime(2024, 1, 15, 12, 0, 0),
|
||||
error_message="News API rate limit exceeded",
|
||||
)
|
||||
|
||||
assert result.status == DiscoveryStatus.FAILED
|
||||
assert result.error_message == "News API rate limit exceeded"
|
||||
|
||||
|
||||
class TestSerializationRoundtrip:
|
||||
def test_to_dict_and_from_dict_serialization_roundtrip(self):
|
||||
article = NewsArticle(
|
||||
title="Microsoft acquires AI startup",
|
||||
source="WSJ",
|
||||
url="https://wsj.com/msft-acquisition",
|
||||
published_at=datetime(2024, 1, 15, 9, 0, 0),
|
||||
content_snippet="Microsoft Corp announced the acquisition of an AI startup...",
|
||||
ticker_mentions=["MSFT"],
|
||||
)
|
||||
|
||||
stock = TrendingStock(
|
||||
ticker="MSFT",
|
||||
company_name="Microsoft Corporation",
|
||||
score=92.3,
|
||||
mention_count=15,
|
||||
sentiment=0.65,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.MERGER_ACQUISITION,
|
||||
news_summary="Microsoft announces major AI acquisition.",
|
||||
source_articles=[article],
|
||||
)
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
sector_filter=[Sector.TECHNOLOGY],
|
||||
event_filter=[EventCategory.MERGER_ACQUISITION],
|
||||
max_results=10,
|
||||
created_at=datetime(2024, 1, 15, 8, 0, 0),
|
||||
)
|
||||
|
||||
result = DiscoveryResult(
|
||||
request=request,
|
||||
trending_stocks=[stock],
|
||||
status=DiscoveryStatus.COMPLETED,
|
||||
started_at=datetime(2024, 1, 15, 8, 0, 0),
|
||||
completed_at=datetime(2024, 1, 15, 8, 1, 30),
|
||||
)
|
||||
|
||||
result_dict = result.to_dict()
|
||||
restored_result = DiscoveryResult.from_dict(result_dict)
|
||||
|
||||
assert restored_result.status == result.status
|
||||
assert restored_result.request.lookback_period == request.lookback_period
|
||||
assert len(restored_result.trending_stocks) == 1
|
||||
|
||||
restored_stock = restored_result.trending_stocks[0]
|
||||
assert restored_stock.ticker == stock.ticker
|
||||
assert restored_stock.company_name == stock.company_name
|
||||
assert restored_stock.score == stock.score
|
||||
assert restored_stock.mention_count == stock.mention_count
|
||||
assert restored_stock.sentiment == stock.sentiment
|
||||
assert restored_stock.sector == stock.sector
|
||||
assert restored_stock.event_type == stock.event_type
|
||||
|
||||
assert len(restored_stock.source_articles) == 1
|
||||
restored_article = restored_stock.source_articles[0]
|
||||
assert restored_article.title == article.title
|
||||
assert restored_article.source == article.source
|
||||
assert restored_article.url == article.url
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
import pytest
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from tradingagents.agents.discovery import (
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
Sector,
|
||||
EventCategory,
|
||||
)
|
||||
from tradingagents.agents.discovery.persistence import (
|
||||
save_discovery_result,
|
||||
generate_markdown_summary,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_discovery_result():
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Apple announces new iPhone with AI features",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/apple-iphone-ai",
|
||||
published_at=datetime(2024, 1, 15, 10, 30, 0),
|
||||
content_snippet="Apple Inc announced its latest iPhone model with advanced AI...",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Apple stock surges on earnings beat",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/apple-earnings",
|
||||
published_at=datetime(2024, 1, 15, 11, 0, 0),
|
||||
content_snippet="Shares of Apple Inc surged after the company reported...",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Microsoft cloud revenue grows 25%",
|
||||
source="WSJ",
|
||||
url="https://wsj.com/msft-cloud",
|
||||
published_at=datetime(2024, 1, 15, 9, 0, 0),
|
||||
content_snippet="Microsoft Corp reported strong cloud revenue growth...",
|
||||
ticker_mentions=["MSFT"],
|
||||
),
|
||||
]
|
||||
|
||||
stocks = [
|
||||
TrendingStock(
|
||||
ticker="AAPL",
|
||||
company_name="Apple Inc.",
|
||||
score=8.54,
|
||||
mention_count=12,
|
||||
sentiment=0.72,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.EARNINGS,
|
||||
news_summary="Apple reported strong earnings and announced new AI features.",
|
||||
source_articles=[articles[0], articles[1]],
|
||||
),
|
||||
TrendingStock(
|
||||
ticker="MSFT",
|
||||
company_name="Microsoft Corporation",
|
||||
score=7.23,
|
||||
mention_count=9,
|
||||
sentiment=0.65,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.PRODUCT_LAUNCH,
|
||||
news_summary="Microsoft cloud business continues strong growth.",
|
||||
source_articles=[articles[2]],
|
||||
),
|
||||
TrendingStock(
|
||||
ticker="GOOGL",
|
||||
company_name="Alphabet Inc.",
|
||||
score=6.15,
|
||||
mention_count=7,
|
||||
sentiment=0.58,
|
||||
sector=Sector.TECHNOLOGY,
|
||||
event_type=EventCategory.REGULATORY,
|
||||
news_summary="Google faces regulatory scrutiny in multiple markets.",
|
||||
source_articles=[],
|
||||
),
|
||||
]
|
||||
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
sector_filter=[Sector.TECHNOLOGY],
|
||||
event_filter=[EventCategory.EARNINGS],
|
||||
max_results=20,
|
||||
created_at=datetime(2024, 1, 15, 14, 30, 45),
|
||||
)
|
||||
|
||||
return DiscoveryResult(
|
||||
request=request,
|
||||
trending_stocks=stocks,
|
||||
status=DiscoveryStatus.COMPLETED,
|
||||
started_at=datetime(2024, 1, 15, 14, 30, 45),
|
||||
completed_at=datetime(2024, 1, 15, 14, 31, 30),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_results_dir():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
yield Path(temp_dir)
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
class TestDirectoryStructureCreation:
|
||||
def test_creates_correct_directory_structure(self, sample_discovery_result, temp_results_dir):
|
||||
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
|
||||
|
||||
assert result_path.exists()
|
||||
assert result_path.is_dir()
|
||||
|
||||
path_parts = result_path.parts
|
||||
assert "discovery" in path_parts
|
||||
|
||||
date_part = path_parts[-2]
|
||||
time_part = path_parts[-1]
|
||||
|
||||
assert len(date_part.split("-")) == 3
|
||||
assert len(time_part.split("-")) == 3
|
||||
|
||||
|
||||
class TestDiscoveryResultJson:
|
||||
def test_discovery_result_json_contains_all_fields(self, sample_discovery_result, temp_results_dir):
|
||||
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
|
||||
|
||||
json_path = result_path / "discovery_result.json"
|
||||
assert json_path.exists()
|
||||
|
||||
with open(json_path, "r") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
assert "request" in saved_data
|
||||
assert "trending_stocks" in saved_data
|
||||
assert "status" in saved_data
|
||||
assert "started_at" in saved_data
|
||||
assert "completed_at" in saved_data
|
||||
|
||||
assert saved_data["request"]["lookback_period"] == "24h"
|
||||
assert saved_data["status"] == "completed"
|
||||
assert len(saved_data["trending_stocks"]) == 3
|
||||
|
||||
first_stock = saved_data["trending_stocks"][0]
|
||||
assert first_stock["ticker"] == "AAPL"
|
||||
assert first_stock["company_name"] == "Apple Inc."
|
||||
assert first_stock["score"] == 8.54
|
||||
assert first_stock["mention_count"] == 12
|
||||
assert first_stock["sentiment"] == 0.72
|
||||
assert first_stock["sector"] == "technology"
|
||||
assert first_stock["event_type"] == "earnings"
|
||||
assert "news_summary" in first_stock
|
||||
assert "source_articles" in first_stock
|
||||
|
||||
|
||||
class TestDiscoverySummaryMarkdown:
|
||||
def test_discovery_summary_md_is_human_readable(self, sample_discovery_result, temp_results_dir):
|
||||
result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
|
||||
|
||||
md_path = result_path / "discovery_summary.md"
|
||||
assert md_path.exists()
|
||||
|
||||
with open(md_path, "r") as f:
|
||||
markdown_content = f.read()
|
||||
|
||||
assert "# Discovery Results" in markdown_content
|
||||
assert "Timestamp:" in markdown_content
|
||||
assert "Lookback Period:" in markdown_content
|
||||
assert "24h" in markdown_content
|
||||
assert "Total Stocks Found:" in markdown_content
|
||||
|
||||
assert "## Trending Stocks" in markdown_content
|
||||
assert "| Rank |" in markdown_content
|
||||
assert "| Ticker |" in markdown_content
|
||||
assert "| Company |" in markdown_content
|
||||
assert "| Score |" in markdown_content
|
||||
assert "| Mentions |" in markdown_content
|
||||
assert "| Event |" in markdown_content
|
||||
|
||||
assert "AAPL" in markdown_content
|
||||
assert "Apple Inc." in markdown_content
|
||||
assert "8.54" in markdown_content
|
||||
assert "12" in markdown_content
|
||||
assert "earnings" in markdown_content
|
||||
|
||||
assert "MSFT" in markdown_content
|
||||
assert "Microsoft Corporation" in markdown_content
|
||||
|
||||
assert "## Top 3 Detailed Analysis" in markdown_content
|
||||
assert "### 1. AAPL - Apple Inc." in markdown_content
|
||||
assert "**Score:**" in markdown_content
|
||||
assert "**Sentiment:**" in markdown_content
|
||||
assert "**Sector:**" in markdown_content
|
||||
assert "**Event Type:**" in markdown_content
|
||||
assert "**Mentions:**" in markdown_content
|
||||
assert "**News Summary:**" in markdown_content
|
||||
|
||||
|
||||
class TestMarkdownGeneration:
|
||||
def test_generate_markdown_with_filters(self, sample_discovery_result):
|
||||
markdown = generate_markdown_summary(sample_discovery_result)
|
||||
|
||||
assert "sector=technology" in markdown.lower()
|
||||
assert "event=earnings" in markdown.lower()
|
||||
|
||||
def test_generate_markdown_without_filters(self):
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="6h",
|
||||
created_at=datetime(2024, 1, 15, 10, 0, 0),
|
||||
)
|
||||
|
||||
result = DiscoveryResult(
|
||||
request=request,
|
||||
trending_stocks=[],
|
||||
status=DiscoveryStatus.COMPLETED,
|
||||
started_at=datetime(2024, 1, 15, 10, 0, 0),
|
||||
completed_at=datetime(2024, 1, 15, 10, 1, 0),
|
||||
)
|
||||
|
||||
markdown = generate_markdown_summary(result)
|
||||
|
||||
assert "Filters:" in markdown
|
||||
assert "None" in markdown
|
||||
|
|
@ -0,0 +1,469 @@
|
|||
import pytest
|
||||
import math
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
from tradingagents.agents.discovery import NewsArticle, EventCategory, Sector
|
||||
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
||||
|
||||
|
||||
class TestFrequencyCalculation:
|
||||
def test_frequency_calculation_unique_article_count(self):
|
||||
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
||||
|
||||
now = datetime.now()
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Apple Q4 Earnings",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/article1",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="Apple Inc reported strong earnings.",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Apple iPhone Sales",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/article2",
|
||||
published_at=now - timedelta(hours=2),
|
||||
content_snippet="Apple saw record iPhone sales.",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Apple AI Features",
|
||||
source="WSJ",
|
||||
url="https://wsj.com/article3",
|
||||
published_at=now - timedelta(hours=3),
|
||||
content_snippet="Apple announced AI features.",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
]
|
||||
|
||||
mentions = [
|
||||
EntityMention(
|
||||
company_name="Apple Inc",
|
||||
confidence=0.95,
|
||||
context_snippet="Apple Inc reported strong earnings",
|
||||
article_id="article_0",
|
||||
event_type=EventCategory.EARNINGS,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Apple",
|
||||
confidence=0.90,
|
||||
context_snippet="Apple saw record iPhone sales",
|
||||
article_id="article_1",
|
||||
event_type=EventCategory.EARNINGS,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Apple Inc.",
|
||||
confidence=0.92,
|
||||
context_snippet="Apple announced AI features",
|
||||
article_id="article_2",
|
||||
event_type=EventCategory.PRODUCT_LAUNCH,
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = "AAPL"
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.classify_sector"
|
||||
) as mock_sector:
|
||||
mock_sector.return_value = "technology"
|
||||
|
||||
result = calculate_trending_scores(mentions, articles)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].ticker == "AAPL"
|
||||
assert result[0].mention_count == 3
|
||||
|
||||
|
||||
class TestSentimentIntensityFactor:
|
||||
def test_sentiment_intensity_uses_absolute_value(self):
|
||||
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
||||
|
||||
now = datetime.now()
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Stock drops sharply",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/article1",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="Company faced major issues.",
|
||||
ticker_mentions=["TSLA"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="More bad news",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/article2",
|
||||
published_at=now - timedelta(hours=2),
|
||||
content_snippet="Further decline expected.",
|
||||
ticker_mentions=["TSLA"],
|
||||
),
|
||||
]
|
||||
|
||||
mentions = [
|
||||
EntityMention(
|
||||
company_name="Tesla",
|
||||
confidence=0.95,
|
||||
context_snippet="Company faced major issues",
|
||||
article_id="article_0",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=-0.8,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Tesla Inc",
|
||||
confidence=0.90,
|
||||
context_snippet="Further decline expected",
|
||||
article_id="article_1",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=-0.6,
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = "TSLA"
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.classify_sector"
|
||||
) as mock_sector:
|
||||
mock_sector.return_value = "technology"
|
||||
|
||||
result = calculate_trending_scores(mentions, articles)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].sentiment < 0
|
||||
expected_sentiment = (-0.8 * 0.95 + -0.6 * 0.90) / (0.95 + 0.90)
|
||||
assert abs(result[0].sentiment - expected_sentiment) < 0.01
|
||||
|
||||
|
||||
class TestRecencyWeightExponentialDecay:
|
||||
def test_recency_weight_exponential_decay(self):
|
||||
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
||||
|
||||
now = datetime.now()
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Recent news",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/article1",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="Recent company news.",
|
||||
ticker_mentions=["NVDA"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Older news",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/article2",
|
||||
published_at=now - timedelta(hours=10),
|
||||
content_snippet="Older company news.",
|
||||
ticker_mentions=["NVDA"],
|
||||
),
|
||||
]
|
||||
|
||||
mentions = [
|
||||
EntityMention(
|
||||
company_name="Nvidia",
|
||||
confidence=0.90,
|
||||
context_snippet="Recent company news",
|
||||
article_id="article_0",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=0.5,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Nvidia",
|
||||
confidence=0.90,
|
||||
context_snippet="Older company news",
|
||||
article_id="article_1",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=0.5,
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = "NVDA"
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.classify_sector"
|
||||
) as mock_sector:
|
||||
mock_sector.return_value = "technology"
|
||||
|
||||
result = calculate_trending_scores(mentions, articles, decay_rate=0.1)
|
||||
|
||||
assert len(result) == 1
|
||||
recent_weight = math.exp(-0.1 * 1)
|
||||
older_weight = math.exp(-0.1 * 10)
|
||||
avg_recency = (recent_weight + older_weight) / 2
|
||||
assert result[0].score > 0
|
||||
|
||||
|
||||
class TestMinimumThresholdFiltering:
|
||||
def test_minimum_threshold_filtering_requires_two_articles(self):
|
||||
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
||||
|
||||
now = datetime.now()
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Single mention stock",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/article1",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="Some company news.",
|
||||
ticker_mentions=["AMD"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Multiple mention stock 1",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/article2",
|
||||
published_at=now - timedelta(hours=2),
|
||||
content_snippet="Popular company news.",
|
||||
ticker_mentions=["MSFT"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Multiple mention stock 2",
|
||||
source="WSJ",
|
||||
url="https://wsj.com/article3",
|
||||
published_at=now - timedelta(hours=3),
|
||||
content_snippet="More popular company news.",
|
||||
ticker_mentions=["MSFT"],
|
||||
),
|
||||
]
|
||||
|
||||
mentions = [
|
||||
EntityMention(
|
||||
company_name="AMD",
|
||||
confidence=0.90,
|
||||
context_snippet="Some company news",
|
||||
article_id="article_0",
|
||||
event_type=EventCategory.OTHER,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Microsoft",
|
||||
confidence=0.95,
|
||||
context_snippet="Popular company news",
|
||||
article_id="article_1",
|
||||
event_type=EventCategory.OTHER,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Microsoft Corp",
|
||||
confidence=0.92,
|
||||
context_snippet="More popular company news",
|
||||
article_id="article_2",
|
||||
event_type=EventCategory.OTHER,
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
||||
) as mock_resolve:
|
||||
|
||||
def resolve_side_effect(name):
|
||||
if "AMD" in name or name == "AMD":
|
||||
return "AMD"
|
||||
return "MSFT"
|
||||
|
||||
mock_resolve.side_effect = resolve_side_effect
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.classify_sector"
|
||||
) as mock_sector:
|
||||
mock_sector.return_value = "technology"
|
||||
|
||||
result = calculate_trending_scores(mentions, articles, min_mentions=2)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].ticker == "MSFT"
|
||||
assert all(stock.mention_count >= 2 for stock in result)
|
||||
|
||||
|
||||
class TestFinalScoreFormulaCorrectness:
|
||||
def test_final_score_formula_correctness(self):
|
||||
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
||||
|
||||
now = datetime.now()
|
||||
hours_old = 2.0
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="Test article 1",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/article1",
|
||||
published_at=now - timedelta(hours=hours_old),
|
||||
content_snippet="Google announced results.",
|
||||
ticker_mentions=["GOOGL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Test article 2",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/article2",
|
||||
published_at=now - timedelta(hours=hours_old),
|
||||
content_snippet="Alphabet earnings beat.",
|
||||
ticker_mentions=["GOOGL"],
|
||||
),
|
||||
]
|
||||
|
||||
sentiment_val = 0.6
|
||||
confidence = 0.9
|
||||
mentions = [
|
||||
EntityMention(
|
||||
company_name="Google",
|
||||
confidence=confidence,
|
||||
context_snippet="Google announced results",
|
||||
article_id="article_0",
|
||||
event_type=EventCategory.EARNINGS,
|
||||
sentiment=sentiment_val,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Alphabet",
|
||||
confidence=confidence,
|
||||
context_snippet="Alphabet earnings beat",
|
||||
article_id="article_1",
|
||||
event_type=EventCategory.EARNINGS,
|
||||
sentiment=sentiment_val,
|
||||
),
|
||||
]
|
||||
|
||||
decay_rate = 0.1
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = "GOOGL"
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.classify_sector"
|
||||
) as mock_sector:
|
||||
mock_sector.return_value = "technology"
|
||||
|
||||
result = calculate_trending_scores(
|
||||
mentions, articles, decay_rate=decay_rate
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
stock = result[0]
|
||||
|
||||
frequency = 2
|
||||
sentiment_factor = 1 + abs(sentiment_val)
|
||||
recency_weight = math.exp(-decay_rate * hours_old)
|
||||
expected_score = frequency * sentiment_factor * recency_weight
|
||||
|
||||
assert abs(stock.score - expected_score) < 0.01
|
||||
|
||||
|
||||
class TestSortingByScoreDescending:
|
||||
def test_results_sorted_by_score_descending(self):
|
||||
from tradingagents.agents.discovery.scorer import calculate_trending_scores
|
||||
|
||||
now = datetime.now()
|
||||
articles = [
|
||||
NewsArticle(
|
||||
title="High score stock 1",
|
||||
source="Reuters",
|
||||
url="https://reuters.com/article1",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="Apple news.",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="High score stock 2",
|
||||
source="Bloomberg",
|
||||
url="https://bloomberg.com/article2",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="More Apple news.",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="High score stock 3",
|
||||
source="WSJ",
|
||||
url="https://wsj.com/article3",
|
||||
published_at=now - timedelta(hours=1),
|
||||
content_snippet="Even more Apple news.",
|
||||
ticker_mentions=["AAPL"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Low score stock 1",
|
||||
source="CNBC",
|
||||
url="https://cnbc.com/article4",
|
||||
published_at=now - timedelta(hours=10),
|
||||
content_snippet="Tesla news.",
|
||||
ticker_mentions=["TSLA"],
|
||||
),
|
||||
NewsArticle(
|
||||
title="Low score stock 2",
|
||||
source="FT",
|
||||
url="https://ft.com/article5",
|
||||
published_at=now - timedelta(hours=10),
|
||||
content_snippet="More Tesla news.",
|
||||
ticker_mentions=["TSLA"],
|
||||
),
|
||||
]
|
||||
|
||||
mentions = [
|
||||
EntityMention(
|
||||
company_name="Apple",
|
||||
confidence=0.95,
|
||||
context_snippet="Apple news",
|
||||
article_id="article_0",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=0.8,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Apple Inc",
|
||||
confidence=0.93,
|
||||
context_snippet="More Apple news",
|
||||
article_id="article_1",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=0.8,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Apple",
|
||||
confidence=0.90,
|
||||
context_snippet="Even more Apple news",
|
||||
article_id="article_2",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=0.8,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Tesla",
|
||||
confidence=0.85,
|
||||
context_snippet="Tesla news",
|
||||
article_id="article_3",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=0.2,
|
||||
),
|
||||
EntityMention(
|
||||
company_name="Tesla Inc",
|
||||
confidence=0.85,
|
||||
context_snippet="More Tesla news",
|
||||
article_id="article_4",
|
||||
event_type=EventCategory.OTHER,
|
||||
sentiment=0.2,
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.resolve_ticker"
|
||||
) as mock_resolve:
|
||||
|
||||
def resolve_side_effect(name):
|
||||
if "Apple" in name:
|
||||
return "AAPL"
|
||||
if "Tesla" in name:
|
||||
return "TSLA"
|
||||
return None
|
||||
|
||||
mock_resolve.side_effect = resolve_side_effect
|
||||
|
||||
with patch(
|
||||
"tradingagents.agents.discovery.scorer.classify_sector"
|
||||
) as mock_sector:
|
||||
mock_sector.return_value = "technology"
|
||||
|
||||
result = calculate_trending_scores(mentions, articles, min_mentions=2)
|
||||
|
||||
assert len(result) == 2
|
||||
for i in range(len(result) - 1):
|
||||
assert result[i].score >= result[i + 1].score
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tradingagents.dataflows.trending.sector_classifier import (
|
||||
classify_sector,
|
||||
TICKER_TO_SECTOR,
|
||||
VALID_SECTORS,
|
||||
_llm_classify_sector,
|
||||
_sector_cache,
|
||||
)
|
||||
|
||||
|
||||
class TestStaticSectorMapping:
|
||||
def test_static_sector_mapping_for_known_technology_tickers(self):
|
||||
assert classify_sector("AAPL") == "technology"
|
||||
assert classify_sector("MSFT") == "technology"
|
||||
assert classify_sector("GOOGL") == "technology"
|
||||
assert classify_sector("NVDA") == "technology"
|
||||
|
||||
def test_static_sector_mapping_for_known_healthcare_tickers(self):
|
||||
assert classify_sector("JNJ") == "healthcare"
|
||||
assert classify_sector("PFE") == "healthcare"
|
||||
assert classify_sector("UNH") == "healthcare"
|
||||
|
||||
def test_static_sector_mapping_for_known_finance_tickers(self):
|
||||
assert classify_sector("JPM") == "finance"
|
||||
assert classify_sector("BAC") == "finance"
|
||||
assert classify_sector("GS") == "finance"
|
||||
|
||||
def test_static_sector_mapping_for_known_energy_tickers(self):
|
||||
assert classify_sector("XOM") == "energy"
|
||||
assert classify_sector("CVX") == "energy"
|
||||
assert classify_sector("COP") == "energy"
|
||||
|
||||
def test_static_sector_mapping_case_insensitive(self):
|
||||
assert classify_sector("aapl") == "technology"
|
||||
assert classify_sector("AAPL") == "technology"
|
||||
assert classify_sector("Aapl") == "technology"
|
||||
|
||||
|
||||
class TestLLMFallback:
|
||||
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
|
||||
def test_llm_fallback_for_unknown_tickers(self, mock_llm_classify):
|
||||
mock_llm_classify.return_value = "technology"
|
||||
_sector_cache.clear()
|
||||
|
||||
result = classify_sector("UNKNOWNTICKER123")
|
||||
|
||||
mock_llm_classify.assert_called_once_with("UNKNOWNTICKER123")
|
||||
assert result == "technology"
|
||||
|
||||
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
|
||||
def test_llm_fallback_caches_results(self, mock_llm_classify):
|
||||
mock_llm_classify.return_value = "healthcare"
|
||||
_sector_cache.clear()
|
||||
|
||||
result1 = classify_sector("NEWCO123")
|
||||
result2 = classify_sector("NEWCO123")
|
||||
|
||||
assert mock_llm_classify.call_count == 1
|
||||
assert result1 == "healthcare"
|
||||
assert result2 == "healthcare"
|
||||
|
||||
@patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
|
||||
def test_llm_fallback_returns_other_on_error(self, mock_llm_classify):
|
||||
mock_llm_classify.side_effect = Exception("LLM error")
|
||||
_sector_cache.clear()
|
||||
|
||||
result = classify_sector("ERRORCO")
|
||||
|
||||
assert result == "other"
|
||||
|
||||
|
||||
class TestAllSectorCategories:
|
||||
def test_all_sector_categories_in_valid_sectors(self):
|
||||
expected_sectors = {
|
||||
"technology",
|
||||
"healthcare",
|
||||
"finance",
|
||||
"energy",
|
||||
"consumer_goods",
|
||||
"industrials",
|
||||
"other",
|
||||
}
|
||||
assert VALID_SECTORS == expected_sectors
|
||||
|
||||
def test_static_mapping_covers_all_sector_categories(self):
|
||||
sectors_in_mapping = set(TICKER_TO_SECTOR.values())
|
||||
assert sectors_in_mapping.issubset(VALID_SECTORS)
|
||||
|
||||
def test_classify_sector_always_returns_valid_sector(self):
|
||||
test_tickers = ["AAPL", "JPM", "XOM", "JNJ", "WMT", "CAT"]
|
||||
for ticker in test_tickers:
|
||||
result = classify_sector(ticker)
|
||||
assert result in VALID_SECTORS
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
import pytest
|
||||
import logging
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tradingagents.dataflows.trending.stock_resolver import (
|
||||
resolve_ticker,
|
||||
validate_us_ticker,
|
||||
_normalize_company_name,
|
||||
_search_yfinance_ticker,
|
||||
)
|
||||
|
||||
|
||||
class TestStaticLookup:
|
||||
def test_static_lookup_for_known_companies(self):
|
||||
assert resolve_ticker("Apple") == "AAPL"
|
||||
assert resolve_ticker("Microsoft") == "MSFT"
|
||||
assert resolve_ticker("Google") == "GOOGL"
|
||||
assert resolve_ticker("Amazon") == "AMZN"
|
||||
assert resolve_ticker("Tesla") == "TSLA"
|
||||
assert resolve_ticker("Nvidia") == "NVDA"
|
||||
|
||||
def test_static_lookup_case_insensitive(self):
|
||||
assert resolve_ticker("APPLE") == "AAPL"
|
||||
assert resolve_ticker("apple") == "AAPL"
|
||||
assert resolve_ticker("ApPlE") == "AAPL"
|
||||
assert resolve_ticker("microsoft") == "MSFT"
|
||||
assert resolve_ticker("MICROSOFT") == "MSFT"
|
||||
|
||||
|
||||
class TestNameVariationHandling:
|
||||
def test_name_variation_handling_with_suffixes(self):
|
||||
assert resolve_ticker("Apple Inc.") == "AAPL"
|
||||
assert resolve_ticker("Apple Inc") == "AAPL"
|
||||
assert resolve_ticker("Apple Corporation") == "AAPL"
|
||||
assert resolve_ticker("Microsoft Corp.") == "MSFT"
|
||||
assert resolve_ticker("Microsoft Corp") == "MSFT"
|
||||
assert resolve_ticker("Tesla Inc") == "TSLA"
|
||||
|
||||
def test_name_variation_handling_informal_names(self):
|
||||
assert resolve_ticker("the iPhone maker") == "AAPL"
|
||||
assert resolve_ticker("iPhone maker") == "AAPL"
|
||||
assert resolve_ticker("the search giant") == "GOOGL"
|
||||
assert resolve_ticker("the e-commerce giant") == "AMZN"
|
||||
assert resolve_ticker("EV maker Tesla") == "TSLA"
|
||||
|
||||
def test_name_variation_handling_alternate_names(self):
|
||||
assert resolve_ticker("Alphabet") == "GOOGL"
|
||||
assert resolve_ticker("Meta") == "META"
|
||||
assert resolve_ticker("Facebook") == "META"
|
||||
assert resolve_ticker("Meta Platforms") == "META"
|
||||
|
||||
|
||||
class TestYfinanceFallback:
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.validate_us_ticker")
|
||||
def test_yfinance_fallback_for_unknown_company(self, mock_validate, mock_search):
|
||||
mock_search.return_value = "PLTR"
|
||||
mock_validate.return_value = True
|
||||
|
||||
result = resolve_ticker("UnknownTechStartupXYZ")
|
||||
|
||||
mock_search.assert_called_once()
|
||||
assert result == "PLTR"
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
|
||||
def test_yfinance_fallback_returns_none_when_not_found(self, mock_search):
|
||||
mock_search.return_value = None
|
||||
|
||||
result = resolve_ticker("NonexistentCompanyXYZ123")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUSExchangeValidation:
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
|
||||
def test_validate_us_ticker_accepts_nyse(self, mock_ticker):
|
||||
mock_info = {"exchange": "NYQ"}
|
||||
mock_ticker.return_value.info = mock_info
|
||||
|
||||
assert validate_us_ticker("IBM") is True
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
|
||||
def test_validate_us_ticker_accepts_nasdaq(self, mock_ticker):
|
||||
mock_info = {"exchange": "NMS"}
|
||||
mock_ticker.return_value.info = mock_info
|
||||
|
||||
assert validate_us_ticker("AAPL") is True
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
|
||||
def test_validate_us_ticker_accepts_amex(self, mock_ticker):
|
||||
mock_info = {"exchange": "ASE"}
|
||||
mock_ticker.return_value.info = mock_info
|
||||
|
||||
assert validate_us_ticker("SPY") is True
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
|
||||
def test_validate_us_ticker_rejects_international(self, mock_ticker):
|
||||
mock_info = {"exchange": "LSE"}
|
||||
mock_ticker.return_value.info = mock_info
|
||||
|
||||
assert validate_us_ticker("VOD.L") is False
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
|
||||
def test_validate_us_ticker_rejects_otc(self, mock_ticker):
|
||||
mock_info = {"exchange": "PNK"}
|
||||
mock_ticker.return_value.info = mock_info
|
||||
|
||||
assert validate_us_ticker("OTCPK") is False
|
||||
|
||||
|
||||
class TestAmbiguousResolutionLogging:
|
||||
def test_ambiguous_resolution_logs_multiple_matches(self, caplog):
|
||||
with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"):
|
||||
pass
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.validate_us_ticker")
|
||||
def test_yfinance_fallback_is_logged(self, mock_validate, mock_search, caplog):
|
||||
mock_search.return_value = "RBLX"
|
||||
mock_validate.return_value = True
|
||||
|
||||
with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"):
|
||||
result = resolve_ticker("SomeRandomCompanyNotInMapping")
|
||||
|
||||
assert any("fallback" in record.message.lower() or "yfinance" in record.message.lower()
|
||||
for record in caplog.records)
|
||||
|
||||
@patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
|
||||
def test_validation_failure_is_logged(self, mock_ticker, caplog):
|
||||
mock_info = {"exchange": "LSE"}
|
||||
mock_ticker.return_value.info = mock_info
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="tradingagents.dataflows.trending.stock_resolver"):
|
||||
result = validate_us_ticker("VOD.L")
|
||||
|
||||
assert result is False
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from .models import (
|
||||
NewsArticle,
|
||||
TrendingStock,
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
Sector,
|
||||
EventCategory,
|
||||
)
|
||||
from .exceptions import (
|
||||
DiscoveryError,
|
||||
NewsUnavailableError,
|
||||
DiscoveryTimeoutError,
|
||||
TickerResolutionError,
|
||||
)
|
||||
from .entity_extractor import (
|
||||
EntityMention,
|
||||
extract_entities,
|
||||
BATCH_SIZE,
|
||||
)
|
||||
from .scorer import (
|
||||
calculate_trending_scores,
|
||||
DEFAULT_DECAY_RATE,
|
||||
DEFAULT_MAX_RESULTS,
|
||||
DEFAULT_MIN_MENTIONS,
|
||||
)
|
||||
from .persistence import (
|
||||
save_discovery_result,
|
||||
generate_markdown_summary,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"NewsArticle",
|
||||
"TrendingStock",
|
||||
"DiscoveryRequest",
|
||||
"DiscoveryResult",
|
||||
"DiscoveryStatus",
|
||||
"Sector",
|
||||
"EventCategory",
|
||||
"DiscoveryError",
|
||||
"NewsUnavailableError",
|
||||
"DiscoveryTimeoutError",
|
||||
"TickerResolutionError",
|
||||
"EntityMention",
|
||||
"extract_entities",
|
||||
"BATCH_SIZE",
|
||||
"calculate_trending_scores",
|
||||
"DEFAULT_DECAY_RATE",
|
||||
"DEFAULT_MAX_RESULTS",
|
||||
"DEFAULT_MIN_MENTIONS",
|
||||
"save_discovery_result",
|
||||
"generate_markdown_summary",
|
||||
]
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.discovery.models import NewsArticle, EventCategory
|
||||
|
||||
|
||||
BATCH_SIZE = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityMention:
|
||||
company_name: str
|
||||
confidence: float
|
||||
context_snippet: str
|
||||
article_id: str
|
||||
event_type: EventCategory
|
||||
sentiment: float = field(default=0.0)
|
||||
|
||||
|
||||
class ExtractedEntity(BaseModel):
|
||||
company_name: str = PydanticField(description="The name of the publicly traded company mentioned")
|
||||
confidence: float = PydanticField(description="Confidence score from 0.0 to 1.0 based on mention clarity")
|
||||
context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
|
||||
event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
|
||||
sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
|
||||
|
||||
|
||||
class ExtractionResponse(BaseModel):
|
||||
entities: List[ExtractedEntity] = PydanticField(default_factory=list, description="List of extracted company entities")
|
||||
|
||||
|
||||
def _get_llm(config: Optional[dict] = None):
|
||||
cfg = config or DEFAULT_CONFIG
|
||||
provider = cfg.get("llm_provider", "openai").lower()
|
||||
model = cfg.get("quick_think_llm", "gpt-4o-mini")
|
||||
backend_url = cfg.get("backend_url", "https://api.openai.com/v1")
|
||||
|
||||
if provider in ("openai", "ollama", "openrouter"):
|
||||
return ChatOpenAI(model=model, base_url=backend_url)
|
||||
elif provider == "anthropic":
|
||||
return ChatAnthropic(model=model, base_url=backend_url)
|
||||
elif provider == "google":
|
||||
return ChatGoogleGenerativeAI(model=model)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
|
||||
EXTRACTION_PROMPT = """You are an expert at identifying publicly traded companies mentioned in news articles.
|
||||
|
||||
For each article provided, extract all mentions of publicly traded companies. For each company mention:
|
||||
|
||||
1. Extract the company name as it appears (e.g., "Apple Inc.", "Apple", "AAPL", "the iPhone maker")
|
||||
2. Assign a confidence score from 0.0 to 1.0 based on how clearly the company is mentioned:
|
||||
- 0.9-1.0: Direct company name or ticker symbol
|
||||
- 0.7-0.9: Clear reference with context (e.g., "the Cupertino tech giant")
|
||||
- 0.5-0.7: Indirect reference requiring inference
|
||||
- Below 0.5: Uncertain or ambiguous reference
|
||||
3. Extract 50-100 characters of surrounding context
|
||||
4. Classify the event type:
|
||||
- earnings: Quarterly/annual earnings reports, revenue announcements
|
||||
- merger_acquisition: Mergers, acquisitions, buyouts, takeovers
|
||||
- regulatory: SEC filings, government investigations, compliance issues
|
||||
- product_launch: New products, services, or features
|
||||
- executive_change: CEO/CFO changes, board appointments, departures
|
||||
- other: Any other business news
|
||||
5. Assign a sentiment score from -1.0 to 1.0:
|
||||
- -1.0: Very negative news (lawsuits, crashes, major failures)
|
||||
- -0.5: Moderately negative news
|
||||
- 0.0: Neutral news
|
||||
- 0.5: Moderately positive news
|
||||
- 1.0: Very positive news (breakthroughs, record earnings)
|
||||
|
||||
Only extract companies that are publicly traded on major stock exchanges.
|
||||
Handle name variations by providing the most complete company name found.
|
||||
|
||||
Articles to analyze:
|
||||
{articles_text}
|
||||
|
||||
Extract all company mentions from the articles above."""
|
||||
|
||||
|
||||
def _format_articles_for_prompt(articles: List[NewsArticle], start_idx: int) -> str:
|
||||
formatted = []
|
||||
for i, article in enumerate(articles):
|
||||
article_id = f"article_{start_idx + i}"
|
||||
formatted.append(
|
||||
f"[{article_id}]\n"
|
||||
f"Title: {article.title}\n"
|
||||
f"Source: {article.source}\n"
|
||||
f"Content: {article.content_snippet}\n"
|
||||
)
|
||||
return "\n---\n".join(formatted)
|
||||
|
||||
|
||||
def _extract_batch(
|
||||
articles: List[NewsArticle],
|
||||
start_idx: int,
|
||||
llm,
|
||||
) -> List[EntityMention]:
|
||||
if not articles:
|
||||
return []
|
||||
|
||||
articles_text = _format_articles_for_prompt(articles, start_idx)
|
||||
prompt = EXTRACTION_PROMPT.format(articles_text=articles_text)
|
||||
|
||||
structured_llm = llm.with_structured_output(ExtractionResponse)
|
||||
response = structured_llm.invoke(prompt)
|
||||
|
||||
mentions = []
|
||||
for entity in response.entities:
|
||||
event_type_str = entity.event_type.lower().strip()
|
||||
valid_event_types = {e.value for e in EventCategory}
|
||||
if event_type_str not in valid_event_types:
|
||||
event_type_str = "other"
|
||||
|
||||
confidence = max(0.0, min(1.0, entity.confidence))
|
||||
sentiment = max(-1.0, min(1.0, entity.sentiment))
|
||||
|
||||
context = entity.context_snippet
|
||||
if len(context) > 150:
|
||||
context = context[:147] + "..."
|
||||
|
||||
mention = EntityMention(
|
||||
company_name=entity.company_name,
|
||||
confidence=confidence,
|
||||
context_snippet=context,
|
||||
article_id=f"article_{start_idx}",
|
||||
event_type=EventCategory(event_type_str),
|
||||
sentiment=sentiment,
|
||||
)
|
||||
mentions.append(mention)
|
||||
|
||||
return mentions
|
||||
|
||||
|
||||
def extract_entities(
|
||||
articles: List[NewsArticle],
|
||||
config: Optional[dict] = None,
|
||||
) -> List[EntityMention]:
|
||||
if not articles:
|
||||
return []
|
||||
|
||||
llm = _get_llm(config)
|
||||
all_mentions: List[EntityMention] = []
|
||||
|
||||
for batch_start in range(0, len(articles), BATCH_SIZE):
|
||||
batch_end = min(batch_start + BATCH_SIZE, len(articles))
|
||||
batch = articles[batch_start:batch_end]
|
||||
|
||||
batch_mentions = _extract_batch(batch, batch_start, llm)
|
||||
all_mentions.extend(batch_mentions)
|
||||
|
||||
return all_mentions
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
class DiscoveryError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NewsUnavailableError(DiscoveryError):
|
||||
pass
|
||||
|
||||
|
||||
class DiscoveryTimeoutError(DiscoveryError):
|
||||
pass
|
||||
|
||||
|
||||
class TickerResolutionError(DiscoveryError):
|
||||
pass
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
|
||||
class DiscoveryStatus(Enum):
|
||||
CREATED = "created"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Sector(Enum):
|
||||
TECHNOLOGY = "technology"
|
||||
HEALTHCARE = "healthcare"
|
||||
FINANCE = "finance"
|
||||
ENERGY = "energy"
|
||||
CONSUMER_GOODS = "consumer_goods"
|
||||
INDUSTRIALS = "industrials"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class EventCategory(Enum):
|
||||
EARNINGS = "earnings"
|
||||
MERGER_ACQUISITION = "merger_acquisition"
|
||||
REGULATORY = "regulatory"
|
||||
PRODUCT_LAUNCH = "product_launch"
|
||||
EXECUTIVE_CHANGE = "executive_change"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsArticle:
|
||||
title: str
|
||||
source: str
|
||||
url: str
|
||||
published_at: datetime
|
||||
content_snippet: str
|
||||
ticker_mentions: List[str]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"title": self.title,
|
||||
"source": self.source,
|
||||
"url": self.url,
|
||||
"published_at": self.published_at.isoformat(),
|
||||
"content_snippet": self.content_snippet,
|
||||
"ticker_mentions": self.ticker_mentions,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "NewsArticle":
|
||||
return cls(
|
||||
title=data["title"],
|
||||
source=data["source"],
|
||||
url=data["url"],
|
||||
published_at=datetime.fromisoformat(data["published_at"]),
|
||||
content_snippet=data["content_snippet"],
|
||||
ticker_mentions=data["ticker_mentions"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrendingStock:
|
||||
ticker: str
|
||||
company_name: str
|
||||
score: float
|
||||
mention_count: int
|
||||
sentiment: float
|
||||
sector: Sector
|
||||
event_type: EventCategory
|
||||
news_summary: str
|
||||
source_articles: List[NewsArticle]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"ticker": self.ticker,
|
||||
"company_name": self.company_name,
|
||||
"score": self.score,
|
||||
"mention_count": self.mention_count,
|
||||
"sentiment": self.sentiment,
|
||||
"sector": self.sector.value,
|
||||
"event_type": self.event_type.value,
|
||||
"news_summary": self.news_summary,
|
||||
"source_articles": [article.to_dict() for article in self.source_articles],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TrendingStock":
|
||||
return cls(
|
||||
ticker=data["ticker"],
|
||||
company_name=data["company_name"],
|
||||
score=data["score"],
|
||||
mention_count=data["mention_count"],
|
||||
sentiment=data["sentiment"],
|
||||
sector=Sector(data["sector"]),
|
||||
event_type=EventCategory(data["event_type"]),
|
||||
news_summary=data["news_summary"],
|
||||
source_articles=[
|
||||
NewsArticle.from_dict(article) for article in data["source_articles"]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscoveryRequest:
|
||||
lookback_period: str
|
||||
sector_filter: Optional[List[Sector]] = None
|
||||
event_filter: Optional[List[EventCategory]] = None
|
||||
max_results: int = 20
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"lookback_period": self.lookback_period,
|
||||
"sector_filter": (
|
||||
[s.value for s in self.sector_filter] if self.sector_filter else None
|
||||
),
|
||||
"event_filter": (
|
||||
[e.value for e in self.event_filter] if self.event_filter else None
|
||||
),
|
||||
"max_results": self.max_results,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryRequest":
|
||||
return cls(
|
||||
lookback_period=data["lookback_period"],
|
||||
sector_filter=(
|
||||
[Sector(s) for s in data["sector_filter"]]
|
||||
if data.get("sector_filter")
|
||||
else None
|
||||
),
|
||||
event_filter=(
|
||||
[EventCategory(e) for e in data["event_filter"]]
|
||||
if data.get("event_filter")
|
||||
else None
|
||||
),
|
||||
max_results=data.get("max_results", 20),
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiscoveryResult:
|
||||
request: DiscoveryRequest
|
||||
trending_stocks: List[TrendingStock]
|
||||
status: DiscoveryStatus
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"request": self.request.to_dict(),
|
||||
"trending_stocks": [stock.to_dict() for stock in self.trending_stocks],
|
||||
"status": self.status.value,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"error_message": self.error_message,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryResult":
|
||||
return cls(
|
||||
request=DiscoveryRequest.from_dict(data["request"]),
|
||||
trending_stocks=[
|
||||
TrendingStock.from_dict(stock) for stock in data["trending_stocks"]
|
||||
],
|
||||
status=DiscoveryStatus(data["status"]),
|
||||
started_at=datetime.fromisoformat(data["started_at"]),
|
||||
completed_at=(
|
||||
datetime.fromisoformat(data["completed_at"])
|
||||
if data.get("completed_at")
|
||||
else None
|
||||
),
|
||||
error_message=data.get("error_message"),
|
||||
)
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .models import DiscoveryResult, TrendingStock
|
||||
|
||||
|
||||
def save_discovery_result(
|
||||
result: DiscoveryResult,
|
||||
base_path: Optional[Path] = None,
|
||||
) -> Path:
|
||||
if base_path is None:
|
||||
base_path = Path("results")
|
||||
|
||||
timestamp = result.completed_at or result.started_at
|
||||
date_str = timestamp.strftime("%Y-%m-%d")
|
||||
time_str = timestamp.strftime("%H-%M-%S")
|
||||
|
||||
result_dir = base_path / "discovery" / date_str / time_str
|
||||
result_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
json_path = result_dir / "discovery_result.json"
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(result.to_dict(), f, indent=2)
|
||||
|
||||
md_path = result_dir / "discovery_summary.md"
|
||||
markdown_content = generate_markdown_summary(result)
|
||||
with open(md_path, "w") as f:
|
||||
f.write(markdown_content)
|
||||
|
||||
return result_dir
|
||||
|
||||
|
||||
def generate_markdown_summary(result: DiscoveryResult) -> str:
|
||||
lines = []
|
||||
|
||||
lines.append("# Discovery Results")
|
||||
lines.append("")
|
||||
|
||||
timestamp = result.completed_at or result.started_at
|
||||
lines.append(f"**Timestamp:** {timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
lines.append(f"**Lookback Period:** {result.request.lookback_period}")
|
||||
|
||||
filters = _format_filters(result)
|
||||
lines.append(f"**Filters:** {filters}")
|
||||
lines.append(f"**Total Stocks Found:** {len(result.trending_stocks)}")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Trending Stocks")
|
||||
lines.append("")
|
||||
lines.append("| Rank | Ticker | Company | Score | Mentions | Event |")
|
||||
lines.append("|------|--------|---------|-------|----------|-------|")
|
||||
|
||||
for rank, stock in enumerate(result.trending_stocks, 1):
|
||||
lines.append(
|
||||
f"| {rank} | {stock.ticker} | {stock.company_name} | "
|
||||
f"{stock.score:.2f} | {stock.mention_count} | {stock.event_type.value} |"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Top 3 Detailed Analysis")
|
||||
lines.append("")
|
||||
|
||||
top_stocks = result.trending_stocks[:3]
|
||||
for rank, stock in enumerate(top_stocks, 1):
|
||||
lines.extend(_format_stock_detail(rank, stock))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_filters(result: DiscoveryResult) -> str:
|
||||
filter_parts = []
|
||||
|
||||
if result.request.sector_filter:
|
||||
sector_values = [s.value for s in result.request.sector_filter]
|
||||
filter_parts.append(f"sector={','.join(sector_values)}")
|
||||
|
||||
if result.request.event_filter:
|
||||
event_values = [e.value for e in result.request.event_filter]
|
||||
filter_parts.append(f"event={','.join(event_values)}")
|
||||
|
||||
if filter_parts:
|
||||
return " ".join(filter_parts)
|
||||
return "None"
|
||||
|
||||
|
||||
def _format_stock_detail(rank: int, stock: TrendingStock) -> list:
|
||||
lines = []
|
||||
|
||||
lines.append(f"### {rank}. {stock.ticker} - {stock.company_name}")
|
||||
lines.append(f"- **Score:** {stock.score:.2f}")
|
||||
|
||||
sentiment_label = _get_sentiment_label(stock.sentiment)
|
||||
lines.append(f"- **Sentiment:** {stock.sentiment:.2f} ({sentiment_label})")
|
||||
lines.append(f"- **Sector:** {stock.sector.value}")
|
||||
lines.append(f"- **Event Type:** {stock.event_type.value}")
|
||||
lines.append(f"- **Mentions:** {stock.mention_count}")
|
||||
lines.append("")
|
||||
|
||||
lines.append("**News Summary:**")
|
||||
lines.append(stock.news_summary)
|
||||
lines.append("")
|
||||
|
||||
if stock.source_articles:
|
||||
lines.append("**Top Sources:**")
|
||||
for article in stock.source_articles[:3]:
|
||||
lines.append(f"- [{article.title}] - {article.source}")
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _get_sentiment_label(sentiment: float) -> str:
|
||||
if sentiment > 0.3:
|
||||
return "positive"
|
||||
elif sentiment < -0.3:
|
||||
return "negative"
|
||||
return "neutral"
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
import math
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from tradingagents.agents.discovery.models import (
|
||||
TrendingStock,
|
||||
NewsArticle,
|
||||
Sector,
|
||||
EventCategory,
|
||||
)
|
||||
from tradingagents.agents.discovery.entity_extractor import EntityMention
|
||||
from tradingagents.dataflows.trending.stock_resolver import resolve_ticker
|
||||
from tradingagents.dataflows.trending.sector_classifier import classify_sector
|
||||
|
||||
|
||||
DEFAULT_DECAY_RATE = 0.1
|
||||
DEFAULT_MAX_RESULTS = 20
|
||||
DEFAULT_MIN_MENTIONS = 2
|
||||
|
||||
|
||||
def _aggregate_sentiment(mentions: List[EntityMention]) -> float:
|
||||
if not mentions:
|
||||
return 0.0
|
||||
|
||||
total_weighted_sentiment = 0.0
|
||||
total_confidence = 0.0
|
||||
|
||||
for mention in mentions:
|
||||
total_weighted_sentiment += mention.sentiment * mention.confidence
|
||||
total_confidence += mention.confidence
|
||||
|
||||
if total_confidence == 0:
|
||||
return 0.0
|
||||
|
||||
return total_weighted_sentiment / total_confidence
|
||||
|
||||
|
||||
def _calculate_recency_weight(
|
||||
articles: List[NewsArticle],
|
||||
article_ids: set,
|
||||
decay_rate: float,
|
||||
) -> float:
|
||||
if not articles:
|
||||
return 1.0
|
||||
|
||||
now = datetime.now()
|
||||
weights = []
|
||||
|
||||
for i, article in enumerate(articles):
|
||||
article_id = f"article_{i}"
|
||||
if article_id in article_ids:
|
||||
hours_old = (now - article.published_at).total_seconds() / 3600.0
|
||||
weight = math.exp(-decay_rate * hours_old)
|
||||
weights.append(weight)
|
||||
|
||||
if not weights:
|
||||
return 1.0
|
||||
|
||||
return sum(weights) / len(weights)
|
||||
|
||||
|
||||
def _get_most_common_event_type(mentions: List[EntityMention]) -> EventCategory:
|
||||
if not mentions:
|
||||
return EventCategory.OTHER
|
||||
|
||||
event_counts: Dict[EventCategory, int] = defaultdict(int)
|
||||
for mention in mentions:
|
||||
event_counts[mention.event_type] += 1
|
||||
|
||||
return max(event_counts.keys(), key=lambda e: event_counts[e])
|
||||
|
||||
|
||||
def _build_news_summary(mentions: List[EntityMention]) -> str:
|
||||
if not mentions:
|
||||
return ""
|
||||
|
||||
snippets = [m.context_snippet for m in mentions[:3]]
|
||||
return " ".join(snippets)
|
||||
|
||||
|
||||
def calculate_trending_scores(
|
||||
mentions: List[EntityMention],
|
||||
articles: List[NewsArticle],
|
||||
decay_rate: float = DEFAULT_DECAY_RATE,
|
||||
max_results: int = DEFAULT_MAX_RESULTS,
|
||||
min_mentions: int = DEFAULT_MIN_MENTIONS,
|
||||
) -> List[TrendingStock]:
|
||||
if not mentions:
|
||||
return []
|
||||
|
||||
ticker_mentions: Dict[str, List[EntityMention]] = defaultdict(list)
|
||||
ticker_company_names: Dict[str, str] = {}
|
||||
|
||||
for mention in mentions:
|
||||
ticker = resolve_ticker(mention.company_name)
|
||||
if ticker:
|
||||
ticker_mentions[ticker].append(mention)
|
||||
if ticker not in ticker_company_names:
|
||||
ticker_company_names[ticker] = mention.company_name
|
||||
|
||||
article_index: Dict[str, int] = {}
|
||||
for i, article in enumerate(articles):
|
||||
article_index[f"article_{i}"] = i
|
||||
|
||||
trending_stocks: List[TrendingStock] = []
|
||||
|
||||
for ticker, ticker_mention_list in ticker_mentions.items():
|
||||
article_ids = {m.article_id for m in ticker_mention_list}
|
||||
frequency = len(article_ids)
|
||||
|
||||
if frequency < min_mentions:
|
||||
continue
|
||||
|
||||
sentiment = _aggregate_sentiment(ticker_mention_list)
|
||||
sentiment_factor = 1 + abs(sentiment)
|
||||
|
||||
recency_weight = _calculate_recency_weight(articles, article_ids, decay_rate)
|
||||
|
||||
score = frequency * sentiment_factor * recency_weight
|
||||
|
||||
sector_str = classify_sector(ticker)
|
||||
try:
|
||||
sector = Sector(sector_str)
|
||||
except ValueError:
|
||||
sector = Sector.OTHER
|
||||
|
||||
event_type = _get_most_common_event_type(ticker_mention_list)
|
||||
|
||||
source_article_list: List[NewsArticle] = []
|
||||
for article_id in article_ids:
|
||||
idx = article_index.get(article_id)
|
||||
if idx is not None and idx < len(articles):
|
||||
source_article_list.append(articles[idx])
|
||||
|
||||
news_summary = _build_news_summary(ticker_mention_list)
|
||||
|
||||
trending_stock = TrendingStock(
|
||||
ticker=ticker,
|
||||
company_name=ticker_company_names.get(ticker, ticker),
|
||||
score=score,
|
||||
mention_count=frequency,
|
||||
sentiment=sentiment,
|
||||
sector=sector,
|
||||
event_type=event_type,
|
||||
news_summary=news_summary,
|
||||
source_articles=source_article_list,
|
||||
)
|
||||
trending_stocks.append(trending_stock)
|
||||
|
||||
trending_stocks.sort(key=lambda s: s.score, reverse=True)
|
||||
|
||||
return trending_stocks[:max_results]
|
||||
|
|
@ -7,44 +7,44 @@ from langgraph.prebuilt import ToolNode
|
|||
from langgraph.graph import END, StateGraph, START, MessagesState
|
||||
|
||||
|
||||
# Researcher team state
|
||||
class InvestDebateState(TypedDict):
|
||||
"""Researcher team state"""
|
||||
bull_history: Annotated[
|
||||
str, "Bullish Conversation history"
|
||||
] # Bullish Conversation history
|
||||
]
|
||||
bear_history: Annotated[
|
||||
str, "Bearish Conversation history"
|
||||
] # Bullish Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
current_response: Annotated[str, "Latest response"] # Last response
|
||||
judge_decision: Annotated[str, "Final judge decision"] # Last response
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
]
|
||||
history: Annotated[str, "Conversation history"]
|
||||
current_response: Annotated[str, "Latest response"]
|
||||
judge_decision: Annotated[str, "Final judge decision"]
|
||||
count: Annotated[int, "Length of the current conversation"]
|
||||
|
||||
|
||||
# Risk management team state
|
||||
class RiskDebateState(TypedDict):
|
||||
"""Risk management team state"""
|
||||
risky_history: Annotated[
|
||||
str, "Risky Agent's Conversation history"
|
||||
] # Conversation history
|
||||
]
|
||||
safe_history: Annotated[
|
||||
str, "Safe Agent's Conversation history"
|
||||
] # Conversation history
|
||||
]
|
||||
neutral_history: Annotated[
|
||||
str, "Neutral Agent's Conversation history"
|
||||
] # Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
]
|
||||
history: Annotated[str, "Conversation history"]
|
||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||
current_risky_response: Annotated[
|
||||
str, "Latest response by the risky analyst"
|
||||
] # Last response
|
||||
]
|
||||
current_safe_response: Annotated[
|
||||
str, "Latest response by the safe analyst"
|
||||
] # Last response
|
||||
]
|
||||
current_neutral_response: Annotated[
|
||||
str, "Latest response by the neutral analyst"
|
||||
] # Last response
|
||||
]
|
||||
judge_decision: Annotated[str, "Judge's decision"]
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
count: Annotated[int, "Length of the current conversation"]
|
||||
|
||||
|
||||
class AgentState(MessagesState):
|
||||
|
|
@ -53,7 +53,7 @@ class AgentState(MessagesState):
|
|||
|
||||
sender: Annotated[str, "Agent that sent this message"]
|
||||
|
||||
# research step
|
||||
# research
|
||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
||||
news_report: Annotated[
|
||||
|
|
@ -61,7 +61,7 @@ class AgentState(MessagesState):
|
|||
]
|
||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
||||
|
||||
# researcher team discussion step
|
||||
# research
|
||||
investment_debate_state: Annotated[
|
||||
InvestDebateState, "Current state of the debate on if to invest or not"
|
||||
]
|
||||
|
|
@ -69,7 +69,7 @@ class AgentState(MessagesState):
|
|||
|
||||
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
|
||||
|
||||
# risk management team discussion step
|
||||
# risk mgmt
|
||||
risk_debate_state: Annotated[
|
||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
|
||||
# Import tools from separate utility files
|
||||
from tradingagents.agents.utils.core_stock_tools import (
|
||||
get_stock_data
|
||||
)
|
||||
|
|
@ -24,16 +23,7 @@ def create_msg_delete():
|
|||
def delete_messages(state):
|
||||
"""Clear messages and add placeholder for Anthropic compatibility"""
|
||||
messages = state["messages"]
|
||||
|
||||
# Remove all messages
|
||||
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||
|
||||
# Add a minimal placeholder message
|
||||
placeholder = HumanMessage(content="Continue")
|
||||
|
||||
return {"messages": removal_operations + [placeholder]}
|
||||
|
||||
return delete_messages
|
||||
|
||||
|
||||
|
||||
|
|
@ -67,47 +67,45 @@ class FinancialSituationMemory:
|
|||
return matched_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
matcher = FinancialSituationMemory()
|
||||
# if __name__ == "__main__":
|
||||
# # Example usage
|
||||
# matcher = FinancialSituationMemory()
|
||||
# example_data = [
|
||||
# (
|
||||
# "High inflation rate with rising interest rates and declining consumer spending",
|
||||
# "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||
# ),
|
||||
# (
|
||||
# "Tech sector showing high volatility with increasing institutional selling pressure",
|
||||
# "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||
# ),
|
||||
# (
|
||||
# "Strong dollar affecting emerging markets with increasing forex volatility",
|
||||
# "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||
# ),
|
||||
# (
|
||||
# "Market showing signs of sector rotation with rising yields",
|
||||
# "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||
# ),
|
||||
# ]
|
||||
|
||||
# Example data
|
||||
example_data = [
|
||||
(
|
||||
"High inflation rate with rising interest rates and declining consumer spending",
|
||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||
),
|
||||
(
|
||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||
),
|
||||
(
|
||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||
),
|
||||
(
|
||||
"Market showing signs of sector rotation with rising yields",
|
||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||
),
|
||||
]
|
||||
# # Add the example situations and recommendations
|
||||
# matcher.add_situations(example_data)
|
||||
|
||||
# Add the example situations and recommendations
|
||||
matcher.add_situations(example_data)
|
||||
# # Example query
|
||||
# current_situation = """
|
||||
# Market showing increased volatility in tech sector, with institutional investors
|
||||
# reducing positions and rising interest rates affecting growth stock valuations
|
||||
# """
|
||||
|
||||
# Example query
|
||||
current_situation = """
|
||||
Market showing increased volatility in tech sector, with institutional investors
|
||||
reducing positions and rising interest rates affecting growth stock valuations
|
||||
"""
|
||||
# try:
|
||||
# recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||
|
||||
try:
|
||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||
# for i, rec in enumerate(recommendations, 1):
|
||||
# print(f"\nMatch {i}:")
|
||||
# print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||||
# print(f"Matched Situation: {rec['matched_situation']}")
|
||||
# print(f"Recommendation: {rec['recommendation']}")
|
||||
|
||||
for i, rec in enumerate(recommendations, 1):
|
||||
print(f"\nMatch {i}:")
|
||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||||
print(f"Matched Situation: {rec['matched_situation']}")
|
||||
print(f"Recommendation: {rec['recommendation']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during recommendation: {str(e)}")
|
||||
# except Exception as e:
|
||||
# print(f"Error during recommendation: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -1,19 +1,9 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||
|
||||
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
||||
|
||||
Covers stocks, cryptocurrencies, forex, and topics like fiscal policy, mergers & acquisitions, IPOs.
|
||||
|
||||
Args:
|
||||
ticker: Stock symbol for news articles.
|
||||
start_date: Start date for news search.
|
||||
end_date: End date for news search.
|
||||
|
||||
Returns:
|
||||
Dictionary containing news sentiment data or JSON string.
|
||||
"""
|
||||
|
||||
params = {
|
||||
"tickers": ticker,
|
||||
"time_from": format_datetime_for_api(start_date),
|
||||
|
|
@ -21,23 +11,63 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
|||
"sort": "LATEST",
|
||||
"limit": "50",
|
||||
}
|
||||
|
||||
|
||||
return _make_api_request("NEWS_SENTIMENT", params)
|
||||
|
||||
def get_insider_transactions(symbol: str) -> dict[str, str] | str:
|
||||
"""Returns latest and historical insider transactions by key stakeholders.
|
||||
|
||||
Covers transactions by founders, executives, board members, etc.
|
||||
|
||||
Args:
|
||||
symbol: Ticker symbol. Example: "IBM".
|
||||
|
||||
Returns:
|
||||
Dictionary containing insider transaction data or JSON string.
|
||||
"""
|
||||
|
||||
params = {
|
||||
"symbol": symbol,
|
||||
}
|
||||
|
||||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||
|
||||
|
||||
def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(hours=lookback_hours)
|
||||
|
||||
params = {
|
||||
"time_from": format_datetime_for_api(start_date),
|
||||
"time_to": format_datetime_for_api(end_date),
|
||||
"sort": "LATEST",
|
||||
"limit": "200",
|
||||
"topics": "financial_markets,earnings,economy_fiscal,economy_monetary,mergers_and_acquisitions",
|
||||
}
|
||||
|
||||
response = _make_api_request("NEWS_SENTIMENT", params)
|
||||
|
||||
if isinstance(response, str):
|
||||
try:
|
||||
response = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if not isinstance(response, dict):
|
||||
return []
|
||||
|
||||
feed = response.get("feed", [])
|
||||
|
||||
articles = []
|
||||
for item in feed:
|
||||
try:
|
||||
time_published = item.get("time_published", "")
|
||||
if time_published:
|
||||
try:
|
||||
published_at = datetime.strptime(time_published, "%Y%m%dT%H%M%S")
|
||||
except ValueError:
|
||||
published_at = datetime.now()
|
||||
else:
|
||||
published_at = datetime.now()
|
||||
|
||||
article = {
|
||||
"title": item.get("title", ""),
|
||||
"source": item.get("source", ""),
|
||||
"url": item.get("url", ""),
|
||||
"published_at": published_at.isoformat(),
|
||||
"content_snippet": item.get("summary", "")[:500],
|
||||
}
|
||||
articles.append(article)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return articles
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Annotated
|
||||
from datetime import datetime
|
||||
from typing import Annotated, List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from .googlenews_utils import getNewsData
|
||||
|
||||
|
|
@ -27,4 +27,53 @@ def get_google_news(
|
|||
if len(news_results) == 0:
|
||||
return ""
|
||||
|
||||
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
|
||||
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
|
||||
|
||||
|
||||
def get_bulk_news_google(lookback_hours: int) -> List[Dict[str, Any]]:
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(hours=lookback_hours)
|
||||
|
||||
start_str = start_date.strftime("%Y-%m-%d")
|
||||
end_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
queries = [
|
||||
"stock market",
|
||||
"trading news",
|
||||
"earnings report",
|
||||
]
|
||||
|
||||
all_articles = []
|
||||
seen_titles = set()
|
||||
|
||||
for query in queries:
|
||||
try:
|
||||
news_results = getNewsData(query.replace(" ", "+"), start_str, end_str)
|
||||
|
||||
for news in news_results:
|
||||
title = news.get("title", "")
|
||||
if title and title not in seen_titles:
|
||||
seen_titles.add(title)
|
||||
|
||||
date_str = news.get("date", "")
|
||||
try:
|
||||
if date_str:
|
||||
published_at = datetime.now()
|
||||
else:
|
||||
published_at = datetime.now()
|
||||
except ValueError:
|
||||
published_at = datetime.now()
|
||||
|
||||
article = {
|
||||
"title": title,
|
||||
"source": news.get("source", "Google News"),
|
||||
"url": news.get("link", ""),
|
||||
"published_at": published_at.isoformat(),
|
||||
"content_snippet": news.get("snippet", "")[:500],
|
||||
}
|
||||
all_articles.append(article)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return all_articles
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Annotated
|
||||
from typing import Annotated, List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import from vendor-specific modules
|
||||
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
|
||||
from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
||||
from .google import get_google_news
|
||||
from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai
|
||||
from .google import get_google_news, get_bulk_news_google
|
||||
from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai, get_bulk_news_openai
|
||||
from .alpha_vantage import (
|
||||
get_stock as get_alpha_vantage_stock,
|
||||
get_indicator as get_alpha_vantage_indicator,
|
||||
|
|
@ -15,12 +15,13 @@ from .alpha_vantage import (
|
|||
get_insider_transactions as get_alpha_vantage_insider_transactions,
|
||||
get_news as get_alpha_vantage_news
|
||||
)
|
||||
from .alpha_vantage_news import get_bulk_news_alpha_vantage
|
||||
from .alpha_vantage_common import AlphaVantageRateLimitError
|
||||
|
||||
# Configuration and routing logic
|
||||
from .config import get_config
|
||||
|
||||
# Tools organized by category
|
||||
from tradingagents.agents.discovery import NewsArticle
|
||||
|
||||
TOOLS_CATEGORIES = {
|
||||
"core_stock_apis": {
|
||||
"description": "OHLCV stock price data",
|
||||
|
|
@ -50,6 +51,7 @@ TOOLS_CATEGORIES = {
|
|||
"get_global_news",
|
||||
"get_insider_sentiment",
|
||||
"get_insider_transactions",
|
||||
"get_bulk_news",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
@ -61,21 +63,17 @@ VENDOR_LIST = [
|
|||
"google"
|
||||
]
|
||||
|
||||
# Mapping of methods to their vendor-specific implementations
|
||||
VENDOR_METHODS = {
|
||||
# core_stock_apis
|
||||
"get_stock_data": {
|
||||
"alpha_vantage": get_alpha_vantage_stock,
|
||||
"yfinance": get_YFin_data_online,
|
||||
"local": get_YFin_data,
|
||||
},
|
||||
# technical_indicators
|
||||
"get_indicators": {
|
||||
"alpha_vantage": get_alpha_vantage_indicator,
|
||||
"yfinance": get_stock_stats_indicators_window,
|
||||
"local": get_stock_stats_indicators_window
|
||||
},
|
||||
# fundamental_data
|
||||
"get_fundamentals": {
|
||||
"alpha_vantage": get_alpha_vantage_fundamentals,
|
||||
"openai": get_fundamentals_openai,
|
||||
|
|
@ -95,7 +93,6 @@ VENDOR_METHODS = {
|
|||
"yfinance": get_yfinance_income_statement,
|
||||
"local": get_simfin_income_statements,
|
||||
},
|
||||
# news_data
|
||||
"get_news": {
|
||||
"alpha_vantage": get_alpha_vantage_news,
|
||||
"openai": get_stock_news_openai,
|
||||
|
|
@ -114,56 +111,159 @@ VENDOR_METHODS = {
|
|||
"yfinance": get_yfinance_insider_transactions,
|
||||
"local": get_finnhub_company_insider_transactions,
|
||||
},
|
||||
"get_bulk_news": {
|
||||
"alpha_vantage": get_bulk_news_alpha_vantage,
|
||||
"openai": get_bulk_news_openai,
|
||||
"google": get_bulk_news_google,
|
||||
},
|
||||
}
|
||||
|
||||
CACHE_TTL_SECONDS = 300
|
||||
|
||||
_bulk_news_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def parse_lookback_period(lookback: str) -> int:
|
||||
lookback = lookback.lower().strip()
|
||||
|
||||
if lookback == "1h":
|
||||
return 1
|
||||
elif lookback == "6h":
|
||||
return 6
|
||||
elif lookback == "24h":
|
||||
return 24
|
||||
elif lookback == "7d":
|
||||
return 168
|
||||
else:
|
||||
raise ValueError(f"Invalid lookback period: {lookback}. Valid values: 1h, 6h, 24h, 7d")
|
||||
|
||||
|
||||
def _get_cached_bulk_news(lookback_period: str) -> Optional[List[NewsArticle]]:
|
||||
cache_key = lookback_period
|
||||
if cache_key in _bulk_news_cache:
|
||||
cached = _bulk_news_cache[cache_key]
|
||||
cached_time = cached.get("timestamp")
|
||||
if cached_time and (datetime.now() - cached_time).total_seconds() < CACHE_TTL_SECONDS:
|
||||
return cached.get("articles")
|
||||
return None
|
||||
|
||||
|
||||
def _set_cached_bulk_news(lookback_period: str, articles: List[NewsArticle]) -> None:
|
||||
cache_key = lookback_period
|
||||
_bulk_news_cache[cache_key] = {
|
||||
"timestamp": datetime.now(),
|
||||
"articles": articles,
|
||||
}
|
||||
|
||||
|
||||
def _convert_to_news_articles(raw_articles: List[Dict[str, Any]]) -> List[NewsArticle]:
|
||||
articles = []
|
||||
for item in raw_articles:
|
||||
try:
|
||||
published_at_str = item.get("published_at", "")
|
||||
if isinstance(published_at_str, str):
|
||||
try:
|
||||
published_at = datetime.fromisoformat(published_at_str.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
published_at = datetime.now()
|
||||
elif isinstance(published_at_str, datetime):
|
||||
published_at = published_at_str
|
||||
else:
|
||||
published_at = datetime.now()
|
||||
|
||||
article = NewsArticle(
|
||||
title=item.get("title", ""),
|
||||
source=item.get("source", ""),
|
||||
url=item.get("url", ""),
|
||||
published_at=published_at,
|
||||
content_snippet=item.get("content_snippet", ""),
|
||||
ticker_mentions=[],
|
||||
)
|
||||
articles.append(article)
|
||||
except Exception:
|
||||
continue
|
||||
return articles
|
||||
|
||||
|
||||
def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]:
|
||||
lookback_hours = parse_lookback_period(lookback_period)
|
||||
|
||||
vendor_order = ["alpha_vantage", "openai", "google"]
|
||||
|
||||
for vendor in vendor_order:
|
||||
if vendor not in VENDOR_METHODS["get_bulk_news"]:
|
||||
continue
|
||||
|
||||
vendor_func = VENDOR_METHODS["get_bulk_news"][vendor]
|
||||
|
||||
try:
|
||||
print(f"DEBUG: Attempting bulk news from vendor '{vendor}'...")
|
||||
result = vendor_func(lookback_hours)
|
||||
if result:
|
||||
print(f"SUCCESS: Got {len(result)} articles from vendor '{vendor}'")
|
||||
return result
|
||||
print(f"DEBUG: Vendor '{vendor}' returned empty results, trying next...")
|
||||
except AlphaVantageRateLimitError as e:
|
||||
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"FAILED: Vendor '{vendor}' failed: {e}")
|
||||
continue
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_bulk_news(lookback_period: str = "24h") -> List[NewsArticle]:
|
||||
cached = _get_cached_bulk_news(lookback_period)
|
||||
if cached is not None:
|
||||
print(f"DEBUG: Returning cached bulk news for period '{lookback_period}'")
|
||||
return cached
|
||||
|
||||
raw_articles = _fetch_bulk_news_from_vendor(lookback_period)
|
||||
|
||||
articles = _convert_to_news_articles(raw_articles)
|
||||
|
||||
_set_cached_bulk_news(lookback_period, articles)
|
||||
|
||||
return articles
|
||||
|
||||
|
||||
def get_category_for_method(method: str) -> str:
|
||||
"""Get the category that contains the specified method."""
|
||||
for category, info in TOOLS_CATEGORIES.items():
|
||||
if method in info["tools"]:
|
||||
return category
|
||||
raise ValueError(f"Method '{method}' not found in any category")
|
||||
|
||||
def get_vendor(category: str, method: str = None) -> str:
|
||||
"""Get the configured vendor for a data category or specific tool method.
|
||||
Tool-level configuration takes precedence over category-level.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
# Check tool-level configuration first (if method provided)
|
||||
if method:
|
||||
tool_vendors = config.get("tool_vendors", {})
|
||||
if method in tool_vendors:
|
||||
return tool_vendors[method]
|
||||
|
||||
# Fall back to category-level configuration
|
||||
return config.get("data_vendors", {}).get(category, "default")
|
||||
|
||||
def route_to_vendor(method: str, *args, **kwargs):
|
||||
"""Route method calls to appropriate vendor implementation with fallback support."""
|
||||
category = get_category_for_method(method)
|
||||
vendor_config = get_vendor(category, method)
|
||||
|
||||
# Handle comma-separated vendors
|
||||
primary_vendors = [v.strip() for v in vendor_config.split(',')]
|
||||
|
||||
if method not in VENDOR_METHODS:
|
||||
raise ValueError(f"Method '{method}' not supported")
|
||||
|
||||
# Get all available vendors for this method for fallback
|
||||
all_available_vendors = list(VENDOR_METHODS[method].keys())
|
||||
|
||||
# Create fallback vendor list: primary vendors first, then remaining vendors as fallbacks
|
||||
|
||||
fallback_vendors = primary_vendors.copy()
|
||||
for vendor in all_available_vendors:
|
||||
if vendor not in fallback_vendors:
|
||||
fallback_vendors.append(vendor)
|
||||
|
||||
# Debug: Print fallback ordering
|
||||
primary_str = " → ".join(primary_vendors)
|
||||
fallback_str = " → ".join(fallback_vendors)
|
||||
primary_str = " -> ".join(primary_vendors)
|
||||
fallback_str = " -> ".join(fallback_vendors)
|
||||
print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]")
|
||||
|
||||
# Track results and execution state
|
||||
results = []
|
||||
vendor_attempt_count = 0
|
||||
any_primary_vendor_attempted = False
|
||||
|
|
@ -179,22 +279,18 @@ def route_to_vendor(method: str, *args, **kwargs):
|
|||
is_primary_vendor = vendor in primary_vendors
|
||||
vendor_attempt_count += 1
|
||||
|
||||
# Track if we attempted any primary vendor
|
||||
if is_primary_vendor:
|
||||
any_primary_vendor_attempted = True
|
||||
|
||||
# Debug: Print current attempt
|
||||
vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK"
|
||||
print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})")
|
||||
|
||||
# Handle list of methods for a vendor
|
||||
if isinstance(vendor_impl, list):
|
||||
vendor_methods = [(impl, vendor) for impl in vendor_impl]
|
||||
print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions")
|
||||
else:
|
||||
vendor_methods = [(vendor_impl, vendor)]
|
||||
|
||||
# Run methods for this vendor
|
||||
vendor_results = []
|
||||
for impl_func, vendor_name in vendor_methods:
|
||||
try:
|
||||
|
|
@ -202,43 +298,35 @@ def route_to_vendor(method: str, *args, **kwargs):
|
|||
result = impl_func(*args, **kwargs)
|
||||
vendor_results.append(result)
|
||||
print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully")
|
||||
|
||||
|
||||
except AlphaVantageRateLimitError as e:
|
||||
if vendor == "alpha_vantage":
|
||||
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor")
|
||||
print(f"DEBUG: Rate limit details: {e}")
|
||||
# Continue to next vendor for fallback
|
||||
continue
|
||||
except Exception as e:
|
||||
# Log error but continue with other implementations
|
||||
print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}")
|
||||
continue
|
||||
|
||||
# Add this vendor's results
|
||||
if vendor_results:
|
||||
results.extend(vendor_results)
|
||||
successful_vendor = vendor
|
||||
result_summary = f"Got {len(vendor_results)} result(s)"
|
||||
print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}")
|
||||
|
||||
# Stopping logic: Stop after first successful vendor for single-vendor configs
|
||||
# Multiple vendor configs (comma-separated) may want to collect from multiple sources
|
||||
|
||||
if len(primary_vendors) == 1:
|
||||
print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)")
|
||||
break
|
||||
else:
|
||||
print(f"FAILED: Vendor '{vendor}' produced no results")
|
||||
|
||||
# Final result summary
|
||||
if not results:
|
||||
print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'")
|
||||
raise RuntimeError(f"All vendor implementations failed for method '{method}'")
|
||||
else:
|
||||
print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)")
|
||||
|
||||
# Return single result if only one, otherwise concatenate as string
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
else:
|
||||
# Convert all results to strings and concatenate
|
||||
return '\n'.join(str(result) for result in results)
|
||||
return '\n'.join(str(result) for result in results)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from openai import OpenAI
|
||||
from .config import get_config
|
||||
|
||||
|
|
@ -104,4 +108,91 @@ def get_fundamentals_openai(ticker, curr_date):
|
|||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
return response.output[1].content[0].text
|
||||
|
||||
|
||||
def get_bulk_news_openai(lookback_hours: int) -> List[Dict[str, Any]]:
|
||||
config = get_config()
|
||||
client = OpenAI(base_url=config["backend_url"])
|
||||
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(hours=lookback_hours)
|
||||
|
||||
start_str = start_date.strftime("%Y-%m-%d %H:%M")
|
||||
end_str = end_date.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
prompt = f"""Search for recent stock market news, trading news, and earnings announcements from {start_str} to {end_str}.
|
||||
|
||||
Return the results as a JSON array with the following structure:
|
||||
[
|
||||
{{
|
||||
"title": "Article title",
|
||||
"source": "Source name",
|
||||
"url": "https://...",
|
||||
"published_at": "YYYY-MM-DDTHH:MM:SS",
|
||||
"content_snippet": "Brief summary of the article..."
|
||||
}}
|
||||
]
|
||||
|
||||
Focus on:
|
||||
- Stock market movements and trends
|
||||
- Company earnings reports
|
||||
- Mergers and acquisitions
|
||||
- Significant trading activity
|
||||
- Economic news affecting markets
|
||||
|
||||
Return ONLY the JSON array, no additional text."""
|
||||
|
||||
response = client.responses.create(
|
||||
model=config["quick_think_llm"],
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "medium",
|
||||
}
|
||||
],
|
||||
temperature=0.5,
|
||||
max_output_tokens=8192,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
try:
|
||||
response_text = response.output[1].content[0].text
|
||||
|
||||
json_match = re.search(r'\[[\s\S]*\]', response_text)
|
||||
if json_match:
|
||||
articles = json.loads(json_match.group())
|
||||
else:
|
||||
articles = json.loads(response_text)
|
||||
|
||||
result = []
|
||||
for item in articles:
|
||||
if isinstance(item, dict):
|
||||
article = {
|
||||
"title": item.get("title", ""),
|
||||
"source": item.get("source", "Web Search"),
|
||||
"url": item.get("url", ""),
|
||||
"published_at": item.get("published_at", datetime.now().isoformat()),
|
||||
"content_snippet": item.get("content_snippet", "")[:500],
|
||||
}
|
||||
result.append(article)
|
||||
|
||||
return result
|
||||
|
||||
except (json.JSONDecodeError, IndexError, AttributeError):
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
from .stock_resolver import (
|
||||
resolve_ticker,
|
||||
validate_tradeable,
|
||||
validate_us_ticker,
|
||||
COMPANY_TO_TICKER,
|
||||
)
|
||||
from .sector_classifier import (
|
||||
classify_sector,
|
||||
TICKER_TO_SECTOR,
|
||||
VALID_SECTORS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"resolve_ticker",
|
||||
"validate_tradeable",
|
||||
"validate_us_ticker",
|
||||
"COMPANY_TO_TICKER",
|
||||
"classify_sector",
|
||||
"TICKER_TO_SECTOR",
|
||||
"VALID_SECTORS",
|
||||
]
|
||||
|
|
@ -0,0 +1,267 @@
|
|||
import logging
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VALID_SECTORS = {
|
||||
"technology",
|
||||
"healthcare",
|
||||
"finance",
|
||||
"energy",
|
||||
"consumer_goods",
|
||||
"industrials",
|
||||
"other",
|
||||
}
|
||||
|
||||
TICKER_TO_SECTOR: Dict[str, str] = {
|
||||
"AAPL": "technology",
|
||||
"MSFT": "technology",
|
||||
"GOOGL": "technology",
|
||||
"GOOG": "technology",
|
||||
"AMZN": "technology",
|
||||
"META": "technology",
|
||||
"NVDA": "technology",
|
||||
"TSLA": "technology",
|
||||
"AMD": "technology",
|
||||
"INTC": "technology",
|
||||
"QCOM": "technology",
|
||||
"AVGO": "technology",
|
||||
"TXN": "technology",
|
||||
"ADBE": "technology",
|
||||
"CRM": "technology",
|
||||
"CSCO": "technology",
|
||||
"NFLX": "technology",
|
||||
"ORCL": "technology",
|
||||
"IBM": "technology",
|
||||
"NOW": "technology",
|
||||
"INTU": "technology",
|
||||
"ADSK": "technology",
|
||||
"SNPS": "technology",
|
||||
"CDNS": "technology",
|
||||
"PLTR": "technology",
|
||||
"SNOW": "technology",
|
||||
"DDOG": "technology",
|
||||
"CRWD": "technology",
|
||||
"OKTA": "technology",
|
||||
"NET": "technology",
|
||||
"MDB": "technology",
|
||||
"TWLO": "technology",
|
||||
"WDAY": "technology",
|
||||
"SPLK": "technology",
|
||||
"VMW": "technology",
|
||||
"HPQ": "technology",
|
||||
"DELL": "technology",
|
||||
"FTNT": "technology",
|
||||
"PANW": "technology",
|
||||
"ZS": "technology",
|
||||
"S": "technology",
|
||||
"VEEV": "technology",
|
||||
"ZM": "technology",
|
||||
"DOCU": "technology",
|
||||
"ASAN": "technology",
|
||||
"MNDY": "technology",
|
||||
"TEAM": "technology",
|
||||
"ANSS": "technology",
|
||||
"ROP": "technology",
|
||||
"JPM": "finance",
|
||||
"BAC": "finance",
|
||||
"WFC": "finance",
|
||||
"GS": "finance",
|
||||
"MS": "finance",
|
||||
"C": "finance",
|
||||
"BLK": "finance",
|
||||
"SCHW": "finance",
|
||||
"AXP": "finance",
|
||||
"V": "finance",
|
||||
"MA": "finance",
|
||||
"PYPL": "finance",
|
||||
"SQ": "finance",
|
||||
"COIN": "finance",
|
||||
"HOOD": "finance",
|
||||
"SOFI": "finance",
|
||||
"AFRM": "finance",
|
||||
"MQ": "finance",
|
||||
"BRK-B": "finance",
|
||||
"BRK-A": "finance",
|
||||
"JNJ": "healthcare",
|
||||
"UNH": "healthcare",
|
||||
"PFE": "healthcare",
|
||||
"ABBV": "healthcare",
|
||||
"MRK": "healthcare",
|
||||
"LLY": "healthcare",
|
||||
"MRNA": "healthcare",
|
||||
"BNTX": "healthcare",
|
||||
"CVS": "healthcare",
|
||||
"WBA": "healthcare",
|
||||
"MCK": "healthcare",
|
||||
"CAH": "healthcare",
|
||||
"HUM": "healthcare",
|
||||
"CI": "healthcare",
|
||||
"ELV": "healthcare",
|
||||
"XOM": "energy",
|
||||
"CVX": "energy",
|
||||
"COP": "energy",
|
||||
"SLB": "energy",
|
||||
"HAL": "energy",
|
||||
"BKR": "energy",
|
||||
"MPC": "energy",
|
||||
"VLO": "energy",
|
||||
"PSX": "energy",
|
||||
"OXY": "energy",
|
||||
"PXD": "energy",
|
||||
"DVN": "energy",
|
||||
"CEG": "energy",
|
||||
"NEE": "energy",
|
||||
"DUK": "energy",
|
||||
"SO": "energy",
|
||||
"D": "energy",
|
||||
"SRE": "energy",
|
||||
"WMT": "consumer_goods",
|
||||
"COST": "consumer_goods",
|
||||
"TGT": "consumer_goods",
|
||||
"HD": "consumer_goods",
|
||||
"LOW": "consumer_goods",
|
||||
"PG": "consumer_goods",
|
||||
"KO": "consumer_goods",
|
||||
"PEP": "consumer_goods",
|
||||
"NKE": "consumer_goods",
|
||||
"SBUX": "consumer_goods",
|
||||
"MCD": "consumer_goods",
|
||||
"CMG": "consumer_goods",
|
||||
"YUM": "consumer_goods",
|
||||
"DPZ": "consumer_goods",
|
||||
"DIS": "consumer_goods",
|
||||
"CMCSA": "consumer_goods",
|
||||
"VZ": "consumer_goods",
|
||||
"T": "consumer_goods",
|
||||
"TMUS": "consumer_goods",
|
||||
"EL": "consumer_goods",
|
||||
"CL": "consumer_goods",
|
||||
"KMB": "consumer_goods",
|
||||
"CLX": "consumer_goods",
|
||||
"KHC": "consumer_goods",
|
||||
"GIS": "consumer_goods",
|
||||
"K": "consumer_goods",
|
||||
"MDLZ": "consumer_goods",
|
||||
"HSY": "consumer_goods",
|
||||
"TSN": "consumer_goods",
|
||||
"BYND": "consumer_goods",
|
||||
"CAG": "consumer_goods",
|
||||
"STZ": "consumer_goods",
|
||||
"BUD": "consumer_goods",
|
||||
"DEO": "consumer_goods",
|
||||
"PM": "consumer_goods",
|
||||
"MO": "consumer_goods",
|
||||
"LULU": "consumer_goods",
|
||||
"DG": "consumer_goods",
|
||||
"DLTR": "consumer_goods",
|
||||
"ROST": "consumer_goods",
|
||||
"TJX": "consumer_goods",
|
||||
"AZO": "consumer_goods",
|
||||
"ORLY": "consumer_goods",
|
||||
"KMX": "consumer_goods",
|
||||
"ADDYY": "consumer_goods",
|
||||
"UBER": "consumer_goods",
|
||||
"LYFT": "consumer_goods",
|
||||
"ABNB": "consumer_goods",
|
||||
"DASH": "consumer_goods",
|
||||
"SNAP": "consumer_goods",
|
||||
"PINS": "consumer_goods",
|
||||
"TWTR": "consumer_goods",
|
||||
"SHOP": "consumer_goods",
|
||||
"TOST": "consumer_goods",
|
||||
"BA": "industrials",
|
||||
"LMT": "industrials",
|
||||
"RTX": "industrials",
|
||||
"GD": "industrials",
|
||||
"NOC": "industrials",
|
||||
"GE": "industrials",
|
||||
"HON": "industrials",
|
||||
"MMM": "industrials",
|
||||
"CAT": "industrials",
|
||||
"DE": "industrials",
|
||||
"UNP": "industrials",
|
||||
"UPS": "industrials",
|
||||
"FDX": "industrials",
|
||||
"DAL": "industrials",
|
||||
"UAL": "industrials",
|
||||
"AAL": "industrials",
|
||||
"LUV": "industrials",
|
||||
"F": "industrials",
|
||||
"GM": "industrials",
|
||||
"TM": "industrials",
|
||||
"HMC": "industrials",
|
||||
"VWAGY": "industrials",
|
||||
"RACE": "industrials",
|
||||
"RIVN": "industrials",
|
||||
"LCID": "industrials",
|
||||
"NIO": "industrials",
|
||||
"LNVGY": "industrials",
|
||||
}
|
||||
|
||||
_sector_cache: Dict[str, str] = {}
|
||||
|
||||
|
||||
def _llm_classify_sector(ticker: str) -> str:
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini")
|
||||
llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai")
|
||||
backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1")
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model=llm_name,
|
||||
base_url=backend_url,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a financial sector classifier. Given a stock ticker symbol, "
|
||||
"classify it into exactly one of the following sectors: "
|
||||
"technology, healthcare, finance, energy, consumer_goods, industrials, other. "
|
||||
"Respond with only the sector name in lowercase, nothing else."
|
||||
)
|
||||
|
||||
user_prompt = f"Classify the stock ticker: {ticker}"
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=user_prompt),
|
||||
]
|
||||
|
||||
response = llm.invoke(messages)
|
||||
sector = response.content.strip().lower()
|
||||
|
||||
if sector not in VALID_SECTORS:
|
||||
logger.warning(
|
||||
"LLM returned invalid sector '%s' for ticker %s, defaulting to 'other'",
|
||||
sector,
|
||||
ticker,
|
||||
)
|
||||
return "other"
|
||||
|
||||
return sector
|
||||
|
||||
|
||||
def classify_sector(ticker: str) -> str:
|
||||
ticker_upper = ticker.upper()
|
||||
|
||||
if ticker_upper in TICKER_TO_SECTOR:
|
||||
return TICKER_TO_SECTOR[ticker_upper]
|
||||
|
||||
if ticker_upper in _sector_cache:
|
||||
return _sector_cache[ticker_upper]
|
||||
|
||||
logger.info("Using LLM fallback for sector classification of ticker: %s", ticker)
|
||||
|
||||
try:
|
||||
sector = _llm_classify_sector(ticker_upper)
|
||||
_sector_cache[ticker_upper] = sector
|
||||
logger.info("Classified %s as %s via LLM", ticker, sector)
|
||||
return sector
|
||||
except Exception as e:
|
||||
logger.error("LLM sector classification failed for %s: %s", ticker, str(e))
|
||||
_sector_cache[ticker_upper] = "other"
|
||||
return "other"
|
||||
|
|
@ -0,0 +1,538 @@
|
|||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import yfinance as yf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPANY_TO_TICKER = {
|
||||
"apple": "AAPL",
|
||||
"apple inc": "AAPL",
|
||||
"apple inc.": "AAPL",
|
||||
"apple corporation": "AAPL",
|
||||
"the iphone maker": "AAPL",
|
||||
"iphone maker": "AAPL",
|
||||
"microsoft": "MSFT",
|
||||
"microsoft inc": "MSFT",
|
||||
"microsoft inc.": "MSFT",
|
||||
"microsoft corp": "MSFT",
|
||||
"microsoft corp.": "MSFT",
|
||||
"microsoft corporation": "MSFT",
|
||||
"google": "GOOGL",
|
||||
"alphabet": "GOOGL",
|
||||
"alphabet inc": "GOOGL",
|
||||
"alphabet inc.": "GOOGL",
|
||||
"the search giant": "GOOGL",
|
||||
"amazon": "AMZN",
|
||||
"amazon inc": "AMZN",
|
||||
"amazon inc.": "AMZN",
|
||||
"amazon.com": "AMZN",
|
||||
"amazon.com inc": "AMZN",
|
||||
"the e-commerce giant": "AMZN",
|
||||
"e-commerce giant": "AMZN",
|
||||
"meta": "META",
|
||||
"meta platforms": "META",
|
||||
"meta platforms inc": "META",
|
||||
"meta platforms inc.": "META",
|
||||
"facebook": "META",
|
||||
"facebook inc": "META",
|
||||
"facebook inc.": "META",
|
||||
"tesla": "TSLA",
|
||||
"tesla inc": "TSLA",
|
||||
"tesla inc.": "TSLA",
|
||||
"tesla motors": "TSLA",
|
||||
"ev maker tesla": "TSLA",
|
||||
"nvidia": "NVDA",
|
||||
"nvidia corp": "NVDA",
|
||||
"nvidia corp.": "NVDA",
|
||||
"nvidia corporation": "NVDA",
|
||||
"berkshire hathaway": "BRK-B",
|
||||
"berkshire": "BRK-B",
|
||||
"jpmorgan": "JPM",
|
||||
"jpmorgan chase": "JPM",
|
||||
"jp morgan": "JPM",
|
||||
"jp morgan chase": "JPM",
|
||||
"johnson & johnson": "JNJ",
|
||||
"johnson and johnson": "JNJ",
|
||||
"j&j": "JNJ",
|
||||
"unitedhealth": "UNH",
|
||||
"unitedhealth group": "UNH",
|
||||
"visa": "V",
|
||||
"visa inc": "V",
|
||||
"visa inc.": "V",
|
||||
"procter & gamble": "PG",
|
||||
"procter and gamble": "PG",
|
||||
"p&g": "PG",
|
||||
"mastercard": "MA",
|
||||
"mastercard inc": "MA",
|
||||
"mastercard inc.": "MA",
|
||||
"home depot": "HD",
|
||||
"the home depot": "HD",
|
||||
"chevron": "CVX",
|
||||
"chevron corp": "CVX",
|
||||
"chevron corporation": "CVX",
|
||||
"exxon": "XOM",
|
||||
"exxon mobil": "XOM",
|
||||
"exxonmobil": "XOM",
|
||||
"pfizer": "PFE",
|
||||
"pfizer inc": "PFE",
|
||||
"pfizer inc.": "PFE",
|
||||
"abbvie": "ABBV",
|
||||
"abbvie inc": "ABBV",
|
||||
"abbvie inc.": "ABBV",
|
||||
"coca-cola": "KO",
|
||||
"coca cola": "KO",
|
||||
"coke": "KO",
|
||||
"the coca-cola company": "KO",
|
||||
"pepsico": "PEP",
|
||||
"pepsi": "PEP",
|
||||
"pepsi co": "PEP",
|
||||
"costco": "COST",
|
||||
"costco wholesale": "COST",
|
||||
"walmart": "WMT",
|
||||
"wal-mart": "WMT",
|
||||
"walmart inc": "WMT",
|
||||
"bank of america": "BAC",
|
||||
"bofa": "BAC",
|
||||
"merck": "MRK",
|
||||
"merck & co": "MRK",
|
||||
"merck and co": "MRK",
|
||||
"eli lilly": "LLY",
|
||||
"lilly": "LLY",
|
||||
"eli lilly and company": "LLY",
|
||||
"adobe": "ADBE",
|
||||
"adobe inc": "ADBE",
|
||||
"adobe inc.": "ADBE",
|
||||
"adobe systems": "ADBE",
|
||||
"salesforce": "CRM",
|
||||
"salesforce inc": "CRM",
|
||||
"salesforce.com": "CRM",
|
||||
"cisco": "CSCO",
|
||||
"cisco systems": "CSCO",
|
||||
"cisco systems inc": "CSCO",
|
||||
"netflix": "NFLX",
|
||||
"netflix inc": "NFLX",
|
||||
"netflix inc.": "NFLX",
|
||||
"oracle": "ORCL",
|
||||
"oracle corp": "ORCL",
|
||||
"oracle corporation": "ORCL",
|
||||
"intel": "INTC",
|
||||
"intel corp": "INTC",
|
||||
"intel corporation": "INTC",
|
||||
"amd": "AMD",
|
||||
"advanced micro devices": "AMD",
|
||||
"qualcomm": "QCOM",
|
||||
"qualcomm inc": "QCOM",
|
||||
"qualcomm inc.": "QCOM",
|
||||
"broadcom": "AVGO",
|
||||
"broadcom inc": "AVGO",
|
||||
"broadcom inc.": "AVGO",
|
||||
"texas instruments": "TXN",
|
||||
"ti": "TXN",
|
||||
"disney": "DIS",
|
||||
"walt disney": "DIS",
|
||||
"the walt disney company": "DIS",
|
||||
"walt disney company": "DIS",
|
||||
"comcast": "CMCSA",
|
||||
"comcast corp": "CMCSA",
|
||||
"comcast corporation": "CMCSA",
|
||||
"verizon": "VZ",
|
||||
"verizon communications": "VZ",
|
||||
"at&t": "T",
|
||||
"att": "T",
|
||||
"t-mobile": "TMUS",
|
||||
"tmobile": "TMUS",
|
||||
"t-mobile us": "TMUS",
|
||||
"american express": "AXP",
|
||||
"amex": "AXP",
|
||||
"goldman sachs": "GS",
|
||||
"goldman": "GS",
|
||||
"morgan stanley": "MS",
|
||||
"wells fargo": "WFC",
|
||||
"wells": "WFC",
|
||||
"citigroup": "C",
|
||||
"citi": "C",
|
||||
"citibank": "C",
|
||||
"charles schwab": "SCHW",
|
||||
"schwab": "SCHW",
|
||||
"blackrock": "BLK",
|
||||
"blackrock inc": "BLK",
|
||||
"paypal": "PYPL",
|
||||
"paypal holdings": "PYPL",
|
||||
"paypal inc": "PYPL",
|
||||
"square": "SQ",
|
||||
"block": "SQ",
|
||||
"block inc": "SQ",
|
||||
"shopify": "SHOP",
|
||||
"shopify inc": "SHOP",
|
||||
"uber": "UBER",
|
||||
"uber technologies": "UBER",
|
||||
"lyft": "LYFT",
|
||||
"lyft inc": "LYFT",
|
||||
"airbnb": "ABNB",
|
||||
"airbnb inc": "ABNB",
|
||||
"doordash": "DASH",
|
||||
"doordash inc": "DASH",
|
||||
"snap": "SNAP",
|
||||
"snap inc": "SNAP",
|
||||
"snapchat": "SNAP",
|
||||
"pinterest": "PINS",
|
||||
"pinterest inc": "PINS",
|
||||
"twitter": "TWTR",
|
||||
"twitter inc": "TWTR",
|
||||
"linkedin": "MSFT",
|
||||
"zoom": "ZM",
|
||||
"zoom video": "ZM",
|
||||
"zoom video communications": "ZM",
|
||||
"slack": "CRM",
|
||||
"slack technologies": "CRM",
|
||||
"palantir": "PLTR",
|
||||
"palantir technologies": "PLTR",
|
||||
"snowflake": "SNOW",
|
||||
"snowflake inc": "SNOW",
|
||||
"datadog": "DDOG",
|
||||
"datadog inc": "DDOG",
|
||||
"crowdstrike": "CRWD",
|
||||
"crowdstrike holdings": "CRWD",
|
||||
"okta": "OKTA",
|
||||
"okta inc": "OKTA",
|
||||
"cloudflare": "NET",
|
||||
"cloudflare inc": "NET",
|
||||
"mongodb": "MDB",
|
||||
"mongodb inc": "MDB",
|
||||
"twilio": "TWLO",
|
||||
"twilio inc": "TWLO",
|
||||
"servicenow": "NOW",
|
||||
"servicenow inc": "NOW",
|
||||
"workday": "WDAY",
|
||||
"workday inc": "WDAY",
|
||||
"splunk": "SPLK",
|
||||
"splunk inc": "SPLK",
|
||||
"vmware": "VMW",
|
||||
"vmware inc": "VMW",
|
||||
"ibm": "IBM",
|
||||
"international business machines": "IBM",
|
||||
"hp": "HPQ",
|
||||
"hewlett-packard": "HPQ",
|
||||
"hewlett packard": "HPQ",
|
||||
"dell": "DELL",
|
||||
"dell technologies": "DELL",
|
||||
"lenovo": "LNVGY",
|
||||
"boeing": "BA",
|
||||
"boeing company": "BA",
|
||||
"the boeing company": "BA",
|
||||
"lockheed martin": "LMT",
|
||||
"lockheed": "LMT",
|
||||
"raytheon": "RTX",
|
||||
"rtx": "RTX",
|
||||
"general dynamics": "GD",
|
||||
"northrop grumman": "NOC",
|
||||
"northrop": "NOC",
|
||||
"general electric": "GE",
|
||||
"ge": "GE",
|
||||
"honeywell": "HON",
|
||||
"honeywell international": "HON",
|
||||
"3m": "MMM",
|
||||
"3m company": "MMM",
|
||||
"caterpillar": "CAT",
|
||||
"caterpillar inc": "CAT",
|
||||
"deere": "DE",
|
||||
"john deere": "DE",
|
||||
"deere & company": "DE",
|
||||
"union pacific": "UNP",
|
||||
"ups": "UPS",
|
||||
"united parcel service": "UPS",
|
||||
"fedex": "FDX",
|
||||
"federal express": "FDX",
|
||||
"delta": "DAL",
|
||||
"delta air lines": "DAL",
|
||||
"delta airlines": "DAL",
|
||||
"united airlines": "UAL",
|
||||
"united": "UAL",
|
||||
"american airlines": "AAL",
|
||||
"southwest": "LUV",
|
||||
"southwest airlines": "LUV",
|
||||
"ford": "F",
|
||||
"ford motor": "F",
|
||||
"ford motor company": "F",
|
||||
"general motors": "GM",
|
||||
"gm": "GM",
|
||||
"toyota": "TM",
|
||||
"toyota motor": "TM",
|
||||
"honda": "HMC",
|
||||
"honda motor": "HMC",
|
||||
"volkswagen": "VWAGY",
|
||||
"vw": "VWAGY",
|
||||
"ferrari": "RACE",
|
||||
"rivian": "RIVN",
|
||||
"rivian automotive": "RIVN",
|
||||
"lucid": "LCID",
|
||||
"lucid motors": "LCID",
|
||||
"lucid group": "LCID",
|
||||
"nio": "NIO",
|
||||
"nio inc": "NIO",
|
||||
"moderna": "MRNA",
|
||||
"moderna inc": "MRNA",
|
||||
"biontech": "BNTX",
|
||||
"cvs": "CVS",
|
||||
"cvs health": "CVS",
|
||||
"walgreens": "WBA",
|
||||
"walgreens boots alliance": "WBA",
|
||||
"mckesson": "MCK",
|
||||
"mckesson corp": "MCK",
|
||||
"cardinal health": "CAH",
|
||||
"humana": "HUM",
|
||||
"humana inc": "HUM",
|
||||
"cigna": "CI",
|
||||
"cigna group": "CI",
|
||||
"anthem": "ELV",
|
||||
"elevance health": "ELV",
|
||||
"starbucks": "SBUX",
|
||||
"starbucks corp": "SBUX",
|
||||
"starbucks corporation": "SBUX",
|
||||
"mcdonalds": "MCD",
|
||||
"mcdonald's": "MCD",
|
||||
"chipotle": "CMG",
|
||||
"chipotle mexican grill": "CMG",
|
||||
"yum brands": "YUM",
|
||||
"yum": "YUM",
|
||||
"dominos": "DPZ",
|
||||
"domino's": "DPZ",
|
||||
"domino's pizza": "DPZ",
|
||||
"nike": "NKE",
|
||||
"nike inc": "NKE",
|
||||
"adidas": "ADDYY",
|
||||
"lululemon": "LULU",
|
||||
"lululemon athletica": "LULU",
|
||||
"target": "TGT",
|
||||
"target corp": "TGT",
|
||||
"target corporation": "TGT",
|
||||
"dollar general": "DG",
|
||||
"dollar tree": "DLTR",
|
||||
"ross stores": "ROST",
|
||||
"ross": "ROST",
|
||||
"tjx": "TJX",
|
||||
"tjx companies": "TJX",
|
||||
"tj maxx": "TJX",
|
||||
"lowes": "LOW",
|
||||
"lowe's": "LOW",
|
||||
"lowe's companies": "LOW",
|
||||
"autozone": "AZO",
|
||||
"o'reilly": "ORLY",
|
||||
"o'reilly automotive": "ORLY",
|
||||
"carmax": "KMX",
|
||||
"estee lauder": "EL",
|
||||
"colgate": "CL",
|
||||
"colgate-palmolive": "CL",
|
||||
"colgate palmolive": "CL",
|
||||
"kimberly-clark": "KMB",
|
||||
"kimberly clark": "KMB",
|
||||
"clorox": "CLX",
|
||||
"clorox company": "CLX",
|
||||
"kraft heinz": "KHC",
|
||||
"kraft": "KHC",
|
||||
"heinz": "KHC",
|
||||
"general mills": "GIS",
|
||||
"kellogg": "K",
|
||||
"kellogg's": "K",
|
||||
"mondelez": "MDLZ",
|
||||
"mondelez international": "MDLZ",
|
||||
"hershey": "HSY",
|
||||
"the hershey company": "HSY",
|
||||
"tyson": "TSN",
|
||||
"tyson foods": "TSN",
|
||||
"beyond meat": "BYND",
|
||||
"conagra": "CAG",
|
||||
"conagra brands": "CAG",
|
||||
"constellation brands": "STZ",
|
||||
"anheuser-busch": "BUD",
|
||||
"anheuser busch": "BUD",
|
||||
"ab inbev": "BUD",
|
||||
"diageo": "DEO",
|
||||
"philip morris": "PM",
|
||||
"philip morris international": "PM",
|
||||
"altria": "MO",
|
||||
"altria group": "MO",
|
||||
"constellation energy": "CEG",
|
||||
"nextera": "NEE",
|
||||
"nextera energy": "NEE",
|
||||
"duke energy": "DUK",
|
||||
"southern company": "SO",
|
||||
"dominion": "D",
|
||||
"dominion energy": "D",
|
||||
"sempra": "SRE",
|
||||
"sempra energy": "SRE",
|
||||
"conocophillips": "COP",
|
||||
"conoco": "COP",
|
||||
"schlumberger": "SLB",
|
||||
"halliburton": "HAL",
|
||||
"baker hughes": "BKR",
|
||||
"marathon": "MPC",
|
||||
"marathon petroleum": "MPC",
|
||||
"valero": "VLO",
|
||||
"valero energy": "VLO",
|
||||
"phillips 66": "PSX",
|
||||
"occidental": "OXY",
|
||||
"occidental petroleum": "OXY",
|
||||
"pioneer": "PXD",
|
||||
"pioneer natural resources": "PXD",
|
||||
"devon energy": "DVN",
|
||||
"devon": "DVN",
|
||||
"coinbase": "COIN",
|
||||
"coinbase global": "COIN",
|
||||
"robinhood": "HOOD",
|
||||
"robinhood markets": "HOOD",
|
||||
"sofi": "SOFI",
|
||||
"sofi technologies": "SOFI",
|
||||
"affirm": "AFRM",
|
||||
"affirm holdings": "AFRM",
|
||||
"marqeta": "MQ",
|
||||
"toast": "TOST",
|
||||
"toast inc": "TOST",
|
||||
"docusign": "DOCU",
|
||||
"docusign inc": "DOCU",
|
||||
"asana": "ASAN",
|
||||
"monday.com": "MNDY",
|
||||
"monday": "MNDY",
|
||||
"atlassian": "TEAM",
|
||||
"atlassian corp": "TEAM",
|
||||
"intuit": "INTU",
|
||||
"intuit inc": "INTU",
|
||||
"autodesk": "ADSK",
|
||||
"autodesk inc": "ADSK",
|
||||
"synopsys": "SNPS",
|
||||
"cadence": "CDNS",
|
||||
"cadence design": "CDNS",
|
||||
"ansys": "ANSS",
|
||||
"roper": "ROP",
|
||||
"roper technologies": "ROP",
|
||||
"fortinet": "FTNT",
|
||||
"palo alto": "PANW",
|
||||
"palo alto networks": "PANW",
|
||||
"zscaler": "ZS",
|
||||
"sentinelone": "S",
|
||||
"veeva": "VEEV",
|
||||
"veeva systems": "VEEV",
|
||||
}
|
||||
|
||||
US_EXCHANGE_CODES = {
|
||||
"NYQ",
|
||||
"NMS",
|
||||
"NGM",
|
||||
"NCM",
|
||||
"ASE",
|
||||
"PCX",
|
||||
"BTS",
|
||||
"NYSE",
|
||||
"NASDAQ",
|
||||
"AMEX",
|
||||
"NYS",
|
||||
"NAS",
|
||||
"NIM",
|
||||
"NAQ",
|
||||
}
|
||||
|
||||
SUFFIX_PATTERNS = [
|
||||
r"\s+inc\.?$",
|
||||
r"\s+corp\.?$",
|
||||
r"\s+corporation$",
|
||||
r"\s+co\.?$",
|
||||
r"\s+company$",
|
||||
r"\s+llc$",
|
||||
r"\s+ltd\.?$",
|
||||
r"\s+limited$",
|
||||
r"\s+plc$",
|
||||
r"\s+holdings?$",
|
||||
r"\s+group$",
|
||||
r"\s+technologies$",
|
||||
r"\s+enterprises?$",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_company_name(name: str) -> str:
|
||||
normalized = name.lower().strip()
|
||||
for pattern in SUFFIX_PATTERNS:
|
||||
normalized = re.sub(pattern, "", normalized, flags=re.IGNORECASE)
|
||||
normalized = normalized.strip()
|
||||
return normalized
|
||||
|
||||
|
||||
def _search_yfinance_ticker(company_name: str) -> Optional[str]:
|
||||
try:
|
||||
search_result = yf.Ticker(company_name)
|
||||
info = search_result.info
|
||||
if info and "symbol" in info:
|
||||
return info["symbol"]
|
||||
except Exception as e:
|
||||
logger.debug("yfinance search failed for %s: %s", company_name, str(e))
|
||||
|
||||
try:
|
||||
search = yf.Search(company_name, max_results=5)
|
||||
if hasattr(search, "quotes") and search.quotes:
|
||||
for quote in search.quotes:
|
||||
if "symbol" in quote:
|
||||
return quote["symbol"]
|
||||
except Exception as e:
|
||||
logger.debug("yfinance Search failed for %s: %s", company_name, str(e))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_us_ticker(ticker: str) -> bool:
|
||||
try:
|
||||
ticker_obj = yf.Ticker(ticker.upper())
|
||||
info = ticker_obj.info
|
||||
if not info:
|
||||
logger.warning("Validation failed for %s: no info available", ticker)
|
||||
return False
|
||||
|
||||
exchange = info.get("exchange", "")
|
||||
if exchange in US_EXCHANGE_CODES:
|
||||
return True
|
||||
|
||||
exchange_lower = exchange.lower()
|
||||
if any(us_ex.lower() in exchange_lower for us_ex in ["nyse", "nasdaq", "amex", "nys", "nms", "ngm"]):
|
||||
return True
|
||||
|
||||
logger.warning("Validation failed for %s: exchange %s is not a US exchange", ticker, exchange)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Validation failed for %s: %s", ticker, str(e))
|
||||
return False
|
||||
|
||||
|
||||
def resolve_ticker(company_name: str) -> Optional[str]:
|
||||
if not company_name or not company_name.strip():
|
||||
return None
|
||||
|
||||
normalized = company_name.lower().strip()
|
||||
|
||||
if normalized in COMPANY_TO_TICKER:
|
||||
return COMPANY_TO_TICKER[normalized]
|
||||
|
||||
normalized_stripped = _normalize_company_name(company_name)
|
||||
if normalized_stripped in COMPANY_TO_TICKER:
|
||||
return COMPANY_TO_TICKER[normalized_stripped]
|
||||
|
||||
if company_name.upper() in [v for v in COMPANY_TO_TICKER.values()]:
|
||||
if validate_us_ticker(company_name.upper()):
|
||||
return company_name.upper()
|
||||
|
||||
logger.info("Using yfinance fallback for company: %s", company_name)
|
||||
yf_ticker = _search_yfinance_ticker(company_name)
|
||||
|
||||
if yf_ticker:
|
||||
if validate_us_ticker(yf_ticker):
|
||||
logger.info("Resolved %s to %s via yfinance", company_name, yf_ticker)
|
||||
return yf_ticker
|
||||
else:
|
||||
logger.warning("Ticker %s for %s failed US exchange validation", yf_ticker, company_name)
|
||||
return None
|
||||
|
||||
logger.warning("Could not resolve ticker for company: %s", company_name)
|
||||
return None
|
||||
|
||||
|
||||
def validate_tradeable(ticker: str) -> bool:
|
||||
return validate_us_ticker(ticker)
|
||||
|
|
@ -8,26 +8,24 @@ DEFAULT_CONFIG = {
|
|||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "o4-mini",
|
||||
"quick_think_llm": "gpt-4o-mini",
|
||||
"deep_think_llm": "gpt-5",
|
||||
"quick_think_llm": "gpt-5-mini",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
"max_debate_rounds": 2,
|
||||
"max_risk_discuss_rounds": 2,
|
||||
"max_recur_limit": 100,
|
||||
# Data vendor configuration
|
||||
# Category-level configuration (default for all tools in category)
|
||||
"data_vendors": {
|
||||
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
|
||||
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
|
||||
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
|
||||
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
|
||||
"core_stock_apis": "yfinance",
|
||||
"technical_indicators": "yfinance",
|
||||
"fundamental_data": "alpha_vantage",
|
||||
"news_data": "alpha_vantage",
|
||||
},
|
||||
# Tool-level configuration (takes precedence over category-level)
|
||||
"tool_vendors": {
|
||||
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
||||
# Example: "get_news": "openai", # Override category default
|
||||
},
|
||||
"discovery_timeout": 60,
|
||||
"discovery_hard_timeout": 120,
|
||||
"discovery_cache_ttl": 300,
|
||||
"discovery_max_results": 20,
|
||||
"discovery_min_mentions": 2,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import date
|
||||
from datetime import date, datetime
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
|
@ -22,7 +22,6 @@ from tradingagents.agents.utils.agent_states import (
|
|||
)
|
||||
from tradingagents.dataflows.config import set_config
|
||||
|
||||
# Import the new abstract tool methods from agent_utils
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
get_stock_data,
|
||||
get_indicators,
|
||||
|
|
@ -36,6 +35,19 @@ from tradingagents.agents.utils.agent_utils import (
|
|||
get_global_news
|
||||
)
|
||||
|
||||
from tradingagents.agents.discovery import (
|
||||
DiscoveryRequest,
|
||||
DiscoveryResult,
|
||||
DiscoveryStatus,
|
||||
TrendingStock,
|
||||
Sector,
|
||||
EventCategory,
|
||||
DiscoveryTimeoutError,
|
||||
extract_entities,
|
||||
calculate_trending_scores,
|
||||
)
|
||||
from tradingagents.dataflows.interface import get_bulk_news
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
|
|
@ -43,8 +55,15 @@ from .reflection import Reflector
|
|||
from .signal_processing import SignalProcessor
|
||||
|
||||
|
||||
class DiscoveryTimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _timeout_handler(signum, frame):
|
||||
raise DiscoveryTimeoutException("Discovery operation timed out")
|
||||
|
||||
|
||||
class TradingAgentsGraph:
|
||||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -52,26 +71,16 @@ class TradingAgentsGraph:
|
|||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""Initialize the trading agents graph and components.
|
||||
|
||||
Args:
|
||||
selected_analysts: List of analyst types to include
|
||||
debug: Whether to run in debug mode
|
||||
config: Configuration dictionary. If None, uses default config
|
||||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs(
|
||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
# Initialize LLMs
|
||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
|
|
@ -83,18 +92,13 @@ class TradingAgentsGraph:
|
|||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||
|
||||
# Initialize memories
|
||||
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||||
|
||||
# Create tool nodes
|
||||
self.tool_nodes = self._create_tool_nodes()
|
||||
|
||||
# Initialize components
|
||||
self.conditional_logic = ConditionalLogic()
|
||||
self.graph_setup = GraphSetup(
|
||||
self.quick_thinking_llm,
|
||||
|
|
@ -111,35 +115,26 @@ class TradingAgentsGraph:
|
|||
self.propagator = Propagator()
|
||||
self.reflector = Reflector(self.quick_thinking_llm)
|
||||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||||
|
||||
# State tracking
|
||||
self.curr_state = None
|
||||
self.ticker = None
|
||||
self.log_states_dict = {} # date to full state dict
|
||||
|
||||
# Set up the graph
|
||||
self.log_states_dict = {}
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
|
||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||
"""Create tool nodes for different data sources using abstract methods."""
|
||||
return {
|
||||
"market": ToolNode(
|
||||
[
|
||||
# Core stock data tools
|
||||
get_stock_data,
|
||||
# Technical indicators
|
||||
get_indicators,
|
||||
]
|
||||
),
|
||||
"social": ToolNode(
|
||||
[
|
||||
# News tools for social media analysis
|
||||
get_news,
|
||||
]
|
||||
),
|
||||
"news": ToolNode(
|
||||
[
|
||||
# News and insider information
|
||||
get_news,
|
||||
get_global_news,
|
||||
get_insider_sentiment,
|
||||
|
|
@ -148,7 +143,6 @@ class TradingAgentsGraph:
|
|||
),
|
||||
"fundamentals": ToolNode(
|
||||
[
|
||||
# Fundamental analysis tools
|
||||
get_fundamentals,
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
|
|
@ -158,18 +152,13 @@ class TradingAgentsGraph:
|
|||
}
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
"""Run the trading agents graph for a company on a specific date."""
|
||||
|
||||
self.ticker = company_name
|
||||
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
if self.debug:
|
||||
# Debug mode with tracing
|
||||
trace = []
|
||||
for chunk in self.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) == 0:
|
||||
|
|
@ -180,20 +169,14 @@ class TradingAgentsGraph:
|
|||
|
||||
final_state = trace[-1]
|
||||
else:
|
||||
# Standard mode without tracing
|
||||
final_state = self.graph.invoke(init_agent_state, **args)
|
||||
|
||||
# Store current state for reflection
|
||||
self.curr_state = final_state
|
||||
|
||||
# Log state
|
||||
self._log_state(trade_date, final_state)
|
||||
|
||||
# Return decision and processed signal
|
||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
def _log_state(self, trade_date, final_state):
|
||||
"""Log the final state to a JSON file."""
|
||||
self.log_states_dict[str(trade_date)] = {
|
||||
"company_of_interest": final_state["company_of_interest"],
|
||||
"trade_date": final_state["trade_date"],
|
||||
|
|
@ -224,7 +207,6 @@ class TradingAgentsGraph:
|
|||
"final_trade_decision": final_state["final_trade_decision"],
|
||||
}
|
||||
|
||||
# Save to file
|
||||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
@ -235,7 +217,6 @@ class TradingAgentsGraph:
|
|||
json.dump(self.log_states_dict, f, indent=4)
|
||||
|
||||
def reflect_and_remember(self, returns_losses):
|
||||
"""Reflect on decisions and update memory based on returns."""
|
||||
self.reflector.reflect_bull_researcher(
|
||||
self.curr_state, returns_losses, self.bull_memory
|
||||
)
|
||||
|
|
@ -253,5 +234,95 @@ class TradingAgentsGraph:
|
|||
)
|
||||
|
||||
def process_signal(self, full_signal):
|
||||
"""Process a signal to extract the core decision."""
|
||||
return self.signal_processor.process_signal(full_signal)
|
||||
|
||||
def discover_trending(
|
||||
self,
|
||||
request: Optional[DiscoveryRequest] = None,
|
||||
) -> DiscoveryResult:
|
||||
if request is None:
|
||||
request = DiscoveryRequest(
|
||||
lookback_period="24h",
|
||||
max_results=self.config.get("discovery_max_results", 20),
|
||||
)
|
||||
|
||||
started_at = datetime.now()
|
||||
result = DiscoveryResult(
|
||||
request=request,
|
||||
trending_stocks=[],
|
||||
status=DiscoveryStatus.PROCESSING,
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
hard_timeout = self.config.get("discovery_hard_timeout", 120)
|
||||
|
||||
discovery_result = {"stocks": [], "error": None}
|
||||
|
||||
def run_discovery():
|
||||
try:
|
||||
articles = get_bulk_news(request.lookback_period)
|
||||
|
||||
mentions = extract_entities(articles, self.config)
|
||||
|
||||
min_mentions = self.config.get("discovery_min_mentions", 2)
|
||||
max_results = request.max_results or self.config.get("discovery_max_results", 20)
|
||||
|
||||
trending_stocks = calculate_trending_scores(
|
||||
mentions,
|
||||
articles,
|
||||
max_results=max_results,
|
||||
min_mentions=min_mentions,
|
||||
)
|
||||
|
||||
discovery_result["stocks"] = trending_stocks
|
||||
except Exception as e:
|
||||
discovery_result["error"] = str(e)
|
||||
|
||||
discovery_thread = threading.Thread(target=run_discovery)
|
||||
discovery_thread.start()
|
||||
discovery_thread.join(timeout=hard_timeout)
|
||||
|
||||
if discovery_thread.is_alive():
|
||||
raise DiscoveryTimeoutError(
|
||||
f"Discovery operation exceeded {hard_timeout} second timeout"
|
||||
)
|
||||
|
||||
if discovery_result["error"]:
|
||||
result.status = DiscoveryStatus.FAILED
|
||||
result.error_message = discovery_result["error"]
|
||||
result.completed_at = datetime.now()
|
||||
return result
|
||||
|
||||
trending_stocks = discovery_result["stocks"]
|
||||
|
||||
if request.sector_filter:
|
||||
sector_values = {s.value if isinstance(s, Sector) else s for s in request.sector_filter}
|
||||
trending_stocks = [
|
||||
stock for stock in trending_stocks
|
||||
if stock.sector.value in sector_values or stock.sector in request.sector_filter
|
||||
]
|
||||
|
||||
if request.event_filter:
|
||||
event_values = {e.value if isinstance(e, EventCategory) else e for e in request.event_filter}
|
||||
trending_stocks = [
|
||||
stock for stock in trending_stocks
|
||||
if stock.event_type.value in event_values or stock.event_type in request.event_filter
|
||||
]
|
||||
|
||||
result.trending_stocks = trending_stocks
|
||||
result.status = DiscoveryStatus.COMPLETED
|
||||
result.completed_at = datetime.now()
|
||||
|
||||
return result
|
||||
|
||||
def analyze_trending(
|
||||
self,
|
||||
trending_stock: TrendingStock,
|
||||
trade_date: Optional[date] = None,
|
||||
) -> Tuple[Dict[str, Any], str]:
|
||||
ticker = trending_stock.ticker
|
||||
|
||||
if trade_date is None:
|
||||
trade_date = date.today()
|
||||
|
||||
return self.propagate(ticker, trade_date)
|
||||
|
|
|
|||
Loading…
Reference in New Issue