omfg basic stub out with passable lint and types
This commit is contained in:
parent
c93ffb6452
commit
775258b950
343
README.md
343
README.md
|
|
@ -418,6 +418,42 @@ LangGraph-based workflow management:
|
|||
- **Constants**: UPPER_CASE (e.g., `DEFAULT_CONFIG`)
|
||||
- **Imports**: Standard library first, third-party, then local imports (langchain, tradingagents modules)
|
||||
|
||||
#### Data Structure Guidelines
|
||||
**MANDATORY: Always use dataclasses for method returns**
|
||||
- **Never return**: `dict`, `str`, `Any`, or unstructured data from public methods
|
||||
- **Always return**: Properly typed dataclasses with clear field definitions
|
||||
- **Rationale**: Provides type safety, IDE support, clear contracts, and prevents runtime errors
|
||||
|
||||
**Examples**:
|
||||
```python
|
||||
# ❌ BAD - Dictionary returns
|
||||
def update_news() -> dict[str, Any]:
|
||||
return {"status": "completed", "count": 5}
|
||||
|
||||
# ✅ GOOD - Dataclass returns
|
||||
@dataclass
|
||||
class NewsUpdateResult:
|
||||
status: str
|
||||
articles_found: int
|
||||
articles_scraped: int
|
||||
articles_failed: int
|
||||
|
||||
def update_news() -> NewsUpdateResult:
|
||||
return NewsUpdateResult(
|
||||
status="completed",
|
||||
articles_found=10,
|
||||
articles_scraped=8,
|
||||
articles_failed=2
|
||||
)
|
||||
```
|
||||
|
||||
**Dataclass Best Practices**:
|
||||
- Use `@dataclass` decorator for all return value structures
|
||||
- Include type hints for all fields
|
||||
- Use `| None` for optional fields (modern Python 3.10+ syntax)
|
||||
- Group related dataclasses in the same module
|
||||
- Prefer immutable dataclasses with `frozen=True` for value objects
|
||||
|
||||
#### Ruff Formatting & Linting Rules
|
||||
**Formatting** (`mise run format`):
|
||||
- **Line length**: 88 characters maximum
|
||||
|
|
@ -534,6 +570,313 @@ service.update_market_data("AAPL", "2024-01-01", "2024-01-31")
|
|||
- Questionnaire-driven configuration collection
|
||||
- Real-time streaming of analysis results
|
||||
|
||||
### Progressive Development Framework
|
||||
|
||||
This framework ensures agents create type-safe, testable code through incremental development. It emphasizes building one component at a time with proper testing and type safety.
|
||||
|
||||
#### Core Principles
|
||||
|
||||
1. **Service-First Development**: Start with business logic in the service layer
|
||||
2. **Stub Dependencies**: Create placeholder methods that return proper dataclasses
|
||||
3. **Progressive Implementation**: Implement one dependency (client OR repository) at a time
|
||||
4. **Constructor Injection**: Dependencies passed through constructor for testability
|
||||
5. **Dataclass Returns**: All public methods return properly typed dataclasses
|
||||
6. **Test-Driven Development**: Write tests first, implement to make them pass
|
||||
|
||||
#### Development Process
|
||||
|
||||
**Step 1: Design Domain Models**
|
||||
```python
|
||||
# models.py - Define all dataclasses first
|
||||
@dataclass
|
||||
class DomainEntity:
|
||||
id: str
|
||||
name: str
|
||||
created_at: datetime
|
||||
|
||||
@dataclass
|
||||
class DomainContext:
|
||||
entities: list[DomainEntity]
|
||||
metadata: dict[str, Any]
|
||||
quality_score: float
|
||||
|
||||
@dataclass
|
||||
class UpdateResult:
|
||||
status: str
|
||||
entities_processed: int
|
||||
entities_failed: int
|
||||
```
|
||||
|
||||
**Step 2: Create Service with Business Logic**
|
||||
```python
|
||||
# service.py - Main business logic with stub dependencies
|
||||
class DomainService:
|
||||
def __init__(self, client: DomainClient, repository: DomainRepository):
|
||||
self.client = client
|
||||
self.repository = repository
|
||||
|
||||
def get_context(self, symbol: str, start_date: str, end_date: str) -> DomainContext:
|
||||
# Implement business logic flow
|
||||
entities = self.repository.get_entities(symbol, start_date, end_date)
|
||||
|
||||
# Process and transform data
|
||||
processed_entities = self._process_entities(entities)
|
||||
|
||||
# Calculate quality metrics
|
||||
quality_score = self._calculate_quality(processed_entities)
|
||||
|
||||
return DomainContext(
|
||||
entities=processed_entities,
|
||||
metadata={"symbol": symbol, "date_range": f"{start_date} to {end_date}"},
|
||||
quality_score=quality_score
|
||||
)
|
||||
|
||||
def update_data(self, symbol: str, start_date: str, end_date: str) -> UpdateResult:
|
||||
# Business logic for updating data
|
||||
raw_data = self.client.fetch_data(symbol, start_date, end_date)
|
||||
entities = self._transform_raw_data(raw_data)
|
||||
|
||||
processed = 0
|
||||
failed = 0
|
||||
for entity in entities:
|
||||
try:
|
||||
self.repository.save_entity(entity)
|
||||
processed += 1
|
||||
except Exception:
|
||||
failed += 1
|
||||
|
||||
return UpdateResult(
|
||||
status="completed",
|
||||
entities_processed=processed,
|
||||
entities_failed=failed
|
||||
)
|
||||
|
||||
def _process_entities(self, entities: list[DomainEntity]) -> list[DomainEntity]:
|
||||
# Private method for business logic
|
||||
return entities # Stub implementation
|
||||
|
||||
def _calculate_quality(self, entities: list[DomainEntity]) -> float:
|
||||
# Private method for quality calculation
|
||||
return 1.0 # Stub implementation
|
||||
```
|
||||
|
||||
**Step 3: Create Stub Dependencies**
|
||||
```python
|
||||
# client.py - Stub client that returns proper dataclasses
|
||||
class DomainClient:
|
||||
def fetch_data(self, symbol: str, start_date: str, end_date: str) -> list[dict[str, Any]]:
|
||||
# Stub implementation - returns realistic structure
|
||||
return [
|
||||
{"id": "1", "name": f"{symbol}_entity", "created_at": "2024-01-01T00:00:00Z"},
|
||||
{"id": "2", "name": f"{symbol}_entity_2", "created_at": "2024-01-02T00:00:00Z"}
|
||||
]
|
||||
|
||||
# repository.py - Stub repository that returns proper dataclasses
|
||||
class DomainRepository:
|
||||
def __init__(self, cache_dir: str):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def get_entities(self, symbol: str, start_date: str, end_date: str) -> list[DomainEntity]:
|
||||
# Stub implementation - returns proper dataclasses
|
||||
return [
|
||||
DomainEntity(id="1", name=f"{symbol}_cached", created_at=datetime.now()),
|
||||
DomainEntity(id="2", name=f"{symbol}_cached_2", created_at=datetime.now())
|
||||
]
|
||||
|
||||
def save_entity(self, entity: DomainEntity) -> None:
|
||||
# Stub implementation
|
||||
pass
|
||||
```
|
||||
|
||||
**Step 4: Write Comprehensive Tests**
|
||||
```python
|
||||
# service_test.py - Test the service with mock dependencies
|
||||
from unittest.mock import Mock
|
||||
import pytest
|
||||
|
||||
def test_get_context_with_mock_dependencies():
|
||||
"""Test service business logic with mocked dependencies."""
|
||||
# Mock the dependencies
|
||||
mock_client = Mock()
|
||||
mock_repository = Mock()
|
||||
|
||||
# Configure mock returns
|
||||
mock_repository.get_entities.return_value = [
|
||||
DomainEntity(id="1", name="TEST_entity", created_at=datetime(2024, 1, 1))
|
||||
]
|
||||
|
||||
# Create service with mocks
|
||||
service = DomainService(client=mock_client, repository=mock_repository)
|
||||
|
||||
# Test the business logic
|
||||
context = service.get_context("TEST", "2024-01-01", "2024-01-31")
|
||||
|
||||
# Validate structure and business logic
|
||||
assert isinstance(context, DomainContext)
|
||||
assert context.metadata["symbol"] == "TEST"
|
||||
assert context.quality_score > 0
|
||||
assert len(context.entities) > 0
|
||||
|
||||
# Verify repository was called correctly
|
||||
mock_repository.get_entities.assert_called_once_with("TEST", "2024-01-01", "2024-01-31")
|
||||
|
||||
def test_update_data_with_mock_dependencies():
|
||||
"""Test update business logic with mocked dependencies."""
|
||||
mock_client = Mock()
|
||||
mock_repository = Mock()
|
||||
|
||||
# Configure mock client to return raw data
|
||||
mock_client.fetch_data.return_value = [
|
||||
{"id": "1", "name": "TEST_raw", "created_at": "2024-01-01T00:00:00Z"}
|
||||
]
|
||||
|
||||
service = DomainService(client=mock_client, repository=mock_repository)
|
||||
|
||||
result = service.update_data("TEST", "2024-01-01", "2024-01-31")
|
||||
|
||||
# Validate business logic results
|
||||
assert isinstance(result, UpdateResult)
|
||||
assert result.status == "completed"
|
||||
assert result.entities_processed >= 0
|
||||
|
||||
# Verify client and repository interactions
|
||||
mock_client.fetch_data.assert_called_once()
|
||||
mock_repository.save_entity.assert_called()
|
||||
```
|
||||
|
||||
**Step 5: Implement One Dependency at a Time**
|
||||
|
||||
Choose either client OR repository to implement first:
|
||||
|
||||
```python
|
||||
# Option A: Implement client first
|
||||
class DomainClient:
|
||||
def __init__(self, api_key: str):
|
||||
self.api_key = api_key
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"User-Agent": "TradingAgents/1.0"})
|
||||
|
||||
def fetch_data(self, symbol: str, start_date: str, end_date: str) -> list[dict[str, Any]]:
|
||||
# Real implementation with error handling
|
||||
try:
|
||||
response = self.session.get(
|
||||
f"https://api.example.com/data/{symbol}",
|
||||
params={"start": start_date, "end": end_date},
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
except requests.RequestException as e:
|
||||
raise DomainClientError(f"Failed to fetch data: {e}")
|
||||
|
||||
# Option B: Implement repository first
|
||||
class DomainRepository:
|
||||
def __init__(self, cache_dir: str):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_entities(self, symbol: str, start_date: str, end_date: str) -> list[DomainEntity]:
|
||||
# Real implementation with file I/O
|
||||
cache_file = self.cache_dir / f"{symbol}_{start_date}_{end_date}.json"
|
||||
|
||||
if not cache_file.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(cache_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
return [
|
||||
DomainEntity(
|
||||
id=item["id"],
|
||||
name=item["name"],
|
||||
created_at=datetime.fromisoformat(item["created_at"])
|
||||
)
|
||||
for item in data
|
||||
]
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
raise DomainRepositoryError(f"Failed to load cached data: {e}")
|
||||
```
|
||||
|
||||
**Step 6: Test Real Implementation**
|
||||
```python
|
||||
def test_real_client_integration():
|
||||
"""Test real client implementation."""
|
||||
client = DomainClient(api_key="test_key")
|
||||
|
||||
# Test with real HTTP calls (or use responses library for mocking)
|
||||
with responses.RequestsMock() as rsps:
|
||||
rsps.add(
|
||||
responses.GET,
|
||||
"https://api.example.com/data/TEST",
|
||||
json={"data": [{"id": "1", "name": "TEST", "created_at": "2024-01-01T00:00:00Z"}]},
|
||||
status=200
|
||||
)
|
||||
|
||||
result = client.fetch_data("TEST", "2024-01-01", "2024-01-31")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["id"] == "1"
|
||||
|
||||
def test_real_repository_integration():
|
||||
"""Test real repository implementation."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
repo = DomainRepository(temp_dir)
|
||||
|
||||
# Test saving and loading
|
||||
entity = DomainEntity(id="1", name="TEST", created_at=datetime.now())
|
||||
repo.save_entity(entity)
|
||||
|
||||
entities = repo.get_entities("TEST", "2024-01-01", "2024-01-31")
|
||||
assert len(entities) == 1
|
||||
assert entities[0].id == "1"
|
||||
```
|
||||
|
||||
**Step 7: Iterate and Refine**
|
||||
|
||||
1. Run tests after each implementation
|
||||
2. Refactor business logic as needed
|
||||
3. Add error handling and edge cases
|
||||
4. Implement the remaining dependency
|
||||
5. Add integration tests with both real dependencies
|
||||
|
||||
#### Directory Structure
|
||||
|
||||
```
|
||||
domain_name/
|
||||
├── models.py # Dataclasses only - no business logic
|
||||
├── client.py # External API integration
|
||||
├── repository.py # Data persistence and caching
|
||||
├── service.py # Main business logic coordinator
|
||||
└── service_test.py # Comprehensive test suite
|
||||
```
|
||||
|
||||
#### Benefits of This Approach
|
||||
|
||||
1. **Type Safety**: All interfaces defined upfront with dataclasses
|
||||
2. **Testability**: Business logic tested independently of external dependencies
|
||||
3. **Incremental Development**: One component at a time reduces complexity
|
||||
4. **Clear Contracts**: Dataclass returns make interfaces explicit
|
||||
5. **Error Isolation**: Issues contained within single components
|
||||
6. **Refactoring Safety**: Type system catches interface changes
|
||||
7. **Documentation**: Dataclasses serve as living documentation
|
||||
|
||||
#### Anti-Patterns to Avoid
|
||||
|
||||
❌ **Don't return dictionaries or strings from public methods**
|
||||
❌ **Don't implement all dependencies simultaneously**
|
||||
❌ **Don't skip writing tests first**
|
||||
❌ **Don't mix business logic with I/O operations**
|
||||
❌ **Don't use inheritance for dependency injection**
|
||||
❌ **Don't create circular dependencies between components**
|
||||
|
||||
✅ **Do use dataclasses for all return values**
|
||||
✅ **Do implement one dependency at a time**
|
||||
✅ **Do write tests before implementation**
|
||||
✅ **Do separate business logic from I/O**
|
||||
✅ **Do use constructor injection**
|
||||
✅ **Do maintain clear separation of concerns**
|
||||
|
||||
### File Structure Context
|
||||
- **`cli/`**: Interactive command-line interface
|
||||
- **`tradingagents/agents/`**: All agent implementations
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ dependencies = [
|
|||
"typer>=0.12.0",
|
||||
"typing-extensions>=4.14.0",
|
||||
"yfinance>=0.2.63",
|
||||
"TA-Lib>=0.4.28",
|
||||
"newspaper3k>=0.2.8",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
|||
from .analysts.market_analyst import create_market_analyst
|
||||
from .analysts.news_analyst import create_news_analyst
|
||||
from .analysts.social_media_analyst import create_social_media_analyst
|
||||
from .libs.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .libs.context_helpers import create_msg_delete
|
||||
from .libs.memory import FinancialSituationMemory
|
||||
from .managers.research_manager import create_research_manager
|
||||
from .managers.risk_manager import create_risk_manager
|
||||
from .researchers.bear_researcher import create_bear_researcher
|
||||
|
|
@ -10,26 +13,20 @@ from .risk_mgmt.aggresive_debator import create_risky_debator
|
|||
from .risk_mgmt.conservative_debator import create_safe_debator
|
||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||
from .trader.trader import create_trader
|
||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .utils.agent_utils import Toolkit, create_msg_delete
|
||||
from .utils.memory import FinancialSituationMemory
|
||||
from .utils.service_toolkit import ServiceToolkit
|
||||
|
||||
__all__ = [
|
||||
"FinancialSituationMemory",
|
||||
"Toolkit",
|
||||
"ServiceToolkit",
|
||||
"AgentState",
|
||||
"create_msg_delete",
|
||||
"InvestDebateState",
|
||||
"RiskDebateState",
|
||||
"create_bear_researcher",
|
||||
"create_bull_researcher",
|
||||
"create_research_manager",
|
||||
"create_fundamentals_analyst",
|
||||
"create_market_analyst",
|
||||
"create_msg_delete",
|
||||
"create_neutral_debator",
|
||||
"create_news_analyst",
|
||||
"create_research_manager",
|
||||
"create_risky_debator",
|
||||
"create_risk_manager",
|
||||
"create_safe_debator",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
|
@ -41,9 +41,9 @@ class AgentToolkit:
|
|||
def __init__(
|
||||
self,
|
||||
news_service: NewsService,
|
||||
socialmedia_service: SocialMediaService,
|
||||
marketdata_service: MarketDataService,
|
||||
fundamentaldata_service: FundamentalDataService,
|
||||
socialmedia_service: SocialMediaService,
|
||||
insiderdata_service: InsiderDataService,
|
||||
config: TradingAgentsConfig = DEFAULT_CONFIG,
|
||||
):
|
||||
|
|
@ -102,8 +102,8 @@ class AgentToolkit:
|
|||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
return self._news_service.get_context(
|
||||
query=ticker, start_date=start_date, end_date=end_date, symbol=ticker
|
||||
return self._news_service.get_company_news_context(
|
||||
symbol=ticker, start_date=start_date, end_date=end_date
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting news for {ticker}: {e}")
|
||||
|
|
@ -177,7 +177,7 @@ class AgentToolkit:
|
|||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = None,
|
||||
look_back_days: Annotated[int, "how many days to look back"],
|
||||
) -> TAReportContext:
|
||||
"""
|
||||
Retrieve stock stats indicators for a given ticker symbol and indicator.
|
||||
|
|
@ -280,11 +280,11 @@ class AgentToolkit:
|
|||
Returns:
|
||||
BalanceSheetContext: Structured balance sheet analysis with key liquidity and debt metrics.
|
||||
"""
|
||||
curr_date_obj = self._parse_date(curr_date)
|
||||
return self._fundamentaldata_service.get_balance_sheet_context(
|
||||
symbol=ticker,
|
||||
start_date=curr_date,
|
||||
end_date=curr_date,
|
||||
frequency=freq.lower(),
|
||||
start_date=curr_date_obj,
|
||||
end_date=curr_date_obj,
|
||||
)
|
||||
|
||||
@tool
|
||||
|
|
@ -306,11 +306,11 @@ class AgentToolkit:
|
|||
Returns:
|
||||
CashFlowContext: Structured cash flow analysis with operating cash flow metrics.
|
||||
"""
|
||||
curr_date_obj = self._parse_date(curr_date)
|
||||
return self._fundamentaldata_service.get_cashflow_context(
|
||||
symbol=ticker,
|
||||
start_date=curr_date,
|
||||
end_date=curr_date,
|
||||
frequency=freq.lower(),
|
||||
start_date=curr_date_obj,
|
||||
end_date=curr_date_obj,
|
||||
)
|
||||
|
||||
@tool
|
||||
|
|
@ -332,11 +332,11 @@ class AgentToolkit:
|
|||
Returns:
|
||||
IncomeStatementContext: Structured income statement analysis with profitability metrics.
|
||||
"""
|
||||
curr_date_obj = self._parse_date(curr_date)
|
||||
return self._fundamentaldata_service.get_income_statement_context(
|
||||
symbol=ticker,
|
||||
start_date=curr_date,
|
||||
end_date=curr_date,
|
||||
frequency=freq.lower(),
|
||||
start_date=curr_date_obj,
|
||||
end_date=curr_date_obj,
|
||||
)
|
||||
|
||||
def _calculate_date_range(
|
||||
|
|
@ -359,7 +359,9 @@ class AgentToolkit:
|
|||
curr_date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid date format '{curr_date}': {e}")
|
||||
raise ValueError(f"Date must be in YYYY-MM-DD format, got: {curr_date}")
|
||||
raise ValueError(
|
||||
f"Date must be in YYYY-MM-DD format, got: {curr_date}"
|
||||
) from e
|
||||
|
||||
if lookback_days is None:
|
||||
lookback_days = self._config.default_lookback_days
|
||||
|
|
@ -367,6 +369,27 @@ class AgentToolkit:
|
|||
start_date_obj = curr_date_obj - timedelta(days=lookback_days)
|
||||
return start_date_obj.strftime("%Y-%m-%d"), curr_date
|
||||
|
||||
def _parse_date(self, date_str: str) -> date:
|
||||
"""
|
||||
Convert string date to date object.
|
||||
|
||||
Args:
|
||||
date_str: Date string in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
date object
|
||||
|
||||
Raises:
|
||||
ValueError: If date format is invalid
|
||||
"""
|
||||
try:
|
||||
return datetime.strptime(date_str, "%Y-%m-%d").date()
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid date format '{date_str}': {e}")
|
||||
raise ValueError(
|
||||
f"Date must be in YYYY-MM-DD format, got '{date_str}'"
|
||||
) from e
|
||||
|
||||
def _validate_ticker(self, ticker: str) -> str:
|
||||
"""
|
||||
Validate and sanitize ticker symbol.
|
||||
|
|
|
|||
|
|
@ -310,3 +310,24 @@ def is_high_quality_data(context_json: str) -> bool:
|
|||
"""Quick check if context contains high quality data."""
|
||||
context = ContextParser.parse_context(context_json)
|
||||
return ContextParser.is_high_quality(context)
|
||||
|
||||
|
||||
def create_msg_delete():
|
||||
"""
|
||||
Create a message deletion node function for LangGraph workflows.
|
||||
|
||||
This function returns a node that clears all messages from the state,
|
||||
which is useful for preventing context pollution between different
|
||||
phases of multi-agent workflows.
|
||||
|
||||
Returns:
|
||||
Callable: A function that can be used as a LangGraph node
|
||||
"""
|
||||
from langchain_core.messages import RemoveMessage
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
|
||||
def delete_messages(state):
|
||||
"""Delete all messages from the current state."""
|
||||
return {"messages": [RemoveMessage(id=REMOVE_ALL_MESSAGES)]}
|
||||
|
||||
return delete_messages
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
"""
|
||||
Client classes for live data access in TradingAgents.
|
||||
"""
|
||||
|
||||
from .base import BaseClient
|
||||
|
||||
__all__ = ["BaseClient"]
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
"""
|
||||
Base client interface for TradingAgents data sources.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseClient(ABC):
|
||||
"""Abstract base class for all data clients."""
|
||||
|
||||
@abstractmethod
|
||||
def get_data(self, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Get data from the client source.
|
||||
|
||||
Args:
|
||||
**kwargs: Client-specific parameters
|
||||
|
||||
Returns:
|
||||
dict: Data dictionary with standardized structure
|
||||
"""
|
||||
pass
|
||||
|
|
@ -4,10 +4,16 @@ Finnhub client for financial data access.
|
|||
|
||||
import logging
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
import finnhub
|
||||
|
||||
from ..models import (
|
||||
CompanyProfile,
|
||||
InsiderSentimentResponse,
|
||||
InsiderTransactionsResponse,
|
||||
ReportedFinancialsResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -31,116 +37,33 @@ class FinnhubClient:
|
|||
Args:
|
||||
api_key: Finnhub API key
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.client = finnhub.Client(api_key=api_key)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Test if the Finnhub API connection is working."""
|
||||
try:
|
||||
# Test with a simple quote request
|
||||
response = self.client.quote("AAPL")
|
||||
return "c" in response # 'c' is current price field
|
||||
except Exception as e:
|
||||
logger.error(f"Finnhub connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def get_balance_sheet(
|
||||
self, symbol: str, frequency: str, report_date: date
|
||||
) -> dict[str, Any]:
|
||||
def get_reported_financials(
|
||||
self, symbol: str, frequency: str
|
||||
) -> ReportedFinancialsResponse:
|
||||
"""
|
||||
Get balance sheet data from Finnhub.
|
||||
Get reported financials data from Finnhub.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol (e.g., 'AAPL')
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
report_date: Report date as date object
|
||||
|
||||
Returns:
|
||||
Balance sheet data from Finnhub API
|
||||
Reported financials data from Finnhub API
|
||||
"""
|
||||
try:
|
||||
# Finnhub SDK expects frequency as 'quarterly' or 'annual'
|
||||
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
|
||||
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
|
||||
return response if isinstance(response, dict) else {"data": []}
|
||||
return ReportedFinancialsResponse.from_dict(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching balance sheet for {symbol}: {e}")
|
||||
return {"data": []}
|
||||
|
||||
def get_income_statement(
|
||||
self, symbol: str, frequency: str, report_date: date
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get income statement data from Finnhub.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol (e.g., 'AAPL')
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
report_date: Report date as date object
|
||||
|
||||
Returns:
|
||||
Income statement data from Finnhub API
|
||||
"""
|
||||
try:
|
||||
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
|
||||
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
|
||||
return response if isinstance(response, dict) else {"data": []}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching income statement for {symbol}: {e}")
|
||||
return {"data": []}
|
||||
|
||||
def get_cash_flow(
|
||||
self, symbol: str, frequency: str, report_date: date
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get cash flow statement data from Finnhub.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol (e.g., 'AAPL')
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
report_date: Report date as date object
|
||||
|
||||
Returns:
|
||||
Cash flow statement data from Finnhub API
|
||||
"""
|
||||
try:
|
||||
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
|
||||
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
|
||||
return response if isinstance(response, dict) else {"data": []}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching cash flow for {symbol}: {e}")
|
||||
return {"data": []}
|
||||
|
||||
def get_company_news(
|
||||
self, symbol: str, start_date: date, end_date: date
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get company news for a specific symbol and date range.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol (e.g., 'AAPL')
|
||||
start_date: Start date as date object
|
||||
end_date: End date as date object
|
||||
|
||||
Returns:
|
||||
List of news articles
|
||||
"""
|
||||
# Convert date objects to strings for API
|
||||
start_str = start_date.isoformat()
|
||||
end_str = end_date.isoformat()
|
||||
|
||||
try:
|
||||
response = self.client.company_news(
|
||||
symbol.upper(), _from=start_str, to=end_str
|
||||
)
|
||||
return response if isinstance(response, list) else []
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching news for {symbol}: {e}")
|
||||
return []
|
||||
logger.error(f"Error fetching reported financials for {symbol}: {e}")
|
||||
raise
|
||||
|
||||
def get_insider_transactions(
|
||||
self, symbol: str, start_date: date, end_date: date
|
||||
) -> dict[str, Any]:
|
||||
) -> InsiderTransactionsResponse:
|
||||
"""
|
||||
Get insider transactions for a company.
|
||||
|
||||
|
|
@ -159,14 +82,18 @@ class FinnhubClient:
|
|||
response = self.client.stock_insider_transactions(
|
||||
symbol.upper(), _from=start_str, to=end_str
|
||||
)
|
||||
return response if isinstance(response, dict) else {"data": []}
|
||||
if isinstance(response, dict):
|
||||
return InsiderTransactionsResponse.from_dict(response)
|
||||
else:
|
||||
# Return empty response if API returns unexpected format
|
||||
return InsiderTransactionsResponse(data=[], symbol=symbol.upper())
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching insider transactions for {symbol}: {e}")
|
||||
return {"data": []}
|
||||
raise
|
||||
|
||||
def get_insider_sentiment(
|
||||
self, symbol: str, start_date: date, end_date: date
|
||||
) -> dict[str, Any]:
|
||||
) -> InsiderSentimentResponse:
|
||||
"""
|
||||
Get insider sentiment data for a company.
|
||||
|
||||
|
|
@ -185,29 +112,16 @@ class FinnhubClient:
|
|||
response = self.client.stock_insider_sentiment(
|
||||
symbol.upper(), _from=start_str, to=end_str
|
||||
)
|
||||
return response if isinstance(response, dict) else {"data": []}
|
||||
if isinstance(response, dict):
|
||||
return InsiderSentimentResponse.from_dict(response)
|
||||
else:
|
||||
# Return empty response if API returns unexpected format
|
||||
return InsiderSentimentResponse(data=[], symbol=symbol.upper())
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching insider sentiment for {symbol}: {e}")
|
||||
return {"data": []}
|
||||
raise
|
||||
|
||||
def get_quote(self, symbol: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get current quote for a symbol.
|
||||
|
||||
Args:
|
||||
symbol: Stock symbol
|
||||
|
||||
Returns:
|
||||
Quote data with current price, change, etc.
|
||||
"""
|
||||
try:
|
||||
response = self.client.quote(symbol.upper())
|
||||
return response if isinstance(response, dict) else {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching quote for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_company_profile(self, symbol: str) -> dict[str, Any]:
|
||||
def get_company_profile(self, symbol: str) -> CompanyProfile:
|
||||
"""
|
||||
Get company profile information.
|
||||
|
||||
|
|
@ -219,20 +133,7 @@ class FinnhubClient:
|
|||
"""
|
||||
try:
|
||||
response = self.client.company_profile2(symbol=symbol.upper())
|
||||
return response if isinstance(response, dict) else {}
|
||||
return CompanyProfile.from_dict(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching company profile for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_client_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about this client.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Client metadata
|
||||
"""
|
||||
return {
|
||||
"client_type": self.__class__.__name__,
|
||||
"api_key_set": bool(self.api_key),
|
||||
"sdk_version": getattr(finnhub, "__version__", "unknown"),
|
||||
}
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,320 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive tests for FinnhubClient using pytest-vcr.
|
||||
|
||||
This test suite records real API interactions with Finnhub and replays them
|
||||
in subsequent runs, providing realistic testing without network dependencies.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
# NOTE: FinnhubClient was removed - this test needs to be updated
|
||||
# from tradingagents.clients.finnhub_client import FinnhubClient
|
||||
|
||||
pytest.skip("FinnhubClient implementation removed", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def finnhub_client():
|
||||
"""Create FinnhubClient with test API key."""
|
||||
# Use environment variable or test key for VCR recording
|
||||
api_key = os.getenv("FINNHUB_API_KEY", "test_api_key")
|
||||
return FinnhubClient(api_key=api_key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vcr_config():
|
||||
"""Configure VCR with proper settings."""
|
||||
return {
|
||||
"filter_headers": ["X-Finnhub-Token"], # Filter out API key from recordings
|
||||
"record_mode": "once", # Record once, then replay
|
||||
"match_on": ["uri", "method"],
|
||||
"decode_compressed_response": True,
|
||||
}
|
||||
|
||||
|
||||
class TestFinnhubClientConnection:
|
||||
"""Test client connection and initialization."""
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_client_initialization(self, finnhub_client):
|
||||
"""Test that client initializes correctly."""
|
||||
assert finnhub_client.api_key is not None
|
||||
assert finnhub_client.client is not None
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_connection_success(self, finnhub_client):
|
||||
"""Test successful API connection."""
|
||||
assert finnhub_client.test_connection() is True
|
||||
|
||||
def test_client_info(self, finnhub_client):
|
||||
"""Test client metadata."""
|
||||
info = finnhub_client.get_client_info()
|
||||
assert info["client_type"] == "FinnhubClient"
|
||||
assert info["api_key_set"] is True
|
||||
|
||||
|
||||
class TestFundamentalDataMethods:
|
||||
"""Test the fundamental data methods needed by FundamentalDataService."""
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_balance_sheet_quarterly(self, finnhub_client):
|
||||
"""Test balance sheet retrieval with date object."""
|
||||
test_date = date(2024, 1, 1)
|
||||
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
|
||||
|
||||
assert isinstance(data, dict)
|
||||
# Finnhub financials_reported returns data structure with 'data' key
|
||||
assert "data" in data or len(data) > 0
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_balance_sheet_annual(self, finnhub_client):
|
||||
"""Test annual balance sheet retrieval."""
|
||||
test_date = date(2024, 1, 1)
|
||||
data = finnhub_client.get_balance_sheet("AAPL", "annual", test_date)
|
||||
|
||||
assert isinstance(data, dict)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_income_statement_quarterly(self, finnhub_client):
|
||||
"""Test income statement retrieval."""
|
||||
test_date = date(2024, 1, 1)
|
||||
data = finnhub_client.get_income_statement("AAPL", "quarterly", test_date)
|
||||
|
||||
assert isinstance(data, dict)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_income_statement_annual(self, finnhub_client):
|
||||
"""Test annual income statement retrieval."""
|
||||
test_date = date(2024, 1, 1)
|
||||
data = finnhub_client.get_income_statement("AAPL", "annual", test_date)
|
||||
|
||||
assert isinstance(data, dict)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_cash_flow_with_date_object(self, finnhub_client):
|
||||
"""Test cash flow retrieval with date object."""
|
||||
test_date = date(2024, 3, 31)
|
||||
data = finnhub_client.get_cash_flow("AAPL", "quarterly", test_date)
|
||||
|
||||
assert isinstance(data, dict)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_fundamental_data_different_symbols(self, finnhub_client):
|
||||
"""Test fundamental data for different symbols."""
|
||||
symbols = ["AAPL", "MSFT", "GOOGL"]
|
||||
test_date = date(2024, 1, 1)
|
||||
|
||||
for symbol in symbols:
|
||||
data = finnhub_client.get_balance_sheet(symbol, "quarterly", test_date)
|
||||
assert isinstance(data, dict)
|
||||
|
||||
|
||||
class TestExistingMethods:
|
||||
"""Test existing methods with enhanced date support."""
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_company_news_january(self, finnhub_client):
|
||||
"""Test company news for January 2024."""
|
||||
start_date = date(2024, 1, 1)
|
||||
end_date = date(2024, 1, 31)
|
||||
news = finnhub_client.get_company_news("AAPL", start_date, end_date)
|
||||
|
||||
assert isinstance(news, list)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_company_news_date_objects(self, finnhub_client):
|
||||
"""Test company news with date objects."""
|
||||
start_date = date(2024, 1, 1)
|
||||
end_date = date(2024, 1, 31)
|
||||
|
||||
news = finnhub_client.get_company_news("AAPL", start_date, end_date)
|
||||
|
||||
assert isinstance(news, list)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_insider_transactions_date_objects(self, finnhub_client):
|
||||
"""Test insider transactions with date objects."""
|
||||
start_date = date(2024, 1, 1)
|
||||
end_date = date(2024, 1, 31)
|
||||
|
||||
data = finnhub_client.get_insider_transactions("AAPL", start_date, end_date)
|
||||
|
||||
assert isinstance(data, dict)
|
||||
assert "data" in data
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_insider_sentiment(self, finnhub_client):
|
||||
"""Test insider sentiment with date objects."""
|
||||
start_date = date(2024, 1, 1)
|
||||
end_date = date(2024, 1, 31)
|
||||
|
||||
data = finnhub_client.get_insider_sentiment("AAPL", start_date, end_date)
|
||||
|
||||
assert isinstance(data, dict)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_quote(self, finnhub_client):
|
||||
"""Test stock quote retrieval."""
|
||||
quote = finnhub_client.get_quote("AAPL")
|
||||
|
||||
assert isinstance(quote, dict)
|
||||
# Quote should contain current price 'c' field
|
||||
if quote: # Only check if we got data
|
||||
assert "c" in quote
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_get_company_profile(self, finnhub_client):
|
||||
"""Test company profile retrieval."""
|
||||
profile = finnhub_client.get_company_profile("AAPL")
|
||||
|
||||
assert isinstance(profile, dict)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling and edge cases."""
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_invalid_symbol_balance_sheet(self, finnhub_client):
|
||||
"""Test balance sheet with invalid symbol."""
|
||||
test_date = date(2024, 1, 1)
|
||||
data = finnhub_client.get_balance_sheet(
|
||||
"INVALID_SYMBOL_XYZ", "quarterly", test_date
|
||||
)
|
||||
|
||||
# Should return dict with empty data structure, not raise exception
|
||||
assert isinstance(data, dict)
|
||||
assert "data" in data
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_invalid_symbol_news(self, finnhub_client):
|
||||
"""Test news with invalid symbol."""
|
||||
start_date = date(2024, 1, 1)
|
||||
end_date = date(2024, 1, 31)
|
||||
news = finnhub_client.get_company_news(
|
||||
"INVALID_SYMBOL_XYZ", start_date, end_date
|
||||
)
|
||||
|
||||
# Should return empty list, not raise exception
|
||||
assert isinstance(news, list)
|
||||
|
||||
def test_connection_with_invalid_api_key(self):
|
||||
"""Test connection failure with invalid API key."""
|
||||
client = FinnhubClient("invalid_api_key")
|
||||
# This should return False, not raise an exception
|
||||
assert client.test_connection() is False
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_frequency_normalization(self, finnhub_client):
|
||||
"""Test that frequency parameters are normalized correctly."""
|
||||
# Test different frequency formats
|
||||
frequencies = ["quarterly", "QUARTERLY", "q", "Q", "annual", "ANNUAL", "a", "A"]
|
||||
test_date = date(2024, 1, 1)
|
||||
|
||||
for freq in frequencies:
|
||||
data = finnhub_client.get_balance_sheet("AAPL", freq, test_date)
|
||||
assert isinstance(data, dict)
|
||||
|
||||
def test_date_edge_cases(self, finnhub_client):
|
||||
"""Test date edge cases."""
|
||||
# Test with year-end date
|
||||
test_date = date(2024, 12, 31)
|
||||
|
||||
# This shouldn't raise exceptions
|
||||
# (actual API calls are mocked/recorded)
|
||||
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
|
||||
assert isinstance(data, dict)
|
||||
|
||||
|
||||
class TestMultipleSymbolsAndTimeframes:
|
||||
"""Test with multiple symbols and different timeframes."""
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_multiple_symbols_balance_sheet(self, finnhub_client):
|
||||
"""Test balance sheet for multiple major symbols."""
|
||||
symbols = ["AAPL", "MSFT", "GOOGL", "TSLA"]
|
||||
test_date = date(2024, 1, 1)
|
||||
|
||||
for symbol in symbols:
|
||||
data = finnhub_client.get_balance_sheet(symbol, "quarterly", test_date)
|
||||
assert isinstance(data, dict)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_quarterly_vs_annual_frequency(self, finnhub_client):
|
||||
"""Test quarterly vs annual frequency."""
|
||||
symbol = "AAPL"
|
||||
test_date = date(2024, 1, 1)
|
||||
|
||||
quarterly_data = finnhub_client.get_balance_sheet(
|
||||
symbol, "quarterly", test_date
|
||||
)
|
||||
annual_data = finnhub_client.get_balance_sheet(symbol, "annual", test_date)
|
||||
|
||||
assert isinstance(quarterly_data, dict)
|
||||
assert isinstance(annual_data, dict)
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_all_fundamental_methods_same_symbol(self, finnhub_client):
|
||||
"""Test all fundamental methods for the same symbol."""
|
||||
symbol = "AAPL"
|
||||
frequency = "quarterly"
|
||||
report_date = date(2024, 1, 1)
|
||||
|
||||
balance_sheet = finnhub_client.get_balance_sheet(symbol, frequency, report_date)
|
||||
income_statement = finnhub_client.get_income_statement(
|
||||
symbol, frequency, report_date
|
||||
)
|
||||
cash_flow = finnhub_client.get_cash_flow(symbol, frequency, report_date)
|
||||
|
||||
assert isinstance(balance_sheet, dict)
|
||||
assert isinstance(income_statement, dict)
|
||||
assert isinstance(cash_flow, dict)
|
||||
|
||||
|
||||
# Integration test to verify the client works with service expectations
|
||||
class TestServiceIntegration:
|
||||
"""Test that client works as expected by FundamentalDataService."""
|
||||
|
||||
def test_service_expected_methods_exist(self, finnhub_client):
|
||||
"""Test that all methods expected by FundamentalDataService exist."""
|
||||
# These are the methods the service calls
|
||||
assert hasattr(finnhub_client, "get_balance_sheet")
|
||||
assert hasattr(finnhub_client, "get_income_statement")
|
||||
assert hasattr(finnhub_client, "get_cash_flow")
|
||||
|
||||
# Verify method signatures accept the expected parameters
|
||||
import inspect
|
||||
|
||||
for method_name in [
|
||||
"get_balance_sheet",
|
||||
"get_income_statement",
|
||||
"get_cash_flow",
|
||||
]:
|
||||
method = getattr(finnhub_client, method_name)
|
||||
sig = inspect.signature(method)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
# Service calls: method(symbol, frequency, date)
|
||||
assert "symbol" in params
|
||||
assert "frequency" in params
|
||||
assert "report_date" in params
|
||||
|
||||
@pytest.mark.vcr
|
||||
def test_data_format_for_service_conversion(self, finnhub_client):
|
||||
"""Test that returned data format can be processed by service conversion method."""
|
||||
test_date = date(2024, 1, 1)
|
||||
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
|
||||
|
||||
# Service expects either empty dict or dict with data
|
||||
assert isinstance(data, dict)
|
||||
|
||||
# The service's _convert_to_financial_statement expects either:
|
||||
# 1. Empty/falsy data -> returns None
|
||||
# 2. Dict with "data" key containing the actual financial data
|
||||
if data: # If we got data
|
||||
# Should be able to access data["data"] or similar structure
|
||||
# This validates the format is compatible with service expectations
|
||||
assert isinstance(data, dict)
|
||||
|
|
@ -3,6 +3,19 @@ Fundamental Data Service for aggregating and analyzing financial statement data.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
|
||||
from .clients.finnhub_client import FinnhubClient
|
||||
from .models import (
|
||||
BalanceSheetContext,
|
||||
CashFlowContext,
|
||||
DataQuality,
|
||||
FundamentalContext,
|
||||
IncomeStatementContext,
|
||||
)
|
||||
from .repos.fundamental_data_repository import FundamentalDataRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -12,35 +25,30 @@ class FundamentalDataService:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
simfin_client: SimFinClient,
|
||||
finnhub_client: FinnhubClient,
|
||||
repository: FundamentalDataRepository,
|
||||
):
|
||||
"""Initialize Fundamental Data Service.
|
||||
|
||||
Args:
|
||||
simfin_client: Client for SimFin/financial API access
|
||||
finnhub_client: Client for FinnHub financial API access
|
||||
repository: Repository for cached fundamental data
|
||||
online_mode: Whether to fetch live data
|
||||
data_dir: Directory for data storage
|
||||
online_mode: Whether to fetch live data or use cached data
|
||||
"""
|
||||
self.simfin_client = simfin_client
|
||||
self.finnhub_client = finnhub_client
|
||||
self.repository = repository
|
||||
|
||||
def update_fundamental_data(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
frequency: str = "quarterly",
|
||||
) -> FundamentalContext:
|
||||
pass # TODO: fetch fundementals from simfin, save in repo
|
||||
@staticmethod
|
||||
def build(_config: TradingAgentsConfig):
|
||||
client = FinnhubClient("")
|
||||
repo = FundamentalDataRepository("")
|
||||
return FundamentalDataService(client, repo)
|
||||
|
||||
def get_fundamental_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
frequency: str = "quarterly",
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> FundamentalContext:
|
||||
"""Get fundamental analysis context for a company.
|
||||
|
||||
|
|
@ -49,34 +57,135 @@ class FundamentalDataService:
|
|||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
force_refresh: If True, skip local data and fetch fresh from APIs
|
||||
|
||||
Returns:
|
||||
FundamentalContext with financial statements and key ratios
|
||||
"""
|
||||
balance_sheet = None
|
||||
income_statement = None
|
||||
cash_flow = None
|
||||
error_info = {}
|
||||
errors = []
|
||||
data_source = "unknown"
|
||||
# TODO: implement
|
||||
return FundamentalContext(
|
||||
symbol=symbol, start_date=start_date, end_date=end_date
|
||||
)
|
||||
|
||||
# return FundamentalContext(
|
||||
# symbol=symbol,
|
||||
# period={"start": start_date, "end": end_date},
|
||||
# balance_sheet=balance_sheet,
|
||||
# income_statement=income_statement,
|
||||
# cash_flow=cash_flow,
|
||||
# key_ratios=key_ratios,
|
||||
# metadata={
|
||||
# "data_quality": data_quality,
|
||||
# "service": "fundamental_data",
|
||||
# "online_mode": self.is_online(),
|
||||
# "frequency": frequency,
|
||||
# "data_source": data_source,
|
||||
# "force_refresh": force_refresh,
|
||||
# **error_info,
|
||||
# },
|
||||
# )
|
||||
def get_balance_sheet_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> BalanceSheetContext:
|
||||
"""Get balance sheet context for a company.
|
||||
|
||||
pass # TODO: read data from repo
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
|
||||
Returns:
|
||||
BalanceSheetContext with balance sheet data and ratios
|
||||
"""
|
||||
# TODO: implement
|
||||
|
||||
# Return empty context if no data
|
||||
return BalanceSheetContext(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
balance_sheet_data=[],
|
||||
key_ratios=[],
|
||||
data_quality=DataQuality.LOW,
|
||||
source="none",
|
||||
metadata={"error": "No balance sheet data available"},
|
||||
)
|
||||
|
||||
def get_income_statement_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> IncomeStatementContext:
|
||||
"""Get income statement context for a company.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
|
||||
Returns:
|
||||
IncomeStatementContext with income statement data and ratios
|
||||
"""
|
||||
# TODO: implement
|
||||
|
||||
# Return empty context if no data
|
||||
return IncomeStatementContext(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
income_statement_data=[],
|
||||
key_ratios=[],
|
||||
data_quality=DataQuality.LOW,
|
||||
source="none",
|
||||
metadata={"error": "No income statement data available"},
|
||||
)
|
||||
|
||||
def get_cashflow_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
) -> CashFlowContext:
|
||||
"""Get cash flow context for a company.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
|
||||
Returns:
|
||||
CashFlowContext with cash flow data and ratios
|
||||
"""
|
||||
# TODOL implement
|
||||
|
||||
# Return empty context if no data
|
||||
return CashFlowContext(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
cash_flow_data=[],
|
||||
key_ratios=[],
|
||||
data_quality=DataQuality.LOW,
|
||||
source="none",
|
||||
metadata={"error": "No cash flow data available"},
|
||||
)
|
||||
|
||||
def update_fundamental_data(
|
||||
self,
|
||||
symbol: str,
|
||||
date: date,
|
||||
frequency: str = "quarterly",
|
||||
) -> bool:
|
||||
"""Update fundamental data by fetching from FinnHub and storing in repository.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
date: Date for the financial data
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
"""
|
||||
try:
|
||||
# Fetch reported financials data from FinnHub using the unified method
|
||||
reported_financials = self.finnhub_client.get_reported_financials(
|
||||
symbol, frequency
|
||||
)
|
||||
|
||||
# Store the reported financials data in repository
|
||||
return self.repository.store_reported_financials(
|
||||
symbol=symbol, date=date, reported_financials=reported_financials
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating fundamental data for {symbol}: {e}")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -1,495 +0,0 @@
|
|||
"""
|
||||
Test FundamentalDataService with mock SimFin clients and real FundamentalDataRepository.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.domains.marketdata.fundamental_data_service import (
|
||||
DataQuality,
|
||||
FinancialStatement,
|
||||
FundamentalContext,
|
||||
FundamentalDataService,
|
||||
)
|
||||
from tradingagents.repositories.fundamental_repository import FundamentalDataRepository
|
||||
|
||||
|
||||
class MockSimFinClient(BaseClient):
|
||||
"""Mock SimFin client that returns sample financial statement data."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.connection_works = True
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
return self.connection_works
|
||||
|
||||
def get_data(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""Not used directly by FundamentalDataService."""
|
||||
return {}
|
||||
|
||||
def get_balance_sheet(
|
||||
self, ticker: str, freq: str, curr_date: str
|
||||
) -> dict[str, Any]:
|
||||
"""Return mock balance sheet data."""
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"statement_type": "balance_sheet",
|
||||
"frequency": freq,
|
||||
"period": "Q3-2024" if freq == "quarterly" else "2024",
|
||||
"report_date": "2024-09-30",
|
||||
"publish_date": "2024-10-30",
|
||||
"currency": "USD",
|
||||
"data": {
|
||||
"Total Assets": 365725000000.0,
|
||||
"Total Current Assets": 143566000000.0,
|
||||
"Cash and Cash Equivalents": 28965000000.0,
|
||||
"Short-term Investments": 31590000000.0,
|
||||
"Accounts Receivable": 13348000000.0,
|
||||
"Inventory": 6511000000.0,
|
||||
"Total Non-Current Assets": 222159000000.0,
|
||||
"Property, Plant & Equipment": 43715000000.0,
|
||||
"Intangible Assets": 11235000000.0,
|
||||
"Total Liabilities": 279414000000.0,
|
||||
"Total Current Liabilities": 136817000000.0,
|
||||
"Accounts Payable": 58146000000.0,
|
||||
"Short-term Debt": 20208000000.0,
|
||||
"Total Non-Current Liabilities": 142597000000.0,
|
||||
"Long-term Debt": 106550000000.0,
|
||||
"Total Shareholders Equity": 86311000000.0,
|
||||
"Retained Earnings": 672000000.0,
|
||||
"Common Stock": 77958000000.0,
|
||||
},
|
||||
"metadata": {
|
||||
"source": "mock_simfin",
|
||||
"retrieved_at": datetime(2024, 1, 2).isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
def get_income_statement(
|
||||
self, ticker: str, freq: str, curr_date: str
|
||||
) -> dict[str, Any]:
|
||||
"""Return mock income statement data."""
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"statement_type": "income_statement",
|
||||
"frequency": freq,
|
||||
"period": "Q3-2024" if freq == "quarterly" else "2024",
|
||||
"report_date": "2024-09-30",
|
||||
"publish_date": "2024-10-30",
|
||||
"currency": "USD",
|
||||
"data": {
|
||||
"Total Revenue": 94930000000.0,
|
||||
"Cost of Revenue": 55720000000.0,
|
||||
"Gross Profit": 39210000000.0,
|
||||
"Operating Expenses": 15706000000.0,
|
||||
"Research and Development": 8067000000.0,
|
||||
"Sales, General & Administrative": 7639000000.0,
|
||||
"Operating Income": 23504000000.0,
|
||||
"Interest Expense": 1013000000.0,
|
||||
"Other Income": 269000000.0,
|
||||
"Income Before Tax": 22760000000.0,
|
||||
"Tax Provision": 4438000000.0,
|
||||
"Net Income": 18322000000.0,
|
||||
"Basic EPS": 1.18,
|
||||
"Diluted EPS": 1.15,
|
||||
"Shares Outstanding": 15550193000,
|
||||
},
|
||||
"metadata": {
|
||||
"source": "mock_simfin",
|
||||
"retrieved_at": datetime(2024, 1, 2).isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
def get_cash_flow(self, ticker: str, freq: str, curr_date: str) -> dict[str, Any]:
|
||||
"""Return mock cash flow statement data."""
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"statement_type": "cash_flow",
|
||||
"frequency": freq,
|
||||
"period": "Q3-2024" if freq == "quarterly" else "2024",
|
||||
"report_date": "2024-09-30",
|
||||
"publish_date": "2024-10-30",
|
||||
"currency": "USD",
|
||||
"data": {
|
||||
"Net Income": 18322000000.0,
|
||||
"Depreciation & Amortization": 2871000000.0,
|
||||
"Changes in Working Capital": -1684000000.0,
|
||||
"Operating Cash Flow": 23302000000.0,
|
||||
"Capital Expenditures": -2736000000.0,
|
||||
"Acquisitions": -1800000000.0,
|
||||
"Asset Sales": 234000000.0,
|
||||
"Investing Cash Flow": -4302000000.0,
|
||||
"Dividends Paid": -3746000000.0,
|
||||
"Share Repurchases": -24979000000.0,
|
||||
"Debt Proceeds": 750000000.0,
|
||||
"Debt Repayment": -1500000000.0,
|
||||
"Financing Cash Flow": -28475000000.0,
|
||||
"Free Cash Flow": 20566000000.0,
|
||||
"Net Change in Cash": -9475000000.0,
|
||||
},
|
||||
"metadata": {
|
||||
"source": "mock_simfin",
|
||||
"retrieved_at": datetime(2024, 1, 2).isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_data_dir():
|
||||
"""Create a temporary directory for test data and clean up after test."""
|
||||
with tempfile.TemporaryDirectory(prefix="fundamental_test_") as temp_dir:
|
||||
yield temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_simfin_client():
|
||||
"""Create a mock SimFin client for testing."""
|
||||
return MockSimFinClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def broken_simfin_client():
|
||||
"""Create a broken SimFin client for error testing."""
|
||||
|
||||
class BrokenSimFinClient(BaseClient):
|
||||
def test_connection(self):
|
||||
return False
|
||||
|
||||
def get_data(self, *args, **kwargs):
|
||||
raise Exception("SimFin API error")
|
||||
|
||||
def get_balance_sheet(self, *args, **kwargs):
|
||||
raise Exception("SimFin API error")
|
||||
|
||||
def get_income_statement(self, *args, **kwargs):
|
||||
raise Exception("SimFin API error")
|
||||
|
||||
def get_cash_flow(self, *args, **kwargs):
|
||||
raise Exception("SimFin API error")
|
||||
|
||||
return BrokenSimFinClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def partial_data_client():
|
||||
"""Create a client that returns partial data for testing."""
|
||||
|
||||
class PartialDataClient(MockSimFinClient):
|
||||
def get_cash_flow(self, ticker, freq, curr_date):
|
||||
# Simulate missing cash flow data
|
||||
raise Exception("Cash flow data not available")
|
||||
|
||||
def get_income_statement(self, ticker, freq, curr_date):
|
||||
# Simulate missing income statement
|
||||
return {"data": {}, "metadata": {"error": "No data found"}}
|
||||
|
||||
return PartialDataClient()
|
||||
|
||||
|
||||
def test_online_mode_with_mock_simfin(temp_data_dir, mock_simfin_client):
|
||||
"""Test FundamentalDataService in online mode with mock SimFin client."""
|
||||
# Create real repository with temporary directory
|
||||
real_repo = FundamentalDataRepository(temp_data_dir)
|
||||
|
||||
# Create service with mock client and real repository
|
||||
service = FundamentalDataService(
|
||||
simfin_client=mock_simfin_client,
|
||||
repository=real_repo,
|
||||
data_dir=temp_data_dir,
|
||||
)
|
||||
|
||||
# Test getting fundamental context with all three statements
|
||||
context = service.get_fundamental_context(
|
||||
symbol="AAPL",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
frequency="quarterly",
|
||||
force_refresh=True, # Force using mock client instead of cache
|
||||
)
|
||||
|
||||
# Validate context structure
|
||||
assert isinstance(context, FundamentalContext)
|
||||
assert context.symbol == "AAPL"
|
||||
assert context.period["start"] == "2024-01-01"
|
||||
assert context.period["end"] == "2024-12-31"
|
||||
|
||||
# Validate financial statements
|
||||
assert context.balance_sheet is not None
|
||||
assert isinstance(context.balance_sheet, FinancialStatement)
|
||||
assert context.balance_sheet.period == "Q3-2024"
|
||||
assert context.balance_sheet.currency == "USD"
|
||||
assert "Total Assets" in context.balance_sheet.data
|
||||
|
||||
assert context.income_statement is not None
|
||||
assert isinstance(context.income_statement, FinancialStatement)
|
||||
assert "Total Revenue" in context.income_statement.data
|
||||
assert "Net Income" in context.income_statement.data
|
||||
|
||||
assert context.cash_flow is not None
|
||||
assert isinstance(context.cash_flow, FinancialStatement)
|
||||
assert "Operating Cash Flow" in context.cash_flow.data
|
||||
assert "Free Cash Flow" in context.cash_flow.data
|
||||
|
||||
# Validate key ratios calculation
|
||||
assert len(context.key_ratios) > 0
|
||||
assert "current_ratio" in context.key_ratios
|
||||
assert "debt_to_equity" in context.key_ratios
|
||||
assert "roe" in context.key_ratios # Return on Equity
|
||||
assert "gross_margin" in context.key_ratios
|
||||
|
||||
# Validate metadata
|
||||
assert "data_quality" in context.metadata
|
||||
assert context.metadata["service"] == "fundamental_data"
|
||||
|
||||
# Test JSON serialization
|
||||
json_output = context.model_dump_json(indent=2)
|
||||
assert len(json_output) > 0
|
||||
|
||||
|
||||
def test_annual_vs_quarterly_frequency(temp_data_dir, mock_simfin_client):
|
||||
"""Test different reporting frequencies."""
|
||||
real_repo = FundamentalDataRepository(temp_data_dir)
|
||||
service = FundamentalDataService(
|
||||
simfin_client=mock_simfin_client, repository=real_repo, data_dir=temp_data_dir
|
||||
)
|
||||
|
||||
# Test quarterly
|
||||
quarterly_context = service.get_fundamental_context(
|
||||
symbol="MSFT",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
frequency="quarterly",
|
||||
)
|
||||
|
||||
assert quarterly_context.balance_sheet is not None
|
||||
assert quarterly_context.balance_sheet.period == "Q3-2024"
|
||||
|
||||
# Test annual
|
||||
annual_context = service.get_fundamental_context(
|
||||
symbol="MSFT",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
frequency="annual",
|
||||
)
|
||||
|
||||
assert annual_context.balance_sheet is not None
|
||||
assert annual_context.balance_sheet.period == "2024"
|
||||
|
||||
|
||||
def test_financial_ratio_calculations(temp_data_dir, mock_simfin_client):
|
||||
"""Test calculation of key financial ratios."""
|
||||
real_repo = FundamentalDataRepository(temp_data_dir)
|
||||
service = FundamentalDataService(
|
||||
simfin_client=mock_simfin_client, repository=real_repo, data_dir=temp_data_dir
|
||||
)
|
||||
|
||||
context = service.get_fundamental_context("TSLA", "2024-01-01", "2024-12-31")
|
||||
|
||||
# Check that key ratios are calculated
|
||||
ratios = context.key_ratios
|
||||
|
||||
# Liquidity ratios
|
||||
assert "current_ratio" in ratios
|
||||
assert ratios["current_ratio"] > 0
|
||||
|
||||
# Leverage ratios
|
||||
assert "debt_to_equity" in ratios
|
||||
assert ratios["debt_to_equity"] >= 0
|
||||
|
||||
# Profitability ratios
|
||||
assert "gross_margin" in ratios
|
||||
assert "operating_margin" in ratios
|
||||
assert "net_margin" in ratios
|
||||
assert "roe" in ratios # Return on Equity
|
||||
assert "roa" in ratios # Return on Assets
|
||||
|
||||
# Efficiency ratios
|
||||
assert "asset_turnover" in ratios
|
||||
|
||||
# Validate ratio calculations are reasonable
|
||||
assert 0 <= ratios["gross_margin"] <= 1
|
||||
assert 0 <= ratios["net_margin"] <= 1
|
||||
|
||||
|
||||
def test_offline_mode(temp_data_dir):
|
||||
"""Test FundamentalDataService without a client (offline mode)."""
|
||||
real_repo = FundamentalDataRepository(temp_data_dir)
|
||||
|
||||
service = FundamentalDataService(
|
||||
simfin_client=None, repository=real_repo, data_dir=temp_data_dir
|
||||
)
|
||||
|
||||
# Should handle offline gracefully
|
||||
context = service.get_fundamental_context("AAPL", "2024-01-01", "2024-12-31")
|
||||
|
||||
assert context.symbol == "AAPL"
|
||||
assert context.balance_sheet is None # No data available offline
|
||||
assert context.income_statement is None
|
||||
assert context.cash_flow is None
|
||||
assert len(context.key_ratios) == 0
|
||||
assert context.metadata.get("data_quality") == DataQuality.LOW
|
||||
|
||||
|
||||
def test_partial_data_handling():
|
||||
"""Test handling when only some financial statements are available."""
|
||||
|
||||
class PartialDataClient(MockSimFinClient):
|
||||
def get_cash_flow(self, ticker, freq, curr_date):
|
||||
# Simulate missing cash flow data
|
||||
raise Exception("Cash flow data not available")
|
||||
|
||||
def get_income_statement(self, ticker, freq, curr_date):
|
||||
# Simulate missing income statement
|
||||
return {"data": {}, "metadata": {"error": "No data found"}}
|
||||
|
||||
partial_client = PartialDataClient()
|
||||
service = FundamentalDataService(
|
||||
simfin_client=partial_client, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_fundamental_context("XYZ", "2024-01-01", "2024-12-31")
|
||||
|
||||
# Should have balance sheet but not others
|
||||
assert context.balance_sheet is not None
|
||||
assert context.income_statement is None # Failed to load
|
||||
assert context.cash_flow is None # Failed to load
|
||||
|
||||
# Ratios should be limited without full data (only balance sheet ratios available)
|
||||
assert (
|
||||
len(context.key_ratios) <= 8
|
||||
) # Only balance sheet ratios possible, no profitability ratios
|
||||
assert context.metadata.get("data_quality") == DataQuality.LOW
|
||||
|
||||
|
||||
def test_error_handling():
|
||||
"""Test error handling with broken client."""
|
||||
|
||||
class BrokenSimFinClient(BaseClient):
|
||||
def test_connection(self):
|
||||
return False
|
||||
|
||||
def get_data(self, *args, **kwargs):
|
||||
raise Exception("SimFin API error")
|
||||
|
||||
def get_balance_sheet(self, *args, **kwargs):
|
||||
raise Exception("SimFin API error")
|
||||
|
||||
def get_income_statement(self, *args, **kwargs):
|
||||
raise Exception("SimFin API error")
|
||||
|
||||
def get_cash_flow(self, *args, **kwargs):
|
||||
raise Exception("SimFin API error")
|
||||
|
||||
broken_client = BrokenSimFinClient()
|
||||
service = FundamentalDataService(
|
||||
simfin_client=broken_client, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
# Should handle errors gracefully
|
||||
context = service.get_fundamental_context(
|
||||
"FAIL", "2024-01-01", "2024-12-31", force_refresh=True
|
||||
)
|
||||
|
||||
assert context.symbol == "FAIL"
|
||||
assert context.balance_sheet is None
|
||||
assert context.income_statement is None
|
||||
assert context.cash_flow is None
|
||||
assert len(context.key_ratios) == 0
|
||||
assert context.metadata.get("data_quality") == DataQuality.LOW
|
||||
# Service logs errors but doesn't include them in metadata
|
||||
|
||||
|
||||
def test_json_structure():
|
||||
"""Test JSON structure of fundamental context."""
|
||||
mock_simfin = MockSimFinClient()
|
||||
service = FundamentalDataService(
|
||||
simfin_client=mock_simfin, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_fundamental_context("NVDA", "2024-01-01", "2024-12-31")
|
||||
json_data = context.model_dump()
|
||||
|
||||
# Validate required fields
|
||||
required_fields = [
|
||||
"symbol",
|
||||
"period",
|
||||
"balance_sheet",
|
||||
"income_statement",
|
||||
"cash_flow",
|
||||
"key_ratios",
|
||||
"metadata",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in json_data
|
||||
|
||||
# Validate financial statement structure
|
||||
if json_data["balance_sheet"]:
|
||||
balance_sheet = json_data["balance_sheet"]
|
||||
required_statement_fields = [
|
||||
"period",
|
||||
"report_date",
|
||||
"publish_date",
|
||||
"currency",
|
||||
"data",
|
||||
]
|
||||
for field in required_statement_fields:
|
||||
assert field in balance_sheet
|
||||
|
||||
# Check some key balance sheet items
|
||||
bs_data = balance_sheet["data"]
|
||||
assert "Total Assets" in bs_data
|
||||
assert "Total Liabilities" in bs_data
|
||||
assert "Total Shareholders Equity" in bs_data
|
||||
|
||||
# Validate key ratios
|
||||
ratios = json_data["key_ratios"]
|
||||
assert isinstance(ratios, dict)
|
||||
assert len(ratios) > 0
|
||||
|
||||
# Validate metadata
|
||||
metadata = json_data["metadata"]
|
||||
assert "data_quality" in metadata
|
||||
assert "service" in metadata
|
||||
|
||||
|
||||
def test_comprehensive_ratio_calculation():
|
||||
"""Test comprehensive financial ratio calculations."""
|
||||
mock_simfin = MockSimFinClient()
|
||||
service = FundamentalDataService(
|
||||
simfin_client=mock_simfin, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_fundamental_context("COMP", "2024-01-01", "2024-12-31")
|
||||
ratios = context.key_ratios
|
||||
|
||||
# Liquidity ratios
|
||||
|
||||
# Not all ratios may be calculable depending on available data
|
||||
calculated_ratios = set(ratios.keys())
|
||||
core_ratios = {
|
||||
"current_ratio",
|
||||
"debt_to_equity",
|
||||
"gross_margin",
|
||||
"net_margin",
|
||||
"roe",
|
||||
"roa",
|
||||
}
|
||||
|
||||
# At least the core ratios should be present
|
||||
assert core_ratios.issubset(calculated_ratios), (
|
||||
f"Missing core ratios: {core_ratios - calculated_ratios}"
|
||||
)
|
||||
|
||||
# All ratio values should be numbers
|
||||
for ratio_name, ratio_value in ratios.items():
|
||||
assert isinstance(ratio_value, int | float), (
|
||||
f"{ratio_name} should be numeric, got {type(ratio_value)}"
|
||||
)
|
||||
assert ratio_value == ratio_value, (
|
||||
f"{ratio_name} should not be NaN"
|
||||
) # NaN check
|
||||
|
|
@ -7,11 +7,26 @@ from dataclasses import dataclass
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.domains.marketdata.finnhub_client import FinnhubClient
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
|
||||
from .clients.finnhub_client import FinnhubClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InsiderDataRepository:
|
||||
"""Simple repository for insider data - placeholder implementation."""
|
||||
|
||||
def __init__(self, data_dir: str):
|
||||
self.data_dir = data_dir
|
||||
|
||||
def get_data(self, symbol: str, start_date: str, end_date: str) -> dict:
|
||||
return {}
|
||||
|
||||
def store_data(self, symbol: str, data: dict) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class DataQuality(Enum):
|
||||
"""Data quality levels for insider data."""
|
||||
|
||||
|
|
@ -85,6 +100,12 @@ class InsiderDataService:
|
|||
self.client = client
|
||||
self.repository = repository
|
||||
|
||||
@staticmethod
|
||||
def build(_config: TradingAgentsConfig):
|
||||
client = FinnhubClient("")
|
||||
repo = InsiderDataRepository("")
|
||||
return InsiderDataService(client, repo)
|
||||
|
||||
def get_insider_sentiment_context(
|
||||
self,
|
||||
symbol: str,
|
||||
|
|
|
|||
|
|
@ -3,90 +3,57 @@ Market data service that provides structured market context.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import talib
|
||||
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
from tradingagents.domains.marketdata.clients.yfinance_client import YFinanceClient
|
||||
from tradingagents.domains.marketdata.models import (
|
||||
INDICATOR_DEFINITIONS,
|
||||
DataQuality,
|
||||
IndicatorConfig,
|
||||
IndicatorParamValue,
|
||||
IndicatorPresets,
|
||||
InputSpec,
|
||||
PriceDataContext,
|
||||
TAReportContext,
|
||||
TechnicalAnalysisError,
|
||||
TechnicalIndicatorData,
|
||||
)
|
||||
from tradingagents.domains.marketdata.repos.market_data_repository import (
|
||||
MarketDataRepository,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataQuality(Enum):
|
||||
"""Data quality levels for market data."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TechnicalIndicatorData:
|
||||
"""Technical indicator data point."""
|
||||
|
||||
date: str
|
||||
value: float | dict[str, Any]
|
||||
indicator_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketDataContext:
|
||||
"""Market data context for trading analysis."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
price_data: list[dict[str, Any]]
|
||||
technical_indicators: dict[str, list[TechnicalIndicatorData]]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TAReportContext:
|
||||
"""Technical Analysis Report context for specific indicators."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
indicator: str
|
||||
indicator_data: list[TechnicalIndicatorData]
|
||||
analysis_summary: str
|
||||
signal_strength: float # -1.0 to 1.0
|
||||
recommendation: str # "BUY", "SELL", "HOLD"
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceDataContext:
|
||||
"""Price Data context for historical price information."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
price_data: list[dict[str, Any]]
|
||||
latest_price: float
|
||||
price_change: float
|
||||
price_change_percent: float
|
||||
volume_info: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class MarketDataService:
|
||||
"""Service for market data and technical indicators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
yfin_client: YFinClient,
|
||||
repo: MarketdataRepository,
|
||||
yfin_client: YFinanceClient,
|
||||
repo: MarketDataRepository,
|
||||
):
|
||||
"""
|
||||
Initialize market data service.
|
||||
|
||||
Args:
|
||||
client: Client for live market data
|
||||
repository: Repository for historical market data
|
||||
online_mode: Whether to use live data
|
||||
**kwargs: Additional configuration
|
||||
yfin_client: Client for live market data
|
||||
repo: Repository for historical market data
|
||||
"""
|
||||
self.finnhub_client = finnhub_client
|
||||
self.yfin_client = yfin_client
|
||||
self.repo = repo
|
||||
|
||||
@staticmethod
|
||||
def build(_config: TradingAgentsConfig):
|
||||
client = YFinanceClient()
|
||||
repo = MarketDataRepository("")
|
||||
return MarketDataService(client, repo)
|
||||
|
||||
def get_market_data_context(
|
||||
self, symbol: str, start_date: str, end_date: str
|
||||
) -> PriceDataContext:
|
||||
|
|
@ -97,26 +64,106 @@ class MarketDataService:
|
|||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
PriceDataContext: Focused price data context
|
||||
"""
|
||||
# return PriceDataContext(
|
||||
# symbol=symbol,
|
||||
# period={"start": start_date, "end": end_date},
|
||||
# price_data=price_data.get("data", []),
|
||||
# latest_price=latest_price,
|
||||
# price_change=price_change,
|
||||
# price_change_percent=price_change_percent,
|
||||
# volume_info=volume_info,
|
||||
# metadata=metadata,
|
||||
# )
|
||||
try:
|
||||
# Convert string dates to date objects
|
||||
start_date_obj = datetime.strptime(start_date, "%Y-%m-%d").date()
|
||||
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d").date()
|
||||
|
||||
pass # TODO: get data from repo
|
||||
# Get data from repository first
|
||||
df = self.repo.get_market_data_df(symbol, start_date_obj, end_date_obj)
|
||||
|
||||
if df.empty:
|
||||
# No data in repository, try to fetch from client
|
||||
logger.info(f"No local data for {symbol}, fetching from client")
|
||||
client_data = self.yfin_client.get_data(symbol, start_date, end_date)
|
||||
price_data = client_data.get("data", [])
|
||||
|
||||
# Convert to DataFrame and store in repository
|
||||
if price_data:
|
||||
df_to_store = pd.DataFrame(price_data)
|
||||
self.repo.store_marketdata(symbol, df_to_store)
|
||||
df = df_to_store
|
||||
else:
|
||||
# Convert DataFrame to list of dictionaries
|
||||
price_data = df.to_dict("records")
|
||||
|
||||
# Calculate metrics
|
||||
latest_price = 0.0
|
||||
price_change = 0.0
|
||||
price_change_percent = 0.0
|
||||
volume_info = {"average_volume": 0, "latest_volume": 0}
|
||||
|
||||
if not df.empty and "Close" in df.columns:
|
||||
latest_price = float(df["Close"].iloc[-1])
|
||||
if len(df) > 1:
|
||||
previous_price = float(df["Close"].iloc[-2])
|
||||
price_change = latest_price - previous_price
|
||||
price_change_percent = (
|
||||
(price_change / previous_price) * 100
|
||||
if previous_price != 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
if "Volume" in df.columns:
|
||||
volume_info = {
|
||||
"average_volume": int(df["Volume"].mean()),
|
||||
"latest_volume": int(df["Volume"].iloc[-1]),
|
||||
}
|
||||
|
||||
# Convert DataFrame back to list of dicts for price_data
|
||||
price_data = df.to_dict("records") if not df.empty else []
|
||||
|
||||
# Assess data quality
|
||||
data_quality = DataQuality.HIGH if len(price_data) > 0 else DataQuality.LOW
|
||||
|
||||
metadata = {
|
||||
"data_quality": data_quality.value,
|
||||
"service": "market_data",
|
||||
"record_count": len(price_data),
|
||||
"source": "repository" if not df.empty else "client",
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
return PriceDataContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
price_data=price_data,
|
||||
latest_price=latest_price,
|
||||
price_change=price_change,
|
||||
price_change_percent=price_change_percent,
|
||||
volume_info=volume_info,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market data context for {symbol}: {e}")
|
||||
return PriceDataContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
price_data=[],
|
||||
latest_price=0.0,
|
||||
price_change=0.0,
|
||||
price_change_percent=0.0,
|
||||
volume_info={"average_volume": 0, "latest_volume": 0},
|
||||
metadata={
|
||||
"data_quality": DataQuality.LOW.value,
|
||||
"service": "market_data",
|
||||
"error": str(e),
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
def get_ta_report_context(
|
||||
self, symbol: str, indicator: str, start_date: str, end_date: str
|
||||
self,
|
||||
symbol: str,
|
||||
indicator: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
custom_params: dict[str, IndicatorParamValue] | None = None,
|
||||
) -> TAReportContext:
|
||||
"""
|
||||
Get technical analysis report context for a specific indicator.
|
||||
|
|
@ -126,24 +173,548 @@ class MarketDataService:
|
|||
indicator: Technical indicator name (e.g., 'rsi', 'macd', 'sma')
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
TAReportContext: Focused technical analysis context
|
||||
"""
|
||||
try:
|
||||
# Get price data first
|
||||
price_context = self.get_market_data_context(symbol, start_date, end_date)
|
||||
|
||||
# return TAReportContext(
|
||||
# symbol=symbol,
|
||||
# period={"start": start_date, "end": end_date},
|
||||
# indicator=indicator,
|
||||
# indicator_data=indicator_data.get(indicator, []),
|
||||
# analysis_summary=analysis_summary,
|
||||
# signal_strength=signal_strength,
|
||||
# recommendation=recommendation,
|
||||
# metadata=metadata,
|
||||
# )
|
||||
if not price_context.price_data:
|
||||
# Create empty indicator config for no data case
|
||||
no_data_config = IndicatorConfig(
|
||||
name=indicator.upper(),
|
||||
parameters={},
|
||||
input_types=["close"],
|
||||
output_format="single",
|
||||
param_ranges={},
|
||||
default_params={},
|
||||
talib_function="",
|
||||
description="",
|
||||
)
|
||||
|
||||
pass # TODO get data from repo and calculate indicator with TALib?
|
||||
return TAReportContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
indicator=indicator,
|
||||
indicator_data=[],
|
||||
analysis_summary="No price data available for technical analysis",
|
||||
signal_strength=0.0,
|
||||
recommendation="HOLD",
|
||||
indicator_config=no_data_config,
|
||||
parameter_summary="",
|
||||
metadata={
|
||||
"data_quality": DataQuality.LOW.value,
|
||||
"service": "technical_analysis",
|
||||
"error": "no_price_data",
|
||||
},
|
||||
)
|
||||
|
||||
# Calculate technical indicator using TA-Lib
|
||||
indicator_data = self._calculate_indicator_talib(
|
||||
price_context.price_data, indicator, custom_params
|
||||
)
|
||||
|
||||
# Generate analysis and recommendations
|
||||
signal_strength = self._calculate_signal_strength(indicator_data, indicator)
|
||||
recommendation = self._get_recommendation(signal_strength)
|
||||
analysis_summary = self._generate_analysis_summary(
|
||||
indicator, signal_strength, recommendation
|
||||
)
|
||||
|
||||
# Create indicator config from the calculation
|
||||
indicator_config = IndicatorConfig(
|
||||
name=indicator.upper(),
|
||||
parameters=indicator_data[0].parameters if indicator_data else {},
|
||||
input_types=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
|
||||
"input_types", ["close"]
|
||||
),
|
||||
output_format=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
|
||||
"output_format", "single"
|
||||
),
|
||||
param_ranges=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
|
||||
"param_ranges", {}
|
||||
),
|
||||
default_params=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
|
||||
"default_params", {}
|
||||
),
|
||||
talib_function=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
|
||||
"talib_function", ""
|
||||
),
|
||||
description=INDICATOR_DEFINITIONS.get(indicator.upper(), {}).get(
|
||||
"description", ""
|
||||
),
|
||||
)
|
||||
|
||||
# Generate parameter summary
|
||||
params = indicator_data[0].parameters if indicator_data else {}
|
||||
parameter_summary = ", ".join([f"{k}={v}" for k, v in params.items()])
|
||||
|
||||
return TAReportContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
indicator=indicator,
|
||||
indicator_data=indicator_data,
|
||||
analysis_summary=analysis_summary,
|
||||
signal_strength=signal_strength,
|
||||
recommendation=recommendation,
|
||||
indicator_config=indicator_config,
|
||||
parameter_summary=parameter_summary,
|
||||
metadata={
|
||||
"data_quality": DataQuality.HIGH.value,
|
||||
"service": "technical_analysis",
|
||||
"indicator_count": len(indicator_data),
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting TA report for {symbol} {indicator}: {e}")
|
||||
# Create empty indicator config for error case
|
||||
error_config = IndicatorConfig(
|
||||
name=indicator.upper(),
|
||||
parameters={},
|
||||
input_types=["close"],
|
||||
output_format="single",
|
||||
param_ranges={},
|
||||
default_params={},
|
||||
talib_function="",
|
||||
description="",
|
||||
)
|
||||
|
||||
return TAReportContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
indicator=indicator,
|
||||
indicator_data=[],
|
||||
analysis_summary=f"Error calculating {indicator}: {str(e)}",
|
||||
signal_strength=0.0,
|
||||
recommendation="HOLD",
|
||||
indicator_config=error_config,
|
||||
parameter_summary="",
|
||||
metadata={
|
||||
"data_quality": DataQuality.LOW.value,
|
||||
"service": "technical_analysis",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
def _validate_parameters(
|
||||
self, indicator: str, params: dict[str, IndicatorParamValue]
|
||||
) -> None:
|
||||
"""Validate indicator parameters against defined ranges."""
|
||||
if indicator.upper() not in INDICATOR_DEFINITIONS:
|
||||
raise TechnicalAnalysisError(f"Unknown indicator: {indicator}")
|
||||
|
||||
definition = INDICATOR_DEFINITIONS[indicator.upper()]
|
||||
param_ranges = definition.get("param_ranges", {})
|
||||
|
||||
for param_name, value in params.items():
|
||||
if param_name in param_ranges:
|
||||
min_val, max_val = param_ranges[param_name]
|
||||
if not isinstance(value, int | float):
|
||||
raise TechnicalAnalysisError(
|
||||
f"Parameter {param_name} must be numeric"
|
||||
)
|
||||
if not (min_val <= value <= max_val):
|
||||
raise TechnicalAnalysisError(
|
||||
f"Parameter {param_name}={value} out of range [{min_val}, {max_val}]"
|
||||
)
|
||||
|
||||
def _prepare_price_arrays(
|
||||
self, price_data: list[dict[str, Any]], input_types: list[InputSpec]
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare price arrays for TA-Lib functions."""
|
||||
if not price_data:
|
||||
raise TechnicalAnalysisError("No price data provided")
|
||||
|
||||
df = pd.DataFrame(price_data)
|
||||
required_columns = []
|
||||
|
||||
for input_type in input_types:
|
||||
if input_type == "close":
|
||||
required_columns.extend(["Close"])
|
||||
elif input_type == "ohlc":
|
||||
required_columns.extend(["Open", "High", "Low", "Close"])
|
||||
elif input_type == "ohlcv":
|
||||
required_columns.extend(["Open", "High", "Low", "Close", "Volume"])
|
||||
elif input_type == "hl":
|
||||
required_columns.extend(["High", "Low"])
|
||||
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
if missing_columns:
|
||||
raise TechnicalAnalysisError(f"Missing required columns: {missing_columns}")
|
||||
|
||||
# Convert to numpy arrays for TA-Lib
|
||||
arrays = {}
|
||||
if "Open" in df.columns:
|
||||
arrays["open"] = df["Open"].astype(float).values
|
||||
if "High" in df.columns:
|
||||
arrays["high"] = df["High"].astype(float).values
|
||||
if "Low" in df.columns:
|
||||
arrays["low"] = df["Low"].astype(float).values
|
||||
if "Close" in df.columns:
|
||||
arrays["close"] = df["Close"].astype(float).values
|
||||
if "Volume" in df.columns:
|
||||
arrays["volume"] = df["Volume"].astype(float).values
|
||||
|
||||
arrays["dates"] = df["Date"].astype(str).values
|
||||
|
||||
return arrays
|
||||
|
||||
def _calculate_indicator_talib(
|
||||
self,
|
||||
price_data: list[dict[str, Any]],
|
||||
indicator: str,
|
||||
params: dict[str, IndicatorParamValue] | None = None,
|
||||
) -> list[TechnicalIndicatorData]:
|
||||
"""Calculate technical indicator using TA-Lib."""
|
||||
if not price_data:
|
||||
return []
|
||||
|
||||
# Get indicator definition
|
||||
indicator_upper = indicator.upper()
|
||||
if indicator_upper not in INDICATOR_DEFINITIONS:
|
||||
raise TechnicalAnalysisError(f"Unknown indicator: {indicator}")
|
||||
|
||||
definition = INDICATOR_DEFINITIONS[indicator_upper]
|
||||
|
||||
# Use provided params or defaults
|
||||
final_params: dict[str, IndicatorParamValue]
|
||||
if params is None:
|
||||
final_params = definition["default_params"].copy()
|
||||
else:
|
||||
# Merge with defaults for missing parameters
|
||||
final_params = definition["default_params"].copy()
|
||||
final_params.update(params)
|
||||
|
||||
# Validate parameters
|
||||
self._validate_parameters(indicator, final_params)
|
||||
|
||||
# Prepare price arrays
|
||||
arrays = self._prepare_price_arrays(price_data, definition["input_types"])
|
||||
|
||||
# Get TA-Lib function
|
||||
talib_func_name = definition["talib_function"].split(".")[
|
||||
-1
|
||||
] # Extract function name
|
||||
talib_func = getattr(talib, talib_func_name)
|
||||
|
||||
# Prepare function arguments
|
||||
func_args = []
|
||||
func_kwargs = {}
|
||||
|
||||
# Add required price arrays based on input types
|
||||
for input_type in definition["input_types"]:
|
||||
if input_type == "close":
|
||||
func_args.append(arrays["close"])
|
||||
elif input_type == "ohlc":
|
||||
func_args.extend([arrays["high"], arrays["low"], arrays["close"]])
|
||||
elif input_type == "ohlcv":
|
||||
func_args.extend(
|
||||
[arrays["high"], arrays["low"], arrays["close"], arrays["volume"]]
|
||||
)
|
||||
elif input_type == "hl":
|
||||
func_args.extend([arrays["high"], arrays["low"]])
|
||||
|
||||
# Add parameters as keyword arguments
|
||||
for param_name, param_value in final_params.items():
|
||||
func_kwargs[param_name] = param_value
|
||||
|
||||
# Calculate indicator
|
||||
try:
|
||||
ta_result = talib_func(*func_args, **func_kwargs)
|
||||
except Exception as e:
|
||||
raise TechnicalAnalysisError(
|
||||
f"TA-Lib calculation failed for {indicator}: {str(e)}"
|
||||
) from e
|
||||
|
||||
# Process results based on output format
|
||||
result = []
|
||||
dates = arrays["dates"]
|
||||
output_format = definition["output_format"]
|
||||
|
||||
if output_format == "single":
|
||||
# Single output array
|
||||
for _i, (date, value) in enumerate(zip(dates, ta_result, strict=False)):
|
||||
if not pd.isna(value):
|
||||
result.append(
|
||||
TechnicalIndicatorData(
|
||||
date=date,
|
||||
value=float(value),
|
||||
indicator_type=indicator.lower(),
|
||||
parameters=final_params,
|
||||
)
|
||||
)
|
||||
|
||||
elif output_format == "double":
|
||||
# Two output arrays (e.g., STOCH, AROON)
|
||||
for _i, (date, val1, val2) in enumerate(
|
||||
zip(dates, ta_result[0], ta_result[1], strict=False)
|
||||
):
|
||||
if not pd.isna(val1) and not pd.isna(val2):
|
||||
# Name outputs based on indicator
|
||||
if indicator_upper == "STOCH":
|
||||
value_dict = {"slowk": float(val1), "slowd": float(val2)}
|
||||
elif indicator_upper == "AROON":
|
||||
value_dict = {"aroondown": float(val1), "aroonup": float(val2)}
|
||||
else:
|
||||
value_dict = {"output1": float(val1), "output2": float(val2)}
|
||||
|
||||
result.append(
|
||||
TechnicalIndicatorData(
|
||||
date=date,
|
||||
value=value_dict,
|
||||
indicator_type=indicator.lower(),
|
||||
parameters=final_params,
|
||||
)
|
||||
)
|
||||
|
||||
elif output_format == "triple":
|
||||
# Three output arrays (e.g., MACD, BBANDS)
|
||||
for _i, (date, val1, val2, val3) in enumerate(
|
||||
zip(dates, ta_result[0], ta_result[1], ta_result[2], strict=False)
|
||||
):
|
||||
if not pd.isna(val1):
|
||||
# Name outputs based on indicator
|
||||
if indicator_upper == "MACD":
|
||||
value_dict = {
|
||||
"macd": float(val1),
|
||||
"signal": float(val2) if not pd.isna(val2) else 0.0,
|
||||
"histogram": float(val3) if not pd.isna(val3) else 0.0,
|
||||
}
|
||||
elif indicator_upper == "BBANDS":
|
||||
value_dict = {
|
||||
"upper": float(val1),
|
||||
"middle": float(val2) if not pd.isna(val2) else 0.0,
|
||||
"lower": float(val3) if not pd.isna(val3) else 0.0,
|
||||
}
|
||||
else:
|
||||
value_dict = {
|
||||
"output1": float(val1),
|
||||
"output2": float(val2) if not pd.isna(val2) else 0.0,
|
||||
"output3": float(val3) if not pd.isna(val3) else 0.0,
|
||||
}
|
||||
|
||||
result.append(
|
||||
TechnicalIndicatorData(
|
||||
date=date,
|
||||
value=value_dict,
|
||||
indicator_type=indicator.lower(),
|
||||
parameters=final_params,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def calculate_indicator(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
indicator: str | dict[str, IndicatorParamValue],
|
||||
params: dict[str, IndicatorParamValue] | None = None,
|
||||
) -> TAReportContext:
|
||||
"""
|
||||
Three-tier API for technical indicator calculation.
|
||||
|
||||
Usage:
|
||||
1. String: calculate_indicator("AAPL", "2024-01-01", "2024-01-31", "RSI")
|
||||
2. Preset: calculate_indicator("AAPL", "2024-01-01", "2024-01-31", "RSI_SCALPING")
|
||||
3. Custom: calculate_indicator("AAPL", "2024-01-01", "2024-01-31", "RSI", {"timeperiod": 21})
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
indicator: Indicator name (string), preset name, or custom config dict
|
||||
params: Optional custom parameters (for string indicators only)
|
||||
|
||||
Returns:
|
||||
TAReportContext: Complete technical analysis context
|
||||
"""
|
||||
if isinstance(indicator, dict):
|
||||
# Custom configuration provided as dict
|
||||
if "name" not in indicator:
|
||||
raise TechnicalAnalysisError(
|
||||
"Custom indicator dict must contain 'name' field"
|
||||
)
|
||||
indicator_name = str(indicator["name"]) # Ensure it's a string
|
||||
custom_params = {k: v for k, v in indicator.items() if k != "name"}
|
||||
return self.get_ta_report_context(
|
||||
symbol, indicator_name, start_date, end_date, custom_params
|
||||
)
|
||||
|
||||
# Check if it's a preset name
|
||||
all_presets = IndicatorPresets.get_all_presets()
|
||||
if indicator in all_presets:
|
||||
# Extract indicator name and parameters from preset
|
||||
preset_params = all_presets[indicator]
|
||||
# Determine base indicator from preset name
|
||||
for base_indicator in INDICATOR_DEFINITIONS:
|
||||
if indicator.startswith(base_indicator):
|
||||
return self.get_ta_report_context(
|
||||
symbol,
|
||||
base_indicator.lower(),
|
||||
start_date,
|
||||
end_date,
|
||||
preset_params,
|
||||
)
|
||||
|
||||
# If no match found, try to extract from preset name
|
||||
indicator_name = indicator.split("_")[0].lower()
|
||||
return self.get_ta_report_context(
|
||||
symbol, indicator_name, start_date, end_date, preset_params
|
||||
)
|
||||
|
||||
# Regular indicator name (string)
|
||||
return self.get_ta_report_context(
|
||||
symbol, indicator, start_date, end_date, params
|
||||
)
|
||||
|
||||
def get_available_indicators(self) -> dict[str, str]:
|
||||
"""Get list of all available indicators with descriptions."""
|
||||
return {
|
||||
name: info["description"] for name, info in INDICATOR_DEFINITIONS.items()
|
||||
}
|
||||
|
||||
def get_available_presets(
|
||||
self, style: str | None = None
|
||||
) -> dict[str, dict[str, IndicatorParamValue]]:
|
||||
"""
|
||||
Get available indicator presets.
|
||||
|
||||
Args:
|
||||
style: Optional trading style filter ("scalping", "day_trading", "swing", "position")
|
||||
|
||||
Returns:
|
||||
Dict of preset names to parameter configurations
|
||||
"""
|
||||
if style:
|
||||
return IndicatorPresets.get_preset_for_style(style)
|
||||
return IndicatorPresets.get_all_presets()
|
||||
|
||||
def get_indicator_info(self, indicator: str) -> IndicatorConfig:
|
||||
"""
|
||||
Get detailed information about a specific indicator.
|
||||
|
||||
Args:
|
||||
indicator: Indicator name
|
||||
|
||||
Returns:
|
||||
IndicatorConfig with full indicator specifications
|
||||
"""
|
||||
indicator_upper = indicator.upper()
|
||||
if indicator_upper not in INDICATOR_DEFINITIONS:
|
||||
raise TechnicalAnalysisError(f"Unknown indicator: {indicator}")
|
||||
|
||||
definition = INDICATOR_DEFINITIONS[indicator_upper]
|
||||
return IndicatorConfig(
|
||||
name=indicator_upper,
|
||||
parameters=definition["default_params"],
|
||||
input_types=definition["input_types"],
|
||||
output_format=definition["output_format"],
|
||||
param_ranges=definition["param_ranges"],
|
||||
default_params=definition["default_params"],
|
||||
talib_function=definition["talib_function"],
|
||||
description=definition["description"],
|
||||
)
|
||||
|
||||
def _calculate_signal_strength(
|
||||
self, indicator_data: list[TechnicalIndicatorData], indicator: str
|
||||
) -> float:
|
||||
"""Calculate signal strength from indicator data."""
|
||||
if not indicator_data:
|
||||
return 0.0
|
||||
|
||||
latest = indicator_data[-1]
|
||||
|
||||
if indicator.lower() == "rsi":
|
||||
rsi_value = latest.value
|
||||
if isinstance(rsi_value, int | float):
|
||||
if rsi_value > 70:
|
||||
return -0.8 # Overbought - sell signal
|
||||
elif rsi_value < 30:
|
||||
return 0.8 # Oversold - buy signal
|
||||
else:
|
||||
return (50 - rsi_value) / 50 # Normalized between -1 and 1
|
||||
|
||||
elif indicator.lower() == "macd":
|
||||
if isinstance(latest.value, dict):
|
||||
macd_val = latest.value.get("macd", 0)
|
||||
signal_val = latest.value.get("signal", 0)
|
||||
if macd_val > signal_val:
|
||||
return 0.6 # Bullish
|
||||
else:
|
||||
return -0.6 # Bearish
|
||||
|
||||
elif indicator.lower() == "sma":
|
||||
# Would need current price to compare with SMA
|
||||
return 0.0 # Neutral for now
|
||||
|
||||
return 0.0
|
||||
|
||||
def _get_recommendation(self, signal_strength: float) -> str:
|
||||
"""Convert signal strength to recommendation."""
|
||||
if signal_strength > 0.5:
|
||||
return "BUY"
|
||||
elif signal_strength < -0.5:
|
||||
return "SELL"
|
||||
else:
|
||||
return "HOLD"
|
||||
|
||||
def _generate_analysis_summary(
|
||||
self, indicator: str, signal_strength: float, recommendation: str
|
||||
) -> str:
|
||||
"""Generate human-readable analysis summary."""
|
||||
strength_desc = (
|
||||
"strong"
|
||||
if abs(signal_strength) > 0.7
|
||||
else "moderate"
|
||||
if abs(signal_strength) > 0.3
|
||||
else "weak"
|
||||
)
|
||||
direction = (
|
||||
"bullish"
|
||||
if signal_strength > 0
|
||||
else "bearish"
|
||||
if signal_strength < 0
|
||||
else "neutral"
|
||||
)
|
||||
|
||||
return f"{indicator.upper()} indicator shows {strength_desc} {direction} signal. Signal strength: {signal_strength:.2f}. Recommendation: {recommendation}."
|
||||
|
||||
def update_market_data(self, symbol: str, start_date: str, end_date: str):
|
||||
pass # TODO: fetch market data and save
|
||||
"""
|
||||
Update market data by fetching fresh data from client and storing in repository.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"Updating market data for {symbol} from {start_date} to {end_date}"
|
||||
)
|
||||
|
||||
# Fetch fresh data from client
|
||||
client_data = self.yfin_client.get_data(symbol, start_date, end_date)
|
||||
price_data = client_data.get("data", [])
|
||||
|
||||
if price_data:
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(price_data)
|
||||
|
||||
# Store in repository
|
||||
self.repo.store_marketdata(symbol, df)
|
||||
logger.info(
|
||||
f"Successfully stored {len(price_data)} records for {symbol}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"No data received for {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating market data for {symbol}: {e}")
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,546 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test MarketDataService with mock YFinanceClient and real MarketDataRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.insert(0, os.path.abspath("."))
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.domains.marketdata.market_data_service import (
|
||||
DataQuality,
|
||||
MarketDataContext,
|
||||
MarketDataService,
|
||||
)
|
||||
from tradingagents.repositories.market_data_repository import MarketDataRepository
|
||||
|
||||
|
||||
class MockYFinanceClient(BaseClient):
|
||||
"""Mock Yahoo Finance client that returns predictable test data."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.connection_works = True
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
return self.connection_works
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
"""Return realistic mock market data."""
|
||||
# Generate realistic price data
|
||||
base_price = {"AAPL": 180.0, "TSLA": 250.0, "MSFT": 400.0}.get(symbol, 100.0)
|
||||
|
||||
mock_data = []
|
||||
current_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
price = base_price
|
||||
while current_date <= end_date_dt:
|
||||
# Simulate some price movement
|
||||
price_change = (
|
||||
hash(current_date.strftime("%Y-%m-%d")) % 10 - 5
|
||||
) / 100 # -5% to +5%
|
||||
price *= 1 + price_change * 0.01
|
||||
|
||||
mock_data.append(
|
||||
{
|
||||
"Date": current_date.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"Open": round(price * 0.99, 2),
|
||||
"High": round(price * 1.02, 2),
|
||||
"Low": round(price * 0.98, 2),
|
||||
"Close": round(price, 2),
|
||||
"Adj Close": round(price, 2),
|
||||
"Volume": 45000000 + (hash(symbol) % 20000000),
|
||||
}
|
||||
)
|
||||
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"period": {"start": start_date, "end": end_date},
|
||||
"data": mock_data,
|
||||
"metadata": {
|
||||
"source": "mock_yahoo_finance",
|
||||
"record_count": len(mock_data),
|
||||
"columns": [
|
||||
"Date",
|
||||
"Open",
|
||||
"High",
|
||||
"Low",
|
||||
"Close",
|
||||
"Adj Close",
|
||||
"Volume",
|
||||
],
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_online_mode_with_mock_client():
|
||||
"""Test MarketDataService in online mode with mock client."""
|
||||
print("📈 Testing MarketDataService - Online Mode")
|
||||
|
||||
# Create mock client and real repository
|
||||
mock_client = MockYFinanceClient()
|
||||
real_repo = MarketDataRepository("test_data")
|
||||
|
||||
# Create service in online mode
|
||||
service = MarketDataService(
|
||||
client=mock_client, repository=real_repo, online_mode=True, data_dir="test_data"
|
||||
)
|
||||
|
||||
try:
|
||||
# Test basic price context
|
||||
context = service.get_price_context(
|
||||
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
|
||||
)
|
||||
|
||||
print(f"✅ Price context created: {context.__class__.__name__}")
|
||||
print(f" Symbol: {context.symbol}")
|
||||
print(f" Period: {context.period}")
|
||||
print(f" Price data records: {len(context.price_data)}")
|
||||
print(f" Technical indicators: {len(context.technical_indicators)}")
|
||||
|
||||
# Validate required fields
|
||||
assert context.symbol == "AAPL"
|
||||
assert context.period["start"] == "2024-01-01"
|
||||
assert context.period["end"] == "2024-01-05"
|
||||
assert len(context.price_data) > 0
|
||||
assert "data_quality" in context.metadata
|
||||
|
||||
print("✅ Basic validation passed")
|
||||
|
||||
# Test JSON serialization
|
||||
json_output = context.model_dump_json(indent=2)
|
||||
parsed = json.loads(json_output)
|
||||
|
||||
print(f"✅ JSON serialization: {len(json_output)} characters")
|
||||
print(f" Top-level keys: {list(parsed.keys())}")
|
||||
|
||||
# Test with technical indicators
|
||||
context_with_indicators = service.get_context(
|
||||
symbol="TSLA",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-03",
|
||||
indicators=["rsi", "macd"],
|
||||
)
|
||||
|
||||
print("✅ Context with indicators created")
|
||||
print(" Requested indicators: ['rsi', 'macd']")
|
||||
print(
|
||||
f" Available indicators: {list(context_with_indicators.technical_indicators.keys())}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Online mode test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_offline_mode_with_real_repository():
|
||||
"""Test MarketDataService in offline mode with real repository."""
|
||||
print("\n💾 Testing MarketDataService - Offline Mode")
|
||||
|
||||
# Create service in offline mode (no client)
|
||||
real_repo = MarketDataRepository("test_data")
|
||||
service = MarketDataService(
|
||||
client=None, repository=real_repo, online_mode=False, data_dir="test_data"
|
||||
)
|
||||
|
||||
try:
|
||||
# Test offline context (will likely return empty data)
|
||||
context = service.get_price_context(
|
||||
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
|
||||
)
|
||||
|
||||
print(f"✅ Offline context created: {context.__class__.__name__}")
|
||||
print(f" Symbol: {context.symbol}")
|
||||
print(f" Price data records: {len(context.price_data)}")
|
||||
print(f" Data quality: {context.metadata.get('data_quality')}")
|
||||
print(f" Service mode: online={service.is_online()}")
|
||||
|
||||
# Should handle empty data gracefully
|
||||
assert context.symbol == "AAPL"
|
||||
assert isinstance(context.price_data, list)
|
||||
assert "data_quality" in context.metadata
|
||||
|
||||
print("✅ Offline mode graceful handling verified")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Offline mode test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_error_handling():
|
||||
"""Test error handling scenarios."""
|
||||
print("\n⚠️ Testing Error Handling")
|
||||
|
||||
# Test with broken client
|
||||
class BrokenClient(BaseClient):
|
||||
def test_connection(self):
|
||||
return False
|
||||
|
||||
def get_data(self, *args, **kwargs):
|
||||
raise Exception("Simulated client failure")
|
||||
|
||||
broken_client = BrokenClient()
|
||||
real_repo = MarketDataRepository("test_data")
|
||||
|
||||
service = MarketDataService(
|
||||
client=broken_client,
|
||||
repository=real_repo,
|
||||
online_mode=True, # Online mode but client will fail
|
||||
data_dir="test_data",
|
||||
)
|
||||
|
||||
try:
|
||||
context = service.get_price_context("AAPL", "2024-01-01", "2024-01-05")
|
||||
|
||||
print("✅ Error handling worked")
|
||||
print(f" Symbol: {context.symbol}")
|
||||
print(f" Price data records: {len(context.price_data)}")
|
||||
print(f" Data quality: {context.metadata.get('data_quality')}")
|
||||
|
||||
# Should fallback to repository or return empty data
|
||||
assert context.symbol == "AAPL"
|
||||
assert isinstance(context.price_data, list)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error handling test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_data_quality_assessment():
|
||||
"""Test data quality determination logic."""
|
||||
print("\n🔍 Testing Data Quality Assessment")
|
||||
|
||||
mock_client = MockYFinanceClient()
|
||||
real_repo = MarketDataRepository("test_data")
|
||||
|
||||
service = MarketDataService(
|
||||
client=mock_client, repository=real_repo, online_mode=True, data_dir="test_data"
|
||||
)
|
||||
|
||||
try:
|
||||
# Test with good data
|
||||
context = service.get_context("AAPL", "2024-01-01", "2024-01-10")
|
||||
|
||||
data_quality = context.metadata.get("data_quality")
|
||||
print(f"✅ Data quality assessment: {data_quality}")
|
||||
print(f" Records: {len(context.price_data)}")
|
||||
print(f" Online mode: {service.is_online()}")
|
||||
|
||||
# Should be medium or high quality for mock data
|
||||
assert data_quality in [DataQuality.MEDIUM, DataQuality.HIGH]
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Data quality test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_json_structure_validation():
|
||||
"""Test detailed JSON structure validation."""
|
||||
print("\n📄 Testing JSON Structure")
|
||||
|
||||
mock_client = MockYFinanceClient()
|
||||
service = MarketDataService(client=mock_client, repository=None, online_mode=True)
|
||||
|
||||
try:
|
||||
context = service.get_price_context("MSFT", "2024-01-01", "2024-01-03")
|
||||
json_str = context.model_dump_json(indent=2)
|
||||
data = json.loads(json_str)
|
||||
|
||||
# Validate required structure
|
||||
required_fields = [
|
||||
"symbol",
|
||||
"period",
|
||||
"price_data",
|
||||
"technical_indicators",
|
||||
"metadata",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in data, f"Missing field: {field}"
|
||||
|
||||
# Validate period structure
|
||||
period = data["period"]
|
||||
assert "start" in period and "end" in period
|
||||
|
||||
# Validate price data structure
|
||||
assert isinstance(data["price_data"], list)
|
||||
if data["price_data"]:
|
||||
first_record = data["price_data"][0]
|
||||
required_price_fields = ["Date", "Open", "High", "Low", "Close", "Volume"]
|
||||
for field in required_price_fields:
|
||||
assert field in first_record, f"Missing price field: {field}"
|
||||
|
||||
# Validate metadata
|
||||
metadata = data["metadata"]
|
||||
assert "data_quality" in metadata
|
||||
assert "service" in metadata
|
||||
|
||||
print("✅ JSON structure validation passed")
|
||||
print(f" Fields: {list(data.keys())}")
|
||||
print(f" Price records: {len(data['price_data'])}")
|
||||
print(f" Metadata keys: {list(metadata.keys())}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ JSON structure test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_force_refresh_parameter():
|
||||
"""Test the force_refresh parameter functionality."""
|
||||
try:
|
||||
mock_client = MockYFinanceClient()
|
||||
real_repo = MarketDataRepository("test_data")
|
||||
|
||||
service = MarketDataService(
|
||||
client=mock_client, repository=real_repo, online_mode=True
|
||||
)
|
||||
|
||||
# Test normal flow (should use repository if available)
|
||||
normal_context = service.get_context(
|
||||
"AAPL", "2024-01-01", "2024-01-31", force_refresh=False
|
||||
)
|
||||
|
||||
# Test force refresh (should bypass repository and use client)
|
||||
refresh_context = service.get_context(
|
||||
"AAPL", "2024-01-01", "2024-01-31", force_refresh=True
|
||||
)
|
||||
|
||||
# Both should return valid contexts
|
||||
assert isinstance(normal_context, MarketDataContext)
|
||||
assert isinstance(refresh_context, MarketDataContext)
|
||||
assert normal_context.symbol == "AAPL"
|
||||
assert refresh_context.symbol == "AAPL"
|
||||
|
||||
# Check metadata indicates source
|
||||
refresh_metadata = refresh_context.metadata
|
||||
assert "force_refresh" in refresh_metadata
|
||||
assert refresh_metadata["force_refresh"]
|
||||
|
||||
print("✅ Force refresh parameter test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Force refresh test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_local_first_strategy():
|
||||
"""Test that the service checks local data first when available."""
|
||||
try:
|
||||
|
||||
class MockRepositoryWithData(MarketDataRepository):
|
||||
def has_data_for_period(
|
||||
self, identifier: str, start_date: str, end_date: str, **kwargs
|
||||
) -> bool:
|
||||
return True # Pretend we have the data
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"symbol": kwargs.get("symbol", "TEST"),
|
||||
"data": [
|
||||
{"date": "2024-01-01", "close": 150.0},
|
||||
{"date": "2024-01-02", "close": 151.0},
|
||||
],
|
||||
"metadata": {"source": "test_repository"},
|
||||
}
|
||||
|
||||
mock_client = MockYFinanceClient()
|
||||
mock_repo = MockRepositoryWithData("test_data")
|
||||
|
||||
service = MarketDataService(
|
||||
client=mock_client, repository=mock_repo, online_mode=True
|
||||
)
|
||||
|
||||
# Should use local data since repository has_data_for_period returns True
|
||||
context = service.get_context("TEST", "2024-01-01", "2024-01-31")
|
||||
|
||||
# Verify we used local data
|
||||
assert context.metadata.get("price_data_source") == "local_cache"
|
||||
assert len(context.price_data) == 2 # From mock repository
|
||||
|
||||
print("✅ Local-first strategy test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Local-first strategy test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_local_first_fallback_to_api():
|
||||
"""Test that service falls back to API when local data is insufficient."""
|
||||
try:
|
||||
|
||||
class MockRepositoryWithoutData(MarketDataRepository):
|
||||
def has_data_for_period(
|
||||
self, identifier: str, start_date: str, end_date: str, **kwargs
|
||||
) -> bool:
|
||||
return False # Pretend we don't have the data
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"symbol": kwargs.get("symbol", "TEST"),
|
||||
"data": [],
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
def store_data(
|
||||
self,
|
||||
symbol: str,
|
||||
data: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
return True # Pretend storage was successful
|
||||
|
||||
mock_client = MockYFinanceClient()
|
||||
mock_repo = MockRepositoryWithoutData("test_data")
|
||||
|
||||
service = MarketDataService(
|
||||
client=mock_client, repository=mock_repo, online_mode=True
|
||||
)
|
||||
|
||||
# Should fall back to API since repository doesn't have data
|
||||
context = service.get_context("TEST", "2024-01-01", "2024-01-31")
|
||||
|
||||
# Verify we used API data
|
||||
assert context.metadata.get("price_data_source") == "live_api"
|
||||
assert len(context.price_data) > 0 # From mock client
|
||||
|
||||
print("✅ Local-first fallback to API test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Local-first fallback test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_force_refresh_bypasses_local_data():
|
||||
"""Test that force_refresh=True bypasses local data even when available."""
|
||||
try:
|
||||
|
||||
class MockRepositoryAlwaysHasData(MarketDataRepository):
|
||||
def has_data_for_period(
|
||||
self, identifier: str, start_date: str, end_date: str, **kwargs
|
||||
) -> bool:
|
||||
return True # Always claim we have data
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"symbol": kwargs.get("symbol", "TEST"),
|
||||
"data": [
|
||||
{"date": "2024-01-01", "close": 100.0}
|
||||
], # Different from client
|
||||
"metadata": {"source": "local"},
|
||||
}
|
||||
|
||||
def clear_data(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def store_data(
|
||||
self,
|
||||
symbol: str,
|
||||
data: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
mock_client = MockYFinanceClient()
|
||||
mock_repo = MockRepositoryAlwaysHasData("test_data")
|
||||
|
||||
service = MarketDataService(
|
||||
client=mock_client, repository=mock_repo, online_mode=True
|
||||
)
|
||||
|
||||
# Force refresh should bypass local data
|
||||
context = service.get_context(
|
||||
"TEST", "2024-01-01", "2024-01-31", force_refresh=True
|
||||
)
|
||||
|
||||
# Verify we used API data (force refresh)
|
||||
assert context.metadata.get("price_data_source") == "live_api_refresh"
|
||||
assert context.metadata.get("force_refresh")
|
||||
# Should have more data from client than the single point from repository
|
||||
assert len(context.price_data) > 1
|
||||
|
||||
print("✅ Force refresh bypasses local data test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Force refresh bypass test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all MarketDataService tests."""
|
||||
print("🧪 Testing MarketDataService\n")
|
||||
|
||||
tests = [
|
||||
test_online_mode_with_mock_client,
|
||||
test_offline_mode_with_real_repository,
|
||||
test_error_handling,
|
||||
test_data_quality_assessment,
|
||||
test_json_structure_validation,
|
||||
test_force_refresh_parameter,
|
||||
test_local_first_strategy,
|
||||
test_local_first_fallback_to_api,
|
||||
test_force_refresh_bypasses_local_data,
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
if test():
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ Test {test.__name__} crashed: {e}")
|
||||
failed += 1
|
||||
|
||||
print("\n📊 MarketDataService Test Results:")
|
||||
print(f" ✅ Passed: {passed}")
|
||||
print(f" ❌ Failed: {failed}")
|
||||
|
||||
if failed == 0:
|
||||
print("🎉 All MarketDataService tests passed!")
|
||||
else:
|
||||
print("⚠️ Some tests failed - check output above")
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -0,0 +1,662 @@
|
|||
"""
|
||||
Market data models and type definitions for technical analysis.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from proto.message import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Proper type definitions to eliminate Any types
|
||||
IndicatorParamValue = int | float | str | bool
|
||||
InputSpec = Literal["close", "ohlc", "ohlcv", "hl"]
|
||||
OutputSpec = Literal["single", "double", "triple"]
|
||||
ParamRanges = dict[str, tuple[int | float, int | float]]
|
||||
|
||||
|
||||
class IndicatorConfig(BaseModel):
|
||||
"""Configuration for technical indicators with proper typing."""
|
||||
|
||||
name: str
|
||||
parameters: dict[str, IndicatorParamValue] # No more Any
|
||||
input_types: list[InputSpec] # Specific requirements
|
||||
output_format: OutputSpec # Precise specification
|
||||
param_ranges: ParamRanges # Type-safe validation
|
||||
default_params: dict[str, IndicatorParamValue]
|
||||
talib_function: str # Direct TA-Lib function name
|
||||
description: str
|
||||
|
||||
|
||||
class TechnicalAnalysisError(Exception):
|
||||
"""Clear, actionable error messages for TA-Lib issues."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DataQuality(Enum):
|
||||
"""Data quality levels for market data."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
class TechnicalIndicatorData(BaseModel):
|
||||
"""Technical indicator data point with proper typing."""
|
||||
|
||||
date: str
|
||||
value: float | dict[str, float] # No more Any
|
||||
indicator_type: str
|
||||
parameters: dict[str, IndicatorParamValue] # Parameter context
|
||||
confidence: float = 0.0 # Signal confidence
|
||||
source: str = "talib" # Always TA-Lib
|
||||
|
||||
|
||||
class MarketDataContext(BaseModel):
|
||||
"""Market data context for trading analysis."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
price_data: list[dict[str, Any]]
|
||||
technical_indicators: dict[str, list[TechnicalIndicatorData]]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class TAReportContext(BaseModel):
|
||||
"""Technical Analysis Report context with enhanced configuration."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
indicator: str
|
||||
indicator_data: list[TechnicalIndicatorData]
|
||||
analysis_summary: str
|
||||
signal_strength: float # -1.0 to 1.0
|
||||
recommendation: str # "BUY", "SELL", "HOLD"
|
||||
indicator_config: IndicatorConfig # Full config used
|
||||
parameter_summary: str # Human-readable params
|
||||
metadata: dict[str, IndicatorParamValue] # Properly typed
|
||||
|
||||
|
||||
class PriceDataContext(BaseModel):
|
||||
"""Price Data context for historical price information."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
price_data: list[dict[str, Any]]
|
||||
latest_price: float
|
||||
price_change: float
|
||||
price_change_percent: float
|
||||
volume_info: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
# Fundamental Data Models
|
||||
class FinancialRatio(BaseModel):
|
||||
"""Financial ratio with calculation metadata."""
|
||||
|
||||
name: str
|
||||
value: float | None
|
||||
formula: str
|
||||
category: str # "profitability", "liquidity", "leverage", "efficiency"
|
||||
interpretation: str
|
||||
|
||||
|
||||
class BalanceSheetData(BaseModel):
|
||||
"""Balance sheet line items."""
|
||||
|
||||
date: str
|
||||
total_assets: float | None = None
|
||||
current_assets: float | None = None
|
||||
cash_and_equivalents: float | None = None
|
||||
accounts_receivable: float | None = None
|
||||
inventory: float | None = None
|
||||
total_liabilities: float | None = None
|
||||
current_liabilities: float | None = None
|
||||
accounts_payable: float | None = None
|
||||
short_term_debt: float | None = None
|
||||
long_term_debt: float | None = None
|
||||
total_equity: float | None = None
|
||||
retained_earnings: float | None = None
|
||||
|
||||
|
||||
class IncomeStatementData(BaseModel):
|
||||
"""Income statement line items."""
|
||||
|
||||
date: str
|
||||
revenue: float | None = None
|
||||
gross_profit: float | None = None
|
||||
operating_income: float | None = None
|
||||
net_income: float | None = None
|
||||
ebitda: float | None = None
|
||||
cost_of_revenue: float | None = None
|
||||
operating_expenses: float | None = None
|
||||
interest_expense: float | None = None
|
||||
tax_expense: float | None = None
|
||||
shares_outstanding: float | None = None
|
||||
eps: float | None = None
|
||||
|
||||
|
||||
class CashFlowData(BaseModel):
|
||||
"""Cash flow statement line items."""
|
||||
|
||||
date: str
|
||||
operating_cash_flow: float | None = None
|
||||
investing_cash_flow: float | None = None
|
||||
financing_cash_flow: float | None = None
|
||||
free_cash_flow: float | None = None
|
||||
capital_expenditures: float | None = None
|
||||
dividends_paid: float | None = None
|
||||
stock_repurchases: float | None = None
|
||||
|
||||
|
||||
class BalanceSheetContext(BaseModel):
|
||||
"""Balance sheet context for fundamental analysis."""
|
||||
|
||||
symbol: str
|
||||
start_date: date
|
||||
end_date: date
|
||||
balance_sheet_data: list[BalanceSheetData]
|
||||
key_ratios: list[FinancialRatio]
|
||||
data_quality: DataQuality
|
||||
source: str
|
||||
metadata: dict[str, str]
|
||||
|
||||
|
||||
class IncomeStatementContext(BaseModel):
|
||||
"""Income statement context for fundamental analysis."""
|
||||
|
||||
symbol: str
|
||||
start_date: date
|
||||
end_date: date
|
||||
income_statement_data: list[IncomeStatementData]
|
||||
key_ratios: list[FinancialRatio]
|
||||
data_quality: DataQuality
|
||||
source: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class CashFlowContext(BaseModel):
|
||||
"""Cash flow context for fundamental analysis."""
|
||||
|
||||
symbol: str
|
||||
start_date: date
|
||||
end_date: date
|
||||
cash_flow_data: list[CashFlowData]
|
||||
key_ratios: list[FinancialRatio]
|
||||
data_quality: DataQuality
|
||||
source: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class FundamentalContext(BaseModel):
|
||||
"""Comprehensive fundamental analysis context."""
|
||||
|
||||
symbol: str
|
||||
start_date: date
|
||||
end_date: date
|
||||
balance_sheet: Optional[BalanceSheetContext] = None
|
||||
income_statement: Optional[IncomeStatementContext] = None
|
||||
cash_flow: Optional[CashFlowContext] = None
|
||||
comprehensive_ratios: Optional[list[FinancialRatio]] = None
|
||||
valuation_metrics: Optional[dict[str, float | None]] = None
|
||||
financial_health_score: float = 0 # 0-100 composite score
|
||||
data_quality: DataQuality = DataQuality.LOW
|
||||
source: Optional[str] = None
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
# Reported Financials Models (for Finnhub API responses)
|
||||
@dataclass
|
||||
class FinancialLineItem:
|
||||
"""Individual financial statement line item from reported financials."""
|
||||
|
||||
concept: str
|
||||
unit: str
|
||||
label: str
|
||||
value: float | int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReportedFinancialsData:
|
||||
"""Financial statements data containing balance sheet, income statement, and cash flow."""
|
||||
|
||||
bs: list[FinancialLineItem] # Balance Sheet
|
||||
ic: list[FinancialLineItem] # Income Statement
|
||||
cf: list[FinancialLineItem] # Cash Flow Statement
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReportedFinancialsResponse:
|
||||
"""Complete response from Finnhub reported financials API."""
|
||||
|
||||
start_date: str
|
||||
end_date: str
|
||||
year: int
|
||||
quarter: int
|
||||
access_number: str
|
||||
data: ReportedFinancialsData
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ReportedFinancialsResponse":
|
||||
"""Create ReportedFinancialsResponse from API response dictionary."""
|
||||
financial_data = data.get("data", {})
|
||||
|
||||
# Convert line items for each statement type
|
||||
bs_items = [
|
||||
FinancialLineItem(
|
||||
concept=item["concept"],
|
||||
unit=item["unit"],
|
||||
label=item["label"],
|
||||
value=item["value"],
|
||||
)
|
||||
for item in financial_data.get("bs", [])
|
||||
]
|
||||
|
||||
ic_items = [
|
||||
FinancialLineItem(
|
||||
concept=item["concept"],
|
||||
unit=item["unit"],
|
||||
label=item["label"],
|
||||
value=item["value"],
|
||||
)
|
||||
for item in financial_data.get("ic", [])
|
||||
]
|
||||
|
||||
cf_items = [
|
||||
FinancialLineItem(
|
||||
concept=item["concept"],
|
||||
unit=item["unit"],
|
||||
label=item["label"],
|
||||
value=item["value"],
|
||||
)
|
||||
for item in financial_data.get("cf", [])
|
||||
]
|
||||
|
||||
return cls(
|
||||
start_date=data["start_date"],
|
||||
end_date=data["end_date"],
|
||||
year=data["year"],
|
||||
quarter=data["quarter"],
|
||||
access_number=data["access_number"],
|
||||
data=ReportedFinancialsData(bs=bs_items, ic=ic_items, cf=cf_items),
|
||||
)
|
||||
|
||||
|
||||
# Insider Transactions Models
|
||||
@dataclass
|
||||
class InsiderTransaction:
|
||||
"""Individual insider transaction record."""
|
||||
|
||||
name: str
|
||||
share: int
|
||||
change: int
|
||||
filing_date: str
|
||||
transaction_date: str
|
||||
transaction_code: str
|
||||
transaction_price: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class InsiderTransactionsResponse:
|
||||
"""Complete response from Finnhub insider transactions API."""
|
||||
|
||||
data: list[InsiderTransaction]
|
||||
symbol: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "InsiderTransactionsResponse":
|
||||
"""Create InsiderTransactionsResponse from API response dictionary."""
|
||||
transactions = [
|
||||
InsiderTransaction(
|
||||
name=item["name"],
|
||||
share=item["share"],
|
||||
change=item["change"],
|
||||
filing_date=item["filingDate"],
|
||||
transaction_date=item["transactionDate"],
|
||||
transaction_code=item["transactionCode"],
|
||||
transaction_price=item["transactionPrice"],
|
||||
)
|
||||
for item in data.get("data", [])
|
||||
]
|
||||
|
||||
return cls(data=transactions, symbol=data["symbol"])
|
||||
|
||||
|
||||
# Insider Sentiment Models
|
||||
@dataclass
|
||||
class InsiderSentimentData:
|
||||
"""Individual insider sentiment data point."""
|
||||
|
||||
symbol: str
|
||||
year: int
|
||||
month: int
|
||||
change: int
|
||||
mspr: float # Monthly Share Purchase Ratio
|
||||
|
||||
|
||||
@dataclass
|
||||
class InsiderSentimentResponse:
|
||||
"""Complete response from Finnhub insider sentiment API."""
|
||||
|
||||
data: list[InsiderSentimentData]
|
||||
symbol: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "InsiderSentimentResponse":
|
||||
"""Create InsiderSentimentResponse from API response dictionary."""
|
||||
sentiment_data = [
|
||||
InsiderSentimentData(
|
||||
symbol=item["symbol"],
|
||||
year=item["year"],
|
||||
month=item["month"],
|
||||
change=item["change"],
|
||||
mspr=item["mspr"],
|
||||
)
|
||||
for item in data.get("data", [])
|
||||
]
|
||||
|
||||
return cls(data=sentiment_data, symbol=data["symbol"])
|
||||
|
||||
|
||||
# Company Profile Models
|
||||
@dataclass
|
||||
class CompanyProfile:
|
||||
"""Company profile information from Finnhub."""
|
||||
|
||||
country: str
|
||||
currency: str
|
||||
exchange: str
|
||||
ipo: str
|
||||
market_capitalization: float
|
||||
name: str
|
||||
phone: str
|
||||
share_outstanding: float
|
||||
ticker: str
|
||||
weburl: str
|
||||
logo: str
|
||||
finnhub_industry: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "CompanyProfile":
|
||||
"""Create CompanyProfile from API response dictionary."""
|
||||
return cls(
|
||||
country=data.get("country", ""),
|
||||
currency=data.get("currency", ""),
|
||||
exchange=data.get("exchange", ""),
|
||||
ipo=data.get("ipo", ""),
|
||||
market_capitalization=data.get("marketCapitalization", 0.0),
|
||||
name=data.get("name", ""),
|
||||
phone=data.get("phone", ""),
|
||||
share_outstanding=data.get("shareOutstanding", 0.0),
|
||||
ticker=data.get("ticker", ""),
|
||||
weburl=data.get("weburl", ""),
|
||||
logo=data.get("logo", ""),
|
||||
finnhub_industry=data.get("finnhubIndustry", ""),
|
||||
)
|
||||
|
||||
|
||||
# Complete indicator definitions for 20 professional indicators
|
||||
INDICATOR_DEFINITIONS = {
|
||||
# Momentum Indicators (7)
|
||||
"RSI": {
|
||||
"talib_function": "talib.RSI",
|
||||
"input_types": ["close"],
|
||||
"default_params": {"timeperiod": 14},
|
||||
"param_ranges": {"timeperiod": (2, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Relative Strength Index",
|
||||
},
|
||||
"MACD": {
|
||||
"talib_function": "talib.MACD",
|
||||
"input_types": ["close"],
|
||||
"default_params": {"fastperiod": 12, "slowperiod": 26, "signalperiod": 9},
|
||||
"param_ranges": {
|
||||
"fastperiod": (2, 50),
|
||||
"slowperiod": (10, 200),
|
||||
"signalperiod": (2, 50),
|
||||
},
|
||||
"output_format": "triple",
|
||||
"description": "Moving Average Convergence Divergence",
|
||||
},
|
||||
"STOCH": {
|
||||
"talib_function": "talib.STOCH",
|
||||
"input_types": ["ohlc"],
|
||||
"default_params": {"fastk_period": 14, "slowk_period": 3, "slowd_period": 3},
|
||||
"param_ranges": {
|
||||
"fastk_period": (1, 100),
|
||||
"slowk_period": (1, 50),
|
||||
"slowd_period": (1, 50),
|
||||
},
|
||||
"output_format": "double",
|
||||
"description": "Stochastic Oscillator",
|
||||
},
|
||||
"WILLR": {
|
||||
"talib_function": "talib.WILLR",
|
||||
"input_types": ["ohlc"],
|
||||
"default_params": {"timeperiod": 14},
|
||||
"param_ranges": {"timeperiod": (2, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Williams %R",
|
||||
},
|
||||
"CCI": {
|
||||
"talib_function": "talib.CCI",
|
||||
"input_types": ["ohlc"],
|
||||
"default_params": {"timeperiod": 20},
|
||||
"param_ranges": {"timeperiod": (2, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Commodity Channel Index",
|
||||
},
|
||||
"ROC": {
|
||||
"talib_function": "talib.ROC",
|
||||
"input_types": ["close"],
|
||||
"default_params": {"timeperiod": 12},
|
||||
"param_ranges": {"timeperiod": (1, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Rate of Change",
|
||||
},
|
||||
"MFI": {
|
||||
"talib_function": "talib.MFI",
|
||||
"input_types": ["ohlcv"],
|
||||
"default_params": {"timeperiod": 14},
|
||||
"param_ranges": {"timeperiod": (2, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Money Flow Index",
|
||||
},
|
||||
# Trend Indicators (7)
|
||||
"SMA": {
|
||||
"talib_function": "talib.SMA",
|
||||
"input_types": ["close"],
|
||||
"default_params": {"timeperiod": 20},
|
||||
"param_ranges": {"timeperiod": (2, 200)},
|
||||
"output_format": "single",
|
||||
"description": "Simple Moving Average",
|
||||
},
|
||||
"EMA": {
|
||||
"talib_function": "talib.EMA",
|
||||
"input_types": ["close"],
|
||||
"default_params": {"timeperiod": 20},
|
||||
"param_ranges": {"timeperiod": (2, 200)},
|
||||
"output_format": "single",
|
||||
"description": "Exponential Moving Average",
|
||||
},
|
||||
"BBANDS": {
|
||||
"talib_function": "talib.BBANDS",
|
||||
"input_types": ["close"],
|
||||
"default_params": {"timeperiod": 20, "nbdevup": 2.0, "nbdevdn": 2.0},
|
||||
"param_ranges": {
|
||||
"timeperiod": (2, 100),
|
||||
"nbdevup": (0.1, 5.0),
|
||||
"nbdevdn": (0.1, 5.0),
|
||||
},
|
||||
"output_format": "triple",
|
||||
"description": "Bollinger Bands",
|
||||
},
|
||||
"SAR": {
|
||||
"talib_function": "talib.SAR",
|
||||
"input_types": ["ohlc"],
|
||||
"default_params": {"acceleration": 0.02, "maximum": 0.2},
|
||||
"param_ranges": {"acceleration": (0.01, 0.1), "maximum": (0.1, 1.0)},
|
||||
"output_format": "single",
|
||||
"description": "Parabolic SAR",
|
||||
},
|
||||
"ADX": {
|
||||
"talib_function": "talib.ADX",
|
||||
"input_types": ["ohlc"],
|
||||
"default_params": {"timeperiod": 14},
|
||||
"param_ranges": {"timeperiod": (2, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Average Directional Index",
|
||||
},
|
||||
"AROON": {
|
||||
"talib_function": "talib.AROON",
|
||||
"input_types": ["hl"],
|
||||
"default_params": {"timeperiod": 25},
|
||||
"param_ranges": {"timeperiod": (2, 100)},
|
||||
"output_format": "double",
|
||||
"description": "Aroon Oscillator",
|
||||
},
|
||||
"TEMA": {
|
||||
"talib_function": "talib.TEMA",
|
||||
"input_types": ["close"],
|
||||
"default_params": {"timeperiod": 20},
|
||||
"param_ranges": {"timeperiod": (2, 200)},
|
||||
"output_format": "single",
|
||||
"description": "Triple Exponential Moving Average",
|
||||
},
|
||||
# Volume Indicators (3)
|
||||
"OBV": {
|
||||
"talib_function": "talib.OBV",
|
||||
"input_types": ["ohlcv"],
|
||||
"default_params": {},
|
||||
"param_ranges": {},
|
||||
"output_format": "single",
|
||||
"description": "On Balance Volume",
|
||||
},
|
||||
"AD": {
|
||||
"talib_function": "talib.AD",
|
||||
"input_types": ["ohlcv"],
|
||||
"default_params": {},
|
||||
"param_ranges": {},
|
||||
"output_format": "single",
|
||||
"description": "Accumulation/Distribution",
|
||||
},
|
||||
"ADOSC": {
|
||||
"talib_function": "talib.ADOSC",
|
||||
"input_types": ["ohlcv"],
|
||||
"default_params": {"fastperiod": 3, "slowperiod": 10},
|
||||
"param_ranges": {"fastperiod": (2, 50), "slowperiod": (5, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Accumulation/Distribution Oscillator",
|
||||
},
|
||||
# Volatility Indicators (3)
|
||||
"ATR": {
|
||||
"talib_function": "talib.ATR",
|
||||
"input_types": ["ohlc"],
|
||||
"default_params": {"timeperiod": 14},
|
||||
"param_ranges": {"timeperiod": (1, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Average True Range",
|
||||
},
|
||||
"NATR": {
|
||||
"talib_function": "talib.NATR",
|
||||
"input_types": ["ohlc"],
|
||||
"default_params": {"timeperiod": 14},
|
||||
"param_ranges": {"timeperiod": (1, 100)},
|
||||
"output_format": "single",
|
||||
"description": "Normalized Average True Range",
|
||||
},
|
||||
"TRANGE": {
|
||||
"talib_function": "talib.TRANGE",
|
||||
"input_types": ["ohlc"],
|
||||
"default_params": {},
|
||||
"param_ranges": {},
|
||||
"output_format": "single",
|
||||
"description": "True Range",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class IndicatorPresets:
|
||||
"""Professional indicator presets for different trading styles."""
|
||||
|
||||
@staticmethod
|
||||
def get_scalping_presets() -> dict[str, dict[str, IndicatorParamValue]]:
|
||||
"""Fast scalping presets (1-5 minute timeframes)."""
|
||||
return {
|
||||
"RSI_SCALPING": {"timeperiod": 5},
|
||||
"MACD_SCALPING": {"fastperiod": 5, "slowperiod": 13, "signalperiod": 5},
|
||||
"STOCH_SCALPING": {"fastk_period": 5, "slowk_period": 3, "slowd_period": 3},
|
||||
"EMA_SCALPING": {"timeperiod": 9},
|
||||
"BBANDS_TIGHT": {"timeperiod": 10, "nbdevup": 1.5, "nbdevdn": 1.5},
|
||||
"ATR_SCALPING": {"timeperiod": 5},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_day_trading_presets() -> dict[str, dict[str, IndicatorParamValue]]:
|
||||
"""Day trading presets (5-60 minute timeframes)."""
|
||||
return {
|
||||
"RSI_DAY_TRADING": {"timeperiod": 14},
|
||||
"MACD_DAY_TRADING": {"fastperiod": 12, "slowperiod": 26, "signalperiod": 9},
|
||||
"STOCH_DAY_TRADING": {
|
||||
"fastk_period": 14,
|
||||
"slowk_period": 3,
|
||||
"slowd_period": 3,
|
||||
},
|
||||
"EMA_DAY_TRADING": {"timeperiod": 20},
|
||||
"BBANDS_STANDARD": {"timeperiod": 20, "nbdevup": 2.0, "nbdevdn": 2.0},
|
||||
"ADX_DAY_TRADING": {"timeperiod": 14},
|
||||
"ATR_DAY_TRADING": {"timeperiod": 14},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_swing_trading_presets() -> dict[str, dict[str, IndicatorParamValue]]:
|
||||
"""Swing trading presets (daily timeframes)."""
|
||||
return {
|
||||
"RSI_SWING": {"timeperiod": 21},
|
||||
"MACD_SWING": {"fastperiod": 12, "slowperiod": 26, "signalperiod": 9},
|
||||
"STOCH_SWING": {"fastk_period": 21, "slowk_period": 5, "slowd_period": 5},
|
||||
"SMA_SWING_SHORT": {"timeperiod": 50},
|
||||
"SMA_SWING_LONG": {"timeperiod": 200},
|
||||
"EMA_SWING": {"timeperiod": 50},
|
||||
"BBANDS_SWING": {"timeperiod": 20, "nbdevup": 2.5, "nbdevdn": 2.5},
|
||||
"ADX_SWING": {"timeperiod": 21},
|
||||
"AROON_SWING": {"timeperiod": 25},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_position_trading_presets() -> dict[str, dict[str, IndicatorParamValue]]:
|
||||
"""Position trading presets (weekly/monthly timeframes)."""
|
||||
return {
|
||||
"RSI_POSITION": {"timeperiod": 30},
|
||||
"MACD_POSITION": {"fastperiod": 20, "slowperiod": 50, "signalperiod": 15},
|
||||
"SMA_POSITION_SHORT": {"timeperiod": 100},
|
||||
"SMA_POSITION_LONG": {"timeperiod": 300},
|
||||
"EMA_POSITION": {"timeperiod": 100},
|
||||
"BBANDS_POSITION": {"timeperiod": 50, "nbdevup": 3.0, "nbdevdn": 3.0},
|
||||
"ADX_POSITION": {"timeperiod": 30},
|
||||
"AROON_POSITION": {"timeperiod": 50},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_all_presets() -> dict[str, dict[str, IndicatorParamValue]]:
|
||||
"""Get all available presets combined."""
|
||||
all_presets = {}
|
||||
all_presets.update(IndicatorPresets.get_scalping_presets())
|
||||
all_presets.update(IndicatorPresets.get_day_trading_presets())
|
||||
all_presets.update(IndicatorPresets.get_swing_trading_presets())
|
||||
all_presets.update(IndicatorPresets.get_position_trading_presets())
|
||||
return all_presets
|
||||
|
||||
@staticmethod
|
||||
def get_preset_for_style(style: str) -> dict[str, dict[str, IndicatorParamValue]]:
|
||||
"""Get presets for a specific trading style."""
|
||||
style_map = {
|
||||
"scalping": IndicatorPresets.get_scalping_presets(),
|
||||
"day_trading": IndicatorPresets.get_day_trading_presets(),
|
||||
"swing": IndicatorPresets.get_swing_trading_presets(),
|
||||
"position": IndicatorPresets.get_position_trading_presets(),
|
||||
}
|
||||
return style_map.get(style.lower(), {})
|
||||
|
|
@ -0,0 +1,183 @@
|
|||
"""
|
||||
Repository for fundamental financial data (balance sheets, income statements, cash flow).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import ReportedFinancialsResponse
|
||||
|
||||
# Base repository functionality inline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinancialStatement:
|
||||
"""Represents a financial statement with standardized structure matching the service."""
|
||||
|
||||
period: str
|
||||
report_date: str
|
||||
publish_date: str
|
||||
currency: str
|
||||
data: dict[str, float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinancialData:
|
||||
"""Container for all financial statements for a symbol and date."""
|
||||
|
||||
symbol: str
|
||||
date: date
|
||||
financial_statements: dict[
|
||||
str, FinancialStatement
|
||||
] # "balance_sheet", "income_statement", "cash_flow"
|
||||
|
||||
|
||||
class FundamentalDataRepository:
|
||||
"""Repository for accessing cached fundamental financial data as a KV store."""
|
||||
|
||||
def __init__(self, data_dir: str, **kwargs):
|
||||
"""
|
||||
Initialize fundamental data repository.
|
||||
|
||||
Args:
|
||||
data_dir: Base directory for fundamental data storage
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
self.fundamental_data_dir = Path(data_dir) / "fundamental_data"
|
||||
self.fundamental_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def store_financial_statements(
|
||||
self,
|
||||
symbol: str,
|
||||
date: date,
|
||||
statements: dict[str, FinancialStatement],
|
||||
) -> tuple[date, dict[str, FinancialStatement]]:
|
||||
"""
|
||||
Store financial statements for a symbol and date.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
date: Date of the financial statements
|
||||
statements: Dictionary of statements keyed by type ("balance_sheet", "income_statement", "cash_flow")
|
||||
|
||||
Returns:
|
||||
Tuple[date, dict[str, FinancialStatement]]: The stored date and statements
|
||||
"""
|
||||
# Create symbol directory
|
||||
symbol_dir = self.fundamental_data_dir / symbol
|
||||
self._ensure_path_exists(symbol_dir)
|
||||
|
||||
# Create JSON file path
|
||||
file_path = symbol_dir / f"{date.isoformat()}.json"
|
||||
|
||||
try:
|
||||
# Prepare data for JSON serialization
|
||||
statements_data = {}
|
||||
for statement_type, statement in statements.items():
|
||||
statements_data[statement_type] = asdict(statement)
|
||||
|
||||
data = {
|
||||
"symbol": symbol,
|
||||
"date": date.isoformat(),
|
||||
"financial_statements": statements_data,
|
||||
"metadata": {
|
||||
"stored_at": date.today().isoformat(),
|
||||
"repository": "fundamental_data_repository",
|
||||
"statement_count": len(statements),
|
||||
},
|
||||
}
|
||||
|
||||
# Write to JSON file
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
|
||||
logger.info(
|
||||
f"Stored {len(statements)} financial statements for {symbol} on {date}"
|
||||
)
|
||||
return (date, statements)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error storing financial statements for {symbol} on {date}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def store_reported_financials(
|
||||
self,
|
||||
symbol: str,
|
||||
date: date,
|
||||
reported_financials: "ReportedFinancialsResponse",
|
||||
) -> bool:
|
||||
"""
|
||||
Store reported financials data from Finnhub API.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
date: Date for the financial data
|
||||
reported_financials: ReportedFinancialsResponse from Finnhub
|
||||
|
||||
Returns:
|
||||
bool: True if storage was successful
|
||||
"""
|
||||
# Create symbol directory
|
||||
symbol_dir = self.fundamental_data_dir / symbol
|
||||
self._ensure_path_exists(symbol_dir)
|
||||
|
||||
# Create JSON file path
|
||||
file_path = symbol_dir / f"{date.isoformat()}_reported_financials.json"
|
||||
|
||||
try:
|
||||
# Convert dataclass to dict for JSON serialization
|
||||
data = {
|
||||
"symbol": symbol,
|
||||
"date": date.isoformat(),
|
||||
"reported_financials": {
|
||||
"start_date": reported_financials.start_date,
|
||||
"end_date": reported_financials.end_date,
|
||||
"year": reported_financials.year,
|
||||
"quarter": reported_financials.quarter,
|
||||
"access_number": reported_financials.access_number,
|
||||
"data": {
|
||||
"bs": [asdict(item) for item in reported_financials.data.bs],
|
||||
"ic": [asdict(item) for item in reported_financials.data.ic],
|
||||
"cf": [asdict(item) for item in reported_financials.data.cf],
|
||||
},
|
||||
},
|
||||
"metadata": {
|
||||
"stored_at": date.today().isoformat(),
|
||||
"repository": "fundamental_data_repository",
|
||||
"data_source": "finnhub_reported_financials",
|
||||
"bs_items": len(reported_financials.data.bs),
|
||||
"ic_items": len(reported_financials.data.ic),
|
||||
"cf_items": len(reported_financials.data.cf),
|
||||
},
|
||||
}
|
||||
|
||||
# Write to JSON file
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
|
||||
logger.info(
|
||||
f"Stored reported financials for {symbol} on {date} "
|
||||
f"(BS: {len(reported_financials.data.bs)}, "
|
||||
f"IC: {len(reported_financials.data.ic)}, "
|
||||
f"CF: {len(reported_financials.data.cf)} items)"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error storing reported financials for {symbol} on {date}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def _ensure_path_exists(self, path: Path) -> None:
|
||||
"""Ensure a directory path exists."""
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -1,313 +0,0 @@
|
|||
"""
|
||||
Repository for fundamental financial data (balance sheets, income statements, cash flow).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
from .base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinancialStatement:
|
||||
"""Represents a financial statement with standardized structure matching the service."""
|
||||
|
||||
period: str
|
||||
report_date: str
|
||||
publish_date: str
|
||||
currency: str
|
||||
data: dict[str, float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinancialData:
|
||||
"""Container for all financial statements for a symbol and date."""
|
||||
|
||||
symbol: str
|
||||
date: date
|
||||
financial_statements: dict[
|
||||
str, FinancialStatement
|
||||
] # "balance_sheet", "income_statement", "cash_flow"
|
||||
|
||||
|
||||
class FundamentalDataRepository(BaseRepository):
|
||||
"""Repository for accessing cached fundamental financial data as a KV store."""
|
||||
|
||||
def __init__(self, data_dir: str, **kwargs):
|
||||
"""
|
||||
Initialize fundamental data repository.
|
||||
|
||||
Args:
|
||||
data_dir: Base directory for fundamental data storage
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
self.fundamental_data_dir = Path(data_dir) / "fundamental_data"
|
||||
self.fundamental_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_financial_data(
|
||||
self, symbol: str, start_date: date, end_date: date
|
||||
) -> dict[date, FinancialData]:
|
||||
"""
|
||||
Get cached fundamental data for a symbol and date range.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Returns:
|
||||
Dict[date, FinancialData]: Financial data keyed by date
|
||||
"""
|
||||
symbol_dir = self.fundamental_data_dir / symbol
|
||||
|
||||
if not symbol_dir.exists():
|
||||
logger.warning(f"No data directory found for symbol {symbol}")
|
||||
return {}
|
||||
|
||||
financial_data = {}
|
||||
|
||||
# Scan for JSON files in the symbol directory
|
||||
for json_file in symbol_dir.glob("*.json"):
|
||||
try:
|
||||
# Parse date from filename (YYYY-MM-DD.json)
|
||||
date_str = json_file.stem
|
||||
file_date = date.fromisoformat(date_str)
|
||||
|
||||
# Filter by date range
|
||||
if start_date <= file_date <= end_date:
|
||||
with open(json_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Create FinancialStatement objects from JSON data
|
||||
financial_statements = {}
|
||||
for statement_type, statement_data in data.get(
|
||||
"financial_statements", {}
|
||||
).items():
|
||||
financial_statements[statement_type] = FinancialStatement(
|
||||
**statement_data
|
||||
)
|
||||
|
||||
# Create FinancialData container
|
||||
financial_data[file_date] = FinancialData(
|
||||
symbol=symbol,
|
||||
date=file_date,
|
||||
financial_statements=financial_statements,
|
||||
)
|
||||
|
||||
except (ValueError, json.JSONDecodeError, KeyError) as e:
|
||||
logger.error(f"Error reading financial data from {json_file}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Retrieved {len(financial_data)} financial records for {symbol}")
|
||||
return financial_data
|
||||
|
||||
def has_data_for_period(
|
||||
self, symbol: str, start_date: str, end_date: str, frequency: str = "quarterly"
|
||||
) -> bool:
|
||||
"""
|
||||
Check if we have sufficient data for the given period.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency (not used in this implementation)
|
||||
|
||||
Returns:
|
||||
bool: True if we have any data in the period
|
||||
"""
|
||||
start_dt = date.fromisoformat(start_date)
|
||||
end_dt = date.fromisoformat(end_date)
|
||||
financial_data = self.get_financial_data(symbol, start_dt, end_dt)
|
||||
return len(financial_data) > 0
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
frequency: str = "quarterly",
|
||||
**kwargs,
|
||||
) -> dict[str, any]:
|
||||
"""
|
||||
Get data in the format expected by the service.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Dict with financial_statements structure expected by service
|
||||
"""
|
||||
start_dt = date.fromisoformat(start_date)
|
||||
end_dt = date.fromisoformat(end_date)
|
||||
financial_data = self.get_financial_data(symbol, start_dt, end_dt)
|
||||
|
||||
if not financial_data:
|
||||
return {}
|
||||
|
||||
# Get the most recent data
|
||||
latest_date = max(financial_data.keys())
|
||||
latest_data = financial_data[latest_date]
|
||||
|
||||
return {
|
||||
"financial_statements": {
|
||||
statement_type: asdict(statement)
|
||||
for statement_type, statement in latest_data.financial_statements.items()
|
||||
}
|
||||
}
|
||||
|
||||
def store_data(
|
||||
self,
|
||||
symbol: str,
|
||||
cache_data: dict,
|
||||
frequency: str = "quarterly",
|
||||
overwrite: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Store data in the format expected by the service.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
cache_data: Data dictionary with financial_statements
|
||||
frequency: Reporting frequency (not used)
|
||||
overwrite: Whether to overwrite existing data
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
# Extract financial statements from cache_data
|
||||
statements_data = cache_data.get("financial_statements", {})
|
||||
|
||||
if not statements_data:
|
||||
logger.warning(
|
||||
f"No financial statements found in cache_data for {symbol}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Use today's date as the storage date
|
||||
storage_date = date.today()
|
||||
|
||||
# Convert to FinancialStatement objects
|
||||
financial_statements = {}
|
||||
for statement_type, statement_dict in statements_data.items():
|
||||
financial_statements[statement_type] = FinancialStatement(
|
||||
**statement_dict
|
||||
)
|
||||
|
||||
# Store the statements
|
||||
self.store_financial_statements(symbol, storage_date, financial_statements)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing data for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def clear_data(
|
||||
self, symbol: str, start_date: str, end_date: str, frequency: str = "quarterly"
|
||||
) -> bool:
|
||||
"""
|
||||
Clear data for a symbol and date range.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency (not used)
|
||||
|
||||
Returns:
|
||||
bool: True if successful
|
||||
"""
|
||||
try:
|
||||
symbol_dir = self.fundamental_data_dir / symbol
|
||||
|
||||
if not symbol_dir.exists():
|
||||
return True # Nothing to clear
|
||||
|
||||
start_dt = date.fromisoformat(start_date)
|
||||
end_dt = date.fromisoformat(end_date)
|
||||
|
||||
# Remove files in the date range
|
||||
for json_file in symbol_dir.glob("*.json"):
|
||||
try:
|
||||
date_str = json_file.stem
|
||||
file_date = date.fromisoformat(date_str)
|
||||
|
||||
if start_dt <= file_date <= end_dt:
|
||||
json_file.unlink()
|
||||
logger.debug(f"Removed {json_file}")
|
||||
|
||||
except (ValueError, OSError) as e:
|
||||
logger.warning(f"Error removing {json_file}: {e}")
|
||||
continue
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing data for {symbol}: {e}")
|
||||
return False
|
||||
|
||||
def store_financial_statements(
|
||||
self,
|
||||
symbol: str,
|
||||
date: date,
|
||||
statements: dict[str, FinancialStatement],
|
||||
) -> tuple[date, dict[str, FinancialStatement]]:
|
||||
"""
|
||||
Store financial statements for a symbol and date.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
date: Date of the financial statements
|
||||
statements: Dictionary of statements keyed by type ("balance_sheet", "income_statement", "cash_flow")
|
||||
|
||||
Returns:
|
||||
Tuple[date, dict[str, FinancialStatement]]: The stored date and statements
|
||||
"""
|
||||
# Create symbol directory
|
||||
symbol_dir = self.fundamental_data_dir / symbol
|
||||
self._ensure_path_exists(symbol_dir)
|
||||
|
||||
# Create JSON file path
|
||||
file_path = symbol_dir / f"{date.isoformat()}.json"
|
||||
|
||||
try:
|
||||
# Prepare data for JSON serialization
|
||||
statements_data = {}
|
||||
for statement_type, statement in statements.items():
|
||||
statements_data[statement_type] = asdict(statement)
|
||||
|
||||
data = {
|
||||
"symbol": symbol,
|
||||
"date": date.isoformat(),
|
||||
"financial_statements": statements_data,
|
||||
"metadata": {
|
||||
"stored_at": date.today().isoformat(),
|
||||
"repository": "fundamental_data_repository",
|
||||
"statement_count": len(statements),
|
||||
},
|
||||
}
|
||||
|
||||
# Write to JSON file
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(data, f, indent=2, default=str)
|
||||
|
||||
logger.info(
|
||||
f"Stored {len(statements)} financial statements for {symbol} on {date}"
|
||||
)
|
||||
return (date, statements)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error storing financial statements for {symbol} on {date}: {e}"
|
||||
)
|
||||
raise
|
||||
|
|
@ -8,12 +8,12 @@ from pathlib import Path
|
|||
|
||||
import pandas as pd
|
||||
|
||||
from .base import BaseRepository
|
||||
# from .base import BaseRepository # Not found, removing import
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketDataRepository(BaseRepository):
|
||||
class MarketDataRepository:
|
||||
"""Repository for accessing historical market data from CSV files."""
|
||||
|
||||
def __init__(self, data_dir: str, **kwargs):
|
||||
|
|
@ -60,9 +60,8 @@ class MarketDataRepository(BaseRepository):
|
|||
df["Date"] = pd.to_datetime(df["Date"]).dt.date
|
||||
|
||||
# Filter by date range
|
||||
filtered_df = df[
|
||||
(df["Date"] >= start_date) & (df["Date"] <= end_date)
|
||||
].copy()
|
||||
mask = (df["Date"] >= start_date) & (df["Date"] <= end_date)
|
||||
filtered_df: pd.DataFrame = df.loc[mask].copy()
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(filtered_df)} records for {symbol} from {start_date} to {end_date}"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,184 @@
|
|||
"""
|
||||
Article scraper client for extracting full content from news URLs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import newspaper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScrapeResult:
|
||||
"""Result of article scraping operation."""
|
||||
|
||||
status: str # 'SUCCESS', 'SCRAPE_FAILED', 'ARCHIVE_SUCCESS', 'NOT_FOUND'
|
||||
content: str = ""
|
||||
author: str = ""
|
||||
final_url: str = ""
|
||||
title: str = ""
|
||||
publish_date: str = ""
|
||||
|
||||
|
||||
class ArticleScraperClient:
|
||||
"""Client for scraping article content with Internet Archive fallback."""
|
||||
|
||||
def __init__(self, user_agent: str, delay: float = 1.0):
|
||||
"""
|
||||
Initialize article scraper.
|
||||
|
||||
Args:
|
||||
user_agent: User agent string for requests
|
||||
delay: Delay between requests in seconds
|
||||
"""
|
||||
self.user_agent = user_agent or (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
||||
"(KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
)
|
||||
self.delay = delay
|
||||
|
||||
def scrape_article(self, url: str) -> ScrapeResult:
|
||||
"""
|
||||
Scrape article content from URL with fallback to Internet Archive.
|
||||
|
||||
Args:
|
||||
url: Article URL to scrape
|
||||
|
||||
Returns:
|
||||
ScrapeResult: Scraping result with content and metadata
|
||||
"""
|
||||
if not url or not self._is_valid_url(url):
|
||||
return ScrapeResult(status="NOT_FOUND", final_url=url)
|
||||
|
||||
# Try original source first
|
||||
result = self._scrape_from_source(url)
|
||||
if result.status == "SUCCESS":
|
||||
return result
|
||||
|
||||
# Fallback to Internet Archive
|
||||
logger.info(f"Original scraping failed for {url}, trying Internet Archive")
|
||||
return self._scrape_from_wayback(url)
|
||||
|
||||
def _scrape_from_source(self, url: str) -> ScrapeResult:
|
||||
"""Scrape article from original source using newspaper3k."""
|
||||
try:
|
||||
# Add delay to be respectful
|
||||
time.sleep(self.delay)
|
||||
|
||||
# Configure newspaper article
|
||||
article = newspaper.Article(url)
|
||||
article.config.browser_user_agent = self.user_agent
|
||||
article.config.request_timeout = 10
|
||||
|
||||
# Download and parse
|
||||
article.download()
|
||||
article.parse()
|
||||
|
||||
# Validate content
|
||||
if not article.text or len(article.text.strip()) < 100:
|
||||
logger.warning(f"Article content too short or empty for {url}")
|
||||
return ScrapeResult(status="SCRAPE_FAILED", final_url=url)
|
||||
|
||||
# Handle publish_date which can be datetime or string
|
||||
publish_date_str = ""
|
||||
if article.publish_date:
|
||||
if isinstance(article.publish_date, datetime):
|
||||
publish_date_str = article.publish_date.strftime("%Y-%m-%d")
|
||||
elif isinstance(article.publish_date, str):
|
||||
publish_date_str = article.publish_date
|
||||
else:
|
||||
# Try to convert to string
|
||||
publish_date_str = str(article.publish_date)
|
||||
|
||||
return ScrapeResult(
|
||||
status="SUCCESS",
|
||||
content=article.text.strip(),
|
||||
author=", ".join(article.authors) if article.authors else "",
|
||||
final_url=url,
|
||||
title=article.title or "",
|
||||
publish_date=publish_date_str,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error scraping article from {url}: {e}")
|
||||
return ScrapeResult(status="SCRAPE_FAILED", final_url=url)
|
||||
|
||||
def _scrape_from_wayback(self, url: str) -> ScrapeResult:
|
||||
"""Scrape article from Internet Archive Wayback Machine."""
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
logger.error("requests not installed. Install with: pip install requests")
|
||||
return ScrapeResult(status="NOT_FOUND", final_url=url)
|
||||
|
||||
try:
|
||||
# Query Wayback Machine CDX API for snapshots
|
||||
cdx_url = "http://web.archive.org/cdx/search/cdx"
|
||||
params = {
|
||||
"url": url,
|
||||
"output": "json",
|
||||
"fl": "timestamp,original",
|
||||
"filter": "statuscode:200",
|
||||
"limit": "1",
|
||||
}
|
||||
|
||||
response = requests.get(cdx_url, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
if len(data) < 2: # First row is headers
|
||||
logger.warning(f"No archived snapshots found for {url}")
|
||||
return ScrapeResult(status="NOT_FOUND", final_url=url)
|
||||
|
||||
# Get the most recent snapshot
|
||||
timestamp, original_url = data[1]
|
||||
archive_url = f"https://web.archive.org/web/{timestamp}/{original_url}"
|
||||
|
||||
logger.info(f"Found archived snapshot: {archive_url}")
|
||||
|
||||
# Scrape from archive URL
|
||||
result = self._scrape_from_source(archive_url)
|
||||
if result.status == "SUCCESS":
|
||||
result.status = "ARCHIVE_SUCCESS"
|
||||
result.final_url = archive_url
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error accessing Internet Archive for {url}: {e}")
|
||||
return ScrapeResult(status="NOT_FOUND", final_url=url)
|
||||
|
||||
def _is_valid_url(self, url: str) -> bool:
|
||||
"""Check if URL is valid and accessible."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return bool(parsed.netloc) and parsed.scheme in ("http", "https")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def scrape_multiple_articles(self, urls: list[str]) -> dict[str, ScrapeResult]:
|
||||
"""
|
||||
Scrape multiple articles sequentially.
|
||||
|
||||
Args:
|
||||
urls: List of article URLs to scrape
|
||||
|
||||
Returns:
|
||||
Dict mapping URLs to ScrapeResults
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for i, url in enumerate(urls):
|
||||
logger.info(f"Scraping article {i + 1}/{len(urls)}: {url}")
|
||||
results[url] = self.scrape_article(url)
|
||||
|
||||
# Add delay between requests
|
||||
if i < len(urls) - 1:
|
||||
time.sleep(self.delay)
|
||||
|
||||
return results
|
||||
|
|
@ -1,194 +1,188 @@
|
|||
"""
|
||||
Google News client for live news data via web scraping.
|
||||
Google News client for live news data via RSS feeds.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from .base import BaseClient
|
||||
import feedparser
|
||||
import requests
|
||||
from dateutil import parser as date_parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleNewsClient(BaseClient):
|
||||
@dataclass
|
||||
class FeedEntry:
|
||||
"""Structured representation of a feedparser entry."""
|
||||
|
||||
title: str
|
||||
link: str
|
||||
published: str
|
||||
published_parsed: time.struct_time | None
|
||||
summary: str
|
||||
guid: str
|
||||
|
||||
@classmethod
|
||||
def from_feedparser_dict(cls, entry: feedparser.FeedParserDict) -> "FeedEntry":
|
||||
"""Convert a FeedParserDict to a structured FeedEntry."""
|
||||
return cls(
|
||||
title=getattr(entry, "title", "Untitled"),
|
||||
link=getattr(entry, "link", ""),
|
||||
published=getattr(entry, "published", ""),
|
||||
published_parsed=getattr(entry, "published_parsed", None),
|
||||
summary=getattr(entry, "summary", ""),
|
||||
guid=getattr(entry, "id", getattr(entry, "link", "")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoogleNewsArticle:
|
||||
"""Represents a news article from Google News RSS feed."""
|
||||
|
||||
title: str
|
||||
link: str
|
||||
published: datetime
|
||||
summary: str
|
||||
source: str
|
||||
guid: str
|
||||
|
||||
|
||||
class GoogleNewsClient:
|
||||
"""Client for Google News data via web scraping."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize Google News client.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration options including rate limits
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.max_retries = kwargs.get("max_retries", 3)
|
||||
self.delay_between_requests = kwargs.get("delay_between_requests", 1.0)
|
||||
|
||||
def get_data(
|
||||
self, query: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get news data for a query and date range.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: News data with metadata
|
||||
"""
|
||||
if not self.validate_date_range(start_date, end_date):
|
||||
raise ValueError(f"Invalid date range: {start_date} to {end_date}")
|
||||
|
||||
try:
|
||||
# Replace spaces with + for URL encoding
|
||||
formatted_query = query.replace(" ", "+")
|
||||
|
||||
logger.info(
|
||||
f"Fetching Google News for query: {query} from {start_date} to {end_date}"
|
||||
)
|
||||
|
||||
news_results = getNewsData(formatted_query, start_date, end_date)
|
||||
|
||||
if not news_results:
|
||||
logger.warning(f"No news found for query: {query}")
|
||||
return {
|
||||
"query": query,
|
||||
"period": {"start": start_date, "end": end_date},
|
||||
"articles": [],
|
||||
"metadata": {
|
||||
"source": "google_news",
|
||||
"empty": True,
|
||||
"reason": "no_articles_found",
|
||||
},
|
||||
}
|
||||
|
||||
# Process and standardize article data
|
||||
processed_articles = []
|
||||
for article in news_results:
|
||||
processed_article = {
|
||||
"headline": article.get("title", ""),
|
||||
"summary": article.get("snippet", ""),
|
||||
"url": article.get("link", ""),
|
||||
"source": article.get("source", "Unknown"),
|
||||
"date": article.get(
|
||||
"date", end_date
|
||||
), # Fallback to end_date if no date
|
||||
"entities": article.get("entities", []),
|
||||
}
|
||||
processed_articles.append(processed_article)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"period": {"start": start_date, "end": end_date},
|
||||
"articles": processed_articles,
|
||||
"metadata": {
|
||||
"source": "google_news",
|
||||
"article_count": len(processed_articles),
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
"search_query": formatted_query,
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Google News for query '{query}': {e}")
|
||||
raise
|
||||
|
||||
def get_company_news(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
def get_company_news(self, symbol: str) -> list[GoogleNewsArticle]:
|
||||
"""
|
||||
Get news data specific to a company symbol.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Company-specific news data
|
||||
list[GoogleNewsArticle]: Company-specific news data
|
||||
"""
|
||||
# Create company-focused search query
|
||||
company_query = f"{symbol} stock"
|
||||
|
||||
result = self.get_data(company_query, start_date, end_date, **kwargs)
|
||||
result["symbol"] = symbol
|
||||
result["metadata"]["query_type"] = "company_specific"
|
||||
|
||||
return result
|
||||
return self._get_rss_feed(symbol)
|
||||
|
||||
def get_global_news(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
categories: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
categories: list[str],
|
||||
) -> list[GoogleNewsArticle]:
|
||||
"""
|
||||
Get global/macro news that might affect markets.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
categories: List of news categories to search
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Global news data
|
||||
list[GoogleNewsArticle]: Global news data
|
||||
"""
|
||||
if categories is None:
|
||||
categories = ["economy", "finance", "markets", "business"]
|
||||
|
||||
# Get RSS queries for categories
|
||||
all_articles = []
|
||||
|
||||
for category in categories:
|
||||
try:
|
||||
category_data = self.get_data(category, start_date, end_date, **kwargs)
|
||||
|
||||
# Add category tag to each article
|
||||
for article in category_data.get("articles", []):
|
||||
article["category"] = category
|
||||
|
||||
all_articles.extend(category_data.get("articles", []))
|
||||
articles = self._get_rss_feed(category)
|
||||
all_articles.extend(articles)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch news for category '{category}': {e}")
|
||||
continue
|
||||
|
||||
return {
|
||||
"query": "global_news",
|
||||
"categories": categories,
|
||||
"period": {"start": start_date, "end": end_date},
|
||||
"articles": all_articles,
|
||||
"metadata": {
|
||||
"source": "google_news",
|
||||
"article_count": len(all_articles),
|
||||
"categories_searched": categories,
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
"query_type": "global_news",
|
||||
},
|
||||
}
|
||||
return all_articles
|
||||
|
||||
def get_available_categories(self) -> list[str]:
|
||||
def _get_rss_feed(self, query: str) -> list[GoogleNewsArticle]:
|
||||
"""
|
||||
Get list of commonly used news categories.
|
||||
Fetch RSS feed from Google News for a given query.
|
||||
|
||||
Args:
|
||||
query: Search query (company symbol or news category)
|
||||
|
||||
Returns:
|
||||
List[str]: News categories
|
||||
list[GoogleNewsArticle]: Parsed articles from RSS feed
|
||||
"""
|
||||
return [
|
||||
"business",
|
||||
"economy",
|
||||
"finance",
|
||||
"markets",
|
||||
"technology",
|
||||
"politics",
|
||||
"world",
|
||||
"healthcare",
|
||||
"energy",
|
||||
"crypto",
|
||||
]
|
||||
try:
|
||||
# Construct Google News RSS URL
|
||||
encoded_query = quote(query)
|
||||
rss_url = f"https://news.google.com/rss/search?q={encoded_query}&hl=en-US&gl=US&ceid=US:en"
|
||||
|
||||
logger.info(f"Fetching RSS feed for query: {query}")
|
||||
|
||||
# Use requests with timeout and User-Agent header
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
response = requests.get(rss_url, timeout=10, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Use feedparser to parse the fetched content
|
||||
feed = feedparser.parse(response.content)
|
||||
|
||||
# Check if feed was parsed successfully
|
||||
if feed.bozo:
|
||||
logger.warning(
|
||||
f"Feed parsing had issues for query '{query}': {feed.bozo_exception}"
|
||||
)
|
||||
|
||||
articles = []
|
||||
for raw_entry in feed.entries:
|
||||
try:
|
||||
# Convert FeedParserDict to structured dataclass
|
||||
entry = FeedEntry.from_feedparser_dict(raw_entry)
|
||||
article = self._convert_entry_to_article(entry)
|
||||
articles.append(article)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse article entry: {e}")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Successfully fetched {len(articles)} articles for query: {query}"
|
||||
)
|
||||
return articles
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Network error fetching RSS feed for query '{query}': {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error fetching RSS feed for query '{query}': {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return []
|
||||
|
||||
def _convert_entry_to_article(self, entry: FeedEntry) -> GoogleNewsArticle:
|
||||
"""
|
||||
Convert a structured FeedEntry to a GoogleNewsArticle.
|
||||
|
||||
Args:
|
||||
entry: Structured FeedEntry dataclass
|
||||
|
||||
Returns:
|
||||
GoogleNewsArticle: Converted article object
|
||||
"""
|
||||
# Parse published date with fallback to current time
|
||||
try:
|
||||
published = (
|
||||
date_parser.parse(entry.published)
|
||||
if entry.published
|
||||
else datetime.utcnow()
|
||||
)
|
||||
except (ValueError, OverflowError, TypeError):
|
||||
published = datetime.utcnow()
|
||||
|
||||
# Extract source from title (Google News format: "Title - Source")
|
||||
title_parts = entry.title.split(" - ")
|
||||
title = title_parts[0] if title_parts else entry.title
|
||||
source = title_parts[-1] if len(title_parts) > 1 else "Unknown"
|
||||
|
||||
return GoogleNewsArticle(
|
||||
title=title,
|
||||
link=entry.link,
|
||||
published=published,
|
||||
summary=entry.summary,
|
||||
source=source,
|
||||
guid=entry.guid,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ from dataclasses import asdict, dataclass, field
|
|||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
from .base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -40,10 +38,10 @@ class NewsData:
|
|||
articles: list[NewsArticle]
|
||||
|
||||
|
||||
class NewsRepository(BaseRepository):
|
||||
class NewsRepository:
|
||||
"""Repository for accessing cached news data with source separation."""
|
||||
|
||||
def __init__(self, data_dir: str, **kwargs):
|
||||
def __init__(self, data_dir: str):
|
||||
"""
|
||||
Initialize news repository.
|
||||
|
||||
|
|
@ -155,7 +153,6 @@ class NewsRepository(BaseRepository):
|
|||
"""
|
||||
# Create source/query directory
|
||||
source_dir = self.news_data_dir / source / query
|
||||
self._ensure_path_exists(source_dir)
|
||||
|
||||
# Create JSON file path
|
||||
file_path = source_dir / f"{date.isoformat()}.json"
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@ from dataclasses import dataclass
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.repositories.base import BaseRepository
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
from tradingagents.domains.news.google_news_client import GoogleNewsClient
|
||||
from tradingagents.domains.news.news_repository import NewsRepository
|
||||
|
||||
from .base import BaseService
|
||||
from .article_scraper_client import ArticleScraperClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -73,16 +74,27 @@ class GlobalNewsContext:
|
|||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class NewsService(BaseService):
|
||||
@dataclass
|
||||
class NewsUpdateResult:
|
||||
"""Result of news update operation."""
|
||||
|
||||
status: str
|
||||
articles_found: int
|
||||
articles_scraped: int
|
||||
articles_failed: int
|
||||
symbol: str | None = None
|
||||
categories: list[str] | None = None
|
||||
date_range: dict[str, str] | None = None
|
||||
|
||||
|
||||
class NewsService:
|
||||
"""Service for news data and sentiment analysis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finnhub_client: BaseClient | None = None,
|
||||
google_client: BaseClient | None = None,
|
||||
repository: BaseRepository | None = None,
|
||||
online_mode: bool = True,
|
||||
**kwargs,
|
||||
google_client: GoogleNewsClient,
|
||||
repository: NewsRepository,
|
||||
article_scraper: ArticleScraperClient,
|
||||
):
|
||||
"""
|
||||
Initialize news service.
|
||||
|
|
@ -91,46 +103,24 @@ class NewsService(BaseService):
|
|||
finnhub_client: Client for Finnhub news data
|
||||
google_client: Client for Google News data
|
||||
repository: Repository for cached news data
|
||||
online_mode: Whether to use live data
|
||||
**kwargs: Additional configuration
|
||||
article_scraper: Client for scraping article content
|
||||
"""
|
||||
super().__init__(online_mode, **kwargs)
|
||||
self.finnhub_client = finnhub_client
|
||||
self.google_client = google_client
|
||||
self.repository = repository
|
||||
self.article_scraper = article_scraper
|
||||
|
||||
def get_context(
|
||||
self,
|
||||
query: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
symbol: str | None = None,
|
||||
sources: list[str] | None = None,
|
||||
force_refresh: bool = False,
|
||||
**kwargs,
|
||||
) -> NewsContext:
|
||||
"""
|
||||
Get news context for a query and date range.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
symbol: Stock ticker symbol if company-specific
|
||||
sources: List of sources to use ('finnhub', 'google', or both)
|
||||
force_refresh: If True, skip local data and fetch fresh from APIs
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
NewsContext: Structured news context
|
||||
"""
|
||||
pass
|
||||
@staticmethod
|
||||
def build(_config: TradingAgentsConfig):
|
||||
google_client = GoogleNewsClient()
|
||||
repository = NewsRepository("")
|
||||
article_scraper = ArticleScraperClient("")
|
||||
return NewsService(google_client, repository, article_scraper)
|
||||
|
||||
def get_company_news_context(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> NewsContext:
|
||||
"""
|
||||
Get news context specific to a company.
|
||||
Get news context specific to a company from repository (no API calls).
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
|
|
@ -141,7 +131,65 @@ class NewsService(BaseService):
|
|||
Returns:
|
||||
NewsContext: Company-specific news context
|
||||
"""
|
||||
pass
|
||||
try:
|
||||
logger.info(f"Getting company news context for {symbol} from repository")
|
||||
|
||||
# Get articles from repository
|
||||
articles = []
|
||||
if self.repository:
|
||||
try:
|
||||
# This would depend on the actual repository interface
|
||||
# For now, return empty list - repository integration needs to be completed
|
||||
articles = []
|
||||
logger.debug(
|
||||
f"Retrieved {len(articles)} articles from repository for {symbol}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error retrieving articles from repository: {e}")
|
||||
articles = []
|
||||
|
||||
# Calculate sentiment summary from articles
|
||||
sentiment_summary = self._calculate_sentiment_summary(articles)
|
||||
|
||||
# Extract unique sources
|
||||
sources = list(
|
||||
{article.source for article in articles if hasattr(article, "source")}
|
||||
)
|
||||
|
||||
return NewsContext(
|
||||
query=symbol,
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
articles=articles,
|
||||
sentiment_summary=sentiment_summary,
|
||||
article_count=len(articles),
|
||||
sources=sources,
|
||||
metadata={
|
||||
"service": "news",
|
||||
"data_source": "repository",
|
||||
"method": "get_company_news_context",
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting company news context for {symbol}: {e}")
|
||||
# Return empty context on error
|
||||
return NewsContext(
|
||||
query=symbol,
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
articles=[],
|
||||
sentiment_summary=SentimentScore(
|
||||
score=0.0, confidence=0.0, label="neutral"
|
||||
),
|
||||
article_count=0,
|
||||
sources=[],
|
||||
metadata={
|
||||
"service": "news",
|
||||
"data_source": "repository",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
def get_global_news_context(
|
||||
self,
|
||||
|
|
@ -151,7 +199,7 @@ class NewsService(BaseService):
|
|||
**kwargs,
|
||||
) -> GlobalNewsContext:
|
||||
"""
|
||||
Get global/macro news context.
|
||||
Get global/macro news context from repository (no API calls).
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
|
|
@ -162,16 +210,425 @@ class NewsService(BaseService):
|
|||
Returns:
|
||||
GlobalNewsContext: Global news context
|
||||
"""
|
||||
# TODO: Implement global news fetching
|
||||
return GlobalNewsContext(
|
||||
period={"start": start_date, "end": end_date},
|
||||
categories=categories or [],
|
||||
articles=[],
|
||||
sentiment_summary=SentimentScore(
|
||||
score=0.0, confidence=0.0, label="neutral"
|
||||
),
|
||||
article_count=0,
|
||||
sources=[],
|
||||
trending_topics=[],
|
||||
metadata={"service": "news", "analysis_method": "global_news"},
|
||||
try:
|
||||
if categories is None:
|
||||
categories = ["general", "business", "politics"]
|
||||
|
||||
logger.info(
|
||||
f"Getting global news context from repository for categories: {categories}"
|
||||
)
|
||||
|
||||
# Get articles from repository
|
||||
articles = []
|
||||
if self.repository:
|
||||
try:
|
||||
# This would depend on the actual repository interface
|
||||
# For now, return empty list - repository integration needs to be completed
|
||||
articles = []
|
||||
logger.debug(
|
||||
f"Retrieved {len(articles)} global articles from repository"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error retrieving global articles from repository: {e}"
|
||||
)
|
||||
articles = []
|
||||
|
||||
# Calculate sentiment summary from articles
|
||||
sentiment_summary = self._calculate_sentiment_summary(articles)
|
||||
|
||||
# Extract unique sources
|
||||
sources = list(
|
||||
{article.source for article in articles if hasattr(article, "source")}
|
||||
)
|
||||
|
||||
# Extract trending topics (simplified implementation)
|
||||
trending_topics = self._extract_trending_topics(articles)
|
||||
|
||||
return GlobalNewsContext(
|
||||
period={"start": start_date, "end": end_date},
|
||||
categories=categories,
|
||||
articles=articles,
|
||||
sentiment_summary=sentiment_summary,
|
||||
article_count=len(articles),
|
||||
sources=sources,
|
||||
trending_topics=trending_topics,
|
||||
metadata={
|
||||
"service": "news",
|
||||
"data_source": "repository",
|
||||
"method": "get_global_news_context",
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting global news context: {e}")
|
||||
# Return empty context on error
|
||||
return GlobalNewsContext(
|
||||
period={"start": start_date, "end": end_date},
|
||||
categories=categories or [],
|
||||
articles=[],
|
||||
sentiment_summary=SentimentScore(
|
||||
score=0.0, confidence=0.0, label="neutral"
|
||||
),
|
||||
article_count=0,
|
||||
sources=[],
|
||||
trending_topics=[],
|
||||
metadata={
|
||||
"service": "news",
|
||||
"data_source": "repository",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
def update_company_news(self, symbol: str) -> NewsUpdateResult:
|
||||
"""
|
||||
Update company news by fetching RSS feeds and scraping article content.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
|
||||
Returns:
|
||||
NewsUpdateResult with update status and statistics
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Updating company news for {symbol}")
|
||||
|
||||
if not self.google_client:
|
||||
raise ValueError("Google client not configured")
|
||||
|
||||
# 1. Get RSS feed data
|
||||
google_articles = self.google_client.get_company_news(symbol)
|
||||
|
||||
if not google_articles:
|
||||
logger.warning(f"No articles found in RSS feed for {symbol}")
|
||||
return NewsUpdateResult(
|
||||
status="completed",
|
||||
articles_found=0,
|
||||
articles_scraped=0,
|
||||
articles_failed=0,
|
||||
symbol=symbol,
|
||||
)
|
||||
|
||||
# 2. Scrape each article content and convert to ArticleData
|
||||
article_data_list = []
|
||||
scraped_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for i, google_article in enumerate(google_articles):
|
||||
if not google_article.link:
|
||||
failed_count += 1
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Scraping article {i + 1}/{len(google_articles)}: {google_article.link}"
|
||||
)
|
||||
scrape_result = self.article_scraper.scrape_article(google_article.link)
|
||||
|
||||
# Create ArticleData with scraped content
|
||||
if scrape_result.status in ["SUCCESS", "ARCHIVE_SUCCESS"]:
|
||||
article_data = ArticleData(
|
||||
title=scrape_result.title or google_article.title,
|
||||
content=scrape_result.content,
|
||||
author=scrape_result.author,
|
||||
source=google_article.source,
|
||||
date=scrape_result.publish_date
|
||||
or google_article.published.strftime("%Y-%m-%d"),
|
||||
url=google_article.link,
|
||||
sentiment=None, # Will be calculated later
|
||||
)
|
||||
scraped_count += 1
|
||||
else:
|
||||
# Create ArticleData with just RSS data if scraping failed
|
||||
article_data = ArticleData(
|
||||
title=google_article.title,
|
||||
content=google_article.summary, # Use summary as fallback content
|
||||
author="",
|
||||
source=google_article.source,
|
||||
date=google_article.published.strftime("%Y-%m-%d"),
|
||||
url=google_article.link,
|
||||
sentiment=None,
|
||||
)
|
||||
failed_count += 1
|
||||
|
||||
article_data_list.append(article_data)
|
||||
|
||||
# 3. Store in repository
|
||||
try:
|
||||
logger.info(f"Storing {len(article_data_list)} articles for {symbol}")
|
||||
# Store articles (implementation depends on repository interface)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing articles in repository: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Company news update completed for {symbol}: {scraped_count} scraped, {failed_count} failed"
|
||||
)
|
||||
|
||||
return NewsUpdateResult(
|
||||
status="completed",
|
||||
articles_found=len(google_articles),
|
||||
articles_scraped=scraped_count,
|
||||
articles_failed=failed_count,
|
||||
symbol=symbol,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating company news for {symbol}: {e}")
|
||||
raise
|
||||
|
||||
def update_global_news(
|
||||
self, start_date: str, end_date: str, categories: list[str] | None = None
|
||||
) -> NewsUpdateResult:
|
||||
"""
|
||||
Update global/macro news by fetching RSS feeds and scraping article content.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
categories: List of news categories to search
|
||||
|
||||
Returns:
|
||||
NewsUpdateResult with update status and statistics
|
||||
"""
|
||||
try:
|
||||
if categories is None:
|
||||
categories = ["general", "business", "politics"]
|
||||
|
||||
logger.info(
|
||||
f"Updating global news from {start_date} to {end_date} for categories: {categories}"
|
||||
)
|
||||
|
||||
if not self.google_client:
|
||||
raise ValueError("Google client not configured")
|
||||
|
||||
# 1. Get RSS feed data for all categories
|
||||
google_articles = self.google_client.get_global_news(categories)
|
||||
|
||||
if not google_articles:
|
||||
logger.warning("No articles found in RSS feeds for global news")
|
||||
return NewsUpdateResult(
|
||||
status="completed",
|
||||
articles_found=0,
|
||||
articles_scraped=0,
|
||||
articles_failed=0,
|
||||
categories=categories,
|
||||
date_range={"start": start_date, "end": end_date},
|
||||
)
|
||||
|
||||
# 2. Scrape each article content and convert to ArticleData
|
||||
article_data_list = []
|
||||
scraped_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for i, google_article in enumerate(google_articles):
|
||||
if not google_article.link:
|
||||
failed_count += 1
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Scraping global article {i + 1}/{len(google_articles)}: {google_article.link}"
|
||||
)
|
||||
scrape_result = self.article_scraper.scrape_article(google_article.link)
|
||||
|
||||
# Create ArticleData with scraped content
|
||||
if scrape_result.status in ["SUCCESS", "ARCHIVE_SUCCESS"]:
|
||||
article_data = ArticleData(
|
||||
title=scrape_result.title or google_article.title,
|
||||
content=scrape_result.content,
|
||||
author=scrape_result.author,
|
||||
source=google_article.source,
|
||||
date=scrape_result.publish_date
|
||||
or google_article.published.strftime("%Y-%m-%d"),
|
||||
url=google_article.link,
|
||||
sentiment=None, # Will be calculated later
|
||||
)
|
||||
scraped_count += 1
|
||||
else:
|
||||
# Create ArticleData with just RSS data if scraping failed
|
||||
article_data = ArticleData(
|
||||
title=google_article.title,
|
||||
content=google_article.summary, # Use summary as fallback content
|
||||
author="",
|
||||
source=google_article.source,
|
||||
date=google_article.published.strftime("%Y-%m-%d"),
|
||||
url=google_article.link,
|
||||
sentiment=None,
|
||||
)
|
||||
failed_count += 1
|
||||
|
||||
article_data_list.append(article_data)
|
||||
|
||||
# 3. Store in repository
|
||||
try:
|
||||
logger.info(f"Storing {len(article_data_list)} global articles")
|
||||
# Store articles (implementation depends on repository interface)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing global articles in repository: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Global news update completed: {scraped_count} scraped, {failed_count} failed"
|
||||
)
|
||||
|
||||
return NewsUpdateResult(
|
||||
status="completed",
|
||||
articles_found=len(google_articles),
|
||||
articles_scraped=scraped_count,
|
||||
articles_failed=failed_count,
|
||||
categories=categories,
|
||||
date_range={"start": start_date, "end": end_date},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating global news: {e}")
|
||||
raise
|
||||
|
||||
def _calculate_sentiment_summary(
|
||||
self, articles: list[ArticleData]
|
||||
) -> SentimentScore:
|
||||
"""
|
||||
Calculate aggregate sentiment from article list.
|
||||
|
||||
Args:
|
||||
articles: List of ArticleData objects
|
||||
|
||||
Returns:
|
||||
SentimentScore: Aggregate sentiment score
|
||||
"""
|
||||
if not articles:
|
||||
return SentimentScore(score=0.0, confidence=0.0, label="neutral")
|
||||
|
||||
# Simple keyword-based sentiment analysis
|
||||
positive_words = {
|
||||
"good",
|
||||
"great",
|
||||
"excellent",
|
||||
"positive",
|
||||
"up",
|
||||
"rise",
|
||||
"gain",
|
||||
"profit",
|
||||
"growth",
|
||||
"success",
|
||||
"strong",
|
||||
"bullish",
|
||||
"optimistic",
|
||||
"boost",
|
||||
"surge",
|
||||
}
|
||||
negative_words = {
|
||||
"bad",
|
||||
"terrible",
|
||||
"negative",
|
||||
"down",
|
||||
"fall",
|
||||
"loss",
|
||||
"decline",
|
||||
"weak",
|
||||
"bearish",
|
||||
"pessimistic",
|
||||
"crash",
|
||||
"drop",
|
||||
"plunge",
|
||||
"concern",
|
||||
}
|
||||
|
||||
total_score = 0.0
|
||||
scored_articles = 0
|
||||
|
||||
for article in articles:
|
||||
if not hasattr(article, "content") or not article.content:
|
||||
continue
|
||||
|
||||
content_lower = article.content.lower()
|
||||
words = content_lower.split()
|
||||
|
||||
positive_count = sum(1 for word in words if word in positive_words)
|
||||
negative_count = sum(1 for word in words if word in negative_words)
|
||||
|
||||
if positive_count + negative_count > 0:
|
||||
article_score = (positive_count - negative_count) / len(words)
|
||||
total_score += article_score
|
||||
scored_articles += 1
|
||||
|
||||
if scored_articles == 0:
|
||||
return SentimentScore(score=0.0, confidence=0.0, label="neutral")
|
||||
|
||||
avg_score = total_score / scored_articles
|
||||
confidence = min(scored_articles / len(articles), 1.0)
|
||||
|
||||
# Normalize score to -1.0 to 1.0 range
|
||||
normalized_score = max(-1.0, min(1.0, avg_score * 10))
|
||||
|
||||
# Determine label
|
||||
if normalized_score > 0.1:
|
||||
label = "positive"
|
||||
elif normalized_score < -0.1:
|
||||
label = "negative"
|
||||
else:
|
||||
label = "neutral"
|
||||
|
||||
return SentimentScore(
|
||||
score=normalized_score, confidence=confidence, label=label
|
||||
)
|
||||
|
||||
def _extract_trending_topics(self, articles: list[ArticleData]) -> list[str]:
|
||||
"""
|
||||
Extract trending topics from article titles and content.
|
||||
|
||||
Args:
|
||||
articles: List of ArticleData objects
|
||||
|
||||
Returns:
|
||||
List of trending topic strings
|
||||
"""
|
||||
if not articles:
|
||||
return []
|
||||
|
||||
# Simple keyword extraction from titles
|
||||
word_counts = {}
|
||||
stop_words = {
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"or",
|
||||
"but",
|
||||
"in",
|
||||
"on",
|
||||
"at",
|
||||
"to",
|
||||
"for",
|
||||
"of",
|
||||
"with",
|
||||
"by",
|
||||
"is",
|
||||
"are",
|
||||
"was",
|
||||
"were",
|
||||
"be",
|
||||
"been",
|
||||
"have",
|
||||
"has",
|
||||
"had",
|
||||
"do",
|
||||
"does",
|
||||
"did",
|
||||
"will",
|
||||
"would",
|
||||
"could",
|
||||
"should",
|
||||
}
|
||||
|
||||
for article in articles:
|
||||
if hasattr(article, "title") and article.title:
|
||||
words = article.title.lower().split()
|
||||
for word in words:
|
||||
# Clean word
|
||||
word = "".join(c for c in word if c.isalnum())
|
||||
if len(word) > 3 and word not in stop_words:
|
||||
word_counts[word] = word_counts.get(word, 0) + 1
|
||||
|
||||
# Get top trending words
|
||||
trending = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
return [word for word, count in trending if count > 1]
|
||||
|
|
|
|||
|
|
@ -1,740 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test NewsService with mock clients and real NewsRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.insert(0, os.path.abspath("."))
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.domains.news.news_service import (
|
||||
NewsContext,
|
||||
NewsService,
|
||||
SentimentScore,
|
||||
)
|
||||
from tradingagents.repositories.news_repository import NewsRepository
|
||||
|
||||
|
||||
class MockFinnhubClient(BaseClient):
|
||||
"""Mock Finnhub client that returns sample news data."""
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_data(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""Not used directly by NewsService."""
|
||||
return {}
|
||||
|
||||
def get_company_news(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
"""Return mock Finnhub company news."""
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"period": {"start": start_date, "end": end_date},
|
||||
"articles": [
|
||||
{
|
||||
"headline": f"{symbol} Beats Q4 Earnings Expectations",
|
||||
"summary": f"{symbol} reported earnings of $2.50 per share, beating analyst estimates of $2.25.",
|
||||
"url": f"https://example.com/finnhub/{symbol.lower()}-earnings",
|
||||
"source": "Finnhub Financial",
|
||||
"date": start_date,
|
||||
"entities": [symbol],
|
||||
},
|
||||
{
|
||||
"headline": f"Insider Trading Activity at {symbol}",
|
||||
"summary": f"Company executives at {symbol} have increased their holdings by 15% this quarter.",
|
||||
"url": f"https://example.com/finnhub/{symbol.lower()}-insider",
|
||||
"source": "Finnhub SEC Filings",
|
||||
"date": end_date,
|
||||
"entities": [symbol, "insider trading"],
|
||||
},
|
||||
],
|
||||
"metadata": {
|
||||
"source": "mock_finnhub",
|
||||
"article_count": 2,
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MockGoogleNewsClient(BaseClient):
|
||||
"""Mock Google News client that returns sample articles."""
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_data(
|
||||
self, query: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
"""Return mock Google News data."""
|
||||
article_templates = [
|
||||
{
|
||||
"template": "{query} Stock Surges on Positive Outlook",
|
||||
"summary": "Shares of {query} rose 5% in after-hours trading following strong guidance for next quarter.",
|
||||
"source": "Mock Market News",
|
||||
},
|
||||
{
|
||||
"template": "Analysts Recommend Buy Rating for {query}",
|
||||
"summary": "Three major investment firms upgraded {query} to 'Buy' with improved price targets.",
|
||||
"source": "Mock Investment Daily",
|
||||
},
|
||||
{
|
||||
"template": "{query} Announces Strategic Partnership",
|
||||
"summary": "The company revealed a new collaboration that could expand market reach significantly.",
|
||||
"source": "Mock Business Wire",
|
||||
},
|
||||
]
|
||||
|
||||
articles = []
|
||||
for i, template in enumerate(article_templates):
|
||||
current_date = datetime.strptime(start_date, "%Y-%m-%d") + timedelta(days=i)
|
||||
|
||||
articles.append(
|
||||
{
|
||||
"headline": template["template"].format(query=query),
|
||||
"summary": template["summary"].format(query=query),
|
||||
"url": f"https://example.com/google/{query.lower()}-{i}",
|
||||
"source": template["source"],
|
||||
"date": current_date.strftime("%Y-%m-%d"),
|
||||
"entities": [query],
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"period": {"start": start_date, "end": end_date},
|
||||
"articles": articles,
|
||||
"metadata": {
|
||||
"source": "mock_google_news",
|
||||
"article_count": len(articles),
|
||||
"retrieved_at": datetime.utcnow().isoformat(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_online_mode_with_mock_clients():
|
||||
"""Test NewsService in online mode with mock clients."""
|
||||
print("📰 Testing NewsService - Online Mode")
|
||||
|
||||
# Create mock clients and real repository
|
||||
mock_finnhub = MockFinnhubClient()
|
||||
mock_google = MockGoogleNewsClient()
|
||||
real_repo = NewsRepository("test_data")
|
||||
|
||||
# Create service in online mode
|
||||
service = NewsService(
|
||||
finnhub_client=mock_finnhub,
|
||||
google_client=mock_google,
|
||||
repository=real_repo,
|
||||
online_mode=True,
|
||||
data_dir="test_data",
|
||||
)
|
||||
|
||||
try:
|
||||
# Test company news context
|
||||
context = service.get_company_news_context(
|
||||
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
|
||||
)
|
||||
|
||||
print(f"✅ Company news context created: {context.__class__.__name__}")
|
||||
print(f" Symbol: {context.symbol}")
|
||||
print(f" Period: {context.period}")
|
||||
print(f" Articles: {len(context.articles)}")
|
||||
print(f" Sentiment score: {context.sentiment_summary.score:.3f}")
|
||||
print(f" Sentiment confidence: {context.sentiment_summary.confidence:.3f}")
|
||||
print(f" Sources: {context.sources}")
|
||||
|
||||
# Validate required fields
|
||||
assert context.symbol == "AAPL"
|
||||
assert context.period["start"] == "2024-01-01"
|
||||
assert context.period["end"] == "2024-01-05"
|
||||
assert len(context.articles) > 0
|
||||
assert (
|
||||
context.sentiment_summary.score >= -1.0
|
||||
and context.sentiment_summary.score <= 1.0
|
||||
)
|
||||
assert "data_quality" in context.metadata
|
||||
|
||||
print("✅ Basic validation passed")
|
||||
|
||||
# Test JSON serialization
|
||||
json_output = context.model_dump_json(indent=2)
|
||||
parsed = json.loads(json_output)
|
||||
|
||||
print(f"✅ JSON serialization: {len(json_output)} characters")
|
||||
print(f" Top-level keys: {list(parsed.keys())}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Online mode test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_global_news_context():
|
||||
"""Test global news functionality."""
|
||||
print("\n🌍 Testing Global News Context")
|
||||
|
||||
mock_google = MockGoogleNewsClient()
|
||||
real_repo = NewsRepository("test_data")
|
||||
|
||||
service = NewsService(
|
||||
finnhub_client=None,
|
||||
google_client=mock_google,
|
||||
repository=real_repo,
|
||||
online_mode=True,
|
||||
data_dir="test_data",
|
||||
)
|
||||
|
||||
try:
|
||||
# Test global news with categories
|
||||
context = service.get_global_news_context(
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-03",
|
||||
categories=["economy", "markets"],
|
||||
)
|
||||
|
||||
print("✅ Global news context created")
|
||||
print(f" Symbol: {context.symbol}") # Should be None for global news
|
||||
print(f" Articles: {len(context.articles)}")
|
||||
print(f" Categories searched: {context.metadata.get('categories', [])}")
|
||||
print(f" Sentiment score: {context.sentiment_summary.score:.3f}")
|
||||
|
||||
# Validate global news structure
|
||||
assert context.symbol is None # Global news shouldn't have a symbol
|
||||
assert len(context.articles) > 0
|
||||
assert "categories" in context.metadata
|
||||
|
||||
print("✅ Global news validation passed")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Global news test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_offline_mode_with_real_repository():
|
||||
"""Test NewsService in offline mode with real repository."""
|
||||
print("\n💾 Testing NewsService - Offline Mode")
|
||||
|
||||
# Create service in offline mode (no clients)
|
||||
real_repo = NewsRepository("test_data")
|
||||
service = NewsService(
|
||||
finnhub_client=None,
|
||||
google_client=None,
|
||||
repository=real_repo,
|
||||
online_mode=False,
|
||||
data_dir="test_data",
|
||||
)
|
||||
|
||||
try:
|
||||
# Test offline context (will likely return empty data)
|
||||
context = service.get_company_news_context(
|
||||
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
|
||||
)
|
||||
|
||||
print(f"✅ Offline context created: {context.__class__.__name__}")
|
||||
print(f" Symbol: {context.symbol}")
|
||||
print(f" Articles: {len(context.articles)}")
|
||||
print(f" Data quality: {context.metadata.get('data_quality')}")
|
||||
print(f" Service mode: online={service.is_online()}")
|
||||
|
||||
# Should handle empty data gracefully
|
||||
assert context.symbol == "AAPL"
|
||||
assert isinstance(context.articles, list)
|
||||
assert isinstance(context.sentiment_summary, SentimentScore)
|
||||
assert "data_quality" in context.metadata
|
||||
|
||||
print("✅ Offline mode graceful handling verified")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Offline mode test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_sentiment_analysis():
|
||||
"""Test sentiment analysis functionality."""
|
||||
print("\n😊 Testing Sentiment Analysis")
|
||||
|
||||
# Create service with custom articles for sentiment testing
|
||||
class SentimentTestClient(BaseClient):
|
||||
def test_connection(self):
|
||||
return True
|
||||
|
||||
def get_data(self, query, start_date, end_date, **kwargs):
|
||||
return {
|
||||
"query": query,
|
||||
"articles": [
|
||||
{
|
||||
"headline": f"{query} Soars on Excellent Earnings Report",
|
||||
"summary": "Great performance with strong growth and positive outlook for investors.",
|
||||
"source": "Positive News",
|
||||
"date": start_date,
|
||||
"entities": [query],
|
||||
},
|
||||
{
|
||||
"headline": f"{query} Faces Challenges in Market Downturn",
|
||||
"summary": "Concerns about declining revenue and poor market conditions affecting performance.",
|
||||
"source": "Negative News",
|
||||
"date": end_date,
|
||||
"entities": [query],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
sentiment_client = SentimentTestClient()
|
||||
service = NewsService(
|
||||
finnhub_client=None,
|
||||
google_client=sentiment_client,
|
||||
repository=None,
|
||||
online_mode=True,
|
||||
)
|
||||
|
||||
try:
|
||||
context = service.get_context(
|
||||
"TEST", "2024-01-01", "2024-01-02", sources=["google"]
|
||||
)
|
||||
|
||||
print("✅ Sentiment analysis completed")
|
||||
print(f" Articles analyzed: {len(context.articles)}")
|
||||
print(f" Overall sentiment: {context.sentiment_summary.score:.3f}")
|
||||
print(f" Confidence: {context.sentiment_summary.confidence:.3f}")
|
||||
print(f" Label: {context.sentiment_summary.label}")
|
||||
|
||||
# Validate sentiment processing
|
||||
assert len(context.articles) == 2
|
||||
assert (
|
||||
context.sentiment_summary.score >= -1.0
|
||||
and context.sentiment_summary.score <= 1.0
|
||||
)
|
||||
assert (
|
||||
context.sentiment_summary.confidence >= 0.0
|
||||
and context.sentiment_summary.confidence <= 1.0
|
||||
)
|
||||
assert context.sentiment_summary.label in ["positive", "negative", "neutral"]
|
||||
|
||||
# Check individual article sentiments
|
||||
for article in context.articles:
|
||||
if article.sentiment:
|
||||
assert (
|
||||
article.sentiment.score >= -1.0 and article.sentiment.score <= 1.0
|
||||
)
|
||||
|
||||
print("✅ Sentiment analysis validation passed")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Sentiment analysis test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_multiple_source_aggregation():
|
||||
"""Test aggregation from multiple news sources."""
|
||||
print("\n🔄 Testing Multiple Source Aggregation")
|
||||
|
||||
mock_finnhub = MockFinnhubClient()
|
||||
mock_google = MockGoogleNewsClient()
|
||||
real_repo = NewsRepository("test_data")
|
||||
|
||||
service = NewsService(
|
||||
finnhub_client=mock_finnhub,
|
||||
google_client=mock_google,
|
||||
repository=real_repo,
|
||||
online_mode=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Test with both sources
|
||||
context = service.get_context(
|
||||
query="MSFT",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-03",
|
||||
symbol="MSFT",
|
||||
sources=["finnhub", "google"],
|
||||
)
|
||||
|
||||
print("✅ Multi-source aggregation completed")
|
||||
print(f" Total articles: {len(context.articles)}")
|
||||
print(f" Unique sources: {context.sources}")
|
||||
print(f" Sources used: {context.metadata.get('sources_used', [])}")
|
||||
|
||||
# Should have articles from both sources
|
||||
assert len(context.articles) > 0
|
||||
assert len(context.sources) > 0
|
||||
|
||||
# Check that articles from different sources are present
|
||||
source_counts = {}
|
||||
for article in context.articles:
|
||||
source = article.source
|
||||
source_counts[source] = source_counts.get(source, 0) + 1
|
||||
|
||||
print(f" Source distribution: {source_counts}")
|
||||
|
||||
print("✅ Multi-source aggregation validated")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Multi-source test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_json_structure_validation():
|
||||
"""Test detailed JSON structure validation."""
|
||||
print("\n📄 Testing JSON Structure")
|
||||
|
||||
mock_google = MockGoogleNewsClient()
|
||||
service = NewsService(
|
||||
finnhub_client=None,
|
||||
google_client=mock_google,
|
||||
repository=None,
|
||||
online_mode=True,
|
||||
)
|
||||
|
||||
try:
|
||||
context = service.get_context(
|
||||
"TSLA", "2024-01-01", "2024-01-03", sources=["google"]
|
||||
)
|
||||
json_str = context.model_dump_json(indent=2)
|
||||
data = json.loads(json_str)
|
||||
|
||||
# Validate required structure
|
||||
required_fields = [
|
||||
"symbol",
|
||||
"period",
|
||||
"articles",
|
||||
"sentiment_summary",
|
||||
"article_count",
|
||||
"sources",
|
||||
"metadata",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in data, f"Missing field: {field}"
|
||||
|
||||
# Validate period structure
|
||||
period = data["period"]
|
||||
assert "start" in period and "end" in period
|
||||
|
||||
# Validate articles structure
|
||||
assert isinstance(data["articles"], list)
|
||||
if data["articles"]:
|
||||
first_article = data["articles"][0]
|
||||
required_article_fields = ["headline", "source", "date"]
|
||||
for field in required_article_fields:
|
||||
assert field in first_article, f"Missing article field: {field}"
|
||||
|
||||
# Validate sentiment structure
|
||||
sentiment = data["sentiment_summary"]
|
||||
assert (
|
||||
"score" in sentiment and "confidence" in sentiment and "label" in sentiment
|
||||
)
|
||||
assert -1.0 <= sentiment["score"] <= 1.0
|
||||
assert 0.0 <= sentiment["confidence"] <= 1.0
|
||||
|
||||
# Validate metadata
|
||||
metadata = data["metadata"]
|
||||
assert "data_quality" in metadata
|
||||
assert "service" in metadata
|
||||
|
||||
print("✅ JSON structure validation passed")
|
||||
print(f" Fields: {list(data.keys())}")
|
||||
print(f" Articles: {len(data['articles'])}")
|
||||
print(f" Sentiment score: {sentiment['score']:.3f}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ JSON structure test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_force_refresh_parameter():
|
||||
"""Test the force_refresh parameter functionality."""
|
||||
print("\n🔄 Testing Force Refresh Parameter")
|
||||
|
||||
try:
|
||||
mock_google = MockGoogleNewsClient()
|
||||
real_repo = NewsRepository("test_data")
|
||||
|
||||
service = NewsService(
|
||||
finnhub_client=None,
|
||||
google_client=mock_google,
|
||||
repository=real_repo,
|
||||
online_mode=True,
|
||||
)
|
||||
|
||||
# Test normal flow (should use repository if available)
|
||||
normal_context = service.get_context(
|
||||
"AAPL", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=False
|
||||
)
|
||||
|
||||
# Test force refresh (should bypass repository and use client)
|
||||
refresh_context = service.get_context(
|
||||
"AAPL", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=True
|
||||
)
|
||||
|
||||
# Both should return valid contexts
|
||||
assert isinstance(normal_context, NewsContext)
|
||||
assert isinstance(refresh_context, NewsContext)
|
||||
assert normal_context.symbol == "AAPL"
|
||||
assert refresh_context.symbol == "AAPL"
|
||||
|
||||
# Check metadata indicates source
|
||||
refresh_metadata = refresh_context.metadata
|
||||
assert "force_refresh" in refresh_metadata
|
||||
assert refresh_metadata["force_refresh"]
|
||||
|
||||
print("✅ Force refresh parameter test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Force refresh test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_local_first_strategy():
|
||||
"""Test that the service checks local data first when available."""
|
||||
print("\n🏠 Testing Local-First Strategy")
|
||||
|
||||
try:
|
||||
|
||||
class MockRepositoryWithData(NewsRepository):
|
||||
def has_data_for_period(
|
||||
self, identifier: str, start_date: str, end_date: str, **kwargs
|
||||
) -> bool:
|
||||
return True # Pretend we have the data
|
||||
|
||||
def get_data(
|
||||
self, query: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"query": kwargs.get("query", "TEST"),
|
||||
"symbol": kwargs.get("symbol"),
|
||||
"articles": [
|
||||
{
|
||||
"headline": "Test Article from Local Cache",
|
||||
"summary": "This article came from local repository",
|
||||
"source": "Local Cache",
|
||||
"date": "2024-01-01",
|
||||
"url": "https://local.cache/test",
|
||||
"entities": ["TEST"],
|
||||
}
|
||||
],
|
||||
"metadata": {"source": "test_repository"},
|
||||
}
|
||||
|
||||
mock_client = MockGoogleNewsClient()
|
||||
mock_repo = MockRepositoryWithData("test_data")
|
||||
|
||||
service = NewsService(
|
||||
finnhub_client=None,
|
||||
google_client=mock_client,
|
||||
repository=mock_repo,
|
||||
online_mode=True,
|
||||
)
|
||||
|
||||
# Should use local data since repository has_data_for_period returns True
|
||||
context = service.get_context(
|
||||
"TEST", "2024-01-01", "2024-01-31", sources=["google"]
|
||||
)
|
||||
|
||||
# Verify we used local data
|
||||
assert context.metadata.get("data_source") == "local_cache"
|
||||
assert len(context.articles) == 1 # From mock repository
|
||||
assert context.articles[0].headline == "Test Article from Local Cache"
|
||||
|
||||
print("✅ Local-first strategy test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Local-first strategy test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_local_first_fallback_to_api():
|
||||
"""Test that service falls back to API when local data is insufficient."""
|
||||
print("\n🔄 Testing Local-First Fallback to API")
|
||||
|
||||
try:
|
||||
|
||||
class MockRepositoryWithoutData(NewsRepository):
|
||||
def has_data_for_period(
|
||||
self, identifier: str, start_date: str, end_date: str, **kwargs
|
||||
) -> bool:
|
||||
return False # Pretend we don't have the data
|
||||
|
||||
def get_data(
|
||||
self, query: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"query": kwargs.get("query", "TEST"),
|
||||
"articles": [],
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
def store_data(
|
||||
self,
|
||||
symbol: str,
|
||||
data: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
return True # Pretend storage was successful
|
||||
|
||||
mock_client = MockGoogleNewsClient()
|
||||
mock_repo = MockRepositoryWithoutData("test_data")
|
||||
|
||||
service = NewsService(
|
||||
finnhub_client=None,
|
||||
google_client=mock_client,
|
||||
repository=mock_repo,
|
||||
online_mode=True,
|
||||
)
|
||||
|
||||
# Should fall back to API since repository doesn't have data
|
||||
context = service.get_context(
|
||||
"TEST", "2024-01-01", "2024-01-31", sources=["google"]
|
||||
)
|
||||
|
||||
# Verify we used API data
|
||||
assert context.metadata.get("data_source") == "live_api"
|
||||
assert len(context.articles) > 0 # From mock client
|
||||
|
||||
print("✅ Local-first fallback to API test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Local-first fallback test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_force_refresh_bypasses_local_data():
|
||||
"""Test that force_refresh=True bypasses local data even when available."""
|
||||
print("\n⚡ Testing Force Refresh Bypasses Local Data")
|
||||
|
||||
try:
|
||||
|
||||
class MockRepositoryAlwaysHasData(NewsRepository):
|
||||
def has_data_for_period(
|
||||
self, identifier: str, start_date: str, end_date: str, **kwargs
|
||||
) -> bool:
|
||||
return True # Always claim we have data
|
||||
|
||||
def get_data(
|
||||
self, query: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"query": kwargs.get("query", "TEST"),
|
||||
"symbol": kwargs.get("symbol"),
|
||||
"articles": [
|
||||
{
|
||||
"headline": "Old Cached Article",
|
||||
"summary": "This is from local cache",
|
||||
"source": "Local Cache",
|
||||
"date": "2024-01-01",
|
||||
"url": "https://cache.local/old",
|
||||
"entities": ["TEST"],
|
||||
}
|
||||
],
|
||||
"metadata": {"source": "local"},
|
||||
}
|
||||
|
||||
def clear_data(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def store_data(
|
||||
self,
|
||||
symbol: str,
|
||||
data: dict[str, Any],
|
||||
overwrite: bool = False,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
mock_client = MockGoogleNewsClient()
|
||||
mock_repo = MockRepositoryAlwaysHasData("test_data")
|
||||
|
||||
service = NewsService(
|
||||
finnhub_client=None,
|
||||
google_client=mock_client,
|
||||
repository=mock_repo,
|
||||
online_mode=True,
|
||||
)
|
||||
|
||||
# Force refresh should bypass local data
|
||||
context = service.get_context(
|
||||
"TEST", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=True
|
||||
)
|
||||
|
||||
# Verify we used API data (force refresh)
|
||||
assert context.metadata.get("data_source") == "live_api_refresh"
|
||||
assert context.metadata.get("force_refresh")
|
||||
# Should have fresh data from client, not the old cached article
|
||||
assert len(context.articles) > 1 # Client returns multiple articles
|
||||
|
||||
print("✅ Force refresh bypasses local data test passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Force refresh bypass test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all NewsService tests."""
|
||||
print("🧪 Testing NewsService\n")
|
||||
|
||||
tests = [
|
||||
test_online_mode_with_mock_clients,
|
||||
test_global_news_context,
|
||||
test_offline_mode_with_real_repository,
|
||||
test_sentiment_analysis,
|
||||
test_multiple_source_aggregation,
|
||||
test_json_structure_validation,
|
||||
test_force_refresh_parameter,
|
||||
test_local_first_strategy,
|
||||
test_local_first_fallback_to_api,
|
||||
test_force_refresh_bypasses_local_data,
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
if test():
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f"❌ Test {test.__name__} crashed: {e}")
|
||||
failed += 1
|
||||
|
||||
print("\n📊 NewsService Test Results:")
|
||||
print(f" ✅ Passed: {passed}")
|
||||
print(f" ❌ Failed: {failed}")
|
||||
|
||||
if failed == 0:
|
||||
print("🎉 All NewsService tests passed!")
|
||||
else:
|
||||
print("⚠️ Some tests failed - check output above")
|
||||
|
||||
return failed == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
|
|
@ -7,6 +7,8 @@ from dataclasses import dataclass
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
|
||||
from .reddit_client import RedditClient
|
||||
from .social_media_repository import SocialMediaRepository
|
||||
|
||||
|
|
@ -111,6 +113,12 @@ class SocialMediaService:
|
|||
self.reddit_client = reddit_client
|
||||
self.repository = repository
|
||||
|
||||
@staticmethod
|
||||
def build(_config: TradingAgentsConfig):
|
||||
client = RedditClient()
|
||||
repo = SocialMediaRepository("")
|
||||
return SocialMediaService(client, repo)
|
||||
|
||||
def get_context(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -134,74 +142,29 @@ class SocialMediaService:
|
|||
SocialContext with posts and sentiment analysis
|
||||
"""
|
||||
posts = []
|
||||
error_info = {}
|
||||
data_source = "unknown"
|
||||
|
||||
try:
|
||||
# Local-first data strategy with force refresh option
|
||||
if force_refresh:
|
||||
# Skip local data, fetch fresh from APIs
|
||||
posts, data_source = self._fetch_and_cache_fresh_social_data(
|
||||
query, start_date, end_date, symbol, subreddits
|
||||
)
|
||||
else:
|
||||
# Check local data first, fetch missing if needed
|
||||
posts, data_source = self._get_social_data_local_first(
|
||||
query, start_date, end_date, symbol, subreddits
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching social media data: {e}")
|
||||
error_info = {"error": str(e)}
|
||||
|
||||
# Calculate sentiment and engagement metrics
|
||||
sentiment_summary = self._calculate_sentiment(posts)
|
||||
engagement_metrics = self._calculate_engagement_metrics(posts)
|
||||
|
||||
# Determine data quality based on data source
|
||||
data_quality = self._determine_data_quality(
|
||||
data_source=data_source,
|
||||
record_count=len(posts),
|
||||
has_errors=bool(error_info),
|
||||
)
|
||||
|
||||
# Create structured engagement metrics
|
||||
structured_metrics = EngagementMetrics(
|
||||
total_engagement=float(engagement_metrics.get("total_engagement", 0)),
|
||||
average_engagement=float(engagement_metrics.get("average_engagement", 0)),
|
||||
max_engagement=float(engagement_metrics.get("max_engagement", 0)),
|
||||
total_posts=int(engagement_metrics.get("total_posts", 0)),
|
||||
total_engagement=0,
|
||||
average_engagement=0,
|
||||
max_engagement=0,
|
||||
total_posts=0,
|
||||
)
|
||||
|
||||
# Separate non-float metrics for metadata
|
||||
metadata_info = {
|
||||
k: v
|
||||
for k, v in engagement_metrics.items()
|
||||
if k
|
||||
not in [
|
||||
"total_engagement",
|
||||
"average_engagement",
|
||||
"max_engagement",
|
||||
"total_posts",
|
||||
]
|
||||
}
|
||||
|
||||
return SocialContext(
|
||||
symbol=symbol,
|
||||
period=(start_date, end_date),
|
||||
posts=posts,
|
||||
engagement_metrics=structured_metrics,
|
||||
sentiment_summary=sentiment_summary,
|
||||
sentiment_summary=SentimentScore(score=0, confidence=0, label=""),
|
||||
post_count=len(posts),
|
||||
platforms=["reddit"],
|
||||
metadata={
|
||||
"data_quality": data_quality,
|
||||
"service": "social_media",
|
||||
"subreddits": subreddits or [],
|
||||
"data_source": data_source,
|
||||
"force_refresh": force_refresh,
|
||||
**metadata_info,
|
||||
**error_info,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,378 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test SocialMediaService with mock RedditClient and real SocialRepository.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.insert(0, os.path.abspath("."))
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.domains.socialmedia.social_media_service import (
|
||||
DataQuality,
|
||||
PostData,
|
||||
SentimentScore,
|
||||
SocialContext,
|
||||
SocialMediaService,
|
||||
)
|
||||
from tradingagents.repositories.social_repository import SocialRepository
|
||||
|
||||
|
||||
class MockRedditClient(BaseClient):
|
||||
"""Mock Reddit client that returns sample social media data."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.connection_works = True
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
return self.connection_works
|
||||
|
||||
def get_data(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""Not used directly by SocialMediaService."""
|
||||
return {}
|
||||
|
||||
def search_posts(
|
||||
self,
|
||||
query: str,
|
||||
subreddit_names: list[str],
|
||||
limit: int = 25,
|
||||
time_filter: str = "week",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return mock Reddit search results."""
|
||||
posts = []
|
||||
# Use fixed dates that will work with our date filter
|
||||
base_date = datetime(2024, 1, 2) # Within our test range
|
||||
|
||||
for i, subreddit in enumerate(
|
||||
subreddit_names[:2]
|
||||
): # Limit to 2 subreddits for testing
|
||||
posts.extend(
|
||||
[
|
||||
{
|
||||
"title": f"{query} to the moon! 🚀",
|
||||
"content": f"DD on {query}: Strong fundamentals, great earnings beat. Buy and hold!",
|
||||
"url": f"https://reddit.com/r/{subreddit}/post1",
|
||||
"upvotes": 1500 - (i * 100),
|
||||
"score": 1450 - (i * 100),
|
||||
"num_comments": 234,
|
||||
"created_utc": (base_date + timedelta(hours=i)).timestamp(),
|
||||
"subreddit": subreddit,
|
||||
"author": f"WSBtrader{i}",
|
||||
"posted_date": (base_date + timedelta(hours=i)).strftime(
|
||||
"%Y-%m-%d"
|
||||
),
|
||||
},
|
||||
{
|
||||
"title": f"Why I'm bearish on {query}",
|
||||
"content": f"Overvalued, competition increasing, margins declining. Time to sell {query}.",
|
||||
"url": f"https://reddit.com/r/{subreddit}/post2",
|
||||
"upvotes": 800 - (i * 50),
|
||||
"score": 750 - (i * 50),
|
||||
"num_comments": 156,
|
||||
"created_utc": (base_date + timedelta(hours=i + 1)).timestamp(),
|
||||
"subreddit": subreddit,
|
||||
"author": f"BearishTrader{i}",
|
||||
"posted_date": (base_date + timedelta(hours=i + 1)).strftime(
|
||||
"%Y-%m-%d"
|
||||
),
|
||||
},
|
||||
]
|
||||
)
|
||||
return posts
|
||||
|
||||
def get_top_posts(
|
||||
self, subreddit_names: list[str], limit: int = 25, time_filter: str = "week"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return mock top posts from subreddits."""
|
||||
posts = []
|
||||
# Use fixed dates that will work with our date filter
|
||||
base_date = datetime(2024, 1, 2) # Within our test range
|
||||
|
||||
for subreddit in subreddit_names[:2]:
|
||||
posts.append(
|
||||
{
|
||||
"title": "Market Update: Tech stocks rally continues",
|
||||
"content": "FAANG stocks leading the charge. SPY hit new ATH. Bull market confirmed.",
|
||||
"url": f"https://reddit.com/r/{subreddit}/top1",
|
||||
"upvotes": 2500,
|
||||
"score": 2400,
|
||||
"num_comments": 456,
|
||||
"created_utc": base_date.timestamp(),
|
||||
"subreddit": subreddit,
|
||||
"author": "MarketWatcher",
|
||||
"posted_date": base_date.strftime("%Y-%m-%d"),
|
||||
}
|
||||
)
|
||||
return posts
|
||||
|
||||
def filter_posts_by_date(
|
||||
self, posts: list[dict[str, Any]], start_date: str, end_date: str
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Filter posts by date range."""
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
|
||||
|
||||
filtered = []
|
||||
for post in posts:
|
||||
if "posted_date" in post:
|
||||
post_dt = datetime.strptime(post["posted_date"], "%Y-%m-%d")
|
||||
if start_dt <= post_dt <= end_dt:
|
||||
filtered.append(post)
|
||||
return filtered
|
||||
|
||||
|
||||
def test_online_mode_with_mock_reddit():
|
||||
"""Test SocialMediaService in online mode with mock Reddit client."""
|
||||
# Create mock client and real repository
|
||||
mock_reddit = MockRedditClient()
|
||||
real_repo = SocialRepository("test_data")
|
||||
|
||||
# Create service in online mode
|
||||
service = SocialMediaService(
|
||||
reddit_client=mock_reddit,
|
||||
repository=real_repo,
|
||||
online_mode=True,
|
||||
data_dir="test_data",
|
||||
)
|
||||
|
||||
# Test company-specific social context
|
||||
context = service.get_company_social_context(
|
||||
symbol="TSLA",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-05",
|
||||
subreddits=["wallstreetbets", "stocks"],
|
||||
force_refresh=True,
|
||||
)
|
||||
|
||||
# Validate context structure
|
||||
assert isinstance(context, SocialContext)
|
||||
assert context.symbol == "TSLA"
|
||||
assert context.period["start"] == "2024-01-01"
|
||||
assert context.period["end"] == "2024-01-05"
|
||||
assert len(context.posts) > 0
|
||||
assert isinstance(context.sentiment_summary, SentimentScore)
|
||||
assert context.post_count == len(context.posts)
|
||||
assert "data_quality" in context.metadata
|
||||
|
||||
# Test JSON serialization
|
||||
json_output = context.model_dump_json(indent=2)
|
||||
assert len(json_output) > 0
|
||||
|
||||
# Validate individual posts
|
||||
for post in context.posts:
|
||||
assert isinstance(post, PostData)
|
||||
assert post.title
|
||||
assert post.author
|
||||
assert post.date
|
||||
assert post.score >= 0
|
||||
|
||||
|
||||
def test_global_social_trends():
|
||||
"""Test global social media trends functionality."""
|
||||
mock_reddit = MockRedditClient()
|
||||
real_repo = SocialRepository("test_data")
|
||||
|
||||
service = SocialMediaService(
|
||||
reddit_client=mock_reddit, repository=real_repo, online_mode=True
|
||||
)
|
||||
|
||||
# Test global trends
|
||||
context = service.get_global_trends(
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-03",
|
||||
subreddits=["investing", "stocks", "wallstreetbets"],
|
||||
force_refresh=True,
|
||||
)
|
||||
|
||||
# Validate global context
|
||||
assert context.symbol is None # Global trends have no specific symbol
|
||||
assert len(context.posts) > 0
|
||||
assert "reddit" in context.platforms
|
||||
assert "subreddits" in context.metadata
|
||||
|
||||
|
||||
def test_sentiment_analysis():
|
||||
"""Test sentiment analysis on social posts."""
|
||||
|
||||
# Create service with posts that have clear sentiment
|
||||
class SentimentTestClient(MockRedditClient):
|
||||
def search_posts(self, query, subreddit_names, limit=25, time_filter="week"):
|
||||
return [
|
||||
{
|
||||
"title": f"{query} is the best investment ever! 🚀🚀🚀",
|
||||
"content": "Amazing earnings, incredible growth, bullish AF!",
|
||||
"upvotes": 5000,
|
||||
"score": 4900,
|
||||
"num_comments": 500,
|
||||
"subreddit": "wallstreetbets",
|
||||
"author": "BullishTrader",
|
||||
"posted_date": "2024-01-01",
|
||||
},
|
||||
{
|
||||
"title": f"WARNING: {query} is about to crash hard",
|
||||
"content": "Terrible fundamentals, overvalued, sell now before it's too late!",
|
||||
"upvotes": 100,
|
||||
"score": 50,
|
||||
"num_comments": 30,
|
||||
"subreddit": "stocks",
|
||||
"author": "BearishAnalyst",
|
||||
"posted_date": "2024-01-01",
|
||||
},
|
||||
]
|
||||
|
||||
sentiment_client = SentimentTestClient()
|
||||
service = SocialMediaService(
|
||||
reddit_client=sentiment_client, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_context("GME", "2024-01-01", "2024-01-02")
|
||||
|
||||
# Check sentiment analysis
|
||||
assert context.sentiment_summary.score != 0 # Should have some sentiment
|
||||
assert context.sentiment_summary.confidence > 0
|
||||
assert context.sentiment_summary.label in ["positive", "negative", "neutral"]
|
||||
|
||||
# Check individual post sentiments
|
||||
for post in context.posts:
|
||||
if post.sentiment:
|
||||
assert -1.0 <= post.sentiment.score <= 1.0
|
||||
|
||||
|
||||
def test_offline_mode():
|
||||
"""Test SocialMediaService in offline mode."""
|
||||
real_repo = SocialRepository("test_data")
|
||||
|
||||
service = SocialMediaService(
|
||||
reddit_client=None, repository=real_repo, online_mode=False
|
||||
)
|
||||
|
||||
# Should handle offline gracefully
|
||||
context = service.get_context("AAPL", "2024-01-01", "2024-01-05", symbol="AAPL")
|
||||
|
||||
assert context.symbol == "AAPL"
|
||||
assert isinstance(context.posts, list)
|
||||
assert context.metadata.get("data_quality") == DataQuality.LOW
|
||||
|
||||
|
||||
def test_engagement_metrics():
|
||||
"""Test calculation of engagement metrics."""
|
||||
mock_reddit = MockRedditClient()
|
||||
service = SocialMediaService(
|
||||
reddit_client=mock_reddit, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_company_social_context(
|
||||
symbol="NVDA",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-02",
|
||||
subreddits=["nvidia", "stocks"],
|
||||
)
|
||||
|
||||
# Check engagement metrics in the context
|
||||
assert len(context.engagement_metrics) > 0
|
||||
assert (
|
||||
"total_engagement" in context.engagement_metrics
|
||||
or "total_engagement" in context.metadata
|
||||
)
|
||||
|
||||
# Verify post scores
|
||||
for post in context.posts:
|
||||
# Posts should have score and comments
|
||||
assert post.score >= 0
|
||||
assert post.comments >= 0
|
||||
|
||||
|
||||
def test_subreddit_filtering():
|
||||
"""Test filtering by specific subreddits."""
|
||||
mock_reddit = MockRedditClient()
|
||||
service = SocialMediaService(
|
||||
reddit_client=mock_reddit, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
# Test with specific subreddits
|
||||
context = service.get_company_social_context(
|
||||
symbol="AMD",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-02",
|
||||
subreddits=["AMD_Stock", "wallstreetbets"],
|
||||
)
|
||||
|
||||
# Check that posts are from requested subreddits
|
||||
subreddit_set = set()
|
||||
for post in context.posts:
|
||||
if post.subreddit:
|
||||
subreddit_set.add(post.subreddit)
|
||||
|
||||
assert len(subreddit_set) > 0
|
||||
assert all(sub in ["AMD_Stock", "wallstreetbets"] for sub in subreddit_set)
|
||||
|
||||
|
||||
def test_error_handling():
|
||||
"""Test error handling with broken client."""
|
||||
|
||||
class BrokenRedditClient(BaseClient):
|
||||
def test_connection(self):
|
||||
return False
|
||||
|
||||
def get_data(self, *args, **kwargs):
|
||||
raise Exception("Reddit API error")
|
||||
|
||||
def search_posts(self, *args, **kwargs):
|
||||
raise Exception("Reddit API error")
|
||||
|
||||
def get_top_posts(self, *args, **kwargs):
|
||||
raise Exception("Reddit API error")
|
||||
|
||||
broken_client = BrokenRedditClient()
|
||||
service = SocialMediaService(
|
||||
reddit_client=broken_client, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
# Should handle errors gracefully
|
||||
context = service.get_context("TSLA", "2024-01-01", "2024-01-02", symbol="TSLA")
|
||||
|
||||
assert context.symbol == "TSLA"
|
||||
assert len(context.posts) == 0
|
||||
assert context.metadata.get("data_quality") == DataQuality.LOW
|
||||
|
||||
|
||||
def test_json_structure():
|
||||
"""Test JSON structure of social context."""
|
||||
mock_reddit = MockRedditClient()
|
||||
service = SocialMediaService(
|
||||
reddit_client=mock_reddit, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_context("PLTR", "2024-01-01", "2024-01-02")
|
||||
json_data = context.model_dump()
|
||||
|
||||
# Validate required fields
|
||||
required_fields = [
|
||||
"symbol",
|
||||
"period",
|
||||
"posts",
|
||||
"sentiment_summary",
|
||||
"post_count",
|
||||
"platforms",
|
||||
"metadata",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in json_data
|
||||
|
||||
# Validate posts structure
|
||||
if json_data["posts"]:
|
||||
first_post = json_data["posts"][0]
|
||||
post_fields = ["title", "author", "date", "score"]
|
||||
for field in post_fields:
|
||||
assert field in first_post
|
||||
|
||||
# Validate sentiment structure
|
||||
sentiment = json_data["sentiment_summary"]
|
||||
assert "score" in sentiment
|
||||
assert "confidence" in sentiment
|
||||
assert "label" in sentiment
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
# TradingAgents/graph/conditional_logic.py
|
||||
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
from tradingagents.agents.libs.agent_states import AgentState
|
||||
|
||||
|
||||
class ConditionalLogic:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
from tradingagents.agents.libs.agent_states import (
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ from tradingagents.agents import (
|
|||
create_social_media_analyst,
|
||||
create_trader,
|
||||
)
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
from tradingagents.agents.utils.agent_utils import Toolkit
|
||||
from tradingagents.agents.libs.agent_states import AgentState
|
||||
from tradingagents.agents.libs.agent_toolkit import AgentToolkit
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ class GraphSetup:
|
|||
self,
|
||||
quick_thinking_llm: ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI,
|
||||
deep_thinking_llm: ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI,
|
||||
toolkit: Toolkit,
|
||||
toolkit: AgentToolkit,
|
||||
tool_nodes: dict[str, ToolNode],
|
||||
bull_memory,
|
||||
bear_memory,
|
||||
|
|
|
|||
|
|
@ -9,9 +9,16 @@ from langchain_google_genai import ChatGoogleGenerativeAI
|
|||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import Toolkit
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.agents.libs.agent_toolkit import AgentToolkit
|
||||
from tradingagents.agents.libs.memory import FinancialSituationMemory
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
from tradingagents.domains.marketdata.fundamental_data_service import (
|
||||
FundamentalDataService,
|
||||
)
|
||||
from tradingagents.domains.marketdata.insider_data_service import InsiderDataService
|
||||
from tradingagents.domains.marketdata.market_data_service import MarketDataService
|
||||
from tradingagents.domains.news.news_service import NewsService
|
||||
from tradingagents.domains.socialmedia.social_media_service import SocialMediaService
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .propagation import Propagator
|
||||
|
|
@ -77,7 +84,18 @@ class TradingAgentsGraph:
|
|||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {self.config.llm_provider}")
|
||||
|
||||
self.toolkit = Toolkit(config=self.config)
|
||||
news_service = NewsService.build(self.config)
|
||||
social_media_service = SocialMediaService.build(self.config)
|
||||
market_data_service = MarketDataService.build(self.config)
|
||||
fundamental_data_service = FundamentalDataService.build(self.config)
|
||||
insider_data_service = InsiderDataService.build(self.config)
|
||||
self.toolkit = AgentToolkit(
|
||||
news_service,
|
||||
social_media_service,
|
||||
market_data_service,
|
||||
fundamental_data_service,
|
||||
insider_data_service,
|
||||
)
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
|
|
@ -125,42 +143,28 @@ class TradingAgentsGraph:
|
|||
return {
|
||||
"market": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_YFin_data_online,
|
||||
self.toolkit.get_stockstats_indicators_report_online,
|
||||
# offline tools
|
||||
self.toolkit.get_YFin_data,
|
||||
self.toolkit.get_stockstats_indicators_report,
|
||||
self.toolkit.get_market_data,
|
||||
self.toolkit.get_ta_report,
|
||||
]
|
||||
),
|
||||
"social": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_stock_news_openai,
|
||||
# offline tools
|
||||
self.toolkit.get_reddit_stock_info,
|
||||
self.toolkit.get_socialmedia_stock_info,
|
||||
]
|
||||
),
|
||||
"news": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_global_news_openai,
|
||||
self.toolkit.get_google_news,
|
||||
# offline tools
|
||||
self.toolkit.get_finnhub_news,
|
||||
self.toolkit.get_reddit_news,
|
||||
self.toolkit.get_global_news,
|
||||
self.toolkit.get_news,
|
||||
]
|
||||
),
|
||||
"fundamentals": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_fundamentals_openai,
|
||||
# offline tools
|
||||
self.toolkit.get_finnhub_company_insider_sentiment,
|
||||
self.toolkit.get_finnhub_company_insider_transactions,
|
||||
self.toolkit.get_simfin_balance_sheet,
|
||||
self.toolkit.get_simfin_cashflow,
|
||||
self.toolkit.get_simfin_income_stmt,
|
||||
self.toolkit.get_insider_sentiment,
|
||||
self.toolkit.get_insider_transactions,
|
||||
self.toolkit.get_balance_sheet,
|
||||
self.toolkit.get_cashflow,
|
||||
self.toolkit.get_income_stmt,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
|
|
|||
112
uv.lock
112
uv.lock
|
|
@ -633,6 +633,17 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/32/b6/7517af5234378518f27ad35a7b24af9591bc500b8c1780929c1295999eb6/fastapi-0.115.9-py3-none-any.whl", hash = "sha256:4a439d7923e4de796bcc88b64e9754340fcd1574673cbd865ba8a99fe0d28c56", size = 94919, upload-time = "2025-02-27T16:43:40.537Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "feedfinder2"
|
||||
version = "0.0.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "requests" },
|
||||
{ name = "six" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/82/1251fefec3bb4b03fd966c7e7f7a41c9fc2bb00d823a34c13f847fd61406/feedfinder2-0.0.4.tar.gz", hash = "sha256:3701ee01a6c85f8b865a049c30ba0b4608858c803fe8e30d1d289fdbe89d0efe", size = 3297, upload-time = "2016-01-25T15:09:17.492Z" }
|
||||
|
||||
[[package]]
|
||||
name = "feedparser"
|
||||
version = "6.0.11"
|
||||
|
|
@ -1038,6 +1049,12 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jieba3k"
|
||||
version = "0.35.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a9/cb/2c8332bcdc14d33b0bedd18ae0a4981a069c3513e445120da3c3f23a8aaa/jieba3k-0.35.1.zip", hash = "sha256:980a4f2636b778d312518066be90c7697d410dd5a472385f5afced71a2db1c10", size = 7423646, upload-time = "2014-11-15T05:47:47.978Z" }
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.6"
|
||||
|
|
@ -1095,6 +1112,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "joblib"
|
||||
version = "1.5.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/dc/fe/0f5a938c54105553436dbff7a61dc4fed4b1b2c98852f8833beaf4d5968f/joblib-1.5.1.tar.gz", hash = "sha256:f4f86e351f39fe3d0d32a9f2c3d8af1ee4cec285aafcb27003dda5205576b444", size = 330475, upload-time = "2025-05-23T12:04:37.097Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/4f/1195bbac8e0c2acc5f740661631d8d750dc38d4a32b23ee5df3cde6f4e0d/joblib-1.5.1-py3-none-any.whl", hash = "sha256:4719a31f054c7d766948dcd83e9613686b27114f190f717cec7eaa2084f8a74a", size = 307746, upload-time = "2025-05-23T12:04:35.124Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jsonpatch"
|
||||
version = "1.33"
|
||||
|
|
@ -1673,6 +1699,45 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "newspaper3k"
|
||||
version = "0.2.8"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "cssselect" },
|
||||
{ name = "feedfinder2" },
|
||||
{ name = "feedparser" },
|
||||
{ name = "jieba3k" },
|
||||
{ name = "lxml" },
|
||||
{ name = "nltk" },
|
||||
{ name = "pillow" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "tinysegmenter" },
|
||||
{ name = "tldextract" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ce/fb/8f8525be0cafa48926e85b0c06a7cb3e2a892d340b8036f8c8b1b572df1c/newspaper3k-0.2.8.tar.gz", hash = "sha256:9f1bd3e1fb48f400c715abf875cc7b0a67b7ddcd87f50c9aeeb8fcbbbd9004fb", size = 205685, upload-time = "2018-09-28T04:58:23.53Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/b9/51afecb35bb61b188a4b44868001de348a0e8134b4dfa00ffc191567c4b9/newspaper3k-0.2.8-py3-none-any.whl", hash = "sha256:44a864222633d3081113d1030615991c3dbba87239f6bbf59d91240f71a22e3e", size = 211132, upload-time = "2018-09-28T04:58:18.847Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nltk"
|
||||
version = "3.9.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "joblib" },
|
||||
{ name = "regex" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691, upload-time = "2024-08-18T19:48:37.769Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442, upload-time = "2024-08-18T19:48:21.909Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.9.1"
|
||||
|
|
@ -3058,6 +3123,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847, upload-time = "2025-06-09T16:43:05.728Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests-file"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/97/bf44e6c6bd8ddbb99943baf7ba8b1a8485bcd2fe0e55e5708d7fee4ff1ae/requests_file-2.1.0.tar.gz", hash = "sha256:0f549a3f3b0699415ac04d167e9cb39bccfb730cb832b4d20be3d9867356e658", size = 6891, upload-time = "2024-05-21T16:28:00.24Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/25/dd878a121fcfdf38f52850f11c512e13ec87c2ea72385933818e5b6c15ce/requests_file-2.1.0-py2.py3-none-any.whl", hash = "sha256:cf270de5a4c5874e84599fc5778303d496c10ae5e870bfa378818f35d21bda5c", size = 4244, upload-time = "2024-05-21T16:27:57.733Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "requests-oauthlib"
|
||||
version = "2.0.0"
|
||||
|
|
@ -3329,6 +3406,16 @@ version = "2.0.3"
|
|||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8d/dd/d4dd75843692690d81f0a4b929212a1614b25d4896aa7c72f4c3546c7e3d/syncer-2.0.3.tar.gz", hash = "sha256:4340eb54b54368724a78c5c0763824470201804fe9180129daf3635cb500550f", size = 11512, upload-time = "2023-05-08T07:50:17.963Z" }
|
||||
|
||||
[[package]]
|
||||
name = "ta-lib"
|
||||
version = "0.6.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
{ name = "setuptools" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ba/97/a49816dd468a18ee080cf3a04640772a9f6321790d4049cece2490c4b7ad/ta_lib-0.6.4.tar.gz", hash = "sha256:08f55bc5771a6d1ceb1a2b713aad7b05f04eb0061e980c9113571c532d32e9cb", size = 381774, upload-time = "2025-06-08T15:28:15.452Z" }
|
||||
|
||||
[[package]]
|
||||
name = "tenacity"
|
||||
version = "9.1.2"
|
||||
|
|
@ -3356,6 +3443,27 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669, upload-time = "2025-02-14T06:02:47.341Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinysegmenter"
|
||||
version = "0.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/17/82/86982e4b6d16e4febc79c2a1d68ee3b707e8a020c5d2bc4af8052d0f136a/tinysegmenter-0.3.tar.gz", hash = "sha256:ed1f6d2e806a4758a73be589754384cbadadc7e1a414c81a166fc9adf2d40c6d", size = 16893, upload-time = "2017-07-23T11:18:29.85Z" }
|
||||
|
||||
[[package]]
|
||||
name = "tldextract"
|
||||
version = "5.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
{ name = "idna" },
|
||||
{ name = "requests" },
|
||||
{ name = "requests-file" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/97/78/182641ea38e3cfd56e9c7b3c0d48a53d432eea755003aa544af96403d4ac/tldextract-5.3.0.tar.gz", hash = "sha256:b3d2b70a1594a0ecfa6967d57251527d58e00bb5a91a74387baa0d87a0678609", size = 128502, upload-time = "2025-04-22T06:19:37.491Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/67/7c/ea488ef48f2f544566947ced88541bc45fae9e0e422b2edbf165ee07da99/tldextract-5.3.0-py3-none-any.whl", hash = "sha256:f70f31d10b55c83993f55e91ecb7c5d84532a8972f22ec578ecfbe5ea2292db2", size = 107384, upload-time = "2025-04-22T06:19:36.304Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.21.1"
|
||||
|
|
@ -3483,6 +3591,7 @@ dependencies = [
|
|||
{ name = "langchain-google-genai" },
|
||||
{ name = "langchain-openai" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "newspaper3k" },
|
||||
{ name = "pandas" },
|
||||
{ name = "parsel" },
|
||||
{ name = "praw" },
|
||||
|
|
@ -3494,6 +3603,7 @@ dependencies = [
|
|||
{ name = "rich" },
|
||||
{ name = "setuptools" },
|
||||
{ name = "stockstats" },
|
||||
{ name = "ta-lib" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "tushare" },
|
||||
{ name = "typer" },
|
||||
|
|
@ -3532,6 +3642,7 @@ requires-dist = [
|
|||
{ name = "langchain-google-genai", specifier = ">=2.1.5" },
|
||||
{ name = "langchain-openai", specifier = ">=0.3.23" },
|
||||
{ name = "langgraph", specifier = ">=0.4.8" },
|
||||
{ name = "newspaper3k", specifier = ">=0.2.8" },
|
||||
{ name = "pandas", specifier = ">=2.3.0" },
|
||||
{ name = "parsel", specifier = ">=1.10.0" },
|
||||
{ name = "praw", specifier = ">=7.8.1" },
|
||||
|
|
@ -3548,6 +3659,7 @@ requires-dist = [
|
|||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.0" },
|
||||
{ name = "setuptools", specifier = ">=80.9.0" },
|
||||
{ name = "stockstats", specifier = ">=0.6.5" },
|
||||
{ name = "ta-lib", specifier = ">=0.4.28" },
|
||||
{ name = "tqdm", specifier = ">=4.67.1" },
|
||||
{ name = "tushare", specifier = ">=1.4.21" },
|
||||
{ name = "typer", specifier = ">=0.12.0" },
|
||||
|
|
|
|||
Loading…
Reference in New Issue