omfg basic stub out with passable lint and types

This commit is contained in:
Martin C. Richards 2025-08-03 22:07:24 +02:00
parent c93ffb6452
commit 775258b950
30 changed files with 3125 additions and 3404 deletions

343
README.md
View File

@ -418,6 +418,42 @@ LangGraph-based workflow management:
- **Constants**: UPPER_CASE (e.g., `DEFAULT_CONFIG`)
- **Imports**: Standard library first, third-party, then local imports (langchain, tradingagents modules)
#### Data Structure Guidelines
**MANDATORY: Always use dataclasses for method returns**
- **Never return**: `dict`, `str`, `Any`, or unstructured data from public methods
- **Always return**: Properly typed dataclasses with clear field definitions
- **Rationale**: Provides type safety, IDE support, clear contracts, and prevents runtime errors
**Examples**:
```python
# ❌ BAD - Dictionary returns
def update_news() -> dict[str, Any]:
return {"status": "completed", "count": 5}
# ✅ GOOD - Dataclass returns
@dataclass
class NewsUpdateResult:
status: str
articles_found: int
articles_scraped: int
articles_failed: int
def update_news() -> NewsUpdateResult:
return NewsUpdateResult(
status="completed",
articles_found=10,
articles_scraped=8,
articles_failed=2
)
```
**Dataclass Best Practices**:
- Use `@dataclass` decorator for all return value structures
- Include type hints for all fields
- Use `| None` for optional fields (modern Python 3.10+ syntax)
- Group related dataclasses in the same module
- Prefer immutable dataclasses with `frozen=True` for value objects
#### Ruff Formatting & Linting Rules
**Formatting** (`mise run format`):
- **Line length**: 88 characters maximum
@ -534,6 +570,313 @@ service.update_market_data("AAPL", "2024-01-01", "2024-01-31")
- Questionnaire-driven configuration collection
- Real-time streaming of analysis results
### Progressive Development Framework
This framework ensures agents create type-safe, testable code through incremental development. It emphasizes building one component at a time with proper testing and type safety.
#### Core Principles
1. **Service-First Development**: Start with business logic in the service layer
2. **Stub Dependencies**: Create placeholder methods that return proper dataclasses
3. **Progressive Implementation**: Implement one dependency (client OR repository) at a time
4. **Constructor Injection**: Dependencies passed through constructor for testability
5. **Dataclass Returns**: All public methods return properly typed dataclasses
6. **Test-Driven Development**: Write tests first, implement to make them pass
#### Development Process
**Step 1: Design Domain Models**
```python
# models.py - Define all dataclasses first
@dataclass
class DomainEntity:
id: str
name: str
created_at: datetime
@dataclass
class DomainContext:
entities: list[DomainEntity]
metadata: dict[str, Any]
quality_score: float
@dataclass
class UpdateResult:
status: str
entities_processed: int
entities_failed: int
```
**Step 2: Create Service with Business Logic**
```python
# service.py - Main business logic with stub dependencies
class DomainService:
def __init__(self, client: DomainClient, repository: DomainRepository):
self.client = client
self.repository = repository
def get_context(self, symbol: str, start_date: str, end_date: str) -> DomainContext:
# Implement business logic flow
entities = self.repository.get_entities(symbol, start_date, end_date)
# Process and transform data
processed_entities = self._process_entities(entities)
# Calculate quality metrics
quality_score = self._calculate_quality(processed_entities)
return DomainContext(
entities=processed_entities,
metadata={"symbol": symbol, "date_range": f"{start_date} to {end_date}"},
quality_score=quality_score
)
def update_data(self, symbol: str, start_date: str, end_date: str) -> UpdateResult:
# Business logic for updating data
raw_data = self.client.fetch_data(symbol, start_date, end_date)
entities = self._transform_raw_data(raw_data)
processed = 0
failed = 0
for entity in entities:
try:
self.repository.save_entity(entity)
processed += 1
except Exception:
failed += 1
return UpdateResult(
status="completed",
entities_processed=processed,
entities_failed=failed
)
def _process_entities(self, entities: list[DomainEntity]) -> list[DomainEntity]:
# Private method for business logic
return entities # Stub implementation
def _calculate_quality(self, entities: list[DomainEntity]) -> float:
# Private method for quality calculation
return 1.0 # Stub implementation
```
**Step 3: Create Stub Dependencies**
```python
# client.py - Stub client that returns proper dataclasses
class DomainClient:
def fetch_data(self, symbol: str, start_date: str, end_date: str) -> list[dict[str, Any]]:
# Stub implementation - returns realistic structure
return [
{"id": "1", "name": f"{symbol}_entity", "created_at": "2024-01-01T00:00:00Z"},
{"id": "2", "name": f"{symbol}_entity_2", "created_at": "2024-01-02T00:00:00Z"}
]
# repository.py - Stub repository that returns proper dataclasses
class DomainRepository:
def __init__(self, cache_dir: str):
self.cache_dir = cache_dir
def get_entities(self, symbol: str, start_date: str, end_date: str) -> list[DomainEntity]:
# Stub implementation - returns proper dataclasses
return [
DomainEntity(id="1", name=f"{symbol}_cached", created_at=datetime.now()),
DomainEntity(id="2", name=f"{symbol}_cached_2", created_at=datetime.now())
]
def save_entity(self, entity: DomainEntity) -> None:
# Stub implementation
pass
```
**Step 4: Write Comprehensive Tests**
```python
# service_test.py - Test the service with mock dependencies
from unittest.mock import Mock
import pytest
def test_get_context_with_mock_dependencies():
"""Test service business logic with mocked dependencies."""
# Mock the dependencies
mock_client = Mock()
mock_repository = Mock()
# Configure mock returns
mock_repository.get_entities.return_value = [
DomainEntity(id="1", name="TEST_entity", created_at=datetime(2024, 1, 1))
]
# Create service with mocks
service = DomainService(client=mock_client, repository=mock_repository)
# Test the business logic
context = service.get_context("TEST", "2024-01-01", "2024-01-31")
# Validate structure and business logic
assert isinstance(context, DomainContext)
assert context.metadata["symbol"] == "TEST"
assert context.quality_score > 0
assert len(context.entities) > 0
# Verify repository was called correctly
mock_repository.get_entities.assert_called_once_with("TEST", "2024-01-01", "2024-01-31")
def test_update_data_with_mock_dependencies():
"""Test update business logic with mocked dependencies."""
mock_client = Mock()
mock_repository = Mock()
# Configure mock client to return raw data
mock_client.fetch_data.return_value = [
{"id": "1", "name": "TEST_raw", "created_at": "2024-01-01T00:00:00Z"}
]
service = DomainService(client=mock_client, repository=mock_repository)
result = service.update_data("TEST", "2024-01-01", "2024-01-31")
# Validate business logic results
assert isinstance(result, UpdateResult)
assert result.status == "completed"
assert result.entities_processed >= 0
# Verify client and repository interactions
mock_client.fetch_data.assert_called_once()
mock_repository.save_entity.assert_called()
```
**Step 5: Implement One Dependency at a Time**
Choose either client OR repository to implement first:
```python
# Option A: Implement client first
class DomainClient:
def __init__(self, api_key: str):
self.api_key = api_key
self.session = requests.Session()
self.session.headers.update({"User-Agent": "TradingAgents/1.0"})
def fetch_data(self, symbol: str, start_date: str, end_date: str) -> list[dict[str, Any]]:
# Real implementation with error handling
try:
response = self.session.get(
f"https://api.example.com/data/{symbol}",
params={"start": start_date, "end": end_date},
timeout=30
)
response.raise_for_status()
return response.json()["data"]
except requests.RequestException as e:
raise DomainClientError(f"Failed to fetch data: {e}")
# Option B: Implement repository first
class DomainRepository:
def __init__(self, cache_dir: str):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
def get_entities(self, symbol: str, start_date: str, end_date: str) -> list[DomainEntity]:
# Real implementation with file I/O
cache_file = self.cache_dir / f"{symbol}_{start_date}_{end_date}.json"
if not cache_file.exists():
return []
try:
with open(cache_file, 'r') as f:
data = json.load(f)
return [
DomainEntity(
id=item["id"],
name=item["name"],
created_at=datetime.fromisoformat(item["created_at"])
)
for item in data
]
except (json.JSONDecodeError, KeyError) as e:
raise DomainRepositoryError(f"Failed to load cached data: {e}")
```
**Step 6: Test Real Implementation**
```python
def test_real_client_integration():
"""Test real client implementation."""
client = DomainClient(api_key="test_key")
# Test with real HTTP calls (or use responses library for mocking)
with responses.RequestsMock() as rsps:
rsps.add(
responses.GET,
"https://api.example.com/data/TEST",
json={"data": [{"id": "1", "name": "TEST", "created_at": "2024-01-01T00:00:00Z"}]},
status=200
)
result = client.fetch_data("TEST", "2024-01-01", "2024-01-31")
assert len(result) == 1
assert result[0]["id"] == "1"
def test_real_repository_integration():
"""Test real repository implementation."""
with tempfile.TemporaryDirectory() as temp_dir:
repo = DomainRepository(temp_dir)
# Test saving and loading
entity = DomainEntity(id="1", name="TEST", created_at=datetime.now())
repo.save_entity(entity)
entities = repo.get_entities("TEST", "2024-01-01", "2024-01-31")
assert len(entities) == 1
assert entities[0].id == "1"
```
**Step 7: Iterate and Refine**
1. Run tests after each implementation
2. Refactor business logic as needed
3. Add error handling and edge cases
4. Implement the remaining dependency
5. Add integration tests with both real dependencies
#### Directory Structure
```
domain_name/
├── models.py # Dataclasses only - no business logic
├── client.py # External API integration
├── repository.py # Data persistence and caching
├── service.py # Main business logic coordinator
└── service_test.py # Comprehensive test suite
```
#### Benefits of This Approach
1. **Type Safety**: All interfaces defined upfront with dataclasses
2. **Testability**: Business logic tested independently of external dependencies
3. **Incremental Development**: One component at a time reduces complexity
4. **Clear Contracts**: Dataclass returns make interfaces explicit
5. **Error Isolation**: Issues contained within single components
6. **Refactoring Safety**: Type system catches interface changes
7. **Documentation**: Dataclasses serve as living documentation
#### Anti-Patterns to Avoid
❌ **Don't return dictionaries or strings from public methods**
❌ **Don't implement all dependencies simultaneously**
❌ **Don't skip writing tests first**
❌ **Don't mix business logic with I/O operations**
❌ **Don't use inheritance for dependency injection**
❌ **Don't create circular dependencies between components**
✅ **Do use dataclasses for all return values**
✅ **Do implement one dependency at a time**
✅ **Do write tests before implementation**
✅ **Do separate business logic from I/O**
✅ **Do use constructor injection**
✅ **Do maintain clear separation of concerns**
### File Structure Context
- **`cli/`**: Interactive command-line interface
- **`tradingagents/agents/`**: All agent implementations

View File

@ -32,6 +32,8 @@ dependencies = [
"typer>=0.12.0",
"typing-extensions>=4.14.0",
"yfinance>=0.2.63",
"TA-Lib>=0.4.28",
"newspaper3k>=0.2.8",
]
[project.optional-dependencies]

View File

@ -2,6 +2,9 @@ from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .libs.agent_states import AgentState, InvestDebateState, RiskDebateState
from .libs.context_helpers import create_msg_delete
from .libs.memory import FinancialSituationMemory
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .researchers.bear_researcher import create_bear_researcher
@ -10,26 +13,20 @@ from .risk_mgmt.aggresive_debator import create_risky_debator
from .risk_mgmt.conservative_debator import create_safe_debator
from .risk_mgmt.neutral_debator import create_neutral_debator
from .trader.trader import create_trader
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.memory import FinancialSituationMemory
from .utils.service_toolkit import ServiceToolkit
__all__ = [
"FinancialSituationMemory",
"Toolkit",
"ServiceToolkit",
"AgentState",
"create_msg_delete",
"InvestDebateState",
"RiskDebateState",
"create_bear_researcher",
"create_bull_researcher",
"create_research_manager",
"create_fundamentals_analyst",
"create_market_analyst",
"create_msg_delete",
"create_neutral_debator",
"create_news_analyst",
"create_research_manager",
"create_risky_debator",
"create_risk_manager",
"create_safe_debator",

View File

@ -1,6 +1,6 @@
import logging
import re
from datetime import datetime, timedelta
from datetime import date, datetime, timedelta
from typing import Annotated
from langchain_core.tools import tool
@ -41,9 +41,9 @@ class AgentToolkit:
def __init__(
self,
news_service: NewsService,
socialmedia_service: SocialMediaService,
marketdata_service: MarketDataService,
fundamentaldata_service: FundamentalDataService,
socialmedia_service: SocialMediaService,
insiderdata_service: InsiderDataService,
config: TradingAgentsConfig = DEFAULT_CONFIG,
):
@ -102,8 +102,8 @@ class AgentToolkit:
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
return self._news_service.get_context(
query=ticker, start_date=start_date, end_date=end_date, symbol=ticker
return self._news_service.get_company_news_context(
symbol=ticker, start_date=start_date, end_date=end_date
)
except Exception as e:
logger.error(f"Error getting news for {ticker}: {e}")
@ -177,7 +177,7 @@ class AgentToolkit:
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
],
look_back_days: Annotated[int, "how many days to look back"] = None,
look_back_days: Annotated[int, "how many days to look back"],
) -> TAReportContext:
"""
Retrieve stock stats indicators for a given ticker symbol and indicator.
@ -280,11 +280,11 @@ class AgentToolkit:
Returns:
BalanceSheetContext: Structured balance sheet analysis with key liquidity and debt metrics.
"""
curr_date_obj = self._parse_date(curr_date)
return self._fundamentaldata_service.get_balance_sheet_context(
symbol=ticker,
start_date=curr_date,
end_date=curr_date,
frequency=freq.lower(),
start_date=curr_date_obj,
end_date=curr_date_obj,
)
@tool
@ -306,11 +306,11 @@ class AgentToolkit:
Returns:
CashFlowContext: Structured cash flow analysis with operating cash flow metrics.
"""
curr_date_obj = self._parse_date(curr_date)
return self._fundamentaldata_service.get_cashflow_context(
symbol=ticker,
start_date=curr_date,
end_date=curr_date,
frequency=freq.lower(),
start_date=curr_date_obj,
end_date=curr_date_obj,
)
@tool
@ -332,11 +332,11 @@ class AgentToolkit:
Returns:
IncomeStatementContext: Structured income statement analysis with profitability metrics.
"""
curr_date_obj = self._parse_date(curr_date)
return self._fundamentaldata_service.get_income_statement_context(
symbol=ticker,
start_date=curr_date,
end_date=curr_date,
frequency=freq.lower(),
start_date=curr_date_obj,
end_date=curr_date_obj,
)
def _calculate_date_range(
@ -359,7 +359,9 @@ class AgentToolkit:
curr_date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
except ValueError as e:
logger.error(f"Invalid date format '{curr_date}': {e}")
raise ValueError(f"Date must be in YYYY-MM-DD format, got: {curr_date}")
raise ValueError(
f"Date must be in YYYY-MM-DD format, got: {curr_date}"
) from e
if lookback_days is None:
lookback_days = self._config.default_lookback_days
@ -367,6 +369,27 @@ class AgentToolkit:
start_date_obj = curr_date_obj - timedelta(days=lookback_days)
return start_date_obj.strftime("%Y-%m-%d"), curr_date
def _parse_date(self, date_str: str) -> date:
"""
Convert string date to date object.
Args:
date_str: Date string in YYYY-MM-DD format
Returns:
date object
Raises:
ValueError: If date format is invalid
"""
try:
return datetime.strptime(date_str, "%Y-%m-%d").date()
except ValueError as e:
logger.error(f"Invalid date format '{date_str}': {e}")
raise ValueError(
f"Date must be in YYYY-MM-DD format, got '{date_str}'"
) from e
def _validate_ticker(self, ticker: str) -> str:
"""
Validate and sanitize ticker symbol.

View File

@ -310,3 +310,24 @@ def is_high_quality_data(context_json: str) -> bool:
"""Quick check if context contains high quality data."""
context = ContextParser.parse_context(context_json)
return ContextParser.is_high_quality(context)
def create_msg_delete():
"""
Create a message deletion node function for LangGraph workflows.
This function returns a node that clears all messages from the state,
which is useful for preventing context pollution between different
phases of multi-agent workflows.
Returns:
Callable: A function that can be used as a LangGraph node
"""
from langchain_core.messages import RemoveMessage
from langgraph.graph.message import REMOVE_ALL_MESSAGES
def delete_messages(state):
"""Delete all messages from the current state."""
return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)]}
return delete_messages

View File

@ -1,7 +0,0 @@
"""
Client classes for live data access in TradingAgents.
"""
from .base import BaseClient
__all__ = ["BaseClient"]

View File

@ -1,23 +0,0 @@
"""
Base client interface for TradingAgents data sources.
"""
from abc import ABC, abstractmethod
from typing import Any
class BaseClient(ABC):
"""Abstract base class for all data clients."""
@abstractmethod
def get_data(self, **kwargs) -> dict[str, Any]:
"""
Get data from the client source.
Args:
**kwargs: Client-specific parameters
Returns:
dict: Data dictionary with standardized structure
"""
pass

View File

@ -4,10 +4,16 @@ Finnhub client for financial data access.
import logging
from datetime import date
from typing import Any
import finnhub
from ..models import (
CompanyProfile,
InsiderSentimentResponse,
InsiderTransactionsResponse,
ReportedFinancialsResponse,
)
logger = logging.getLogger(__name__)
@ -31,116 +37,33 @@ class FinnhubClient:
Args:
api_key: Finnhub API key
"""
self.api_key = api_key
self.client = finnhub.Client(api_key=api_key)
def test_connection(self) -> bool:
"""Test if the Finnhub API connection is working."""
try:
# Test with a simple quote request
response = self.client.quote("AAPL")
return "c" in response # 'c' is current price field
except Exception as e:
logger.error(f"Finnhub connection test failed: {e}")
return False
def get_balance_sheet(
self, symbol: str, frequency: str, report_date: date
) -> dict[str, Any]:
def get_reported_financials(
self, symbol: str, frequency: str
) -> ReportedFinancialsResponse:
"""
Get balance sheet data from Finnhub.
Get reported financials data from Finnhub.
Args:
symbol: Stock symbol (e.g., 'AAPL')
frequency: Reporting frequency ('quarterly' or 'annual')
report_date: Report date as date object
Returns:
Balance sheet data from Finnhub API
Reported financials data from Finnhub API
"""
try:
# Finnhub SDK expects frequency as 'quarterly' or 'annual'
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
return response if isinstance(response, dict) else {"data": []}
return ReportedFinancialsResponse.from_dict(response)
except Exception as e:
logger.error(f"Error fetching balance sheet for {symbol}: {e}")
return {"data": []}
def get_income_statement(
self, symbol: str, frequency: str, report_date: date
) -> dict[str, Any]:
"""
Get income statement data from Finnhub.
Args:
symbol: Stock symbol (e.g., 'AAPL')
frequency: Reporting frequency ('quarterly' or 'annual')
report_date: Report date as date object
Returns:
Income statement data from Finnhub API
"""
try:
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
return response if isinstance(response, dict) else {"data": []}
except Exception as e:
logger.error(f"Error fetching income statement for {symbol}: {e}")
return {"data": []}
def get_cash_flow(
self, symbol: str, frequency: str, report_date: date
) -> dict[str, Any]:
"""
Get cash flow statement data from Finnhub.
Args:
symbol: Stock symbol (e.g., 'AAPL')
frequency: Reporting frequency ('quarterly' or 'annual')
report_date: Report date as date object
Returns:
Cash flow statement data from Finnhub API
"""
try:
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
return response if isinstance(response, dict) else {"data": []}
except Exception as e:
logger.error(f"Error fetching cash flow for {symbol}: {e}")
return {"data": []}
def get_company_news(
self, symbol: str, start_date: date, end_date: date
) -> list[dict[str, Any]]:
"""
Get company news for a specific symbol and date range.
Args:
symbol: Stock symbol (e.g., 'AAPL')
start_date: Start date as date object
end_date: End date as date object
Returns:
List of news articles
"""
# Convert date objects to strings for API
start_str = start_date.isoformat()
end_str = end_date.isoformat()
try:
response = self.client.company_news(
symbol.upper(), _from=start_str, to=end_str
)
return response if isinstance(response, list) else []
except Exception as e:
logger.error(f"Error fetching news for {symbol}: {e}")
return []
logger.error(f"Error fetching reported financials for {symbol}: {e}")
raise
def get_insider_transactions(
self, symbol: str, start_date: date, end_date: date
) -> dict[str, Any]:
) -> InsiderTransactionsResponse:
"""
Get insider transactions for a company.
@ -159,14 +82,18 @@ class FinnhubClient:
response = self.client.stock_insider_transactions(
symbol.upper(), _from=start_str, to=end_str
)
return response if isinstance(response, dict) else {"data": []}
if isinstance(response, dict):
return InsiderTransactionsResponse.from_dict(response)
else:
# Return empty response if API returns unexpected format
return InsiderTransactionsResponse(data=[], symbol=symbol.upper())
except Exception as e:
logger.error(f"Error fetching insider transactions for {symbol}: {e}")
return {"data": []}
raise
def get_insider_sentiment(
self, symbol: str, start_date: date, end_date: date
) -> dict[str, Any]:
) -> InsiderSentimentResponse:
"""
Get insider sentiment data for a company.
@ -185,29 +112,16 @@ class FinnhubClient:
response = self.client.stock_insider_sentiment(
symbol.upper(), _from=start_str, to=end_str
)
return response if isinstance(response, dict) else {"data": []}
if isinstance(response, dict):
return InsiderSentimentResponse.from_dict(response)
else:
# Return empty response if API returns unexpected format
return InsiderSentimentResponse(data=[], symbol=symbol.upper())
except Exception as e:
logger.error(f"Error fetching insider sentiment for {symbol}: {e}")
return {"data": []}
raise
def get_quote(self, symbol: str) -> dict[str, Any]:
"""
Get current quote for a symbol.
Args:
symbol: Stock symbol
Returns:
Quote data with current price, change, etc.
"""
try:
response = self.client.quote(symbol.upper())
return response if isinstance(response, dict) else {}
except Exception as e:
logger.error(f"Error fetching quote for {symbol}: {e}")
return {}
def get_company_profile(self, symbol: str) -> dict[str, Any]:
def get_company_profile(self, symbol: str) -> CompanyProfile:
"""
Get company profile information.
@ -219,20 +133,7 @@ class FinnhubClient:
"""
try:
response = self.client.company_profile2(symbol=symbol.upper())
return response if isinstance(response, dict) else {}
return CompanyProfile.from_dict(response)
except Exception as e:
logger.error(f"Error fetching company profile for {symbol}: {e}")
return {}
def get_client_info(self) -> dict[str, Any]:
"""
Get information about this client.
Returns:
Dict[str, Any]: Client metadata
"""
return {
"client_type": self.__class__.__name__,
"api_key_set": bool(self.api_key),
"sdk_version": getattr(finnhub, "__version__", "unknown"),
}
raise

View File

@ -1,320 +0,0 @@
#!/usr/bin/env python3
"""
Comprehensive tests for FinnhubClient using pytest-vcr.
This test suite records real API interactions with Finnhub and replays them
in subsequent runs, providing realistic testing without network dependencies.
"""
import os
from datetime import date
import pytest
# NOTE: FinnhubClient was removed - this test needs to be updated
# from tradingagents.clients.finnhub_client import FinnhubClient
pytest.skip("FinnhubClient implementation removed", allow_module_level=True)
@pytest.fixture
def finnhub_client():
"""Create FinnhubClient with test API key."""
# Use environment variable or test key for VCR recording
api_key = os.getenv("FINNHUB_API_KEY", "test_api_key")
return FinnhubClient(api_key=api_key)
@pytest.fixture
def vcr_config():
"""Configure VCR with proper settings."""
return {
"filter_headers": ["X-Finnhub-Token"], # Filter out API key from recordings
"record_mode": "once", # Record once, then replay
"match_on": ["uri", "method"],
"decode_compressed_response": True,
}
class TestFinnhubClientConnection:
"""Test client connection and initialization."""
@pytest.mark.vcr
def test_client_initialization(self, finnhub_client):
"""Test that client initializes correctly."""
assert finnhub_client.api_key is not None
assert finnhub_client.client is not None
@pytest.mark.vcr
def test_connection_success(self, finnhub_client):
"""Test successful API connection."""
assert finnhub_client.test_connection() is True
def test_client_info(self, finnhub_client):
"""Test client metadata."""
info = finnhub_client.get_client_info()
assert info["client_type"] == "FinnhubClient"
assert info["api_key_set"] is True
class TestFundamentalDataMethods:
"""Test the fundamental data methods needed by FundamentalDataService."""
@pytest.mark.vcr
def test_get_balance_sheet_quarterly(self, finnhub_client):
"""Test balance sheet retrieval with date object."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
assert isinstance(data, dict)
# Finnhub financials_reported returns data structure with 'data' key
assert "data" in data or len(data) > 0
@pytest.mark.vcr
def test_get_balance_sheet_annual(self, finnhub_client):
"""Test annual balance sheet retrieval."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_balance_sheet("AAPL", "annual", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_get_income_statement_quarterly(self, finnhub_client):
"""Test income statement retrieval."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_income_statement("AAPL", "quarterly", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_get_income_statement_annual(self, finnhub_client):
"""Test annual income statement retrieval."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_income_statement("AAPL", "annual", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_get_cash_flow_with_date_object(self, finnhub_client):
"""Test cash flow retrieval with date object."""
test_date = date(2024, 3, 31)
data = finnhub_client.get_cash_flow("AAPL", "quarterly", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_fundamental_data_different_symbols(self, finnhub_client):
"""Test fundamental data for different symbols."""
symbols = ["AAPL", "MSFT", "GOOGL"]
test_date = date(2024, 1, 1)
for symbol in symbols:
data = finnhub_client.get_balance_sheet(symbol, "quarterly", test_date)
assert isinstance(data, dict)
class TestExistingMethods:
"""Test existing methods with enhanced date support."""
@pytest.mark.vcr
def test_get_company_news_january(self, finnhub_client):
"""Test company news for January 2024."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
news = finnhub_client.get_company_news("AAPL", start_date, end_date)
assert isinstance(news, list)
@pytest.mark.vcr
def test_get_company_news_date_objects(self, finnhub_client):
"""Test company news with date objects."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
news = finnhub_client.get_company_news("AAPL", start_date, end_date)
assert isinstance(news, list)
@pytest.mark.vcr
def test_get_insider_transactions_date_objects(self, finnhub_client):
"""Test insider transactions with date objects."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
data = finnhub_client.get_insider_transactions("AAPL", start_date, end_date)
assert isinstance(data, dict)
assert "data" in data
@pytest.mark.vcr
def test_get_insider_sentiment(self, finnhub_client):
"""Test insider sentiment with date objects."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
data = finnhub_client.get_insider_sentiment("AAPL", start_date, end_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_get_quote(self, finnhub_client):
"""Test stock quote retrieval."""
quote = finnhub_client.get_quote("AAPL")
assert isinstance(quote, dict)
# Quote should contain current price 'c' field
if quote: # Only check if we got data
assert "c" in quote
@pytest.mark.vcr
def test_get_company_profile(self, finnhub_client):
"""Test company profile retrieval."""
profile = finnhub_client.get_company_profile("AAPL")
assert isinstance(profile, dict)
class TestErrorHandling:
"""Test error handling and edge cases."""
@pytest.mark.vcr
def test_invalid_symbol_balance_sheet(self, finnhub_client):
"""Test balance sheet with invalid symbol."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_balance_sheet(
"INVALID_SYMBOL_XYZ", "quarterly", test_date
)
# Should return dict with empty data structure, not raise exception
assert isinstance(data, dict)
assert "data" in data
@pytest.mark.vcr
def test_invalid_symbol_news(self, finnhub_client):
"""Test news with invalid symbol."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
news = finnhub_client.get_company_news(
"INVALID_SYMBOL_XYZ", start_date, end_date
)
# Should return empty list, not raise exception
assert isinstance(news, list)
def test_connection_with_invalid_api_key(self):
"""Test connection failure with invalid API key."""
client = FinnhubClient("invalid_api_key")
# This should return False, not raise an exception
assert client.test_connection() is False
@pytest.mark.vcr
def test_frequency_normalization(self, finnhub_client):
"""Test that frequency parameters are normalized correctly."""
# Test different frequency formats
frequencies = ["quarterly", "QUARTERLY", "q", "Q", "annual", "ANNUAL", "a", "A"]
test_date = date(2024, 1, 1)
for freq in frequencies:
data = finnhub_client.get_balance_sheet("AAPL", freq, test_date)
assert isinstance(data, dict)
def test_date_edge_cases(self, finnhub_client):
"""Test date edge cases."""
# Test with year-end date
test_date = date(2024, 12, 31)
# This shouldn't raise exceptions
# (actual API calls are mocked/recorded)
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
assert isinstance(data, dict)
class TestMultipleSymbolsAndTimeframes:
"""Test with multiple symbols and different timeframes."""
@pytest.mark.vcr
def test_multiple_symbols_balance_sheet(self, finnhub_client):
"""Test balance sheet for multiple major symbols."""
symbols = ["AAPL", "MSFT", "GOOGL", "TSLA"]
test_date = date(2024, 1, 1)
for symbol in symbols:
data = finnhub_client.get_balance_sheet(symbol, "quarterly", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_quarterly_vs_annual_frequency(self, finnhub_client):
"""Test quarterly vs annual frequency."""
symbol = "AAPL"
test_date = date(2024, 1, 1)
quarterly_data = finnhub_client.get_balance_sheet(
symbol, "quarterly", test_date
)
annual_data = finnhub_client.get_balance_sheet(symbol, "annual", test_date)
assert isinstance(quarterly_data, dict)
assert isinstance(annual_data, dict)
@pytest.mark.vcr
def test_all_fundamental_methods_same_symbol(self, finnhub_client):
"""Test all fundamental methods for the same symbol."""
symbol = "AAPL"
frequency = "quarterly"
report_date = date(2024, 1, 1)
balance_sheet = finnhub_client.get_balance_sheet(symbol, frequency, report_date)
income_statement = finnhub_client.get_income_statement(
symbol, frequency, report_date
)
cash_flow = finnhub_client.get_cash_flow(symbol, frequency, report_date)
assert isinstance(balance_sheet, dict)
assert isinstance(income_statement, dict)
assert isinstance(cash_flow, dict)
# Integration test to verify the client works with service expectations
class TestServiceIntegration:
"""Test that client works as expected by FundamentalDataService."""
def test_service_expected_methods_exist(self, finnhub_client):
"""Test that all methods expected by FundamentalDataService exist."""
# These are the methods the service calls
assert hasattr(finnhub_client, "get_balance_sheet")
assert hasattr(finnhub_client, "get_income_statement")
assert hasattr(finnhub_client, "get_cash_flow")
# Verify method signatures accept the expected parameters
import inspect
for method_name in [
"get_balance_sheet",
"get_income_statement",
"get_cash_flow",
]:
method = getattr(finnhub_client, method_name)
sig = inspect.signature(method)
params = list(sig.parameters.keys())
# Service calls: method(symbol, frequency, date)
assert "symbol" in params
assert "frequency" in params
assert "report_date" in params
@pytest.mark.vcr
def test_data_format_for_service_conversion(self, finnhub_client):
"""Test that returned data format can be processed by service conversion method."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
# Service expects either empty dict or dict with data
assert isinstance(data, dict)
# The service's _convert_to_financial_statement expects either:
# 1. Empty/falsy data -> returns None
# 2. Dict with "data" key containing the actual financial data
if data: # If we got data
# Should be able to access data["data"] or similar structure
# This validates the format is compatible with service expectations
assert isinstance(data, dict)

View File

@ -3,6 +3,19 @@ Fundamental Data Service for aggregating and analyzing financial statement data.
"""
import logging
from datetime import date
from tradingagents.config import TradingAgentsConfig
from .clients.finnhub_client import FinnhubClient
from .models import (
BalanceSheetContext,
CashFlowContext,
DataQuality,
FundamentalContext,
IncomeStatementContext,
)
from .repos.fundamental_data_repository import FundamentalDataRepository
logger = logging.getLogger(__name__)
@ -12,35 +25,30 @@ class FundamentalDataService:
def __init__(
self,
simfin_client: SimFinClient,
finnhub_client: FinnhubClient,
repository: FundamentalDataRepository,
):
"""Initialize Fundamental Data Service.
Args:
simfin_client: Client for SimFin/financial API access
finnhub_client: Client for FinnHub financial API access
repository: Repository for cached fundamental data
online_mode: Whether to fetch live data
data_dir: Directory for data storage
online_mode: Whether to fetch live data or use cached data
"""
self.simfin_client = simfin_client
self.finnhub_client = finnhub_client
self.repository = repository
def update_fundamental_data(
self,
symbol: str,
start_date: str,
end_date: str,
frequency: str = "quarterly",
) -> FundamentalContext:
pass # TODO: fetch fundementals from simfin, save in repo
@staticmethod
def build(_config: TradingAgentsConfig):
client = FinnhubClient("")
repo = FundamentalDataRepository("")
return FundamentalDataService(client, repo)
def get_fundamental_context(
self,
symbol: str,
start_date: str,
end_date: str,
frequency: str = "quarterly",
start_date: date,
end_date: date,
) -> FundamentalContext:
"""Get fundamental analysis context for a company.
@ -49,34 +57,135 @@ class FundamentalDataService:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency ('quarterly' or 'annual')
force_refresh: If True, skip local data and fetch fresh from APIs
Returns:
FundamentalContext with financial statements and key ratios
"""
balance_sheet = None
income_statement = None
cash_flow = None
error_info = {}
errors = []
data_source = "unknown"
# TODO: implement
return FundamentalContext(
symbol=symbol, start_date=start_date, end_date=end_date
)
# return FundamentalContext(
# symbol=symbol,
# period={"start": start_date, "end": end_date},
# balance_sheet=balance_sheet,
# income_statement=income_statement,
# cash_flow=cash_flow,
# key_ratios=key_ratios,
# metadata={
# "data_quality": data_quality,
# "service": "fundamental_data",
# "online_mode": self.is_online(),
# "frequency": frequency,
# "data_source": data_source,
# "force_refresh": force_refresh,
# **error_info,
# },
# )
def get_balance_sheet_context(
self,
symbol: str,
start_date: date,
end_date: date,
) -> BalanceSheetContext:
"""Get balance sheet context for a company.
pass # TODO: read data from repo
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency ('quarterly' or 'annual')
Returns:
BalanceSheetContext with balance sheet data and ratios
"""
# TODO: implement
# Return empty context if no data
return BalanceSheetContext(
symbol=symbol,
start_date=start_date,
end_date=end_date,
balance_sheet_data=[],
key_ratios=[],
data_quality=DataQuality.LOW,
source="none",
metadata={"error": "No balance sheet data available"},
)
def get_income_statement_context(
self,
symbol: str,
start_date: date,
end_date: date,
) -> IncomeStatementContext:
"""Get income statement context for a company.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency ('quarterly' or 'annual')
Returns:
IncomeStatementContext with income statement data and ratios
"""
# TODO: implement
# Return empty context if no data
return IncomeStatementContext(
symbol=symbol,
start_date=start_date,
end_date=end_date,
income_statement_data=[],
key_ratios=[],
data_quality=DataQuality.LOW,
source="none",
metadata={"error": "No income statement data available"},
)
def get_cashflow_context(
self,
symbol: str,
start_date: date,
end_date: date,
) -> CashFlowContext:
"""Get cash flow context for a company.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency ('quarterly' or 'annual')
Returns:
CashFlowContext with cash flow data and ratios
"""
# TODOL implement
# Return empty context if no data
return CashFlowContext(
symbol=symbol,
start_date=start_date,
end_date=end_date,
cash_flow_data=[],
key_ratios=[],
data_quality=DataQuality.LOW,
source="none",
metadata={"error": "No cash flow data available"},
)
def update_fundamental_data(
self,
symbol: str,
date: date,
frequency: str = "quarterly",
) -> bool:
"""Update fundamental data by fetching from FinnHub and storing in repository.
Args:
symbol: Stock ticker symbol
date: Date for the financial data
frequency: Reporting frequency ('quarterly' or 'annual')
Returns:
bool: True if update was successful
"""
try:
# Fetch reported financials data from FinnHub using the unified method
reported_financials = self.finnhub_client.get_reported_financials(
symbol, frequency
)
# Store the reported financials data in repository
return self.repository.store_reported_financials(
symbol=symbol, date=date, reported_financials=reported_financials
)
except Exception as e:
logger.error(f"Error updating fundamental data for {symbol}: {e}")
return False

View File

@ -1,495 +0,0 @@
"""
Test FundamentalDataService with mock SimFin clients and real FundamentalDataRepository.
"""
import tempfile
from datetime import datetime
from typing import Any
import pytest
from tradingagents.clients.base import BaseClient
from tradingagents.domains.marketdata.fundamental_data_service import (
DataQuality,
FinancialStatement,
FundamentalContext,
FundamentalDataService,
)
from tradingagents.repositories.fundamental_repository import FundamentalDataRepository
class MockSimFinClient(BaseClient):
"""Mock SimFin client that returns sample financial statement data."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connection_works = True
def test_connection(self) -> bool:
return self.connection_works
def get_data(self, *args, **kwargs) -> dict[str, Any]:
"""Not used directly by FundamentalDataService."""
return {}
def get_balance_sheet(
self, ticker: str, freq: str, curr_date: str
) -> dict[str, Any]:
"""Return mock balance sheet data."""
return {
"ticker": ticker,
"statement_type": "balance_sheet",
"frequency": freq,
"period": "Q3-2024" if freq == "quarterly" else "2024",
"report_date": "2024-09-30",
"publish_date": "2024-10-30",
"currency": "USD",
"data": {
"Total Assets": 365725000000.0,
"Total Current Assets": 143566000000.0,
"Cash and Cash Equivalents": 28965000000.0,
"Short-term Investments": 31590000000.0,
"Accounts Receivable": 13348000000.0,
"Inventory": 6511000000.0,
"Total Non-Current Assets": 222159000000.0,
"Property, Plant & Equipment": 43715000000.0,
"Intangible Assets": 11235000000.0,
"Total Liabilities": 279414000000.0,
"Total Current Liabilities": 136817000000.0,
"Accounts Payable": 58146000000.0,
"Short-term Debt": 20208000000.0,
"Total Non-Current Liabilities": 142597000000.0,
"Long-term Debt": 106550000000.0,
"Total Shareholders Equity": 86311000000.0,
"Retained Earnings": 672000000.0,
"Common Stock": 77958000000.0,
},
"metadata": {
"source": "mock_simfin",
"retrieved_at": datetime(2024, 1, 2).isoformat(),
},
}
def get_income_statement(
self, ticker: str, freq: str, curr_date: str
) -> dict[str, Any]:
"""Return mock income statement data."""
return {
"ticker": ticker,
"statement_type": "income_statement",
"frequency": freq,
"period": "Q3-2024" if freq == "quarterly" else "2024",
"report_date": "2024-09-30",
"publish_date": "2024-10-30",
"currency": "USD",
"data": {
"Total Revenue": 94930000000.0,
"Cost of Revenue": 55720000000.0,
"Gross Profit": 39210000000.0,
"Operating Expenses": 15706000000.0,
"Research and Development": 8067000000.0,
"Sales, General & Administrative": 7639000000.0,
"Operating Income": 23504000000.0,
"Interest Expense": 1013000000.0,
"Other Income": 269000000.0,
"Income Before Tax": 22760000000.0,
"Tax Provision": 4438000000.0,
"Net Income": 18322000000.0,
"Basic EPS": 1.18,
"Diluted EPS": 1.15,
"Shares Outstanding": 15550193000,
},
"metadata": {
"source": "mock_simfin",
"retrieved_at": datetime(2024, 1, 2).isoformat(),
},
}
def get_cash_flow(self, ticker: str, freq: str, curr_date: str) -> dict[str, Any]:
"""Return mock cash flow statement data."""
return {
"ticker": ticker,
"statement_type": "cash_flow",
"frequency": freq,
"period": "Q3-2024" if freq == "quarterly" else "2024",
"report_date": "2024-09-30",
"publish_date": "2024-10-30",
"currency": "USD",
"data": {
"Net Income": 18322000000.0,
"Depreciation & Amortization": 2871000000.0,
"Changes in Working Capital": -1684000000.0,
"Operating Cash Flow": 23302000000.0,
"Capital Expenditures": -2736000000.0,
"Acquisitions": -1800000000.0,
"Asset Sales": 234000000.0,
"Investing Cash Flow": -4302000000.0,
"Dividends Paid": -3746000000.0,
"Share Repurchases": -24979000000.0,
"Debt Proceeds": 750000000.0,
"Debt Repayment": -1500000000.0,
"Financing Cash Flow": -28475000000.0,
"Free Cash Flow": 20566000000.0,
"Net Change in Cash": -9475000000.0,
},
"metadata": {
"source": "mock_simfin",
"retrieved_at": datetime(2024, 1, 2).isoformat(),
},
}
@pytest.fixture
def temp_data_dir():
"""Create a temporary directory for test data and clean up after test."""
with tempfile.TemporaryDirectory(prefix="fundamental_test_") as temp_dir:
yield temp_dir
@pytest.fixture
def mock_simfin_client():
"""Create a mock SimFin client for testing."""
return MockSimFinClient()
@pytest.fixture
def broken_simfin_client():
"""Create a broken SimFin client for error testing."""
class BrokenSimFinClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_balance_sheet(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_income_statement(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_cash_flow(self, *args, **kwargs):
raise Exception("SimFin API error")
return BrokenSimFinClient()
@pytest.fixture
def partial_data_client():
"""Create a client that returns partial data for testing."""
class PartialDataClient(MockSimFinClient):
def get_cash_flow(self, ticker, freq, curr_date):
# Simulate missing cash flow data
raise Exception("Cash flow data not available")
def get_income_statement(self, ticker, freq, curr_date):
# Simulate missing income statement
return {"data": {}, "metadata": {"error": "No data found"}}
return PartialDataClient()
def test_online_mode_with_mock_simfin(temp_data_dir, mock_simfin_client):
"""Test FundamentalDataService in online mode with mock SimFin client."""
# Create real repository with temporary directory
real_repo = FundamentalDataRepository(temp_data_dir)
# Create service with mock client and real repository
service = FundamentalDataService(
simfin_client=mock_simfin_client,
repository=real_repo,
data_dir=temp_data_dir,
)
# Test getting fundamental context with all three statements
context = service.get_fundamental_context(
symbol="AAPL",
start_date="2024-01-01",
end_date="2024-12-31",
frequency="quarterly",
force_refresh=True, # Force using mock client instead of cache
)
# Validate context structure
assert isinstance(context, FundamentalContext)
assert context.symbol == "AAPL"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-12-31"
# Validate financial statements
assert context.balance_sheet is not None
assert isinstance(context.balance_sheet, FinancialStatement)
assert context.balance_sheet.period == "Q3-2024"
assert context.balance_sheet.currency == "USD"
assert "Total Assets" in context.balance_sheet.data
assert context.income_statement is not None
assert isinstance(context.income_statement, FinancialStatement)
assert "Total Revenue" in context.income_statement.data
assert "Net Income" in context.income_statement.data
assert context.cash_flow is not None
assert isinstance(context.cash_flow, FinancialStatement)
assert "Operating Cash Flow" in context.cash_flow.data
assert "Free Cash Flow" in context.cash_flow.data
# Validate key ratios calculation
assert len(context.key_ratios) > 0
assert "current_ratio" in context.key_ratios
assert "debt_to_equity" in context.key_ratios
assert "roe" in context.key_ratios # Return on Equity
assert "gross_margin" in context.key_ratios
# Validate metadata
assert "data_quality" in context.metadata
assert context.metadata["service"] == "fundamental_data"
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
assert len(json_output) > 0
def test_annual_vs_quarterly_frequency(temp_data_dir, mock_simfin_client):
"""Test different reporting frequencies."""
real_repo = FundamentalDataRepository(temp_data_dir)
service = FundamentalDataService(
simfin_client=mock_simfin_client, repository=real_repo, data_dir=temp_data_dir
)
# Test quarterly
quarterly_context = service.get_fundamental_context(
symbol="MSFT",
start_date="2024-01-01",
end_date="2024-12-31",
frequency="quarterly",
)
assert quarterly_context.balance_sheet is not None
assert quarterly_context.balance_sheet.period == "Q3-2024"
# Test annual
annual_context = service.get_fundamental_context(
symbol="MSFT",
start_date="2024-01-01",
end_date="2024-12-31",
frequency="annual",
)
assert annual_context.balance_sheet is not None
assert annual_context.balance_sheet.period == "2024"
def test_financial_ratio_calculations(temp_data_dir, mock_simfin_client):
"""Test calculation of key financial ratios."""
real_repo = FundamentalDataRepository(temp_data_dir)
service = FundamentalDataService(
simfin_client=mock_simfin_client, repository=real_repo, data_dir=temp_data_dir
)
context = service.get_fundamental_context("TSLA", "2024-01-01", "2024-12-31")
# Check that key ratios are calculated
ratios = context.key_ratios
# Liquidity ratios
assert "current_ratio" in ratios
assert ratios["current_ratio"] > 0
# Leverage ratios
assert "debt_to_equity" in ratios
assert ratios["debt_to_equity"] >= 0
# Profitability ratios
assert "gross_margin" in ratios
assert "operating_margin" in ratios
assert "net_margin" in ratios
assert "roe" in ratios # Return on Equity
assert "roa" in ratios # Return on Assets
# Efficiency ratios
assert "asset_turnover" in ratios
# Validate ratio calculations are reasonable
assert 0 <= ratios["gross_margin"] <= 1
assert 0 <= ratios["net_margin"] <= 1
def test_offline_mode(temp_data_dir):
"""Test FundamentalDataService without a client (offline mode)."""
real_repo = FundamentalDataRepository(temp_data_dir)
service = FundamentalDataService(
simfin_client=None, repository=real_repo, data_dir=temp_data_dir
)
# Should handle offline gracefully
context = service.get_fundamental_context("AAPL", "2024-01-01", "2024-12-31")
assert context.symbol == "AAPL"
assert context.balance_sheet is None # No data available offline
assert context.income_statement is None
assert context.cash_flow is None
assert len(context.key_ratios) == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_partial_data_handling():
"""Test handling when only some financial statements are available."""
class PartialDataClient(MockSimFinClient):
def get_cash_flow(self, ticker, freq, curr_date):
# Simulate missing cash flow data
raise Exception("Cash flow data not available")
def get_income_statement(self, ticker, freq, curr_date):
# Simulate missing income statement
return {"data": {}, "metadata": {"error": "No data found"}}
partial_client = PartialDataClient()
service = FundamentalDataService(
simfin_client=partial_client, repository=None, online_mode=True
)
context = service.get_fundamental_context("XYZ", "2024-01-01", "2024-12-31")
# Should have balance sheet but not others
assert context.balance_sheet is not None
assert context.income_statement is None # Failed to load
assert context.cash_flow is None # Failed to load
# Ratios should be limited without full data (only balance sheet ratios available)
assert (
len(context.key_ratios) <= 8
) # Only balance sheet ratios possible, no profitability ratios
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_error_handling():
"""Test error handling with broken client."""
class BrokenSimFinClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_balance_sheet(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_income_statement(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_cash_flow(self, *args, **kwargs):
raise Exception("SimFin API error")
broken_client = BrokenSimFinClient()
service = FundamentalDataService(
simfin_client=broken_client, repository=None, online_mode=True
)
# Should handle errors gracefully
context = service.get_fundamental_context(
"FAIL", "2024-01-01", "2024-12-31", force_refresh=True
)
assert context.symbol == "FAIL"
assert context.balance_sheet is None
assert context.income_statement is None
assert context.cash_flow is None
assert len(context.key_ratios) == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
# Service logs errors but doesn't include them in metadata
def test_json_structure():
"""Test JSON structure of fundamental context."""
mock_simfin = MockSimFinClient()
service = FundamentalDataService(
simfin_client=mock_simfin, repository=None, online_mode=True
)
context = service.get_fundamental_context("NVDA", "2024-01-01", "2024-12-31")
json_data = context.model_dump()
# Validate required fields
required_fields = [
"symbol",
"period",
"balance_sheet",
"income_statement",
"cash_flow",
"key_ratios",
"metadata",
]
for field in required_fields:
assert field in json_data
# Validate financial statement structure
if json_data["balance_sheet"]:
balance_sheet = json_data["balance_sheet"]
required_statement_fields = [
"period",
"report_date",
"publish_date",
"currency",
"data",
]
for field in required_statement_fields:
assert field in balance_sheet
# Check some key balance sheet items
bs_data = balance_sheet["data"]
assert "Total Assets" in bs_data
assert "Total Liabilities" in bs_data
assert "Total Shareholders Equity" in bs_data
# Validate key ratios
ratios = json_data["key_ratios"]
assert isinstance(ratios, dict)
assert len(ratios) > 0
# Validate metadata
metadata = json_data["metadata"]
assert "data_quality" in metadata
assert "service" in metadata
def test_comprehensive_ratio_calculation():
"""Test comprehensive financial ratio calculations."""
mock_simfin = MockSimFinClient()
service = FundamentalDataService(
simfin_client=mock_simfin, repository=None, online_mode=True
)
context = service.get_fundamental_context("COMP", "2024-01-01", "2024-12-31")
ratios = context.key_ratios
# Liquidity ratios
# Not all ratios may be calculable depending on available data
calculated_ratios = set(ratios.keys())
core_ratios = {
"current_ratio",
"debt_to_equity",
"gross_margin",
"net_margin",
"roe",
"roa",
}
# At least the core ratios should be present
assert core_ratios.issubset(calculated_ratios), (
f"Missing core ratios: {core_ratios - calculated_ratios}"
)
# All ratio values should be numbers
for ratio_name, ratio_value in ratios.items():
assert isinstance(ratio_value, int | float), (
f"{ratio_name} should be numeric, got {type(ratio_value)}"
)
assert ratio_value == ratio_value, (
f"{ratio_name} should not be NaN"
) # NaN check

View File

@ -7,11 +7,26 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any
from tradingagents.domains.marketdata.finnhub_client import FinnhubClient
from tradingagents.config import TradingAgentsConfig
from .clients.finnhub_client import FinnhubClient
logger = logging.getLogger(__name__)
class InsiderDataRepository:
"""Simple repository for insider data - placeholder implementation."""
def __init__(self, data_dir: str):
self.data_dir = data_dir
def get_data(self, symbol: str, start_date: str, end_date: str) -> dict:
return {}
def store_data(self, symbol: str, data: dict) -> bool:
return True
class DataQuality(Enum):
"""Data quality levels for insider data."""
@ -85,6 +100,12 @@ class InsiderDataService:
self.client = client
self.repository = repository
@staticmethod
def build(_config: TradingAgentsConfig):
client = FinnhubClient("")
repo = InsiderDataRepository("")
return InsiderDataService(client, repo)
def get_insider_sentiment_context(
self,
symbol: str,

View File

@ -3,90 +3,57 @@ Market data service that provides structured market context.
"""
import logging
from dataclasses import dataclass
from enum import Enum
from datetime import datetime
from typing import Any
import pandas as pd
import talib
from tradingagents.config import TradingAgentsConfig
from tradingagents.domains.marketdata.clients.yfinance_client import YFinanceClient
from tradingagents.domains.marketdata.models import (
INDICATOR_DEFINITIONS,
DataQuality,
IndicatorConfig,
IndicatorParamValue,
IndicatorPresets,
InputSpec,
PriceDataContext,
TAReportContext,
TechnicalAnalysisError,
TechnicalIndicatorData,
)
from tradingagents.domains.marketdata.repos.market_data_repository import (
MarketDataRepository,
)
logger = logging.getLogger(__name__)
class DataQuality(Enum):
"""Data quality levels for market data."""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
@dataclass
class TechnicalIndicatorData:
"""Technical indicator data point."""
date: str
value: float | dict[str, Any]
indicator_type: str
@dataclass
class MarketDataContext:
"""Market data context for trading analysis."""
symbol: str
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
price_data: list[dict[str, Any]]
technical_indicators: dict[str, list[TechnicalIndicatorData]]
metadata: dict[str, Any]
@dataclass
class TAReportContext:
"""Technical Analysis Report context for specific indicators."""
symbol: str
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
indicator: str
indicator_data: list[TechnicalIndicatorData]
analysis_summary: str
signal_strength: float # -1.0 to 1.0
recommendation: str # "BUY", "SELL", "HOLD"
metadata: dict[str, Any]
@dataclass
class PriceDataContext:
"""Price Data context for historical price information."""
symbol: str
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
price_data: list[dict[str, Any]]
latest_price: float
price_change: float
price_change_percent: float
volume_info: dict[str, Any]
metadata: dict[str, Any]
class MarketDataService:
"""Service for market data and technical indicators."""
def __init__(
self,
yfin_client: YFinClient,
repo: MarketdataRepository,
yfin_client: YFinanceClient,
repo: MarketDataRepository,
):
"""
Initialize market data service.
Args:
client: Client for live market data
repository: Repository for historical market data
online_mode: Whether to use live data
**kwargs: Additional configuration
yfin_client: Client for live market data
repo: Repository for historical market data
"""
self.finnhub_client = finnhub_client
self.yfin_client = yfin_client
self.repo = repo
@staticmethod
def build(_config: TradingAgentsConfig):
client = YFinanceClient()
repo = MarketDataRepository("")
return MarketDataService(client, repo)
def get_market_data_context(
self, symbol: str, start_date: str, end_date: str
) -> PriceDataContext:
@ -97,26 +64,106 @@ class MarketDataService:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
PriceDataContext: Focused price data context
"""
# return PriceDataContext(
# symbol=symbol,
# period={"start": start_date, "end": end_date},
# price_data=price_data.get("data", []),
# latest_price=latest_price,
# price_change=price_change,
# price_change_percent=price_change_percent,
# volume_info=volume_info,
# metadata=metadata,
# )
try:
# Convert string dates to date objects
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d").date()
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d").date()
pass # TODO: get data from repo
# Get data from repository first
df = self.repo.get_market_data_df(symbol, start_date_obj, end_date_obj)
if df.empty:
# No data in repository, try to fetch from client
logger.info(f"No local data for {symbol}, fetching from client")
client_data = self.yfin_client.get_data(symbol, start_date, end_date)
price_data = client_data.get("data", [])
# Convert to DataFrame and store in repository
if price_data:
df_to_store = pd.DataFrame(price_data)
self.repo.store_marketdata(symbol, df_to_store)
df = df_to_store
else:
# Convert DataFrame to list of dictionaries
price_data = df.to_dict("records")
# Calculate metrics
latest_price = 0.0
price_change = 0.0
price_change_percent = 0.0
volume_info = {"average_volume": 0, "latest_volume": 0}
if not df.empty and "Close" in df.columns:
latest_price = float(df["Close"].iloc[-1])
if len(df) > 1:
previous_price = float(df["Close"].iloc[-2])
price_change = latest_price - previous_price
price_change_percent = (
(price_change / previous_price) * 100
if previous_price != 0
else 0.0
)
if "Volume" in df.columns:
volume_info = {
"average_volume": int(df["Volume"].mean()),
"latest_volume": int(df["Volume"].iloc[-1]),
}
# Convert DataFrame back to list of dicts for price_data
price_data = df.to_dict("records") if not df.empty else []
# Assess data quality
data_quality = DataQuality.HIGH if len(price_data) > 0 else DataQuality.LOW
metadata = {
"data_quality": data_quality.value,
"service": "market_data",
"record_count": len(price_data),
"source": "repository" if not df.empty else "client",
"retrieved_at": datetime.utcnow().isoformat(),
}
return PriceDataContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
price_data=price_data,
latest_price=latest_price,
price_change=price_change,
price_change_percent=price_change_percent,
volume_info=volume_info,
metadata=metadata,
)
except Exception as e:
logger.error(f"Error getting market data context for {symbol}: {e}")
return PriceDataContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
price_data=[],
latest_price=0.0,
price_change=0.0,
price_change_percent=0.0,
volume_info={"average_volume": 0, "latest_volume": 0},
metadata={
"data_quality": DataQuality.LOW.value,
"service": "market_data",
"error": str(e),
"retrieved_at": datetime.utcnow().isoformat(),
},
)
def get_ta_report_context(
self, symbol: str, indicator: str, start_date: str, end_date: str
self,
symbol: str,
indicator: str,
start_date: str,
end_date: str,
custom_params: dict[str, IndicatorParamValue] | None = None,
) -> TAReportContext:
"""
Get technical analysis report context for a specific indicator.
@ -126,24 +173,548 @@ class MarketDataService:
indicator: Technical indicator name (e.g., 'rsi', 'macd', 'sma')
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
TAReportContext: Focused technical analysis context
"""
try:
# Get price data first
price_context = self.get_market_data_context(symbol, start_date, end_date)
# return TAReportContext(
# symbol=symbol,
# period={"start": start_date, "end": end_date},
# indicator=indicator,
# indicator_data=indicator_data.get(indicator, []),
# analysis_summary=analysis_summary,
# signal_strength=signal_strength,
# recommendation=recommendation,
# metadata=metadata,
# )
if not price_context.price_data:
# Create empty indicator config for no data case
no_data_config = IndicatorConfig(
name=indicator.upper(),
parameters={},
input_types=["close"],
output_format="single",
param_ranges={},
default_params={},
talib_function="",
description="",
)
pass # TODO get data from repo and calculate indicator with TALib?
return TAReportContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
indicator=indicator,
indicator_data=[],
analysis_summary="No price data available for technical analysis",
signal_strength=0.0,
recommendation="HOLD",
indicator_config=no_data_config,
parameter_summary="",
metadata={
"data_quality": DataQuality.LOW.value,
"service": "technical_analysis",
"error": "no_price_data",
},
)
# Calculate technical indicator using TA-Lib
indicator_data = self._calculate_indicator_talib(
price_context.price_data, indicator, custom_params
)
# Generate analysis and recommendations
signal_strength = self._calculate_signal_strength(indicator_data, indicator)
recommendation = self._get_recommendation(signal_strength)
analysis_summary = self._generate_analysis_summary(
indicator, signal_strength, recommendation
)
# Create indicator config from the calculation
indicator_config = IndicatorConfig(
name=indicator.upper(),
parameters=indicator_data[0].parameters if indicator_data else {},
input_types=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
"input_types", ["close"]
),
output_format=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
"output_format", "single"
),
param_ranges=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
"param_ranges", {}
),
default_params=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
"default_params", {}
),
talib_function=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
"talib_function", ""
),
description=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
"description", ""
),
)
# Generate parameter summary
params = indicator_data[0].parameters if indicator_data else {}
parameter_summary = ", ".join([f"{k}={v}" for k, v in params.items()])
return TAReportContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
indicator=indicator,
indicator_data=indicator_data,
analysis_summary=analysis_summary,
signal_strength=signal_strength,
recommendation=recommendation,
indicator_config=indicator_config,
parameter_summary=parameter_summary,
metadata={
"data_quality": DataQuality.HIGH.value,
"service": "technical_analysis",
"indicator_count": len(indicator_data),
"retrieved_at": datetime.utcnow().isoformat(),
},
)
except Exception as e:
logger.error(f"Error getting TA report for {symbol} {indicator}: {e}")
# Create empty indicator config for error case
error_config = IndicatorConfig(
name=indicator.upper(),
parameters={},
input_types=["close"],
output_format="single",
param_ranges={},
default_params={},
talib_function="",
description="",
)
return TAReportContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
indicator=indicator,
indicator_data=[],
analysis_summary=f"Error calculating {indicator}: {str(e)}",
signal_strength=0.0,
recommendation="HOLD",
indicator_config=error_config,
parameter_summary="",
metadata={
"data_quality": DataQuality.LOW.value,
"service": "technical_analysis",
"error": str(e),
},
)
def _validate_parameters(
self, indicator: str, params: dict[str, IndicatorParamValue]
) -> None:
"""Validate indicator parameters against defined ranges."""
if indicator.upper() not in INDICATOR_DEFINITIONS:
raise TechnicalAnalysisError(f"Unknown indicator: {indicator}")
definition = INDICATOR_DEFINITIONS[indicator.upper()]
param_ranges = definition.get("param_ranges", {})
for param_name, value in params.items():
if param_name in param_ranges:
min_val, max_val = param_ranges[param_name]
if not isinstance(value, int | float):
raise TechnicalAnalysisError(
f"Parameter {param_name} must be numeric"
)
if not (min_val <= value <= max_val):
raise TechnicalAnalysisError(
f"Parameter {param_name}={value} out of range [{min_val}, {max_val}]"
)
def _prepare_price_arrays(
self, price_data: list[dict[str, Any]], input_types: list[InputSpec]
) -> dict[str, Any]:
"""Prepare price arrays for TA-Lib functions."""
if not price_data:
raise TechnicalAnalysisError("No price data provided")
df = pd.DataFrame(price_data)
required_columns = []
for input_type in input_types:
if input_type == "close":
required_columns.extend(["Close"])
elif input_type == "ohlc":
required_columns.extend(["Open", "High", "Low", "Close"])
elif input_type == "ohlcv":
required_columns.extend(["Open", "High", "Low", "Close", "Volume"])
elif input_type == "hl":
required_columns.extend(["High", "Low"])
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise TechnicalAnalysisError(f"Missing required columns: {missing_columns}")
# Convert to numpy arrays for TA-Lib
arrays = {}
if "Open" in df.columns:
arrays["open"] = df["Open"].astype(float).values
if "High" in df.columns:
arrays["high"] = df["High"].astype(float).values
if "Low" in df.columns:
arrays["low"] = df["Low"].astype(float).values
if "Close" in df.columns:
arrays["close"] = df["Close"].astype(float).values
if "Volume" in df.columns:
arrays["volume"] = df["Volume"].astype(float).values
arrays["dates"] = df["Date"].astype(str).values
return arrays
def _calculate_indicator_talib(
self,
price_data: list[dict[str, Any]],
indicator: str,
params: dict[str, IndicatorParamValue] | None = None,
) -> list[TechnicalIndicatorData]:
"""Calculate technical indicator using TA-Lib."""
if not price_data:
return []
# Get indicator definition
indicator_upper = indicator.upper()
if indicator_upper not in INDICATOR_DEFINITIONS:
raise TechnicalAnalysisError(f"Unknown indicator: {indicator}")
definition = INDICATOR_DEFINITIONS[indicator_upper]
# Use provided params or defaults
final_params: dict[str, IndicatorParamValue]
if params is None:
final_params = definition["default_params"].copy()
else:
# Merge with defaults for missing parameters
final_params = definition["default_params"].copy()
final_params.update(params)
# Validate parameters
self._validate_parameters(indicator, final_params)
# Prepare price arrays
arrays = self._prepare_price_arrays(price_data, definition["input_types"])
# Get TA-Lib function
talib_func_name = definition["talib_function"].split(".")[
-1
] # Extract function name
talib_func = getattr(talib, talib_func_name)
# Prepare function arguments
func_args = []
func_kwargs = {}
# Add required price arrays based on input types
for input_type in definition["input_types"]:
if input_type == "close":
func_args.append(arrays["close"])
elif input_type == "ohlc":
func_args.extend([arrays["high"], arrays["low"], arrays["close"]])
elif input_type == "ohlcv":
func_args.extend(
[arrays["high"], arrays["low"], arrays["close"], arrays["volume"]]
)
elif input_type == "hl":
func_args.extend([arrays["high"], arrays["low"]])
# Add parameters as keyword arguments
for param_name, param_value in final_params.items():
func_kwargs[param_name] = param_value
# Calculate indicator
try:
ta_result = talib_func(*func_args, **func_kwargs)
except Exception as e:
raise TechnicalAnalysisError(
f"TA-Lib calculation failed for {indicator}: {str(e)}"
) from e
# Process results based on output format
result = []
dates = arrays["dates"]
output_format = definition["output_format"]
if output_format == "single":
# Single output array
for _i, (date, value) in enumerate(zip(dates, ta_result, strict=False)):
if not pd.isna(value):
result.append(
TechnicalIndicatorData(
date=date,
value=float(value),
indicator_type=indicator.lower(),
parameters=final_params,
)
)
elif output_format == "double":
# Two output arrays (e.g., STOCH, AROON)
for _i, (date, val1, val2) in enumerate(
zip(dates, ta_result[0], ta_result[1], strict=False)
):
if not pd.isna(val1) and not pd.isna(val2):
# Name outputs based on indicator
if indicator_upper == "STOCH":
value_dict = {"slowk": float(val1), "slowd": float(val2)}
elif indicator_upper == "AROON":
value_dict = {"aroondown": float(val1), "aroonup": float(val2)}
else:
value_dict = {"output1": float(val1), "output2": float(val2)}
result.append(
TechnicalIndicatorData(
date=date,
value=value_dict,
indicator_type=indicator.lower(),
parameters=final_params,
)
)
elif output_format == "triple":
# Three output arrays (e.g., MACD, BBANDS)
for _i, (date, val1, val2, val3) in enumerate(
zip(dates, ta_result[0], ta_result[1], ta_result[2], strict=False)
):
if not pd.isna(val1):
# Name outputs based on indicator
if indicator_upper == "MACD":
value_dict = {
"macd": float(val1),
"signal": float(val2) if not pd.isna(val2) else 0.0,
"histogram": float(val3) if not pd.isna(val3) else 0.0,
}
elif indicator_upper == "BBANDS":
value_dict = {
"upper": float(val1),
"middle": float(val2) if not pd.isna(val2) else 0.0,
"lower": float(val3) if not pd.isna(val3) else 0.0,
}
else:
value_dict = {
"output1": float(val1),
"output2": float(val2) if not pd.isna(val2) else 0.0,
"output3": float(val3) if not pd.isna(val3) else 0.0,
}
result.append(
TechnicalIndicatorData(
date=date,
value=value_dict,
indicator_type=indicator.lower(),
parameters=final_params,
)
)
return result
def calculate_indicator(
self,
symbol: str,
start_date: str,
end_date: str,
indicator: str | dict[str, IndicatorParamValue],
params: dict[str, IndicatorParamValue] | None = None,
) -> TAReportContext:
"""
Three-tier API for technical indicator calculation.
Usage:
1. String: calculate_indicator("AAPL", "2024-01-01", "2024-01-31", "RSI")
2. Preset: calculate_indicator("AAPL", "2024-01-01", "2024-01-31", "RSI_SCALPING")
3. Custom: calculate_indicator("AAPL", "2024-01-01", "2024-01-31", "RSI", {"timeperiod": 21})
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
indicator: Indicator name (string), preset name, or custom config dict
params: Optional custom parameters (for string indicators only)
Returns:
TAReportContext: Complete technical analysis context
"""
if isinstance(indicator, dict):
# Custom configuration provided as dict
if "name" not in indicator:
raise TechnicalAnalysisError(
"Custom indicator dict must contain 'name' field"
)
indicator_name = str(indicator["name"]) # Ensure it's a string
custom_params = {k: v for k, v in indicator.items() if k != "name"}
return self.get_ta_report_context(
symbol, indicator_name, start_date, end_date, custom_params
)
# Check if it's a preset name
all_presets = IndicatorPresets.get_all_presets()
if indicator in all_presets:
# Extract indicator name and parameters from preset
preset_params = all_presets[indicator]
# Determine base indicator from preset name
for base_indicator in INDICATOR_DEFINITIONS:
if indicator.startswith(base_indicator):
return self.get_ta_report_context(
symbol,
base_indicator.lower(),
start_date,
end_date,
preset_params,
)
# If no match found, try to extract from preset name
indicator_name = indicator.split("_")[0].lower()
return self.get_ta_report_context(
symbol, indicator_name, start_date, end_date, preset_params
)
# Regular indicator name (string)
return self.get_ta_report_context(
symbol, indicator, start_date, end_date, params
)
def get_available_indicators(self) -> dict[str, str]:
"""Get list of all available indicators with descriptions."""
return {
name: info["description"] for name, info in INDICATOR_DEFINITIONS.items()
}
def get_available_presets(
self, style: str | None = None
) -> dict[str, dict[str, IndicatorParamValue]]:
"""
Get available indicator presets.
Args:
style: Optional trading style filter ("scalping", "day_trading", "swing", "position")
Returns:
Dict of preset names to parameter configurations
"""
if style:
return IndicatorPresets.get_preset_for_style(style)
return IndicatorPresets.get_all_presets()
def get_indicator_info(self, indicator: str) -> IndicatorConfig:
"""
Get detailed information about a specific indicator.
Args:
indicator: Indicator name
Returns:
IndicatorConfig with full indicator specifications
"""
indicator_upper = indicator.upper()
if indicator_upper not in INDICATOR_DEFINITIONS:
raise TechnicalAnalysisError(f"Unknown indicator: {indicator}")
definition = INDICATOR_DEFINITIONS[indicator_upper]
return IndicatorConfig(
name=indicator_upper,
parameters=definition["default_params"],
input_types=definition["input_types"],
output_format=definition["output_format"],
param_ranges=definition["param_ranges"],
default_params=definition["default_params"],
talib_function=definition["talib_function"],
description=definition["description"],
)
def _calculate_signal_strength(
self, indicator_data: list[TechnicalIndicatorData], indicator: str
) -> float:
"""Calculate signal strength from indicator data."""
if not indicator_data:
return 0.0
latest = indicator_data[-1]
if indicator.lower() == "rsi":
rsi_value = latest.value
if isinstance(rsi_value, int | float):
if rsi_value > 70:
return -0.8 # Overbought - sell signal
elif rsi_value < 30:
return 0.8 # Oversold - buy signal
else:
return (50 - rsi_value) / 50 # Normalized between -1 and 1
elif indicator.lower() == "macd":
if isinstance(latest.value, dict):
macd_val = latest.value.get("macd", 0)
signal_val = latest.value.get("signal", 0)
if macd_val > signal_val:
return 0.6 # Bullish
else:
return -0.6 # Bearish
elif indicator.lower() == "sma":
# Would need current price to compare with SMA
return 0.0 # Neutral for now
return 0.0
def _get_recommendation(self, signal_strength: float) -> str:
"""Convert signal strength to recommendation."""
if signal_strength > 0.5:
return "BUY"
elif signal_strength < -0.5:
return "SELL"
else:
return "HOLD"
def _generate_analysis_summary(
self, indicator: str, signal_strength: float, recommendation: str
) -> str:
"""Generate human-readable analysis summary."""
strength_desc = (
"strong"
if abs(signal_strength) > 0.7
else "moderate"
if abs(signal_strength) > 0.3
else "weak"
)
direction = (
"bullish"
if signal_strength > 0
else "bearish"
if signal_strength < 0
else "neutral"
)
return f"{indicator.upper()} indicator shows {strength_desc} {direction} signal. Signal strength: {signal_strength:.2f}. Recommendation: {recommendation}."
def update_market_data(self, symbol: str, start_date: str, end_date: str):
pass # TODO: fetch market data and save
"""
Update market data by fetching fresh data from client and storing in repository.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
"""
try:
logger.info(
f"Updating market data for {symbol} from {start_date} to {end_date}"
)
# Fetch fresh data from client
client_data = self.yfin_client.get_data(symbol, start_date, end_date)
price_data = client_data.get("data", [])
if price_data:
# Convert to DataFrame
df = pd.DataFrame(price_data)
# Store in repository
self.repo.store_marketdata(symbol, df)
logger.info(
f"Successfully stored {len(price_data)} records for {symbol}"
)
else:
logger.warning(f"No data received for {symbol}")
except Exception as e:
logger.error(f"Error updating market data for {symbol}: {e}")
raise

View File

@ -1,546 +0,0 @@
#!/usr/bin/env python3
"""
Test MarketDataService with mock YFinanceClient and real MarketDataRepository.
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath("."))
from tradingagents.clients.base import BaseClient
from tradingagents.domains.marketdata.market_data_service import (
DataQuality,
MarketDataContext,
MarketDataService,
)
from tradingagents.repositories.market_data_repository import MarketDataRepository
class MockYFinanceClient(BaseClient):
"""Mock Yahoo Finance client that returns predictable test data."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connection_works = True
def test_connection(self) -> bool:
return self.connection_works
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""Return realistic mock market data."""
# Generate realistic price data
base_price = {"AAPL": 180.0, "TSLA": 250.0, "MSFT": 400.0}.get(symbol, 100.0)
mock_data = []
current_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
price = base_price
while current_date <= end_date_dt:
# Simulate some price movement
price_change = (
hash(current_date.strftime("%Y-%m-%d")) % 10 - 5
) / 100 # -5% to +5%
price *= 1 + price_change * 0.01
mock_data.append(
{
"Date": current_date.strftime("%Y-%m-%d %H:%M:%S"),
"Open": round(price * 0.99, 2),
"High": round(price * 1.02, 2),
"Low": round(price * 0.98, 2),
"Close": round(price, 2),
"Adj Close": round(price, 2),
"Volume": 45000000 + (hash(symbol) % 20000000),
}
)
current_date += timedelta(days=1)
return {
"symbol": symbol,
"period": {"start": start_date, "end": end_date},
"data": mock_data,
"metadata": {
"source": "mock_yahoo_finance",
"record_count": len(mock_data),
"columns": [
"Date",
"Open",
"High",
"Low",
"Close",
"Adj Close",
"Volume",
],
"retrieved_at": datetime.utcnow().isoformat(),
},
}
def test_online_mode_with_mock_client():
"""Test MarketDataService in online mode with mock client."""
print("📈 Testing MarketDataService - Online Mode")
# Create mock client and real repository
mock_client = MockYFinanceClient()
real_repo = MarketDataRepository("test_data")
# Create service in online mode
service = MarketDataService(
client=mock_client, repository=real_repo, online_mode=True, data_dir="test_data"
)
try:
# Test basic price context
context = service.get_price_context(
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
)
print(f"✅ Price context created: {context.__class__.__name__}")
print(f" Symbol: {context.symbol}")
print(f" Period: {context.period}")
print(f" Price data records: {len(context.price_data)}")
print(f" Technical indicators: {len(context.technical_indicators)}")
# Validate required fields
assert context.symbol == "AAPL"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-01-05"
assert len(context.price_data) > 0
assert "data_quality" in context.metadata
print("✅ Basic validation passed")
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
parsed = json.loads(json_output)
print(f"✅ JSON serialization: {len(json_output)} characters")
print(f" Top-level keys: {list(parsed.keys())}")
# Test with technical indicators
context_with_indicators = service.get_context(
symbol="TSLA",
start_date="2024-01-01",
end_date="2024-01-03",
indicators=["rsi", "macd"],
)
print("✅ Context with indicators created")
print(" Requested indicators: ['rsi', 'macd']")
print(
f" Available indicators: {list(context_with_indicators.technical_indicators.keys())}"
)
return True
except Exception as e:
print(f"❌ Online mode test failed: {e}")
return False
def test_offline_mode_with_real_repository():
"""Test MarketDataService in offline mode with real repository."""
print("\n💾 Testing MarketDataService - Offline Mode")
# Create service in offline mode (no client)
real_repo = MarketDataRepository("test_data")
service = MarketDataService(
client=None, repository=real_repo, online_mode=False, data_dir="test_data"
)
try:
# Test offline context (will likely return empty data)
context = service.get_price_context(
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
)
print(f"✅ Offline context created: {context.__class__.__name__}")
print(f" Symbol: {context.symbol}")
print(f" Price data records: {len(context.price_data)}")
print(f" Data quality: {context.metadata.get('data_quality')}")
print(f" Service mode: online={service.is_online()}")
# Should handle empty data gracefully
assert context.symbol == "AAPL"
assert isinstance(context.price_data, list)
assert "data_quality" in context.metadata
print("✅ Offline mode graceful handling verified")
return True
except Exception as e:
print(f"❌ Offline mode test failed: {e}")
return False
def test_error_handling():
"""Test error handling scenarios."""
print("\n⚠️ Testing Error Handling")
# Test with broken client
class BrokenClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("Simulated client failure")
broken_client = BrokenClient()
real_repo = MarketDataRepository("test_data")
service = MarketDataService(
client=broken_client,
repository=real_repo,
online_mode=True, # Online mode but client will fail
data_dir="test_data",
)
try:
context = service.get_price_context("AAPL", "2024-01-01", "2024-01-05")
print("✅ Error handling worked")
print(f" Symbol: {context.symbol}")
print(f" Price data records: {len(context.price_data)}")
print(f" Data quality: {context.metadata.get('data_quality')}")
# Should fallback to repository or return empty data
assert context.symbol == "AAPL"
assert isinstance(context.price_data, list)
return True
except Exception as e:
print(f"❌ Error handling test failed: {e}")
return False
def test_data_quality_assessment():
"""Test data quality determination logic."""
print("\n🔍 Testing Data Quality Assessment")
mock_client = MockYFinanceClient()
real_repo = MarketDataRepository("test_data")
service = MarketDataService(
client=mock_client, repository=real_repo, online_mode=True, data_dir="test_data"
)
try:
# Test with good data
context = service.get_context("AAPL", "2024-01-01", "2024-01-10")
data_quality = context.metadata.get("data_quality")
print(f"✅ Data quality assessment: {data_quality}")
print(f" Records: {len(context.price_data)}")
print(f" Online mode: {service.is_online()}")
# Should be medium or high quality for mock data
assert data_quality in [DataQuality.MEDIUM, DataQuality.HIGH]
return True
except Exception as e:
print(f"❌ Data quality test failed: {e}")
return False
def test_json_structure_validation():
"""Test detailed JSON structure validation."""
print("\n📄 Testing JSON Structure")
mock_client = MockYFinanceClient()
service = MarketDataService(client=mock_client, repository=None, online_mode=True)
try:
context = service.get_price_context("MSFT", "2024-01-01", "2024-01-03")
json_str = context.model_dump_json(indent=2)
data = json.loads(json_str)
# Validate required structure
required_fields = [
"symbol",
"period",
"price_data",
"technical_indicators",
"metadata",
]
for field in required_fields:
assert field in data, f"Missing field: {field}"
# Validate period structure
period = data["period"]
assert "start" in period and "end" in period
# Validate price data structure
assert isinstance(data["price_data"], list)
if data["price_data"]:
first_record = data["price_data"][0]
required_price_fields = ["Date", "Open", "High", "Low", "Close", "Volume"]
for field in required_price_fields:
assert field in first_record, f"Missing price field: {field}"
# Validate metadata
metadata = data["metadata"]
assert "data_quality" in metadata
assert "service" in metadata
print("✅ JSON structure validation passed")
print(f" Fields: {list(data.keys())}")
print(f" Price records: {len(data['price_data'])}")
print(f" Metadata keys: {list(metadata.keys())}")
return True
except Exception as e:
print(f"❌ JSON structure test failed: {e}")
return False
def test_force_refresh_parameter():
"""Test the force_refresh parameter functionality."""
try:
mock_client = MockYFinanceClient()
real_repo = MarketDataRepository("test_data")
service = MarketDataService(
client=mock_client, repository=real_repo, online_mode=True
)
# Test normal flow (should use repository if available)
normal_context = service.get_context(
"AAPL", "2024-01-01", "2024-01-31", force_refresh=False
)
# Test force refresh (should bypass repository and use client)
refresh_context = service.get_context(
"AAPL", "2024-01-01", "2024-01-31", force_refresh=True
)
# Both should return valid contexts
assert isinstance(normal_context, MarketDataContext)
assert isinstance(refresh_context, MarketDataContext)
assert normal_context.symbol == "AAPL"
assert refresh_context.symbol == "AAPL"
# Check metadata indicates source
refresh_metadata = refresh_context.metadata
assert "force_refresh" in refresh_metadata
assert refresh_metadata["force_refresh"]
print("✅ Force refresh parameter test passed")
return True
except Exception as e:
print(f"❌ Force refresh test failed: {e}")
return False
def test_local_first_strategy():
"""Test that the service checks local data first when available."""
try:
class MockRepositoryWithData(MarketDataRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True # Pretend we have the data
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"symbol": kwargs.get("symbol", "TEST"),
"data": [
{"date": "2024-01-01", "close": 150.0},
{"date": "2024-01-02", "close": 151.0},
],
"metadata": {"source": "test_repository"},
}
mock_client = MockYFinanceClient()
mock_repo = MockRepositoryWithData("test_data")
service = MarketDataService(
client=mock_client, repository=mock_repo, online_mode=True
)
# Should use local data since repository has_data_for_period returns True
context = service.get_context("TEST", "2024-01-01", "2024-01-31")
# Verify we used local data
assert context.metadata.get("price_data_source") == "local_cache"
assert len(context.price_data) == 2 # From mock repository
print("✅ Local-first strategy test passed")
return True
except Exception as e:
print(f"❌ Local-first strategy test failed: {e}")
return False
def test_local_first_fallback_to_api():
"""Test that service falls back to API when local data is insufficient."""
try:
class MockRepositoryWithoutData(MarketDataRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return False # Pretend we don't have the data
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"symbol": kwargs.get("symbol", "TEST"),
"data": [],
"metadata": {},
}
def store_data(
self,
symbol: str,
data: dict[str, Any],
overwrite: bool = False,
**kwargs,
) -> bool:
return True # Pretend storage was successful
mock_client = MockYFinanceClient()
mock_repo = MockRepositoryWithoutData("test_data")
service = MarketDataService(
client=mock_client, repository=mock_repo, online_mode=True
)
# Should fall back to API since repository doesn't have data
context = service.get_context("TEST", "2024-01-01", "2024-01-31")
# Verify we used API data
assert context.metadata.get("price_data_source") == "live_api"
assert len(context.price_data) > 0 # From mock client
print("✅ Local-first fallback to API test passed")
return True
except Exception as e:
print(f"❌ Local-first fallback test failed: {e}")
return False
def test_force_refresh_bypasses_local_data():
"""Test that force_refresh=True bypasses local data even when available."""
try:
class MockRepositoryAlwaysHasData(MarketDataRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True # Always claim we have data
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"symbol": kwargs.get("symbol", "TEST"),
"data": [
{"date": "2024-01-01", "close": 100.0}
], # Different from client
"metadata": {"source": "local"},
}
def clear_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True
def store_data(
self,
symbol: str,
data: dict[str, Any],
overwrite: bool = False,
**kwargs,
) -> bool:
return True
mock_client = MockYFinanceClient()
mock_repo = MockRepositoryAlwaysHasData("test_data")
service = MarketDataService(
client=mock_client, repository=mock_repo, online_mode=True
)
# Force refresh should bypass local data
context = service.get_context(
"TEST", "2024-01-01", "2024-01-31", force_refresh=True
)
# Verify we used API data (force refresh)
assert context.metadata.get("price_data_source") == "live_api_refresh"
assert context.metadata.get("force_refresh")
# Should have more data from client than the single point from repository
assert len(context.price_data) > 1
print("✅ Force refresh bypasses local data test passed")
return True
except Exception as e:
print(f"❌ Force refresh bypass test failed: {e}")
return False
def main():
"""Run all MarketDataService tests."""
print("🧪 Testing MarketDataService\n")
tests = [
test_online_mode_with_mock_client,
test_offline_mode_with_real_repository,
test_error_handling,
test_data_quality_assessment,
test_json_structure_validation,
test_force_refresh_parameter,
test_local_first_strategy,
test_local_first_fallback_to_api,
test_force_refresh_bypasses_local_data,
]
passed = 0
failed = 0
for test in tests:
try:
if test():
passed += 1
else:
failed += 1
except Exception as e:
print(f"❌ Test {test.__name__} crashed: {e}")
failed += 1
print("\n📊 MarketDataService Test Results:")
print(f" ✅ Passed: {passed}")
print(f" ❌ Failed: {failed}")
if failed == 0:
print("🎉 All MarketDataService tests passed!")
else:
print("⚠️ Some tests failed - check output above")
return failed == 0
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -0,0 +1,662 @@
"""
Market data models and type definitions for technical analysis.
"""
from dataclasses import dataclass
from datetime import date
from enum import Enum
from typing import Any, Literal
from proto.message import Optional
from pydantic import BaseModel
# Proper type definitions to eliminate Any types
IndicatorParamValue = int | float | str | bool
InputSpec = Literal["close", "ohlc", "ohlcv", "hl"]
OutputSpec = Literal["single", "double", "triple"]
ParamRanges = dict[str, tuple[int | float, int | float]]
class IndicatorConfig(BaseModel):
"""Configuration for technical indicators with proper typing."""
name: str
parameters: dict[str, IndicatorParamValue] # No more Any
input_types: list[InputSpec] # Specific requirements
output_format: OutputSpec # Precise specification
param_ranges: ParamRanges # Type-safe validation
default_params: dict[str, IndicatorParamValue]
talib_function: str # Direct TA-Lib function name
description: str
class TechnicalAnalysisError(Exception):
"""Clear, actionable error messages for TA-Lib issues."""
pass
class DataQuality(Enum):
"""Data quality levels for market data."""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class TechnicalIndicatorData(BaseModel):
"""Technical indicator data point with proper typing."""
date: str
value: float | dict[str, float] # No more Any
indicator_type: str
parameters: dict[str, IndicatorParamValue] # Parameter context
confidence: float = 0.0 # Signal confidence
source: str = "talib" # Always TA-Lib
class MarketDataContext(BaseModel):
"""Market data context for trading analysis."""
symbol: str
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
price_data: list[dict[str, Any]]
technical_indicators: dict[str, list[TechnicalIndicatorData]]
metadata: dict[str, Any]
class TAReportContext(BaseModel):
"""Technical Analysis Report context with enhanced configuration."""
symbol: str
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
indicator: str
indicator_data: list[TechnicalIndicatorData]
analysis_summary: str
signal_strength: float # -1.0 to 1.0
recommendation: str # "BUY", "SELL", "HOLD"
indicator_config: IndicatorConfig # Full config used
parameter_summary: str # Human-readable params
metadata: dict[str, IndicatorParamValue] # Properly typed
class PriceDataContext(BaseModel):
"""Price Data context for historical price information."""
symbol: str
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
price_data: list[dict[str, Any]]
latest_price: float
price_change: float
price_change_percent: float
volume_info: dict[str, Any]
metadata: dict[str, Any]
# Fundamental Data Models
class FinancialRatio(BaseModel):
"""Financial ratio with calculation metadata."""
name: str
value: float | None
formula: str
category: str # "profitability", "liquidity", "leverage", "efficiency"
interpretation: str
class BalanceSheetData(BaseModel):
"""Balance sheet line items."""
date: str
total_assets: float | None = None
current_assets: float | None = None
cash_and_equivalents: float | None = None
accounts_receivable: float | None = None
inventory: float | None = None
total_liabilities: float | None = None
current_liabilities: float | None = None
accounts_payable: float | None = None
short_term_debt: float | None = None
long_term_debt: float | None = None
total_equity: float | None = None
retained_earnings: float | None = None
class IncomeStatementData(BaseModel):
"""Income statement line items."""
date: str
revenue: float | None = None
gross_profit: float | None = None
operating_income: float | None = None
net_income: float | None = None
ebitda: float | None = None
cost_of_revenue: float | None = None
operating_expenses: float | None = None
interest_expense: float | None = None
tax_expense: float | None = None
shares_outstanding: float | None = None
eps: float | None = None
class CashFlowData(BaseModel):
"""Cash flow statement line items."""
date: str
operating_cash_flow: float | None = None
investing_cash_flow: float | None = None
financing_cash_flow: float | None = None
free_cash_flow: float | None = None
capital_expenditures: float | None = None
dividends_paid: float | None = None
stock_repurchases: float | None = None
class BalanceSheetContext(BaseModel):
"""Balance sheet context for fundamental analysis."""
symbol: str
start_date: date
end_date: date
balance_sheet_data: list[BalanceSheetData]
key_ratios: list[FinancialRatio]
data_quality: DataQuality
source: str
metadata: dict[str, str]
class IncomeStatementContext(BaseModel):
"""Income statement context for fundamental analysis."""
symbol: str
start_date: date
end_date: date
income_statement_data: list[IncomeStatementData]
key_ratios: list[FinancialRatio]
data_quality: DataQuality
source: str
metadata: dict[str, Any]
class CashFlowContext(BaseModel):
"""Cash flow context for fundamental analysis."""
symbol: str
start_date: date
end_date: date
cash_flow_data: list[CashFlowData]
key_ratios: list[FinancialRatio]
data_quality: DataQuality
source: str
metadata: dict[str, Any]
class FundamentalContext(BaseModel):
"""Comprehensive fundamental analysis context."""
symbol: str
start_date: date
end_date: date
balance_sheet: Optional[BalanceSheetContext] = None
income_statement: Optional[IncomeStatementContext] = None
cash_flow: Optional[CashFlowContext] = None
comprehensive_ratios: Optional[list[FinancialRatio]] = None
valuation_metrics: Optional[dict[str, float | None]] = None
financial_health_score: float = 0 # 0-100 composite score
data_quality: DataQuality = DataQuality.LOW
source: Optional[str] = None
metadata: Optional[dict[str, str]] = None
# Reported Financials Models (for Finnhub API responses)
@dataclass
class FinancialLineItem:
"""Individual financial statement line item from reported financials."""
concept: str
unit: str
label: str
value: float | int
@dataclass
class ReportedFinancialsData:
"""Financial statements data containing balance sheet, income statement, and cash flow."""
bs: list[FinancialLineItem] # Balance Sheet
ic: list[FinancialLineItem] # Income Statement
cf: list[FinancialLineItem] # Cash Flow Statement
@dataclass
class ReportedFinancialsResponse:
"""Complete response from Finnhub reported financials API."""
start_date: str
end_date: str
year: int
quarter: int
access_number: str
data: ReportedFinancialsData
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ReportedFinancialsResponse":
"""Create ReportedFinancialsResponse from API response dictionary."""
financial_data = data.get("data", {})
# Convert line items for each statement type
bs_items = [
FinancialLineItem(
concept=item["concept"],
unit=item["unit"],
label=item["label"],
value=item["value"],
)
for item in financial_data.get("bs", [])
]
ic_items = [
FinancialLineItem(
concept=item["concept"],
unit=item["unit"],
label=item["label"],
value=item["value"],
)
for item in financial_data.get("ic", [])
]
cf_items = [
FinancialLineItem(
concept=item["concept"],
unit=item["unit"],
label=item["label"],
value=item["value"],
)
for item in financial_data.get("cf", [])
]
return cls(
start_date=data["start_date"],
end_date=data["end_date"],
year=data["year"],
quarter=data["quarter"],
access_number=data["access_number"],
data=ReportedFinancialsData(bs=bs_items, ic=ic_items, cf=cf_items),
)
# Insider Transactions Models
@dataclass
class InsiderTransaction:
"""Individual insider transaction record."""
name: str
share: int
change: int
filing_date: str
transaction_date: str
transaction_code: str
transaction_price: float
@dataclass
class InsiderTransactionsResponse:
"""Complete response from Finnhub insider transactions API."""
data: list[InsiderTransaction]
symbol: str
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "InsiderTransactionsResponse":
"""Create InsiderTransactionsResponse from API response dictionary."""
transactions = [
InsiderTransaction(
name=item["name"],
share=item["share"],
change=item["change"],
filing_date=item["filingDate"],
transaction_date=item["transactionDate"],
transaction_code=item["transactionCode"],
transaction_price=item["transactionPrice"],
)
for item in data.get("data", [])
]
return cls(data=transactions, symbol=data["symbol"])
# Insider Sentiment Models
@dataclass
class InsiderSentimentData:
"""Individual insider sentiment data point."""
symbol: str
year: int
month: int
change: int
mspr: float # Monthly Share Purchase Ratio
@dataclass
class InsiderSentimentResponse:
"""Complete response from Finnhub insider sentiment API."""
data: list[InsiderSentimentData]
symbol: str
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "InsiderSentimentResponse":
"""Create InsiderSentimentResponse from API response dictionary."""
sentiment_data = [
InsiderSentimentData(
symbol=item["symbol"],
year=item["year"],
month=item["month"],
change=item["change"],
mspr=item["mspr"],
)
for item in data.get("data", [])
]
return cls(data=sentiment_data, symbol=data["symbol"])
# Company Profile Models
@dataclass
class CompanyProfile:
"""Company profile information from Finnhub."""
country: str
currency: str
exchange: str
ipo: str
market_capitalization: float
name: str
phone: str
share_outstanding: float
ticker: str
weburl: str
logo: str
finnhub_industry: str
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "CompanyProfile":
"""Create CompanyProfile from API response dictionary."""
return cls(
country=data.get("country", ""),
currency=data.get("currency", ""),
exchange=data.get("exchange", ""),
ipo=data.get("ipo", ""),
market_capitalization=data.get("marketCapitalization", 0.0),
name=data.get("name", ""),
phone=data.get("phone", ""),
share_outstanding=data.get("shareOutstanding", 0.0),
ticker=data.get("ticker", ""),
weburl=data.get("weburl", ""),
logo=data.get("logo", ""),
finnhub_industry=data.get("finnhubIndustry", ""),
)
# Complete indicator definitions for 20 professional indicators
INDICATOR_DEFINITIONS = {
# Momentum Indicators (7)
"RSI": {
"talib_function": "talib.RSI",
"input_types": ["close"],
"default_params": {"timeperiod": 14},
"param_ranges": {"timeperiod": (2, 100)},
"output_format": "single",
"description": "Relative Strength Index",
},
"MACD": {
"talib_function": "talib.MACD",
"input_types": ["close"],
"default_params": {"fastperiod": 12, "slowperiod": 26, "signalperiod": 9},
"param_ranges": {
"fastperiod": (2, 50),
"slowperiod": (10, 200),
"signalperiod": (2, 50),
},
"output_format": "triple",
"description": "Moving Average Convergence Divergence",
},
"STOCH": {
"talib_function": "talib.STOCH",
"input_types": ["ohlc"],
"default_params": {"fastk_period": 14, "slowk_period": 3, "slowd_period": 3},
"param_ranges": {
"fastk_period": (1, 100),
"slowk_period": (1, 50),
"slowd_period": (1, 50),
},
"output_format": "double",
"description": "Stochastic Oscillator",
},
"WILLR": {
"talib_function": "talib.WILLR",
"input_types": ["ohlc"],
"default_params": {"timeperiod": 14},
"param_ranges": {"timeperiod": (2, 100)},
"output_format": "single",
"description": "Williams %R",
},
"CCI": {
"talib_function": "talib.CCI",
"input_types": ["ohlc"],
"default_params": {"timeperiod": 20},
"param_ranges": {"timeperiod": (2, 100)},
"output_format": "single",
"description": "Commodity Channel Index",
},
"ROC": {
"talib_function": "talib.ROC",
"input_types": ["close"],
"default_params": {"timeperiod": 12},
"param_ranges": {"timeperiod": (1, 100)},
"output_format": "single",
"description": "Rate of Change",
},
"MFI": {
"talib_function": "talib.MFI",
"input_types": ["ohlcv"],
"default_params": {"timeperiod": 14},
"param_ranges": {"timeperiod": (2, 100)},
"output_format": "single",
"description": "Money Flow Index",
},
# Trend Indicators (7)
"SMA": {
"talib_function": "talib.SMA",
"input_types": ["close"],
"default_params": {"timeperiod": 20},
"param_ranges": {"timeperiod": (2, 200)},
"output_format": "single",
"description": "Simple Moving Average",
},
"EMA": {
"talib_function": "talib.EMA",
"input_types": ["close"],
"default_params": {"timeperiod": 20},
"param_ranges": {"timeperiod": (2, 200)},
"output_format": "single",
"description": "Exponential Moving Average",
},
"BBANDS": {
"talib_function": "talib.BBANDS",
"input_types": ["close"],
"default_params": {"timeperiod": 20, "nbdevup": 2.0, "nbdevdn": 2.0},
"param_ranges": {
"timeperiod": (2, 100),
"nbdevup": (0.1, 5.0),
"nbdevdn": (0.1, 5.0),
},
"output_format": "triple",
"description": "Bollinger Bands",
},
"SAR": {
"talib_function": "talib.SAR",
"input_types": ["ohlc"],
"default_params": {"acceleration": 0.02, "maximum": 0.2},
"param_ranges": {"acceleration": (0.01, 0.1), "maximum": (0.1, 1.0)},
"output_format": "single",
"description": "Parabolic SAR",
},
"ADX": {
"talib_function": "talib.ADX",
"input_types": ["ohlc"],
"default_params": {"timeperiod": 14},
"param_ranges": {"timeperiod": (2, 100)},
"output_format": "single",
"description": "Average Directional Index",
},
"AROON": {
"talib_function": "talib.AROON",
"input_types": ["hl"],
"default_params": {"timeperiod": 25},
"param_ranges": {"timeperiod": (2, 100)},
"output_format": "double",
"description": "Aroon Oscillator",
},
"TEMA": {
"talib_function": "talib.TEMA",
"input_types": ["close"],
"default_params": {"timeperiod": 20},
"param_ranges": {"timeperiod": (2, 200)},
"output_format": "single",
"description": "Triple Exponential Moving Average",
},
# Volume Indicators (3)
"OBV": {
"talib_function": "talib.OBV",
"input_types": ["ohlcv"],
"default_params": {},
"param_ranges": {},
"output_format": "single",
"description": "On Balance Volume",
},
"AD": {
"talib_function": "talib.AD",
"input_types": ["ohlcv"],
"default_params": {},
"param_ranges": {},
"output_format": "single",
"description": "Accumulation/Distribution",
},
"ADOSC": {
"talib_function": "talib.ADOSC",
"input_types": ["ohlcv"],
"default_params": {"fastperiod": 3, "slowperiod": 10},
"param_ranges": {"fastperiod": (2, 50), "slowperiod": (5, 100)},
"output_format": "single",
"description": "Accumulation/Distribution Oscillator",
},
# Volatility Indicators (3)
"ATR": {
"talib_function": "talib.ATR",
"input_types": ["ohlc"],
"default_params": {"timeperiod": 14},
"param_ranges": {"timeperiod": (1, 100)},
"output_format": "single",
"description": "Average True Range",
},
"NATR": {
"talib_function": "talib.NATR",
"input_types": ["ohlc"],
"default_params": {"timeperiod": 14},
"param_ranges": {"timeperiod": (1, 100)},
"output_format": "single",
"description": "Normalized Average True Range",
},
"TRANGE": {
"talib_function": "talib.TRANGE",
"input_types": ["ohlc"],
"default_params": {},
"param_ranges": {},
"output_format": "single",
"description": "True Range",
},
}
class IndicatorPresets:
"""Professional indicator presets for different trading styles."""
@staticmethod
def get_scalping_presets() -> dict[str, dict[str, IndicatorParamValue]]:
"""Fast scalping presets (1-5 minute timeframes)."""
return {
"RSI_SCALPING": {"timeperiod": 5},
"MACD_SCALPING": {"fastperiod": 5, "slowperiod": 13, "signalperiod": 5},
"STOCH_SCALPING": {"fastk_period": 5, "slowk_period": 3, "slowd_period": 3},
"EMA_SCALPING": {"timeperiod": 9},
"BBANDS_TIGHT": {"timeperiod": 10, "nbdevup": 1.5, "nbdevdn": 1.5},
"ATR_SCALPING": {"timeperiod": 5},
}
@staticmethod
def get_day_trading_presets() -> dict[str, dict[str, IndicatorParamValue]]:
"""Day trading presets (5-60 minute timeframes)."""
return {
"RSI_DAY_TRADING": {"timeperiod": 14},
"MACD_DAY_TRADING": {"fastperiod": 12, "slowperiod": 26, "signalperiod": 9},
"STOCH_DAY_TRADING": {
"fastk_period": 14,
"slowk_period": 3,
"slowd_period": 3,
},
"EMA_DAY_TRADING": {"timeperiod": 20},
"BBANDS_STANDARD": {"timeperiod": 20, "nbdevup": 2.0, "nbdevdn": 2.0},
"ADX_DAY_TRADING": {"timeperiod": 14},
"ATR_DAY_TRADING": {"timeperiod": 14},
}
@staticmethod
def get_swing_trading_presets() -> dict[str, dict[str, IndicatorParamValue]]:
"""Swing trading presets (daily timeframes)."""
return {
"RSI_SWING": {"timeperiod": 21},
"MACD_SWING": {"fastperiod": 12, "slowperiod": 26, "signalperiod": 9},
"STOCH_SWING": {"fastk_period": 21, "slowk_period": 5, "slowd_period": 5},
"SMA_SWING_SHORT": {"timeperiod": 50},
"SMA_SWING_LONG": {"timeperiod": 200},
"EMA_SWING": {"timeperiod": 50},
"BBANDS_SWING": {"timeperiod": 20, "nbdevup": 2.5, "nbdevdn": 2.5},
"ADX_SWING": {"timeperiod": 21},
"AROON_SWING": {"timeperiod": 25},
}
@staticmethod
def get_position_trading_presets() -> dict[str, dict[str, IndicatorParamValue]]:
"""Position trading presets (weekly/monthly timeframes)."""
return {
"RSI_POSITION": {"timeperiod": 30},
"MACD_POSITION": {"fastperiod": 20, "slowperiod": 50, "signalperiod": 15},
"SMA_POSITION_SHORT": {"timeperiod": 100},
"SMA_POSITION_LONG": {"timeperiod": 300},
"EMA_POSITION": {"timeperiod": 100},
"BBANDS_POSITION": {"timeperiod": 50, "nbdevup": 3.0, "nbdevdn": 3.0},
"ADX_POSITION": {"timeperiod": 30},
"AROON_POSITION": {"timeperiod": 50},
}
@staticmethod
def get_all_presets() -> dict[str, dict[str, IndicatorParamValue]]:
"""Get all available presets combined."""
all_presets = {}
all_presets.update(IndicatorPresets.get_scalping_presets())
all_presets.update(IndicatorPresets.get_day_trading_presets())
all_presets.update(IndicatorPresets.get_swing_trading_presets())
all_presets.update(IndicatorPresets.get_position_trading_presets())
return all_presets
@staticmethod
def get_preset_for_style(style: str) -> dict[str, dict[str, IndicatorParamValue]]:
"""Get presets for a specific trading style."""
style_map = {
"scalping": IndicatorPresets.get_scalping_presets(),
"day_trading": IndicatorPresets.get_day_trading_presets(),
"swing": IndicatorPresets.get_swing_trading_presets(),
"position": IndicatorPresets.get_position_trading_presets(),
}
return style_map.get(style.lower(), {})

View File

@ -0,0 +1,183 @@
"""
Repository for fundamental financial data (balance sheets, income statements, cash flow).
"""
import json
import logging
from dataclasses import asdict, dataclass
from datetime import date
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..models import ReportedFinancialsResponse
# Base repository functionality inline
logger = logging.getLogger(__name__)
@dataclass
class FinancialStatement:
"""Represents a financial statement with standardized structure matching the service."""
period: str
report_date: str
publish_date: str
currency: str
data: dict[str, float]
@dataclass
class FinancialData:
"""Container for all financial statements for a symbol and date."""
symbol: str
date: date
financial_statements: dict[
str, FinancialStatement
] # "balance_sheet", "income_statement", "cash_flow"
class FundamentalDataRepository:
"""Repository for accessing cached fundamental financial data as a KV store."""
def __init__(self, data_dir: str, **kwargs):
"""
Initialize fundamental data repository.
Args:
data_dir: Base directory for fundamental data storage
**kwargs: Additional configuration
"""
self.fundamental_data_dir = Path(data_dir) / "fundamental_data"
self.fundamental_data_dir.mkdir(parents=True, exist_ok=True)
def store_financial_statements(
self,
symbol: str,
date: date,
statements: dict[str, FinancialStatement],
) -> tuple[date, dict[str, FinancialStatement]]:
"""
Store financial statements for a symbol and date.
Args:
symbol: Stock ticker symbol
date: Date of the financial statements
statements: Dictionary of statements keyed by type ("balance_sheet", "income_statement", "cash_flow")
Returns:
Tuple[date, dict[str, FinancialStatement]]: The stored date and statements
"""
# Create symbol directory
symbol_dir = self.fundamental_data_dir / symbol
self._ensure_path_exists(symbol_dir)
# Create JSON file path
file_path = symbol_dir / f"{date.isoformat()}.json"
try:
# Prepare data for JSON serialization
statements_data = {}
for statement_type, statement in statements.items():
statements_data[statement_type] = asdict(statement)
data = {
"symbol": symbol,
"date": date.isoformat(),
"financial_statements": statements_data,
"metadata": {
"stored_at": date.today().isoformat(),
"repository": "fundamental_data_repository",
"statement_count": len(statements),
},
}
# Write to JSON file
with open(file_path, "w") as f:
json.dump(data, f, indent=2, default=str)
logger.info(
f"Stored {len(statements)} financial statements for {symbol} on {date}"
)
return (date, statements)
except Exception as e:
logger.error(
f"Error storing financial statements for {symbol} on {date}: {e}"
)
raise
def store_reported_financials(
self,
symbol: str,
date: date,
reported_financials: "ReportedFinancialsResponse",
) -> bool:
"""
Store reported financials data from Finnhub API.
Args:
symbol: Stock ticker symbol
date: Date for the financial data
reported_financials: ReportedFinancialsResponse from Finnhub
Returns:
bool: True if storage was successful
"""
# Create symbol directory
symbol_dir = self.fundamental_data_dir / symbol
self._ensure_path_exists(symbol_dir)
# Create JSON file path
file_path = symbol_dir / f"{date.isoformat()}_reported_financials.json"
try:
# Convert dataclass to dict for JSON serialization
data = {
"symbol": symbol,
"date": date.isoformat(),
"reported_financials": {
"start_date": reported_financials.start_date,
"end_date": reported_financials.end_date,
"year": reported_financials.year,
"quarter": reported_financials.quarter,
"access_number": reported_financials.access_number,
"data": {
"bs": [asdict(item) for item in reported_financials.data.bs],
"ic": [asdict(item) for item in reported_financials.data.ic],
"cf": [asdict(item) for item in reported_financials.data.cf],
},
},
"metadata": {
"stored_at": date.today().isoformat(),
"repository": "fundamental_data_repository",
"data_source": "finnhub_reported_financials",
"bs_items": len(reported_financials.data.bs),
"ic_items": len(reported_financials.data.ic),
"cf_items": len(reported_financials.data.cf),
},
}
# Write to JSON file
with open(file_path, "w") as f:
json.dump(data, f, indent=2, default=str)
logger.info(
f"Stored reported financials for {symbol} on {date} "
f"(BS: {len(reported_financials.data.bs)}, "
f"IC: {len(reported_financials.data.ic)}, "
f"CF: {len(reported_financials.data.cf)} items)"
)
return True
except Exception as e:
logger.error(
f"Error storing reported financials for {symbol} on {date}: {e}"
)
return False
def _ensure_path_exists(self, path: Path) -> None:
"""Ensure a directory path exists."""
path.mkdir(parents=True, exist_ok=True)

View File

@ -1,313 +0,0 @@
"""
Repository for fundamental financial data (balance sheets, income statements, cash flow).
"""
import json
import logging
from dataclasses import asdict, dataclass
from datetime import date
from pathlib import Path
from .base import BaseRepository
logger = logging.getLogger(__name__)
@dataclass
class FinancialStatement:
"""Represents a financial statement with standardized structure matching the service."""
period: str
report_date: str
publish_date: str
currency: str
data: dict[str, float]
@dataclass
class FinancialData:
"""Container for all financial statements for a symbol and date."""
symbol: str
date: date
financial_statements: dict[
str, FinancialStatement
] # "balance_sheet", "income_statement", "cash_flow"
class FundamentalDataRepository(BaseRepository):
"""Repository for accessing cached fundamental financial data as a KV store."""
def __init__(self, data_dir: str, **kwargs):
"""
Initialize fundamental data repository.
Args:
data_dir: Base directory for fundamental data storage
**kwargs: Additional configuration
"""
self.fundamental_data_dir = Path(data_dir) / "fundamental_data"
self.fundamental_data_dir.mkdir(parents=True, exist_ok=True)
def get_financial_data(
self, symbol: str, start_date: date, end_date: date
) -> dict[date, FinancialData]:
"""
Get cached fundamental data for a symbol and date range.
Args:
symbol: Stock ticker symbol
start_date: Start date
end_date: End date
Returns:
Dict[date, FinancialData]: Financial data keyed by date
"""
symbol_dir = self.fundamental_data_dir / symbol
if not symbol_dir.exists():
logger.warning(f"No data directory found for symbol {symbol}")
return {}
financial_data = {}
# Scan for JSON files in the symbol directory
for json_file in symbol_dir.glob("*.json"):
try:
# Parse date from filename (YYYY-MM-DD.json)
date_str = json_file.stem
file_date = date.fromisoformat(date_str)
# Filter by date range
if start_date <= file_date <= end_date:
with open(json_file) as f:
data = json.load(f)
# Create FinancialStatement objects from JSON data
financial_statements = {}
for statement_type, statement_data in data.get(
"financial_statements", {}
).items():
financial_statements[statement_type] = FinancialStatement(
**statement_data
)
# Create FinancialData container
financial_data[file_date] = FinancialData(
symbol=symbol,
date=file_date,
financial_statements=financial_statements,
)
except (ValueError, json.JSONDecodeError, KeyError) as e:
logger.error(f"Error reading financial data from {json_file}: {e}")
continue
logger.info(f"Retrieved {len(financial_data)} financial records for {symbol}")
return financial_data
def has_data_for_period(
self, symbol: str, start_date: str, end_date: str, frequency: str = "quarterly"
) -> bool:
"""
Check if we have sufficient data for the given period.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency (not used in this implementation)
Returns:
bool: True if we have any data in the period
"""
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
financial_data = self.get_financial_data(symbol, start_dt, end_dt)
return len(financial_data) > 0
def get_data(
self,
symbol: str,
start_date: str,
end_date: str,
frequency: str = "quarterly",
**kwargs,
) -> dict[str, any]:
"""
Get data in the format expected by the service.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency
**kwargs: Additional parameters
Returns:
Dict with financial_statements structure expected by service
"""
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
financial_data = self.get_financial_data(symbol, start_dt, end_dt)
if not financial_data:
return {}
# Get the most recent data
latest_date = max(financial_data.keys())
latest_data = financial_data[latest_date]
return {
"financial_statements": {
statement_type: asdict(statement)
for statement_type, statement in latest_data.financial_statements.items()
}
}
def store_data(
self,
symbol: str,
cache_data: dict,
frequency: str = "quarterly",
overwrite: bool = False,
) -> bool:
"""
Store data in the format expected by the service.
Args:
symbol: Stock ticker symbol
cache_data: Data dictionary with financial_statements
frequency: Reporting frequency (not used)
overwrite: Whether to overwrite existing data
Returns:
bool: True if successful
"""
try:
# Extract financial statements from cache_data
statements_data = cache_data.get("financial_statements", {})
if not statements_data:
logger.warning(
f"No financial statements found in cache_data for {symbol}"
)
return False
# Use today's date as the storage date
storage_date = date.today()
# Convert to FinancialStatement objects
financial_statements = {}
for statement_type, statement_dict in statements_data.items():
financial_statements[statement_type] = FinancialStatement(
**statement_dict
)
# Store the statements
self.store_financial_statements(symbol, storage_date, financial_statements)
return True
except Exception as e:
logger.error(f"Error storing data for {symbol}: {e}")
return False
def clear_data(
self, symbol: str, start_date: str, end_date: str, frequency: str = "quarterly"
) -> bool:
"""
Clear data for a symbol and date range.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency (not used)
Returns:
bool: True if successful
"""
try:
symbol_dir = self.fundamental_data_dir / symbol
if not symbol_dir.exists():
return True # Nothing to clear
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
# Remove files in the date range
for json_file in symbol_dir.glob("*.json"):
try:
date_str = json_file.stem
file_date = date.fromisoformat(date_str)
if start_dt <= file_date <= end_dt:
json_file.unlink()
logger.debug(f"Removed {json_file}")
except (ValueError, OSError) as e:
logger.warning(f"Error removing {json_file}: {e}")
continue
return True
except Exception as e:
logger.error(f"Error clearing data for {symbol}: {e}")
return False
def store_financial_statements(
self,
symbol: str,
date: date,
statements: dict[str, FinancialStatement],
) -> tuple[date, dict[str, FinancialStatement]]:
"""
Store financial statements for a symbol and date.
Args:
symbol: Stock ticker symbol
date: Date of the financial statements
statements: Dictionary of statements keyed by type ("balance_sheet", "income_statement", "cash_flow")
Returns:
Tuple[date, dict[str, FinancialStatement]]: The stored date and statements
"""
# Create symbol directory
symbol_dir = self.fundamental_data_dir / symbol
self._ensure_path_exists(symbol_dir)
# Create JSON file path
file_path = symbol_dir / f"{date.isoformat()}.json"
try:
# Prepare data for JSON serialization
statements_data = {}
for statement_type, statement in statements.items():
statements_data[statement_type] = asdict(statement)
data = {
"symbol": symbol,
"date": date.isoformat(),
"financial_statements": statements_data,
"metadata": {
"stored_at": date.today().isoformat(),
"repository": "fundamental_data_repository",
"statement_count": len(statements),
},
}
# Write to JSON file
with open(file_path, "w") as f:
json.dump(data, f, indent=2, default=str)
logger.info(
f"Stored {len(statements)} financial statements for {symbol} on {date}"
)
return (date, statements)
except Exception as e:
logger.error(
f"Error storing financial statements for {symbol} on {date}: {e}"
)
raise

View File

@ -8,12 +8,12 @@ from pathlib import Path
import pandas as pd
from .base import BaseRepository
# from .base import BaseRepository # Not found, removing import
logger = logging.getLogger(__name__)
class MarketDataRepository(BaseRepository):
class MarketDataRepository:
"""Repository for accessing historical market data from CSV files."""
def __init__(self, data_dir: str, **kwargs):
@ -60,9 +60,8 @@ class MarketDataRepository(BaseRepository):
df["Date"] = pd.to_datetime(df["Date"]).dt.date
# Filter by date range
filtered_df = df[
(df["Date"] >= start_date) & (df["Date"] <= end_date)
].copy()
mask = (df["Date"] >= start_date) & (df["Date"] <= end_date)
filtered_df: pd.DataFrame = df.loc[mask].copy()
logger.info(
f"Retrieved {len(filtered_df)} records for {symbol} from {start_date} to {end_date}"

View File

@ -0,0 +1,184 @@
"""
Article scraper client for extracting full content from news URLs.
"""
import logging
import time
from dataclasses import dataclass
from datetime import datetime
from urllib.parse import urlparse
import newspaper
logger = logging.getLogger(__name__)
@dataclass
class ScrapeResult:
"""Result of article scraping operation."""
status: str # 'SUCCESS', 'SCRAPE_FAILED', 'ARCHIVE_SUCCESS', 'NOT_FOUND'
content: str = ""
author: str = ""
final_url: str = ""
title: str = ""
publish_date: str = ""
class ArticleScraperClient:
"""Client for scraping article content with Internet Archive fallback."""
def __init__(self, user_agent: str, delay: float = 1.0):
"""
Initialize article scraper.
Args:
user_agent: User agent string for requests
delay: Delay between requests in seconds
"""
self.user_agent = user_agent or (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
)
self.delay = delay
def scrape_article(self, url: str) -> ScrapeResult:
"""
Scrape article content from URL with fallback to Internet Archive.
Args:
url: Article URL to scrape
Returns:
ScrapeResult: Scraping result with content and metadata
"""
if not url or not self._is_valid_url(url):
return ScrapeResult(status="NOT_FOUND", final_url=url)
# Try original source first
result = self._scrape_from_source(url)
if result.status == "SUCCESS":
return result
# Fallback to Internet Archive
logger.info(f"Original scraping failed for {url}, trying Internet Archive")
return self._scrape_from_wayback(url)
def _scrape_from_source(self, url: str) -> ScrapeResult:
"""Scrape article from original source using newspaper3k."""
try:
# Add delay to be respectful
time.sleep(self.delay)
# Configure newspaper article
article = newspaper.Article(url)
article.config.browser_user_agent = self.user_agent
article.config.request_timeout = 10
# Download and parse
article.download()
article.parse()
# Validate content
if not article.text or len(article.text.strip()) < 100:
logger.warning(f"Article content too short or empty for {url}")
return ScrapeResult(status="SCRAPE_FAILED", final_url=url)
# Handle publish_date which can be datetime or string
publish_date_str = ""
if article.publish_date:
if isinstance(article.publish_date, datetime):
publish_date_str = article.publish_date.strftime("%Y-%m-%d")
elif isinstance(article.publish_date, str):
publish_date_str = article.publish_date
else:
# Try to convert to string
publish_date_str = str(article.publish_date)
return ScrapeResult(
status="SUCCESS",
content=article.text.strip(),
author=", ".join(article.authors) if article.authors else "",
final_url=url,
title=article.title or "",
publish_date=publish_date_str,
)
except Exception as e:
logger.warning(f"Error scraping article from {url}: {e}")
return ScrapeResult(status="SCRAPE_FAILED", final_url=url)
def _scrape_from_wayback(self, url: str) -> ScrapeResult:
"""Scrape article from Internet Archive Wayback Machine."""
try:
import requests
except ImportError:
logger.error("requests not installed. Install with: pip install requests")
return ScrapeResult(status="NOT_FOUND", final_url=url)
try:
# Query Wayback Machine CDX API for snapshots
cdx_url = "http://web.archive.org/cdx/search/cdx"
params = {
"url": url,
"output": "json",
"fl": "timestamp,original",
"filter": "statuscode:200",
"limit": "1",
}
response = requests.get(cdx_url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
if len(data) < 2: # First row is headers
logger.warning(f"No archived snapshots found for {url}")
return ScrapeResult(status="NOT_FOUND", final_url=url)
# Get the most recent snapshot
timestamp, original_url = data[1]
archive_url = f"https://web.archive.org/web/{timestamp}/{original_url}"
logger.info(f"Found archived snapshot: {archive_url}")
# Scrape from archive URL
result = self._scrape_from_source(archive_url)
if result.status == "SUCCESS":
result.status = "ARCHIVE_SUCCESS"
result.final_url = archive_url
return result
except Exception as e:
logger.warning(f"Error accessing Internet Archive for {url}: {e}")
return ScrapeResult(status="NOT_FOUND", final_url=url)
def _is_valid_url(self, url: str) -> bool:
"""Check if URL is valid and accessible."""
try:
parsed = urlparse(url)
return bool(parsed.netloc) and parsed.scheme in ("http", "https")
except Exception:
return False
def scrape_multiple_articles(self, urls: list[str]) -> dict[str, ScrapeResult]:
"""
Scrape multiple articles sequentially.
Args:
urls: List of article URLs to scrape
Returns:
Dict mapping URLs to ScrapeResults
"""
results = {}
for i, url in enumerate(urls):
logger.info(f"Scraping article {i + 1}/{len(urls)}: {url}")
results[url] = self.scrape_article(url)
# Add delay between requests
if i < len(urls) - 1:
time.sleep(self.delay)
return results

View File

@ -1,194 +1,188 @@
"""
Google News client for live news data via web scraping.
Google News client for live news data via RSS feeds.
"""
import logging
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from urllib.parse import quote
from .base import BaseClient
import feedparser
import requests
from dateutil import parser as date_parser
logger = logging.getLogger(__name__)
class GoogleNewsClient(BaseClient):
@dataclass
class FeedEntry:
"""Structured representation of a feedparser entry."""
title: str
link: str
published: str
published_parsed: time.struct_time | None
summary: str
guid: str
@classmethod
def from_feedparser_dict(cls, entry: feedparser.FeedParserDict) -> "FeedEntry":
"""Convert a FeedParserDict to a structured FeedEntry."""
return cls(
title=getattr(entry, "title", "Untitled"),
link=getattr(entry, "link", ""),
published=getattr(entry, "published", ""),
published_parsed=getattr(entry, "published_parsed", None),
summary=getattr(entry, "summary", ""),
guid=getattr(entry, "id", getattr(entry, "link", "")),
)
@dataclass
class GoogleNewsArticle:
"""Represents a news article from Google News RSS feed."""
title: str
link: str
published: datetime
summary: str
source: str
guid: str
class GoogleNewsClient:
"""Client for Google News data via web scraping."""
def __init__(self, **kwargs):
"""
Initialize Google News client.
Args:
**kwargs: Configuration options including rate limits
"""
super().__init__(**kwargs)
self.max_retries = kwargs.get("max_retries", 3)
self.delay_between_requests = kwargs.get("delay_between_requests", 1.0)
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""
Get news data for a query and date range.
Args:
query: Search query
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
Dict[str, Any]: News data with metadata
"""
if not self.validate_date_range(start_date, end_date):
raise ValueError(f"Invalid date range: {start_date} to {end_date}")
try:
# Replace spaces with + for URL encoding
formatted_query = query.replace(" ", "+")
logger.info(
f"Fetching Google News for query: {query} from {start_date} to {end_date}"
)
news_results = getNewsData(formatted_query, start_date, end_date)
if not news_results:
logger.warning(f"No news found for query: {query}")
return {
"query": query,
"period": {"start": start_date, "end": end_date},
"articles": [],
"metadata": {
"source": "google_news",
"empty": True,
"reason": "no_articles_found",
},
}
# Process and standardize article data
processed_articles = []
for article in news_results:
processed_article = {
"headline": article.get("title", ""),
"summary": article.get("snippet", ""),
"url": article.get("link", ""),
"source": article.get("source", "Unknown"),
"date": article.get(
"date", end_date
), # Fallback to end_date if no date
"entities": article.get("entities", []),
}
processed_articles.append(processed_article)
return {
"query": query,
"period": {"start": start_date, "end": end_date},
"articles": processed_articles,
"metadata": {
"source": "google_news",
"article_count": len(processed_articles),
"retrieved_at": datetime.utcnow().isoformat(),
"search_query": formatted_query,
},
}
except Exception as e:
logger.error(f"Error fetching Google News for query '{query}': {e}")
raise
def get_company_news(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
def get_company_news(self, symbol: str) -> list[GoogleNewsArticle]:
"""
Get news data specific to a company symbol.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
Dict[str, Any]: Company-specific news data
list[GoogleNewsArticle]: Company-specific news data
"""
# Create company-focused search query
company_query = f"{symbol} stock"
result = self.get_data(company_query, start_date, end_date, **kwargs)
result["symbol"] = symbol
result["metadata"]["query_type"] = "company_specific"
return result
return self._get_rss_feed(symbol)
def get_global_news(
self,
start_date: str,
end_date: str,
categories: list[str] | None = None,
**kwargs,
) -> dict[str, Any]:
categories: list[str],
) -> list[GoogleNewsArticle]:
"""
Get global/macro news that might affect markets.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
categories: List of news categories to search
**kwargs: Additional parameters
Returns:
Dict[str, Any]: Global news data
list[GoogleNewsArticle]: Global news data
"""
if categories is None:
categories = ["economy", "finance", "markets", "business"]
# Get RSS queries for categories
all_articles = []
for category in categories:
try:
category_data = self.get_data(category, start_date, end_date, **kwargs)
# Add category tag to each article
for article in category_data.get("articles", []):
article["category"] = category
all_articles.extend(category_data.get("articles", []))
articles = self._get_rss_feed(category)
all_articles.extend(articles)
except Exception as e:
logger.warning(f"Failed to fetch news for category '{category}': {e}")
continue
return {
"query": "global_news",
"categories": categories,
"period": {"start": start_date, "end": end_date},
"articles": all_articles,
"metadata": {
"source": "google_news",
"article_count": len(all_articles),
"categories_searched": categories,
"retrieved_at": datetime.utcnow().isoformat(),
"query_type": "global_news",
},
}
return all_articles
def get_available_categories(self) -> list[str]:
def _get_rss_feed(self, query: str) -> list[GoogleNewsArticle]:
"""
Get list of commonly used news categories.
Fetch RSS feed from Google News for a given query.
Args:
query: Search query (company symbol or news category)
Returns:
List[str]: News categories
list[GoogleNewsArticle]: Parsed articles from RSS feed
"""
return [
"business",
"economy",
"finance",
"markets",
"technology",
"politics",
"world",
"healthcare",
"energy",
"crypto",
]
try:
# Construct Google News RSS URL
encoded_query = quote(query)
rss_url = f"https://news.google.com/rss/search?q={encoded_query}&hl=en-US&gl=US&ceid=US:en"
logger.info(f"Fetching RSS feed for query: {query}")
# Use requests with timeout and User-Agent header
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.get(rss_url, timeout=10, headers=headers)
response.raise_for_status()
# Use feedparser to parse the fetched content
feed = feedparser.parse(response.content)
# Check if feed was parsed successfully
if feed.bozo:
logger.warning(
f"Feed parsing had issues for query '{query}': {feed.bozo_exception}"
)
articles = []
for raw_entry in feed.entries:
try:
# Convert FeedParserDict to structured dataclass
entry = FeedEntry.from_feedparser_dict(raw_entry)
article = self._convert_entry_to_article(entry)
articles.append(article)
except Exception as e:
logger.warning(f"Failed to parse article entry: {e}")
continue
logger.info(
f"Successfully fetched {len(articles)} articles for query: {query}"
)
return articles
except requests.exceptions.RequestException as e:
logger.error(f"Network error fetching RSS feed for query '{query}': {e}")
return []
except Exception as e:
logger.error(
f"Unexpected error fetching RSS feed for query '{query}': {e}",
exc_info=True,
)
return []
def _convert_entry_to_article(self, entry: FeedEntry) -> GoogleNewsArticle:
"""
Convert a structured FeedEntry to a GoogleNewsArticle.
Args:
entry: Structured FeedEntry dataclass
Returns:
GoogleNewsArticle: Converted article object
"""
# Parse published date with fallback to current time
try:
published = (
date_parser.parse(entry.published)
if entry.published
else datetime.utcnow()
)
except (ValueError, OverflowError, TypeError):
published = datetime.utcnow()
# Extract source from title (Google News format: "Title - Source")
title_parts = entry.title.split(" - ")
title = title_parts[0] if title_parts else entry.title
source = title_parts[-1] if len(title_parts) > 1 else "Unknown"
return GoogleNewsArticle(
title=title,
link=entry.link,
published=published,
summary=entry.summary,
source=source,
guid=entry.guid,
)

View File

@ -8,8 +8,6 @@ from dataclasses import asdict, dataclass, field
from datetime import date
from pathlib import Path
from .base import BaseRepository
logger = logging.getLogger(__name__)
@ -40,10 +38,10 @@ class NewsData:
articles: list[NewsArticle]
class NewsRepository(BaseRepository):
class NewsRepository:
"""Repository for accessing cached news data with source separation."""
def __init__(self, data_dir: str, **kwargs):
def __init__(self, data_dir: str):
"""
Initialize news repository.
@ -155,7 +153,6 @@ class NewsRepository(BaseRepository):
"""
# Create source/query directory
source_dir = self.news_data_dir / source / query
self._ensure_path_exists(source_dir)
# Create JSON file path
file_path = source_dir / f"{date.isoformat()}.json"

View File

@ -7,10 +7,11 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any
from tradingagents.clients.base import BaseClient
from tradingagents.repositories.base import BaseRepository
from tradingagents.config import TradingAgentsConfig
from tradingagents.domains.news.google_news_client import GoogleNewsClient
from tradingagents.domains.news.news_repository import NewsRepository
from .base import BaseService
from .article_scraper_client import ArticleScraperClient
logger = logging.getLogger(__name__)
@ -73,16 +74,27 @@ class GlobalNewsContext:
metadata: dict[str, Any]
class NewsService(BaseService):
@dataclass
class NewsUpdateResult:
"""Result of news update operation."""
status: str
articles_found: int
articles_scraped: int
articles_failed: int
symbol: str | None = None
categories: list[str] | None = None
date_range: dict[str, str] | None = None
class NewsService:
"""Service for news data and sentiment analysis."""
def __init__(
self,
finnhub_client: BaseClient | None = None,
google_client: BaseClient | None = None,
repository: BaseRepository | None = None,
online_mode: bool = True,
**kwargs,
google_client: GoogleNewsClient,
repository: NewsRepository,
article_scraper: ArticleScraperClient,
):
"""
Initialize news service.
@ -91,46 +103,24 @@ class NewsService(BaseService):
finnhub_client: Client for Finnhub news data
google_client: Client for Google News data
repository: Repository for cached news data
online_mode: Whether to use live data
**kwargs: Additional configuration
article_scraper: Client for scraping article content
"""
super().__init__(online_mode, **kwargs)
self.finnhub_client = finnhub_client
self.google_client = google_client
self.repository = repository
self.article_scraper = article_scraper
def get_context(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None = None,
sources: list[str] | None = None,
force_refresh: bool = False,
**kwargs,
) -> NewsContext:
"""
Get news context for a query and date range.
Args:
query: Search query
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
symbol: Stock ticker symbol if company-specific
sources: List of sources to use ('finnhub', 'google', or both)
force_refresh: If True, skip local data and fetch fresh from APIs
**kwargs: Additional parameters
Returns:
NewsContext: Structured news context
"""
pass
@staticmethod
def build(_config: TradingAgentsConfig):
google_client = GoogleNewsClient()
repository = NewsRepository("")
article_scraper = ArticleScraperClient("")
return NewsService(google_client, repository, article_scraper)
def get_company_news_context(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> NewsContext:
"""
Get news context specific to a company.
Get news context specific to a company from repository (no API calls).
Args:
symbol: Stock ticker symbol
@ -141,7 +131,65 @@ class NewsService(BaseService):
Returns:
NewsContext: Company-specific news context
"""
pass
try:
logger.info(f"Getting company news context for {symbol} from repository")
# Get articles from repository
articles = []
if self.repository:
try:
# This would depend on the actual repository interface
# For now, return empty list - repository integration needs to be completed
articles = []
logger.debug(
f"Retrieved {len(articles)} articles from repository for {symbol}"
)
except Exception as e:
logger.warning(f"Error retrieving articles from repository: {e}")
articles = []
# Calculate sentiment summary from articles
sentiment_summary = self._calculate_sentiment_summary(articles)
# Extract unique sources
sources = list(
{article.source for article in articles if hasattr(article, "source")}
)
return NewsContext(
query=symbol,
symbol=symbol,
period={"start": start_date, "end": end_date},
articles=articles,
sentiment_summary=sentiment_summary,
article_count=len(articles),
sources=sources,
metadata={
"service": "news",
"data_source": "repository",
"method": "get_company_news_context",
},
)
except Exception as e:
logger.error(f"Error getting company news context for {symbol}: {e}")
# Return empty context on error
return NewsContext(
query=symbol,
symbol=symbol,
period={"start": start_date, "end": end_date},
articles=[],
sentiment_summary=SentimentScore(
score=0.0, confidence=0.0, label="neutral"
),
article_count=0,
sources=[],
metadata={
"service": "news",
"data_source": "repository",
"error": str(e),
},
)
def get_global_news_context(
self,
@ -151,7 +199,7 @@ class NewsService(BaseService):
**kwargs,
) -> GlobalNewsContext:
"""
Get global/macro news context.
Get global/macro news context from repository (no API calls).
Args:
start_date: Start date in YYYY-MM-DD format
@ -162,16 +210,425 @@ class NewsService(BaseService):
Returns:
GlobalNewsContext: Global news context
"""
# TODO: Implement global news fetching
return GlobalNewsContext(
period={"start": start_date, "end": end_date},
categories=categories or [],
articles=[],
sentiment_summary=SentimentScore(
score=0.0, confidence=0.0, label="neutral"
),
article_count=0,
sources=[],
trending_topics=[],
metadata={"service": "news", "analysis_method": "global_news"},
try:
if categories is None:
categories = ["general", "business", "politics"]
logger.info(
f"Getting global news context from repository for categories: {categories}"
)
# Get articles from repository
articles = []
if self.repository:
try:
# This would depend on the actual repository interface
# For now, return empty list - repository integration needs to be completed
articles = []
logger.debug(
f"Retrieved {len(articles)} global articles from repository"
)
except Exception as e:
logger.warning(
f"Error retrieving global articles from repository: {e}"
)
articles = []
# Calculate sentiment summary from articles
sentiment_summary = self._calculate_sentiment_summary(articles)
# Extract unique sources
sources = list(
{article.source for article in articles if hasattr(article, "source")}
)
# Extract trending topics (simplified implementation)
trending_topics = self._extract_trending_topics(articles)
return GlobalNewsContext(
period={"start": start_date, "end": end_date},
categories=categories,
articles=articles,
sentiment_summary=sentiment_summary,
article_count=len(articles),
sources=sources,
trending_topics=trending_topics,
metadata={
"service": "news",
"data_source": "repository",
"method": "get_global_news_context",
},
)
except Exception as e:
logger.error(f"Error getting global news context: {e}")
# Return empty context on error
return GlobalNewsContext(
period={"start": start_date, "end": end_date},
categories=categories or [],
articles=[],
sentiment_summary=SentimentScore(
score=0.0, confidence=0.0, label="neutral"
),
article_count=0,
sources=[],
trending_topics=[],
metadata={
"service": "news",
"data_source": "repository",
"error": str(e),
},
)
def update_company_news(self, symbol: str) -> NewsUpdateResult:
"""
Update company news by fetching RSS feeds and scraping article content.
Args:
symbol: Stock ticker symbol
Returns:
NewsUpdateResult with update status and statistics
"""
try:
logger.info(f"Updating company news for {symbol}")
if not self.google_client:
raise ValueError("Google client not configured")
# 1. Get RSS feed data
google_articles = self.google_client.get_company_news(symbol)
if not google_articles:
logger.warning(f"No articles found in RSS feed for {symbol}")
return NewsUpdateResult(
status="completed",
articles_found=0,
articles_scraped=0,
articles_failed=0,
symbol=symbol,
)
# 2. Scrape each article content and convert to ArticleData
article_data_list = []
scraped_count = 0
failed_count = 0
for i, google_article in enumerate(google_articles):
if not google_article.link:
failed_count += 1
continue
logger.info(
f"Scraping article {i + 1}/{len(google_articles)}: {google_article.link}"
)
scrape_result = self.article_scraper.scrape_article(google_article.link)
# Create ArticleData with scraped content
if scrape_result.status in ["SUCCESS", "ARCHIVE_SUCCESS"]:
article_data = ArticleData(
title=scrape_result.title or google_article.title,
content=scrape_result.content,
author=scrape_result.author,
source=google_article.source,
date=scrape_result.publish_date
or google_article.published.strftime("%Y-%m-%d"),
url=google_article.link,
sentiment=None, # Will be calculated later
)
scraped_count += 1
else:
# Create ArticleData with just RSS data if scraping failed
article_data = ArticleData(
title=google_article.title,
content=google_article.summary, # Use summary as fallback content
author="",
source=google_article.source,
date=google_article.published.strftime("%Y-%m-%d"),
url=google_article.link,
sentiment=None,
)
failed_count += 1
article_data_list.append(article_data)
# 3. Store in repository
try:
logger.info(f"Storing {len(article_data_list)} articles for {symbol}")
# Store articles (implementation depends on repository interface)
except Exception as e:
logger.error(f"Error storing articles in repository: {e}")
logger.info(
f"Company news update completed for {symbol}: {scraped_count} scraped, {failed_count} failed"
)
return NewsUpdateResult(
status="completed",
articles_found=len(google_articles),
articles_scraped=scraped_count,
articles_failed=failed_count,
symbol=symbol,
)
except Exception as e:
logger.error(f"Error updating company news for {symbol}: {e}")
raise
def update_global_news(
self, start_date: str, end_date: str, categories: list[str] | None = None
) -> NewsUpdateResult:
"""
Update global/macro news by fetching RSS feeds and scraping article content.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
categories: List of news categories to search
Returns:
NewsUpdateResult with update status and statistics
"""
try:
if categories is None:
categories = ["general", "business", "politics"]
logger.info(
f"Updating global news from {start_date} to {end_date} for categories: {categories}"
)
if not self.google_client:
raise ValueError("Google client not configured")
# 1. Get RSS feed data for all categories
google_articles = self.google_client.get_global_news(categories)
if not google_articles:
logger.warning("No articles found in RSS feeds for global news")
return NewsUpdateResult(
status="completed",
articles_found=0,
articles_scraped=0,
articles_failed=0,
categories=categories,
date_range={"start": start_date, "end": end_date},
)
# 2. Scrape each article content and convert to ArticleData
article_data_list = []
scraped_count = 0
failed_count = 0
for i, google_article in enumerate(google_articles):
if not google_article.link:
failed_count += 1
continue
logger.info(
f"Scraping global article {i + 1}/{len(google_articles)}: {google_article.link}"
)
scrape_result = self.article_scraper.scrape_article(google_article.link)
# Create ArticleData with scraped content
if scrape_result.status in ["SUCCESS", "ARCHIVE_SUCCESS"]:
article_data = ArticleData(
title=scrape_result.title or google_article.title,
content=scrape_result.content,
author=scrape_result.author,
source=google_article.source,
date=scrape_result.publish_date
or google_article.published.strftime("%Y-%m-%d"),
url=google_article.link,
sentiment=None, # Will be calculated later
)
scraped_count += 1
else:
# Create ArticleData with just RSS data if scraping failed
article_data = ArticleData(
title=google_article.title,
content=google_article.summary, # Use summary as fallback content
author="",
source=google_article.source,
date=google_article.published.strftime("%Y-%m-%d"),
url=google_article.link,
sentiment=None,
)
failed_count += 1
article_data_list.append(article_data)
# 3. Store in repository
try:
logger.info(f"Storing {len(article_data_list)} global articles")
# Store articles (implementation depends on repository interface)
except Exception as e:
logger.error(f"Error storing global articles in repository: {e}")
logger.info(
f"Global news update completed: {scraped_count} scraped, {failed_count} failed"
)
return NewsUpdateResult(
status="completed",
articles_found=len(google_articles),
articles_scraped=scraped_count,
articles_failed=failed_count,
categories=categories,
date_range={"start": start_date, "end": end_date},
)
except Exception as e:
logger.error(f"Error updating global news: {e}")
raise
def _calculate_sentiment_summary(
self, articles: list[ArticleData]
) -> SentimentScore:
"""
Calculate aggregate sentiment from article list.
Args:
articles: List of ArticleData objects
Returns:
SentimentScore: Aggregate sentiment score
"""
if not articles:
return SentimentScore(score=0.0, confidence=0.0, label="neutral")
# Simple keyword-based sentiment analysis
positive_words = {
"good",
"great",
"excellent",
"positive",
"up",
"rise",
"gain",
"profit",
"growth",
"success",
"strong",
"bullish",
"optimistic",
"boost",
"surge",
}
negative_words = {
"bad",
"terrible",
"negative",
"down",
"fall",
"loss",
"decline",
"weak",
"bearish",
"pessimistic",
"crash",
"drop",
"plunge",
"concern",
}
total_score = 0.0
scored_articles = 0
for article in articles:
if not hasattr(article, "content") or not article.content:
continue
content_lower = article.content.lower()
words = content_lower.split()
positive_count = sum(1 for word in words if word in positive_words)
negative_count = sum(1 for word in words if word in negative_words)
if positive_count + negative_count > 0:
article_score = (positive_count - negative_count) / len(words)
total_score += article_score
scored_articles += 1
if scored_articles == 0:
return SentimentScore(score=0.0, confidence=0.0, label="neutral")
avg_score = total_score / scored_articles
confidence = min(scored_articles / len(articles), 1.0)
# Normalize score to -1.0 to 1.0 range
normalized_score = max(-1.0, min(1.0, avg_score * 10))
# Determine label
if normalized_score > 0.1:
label = "positive"
elif normalized_score < -0.1:
label = "negative"
else:
label = "neutral"
return SentimentScore(
score=normalized_score, confidence=confidence, label=label
)
def _extract_trending_topics(self, articles: list[ArticleData]) -> list[str]:
"""
Extract trending topics from article titles and content.
Args:
articles: List of ArticleData objects
Returns:
List of trending topic strings
"""
if not articles:
return []
# Simple keyword extraction from titles
word_counts = {}
stop_words = {
"the",
"a",
"an",
"and",
"or",
"but",
"in",
"on",
"at",
"to",
"for",
"of",
"with",
"by",
"is",
"are",
"was",
"were",
"be",
"been",
"have",
"has",
"had",
"do",
"does",
"did",
"will",
"would",
"could",
"should",
}
for article in articles:
if hasattr(article, "title") and article.title:
words = article.title.lower().split()
for word in words:
# Clean word
word = "".join(c for c in word if c.isalnum())
if len(word) > 3 and word not in stop_words:
word_counts[word] = word_counts.get(word, 0) + 1
# Get top trending words
trending = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:5]
return [word for word, count in trending if count > 1]

View File

@ -1,740 +0,0 @@
#!/usr/bin/env python3
"""
Test NewsService with mock clients and real NewsRepository.
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath("."))
from tradingagents.clients.base import BaseClient
from tradingagents.domains.news.news_service import (
NewsContext,
NewsService,
SentimentScore,
)
from tradingagents.repositories.news_repository import NewsRepository
class MockFinnhubClient(BaseClient):
"""Mock Finnhub client that returns sample news data."""
def test_connection(self) -> bool:
return True
def get_data(self, *args, **kwargs) -> dict[str, Any]:
"""Not used directly by NewsService."""
return {}
def get_company_news(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""Return mock Finnhub company news."""
return {
"symbol": symbol,
"period": {"start": start_date, "end": end_date},
"articles": [
{
"headline": f"{symbol} Beats Q4 Earnings Expectations",
"summary": f"{symbol} reported earnings of $2.50 per share, beating analyst estimates of $2.25.",
"url": f"https://example.com/finnhub/{symbol.lower()}-earnings",
"source": "Finnhub Financial",
"date": start_date,
"entities": [symbol],
},
{
"headline": f"Insider Trading Activity at {symbol}",
"summary": f"Company executives at {symbol} have increased their holdings by 15% this quarter.",
"url": f"https://example.com/finnhub/{symbol.lower()}-insider",
"source": "Finnhub SEC Filings",
"date": end_date,
"entities": [symbol, "insider trading"],
},
],
"metadata": {
"source": "mock_finnhub",
"article_count": 2,
"retrieved_at": datetime.utcnow().isoformat(),
},
}
class MockGoogleNewsClient(BaseClient):
"""Mock Google News client that returns sample articles."""
def test_connection(self) -> bool:
return True
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""Return mock Google News data."""
article_templates = [
{
"template": "{query} Stock Surges on Positive Outlook",
"summary": "Shares of {query} rose 5% in after-hours trading following strong guidance for next quarter.",
"source": "Mock Market News",
},
{
"template": "Analysts Recommend Buy Rating for {query}",
"summary": "Three major investment firms upgraded {query} to 'Buy' with improved price targets.",
"source": "Mock Investment Daily",
},
{
"template": "{query} Announces Strategic Partnership",
"summary": "The company revealed a new collaboration that could expand market reach significantly.",
"source": "Mock Business Wire",
},
]
articles = []
for i, template in enumerate(article_templates):
current_date = datetime.strptime(start_date, "%Y-%m-%d") + timedelta(days=i)
articles.append(
{
"headline": template["template"].format(query=query),
"summary": template["summary"].format(query=query),
"url": f"https://example.com/google/{query.lower()}-{i}",
"source": template["source"],
"date": current_date.strftime("%Y-%m-%d"),
"entities": [query],
}
)
return {
"query": query,
"period": {"start": start_date, "end": end_date},
"articles": articles,
"metadata": {
"source": "mock_google_news",
"article_count": len(articles),
"retrieved_at": datetime.utcnow().isoformat(),
},
}
def test_online_mode_with_mock_clients():
"""Test NewsService in online mode with mock clients."""
print("📰 Testing NewsService - Online Mode")
# Create mock clients and real repository
mock_finnhub = MockFinnhubClient()
mock_google = MockGoogleNewsClient()
real_repo = NewsRepository("test_data")
# Create service in online mode
service = NewsService(
finnhub_client=mock_finnhub,
google_client=mock_google,
repository=real_repo,
online_mode=True,
data_dir="test_data",
)
try:
# Test company news context
context = service.get_company_news_context(
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
)
print(f"✅ Company news context created: {context.__class__.__name__}")
print(f" Symbol: {context.symbol}")
print(f" Period: {context.period}")
print(f" Articles: {len(context.articles)}")
print(f" Sentiment score: {context.sentiment_summary.score:.3f}")
print(f" Sentiment confidence: {context.sentiment_summary.confidence:.3f}")
print(f" Sources: {context.sources}")
# Validate required fields
assert context.symbol == "AAPL"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-01-05"
assert len(context.articles) > 0
assert (
context.sentiment_summary.score >= -1.0
and context.sentiment_summary.score <= 1.0
)
assert "data_quality" in context.metadata
print("✅ Basic validation passed")
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
parsed = json.loads(json_output)
print(f"✅ JSON serialization: {len(json_output)} characters")
print(f" Top-level keys: {list(parsed.keys())}")
return True
except Exception as e:
print(f"❌ Online mode test failed: {e}")
return False
def test_global_news_context():
"""Test global news functionality."""
print("\n🌍 Testing Global News Context")
mock_google = MockGoogleNewsClient()
real_repo = NewsRepository("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_google,
repository=real_repo,
online_mode=True,
data_dir="test_data",
)
try:
# Test global news with categories
context = service.get_global_news_context(
start_date="2024-01-01",
end_date="2024-01-03",
categories=["economy", "markets"],
)
print("✅ Global news context created")
print(f" Symbol: {context.symbol}") # Should be None for global news
print(f" Articles: {len(context.articles)}")
print(f" Categories searched: {context.metadata.get('categories', [])}")
print(f" Sentiment score: {context.sentiment_summary.score:.3f}")
# Validate global news structure
assert context.symbol is None # Global news shouldn't have a symbol
assert len(context.articles) > 0
assert "categories" in context.metadata
print("✅ Global news validation passed")
return True
except Exception as e:
print(f"❌ Global news test failed: {e}")
return False
def test_offline_mode_with_real_repository():
"""Test NewsService in offline mode with real repository."""
print("\n💾 Testing NewsService - Offline Mode")
# Create service in offline mode (no clients)
real_repo = NewsRepository("test_data")
service = NewsService(
finnhub_client=None,
google_client=None,
repository=real_repo,
online_mode=False,
data_dir="test_data",
)
try:
# Test offline context (will likely return empty data)
context = service.get_company_news_context(
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
)
print(f"✅ Offline context created: {context.__class__.__name__}")
print(f" Symbol: {context.symbol}")
print(f" Articles: {len(context.articles)}")
print(f" Data quality: {context.metadata.get('data_quality')}")
print(f" Service mode: online={service.is_online()}")
# Should handle empty data gracefully
assert context.symbol == "AAPL"
assert isinstance(context.articles, list)
assert isinstance(context.sentiment_summary, SentimentScore)
assert "data_quality" in context.metadata
print("✅ Offline mode graceful handling verified")
return True
except Exception as e:
print(f"❌ Offline mode test failed: {e}")
return False
def test_sentiment_analysis():
"""Test sentiment analysis functionality."""
print("\n😊 Testing Sentiment Analysis")
# Create service with custom articles for sentiment testing
class SentimentTestClient(BaseClient):
def test_connection(self):
return True
def get_data(self, query, start_date, end_date, **kwargs):
return {
"query": query,
"articles": [
{
"headline": f"{query} Soars on Excellent Earnings Report",
"summary": "Great performance with strong growth and positive outlook for investors.",
"source": "Positive News",
"date": start_date,
"entities": [query],
},
{
"headline": f"{query} Faces Challenges in Market Downturn",
"summary": "Concerns about declining revenue and poor market conditions affecting performance.",
"source": "Negative News",
"date": end_date,
"entities": [query],
},
],
}
sentiment_client = SentimentTestClient()
service = NewsService(
finnhub_client=None,
google_client=sentiment_client,
repository=None,
online_mode=True,
)
try:
context = service.get_context(
"TEST", "2024-01-01", "2024-01-02", sources=["google"]
)
print("✅ Sentiment analysis completed")
print(f" Articles analyzed: {len(context.articles)}")
print(f" Overall sentiment: {context.sentiment_summary.score:.3f}")
print(f" Confidence: {context.sentiment_summary.confidence:.3f}")
print(f" Label: {context.sentiment_summary.label}")
# Validate sentiment processing
assert len(context.articles) == 2
assert (
context.sentiment_summary.score >= -1.0
and context.sentiment_summary.score <= 1.0
)
assert (
context.sentiment_summary.confidence >= 0.0
and context.sentiment_summary.confidence <= 1.0
)
assert context.sentiment_summary.label in ["positive", "negative", "neutral"]
# Check individual article sentiments
for article in context.articles:
if article.sentiment:
assert (
article.sentiment.score >= -1.0 and article.sentiment.score <= 1.0
)
print("✅ Sentiment analysis validation passed")
return True
except Exception as e:
print(f"❌ Sentiment analysis test failed: {e}")
return False
def test_multiple_source_aggregation():
"""Test aggregation from multiple news sources."""
print("\n🔄 Testing Multiple Source Aggregation")
mock_finnhub = MockFinnhubClient()
mock_google = MockGoogleNewsClient()
real_repo = NewsRepository("test_data")
service = NewsService(
finnhub_client=mock_finnhub,
google_client=mock_google,
repository=real_repo,
online_mode=True,
)
try:
# Test with both sources
context = service.get_context(
query="MSFT",
start_date="2024-01-01",
end_date="2024-01-03",
symbol="MSFT",
sources=["finnhub", "google"],
)
print("✅ Multi-source aggregation completed")
print(f" Total articles: {len(context.articles)}")
print(f" Unique sources: {context.sources}")
print(f" Sources used: {context.metadata.get('sources_used', [])}")
# Should have articles from both sources
assert len(context.articles) > 0
assert len(context.sources) > 0
# Check that articles from different sources are present
source_counts = {}
for article in context.articles:
source = article.source
source_counts[source] = source_counts.get(source, 0) + 1
print(f" Source distribution: {source_counts}")
print("✅ Multi-source aggregation validated")
return True
except Exception as e:
print(f"❌ Multi-source test failed: {e}")
return False
def test_json_structure_validation():
"""Test detailed JSON structure validation."""
print("\n📄 Testing JSON Structure")
mock_google = MockGoogleNewsClient()
service = NewsService(
finnhub_client=None,
google_client=mock_google,
repository=None,
online_mode=True,
)
try:
context = service.get_context(
"TSLA", "2024-01-01", "2024-01-03", sources=["google"]
)
json_str = context.model_dump_json(indent=2)
data = json.loads(json_str)
# Validate required structure
required_fields = [
"symbol",
"period",
"articles",
"sentiment_summary",
"article_count",
"sources",
"metadata",
]
for field in required_fields:
assert field in data, f"Missing field: {field}"
# Validate period structure
period = data["period"]
assert "start" in period and "end" in period
# Validate articles structure
assert isinstance(data["articles"], list)
if data["articles"]:
first_article = data["articles"][0]
required_article_fields = ["headline", "source", "date"]
for field in required_article_fields:
assert field in first_article, f"Missing article field: {field}"
# Validate sentiment structure
sentiment = data["sentiment_summary"]
assert (
"score" in sentiment and "confidence" in sentiment and "label" in sentiment
)
assert -1.0 <= sentiment["score"] <= 1.0
assert 0.0 <= sentiment["confidence"] <= 1.0
# Validate metadata
metadata = data["metadata"]
assert "data_quality" in metadata
assert "service" in metadata
print("✅ JSON structure validation passed")
print(f" Fields: {list(data.keys())}")
print(f" Articles: {len(data['articles'])}")
print(f" Sentiment score: {sentiment['score']:.3f}")
return True
except Exception as e:
print(f"❌ JSON structure test failed: {e}")
return False
def test_force_refresh_parameter():
"""Test the force_refresh parameter functionality."""
print("\n🔄 Testing Force Refresh Parameter")
try:
mock_google = MockGoogleNewsClient()
real_repo = NewsRepository("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_google,
repository=real_repo,
online_mode=True,
)
# Test normal flow (should use repository if available)
normal_context = service.get_context(
"AAPL", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=False
)
# Test force refresh (should bypass repository and use client)
refresh_context = service.get_context(
"AAPL", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=True
)
# Both should return valid contexts
assert isinstance(normal_context, NewsContext)
assert isinstance(refresh_context, NewsContext)
assert normal_context.symbol == "AAPL"
assert refresh_context.symbol == "AAPL"
# Check metadata indicates source
refresh_metadata = refresh_context.metadata
assert "force_refresh" in refresh_metadata
assert refresh_metadata["force_refresh"]
print("✅ Force refresh parameter test passed")
return True
except Exception as e:
print(f"❌ Force refresh test failed: {e}")
return False
def test_local_first_strategy():
"""Test that the service checks local data first when available."""
print("\n🏠 Testing Local-First Strategy")
try:
class MockRepositoryWithData(NewsRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True # Pretend we have the data
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"query": kwargs.get("query", "TEST"),
"symbol": kwargs.get("symbol"),
"articles": [
{
"headline": "Test Article from Local Cache",
"summary": "This article came from local repository",
"source": "Local Cache",
"date": "2024-01-01",
"url": "https://local.cache/test",
"entities": ["TEST"],
}
],
"metadata": {"source": "test_repository"},
}
mock_client = MockGoogleNewsClient()
mock_repo = MockRepositoryWithData("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_client,
repository=mock_repo,
online_mode=True,
)
# Should use local data since repository has_data_for_period returns True
context = service.get_context(
"TEST", "2024-01-01", "2024-01-31", sources=["google"]
)
# Verify we used local data
assert context.metadata.get("data_source") == "local_cache"
assert len(context.articles) == 1 # From mock repository
assert context.articles[0].headline == "Test Article from Local Cache"
print("✅ Local-first strategy test passed")
return True
except Exception as e:
print(f"❌ Local-first strategy test failed: {e}")
return False
def test_local_first_fallback_to_api():
"""Test that service falls back to API when local data is insufficient."""
print("\n🔄 Testing Local-First Fallback to API")
try:
class MockRepositoryWithoutData(NewsRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return False # Pretend we don't have the data
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"query": kwargs.get("query", "TEST"),
"articles": [],
"metadata": {},
}
def store_data(
self,
symbol: str,
data: dict[str, Any],
overwrite: bool = False,
**kwargs,
) -> bool:
return True # Pretend storage was successful
mock_client = MockGoogleNewsClient()
mock_repo = MockRepositoryWithoutData("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_client,
repository=mock_repo,
online_mode=True,
)
# Should fall back to API since repository doesn't have data
context = service.get_context(
"TEST", "2024-01-01", "2024-01-31", sources=["google"]
)
# Verify we used API data
assert context.metadata.get("data_source") == "live_api"
assert len(context.articles) > 0 # From mock client
print("✅ Local-first fallback to API test passed")
return True
except Exception as e:
print(f"❌ Local-first fallback test failed: {e}")
return False
def test_force_refresh_bypasses_local_data():
"""Test that force_refresh=True bypasses local data even when available."""
print("\n⚡ Testing Force Refresh Bypasses Local Data")
try:
class MockRepositoryAlwaysHasData(NewsRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True # Always claim we have data
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"query": kwargs.get("query", "TEST"),
"symbol": kwargs.get("symbol"),
"articles": [
{
"headline": "Old Cached Article",
"summary": "This is from local cache",
"source": "Local Cache",
"date": "2024-01-01",
"url": "https://cache.local/old",
"entities": ["TEST"],
}
],
"metadata": {"source": "local"},
}
def clear_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True
def store_data(
self,
symbol: str,
data: dict[str, Any],
overwrite: bool = False,
**kwargs,
) -> bool:
return True
mock_client = MockGoogleNewsClient()
mock_repo = MockRepositoryAlwaysHasData("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_client,
repository=mock_repo,
online_mode=True,
)
# Force refresh should bypass local data
context = service.get_context(
"TEST", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=True
)
# Verify we used API data (force refresh)
assert context.metadata.get("data_source") == "live_api_refresh"
assert context.metadata.get("force_refresh")
# Should have fresh data from client, not the old cached article
assert len(context.articles) > 1 # Client returns multiple articles
print("✅ Force refresh bypasses local data test passed")
return True
except Exception as e:
print(f"❌ Force refresh bypass test failed: {e}")
return False
def main():
"""Run all NewsService tests."""
print("🧪 Testing NewsService\n")
tests = [
test_online_mode_with_mock_clients,
test_global_news_context,
test_offline_mode_with_real_repository,
test_sentiment_analysis,
test_multiple_source_aggregation,
test_json_structure_validation,
test_force_refresh_parameter,
test_local_first_strategy,
test_local_first_fallback_to_api,
test_force_refresh_bypasses_local_data,
]
passed = 0
failed = 0
for test in tests:
try:
if test():
passed += 1
else:
failed += 1
except Exception as e:
print(f"❌ Test {test.__name__} crashed: {e}")
failed += 1
print("\n📊 NewsService Test Results:")
print(f" ✅ Passed: {passed}")
print(f" ❌ Failed: {failed}")
if failed == 0:
print("🎉 All NewsService tests passed!")
else:
print("⚠️ Some tests failed - check output above")
return failed == 0
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -7,6 +7,8 @@ from dataclasses import dataclass
from enum import Enum
from typing import Any
from tradingagents.config import TradingAgentsConfig
from .reddit_client import RedditClient
from .social_media_repository import SocialMediaRepository
@ -111,6 +113,12 @@ class SocialMediaService:
self.reddit_client = reddit_client
self.repository = repository
@staticmethod
def build(_config: TradingAgentsConfig):
client = RedditClient()
repo = SocialMediaRepository("")
return SocialMediaService(client, repo)
def get_context(
self,
query: str,
@ -134,74 +142,29 @@ class SocialMediaService:
SocialContext with posts and sentiment analysis
"""
posts = []
error_info = {}
data_source = "unknown"
try:
# Local-first data strategy with force refresh option
if force_refresh:
# Skip local data, fetch fresh from APIs
posts, data_source = self._fetch_and_cache_fresh_social_data(
query, start_date, end_date, symbol, subreddits
)
else:
# Check local data first, fetch missing if needed
posts, data_source = self._get_social_data_local_first(
query, start_date, end_date, symbol, subreddits
)
except Exception as e:
logger.error(f"Error fetching social media data: {e}")
error_info = {"error": str(e)}
# Calculate sentiment and engagement metrics
sentiment_summary = self._calculate_sentiment(posts)
engagement_metrics = self._calculate_engagement_metrics(posts)
# Determine data quality based on data source
data_quality = self._determine_data_quality(
data_source=data_source,
record_count=len(posts),
has_errors=bool(error_info),
)
# Create structured engagement metrics
structured_metrics = EngagementMetrics(
total_engagement=float(engagement_metrics.get("total_engagement", 0)),
average_engagement=float(engagement_metrics.get("average_engagement", 0)),
max_engagement=float(engagement_metrics.get("max_engagement", 0)),
total_posts=int(engagement_metrics.get("total_posts", 0)),
total_engagement=0,
average_engagement=0,
max_engagement=0,
total_posts=0,
)
# Separate non-float metrics for metadata
metadata_info = {
k: v
for k, v in engagement_metrics.items()
if k
not in [
"total_engagement",
"average_engagement",
"max_engagement",
"total_posts",
]
}
return SocialContext(
symbol=symbol,
period=(start_date, end_date),
posts=posts,
engagement_metrics=structured_metrics,
sentiment_summary=sentiment_summary,
sentiment_summary=SentimentScore(score=0, confidence=0, label=""),
post_count=len(posts),
platforms=["reddit"],
metadata={
"data_quality": data_quality,
"service": "social_media",
"subreddits": subreddits or [],
"data_source": data_source,
"force_refresh": force_refresh,
**metadata_info,
**error_info,
},
)

View File

@ -1,378 +0,0 @@
#!/usr/bin/env python3
"""
Test SocialMediaService with mock RedditClient and real SocialRepository.
"""
import os
import sys
from datetime import datetime, timedelta
from typing import Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath("."))
from tradingagents.clients.base import BaseClient
from tradingagents.domains.socialmedia.social_media_service import (
DataQuality,
PostData,
SentimentScore,
SocialContext,
SocialMediaService,
)
from tradingagents.repositories.social_repository import SocialRepository
class MockRedditClient(BaseClient):
"""Mock Reddit client that returns sample social media data."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connection_works = True
def test_connection(self) -> bool:
return self.connection_works
def get_data(self, *args, **kwargs) -> dict[str, Any]:
"""Not used directly by SocialMediaService."""
return {}
def search_posts(
self,
query: str,
subreddit_names: list[str],
limit: int = 25,
time_filter: str = "week",
) -> list[dict[str, Any]]:
"""Return mock Reddit search results."""
posts = []
# Use fixed dates that will work with our date filter
base_date = datetime(2024, 1, 2) # Within our test range
for i, subreddit in enumerate(
subreddit_names[:2]
): # Limit to 2 subreddits for testing
posts.extend(
[
{
"title": f"{query} to the moon! 🚀",
"content": f"DD on {query}: Strong fundamentals, great earnings beat. Buy and hold!",
"url": f"https://reddit.com/r/{subreddit}/post1",
"upvotes": 1500 - (i * 100),
"score": 1450 - (i * 100),
"num_comments": 234,
"created_utc": (base_date + timedelta(hours=i)).timestamp(),
"subreddit": subreddit,
"author": f"WSBtrader{i}",
"posted_date": (base_date + timedelta(hours=i)).strftime(
"%Y-%m-%d"
),
},
{
"title": f"Why I'm bearish on {query}",
"content": f"Overvalued, competition increasing, margins declining. Time to sell {query}.",
"url": f"https://reddit.com/r/{subreddit}/post2",
"upvotes": 800 - (i * 50),
"score": 750 - (i * 50),
"num_comments": 156,
"created_utc": (base_date + timedelta(hours=i + 1)).timestamp(),
"subreddit": subreddit,
"author": f"BearishTrader{i}",
"posted_date": (base_date + timedelta(hours=i + 1)).strftime(
"%Y-%m-%d"
),
},
]
)
return posts
def get_top_posts(
self, subreddit_names: list[str], limit: int = 25, time_filter: str = "week"
) -> list[dict[str, Any]]:
"""Return mock top posts from subreddits."""
posts = []
# Use fixed dates that will work with our date filter
base_date = datetime(2024, 1, 2) # Within our test range
for subreddit in subreddit_names[:2]:
posts.append(
{
"title": "Market Update: Tech stocks rally continues",
"content": "FAANG stocks leading the charge. SPY hit new ATH. Bull market confirmed.",
"url": f"https://reddit.com/r/{subreddit}/top1",
"upvotes": 2500,
"score": 2400,
"num_comments": 456,
"created_utc": base_date.timestamp(),
"subreddit": subreddit,
"author": "MarketWatcher",
"posted_date": base_date.strftime("%Y-%m-%d"),
}
)
return posts
def filter_posts_by_date(
self, posts: list[dict[str, Any]], start_date: str, end_date: str
) -> list[dict[str, Any]]:
"""Filter posts by date range."""
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
filtered = []
for post in posts:
if "posted_date" in post:
post_dt = datetime.strptime(post["posted_date"], "%Y-%m-%d")
if start_dt <= post_dt <= end_dt:
filtered.append(post)
return filtered
def test_online_mode_with_mock_reddit():
"""Test SocialMediaService in online mode with mock Reddit client."""
# Create mock client and real repository
mock_reddit = MockRedditClient()
real_repo = SocialRepository("test_data")
# Create service in online mode
service = SocialMediaService(
reddit_client=mock_reddit,
repository=real_repo,
online_mode=True,
data_dir="test_data",
)
# Test company-specific social context
context = service.get_company_social_context(
symbol="TSLA",
start_date="2024-01-01",
end_date="2024-01-05",
subreddits=["wallstreetbets", "stocks"],
force_refresh=True,
)
# Validate context structure
assert isinstance(context, SocialContext)
assert context.symbol == "TSLA"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-01-05"
assert len(context.posts) > 0
assert isinstance(context.sentiment_summary, SentimentScore)
assert context.post_count == len(context.posts)
assert "data_quality" in context.metadata
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
assert len(json_output) > 0
# Validate individual posts
for post in context.posts:
assert isinstance(post, PostData)
assert post.title
assert post.author
assert post.date
assert post.score >= 0
def test_global_social_trends():
"""Test global social media trends functionality."""
mock_reddit = MockRedditClient()
real_repo = SocialRepository("test_data")
service = SocialMediaService(
reddit_client=mock_reddit, repository=real_repo, online_mode=True
)
# Test global trends
context = service.get_global_trends(
start_date="2024-01-01",
end_date="2024-01-03",
subreddits=["investing", "stocks", "wallstreetbets"],
force_refresh=True,
)
# Validate global context
assert context.symbol is None # Global trends have no specific symbol
assert len(context.posts) > 0
assert "reddit" in context.platforms
assert "subreddits" in context.metadata
def test_sentiment_analysis():
"""Test sentiment analysis on social posts."""
# Create service with posts that have clear sentiment
class SentimentTestClient(MockRedditClient):
def search_posts(self, query, subreddit_names, limit=25, time_filter="week"):
return [
{
"title": f"{query} is the best investment ever! 🚀🚀🚀",
"content": "Amazing earnings, incredible growth, bullish AF!",
"upvotes": 5000,
"score": 4900,
"num_comments": 500,
"subreddit": "wallstreetbets",
"author": "BullishTrader",
"posted_date": "2024-01-01",
},
{
"title": f"WARNING: {query} is about to crash hard",
"content": "Terrible fundamentals, overvalued, sell now before it's too late!",
"upvotes": 100,
"score": 50,
"num_comments": 30,
"subreddit": "stocks",
"author": "BearishAnalyst",
"posted_date": "2024-01-01",
},
]
sentiment_client = SentimentTestClient()
service = SocialMediaService(
reddit_client=sentiment_client, repository=None, online_mode=True
)
context = service.get_context("GME", "2024-01-01", "2024-01-02")
# Check sentiment analysis
assert context.sentiment_summary.score != 0 # Should have some sentiment
assert context.sentiment_summary.confidence > 0
assert context.sentiment_summary.label in ["positive", "negative", "neutral"]
# Check individual post sentiments
for post in context.posts:
if post.sentiment:
assert -1.0 <= post.sentiment.score <= 1.0
def test_offline_mode():
"""Test SocialMediaService in offline mode."""
real_repo = SocialRepository("test_data")
service = SocialMediaService(
reddit_client=None, repository=real_repo, online_mode=False
)
# Should handle offline gracefully
context = service.get_context("AAPL", "2024-01-01", "2024-01-05", symbol="AAPL")
assert context.symbol == "AAPL"
assert isinstance(context.posts, list)
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_engagement_metrics():
"""Test calculation of engagement metrics."""
mock_reddit = MockRedditClient()
service = SocialMediaService(
reddit_client=mock_reddit, repository=None, online_mode=True
)
context = service.get_company_social_context(
symbol="NVDA",
start_date="2024-01-01",
end_date="2024-01-02",
subreddits=["nvidia", "stocks"],
)
# Check engagement metrics in the context
assert len(context.engagement_metrics) > 0
assert (
"total_engagement" in context.engagement_metrics
or "total_engagement" in context.metadata
)
# Verify post scores
for post in context.posts:
# Posts should have score and comments
assert post.score >= 0
assert post.comments >= 0
def test_subreddit_filtering():
"""Test filtering by specific subreddits."""
mock_reddit = MockRedditClient()
service = SocialMediaService(
reddit_client=mock_reddit, repository=None, online_mode=True
)
# Test with specific subreddits
context = service.get_company_social_context(
symbol="AMD",
start_date="2024-01-01",
end_date="2024-01-02",
subreddits=["AMD_Stock", "wallstreetbets"],
)
# Check that posts are from requested subreddits
subreddit_set = set()
for post in context.posts:
if post.subreddit:
subreddit_set.add(post.subreddit)
assert len(subreddit_set) > 0
assert all(sub in ["AMD_Stock", "wallstreetbets"] for sub in subreddit_set)
def test_error_handling():
"""Test error handling with broken client."""
class BrokenRedditClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("Reddit API error")
def search_posts(self, *args, **kwargs):
raise Exception("Reddit API error")
def get_top_posts(self, *args, **kwargs):
raise Exception("Reddit API error")
broken_client = BrokenRedditClient()
service = SocialMediaService(
reddit_client=broken_client, repository=None, online_mode=True
)
# Should handle errors gracefully
context = service.get_context("TSLA", "2024-01-01", "2024-01-02", symbol="TSLA")
assert context.symbol == "TSLA"
assert len(context.posts) == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_json_structure():
"""Test JSON structure of social context."""
mock_reddit = MockRedditClient()
service = SocialMediaService(
reddit_client=mock_reddit, repository=None, online_mode=True
)
context = service.get_context("PLTR", "2024-01-01", "2024-01-02")
json_data = context.model_dump()
# Validate required fields
required_fields = [
"symbol",
"period",
"posts",
"sentiment_summary",
"post_count",
"platforms",
"metadata",
]
for field in required_fields:
assert field in json_data
# Validate posts structure
if json_data["posts"]:
first_post = json_data["posts"][0]
post_fields = ["title", "author", "date", "score"]
for field in post_fields:
assert field in first_post
# Validate sentiment structure
sentiment = json_data["sentiment_summary"]
assert "score" in sentiment
assert "confidence" in sentiment
assert "label" in sentiment

View File

@ -1,6 +1,6 @@
# TradingAgents/graph/conditional_logic.py
from tradingagents.agents.utils.agent_states import AgentState
from tradingagents.agents.libs.agent_states import AgentState
class ConditionalLogic:

View File

@ -2,7 +2,7 @@
from typing import Any
from tradingagents.agents.utils.agent_states import (
from tradingagents.agents.libs.agent_states import (
InvestDebateState,
RiskDebateState,
)

View File

@ -22,8 +22,8 @@ from tradingagents.agents import (
create_social_media_analyst,
create_trader,
)
from tradingagents.agents.utils.agent_states import AgentState
from tradingagents.agents.utils.agent_utils import Toolkit
from tradingagents.agents.libs.agent_states import AgentState
from tradingagents.agents.libs.agent_toolkit import AgentToolkit
from .conditional_logic import ConditionalLogic
@ -35,7 +35,7 @@ class GraphSetup:
self,
quick_thinking_llm: ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI,
deep_thinking_llm: ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI,
toolkit: Toolkit,
toolkit: AgentToolkit,
tool_nodes: dict[str, ToolNode],
bull_memory,
bear_memory,

View File

@ -9,9 +9,16 @@ from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from tradingagents.agents.utils.agent_utils import Toolkit
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.libs.agent_toolkit import AgentToolkit
from tradingagents.agents.libs.memory import FinancialSituationMemory
from tradingagents.config import TradingAgentsConfig
from tradingagents.domains.marketdata.fundamental_data_service import (
FundamentalDataService,
)
from tradingagents.domains.marketdata.insider_data_service import InsiderDataService
from tradingagents.domains.marketdata.market_data_service import MarketDataService
from tradingagents.domains.news.news_service import NewsService
from tradingagents.domains.socialmedia.social_media_service import SocialMediaService
from .conditional_logic import ConditionalLogic
from .propagation import Propagator
@ -77,7 +84,18 @@ class TradingAgentsGraph:
else:
raise ValueError(f"Unsupported LLM provider: {self.config.llm_provider}")
self.toolkit = Toolkit(config=self.config)
news_service = NewsService.build(self.config)
social_media_service = SocialMediaService.build(self.config)
market_data_service = MarketDataService.build(self.config)
fundamental_data_service = FundamentalDataService.build(self.config)
insider_data_service = InsiderDataService.build(self.config)
self.toolkit = AgentToolkit(
news_service,
social_media_service,
market_data_service,
fundamental_data_service,
insider_data_service,
)
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
@ -125,42 +143,28 @@ class TradingAgentsGraph:
return {
"market": ToolNode(
[
# online tools
self.toolkit.get_YFin_data_online,
self.toolkit.get_stockstats_indicators_report_online,
# offline tools
self.toolkit.get_YFin_data,
self.toolkit.get_stockstats_indicators_report,
self.toolkit.get_market_data,
self.toolkit.get_ta_report,
]
),
"social": ToolNode(
[
# online tools
self.toolkit.get_stock_news_openai,
# offline tools
self.toolkit.get_reddit_stock_info,
self.toolkit.get_socialmedia_stock_info,
]
),
"news": ToolNode(
[
# online tools
self.toolkit.get_global_news_openai,
self.toolkit.get_google_news,
# offline tools
self.toolkit.get_finnhub_news,
self.toolkit.get_reddit_news,
self.toolkit.get_global_news,
self.toolkit.get_news,
]
),
"fundamentals": ToolNode(
[
# online tools
self.toolkit.get_fundamentals_openai,
# offline tools
self.toolkit.get_finnhub_company_insider_sentiment,
self.toolkit.get_finnhub_company_insider_transactions,
self.toolkit.get_simfin_balance_sheet,
self.toolkit.get_simfin_cashflow,
self.toolkit.get_simfin_income_stmt,
self.toolkit.get_insider_sentiment,
self.toolkit.get_insider_transactions,
self.toolkit.get_balance_sheet,
self.toolkit.get_cashflow,
self.toolkit.get_income_stmt,
]
),
}

112
uv.lock
View File

@ -633,6 +633,17 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/32/b6/7517af5234378518f27ad35a7b24af9591bc500b8c1780929c1295999eb6/fastapi-0.115.9-py3-none-any.whl", hash = "sha256:4a439d7923e4de796bcc88b64e9754340fcd1574673cbd865ba8a99fe0d28c56", size = 94919, upload-time = "2025-02-27T16:43:40.537Z" },
]
[[package]]
name = "feedfinder2"
version = "0.0.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "beautifulsoup4" },
{ name = "requests" },
{ name = "six" },
]
sdist = { url = "https://files.pythonhosted.org/packages/35/82/1251fefec3bb4b03fd966c7e7f7a41c9fc2bb00d823a34c13f847fd61406/feedfinder2-0.0.4.tar.gz", hash = "sha256:3701ee01a6c85f8b865a049c30ba0b4608858c803fe8e30d1d289fdbe89d0efe", size = 3297, upload-time = "2016-01-25T15:09:17.492Z" }
[[package]]
name = "feedparser"
version = "6.0.11"
@ -1038,6 +1049,12 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
]
[[package]]
name = "jieba3k"
version = "0.35.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a9/cb/2c8332bcdc14d33b0bedd18ae0a4981a069c3513e445120da3c3f23a8aaa/jieba3k-0.35.1.zip", hash = "sha256:980a4f2636b778d312518066be90c7697d410dd5a472385f5afced71a2db1c10", size = 7423646, upload-time = "2014-11-15T05:47:47.978Z" }
[[package]]
name = "jinja2"
version = "3.1.6"
@ -1095,6 +1112,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
]
[[package]]
name = "joblib"
version = "1.5.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/dc/fe/0f5a938c54105553436dbff7a61dc4fed4b1b2c98852f8833beaf4d5968f/joblib-1.5.1.tar.gz", hash = "sha256:f4f86e351f39fe3d0d32a9f2c3d8af1ee4cec285aafcb27003dda5205576b444", size = 330475, upload-time = "2025-05-23T12:04:37.097Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7d/4f/1195bbac8e0c2acc5f740661631d8d750dc38d4a32b23ee5df3cde6f4e0d/joblib-1.5.1-py3-none-any.whl", hash = "sha256:4719a31f054c7d766948dcd83e9613686b27114f190f717cec7eaa2084f8a74a", size = 307746, upload-time = "2025-05-23T12:04:35.124Z" },
]
[[package]]
name = "jsonpatch"
version = "1.33"
@ -1673,6 +1699,45 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" },
]
[[package]]
name = "newspaper3k"
version = "0.2.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "beautifulsoup4" },
{ name = "cssselect" },
{ name = "feedfinder2" },
{ name = "feedparser" },
{ name = "jieba3k" },
{ name = "lxml" },
{ name = "nltk" },
{ name = "pillow" },
{ name = "python-dateutil" },
{ name = "pyyaml" },
{ name = "requests" },
{ name = "tinysegmenter" },
{ name = "tldextract" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ce/fb/8f8525be0cafa48926e85b0c06a7cb3e2a892d340b8036f8c8b1b572df1c/newspaper3k-0.2.8.tar.gz", hash = "sha256:9f1bd3e1fb48f400c715abf875cc7b0a67b7ddcd87f50c9aeeb8fcbbbd9004fb", size = 205685, upload-time = "2018-09-28T04:58:23.53Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d7/b9/51afecb35bb61b188a4b44868001de348a0e8134b4dfa00ffc191567c4b9/newspaper3k-0.2.8-py3-none-any.whl", hash = "sha256:44a864222633d3081113d1030615991c3dbba87239f6bbf59d91240f71a22e3e", size = 211132, upload-time = "2018-09-28T04:58:18.847Z" },
]
[[package]]
name = "nltk"
version = "3.9.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "joblib" },
{ name = "regex" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691, upload-time = "2024-08-18T19:48:37.769Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442, upload-time = "2024-08-18T19:48:21.909Z" },
]
[[package]]
name = "nodeenv"
version = "1.9.1"
@ -3058,6 +3123,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847, upload-time = "2025-06-09T16:43:05.728Z" },
]
[[package]]
name = "requests-file"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "requests" },
]
sdist = { url = "https://files.pythonhosted.org/packages/72/97/bf44e6c6bd8ddbb99943baf7ba8b1a8485bcd2fe0e55e5708d7fee4ff1ae/requests_file-2.1.0.tar.gz", hash = "sha256:0f549a3f3b0699415ac04d167e9cb39bccfb730cb832b4d20be3d9867356e658", size = 6891, upload-time = "2024-05-21T16:28:00.24Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d7/25/dd878a121fcfdf38f52850f11c512e13ec87c2ea72385933818e5b6c15ce/requests_file-2.1.0-py2.py3-none-any.whl", hash = "sha256:cf270de5a4c5874e84599fc5778303d496c10ae5e870bfa378818f35d21bda5c", size = 4244, upload-time = "2024-05-21T16:27:57.733Z" },
]
[[package]]
name = "requests-oauthlib"
version = "2.0.0"
@ -3329,6 +3406,16 @@ version = "2.0.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8d/dd/d4dd75843692690d81f0a4b929212a1614b25d4896aa7c72f4c3546c7e3d/syncer-2.0.3.tar.gz", hash = "sha256:4340eb54b54368724a78c5c0763824470201804fe9180129daf3635cb500550f", size = 11512, upload-time = "2023-05-08T07:50:17.963Z" }
[[package]]
name = "ta-lib"
version = "0.6.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
{ name = "setuptools" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ba/97/a49816dd468a18ee080cf3a04640772a9f6321790d4049cece2490c4b7ad/ta_lib-0.6.4.tar.gz", hash = "sha256:08f55bc5771a6d1ceb1a2b713aad7b05f04eb0061e980c9113571c532d32e9cb", size = 381774, upload-time = "2025-06-08T15:28:15.452Z" }
[[package]]
name = "tenacity"
version = "9.1.2"
@ -3356,6 +3443,27 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669, upload-time = "2025-02-14T06:02:47.341Z" },
]
[[package]]
name = "tinysegmenter"
version = "0.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/17/82/86982e4b6d16e4febc79c2a1d68ee3b707e8a020c5d2bc4af8052d0f136a/tinysegmenter-0.3.tar.gz", hash = "sha256:ed1f6d2e806a4758a73be589754384cbadadc7e1a414c81a166fc9adf2d40c6d", size = 16893, upload-time = "2017-07-23T11:18:29.85Z" }
[[package]]
name = "tldextract"
version = "5.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
{ name = "idna" },
{ name = "requests" },
{ name = "requests-file" },
]
sdist = { url = "https://files.pythonhosted.org/packages/97/78/182641ea38e3cfd56e9c7b3c0d48a53d432eea755003aa544af96403d4ac/tldextract-5.3.0.tar.gz", hash = "sha256:b3d2b70a1594a0ecfa6967d57251527d58e00bb5a91a74387baa0d87a0678609", size = 128502, upload-time = "2025-04-22T06:19:37.491Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/67/7c/ea488ef48f2f544566947ced88541bc45fae9e0e422b2edbf165ee07da99/tldextract-5.3.0-py3-none-any.whl", hash = "sha256:f70f31d10b55c83993f55e91ecb7c5d84532a8972f22ec578ecfbe5ea2292db2", size = 107384, upload-time = "2025-04-22T06:19:36.304Z" },
]
[[package]]
name = "tokenizers"
version = "0.21.1"
@ -3483,6 +3591,7 @@ dependencies = [
{ name = "langchain-google-genai" },
{ name = "langchain-openai" },
{ name = "langgraph" },
{ name = "newspaper3k" },
{ name = "pandas" },
{ name = "parsel" },
{ name = "praw" },
@ -3494,6 +3603,7 @@ dependencies = [
{ name = "rich" },
{ name = "setuptools" },
{ name = "stockstats" },
{ name = "ta-lib" },
{ name = "tqdm" },
{ name = "tushare" },
{ name = "typer" },
@ -3532,6 +3642,7 @@ requires-dist = [
{ name = "langchain-google-genai", specifier = ">=2.1.5" },
{ name = "langchain-openai", specifier = ">=0.3.23" },
{ name = "langgraph", specifier = ">=0.4.8" },
{ name = "newspaper3k", specifier = ">=0.2.8" },
{ name = "pandas", specifier = ">=2.3.0" },
{ name = "parsel", specifier = ">=1.10.0" },
{ name = "praw", specifier = ">=7.8.1" },
@ -3548,6 +3659,7 @@ requires-dist = [
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.0" },
{ name = "setuptools", specifier = ">=80.9.0" },
{ name = "stockstats", specifier = ">=0.6.5" },
{ name = "ta-lib", specifier = ">=0.4.28" },
{ name = "tqdm", specifier = ">=4.67.1" },
{ name = "tushare", specifier = ">=1.4.21" },
{ name = "typer", specifier = ">=0.12.0" },