wip
This commit is contained in:
parent
b2a09403fa
commit
c93ffb6452
166
AGENTS.md
166
AGENTS.md
|
|
@ -1,166 +0,0 @@
|
|||
# AGENTS.md - TradingAgents Development Guide
|
||||
|
||||
## What TradingAgents Does
|
||||
|
||||
**TradingAgents** is a multi-agent LLM financial trading framework that simulates a real-world trading firm using specialized AI agents. The system analyzes stocks through collaborative decision-making, mirroring professional trading teams.
|
||||
|
||||
### Core Architecture
|
||||
- Built on **LangGraph** with state-based workflows
|
||||
- **TradingAgentsGraph** orchestrates the entire process
|
||||
- Agents work sequentially and in parallel to analyze market conditions
|
||||
|
||||
### Agent Teams & Workflow
|
||||
1. **Analyst Team**: Market, Social Media, News, and Fundamentals analysts gather data
|
||||
2. **Research Team**: Bull/Bear researchers debate, Research Manager decides
|
||||
3. **Trading Team**: Trader develops detailed trading plans
|
||||
4. **Risk Management**: Risk analysts debate, Risk Manager makes final decision
|
||||
|
||||
### Data Sources
|
||||
- Yahoo Finance, FinnHub API, Reddit, Google News, StockStats
|
||||
- Supports both real-time online data and cached offline data for backtesting
|
||||
|
||||
### Decision Process
|
||||
Sequential analysis → Structured debate → Managerial oversight → Risk assessment → Memory & learning
|
||||
|
||||
## Build/Test Commands
|
||||
|
||||
This project uses [mise](https://mise.jdx.dev/) for tool and task management. All development tasks are managed through mise.
|
||||
|
||||
### Initial Setup
|
||||
- **First-time setup**: `mise run setup` - Install tools and dependencies
|
||||
- **Install tools only**: `mise install` - Install Python, uv, ruff, pyright
|
||||
- **Install dependencies**: `mise run install` - Install project dependencies with uv
|
||||
|
||||
### Development Workflow
|
||||
- **CLI Application**: `mise run dev` - Interactive CLI for running trading analysis
|
||||
- **Direct Python Usage**: `mise run run` - Run main.py programmatically
|
||||
- **Format code**: `mise run format` - Auto-format with ruff
|
||||
- **Lint code**: `mise run lint` - Check code quality with ruff
|
||||
- **Type checking**: `mise run typecheck` - Run pyright type checker
|
||||
- **Fix lint issues**: `mise run fix` - Auto-fix linting issues
|
||||
- **Run all checks**: `mise run all` - Format, lint, and typecheck
|
||||
- **Clean artifacts**: `mise run clean` - Remove cache and build files
|
||||
|
||||
### Testing
|
||||
- **Run tests**: `mise run test` - Run tests with pytest (when available)
|
||||
|
||||
### Configuration
|
||||
- **Environment Variables**: Create `.env` file with API keys (see `.env.example`)
|
||||
- **Tool Configuration**: `.mise.toml` manages Python 3.13, uv, ruff, pyright
|
||||
- **Code Quality**: `pyproject.toml` contains ruff and pyright configurations
|
||||
|
||||
## Configuration System
|
||||
|
||||
### Environment Variables
|
||||
Create `.env` file with API keys (see `.env.example`):
|
||||
|
||||
#### Core LLM APIs (Choose One)
|
||||
```bash
|
||||
# For OpenAI (default)
|
||||
export OPENAI_API_KEY="your_openai_api_key"
|
||||
|
||||
# For Anthropic Claude
|
||||
export ANTHROPIC_API_KEY="your_anthropic_api_key"
|
||||
|
||||
# For Google Gemini
|
||||
export GOOGLE_API_KEY="your_google_api_key"
|
||||
```
|
||||
|
||||
#### Data Sources (Optional)
|
||||
```bash
|
||||
# For financial data
|
||||
export FINNHUB_API_KEY="your_finnhub_api_key"
|
||||
|
||||
# For Reddit data
|
||||
export REDDIT_CLIENT_ID="your_reddit_client_id"
|
||||
export REDDIT_CLIENT_SECRET="your_reddit_client_secret"
|
||||
export REDDIT_USER_AGENT="your_app_name"
|
||||
```
|
||||
|
||||
### Configuration Management
|
||||
- **Config Class**: `TradingAgentsConfig` in `tradingagents/config.py` handles all configuration
|
||||
- Use `TradingAgentsConfig.from_env()` for environment-based configuration
|
||||
- Key settings: `max_debate_rounds`, `llm_provider`, `online_tools`
|
||||
- Results are saved to `results_dir/{ticker}/{date}/` with structured reports
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### Anthropic Setup
|
||||
```python
|
||||
config = TradingAgentsConfig.from_env()
|
||||
config.llm_provider = "anthropic"
|
||||
config.deep_think_llm = "claude-3-5-sonnet-20241022"
|
||||
config.quick_think_llm = "claude-3-5-haiku-20241022"
|
||||
```
|
||||
|
||||
#### Google Gemini Setup
|
||||
```python
|
||||
config.llm_provider = "google"
|
||||
config.deep_think_llm = "gemini-2.0-flash"
|
||||
config.quick_think_llm = "gemini-2.0-flash"
|
||||
```
|
||||
|
||||
#### Data Mode Configuration
|
||||
- `config.online_tools = True` - Real-time data (requires API keys)
|
||||
- `config.online_tools = False` - Cached data (faster, historical only)
|
||||
|
||||
## Code Style Guidelines
|
||||
|
||||
- **Imports**: Standard library first, third-party, then local imports (langchain, tradingagents modules)
|
||||
- **Formatting**: Auto-formatted with ruff (`mise run format`)
|
||||
- **Linting**: Code quality checked with ruff (`mise run lint`)
|
||||
- **Type Checking**: Static analysis with pyright (`mise run typecheck`)
|
||||
- **Functions**: Snake_case naming (e.g., `fundamentals_analyst_node`, `create_fundamentals_analyst`)
|
||||
- **Classes**: PascalCase (e.g., `TradingAgentsGraph`, `MessageBuffer`)
|
||||
- **Variables**: Snake_case (e.g., `current_date`, `company_of_interest`)
|
||||
- **Constants**: UPPER_CASE (e.g., `DEFAULT_CONFIG`)
|
||||
|
||||
## Project Structure
|
||||
|
||||
- **Main entry**: `main.py` for package usage, `cli/main.py` for CLI
|
||||
- **Core logic**: `tradingagents/` package with agents, dataflows, graph modules
|
||||
- **Configuration**: `tradingagents/config.py` for LLM and system settings
|
||||
- **CLI interface**: `cli/` directory with rich-based terminal UI
|
||||
- **Tool Management**: `.mise.toml` for development tool configuration
|
||||
- **Dependencies**: `pyproject.toml` for project dependencies and tool settings
|
||||
|
||||
## Key Patterns
|
||||
|
||||
- **Agent creation**: Factory functions that return node functions (e.g., `create_fundamentals_analyst`)
|
||||
- **State management**: Dictionary-based state passed between graph nodes
|
||||
- **Tool integration**: LangChain tools bound to LLMs via `llm.bind_tools(tools)`
|
||||
- **Configuration**: Use `TradingAgentsConfig.from_env()` for environment-based configuration
|
||||
- **Debate-Driven Decision Making**: Critical decisions emerge from structured agent debates
|
||||
- **Memory-Augmented Learning**: Agents learn from past similar situations using vector similarity
|
||||
- **Dual-Mode Data Access**: Support for both live API calls and pre-processed cached data
|
||||
- **Factory Pattern**: Agent creation via factory functions for flexible configuration
|
||||
- **Signal Processing**: Final trading decisions processed into clean BUY/SELL/HOLD signals
|
||||
|
||||
## Development Guidelines
|
||||
|
||||
### Working with Agents
|
||||
- Each agent has its own memory instance in `FinancialSituationMemory`
|
||||
- Agents use the unified `Toolkit` for data access
|
||||
- Agent state is passed sequentially through the workflow
|
||||
- Configuration affects debate rounds, LLM selection, and data sources
|
||||
|
||||
### Working with Data Sources
|
||||
- All data utilities follow consistent date range patterns: `curr_date + look_back_days`
|
||||
- Interface functions return markdown-formatted strings for LLM consumption
|
||||
- Check `online_tools` config flag to determine live vs cached data usage
|
||||
- Data caching happens in `data_cache_dir` for online mode
|
||||
|
||||
### CLI Development
|
||||
- CLI uses Rich for terminal UI with live updating displays
|
||||
- Agent progress tracking through `MessageBuffer` class
|
||||
- Questionnaire-driven configuration collection
|
||||
- Real-time streaming of analysis results
|
||||
|
||||
### File Structure Context
|
||||
- **`cli/`**: Interactive command-line interface
|
||||
- **`tradingagents/agents/`**: All agent implementations
|
||||
- **`tradingagents/dataflows/`**: Data source integrations
|
||||
- **`tradingagents/graph/`**: LangGraph workflow orchestration
|
||||
- **`tradingagents/config.py`**: Configuration management
|
||||
- **`main.py`**: Direct Python usage example
|
||||
- **`CLAUDE.md`**: Guidance for Claude Code development
|
||||
327
CLAUDE.md
327
CLAUDE.md
|
|
@ -1,327 +0,0 @@
|
|||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Common Development Commands
|
||||
|
||||
This project uses [mise](https://mise.jdx.dev/) for tool and task management. All development tasks are managed through mise.
|
||||
|
||||
### Initial Setup
|
||||
- **First-time setup**: `mise run setup` - Install tools and dependencies
|
||||
- **Install tools only**: `mise install` - Install Python, uv, ruff, pyright
|
||||
- **Install dependencies**: `mise run install` - Install project dependencies with uv
|
||||
|
||||
### Development Workflow
|
||||
- **CLI Application**: `mise run dev` - Interactive CLI for running trading analysis
|
||||
- **Direct Python Usage**: `mise run run` - Run main.py programmatically
|
||||
- **Format code**: `mise run format` - Auto-format with ruff
|
||||
- **Lint code**: `mise run lint` - Check code quality with ruff
|
||||
- **Type checking**: `mise run typecheck` - Run pyright type checker
|
||||
- **Fix lint issues**: `mise run fix` - Auto-fix linting issues
|
||||
- **Run all checks**: `mise run all` - Format, lint, and typecheck
|
||||
- **Clean artifacts**: `mise run clean` - Remove cache and build files
|
||||
|
||||
### Testing
|
||||
|
||||
#### Running Tests
|
||||
- **Run all tests**: `mise run test` - Run tests with pytest
|
||||
- **Run specific test file**: `uv run pytest test_social_media_service.py` - Run individual test file
|
||||
- **Verbose output**: `uv run pytest -v` - Run tests with detailed output
|
||||
- **Run with output**: `uv run pytest -s` - Show print statements and debug output
|
||||
- **Test coverage**: `uv run pytest --cov=tradingagents` - Run tests with coverage report
|
||||
|
||||
#### Test Development (TDD Approach)
|
||||
This project follows **Test-Driven Development (TDD)** for service layer development:
|
||||
|
||||
1. **Write test first**: Create `test_{service_name}_service.py` with comprehensive test cases
|
||||
2. **Run test (should fail)**: Verify test fails with appropriate error messages
|
||||
3. **Implement minimum code**: Write just enough code to make the test pass
|
||||
4. **Refactor**: Improve code while keeping tests passing
|
||||
5. **Repeat**: Add more test cases and implement additional functionality
|
||||
|
||||
#### Test Structure and Conventions
|
||||
- **Test files**: Named `test_{component}_service.py` and placed next to source code (not in separate tests/ directory)
|
||||
- **Test functions**: Named `test_{functionality}()` and should not return values (use `assert` statements)
|
||||
- **Mock clients**: Create mock implementations of BaseClient for testing services
|
||||
- **Real repositories**: Use actual repository implementations (don't mock the repository layer)
|
||||
- **Test data**: Use realistic mock data that matches expected API responses
|
||||
- **Date handling**: Use fixed dates (e.g., `datetime(2024, 1, 2)`) in mocks for predictable filtering
|
||||
|
||||
#### Service Testing Pattern
|
||||
Example test structure for services:
|
||||
```python
|
||||
def test_online_mode_with_mock_client():
|
||||
"""Test service in online mode with mock client."""
|
||||
mock_client = MockServiceClient()
|
||||
real_repo = ServiceRepository("test_data")
|
||||
|
||||
service = ServiceClass(
|
||||
client=mock_client,
|
||||
repository=real_repo,
|
||||
online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_context("TEST", "2024-01-01", "2024-01-05")
|
||||
|
||||
# Validate structure
|
||||
assert isinstance(context, ContextModel)
|
||||
assert context.symbol == "TEST"
|
||||
assert len(context.data) > 0
|
||||
|
||||
# Test JSON serialization
|
||||
json_output = context.model_dump_json()
|
||||
assert len(json_output) > 0
|
||||
```
|
||||
|
||||
#### Mock Client Guidelines
|
||||
- **Extend BaseClient**: All mock clients must implement the abstract `get_data()` method
|
||||
- **Realistic data**: Return data structures that match actual API responses
|
||||
- **Date consistency**: Use fixed dates that work with test date ranges
|
||||
- **Error simulation**: Create broken clients for testing error handling paths
|
||||
- **Multiple scenarios**: Provide different data for different test cases
|
||||
|
||||
### Configuration
|
||||
- **Environment Variables**: Create `.env` file with API keys (see `.env.example`)
|
||||
- **Config Class**: `TradingAgentsConfig` in `tradingagents/config.py` handles all configuration
|
||||
- **Tool Configuration**: `.mise.toml` manages Python 3.13, uv, ruff, pyright
|
||||
- **Code Quality**: `pyproject.toml` contains ruff and pyright configurations
|
||||
|
||||
#### Required Environment Variables
|
||||
|
||||
##### Core LLM APIs (Choose One)
|
||||
```bash
|
||||
# For OpenAI (default)
|
||||
export OPENAI_API_KEY="your_openai_api_key"
|
||||
|
||||
# For Anthropic Claude
|
||||
export ANTHROPIC_API_KEY="your_anthropic_api_key"
|
||||
|
||||
# For Google Gemini
|
||||
export GOOGLE_API_KEY="your_google_api_key"
|
||||
```
|
||||
|
||||
##### Data Sources (Optional)
|
||||
```bash
|
||||
# For financial data
|
||||
export FINNHUB_API_KEY="your_finnhub_api_key"
|
||||
|
||||
# For Reddit data
|
||||
export REDDIT_CLIENT_ID="your_reddit_client_id"
|
||||
export REDDIT_CLIENT_SECRET="your_reddit_client_secret"
|
||||
export REDDIT_USER_AGENT="your_app_name"
|
||||
```
|
||||
|
||||
## High-Level Architecture
|
||||
|
||||
### Multi-Agent Trading Framework
|
||||
TradingAgents implements a sophisticated multi-agent system that mirrors real-world trading firms with specialized roles and structured workflows.
|
||||
|
||||
### Core Architecture Components
|
||||
|
||||
#### 1. **Agent Teams** (Sequential Workflow)
|
||||
```
|
||||
Analyst Team → Research Team → Trading Team → Risk Management Team
|
||||
```
|
||||
|
||||
**Analyst Team** (`tradingagents/agents/analysts/`)
|
||||
- **Market Analyst**: Technical analysis using Yahoo Finance and StockStats
|
||||
- **Fundamentals Analyst**: Financial statements and company fundamentals via SimFin/Finnhub
|
||||
- **News Analyst**: News sentiment analysis and world affairs impact
|
||||
- **Social Media Analyst**: Reddit and social platform sentiment analysis
|
||||
|
||||
**Research Team** (`tradingagents/agents/researchers/`)
|
||||
- **Bull Researcher**: Advocates for investment opportunities and growth potential
|
||||
- **Bear Researcher**: Highlights risks and argues against investments
|
||||
- **Research Manager**: Synthesizes debates and creates investment recommendations
|
||||
|
||||
**Trading Team** (`tradingagents/agents/trader/`)
|
||||
- **Trader**: Converts investment plans into specific trading decisions
|
||||
|
||||
**Risk Management Team** (`tradingagents/agents/risk_mgmt/`)
|
||||
- **Aggressive/Conservative/Neutral Debators**: Different risk perspectives
|
||||
- **Risk Manager**: Final decision maker balancing risk and reward
|
||||
|
||||
#### 2. **Data Layer** (`tradingagents/services/` + `tradingagents/dataflows/`)
|
||||
**New Service-Based Architecture** (Current):
|
||||
- **Service Layer**: `MarketDataService`, `NewsService`, `SocialMediaService`, `FundamentalDataService`, `InsiderDataService`, `OpenAIDataService`
|
||||
- Orchestrate between clients and repositories with local-first data strategy
|
||||
- Provide structured JSON contexts via Pydantic models
|
||||
- Support `force_refresh=False` parameter for explicit data refreshing
|
||||
- Automatically check local cache first, fetch from APIs only when data is missing
|
||||
- **Client Layer**: Live API integrations - `YFinanceClient`, `FinnhubClient`, `RedditClient`, `GoogleNewsClient`, `SimFinClient`
|
||||
- **Repository Layer**: Smart cached data storage with gap detection - `MarketDataRepository`, `NewsRepository`, `SocialRepository`
|
||||
- **Context Models**: Pydantic models for structured JSON data - `MarketDataContext`, `NewsContext`, `SocialContext`, `FundamentalContext`
|
||||
- **Toolkit Integration**: `ServiceToolkit` provides 100% backward-compatible interface returning JSON contexts
|
||||
|
||||
**Legacy Data Integration System** (Being phased out):
|
||||
- **Yahoo Finance** (`yfin_utils.py`): Stock prices, financials, analyst recommendations
|
||||
- **Finnhub** (`finnhub_utils.py`): News, insider trading, SEC filings
|
||||
- **Reddit** (`reddit_utils.py`): Social sentiment from curated subreddits
|
||||
- **Google News** (`googlenews_utils.py`): Web-scraped news with retry logic
|
||||
- **SimFin**: Balance sheets, cash flow, income statements
|
||||
- **StockStats** (`stockstats_utils.py`): Technical indicators (MACD, RSI, etc.)
|
||||
- **Interface Layer** (`interface.py`): Standardized agent-facing APIs with markdown formatting
|
||||
|
||||
#### 3. **Graph Orchestration** (`tradingagents/graph/`)
|
||||
LangGraph-based workflow management:
|
||||
|
||||
- **TradingAgentsGraph**: Main orchestrator class
|
||||
- **State Management**: `AgentState`, `InvestDebateState`, `RiskDebateState` track workflow progress
|
||||
- **Conditional Logic**: Dynamic routing based on tool usage and debate completion
|
||||
- **Memory System**: ChromaDB-based vector memory for learning from past decisions
|
||||
|
||||
#### 4. **Configuration System**
|
||||
- **TradingAgentsConfig**: Centralized configuration with environment variable support
|
||||
- **Multi-LLM Support**: OpenAI, Anthropic, Google, Ollama, OpenRouter
|
||||
- **Data Modes**: Online (live APIs) vs offline (cached data)
|
||||
|
||||
### Key Design Patterns
|
||||
|
||||
1. **Debate-Driven Decision Making**: Critical decisions emerge from structured agent debates
|
||||
2. **Memory-Augmented Learning**: Agents learn from past similar situations using vector similarity
|
||||
3. **Local-First Data Strategy**: Automatically check local cache first, fetch from APIs only when needed
|
||||
4. **Structured JSON Contexts**: Replace error-prone string parsing with rich Pydantic models
|
||||
5. **Factory Pattern**: Agent creation via factory functions for flexible configuration
|
||||
6. **Signal Processing**: Final trading decisions processed into clean BUY/SELL/HOLD signals
|
||||
7. **Quality-Aware Data**: All contexts include quality metadata to help agents make better decisions
|
||||
|
||||
### Code Style Guidelines
|
||||
|
||||
#### General Style
|
||||
- **Functions**: Snake_case naming (e.g., `fundamentals_analyst_node`, `create_fundamentals_analyst`)
|
||||
- **Classes**: PascalCase (e.g., `TradingAgentsGraph`, `MessageBuffer`)
|
||||
- **Variables**: Snake_case (e.g., `current_date`, `company_of_interest`)
|
||||
- **Constants**: UPPER_CASE (e.g., `DEFAULT_CONFIG`)
|
||||
- **Imports**: Standard library first, third-party, then local imports (langchain, tradingagents modules)
|
||||
|
||||
#### Ruff Formatting & Linting Rules
|
||||
**Formatting** (`mise run format`):
|
||||
- **Line length**: 88 characters maximum
|
||||
- **Quote style**: Double quotes (`"string"`)
|
||||
- **Indentation**: 4 spaces (no tabs)
|
||||
- **Trailing commas**: Preserved for multi-line structures
|
||||
- **Line endings**: Auto-detected based on platform
|
||||
|
||||
**Linting** (`mise run lint`):
|
||||
- **Selected rules**:
|
||||
- `E`, `W`: pycodestyle errors and warnings
|
||||
- `F`: pyflakes (undefined names, unused imports)
|
||||
- `I`: isort (import sorting)
|
||||
- `B`: flake8-bugbear (common bugs)
|
||||
- `C4`: flake8-comprehensions (list/dict comprehensions)
|
||||
- `UP`: pyupgrade (Python syntax modernization)
|
||||
- `ARG`: flake8-unused-arguments
|
||||
- `SIM`: flake8-simplify (code simplification)
|
||||
- `TCH`: flake8-type-checking (type annotation imports)
|
||||
|
||||
- **Ignored rules**:
|
||||
- `E501`: Line too long (handled by formatter)
|
||||
- `B008`: Function calls in argument defaults (allowed for LangChain)
|
||||
- `C901`: Complex functions (legacy code tolerance)
|
||||
- `ARG001`, `ARG002`: Unused arguments (common in callbacks)
|
||||
|
||||
- **Import sorting**: `tradingagents` and `cli` treated as first-party modules
|
||||
|
||||
#### Pyright Type Checking Rules
|
||||
**Configuration** (`mise run typecheck`):
|
||||
- **Tool**: pyright 1.1.390+ with standard type checking mode
|
||||
- **Python version**: 3.10+ (configured for compatibility with modern syntax)
|
||||
- **Coverage**: Includes `tradingagents/`, `cli/`, and `main.py`
|
||||
- **Exclusions**: `__pycache__`, `node_modules`, `.venv`, `venv`, `build`, `dist`
|
||||
|
||||
**Configured Type Checking Rules**:
|
||||
- `reportMissingImports = true` - Catch undefined module imports
|
||||
- `reportMissingTypeStubs = false` - Allow libraries without type stubs
|
||||
- `reportGeneralTypeIssues = true` - General type inconsistencies
|
||||
- `reportOptionalMemberAccess = true` - Unsafe access to optional values
|
||||
- `reportOptionalCall = true` - Calling potentially None values
|
||||
- `reportOptionalIterable = true` - Iterating over potentially None values
|
||||
- `reportOptionalContextManager = true` - Using None in with statements
|
||||
- `reportOptionalOperand = true` - Operations on potentially None values
|
||||
- `reportTypedDictNotRequiredAccess = false` - Flexible TypedDict access
|
||||
- `reportPrivateImportUsage = false` - Allow importing private modules
|
||||
- `reportUnknownParameterType = false` - Allow untyped parameters
|
||||
- `reportUnknownArgumentType = false` - Allow untyped arguments
|
||||
- `reportUnknownLambdaType = false` - Allow untyped lambdas
|
||||
- `reportUnknownVariableType = false` - Allow untyped variables
|
||||
- `reportUnknownMemberType = false` - Allow untyped attributes
|
||||
|
||||
**Type Annotation Guidelines**:
|
||||
- Use modern Python 3.10+ union syntax: `str | None` instead of `Optional[str]`
|
||||
- Use built-in generics: `list[str]` instead of `List[str]`
|
||||
- Use `dict[str, Any]` for flexible dictionaries
|
||||
- Import `from typing import Any` for untyped data structures
|
||||
- Prefer explicit return types on public functions
|
||||
- Use `# type: ignore` sparingly with explanatory comments
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
#### Working with Agents
|
||||
|
||||
**Current Approach** (JSON Contexts):
|
||||
- Import `ServiceToolkit` from `tradingagents.agents.utils.service_toolkit`
|
||||
- Use `context_helpers` for parsing structured JSON data (MarketDataParser, NewsParser, etc.)
|
||||
- All toolkit methods return structured JSON instead of markdown strings
|
||||
- Check data quality with `is_high_quality_data()` before analysis
|
||||
- Use `extract_latest_price()`, `extract_sentiment_score()` for quick data extraction
|
||||
|
||||
**Legacy Approach** (Being phased out):
|
||||
- Agents use the unified `Toolkit` for markdown-formatted data access
|
||||
|
||||
#### Working with Data Sources
|
||||
|
||||
**Current Service-Based Approach**:
|
||||
- **Local-First Strategy**: Services automatically check local cache first, only fetch from APIs when data is missing
|
||||
- **Force Refresh**: Use `force_refresh=True` parameter to bypass local data and get fresh API data
|
||||
- **Structured Contexts**: Services return Pydantic models (`MarketDataContext`, `NewsContext`, etc.) with rich metadata
|
||||
- **Quality Awareness**: All contexts include data_quality (HIGH/MEDIUM/LOW) and data_source information
|
||||
- **JSON Serialization**: All context models support `.model_dump_json()` for agent consumption
|
||||
- **Smart Caching**: Repositories detect data gaps and fetch only missing periods
|
||||
- **Cost Efficient**: Minimizes expensive API calls through intelligent local-first logic
|
||||
|
||||
**Service Configuration**:
|
||||
```python
|
||||
# Services use dependency injection
|
||||
service = MarketDataService(
|
||||
client=YFinanceClient(),
|
||||
repository=MarketDataRepository("cache_dir"),
|
||||
online_mode=True # Enable API fetching when local data insufficient
|
||||
)
|
||||
|
||||
# Get data with local-first strategy
|
||||
context = service.get_context("AAPL", "2024-01-01", "2024-01-31")
|
||||
|
||||
# Force fresh data when needed
|
||||
fresh_context = service.get_context("AAPL", "2024-01-01", "2024-01-31", force_refresh=True)
|
||||
```
|
||||
|
||||
**Legacy Approach** (Being phased out):
|
||||
- Interface functions return markdown-formatted strings for LLM consumption
|
||||
- All data utilities follow consistent date range patterns: `curr_date + look_back_days`
|
||||
- Check `online_tools` config flag to determine live vs cached data usage
|
||||
- Data caching happens in `data_cache_dir` for online mode
|
||||
|
||||
#### Configuration Management
|
||||
- Use `TradingAgentsConfig.from_env()` for environment-based configuration
|
||||
- Key settings: `max_debate_rounds`, `llm_provider`, `online_tools`
|
||||
- Results are saved to `results_dir/{ticker}/{date}/` with structured reports
|
||||
|
||||
#### CLI Development
|
||||
- CLI uses Rich for terminal UI with live updating displays
|
||||
- Agent progress tracking through `MessageBuffer` class
|
||||
- Questionnaire-driven configuration collection
|
||||
- Real-time streaming of analysis results
|
||||
|
||||
### File Structure Context
|
||||
- **`cli/`**: Interactive command-line interface
|
||||
- **`tradingagents/agents/`**: All agent implementations
|
||||
- **`utils/service_toolkit.py`**: New ServiceToolkit with JSON contexts (100% backward compatible)
|
||||
- **`utils/context_helpers.py`**: Helper functions for parsing structured JSON data
|
||||
- **`utils/agent_utils.py`**: Legacy Toolkit (being phased out)
|
||||
- **`tradingagents/services/`**: Service layer with local-first data strategy
|
||||
- **`tradingagents/dataflows/`**: Legacy data source integrations (being phased out)
|
||||
- **`tradingagents/graph/`**: LangGraph workflow orchestration
|
||||
- **`tradingagents/config.py`**: Configuration management
|
||||
- **`main.py`**: Direct Python usage example
|
||||
- **`AGENTS.md`**: Detailed agent documentation
|
||||
- **`examples/agent_json_migration.py`**: Migration guide from markdown to JSON contexts
|
||||
358
README.md
358
README.md
|
|
@ -192,6 +192,364 @@ print(decision)
|
|||
|
||||
You can view the full list of configurations in `tradingagents/default_config.py`.
|
||||
|
||||
## Development Guide
|
||||
|
||||
This section provides comprehensive development guidance for contributors working on the TradingAgents codebase.
|
||||
|
||||
### Common Development Commands
|
||||
|
||||
This project uses [mise](https://mise.jdx.dev/) for tool and task management. All development tasks are managed through mise.
|
||||
|
||||
#### Initial Setup
|
||||
- **First-time setup**: `mise run setup` - Install tools and dependencies
|
||||
- **Install tools only**: `mise install` - Install Python, uv, ruff, pyright
|
||||
- **Install dependencies**: `mise run install` - Install project dependencies with uv
|
||||
|
||||
#### Development Workflow
|
||||
- **CLI Application**: `mise run dev` - Interactive CLI for running trading analysis
|
||||
- **Direct Python Usage**: `mise run run` - Run main.py programmatically
|
||||
- **Format code**: `mise run format` - Auto-format with ruff
|
||||
- **Lint code**: `mise run lint` - Check code quality with ruff
|
||||
- **Type checking**: `mise run typecheck` - Run pyright type checker
|
||||
- **Fix lint issues**: `mise run fix` - Auto-fix linting issues
|
||||
- **Run all checks**: `mise run all` - Format, lint, and typecheck
|
||||
- **Clean artifacts**: `mise run clean` - Remove cache and build files
|
||||
|
||||
#### Testing
|
||||
|
||||
##### Running Tests
|
||||
- **Run all tests**: `mise run test` - Run tests with pytest
|
||||
- **Run specific test file**: `uv run pytest test_social_media_service.py` - Run individual test file
|
||||
- **Verbose output**: `uv run pytest -v` - Run tests with detailed output
|
||||
- **Run with output**: `uv run pytest -s` - Show print statements and debug output
|
||||
- **Test coverage**: `uv run pytest --cov=tradingagents` - Run tests with coverage report
|
||||
|
||||
##### Test Development (TDD Approach)
|
||||
This project follows **Test-Driven Development (TDD)** for service layer development:
|
||||
|
||||
1. **Write test first**: Create `{component}_service_test.py` with comprehensive test cases
|
||||
2. **Run test (should fail)**: Verify test fails with appropriate error messages
|
||||
3. **Implement minimum code**: Write just enough code to make the test pass
|
||||
4. **Refactor**: Improve code while keeping tests passing
|
||||
5. **Repeat**: Add more test cases and implement additional functionality
|
||||
|
||||
##### Test Structure and Conventions
|
||||
- **Test files**: Named `{component}_service_test.py` and placed next to source code (not in separate tests/ directory)
|
||||
- **Test functions**: Named `test_{functionality}()` and should not return values (use `assert` statements)
|
||||
- **Mock clients**: Use `unittest.mock.Mock()` objects for testing services
|
||||
- **Real repositories**: Use actual repository implementations (don't mock the repository layer)
|
||||
- **Test data**: Use realistic mock data that matches expected API responses
|
||||
- **Date handling**: Use fixed dates (e.g., `datetime(2024, 1, 2)`) in mocks for predictable filtering
|
||||
|
||||
##### Service Testing Pattern
|
||||
Example test structure for services:
|
||||
```python
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
def test_online_mode_with_mock_client():
|
||||
"""Test service in online mode with mock client."""
|
||||
# Mock the client
|
||||
mock_client = Mock()
|
||||
mock_client.get_data.return_value = {"data": [{"symbol": "TEST", "price": 100.0}]}
|
||||
|
||||
real_repo = ServiceRepository("test_data")
|
||||
|
||||
service = ServiceClass(
|
||||
client=mock_client,
|
||||
repository=real_repo,
|
||||
online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_context("TEST", "2024-01-01", "2024-01-05")
|
||||
|
||||
# Validate structure
|
||||
assert isinstance(context, ContextModel)
|
||||
assert context.symbol == "TEST"
|
||||
assert len(context.data) > 0
|
||||
|
||||
# Test JSON serialization
|
||||
json_output = context.model_dump_json()
|
||||
assert len(json_output) > 0
|
||||
|
||||
# Verify client was called
|
||||
mock_client.get_data.assert_called_once()
|
||||
```
|
||||
|
||||
##### Mock Client Guidelines
|
||||
- **Use unittest.mock**: Use `Mock()` objects instead of custom mock classes
|
||||
- **Realistic data**: Return data structures that match actual API responses
|
||||
- **Date consistency**: Use fixed dates that work with test date ranges
|
||||
- **Error simulation**: Configure mocks to raise exceptions for testing error handling paths
|
||||
- **Multiple scenarios**: Use different return values for different test cases
|
||||
|
||||
### Configuration
|
||||
- **Environment Variables**: Create `.env` file with API keys (see `.env.example`)
|
||||
- **Config Class**: `TradingAgentsConfig` in `tradingagents/config.py` handles all configuration
|
||||
- **Tool Configuration**: `.mise.toml` manages Python 3.13, uv, ruff, pyright
|
||||
- **Code Quality**: `pyproject.toml` contains ruff and pyright configurations
|
||||
|
||||
#### Required Environment Variables
|
||||
|
||||
##### Core LLM APIs (Choose One)
|
||||
```bash
|
||||
# For OpenAI (default)
|
||||
export OPENAI_API_KEY="your_openai_api_key"
|
||||
|
||||
# For Anthropic Claude
|
||||
export ANTHROPIC_API_KEY="your_anthropic_api_key"
|
||||
|
||||
# For Google Gemini
|
||||
export GOOGLE_API_KEY="your_google_api_key"
|
||||
```
|
||||
|
||||
##### Data Sources (Optional)
|
||||
```bash
|
||||
# For financial data
|
||||
export FINNHUB_API_KEY="your_finnhub_api_key"
|
||||
|
||||
# For Reddit data
|
||||
export REDDIT_CLIENT_ID="your_reddit_client_id"
|
||||
export REDDIT_CLIENT_SECRET="your_reddit_client_secret"
|
||||
export REDDIT_USER_AGENT="your_app_name"
|
||||
```
|
||||
|
||||
## Architecture Deep Dive
|
||||
|
||||
### Multi-Agent Trading Framework
|
||||
TradingAgents implements a sophisticated multi-agent system that mirrors real-world trading firms with specialized roles and structured workflows.
|
||||
|
||||
### Core Architecture Components
|
||||
|
||||
#### 1. **Agent Teams** (Sequential Workflow)
|
||||
```
|
||||
Analyst Team → Research Team → Trading Team → Risk Management Team
|
||||
```
|
||||
|
||||
**Analyst Team** (`tradingagents/agents/analysts/`)
|
||||
- **Market Analyst**: Technical analysis using Yahoo Finance and StockStats
|
||||
- **Fundamentals Analyst**: Financial statements and company fundamentals via SimFin/Finnhub
|
||||
- **News Analyst**: News sentiment analysis and world affairs impact
|
||||
- **Social Media Analyst**: Reddit and social platform sentiment analysis
|
||||
|
||||
**Research Team** (`tradingagents/agents/researchers/`)
|
||||
- **Bull Researcher**: Advocates for investment opportunities and growth potential
|
||||
- **Bear Researcher**: Highlights risks and argues against investments
|
||||
- **Research Manager**: Synthesizes debates and creates investment recommendations
|
||||
|
||||
**Trading Team** (`tradingagents/agents/trader/`)
|
||||
- **Trader**: Converts investment plans into specific trading decisions
|
||||
|
||||
**Risk Management Team** (`tradingagents/agents/risk_mgmt/`)
|
||||
- **Aggressive/Conservative/Neutral Debators**: Different risk perspectives
|
||||
- **Risk Manager**: Final decision maker balancing risk and reward
|
||||
|
||||
#### 2. **Domain-Driven Architecture** (`tradingagents/domains/`)
|
||||
**Domain-Driven Design (DDD) Architecture** (Current):
|
||||
The system has been restructured using Domain-Driven Design principles with three main bounded contexts:
|
||||
|
||||
**Domain Boundaries & Bounded Contexts:**
|
||||
- **Financial Data Domain** (`tradingagents/domains/marketdata/`): Market prices, technical indicators, fundamentals, insider data
|
||||
- **News Domain** (`tradingagents/domains/news/`): News articles, sentiment analysis, content aggregation
|
||||
- **Social Media Domain** (`tradingagents/domains/socialmedia/`): Social media posts, engagement metrics, sentiment analysis
|
||||
|
||||
**DDD Tactical Patterns per Domain:**
|
||||
- **Domain Services**: Business logic encapsulated in domain-specific services (`MarketDataService`, `NewsService`, `SocialMediaService`)
|
||||
- **Value Objects**: Immutable data structures (`SentimentScore`, `TechnicalIndicatorData`, `PostMetadata`)
|
||||
- **Entities**: Objects with identity and lifecycle (`NewsArticle`, `PostData`)
|
||||
- **Repository Pattern**: Domain-specific data access with smart caching, deduplication, and gap detection
|
||||
- **Context Objects**: Structured domain data containers (`MarketDataContext`, `NewsContext`, `SocialContext`)
|
||||
|
||||
**Domain Infrastructure per Bounded Context:**
|
||||
```
|
||||
marketdata/
|
||||
├── clients/ # YFinanceClient, FinnhubClient (domain-specific)
|
||||
├── repos/ # MarketDataRepository, FundamentalRepository
|
||||
├── services/ # MarketDataService, FundamentalDataService, InsiderDataService
|
||||
└── models/ # Domain Value Objects and Entities
|
||||
|
||||
news/
|
||||
├── clients/ # GoogleNewsClient (domain-specific)
|
||||
├── repositories/ # NewsRepository with article deduplication
|
||||
├── services/ # NewsService with sentiment analysis
|
||||
└── models/ # NewsArticle, SentimentScore
|
||||
|
||||
socialmedia/
|
||||
├── clients/ # RedditClient (domain-specific)
|
||||
├── repositories/ # SocialMediaRepository with engagement tracking
|
||||
├── services/ # SocialMediaService with sentiment analysis
|
||||
└── models/ # PostData, EngagementMetrics
|
||||
```
|
||||
|
||||
**Agent Integration Strategy - Anti-Corruption Layer (ACL):**
|
||||
- **AgentToolkit as ACL**: Mediates between agents (string-based, procedural) and domains (object-oriented, rich models)
|
||||
- **Data Translation**: Converts rich Pydantic domain models to structured JSON strings for LLM consumption
|
||||
- **Parameter Adaptation**: Handles interface mismatches (single date → date ranges, etc.)
|
||||
- **Backward Compatibility**: Preserves existing agent tool interface while providing domain service benefits
|
||||
|
||||
#### 3. **Graph Orchestration** (`tradingagents/graph/`)
|
||||
LangGraph-based workflow management:
|
||||
|
||||
- **TradingAgentsGraph**: Main orchestrator class
|
||||
- **State Management**: `AgentState`, `InvestDebateState`, `RiskDebateState` track workflow progress
|
||||
- **Conditional Logic**: Dynamic routing based on tool usage and debate completion
|
||||
- **Memory System**: ChromaDB-based vector memory for learning from past decisions
|
||||
|
||||
#### 4. **Configuration System**
|
||||
- **TradingAgentsConfig**: Centralized configuration with environment variable support
|
||||
- **Multi-LLM Support**: OpenAI, Anthropic, Google, Ollama, OpenRouter
|
||||
- **Data Modes**: Online (live APIs) vs offline (cached data)
|
||||
|
||||
### Key Design Patterns
|
||||
|
||||
1. **Debate-Driven Decision Making**: Critical decisions emerge from structured agent debates
|
||||
2. **Memory-Augmented Learning**: Agents learn from past similar situations using vector similarity
|
||||
3. **Repository-First Data Strategy**: Services always read from repositories with separate update operations
|
||||
4. **Structured JSON Contexts**: Replace error-prone string parsing with rich Pydantic models
|
||||
5. **Factory Pattern**: Agent creation via factory functions for flexible configuration
|
||||
6. **Signal Processing**: Final trading decisions processed into clean BUY/SELL/HOLD signals
|
||||
7. **Quality-Aware Data**: All contexts include quality metadata to help agents make better decisions
|
||||
|
||||
### Code Style Guidelines
|
||||
|
||||
#### General Style
|
||||
- **Functions**: Snake_case naming (e.g., `fundamentals_analyst_node`, `create_fundamentals_analyst`)
|
||||
- **Classes**: PascalCase (e.g., `TradingAgentsGraph`, `MessageBuffer`)
|
||||
- **Variables**: Snake_case (e.g., `current_date`, `company_of_interest`)
|
||||
- **Constants**: UPPER_CASE (e.g., `DEFAULT_CONFIG`)
|
||||
- **Imports**: Standard library first, third-party, then local imports (langchain, tradingagents modules)
|
||||
|
||||
#### Ruff Formatting & Linting Rules
|
||||
**Formatting** (`mise run format`):
|
||||
- **Line length**: 88 characters maximum
|
||||
- **Quote style**: Double quotes (`"string"`)
|
||||
- **Indentation**: 4 spaces (no tabs)
|
||||
- **Trailing commas**: Preserved for multi-line structures
|
||||
- **Line endings**: Auto-detected based on platform
|
||||
|
||||
**Linting** (`mise run lint`):
|
||||
- **Selected rules**:
|
||||
- `E`, `W`: pycodestyle errors and warnings
|
||||
- `F`: pyflakes (undefined names, unused imports)
|
||||
- `I`: isort (import sorting)
|
||||
- `B`: flake8-bugbear (common bugs)
|
||||
- `C4`: flake8-comprehensions (list/dict comprehensions)
|
||||
- `UP`: pyupgrade (Python syntax modernization)
|
||||
- `ARG`: flake8-unused-arguments
|
||||
- `SIM`: flake8-simplify (code simplification)
|
||||
- `TCH`: flake8-type-checking (type annotation imports)
|
||||
|
||||
- **Ignored rules**:
|
||||
- `E501`: Line too long (handled by formatter)
|
||||
- `B008`: Function calls in argument defaults (allowed for LangChain)
|
||||
- `C901`: Complex functions (legacy code tolerance)
|
||||
- `ARG001`, `ARG002`: Unused arguments (common in callbacks)
|
||||
|
||||
- **Import sorting**: `tradingagents` and `cli` treated as first-party modules
|
||||
|
||||
#### Pyright Type Checking Rules
|
||||
**Configuration** (`mise run typecheck`):
|
||||
- **Tool**: pyright 1.1.390+ with standard type checking mode
|
||||
- **Python version**: 3.10+ (configured for compatibility with modern syntax)
|
||||
- **Coverage**: Includes `tradingagents/`, `cli/`, and `main.py`
|
||||
- **Exclusions**: `__pycache__`, `node_modules`, `.venv`, `venv`, `build`, `dist`
|
||||
|
||||
**Type Annotation Guidelines**:
|
||||
- Use modern Python 3.10+ union syntax: `str | None` instead of `Optional[str]`
|
||||
- Use built-in generics: `list[str]` instead of `List[str]`
|
||||
- Use `dict[str, Any]` for flexible dictionaries
|
||||
- Import `from typing import Any` for untyped data structures
|
||||
- Prefer explicit return types on public functions
|
||||
- Use `# type: ignore` sparingly with explanatory comments
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
#### Working with Agents
|
||||
|
||||
**Current Approach** (AgentToolkit as Anti-Corruption Layer):
|
||||
- Use `AgentToolkit` from `tradingagents.agents.libs.agent_toolkit`
|
||||
- Toolkit injects all domain services via dependency injection
|
||||
- Provides LangChain `@tool` decorated methods for agent consumption
|
||||
- Returns rich Pydantic domain models directly to agents
|
||||
- Handles parameter validation, date calculations, and error handling
|
||||
|
||||
**Agent Integration Pattern**:
|
||||
```python
|
||||
from tradingagents.agents.libs.agent_toolkit import AgentToolkit
|
||||
|
||||
# AgentToolkit acts as Anti-Corruption Layer
|
||||
toolkit = AgentToolkit(
|
||||
news_service=news_service,
|
||||
marketdata_service=marketdata_service,
|
||||
fundamentaldata_service=fundamentaldata_service,
|
||||
socialmedia_service=socialmedia_service,
|
||||
insiderdata_service=insiderdata_service
|
||||
)
|
||||
|
||||
# Agents use toolkit tools that return rich domain contexts
|
||||
@tool
|
||||
def analyze_stock(symbol: str, date: str):
|
||||
# Get structured contexts from domain services via toolkit
|
||||
market_data = toolkit.get_market_data(symbol, start_date, end_date)
|
||||
social_data = toolkit.get_socialmedia_stock_info(symbol, date)
|
||||
news_data = toolkit.get_news(symbol, start_date, end_date)
|
||||
|
||||
# Work with rich Pydantic models
|
||||
price = market_data.latest_price
|
||||
sentiment = social_data.sentiment_summary.score
|
||||
article_count = news_data.article_count
|
||||
```
|
||||
|
||||
#### Working with Data Sources
|
||||
|
||||
**Current Domain Service Approach**:
|
||||
- **Repository-First**: Services always read data from repositories (local storage)
|
||||
- **Separate Update Operations**: Use dedicated update methods to fetch fresh data from APIs and store in repositories
|
||||
- **Clear Separation**: Reading data vs updating data are separate concerns
|
||||
- **Structured Contexts**: Services return rich Pydantic models with metadata
|
||||
- **Quality Awareness**: All contexts include data quality and source information
|
||||
|
||||
**Service Usage Pattern**:
|
||||
```python
|
||||
# Services use dependency injection
|
||||
service = MarketDataService(
|
||||
yfin_client=YFinanceClient(),
|
||||
repo=MarketDataRepository("cache_dir")
|
||||
)
|
||||
|
||||
# Always read from repository
|
||||
context = service.get_market_data_context("AAPL", "2024-01-01", "2024-01-31")
|
||||
|
||||
# Separate update operation to refresh repository data
|
||||
service.update_market_data("AAPL", "2024-01-01", "2024-01-31")
|
||||
```
|
||||
|
||||
#### Configuration Management
|
||||
- Use `TradingAgentsConfig.from_env()` for environment-based configuration
|
||||
- Key settings: `max_debate_rounds`, `llm_provider`, `online_tools`
|
||||
- Results are saved to `results_dir/{ticker}/{date}/` with structured reports
|
||||
|
||||
#### CLI Development
|
||||
- CLI uses Rich for terminal UI with live updating displays
|
||||
- Agent progress tracking through `MessageBuffer` class
|
||||
- Questionnaire-driven configuration collection
|
||||
- Real-time streaming of analysis results
|
||||
|
||||
### File Structure Context
|
||||
- **`cli/`**: Interactive command-line interface
|
||||
- **`tradingagents/agents/`**: All agent implementations
|
||||
- **`libs/agent_toolkit.py`**: AgentToolkit Anti-Corruption Layer with LangChain @tool decorators
|
||||
- **`libs/context_helpers.py`**: Helper functions for parsing structured JSON data
|
||||
- **`libs/agent_utils.py`**: Legacy Toolkit (being phased out)
|
||||
- **`tradingagents/domains/`**: Domain-Driven Design bounded contexts
|
||||
- **`marketdata/`**: Financial data domain (prices, indicators, fundamentals, insider data)
|
||||
- **`news/`**: News domain (articles, sentiment analysis)
|
||||
- **`socialmedia/`**: Social media domain (posts, engagement, sentiment)
|
||||
- **`tradingagents/dataflows/`**: Legacy data source integrations (being phased out)
|
||||
- **`tradingagents/graph/`**: LangGraph workflow orchestration
|
||||
- **`tradingagents/config.py`**: Configuration management
|
||||
- **`main.py`**: Direct Python usage example
|
||||
- **`AGENTS.md`**: Detailed agent documentation
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||
|
|
|
|||
|
|
@ -0,0 +1,585 @@
|
|||
{
|
||||
"russell1000_by_sector": {
|
||||
"summary": {
|
||||
"total_companies": 1000,
|
||||
"market_cap_representation": "93% of investable US equity market",
|
||||
"weighted_average_market_cap": "$1.013 trillion",
|
||||
"median_market_cap": "$15.7 billion",
|
||||
"minimum_market_cap": "$2.4 billion",
|
||||
"last_updated": "July 2025",
|
||||
"reconstitution_date": "June 27, 2025",
|
||||
"index_provider": "FTSE Russell (London Stock Exchange Group)"
|
||||
},
|
||||
"sector_weights": {
|
||||
"Information Technology": 30.52,
|
||||
"Financials": 14.31,
|
||||
"Health Care": 10.77,
|
||||
"Consumer Discretionary": 10.5,
|
||||
"Communication Services": 9.28,
|
||||
"Industrials": 8.42,
|
||||
"Consumer Staples": 5.94,
|
||||
"Energy": 3.24,
|
||||
"Real Estate": 2.57,
|
||||
"Utilities": 2.44,
|
||||
"Materials": 2.02
|
||||
},
|
||||
"sectors": {
|
||||
"Information Technology": {
|
||||
"weight": 30.52,
|
||||
"companies": [
|
||||
{ "ticker": "AAPL", "name": "Apple Inc." },
|
||||
{ "ticker": "MSFT", "name": "Microsoft Corporation" },
|
||||
{ "ticker": "NVDA", "name": "NVIDIA Corporation" },
|
||||
{ "ticker": "AVGO", "name": "Broadcom Inc." },
|
||||
{ "ticker": "ORCL", "name": "Oracle Corporation" },
|
||||
{ "ticker": "CRM", "name": "Salesforce, Inc." },
|
||||
{ "ticker": "ADBE", "name": "Adobe Inc." },
|
||||
{ "ticker": "CSCO", "name": "Cisco Systems, Inc." },
|
||||
{ "ticker": "ACN", "name": "Accenture plc" },
|
||||
{ "ticker": "NOW", "name": "ServiceNow, Inc." },
|
||||
{ "ticker": "INTU", "name": "Intuit Inc." },
|
||||
{
|
||||
"ticker": "IBM",
|
||||
"name": "International Business Machines Corporation"
|
||||
},
|
||||
{ "ticker": "AMD", "name": "Advanced Micro Devices, Inc." },
|
||||
{ "ticker": "QCOM", "name": "Qualcomm Incorporated" },
|
||||
{ "ticker": "TXN", "name": "Texas Instruments Incorporated" },
|
||||
{ "ticker": "AMAT", "name": "Applied Materials, Inc." },
|
||||
{ "ticker": "PLTR", "name": "Palantir Technologies Inc." },
|
||||
{ "ticker": "PANW", "name": "Palo Alto Networks, Inc." },
|
||||
{ "ticker": "ADI", "name": "Analog Devices, Inc." },
|
||||
{ "ticker": "CRWD", "name": "CrowdStrike Holdings, Inc." },
|
||||
{ "ticker": "DELL", "name": "Dell Technologies Inc." },
|
||||
{ "ticker": "HPQ", "name": "HP Inc." },
|
||||
{ "ticker": "ANET", "name": "Arista Networks, Inc." },
|
||||
{ "ticker": "SNOW", "name": "Snowflake Inc." },
|
||||
{ "ticker": "DDOG", "name": "Datadog, Inc." },
|
||||
{ "ticker": "TEAM", "name": "Atlassian Corporation" },
|
||||
{ "ticker": "ZS", "name": "Zscaler, Inc." },
|
||||
{ "ticker": "MRVL", "name": "Marvell Technology, Inc." },
|
||||
{
|
||||
"ticker": "CTSH",
|
||||
"name": "Cognizant Technology Solutions Corporation"
|
||||
},
|
||||
{ "ticker": "OKTA", "name": "Okta, Inc." },
|
||||
{ "ticker": "GLW", "name": "Corning Incorporated" },
|
||||
{ "ticker": "KEYS", "name": "Keysight Technologies, Inc." },
|
||||
{ "ticker": "TEL", "name": "TE Connectivity Ltd." },
|
||||
{ "ticker": "JNPR", "name": "Juniper Networks, Inc." },
|
||||
{ "ticker": "SMCI", "name": "Super Micro Computer, Inc." },
|
||||
{ "ticker": "NTAP", "name": "NetApp, Inc." },
|
||||
{ "ticker": "STX", "name": "Seagate Technology Holdings plc" },
|
||||
{ "ticker": "WDC", "name": "Western Digital Corporation" },
|
||||
{ "ticker": "EPAM", "name": "EPAM Systems, Inc." },
|
||||
{ "ticker": "FLEX", "name": "Flex Ltd." },
|
||||
{ "ticker": "DXC", "name": "DXC Technology Company" }
|
||||
]
|
||||
},
|
||||
"Financials": {
|
||||
"weight": 14.31,
|
||||
"companies": [
|
||||
{ "ticker": "BRK.B", "name": "Berkshire Hathaway Inc." },
|
||||
{ "ticker": "JPM", "name": "JPMorgan Chase & Co." },
|
||||
{ "ticker": "V", "name": "Visa Inc." },
|
||||
{ "ticker": "MA", "name": "Mastercard Incorporated" },
|
||||
{ "ticker": "BAC", "name": "Bank of America Corporation" },
|
||||
{ "ticker": "WFC", "name": "Wells Fargo & Company" },
|
||||
{ "ticker": "GS", "name": "The Goldman Sachs Group, Inc." },
|
||||
{ "ticker": "MS", "name": "Morgan Stanley" },
|
||||
{ "ticker": "SPGI", "name": "S&P Global Inc." },
|
||||
{ "ticker": "AXP", "name": "American Express Company" },
|
||||
{ "ticker": "BLK", "name": "BlackRock, Inc." },
|
||||
{ "ticker": "C", "name": "Citigroup Inc." },
|
||||
{ "ticker": "SCHW", "name": "The Charles Schwab Corporation" },
|
||||
{ "ticker": "CB", "name": "Chubb Limited" },
|
||||
{ "ticker": "PGR", "name": "The Progressive Corporation" },
|
||||
{ "ticker": "BX", "name": "Blackstone Inc." },
|
||||
{ "ticker": "ICE", "name": "Intercontinental Exchange, Inc." },
|
||||
{ "ticker": "CME", "name": "CME Group Inc." },
|
||||
{ "ticker": "PNC", "name": "The PNC Financial Services Group, Inc." },
|
||||
{ "ticker": "USB", "name": "U.S. Bancorp" },
|
||||
{ "ticker": "MCO", "name": "Moody's Corporation" },
|
||||
{ "ticker": "TFC", "name": "Truist Financial Corporation" },
|
||||
{ "ticker": "COF", "name": "Capital One Financial Corporation" },
|
||||
{ "ticker": "AIG", "name": "American International Group, Inc." },
|
||||
{ "ticker": "MET", "name": "MetLife, Inc." },
|
||||
{ "ticker": "PRU", "name": "Prudential Financial, Inc." },
|
||||
{ "ticker": "TRV", "name": "The Travelers Companies, Inc." },
|
||||
{ "ticker": "AFL", "name": "Aflac Incorporated" },
|
||||
{ "ticker": "ALL", "name": "The Allstate Corporation" },
|
||||
{ "ticker": "FI", "name": "Fiserv, Inc." },
|
||||
{ "ticker": "NDAQ", "name": "Nasdaq, Inc." },
|
||||
{ "ticker": "PYPL", "name": "PayPal Holdings, Inc." },
|
||||
{ "ticker": "KKR", "name": "KKR & Co. Inc." },
|
||||
{ "ticker": "MMC", "name": "Marsh & McLennan Companies, Inc." },
|
||||
{ "ticker": "APO", "name": "Apollo Global Management, Inc." },
|
||||
{ "ticker": "MSCI", "name": "MSCI Inc." },
|
||||
{ "ticker": "AON", "name": "Aon plc" },
|
||||
{
|
||||
"ticker": "FIS",
|
||||
"name": "Fidelity National Information Services, Inc."
|
||||
},
|
||||
{ "ticker": "SYF", "name": "Synchrony Financial" },
|
||||
{ "ticker": "AMP", "name": "Ameriprise Financial, Inc." },
|
||||
{ "ticker": "MTB", "name": "M&T Bank Corporation" },
|
||||
{ "ticker": "FITB", "name": "Fifth Third Bancorp" },
|
||||
{ "ticker": "HBAN", "name": "Huntington Bancshares Incorporated" },
|
||||
{ "ticker": "STT", "name": "State Street Corporation" },
|
||||
{ "ticker": "RF", "name": "Regions Financial Corporation" },
|
||||
{ "ticker": "NTRS", "name": "Northern Trust Corporation" },
|
||||
{ "ticker": "CFG", "name": "Citizens Financial Group, Inc." },
|
||||
{ "ticker": "CINF", "name": "Cincinnati Financial Corporation" },
|
||||
{ "ticker": "KEY", "name": "KeyCorp" },
|
||||
{ "ticker": "WRB", "name": "W. R. Berkley Corporation" },
|
||||
{ "ticker": "L", "name": "Loews Corporation" },
|
||||
{ "ticker": "ACGL", "name": "Arch Capital Group Ltd." },
|
||||
{ "ticker": "TROW", "name": "T. Rowe Price Group, Inc." },
|
||||
{ "ticker": "RJF", "name": "Raymond James Financial, Inc." },
|
||||
{ "ticker": "FDS", "name": "FactSet Research Systems Inc." },
|
||||
{ "ticker": "CBOE", "name": "Cboe Global Markets, Inc." },
|
||||
{ "ticker": "COIN", "name": "Coinbase Global, Inc." },
|
||||
{ "ticker": "CPAY", "name": "Corpay, Inc." },
|
||||
{ "ticker": "GPN", "name": "Global Payments Inc." },
|
||||
{ "ticker": "AJG", "name": "Arthur J. Gallagher & Co." },
|
||||
{
|
||||
"ticker": "WTW",
|
||||
"name": "Willis Towers Watson Public Limited Company"
|
||||
},
|
||||
{ "ticker": "BRO", "name": "Brown & Brown, Inc." },
|
||||
{ "ticker": "GL", "name": "Globe Life Inc." },
|
||||
{ "ticker": "EG", "name": "Everest Group, Ltd." },
|
||||
{ "ticker": "PFG", "name": "Principal Financial Group, Inc." },
|
||||
{ "ticker": "HIG", "name": "The Hartford Insurance Group, Inc." },
|
||||
{ "ticker": "BEN", "name": "Franklin Resources, Inc." },
|
||||
{ "ticker": "IVZ", "name": "Invesco Ltd." },
|
||||
{ "ticker": "AIZ", "name": "Assurant, Inc." },
|
||||
{ "ticker": "ERIE", "name": "Erie Indemnity Company" },
|
||||
{ "ticker": "JKHY", "name": "Jack Henry & Associates, Inc." },
|
||||
{ "ticker": "MKTX", "name": "MarketAxess Holdings Inc." }
|
||||
]
|
||||
},
|
||||
"Health Care": {
|
||||
"weight": 10.77,
|
||||
"companies": [
|
||||
{ "ticker": "LLY", "name": "Eli Lilly and Company" },
|
||||
{ "ticker": "UNH", "name": "UnitedHealth Group Incorporated" },
|
||||
{ "ticker": "JNJ", "name": "Johnson & Johnson" },
|
||||
{ "ticker": "ABBV", "name": "AbbVie Inc." },
|
||||
{ "ticker": "MRK", "name": "Merck & Co., Inc." },
|
||||
{ "ticker": "TMO", "name": "Thermo Fisher Scientific Inc." },
|
||||
{ "ticker": "ABT", "name": "Abbott Laboratories" },
|
||||
{ "ticker": "PFE", "name": "Pfizer Inc." },
|
||||
{ "ticker": "CVS", "name": "CVS Health Corporation" },
|
||||
{ "ticker": "ELV", "name": "Elevance Health, Inc." },
|
||||
{ "ticker": "AMGN", "name": "Amgen Inc." },
|
||||
{ "ticker": "DHR", "name": "Danaher Corporation" },
|
||||
{ "ticker": "ISRG", "name": "Intuitive Surgical, Inc." },
|
||||
{ "ticker": "BSX", "name": "Boston Scientific Corporation" },
|
||||
{ "ticker": "VRTX", "name": "Vertex Pharmaceuticals Incorporated" },
|
||||
{ "ticker": "SYK", "name": "Stryker Corporation" },
|
||||
{ "ticker": "GILD", "name": "Gilead Sciences, Inc." },
|
||||
{ "ticker": "MDT", "name": "Medtronic plc" },
|
||||
{ "ticker": "CI", "name": "Cigna Group" },
|
||||
{ "ticker": "REGN", "name": "Regeneron Pharmaceuticals, Inc." },
|
||||
{ "ticker": "BMY", "name": "Bristol-Myers Squibb Company" },
|
||||
{ "ticker": "MCK", "name": "McKesson Corporation" },
|
||||
{ "ticker": "ZTS", "name": "Zoetis Inc." },
|
||||
{ "ticker": "HCA", "name": "HCA Healthcare, Inc." },
|
||||
{ "ticker": "BDX", "name": "Becton, Dickinson and Company" },
|
||||
{ "ticker": "CNC", "name": "Centene Corporation" },
|
||||
{ "ticker": "HUM", "name": "Humana Inc." },
|
||||
{ "ticker": "EW", "name": "Edwards Lifesciences Corporation" },
|
||||
{ "ticker": "CAH", "name": "Cardinal Health, Inc." },
|
||||
{ "ticker": "BIIB", "name": "Biogen Inc." },
|
||||
{ "ticker": "A", "name": "Agilent Technologies, Inc." },
|
||||
{ "ticker": "IQV", "name": "IQVIA Holdings Inc." },
|
||||
{ "ticker": "DXCM", "name": "DexCom, Inc." },
|
||||
{ "ticker": "RMD", "name": "ResMed Inc." },
|
||||
{ "ticker": "IDXX", "name": "IDEXX Laboratories, Inc." },
|
||||
{ "ticker": "MTD", "name": "Mettler-Toledo International Inc." },
|
||||
{ "ticker": "WST", "name": "West Pharmaceutical Services, Inc." },
|
||||
{ "ticker": "MOH", "name": "Molina Healthcare, Inc." },
|
||||
{ "ticker": "LH", "name": "LabCorp" },
|
||||
{ "ticker": "PODD", "name": "Insulet Corporation" },
|
||||
{ "ticker": "STE", "name": "STERIS plc" },
|
||||
{ "ticker": "DGX", "name": "Quest Diagnostics Incorporated" },
|
||||
{ "ticker": "WAT", "name": "Waters Corporation" },
|
||||
{ "ticker": "ALGN", "name": "Align Technology, Inc." },
|
||||
{ "ticker": "ZBH", "name": "Zimmer Biomet Holdings, Inc." },
|
||||
{ "ticker": "COO", "name": "The Cooper Companies, Inc." },
|
||||
{ "ticker": "BAX", "name": "Baxter International Inc." },
|
||||
{ "ticker": "HOLX", "name": "Hologic, Inc." },
|
||||
{ "ticker": "GEHC", "name": "GE HealthCare Technologies Inc." },
|
||||
{ "ticker": "COR", "name": "Cencora, Inc." }
|
||||
]
|
||||
},
|
||||
"Consumer Discretionary": {
|
||||
"weight": 10.5,
|
||||
"companies": [
|
||||
{ "ticker": "AMZN", "name": "Amazon.com, Inc." },
|
||||
{ "ticker": "TSLA", "name": "Tesla, Inc." },
|
||||
{ "ticker": "HD", "name": "The Home Depot, Inc." },
|
||||
{ "ticker": "MCD", "name": "McDonald's Corporation" },
|
||||
{ "ticker": "BKNG", "name": "Booking Holdings Inc." },
|
||||
{ "ticker": "NKE", "name": "NIKE, Inc." },
|
||||
{ "ticker": "LOW", "name": "Lowe's Companies, Inc." },
|
||||
{ "ticker": "SBUX", "name": "Starbucks Corporation" },
|
||||
{ "ticker": "TJX", "name": "The TJX Companies, Inc." },
|
||||
{ "ticker": "ORLY", "name": "O'Reilly Automotive, Inc." },
|
||||
{ "ticker": "ABNB", "name": "Airbnb Inc." },
|
||||
{ "ticker": "CMG", "name": "Chipotle Mexican Grill" },
|
||||
{ "ticker": "HLT", "name": "Hilton Worldwide Holdings Inc." },
|
||||
{ "ticker": "AZO", "name": "AutoZone, Inc." },
|
||||
{ "ticker": "RCL", "name": "Royal Caribbean Cruises Ltd." },
|
||||
{ "ticker": "MAR", "name": "Marriott International, Inc." },
|
||||
{ "ticker": "GM", "name": "General Motors Company" },
|
||||
{ "ticker": "DASH", "name": "DoorDash, Inc." },
|
||||
{ "ticker": "F", "name": "Ford Motor Company" },
|
||||
{ "ticker": "ROST", "name": "Ross Stores, Inc." },
|
||||
{ "ticker": "DHI", "name": "D.R. Horton, Inc." },
|
||||
{ "ticker": "YUM", "name": "Yum! Brands, Inc." },
|
||||
{ "ticker": "LULU", "name": "Lululemon Athletica Inc." },
|
||||
{ "ticker": "LEN", "name": "Lennar Corporation" },
|
||||
{ "ticker": "GRMN", "name": "Garmin Ltd." },
|
||||
{ "ticker": "NVR", "name": "NVR, Inc." },
|
||||
{ "ticker": "EBAY", "name": "eBay Inc." },
|
||||
{ "ticker": "TSCO", "name": "Tractor Supply Company" },
|
||||
{ "ticker": "DRI", "name": "Darden Restaurants, Inc." },
|
||||
{ "ticker": "CCL", "name": "Carnival Corporation & plc" },
|
||||
{ "ticker": "ULTA", "name": "Ulta Beauty, Inc." },
|
||||
{ "ticker": "EXPE", "name": "Expedia Group, Inc." },
|
||||
{ "ticker": "DECK", "name": "Deckers Outdoor Corporation" },
|
||||
{ "ticker": "PHM", "name": "PulteGroup, Inc." },
|
||||
{ "ticker": "LVS", "name": "Las Vegas Sands Corp." },
|
||||
{ "ticker": "APTV", "name": "Aptiv PLC" },
|
||||
{ "ticker": "POOL", "name": "Pool Corporation" },
|
||||
{ "ticker": "BBY", "name": "Best Buy Co., Inc." },
|
||||
{ "ticker": "DPZ", "name": "Domino's Pizza, Inc." },
|
||||
{ "ticker": "WSM", "name": "Williams-Sonoma, Inc." },
|
||||
{ "ticker": "TPR", "name": "Tapestry, Inc." },
|
||||
{ "ticker": "GPC", "name": "Genuine Parts Company" },
|
||||
{ "ticker": "RL", "name": "Ralph Lauren Corporation" },
|
||||
{ "ticker": "KMX", "name": "CarMax, Inc." },
|
||||
{ "ticker": "LKQ", "name": "LKQ Corporation" },
|
||||
{ "ticker": "MGM", "name": "MGM Resorts International" },
|
||||
{ "ticker": "WYNN", "name": "Wynn Resorts, Limited" },
|
||||
{ "ticker": "CZR", "name": "Caesars Entertainment, Inc." },
|
||||
{ "ticker": "HAS", "name": "Hasbro, Inc." },
|
||||
{ "ticker": "NCLH", "name": "Norwegian Cruise Line Holdings Ltd." },
|
||||
{ "ticker": "MHK", "name": "Mohawk Industries, Inc." }
|
||||
]
|
||||
},
|
||||
"Communication Services": {
|
||||
"weight": 9.28,
|
||||
"companies": [
|
||||
{ "ticker": "GOOGL", "name": "Alphabet Inc. Class A" },
|
||||
{ "ticker": "GOOG", "name": "Alphabet Inc. Class C" },
|
||||
{ "ticker": "META", "name": "Meta Platforms, Inc." },
|
||||
{ "ticker": "NFLX", "name": "Netflix, Inc." },
|
||||
{ "ticker": "DIS", "name": "The Walt Disney Company" },
|
||||
{ "ticker": "CMCSA", "name": "Comcast Corporation" },
|
||||
{ "ticker": "VZ", "name": "Verizon Communications Inc." },
|
||||
{ "ticker": "T", "name": "AT&T Inc." },
|
||||
{ "ticker": "TMUS", "name": "T-Mobile US, Inc." },
|
||||
{ "ticker": "CHTR", "name": "Charter Communications, Inc." },
|
||||
{ "ticker": "UBER", "name": "Uber Technologies, Inc." },
|
||||
{ "ticker": "EA", "name": "Electronic Arts Inc." },
|
||||
{ "ticker": "TTWO", "name": "Take-Two Interactive Software, Inc." },
|
||||
{ "ticker": "WBD", "name": "Warner Bros. Discovery, Inc." },
|
||||
{ "ticker": "PARA", "name": "Paramount Global" },
|
||||
{ "ticker": "LYV", "name": "Live Nation Entertainment, Inc." },
|
||||
{ "ticker": "MTCH", "name": "Match Group, Inc." },
|
||||
{ "ticker": "ROKU", "name": "Roku, Inc." },
|
||||
{ "ticker": "SNAP", "name": "Snap Inc." },
|
||||
{ "ticker": "PINS", "name": "Pinterest, Inc." },
|
||||
{ "ticker": "FOX", "name": "Fox Corporation" },
|
||||
{ "ticker": "FOXA", "name": "Fox Corporation Class A" },
|
||||
{ "ticker": "LYFT", "name": "Lyft, Inc." },
|
||||
{ "ticker": "NWSA", "name": "News Corporation Class A" },
|
||||
{ "ticker": "NWS", "name": "News Corporation Class B" },
|
||||
{ "ticker": "NYT", "name": "The New York Times Company" },
|
||||
{ "ticker": "MSGS", "name": "Madison Square Garden Sports Corp." },
|
||||
{ "ticker": "SONY", "name": "Sony Group Corporation" }
|
||||
]
|
||||
},
|
||||
"Industrials": {
|
||||
"weight": 8.42,
|
||||
"companies": [
|
||||
{ "ticker": "UNP", "name": "Union Pacific Corporation" },
|
||||
{ "ticker": "RTX", "name": "RTX Corporation" },
|
||||
{ "ticker": "HON", "name": "Honeywell International Inc." },
|
||||
{ "ticker": "CAT", "name": "Caterpillar Inc." },
|
||||
{ "ticker": "BA", "name": "The Boeing Company" },
|
||||
{ "ticker": "GE", "name": "GE Aerospace" },
|
||||
{ "ticker": "DE", "name": "Deere & Company" },
|
||||
{ "ticker": "LMT", "name": "Lockheed Martin Corporation" },
|
||||
{ "ticker": "UPS", "name": "United Parcel Service, Inc." },
|
||||
{ "ticker": "ADP", "name": "Automatic Data Processing, Inc." },
|
||||
{ "ticker": "ETN", "name": "Eaton Corporation plc" },
|
||||
{ "ticker": "WM", "name": "Waste Management, Inc." },
|
||||
{ "ticker": "PH", "name": "Parker-Hannifin Corporation" },
|
||||
{ "ticker": "CTAS", "name": "Cintas Corporation" },
|
||||
{ "ticker": "ITW", "name": "Illinois Tool Works Inc." },
|
||||
{ "ticker": "GEV", "name": "GE Vernova Inc." },
|
||||
{ "ticker": "CSX", "name": "CSX Corporation" },
|
||||
{ "ticker": "GD", "name": "General Dynamics Corporation" },
|
||||
{ "ticker": "EMR", "name": "Emerson Electric Co." },
|
||||
{ "ticker": "NSC", "name": "Norfolk Southern Corporation" },
|
||||
{ "ticker": "NOC", "name": "Northrop Grumman Corporation" },
|
||||
{ "ticker": "MMM", "name": "3M Company" },
|
||||
{ "ticker": "TT", "name": "Trane Technologies plc" },
|
||||
{ "ticker": "TDG", "name": "TransDigm Group Incorporated" },
|
||||
{ "ticker": "CARR", "name": "Carrier Global Corporation" },
|
||||
{ "ticker": "PCAR", "name": "PACCAR Inc" },
|
||||
{ "ticker": "OTIS", "name": "Otis Worldwide Corporation" },
|
||||
{ "ticker": "JCI", "name": "Johnson Controls International plc" },
|
||||
{ "ticker": "PWR", "name": "Quanta Services, Inc." },
|
||||
{ "ticker": "FDX", "name": "FedEx Corporation" },
|
||||
{ "ticker": "RSG", "name": "Republic Services, Inc." },
|
||||
{ "ticker": "URI", "name": "United Rentals, Inc." },
|
||||
{ "ticker": "FAST", "name": "Fastenal Company" },
|
||||
{ "ticker": "DAL", "name": "Delta Air Lines, Inc." },
|
||||
{ "ticker": "CPRT", "name": "Copart, Inc." },
|
||||
{ "ticker": "HWM", "name": "Howmet Aerospace Inc." },
|
||||
{ "ticker": "LHX", "name": "L3Harris Technologies, Inc." },
|
||||
{ "ticker": "VRSK", "name": "Verisk Analytics, Inc." },
|
||||
{ "ticker": "PAYX", "name": "Paychex, Inc." },
|
||||
{ "ticker": "AXON", "name": "Axon Enterprise, Inc." },
|
||||
{ "ticker": "ROK", "name": "Rockwell Automation, Inc." },
|
||||
{ "ticker": "AME", "name": "AMETEK, Inc." },
|
||||
{ "ticker": "ODFL", "name": "Old Dominion Freight Line, Inc." },
|
||||
{ "ticker": "GWW", "name": "W.W. Grainger, Inc." },
|
||||
{ "ticker": "CMI", "name": "Cummins Inc." },
|
||||
{
|
||||
"ticker": "WAB",
|
||||
"name": "Westinghouse Air Brake Technologies Corporation"
|
||||
},
|
||||
{ "ticker": "IR", "name": "Ingersoll Rand Inc." },
|
||||
{ "ticker": "EFX", "name": "Equifax Inc." },
|
||||
{ "ticker": "XYL", "name": "Xylem Inc." }
|
||||
]
|
||||
},
|
||||
"Consumer Staples": {
|
||||
"weight": 5.94,
|
||||
"companies": [
|
||||
{ "ticker": "WMT", "name": "Walmart Inc." },
|
||||
{ "ticker": "PG", "name": "The Procter & Gamble Company" },
|
||||
{ "ticker": "COST", "name": "Costco Wholesale Corporation" },
|
||||
{ "ticker": "KO", "name": "The Coca-Cola Company" },
|
||||
{ "ticker": "PEP", "name": "PepsiCo, Inc." },
|
||||
{ "ticker": "PM", "name": "Philip Morris International Inc." },
|
||||
{ "ticker": "MDLZ", "name": "Mondelez International, Inc." },
|
||||
{ "ticker": "MO", "name": "Altria Group, Inc." },
|
||||
{ "ticker": "CL", "name": "Colgate-Palmolive Company" },
|
||||
{ "ticker": "KMB", "name": "Kimberly-Clark Corporation" },
|
||||
{ "ticker": "GIS", "name": "General Mills, Inc." },
|
||||
{ "ticker": "ADM", "name": "Archer-Daniels-Midland Company" },
|
||||
{ "ticker": "KR", "name": "The Kroger Co." },
|
||||
{ "ticker": "SYY", "name": "Sysco Corporation" },
|
||||
{ "ticker": "KDP", "name": "Keurig Dr Pepper Inc." },
|
||||
{ "ticker": "HSY", "name": "The Hershey Company" },
|
||||
{ "ticker": "KHC", "name": "The Kraft Heinz Company" },
|
||||
{ "ticker": "CHD", "name": "Church & Dwight Co., Inc." },
|
||||
{ "ticker": "TGT", "name": "Target Corporation" },
|
||||
{ "ticker": "MNST", "name": "Monster Beverage Corporation" },
|
||||
{ "ticker": "KVUE", "name": "Kenvue Inc." },
|
||||
{ "ticker": "K", "name": "Kellanova" },
|
||||
{ "ticker": "STZ", "name": "Constellation Brands, Inc." },
|
||||
{ "ticker": "CLX", "name": "The Clorox Company" },
|
||||
{ "ticker": "TSN", "name": "Tyson Foods, Inc." },
|
||||
{ "ticker": "DG", "name": "Dollar General Corporation" },
|
||||
{ "ticker": "EL", "name": "The Estée Lauder Companies Inc." },
|
||||
{ "ticker": "DLTR", "name": "Dollar Tree, Inc." },
|
||||
{ "ticker": "WBA", "name": "Walgreens Boots Alliance, Inc." },
|
||||
{ "ticker": "BG", "name": "Bunge Global SA" },
|
||||
{ "ticker": "MKC", "name": "McCormick & Company, Incorporated" },
|
||||
{ "ticker": "TAP", "name": "Molson Coors Beverage Company" },
|
||||
{ "ticker": "SJM", "name": "The J. M. Smucker Company" },
|
||||
{ "ticker": "LW", "name": "Lamb Weston Holdings, Inc." },
|
||||
{ "ticker": "BF.B", "name": "Brown-Forman Corporation Class B" },
|
||||
{ "ticker": "CAG", "name": "Conagra Brands, Inc." },
|
||||
{ "ticker": "CPB", "name": "Campbell Soup Company" },
|
||||
{ "ticker": "HRL", "name": "Hormel Foods Corporation" }
|
||||
]
|
||||
},
|
||||
"Energy": {
|
||||
"weight": 3.24,
|
||||
"companies": [
|
||||
{ "ticker": "XOM", "name": "Exxon Mobil Corporation" },
|
||||
{ "ticker": "CVX", "name": "Chevron Corporation" },
|
||||
{ "ticker": "COP", "name": "ConocoPhillips" },
|
||||
{ "ticker": "EOG", "name": "EOG Resources, Inc." },
|
||||
{ "ticker": "SLB", "name": "Schlumberger N.V." },
|
||||
{ "ticker": "MPC", "name": "Marathon Petroleum Corporation" },
|
||||
{ "ticker": "PSX", "name": "Phillips 66" },
|
||||
{ "ticker": "VLO", "name": "Valero Energy Corporation" },
|
||||
{ "ticker": "WMB", "name": "The Williams Companies, Inc." },
|
||||
{ "ticker": "OKE", "name": "ONEOK, Inc." },
|
||||
{ "ticker": "KMI", "name": "Kinder Morgan, Inc." },
|
||||
{ "ticker": "PXD", "name": "Pioneer Natural Resources Company" },
|
||||
{ "ticker": "BKR", "name": "Baker Hughes Company" },
|
||||
{ "ticker": "LNG", "name": "Cheniere Energy, Inc." },
|
||||
{ "ticker": "FANG", "name": "Diamondback Energy, Inc." },
|
||||
{ "ticker": "TRGP", "name": "Targa Resources Corp." },
|
||||
{ "ticker": "HAL", "name": "Halliburton Company" },
|
||||
{ "ticker": "OXY", "name": "Occidental Petroleum Corporation" },
|
||||
{ "ticker": "DVN", "name": "Devon Energy Corporation" },
|
||||
{ "ticker": "HES", "name": "Hess Corporation" },
|
||||
{ "ticker": "CTRA", "name": "Coterra Energy Inc." },
|
||||
{ "ticker": "EQT", "name": "EQT Corporation" },
|
||||
{ "ticker": "MRO", "name": "Marathon Oil Corporation" },
|
||||
{ "ticker": "CLR", "name": "Continental Resources, Inc." },
|
||||
{ "ticker": "APA", "name": "APA Corporation" },
|
||||
{ "ticker": "TPL", "name": "Texas Pacific Land Corporation" },
|
||||
{ "ticker": "CHK", "name": "Chesapeake Energy Corporation" },
|
||||
{ "ticker": "PTEN", "name": "Patterson-UTI Energy, Inc." },
|
||||
{ "ticker": "AR", "name": "Antero Resources Corporation" },
|
||||
{ "ticker": "SW", "name": "Southwestern Energy Company" },
|
||||
{ "ticker": "RRC", "name": "Range Resources Corporation" },
|
||||
{ "ticker": "HP", "name": "Helmerich & Payne, Inc." },
|
||||
{ "ticker": "NOV", "name": "National Oilwell Varco, Inc." },
|
||||
{ "ticker": "SM", "name": "SM Energy Company" },
|
||||
{ "ticker": "CVI", "name": "CVR Energy, Inc." },
|
||||
{ "ticker": "CEIX", "name": "Consol Energy Inc." },
|
||||
{ "ticker": "WKC", "name": "World Kinect Corporation" }
|
||||
]
|
||||
},
|
||||
"Real Estate": {
|
||||
"weight": 2.57,
|
||||
"companies": [
|
||||
{ "ticker": "PLD", "name": "Prologis, Inc." },
|
||||
{ "ticker": "AMT", "name": "American Tower Corporation" },
|
||||
{ "ticker": "EQIX", "name": "Equinix, Inc." },
|
||||
{ "ticker": "WELL", "name": "Welltower Inc." },
|
||||
{ "ticker": "SPG", "name": "Simon Property Group, Inc." },
|
||||
{ "ticker": "PSA", "name": "Public Storage" },
|
||||
{ "ticker": "O", "name": "Realty Income Corporation" },
|
||||
{ "ticker": "CCI", "name": "Crown Castle Inc." },
|
||||
{ "ticker": "VICI", "name": "VICI Properties Inc." },
|
||||
{ "ticker": "DLR", "name": "Digital Realty Trust, Inc." },
|
||||
{ "ticker": "CBRE", "name": "CBRE Group, Inc." },
|
||||
{ "ticker": "AVB", "name": "AvalonBay Communities, Inc." },
|
||||
{ "ticker": "EQR", "name": "Equity Residential" },
|
||||
{ "ticker": "WY", "name": "Weyerhaeuser Company" },
|
||||
{ "ticker": "SBAC", "name": "SBA Communications Corporation" },
|
||||
{ "ticker": "INVH", "name": "Invitation Homes Inc." },
|
||||
{ "ticker": "ARE", "name": "Alexandria Real Estate Equities, Inc." },
|
||||
{ "ticker": "VTR", "name": "Ventas, Inc." },
|
||||
{ "ticker": "EXR", "name": "Extra Space Storage Inc." },
|
||||
{ "ticker": "BXP", "name": "BXP, Inc." },
|
||||
{
|
||||
"ticker": "MAA",
|
||||
"name": "Mid-America Apartment Communities, Inc."
|
||||
},
|
||||
{ "ticker": "DOC", "name": "Healthpeak Properties, Inc." },
|
||||
{ "ticker": "HST", "name": "Host Hotels & Resorts, Inc." },
|
||||
{ "ticker": "KIM", "name": "Kimco Realty Corporation" },
|
||||
{ "ticker": "ESS", "name": "Essex Property Trust, Inc." },
|
||||
{ "ticker": "CSGP", "name": "CoStar Group, Inc." },
|
||||
{ "ticker": "UDR", "name": "UDR, Inc." },
|
||||
{ "ticker": "REG", "name": "Regency Centers Corporation" },
|
||||
{ "ticker": "IRM", "name": "Iron Mountain Incorporated" },
|
||||
{ "ticker": "CPT", "name": "Camden Property Trust" },
|
||||
{ "ticker": "FRT", "name": "Federal Realty Investment Trust" }
|
||||
]
|
||||
},
|
||||
"Utilities": {
|
||||
"weight": 2.44,
|
||||
"companies": [
|
||||
{ "ticker": "NEE", "name": "NextEra Energy, Inc." },
|
||||
{ "ticker": "SO", "name": "The Southern Company" },
|
||||
{ "ticker": "DUK", "name": "Duke Energy Corporation" },
|
||||
{ "ticker": "CEG", "name": "Constellation Energy Corporation" },
|
||||
{ "ticker": "SRE", "name": "Sempra" },
|
||||
{ "ticker": "AEP", "name": "American Electric Power Company, Inc." },
|
||||
{ "ticker": "D", "name": "Dominion Energy, Inc." },
|
||||
{ "ticker": "PCG", "name": "PG&E Corporation" },
|
||||
{
|
||||
"ticker": "PEG",
|
||||
"name": "Public Service Enterprise Group Incorporated"
|
||||
},
|
||||
{ "ticker": "EXC", "name": "Exelon Corporation" },
|
||||
{ "ticker": "XEL", "name": "Xcel Energy Inc." },
|
||||
{ "ticker": "ED", "name": "Consolidated Edison, Inc." },
|
||||
{ "ticker": "EIX", "name": "Edison International" },
|
||||
{ "ticker": "ETR", "name": "Entergy Corporation" },
|
||||
{ "ticker": "WEC", "name": "WEC Energy Group, Inc." },
|
||||
{ "ticker": "AWK", "name": "American Water Works Company, Inc." },
|
||||
{ "ticker": "DTE", "name": "DTE Energy Company" },
|
||||
{ "ticker": "PPL", "name": "PPL Corporation" },
|
||||
{ "ticker": "AEE", "name": "Ameren Corporation" },
|
||||
{ "ticker": "ATO", "name": "Atmos Energy Corporation" },
|
||||
{ "ticker": "ES", "name": "Eversource Energy" },
|
||||
{ "ticker": "CMS", "name": "CMS Energy Corporation" },
|
||||
{ "ticker": "CNP", "name": "CenterPoint Energy, Inc." },
|
||||
{ "ticker": "VST", "name": "Vistra Corp." },
|
||||
{ "ticker": "FE", "name": "FirstEnergy Corp." },
|
||||
{ "ticker": "LNT", "name": "Alliant Energy Corporation" },
|
||||
{ "ticker": "EVRG", "name": "Evergy, Inc." },
|
||||
{ "ticker": "NI", "name": "NiSource Inc." },
|
||||
{ "ticker": "AES", "name": "The AES Corporation" },
|
||||
{ "ticker": "PNW", "name": "Pinnacle West Capital Corporation" },
|
||||
{ "ticker": "NRG", "name": "NRG Energy, Inc." },
|
||||
{ "ticker": "OGE", "name": "OGE Energy Corp." },
|
||||
{ "ticker": "WTRG", "name": "Essential Utilities, Inc." },
|
||||
{ "ticker": "UGI", "name": "UGI Corporation" },
|
||||
{ "ticker": "NFG", "name": "National Fuel Gas Company" },
|
||||
{ "ticker": "HE", "name": "Hawaiian Electric Industries, Inc." }
|
||||
]
|
||||
},
|
||||
"Materials": {
|
||||
"weight": 2.02,
|
||||
"companies": [
|
||||
{ "ticker": "LIN", "name": "Linde plc" },
|
||||
{ "ticker": "SHW", "name": "The Sherwin-Williams Company" },
|
||||
{ "ticker": "APD", "name": "Air Products and Chemicals, Inc." },
|
||||
{ "ticker": "ECL", "name": "Ecolab Inc." },
|
||||
{ "ticker": "FCX", "name": "Freeport-McMoRan Inc." },
|
||||
{ "ticker": "NUE", "name": "Nucor Corporation" },
|
||||
{ "ticker": "NEM", "name": "Newmont Corporation" },
|
||||
{ "ticker": "CTVA", "name": "Corteva, Inc." },
|
||||
{ "ticker": "VMC", "name": "Vulcan Materials Company" },
|
||||
{ "ticker": "DOW", "name": "Dow Inc." },
|
||||
{ "ticker": "DD", "name": "DuPont de Nemours, Inc." },
|
||||
{ "ticker": "PPG", "name": "PPG Industries, Inc." },
|
||||
{ "ticker": "MLM", "name": "Martin Marietta Materials, Inc." },
|
||||
{ "ticker": "CF", "name": "CF Industries Holdings, Inc." },
|
||||
{ "ticker": "ALB", "name": "Albemarle Corporation" },
|
||||
{ "ticker": "LYB", "name": "LyondellBasell Industries N.V." },
|
||||
{ "ticker": "STLD", "name": "Steel Dynamics, Inc." },
|
||||
{ "ticker": "BALL", "name": "Ball Corporation" },
|
||||
{
|
||||
"ticker": "IFF",
|
||||
"name": "International Flavors & Fragrances Inc."
|
||||
},
|
||||
{ "ticker": "AMCR", "name": "Amcor plc" },
|
||||
{ "ticker": "AVY", "name": "Avery Dennison Corporation" },
|
||||
{ "ticker": "EMN", "name": "Eastman Chemical Company" },
|
||||
{ "ticker": "CE", "name": "Celanese Corporation" },
|
||||
{ "ticker": "IP", "name": "International Paper Company" },
|
||||
{ "ticker": "MOS", "name": "The Mosaic Company" },
|
||||
{ "ticker": "PKG", "name": "Packaging Corporation of America" },
|
||||
{ "ticker": "FMC", "name": "FMC Corporation" },
|
||||
{ "ticker": "OC", "name": "Owens Corning" },
|
||||
{ "ticker": "RS", "name": "Reliance Steel & Aluminum Co." },
|
||||
{ "ticker": "WLK", "name": "Westlake Chemical Corporation" },
|
||||
{ "ticker": "EXP", "name": "Eagle Materials Inc." },
|
||||
{ "ticker": "SW", "name": "Smurfit Westrock plc" },
|
||||
{ "ticker": "SEE", "name": "Sealed Air Corporation" },
|
||||
{ "ticker": "X", "name": "United States Steel Corporation" },
|
||||
{ "ticker": "RPM", "name": "RPM International Inc." },
|
||||
{ "ticker": "SCCO", "name": "Southern Copper Corporation" },
|
||||
{ "ticker": "HUN", "name": "Huntsman Corporation" },
|
||||
{ "ticker": "RGLD", "name": "Royal Gold, Inc." },
|
||||
{ "ticker": "SON", "name": "Sonoco Products Company" },
|
||||
{ "ticker": "SLGN", "name": "Silgan Holdings Inc." },
|
||||
{ "ticker": "NEU", "name": "NewMarket Corporation" },
|
||||
{ "ticker": "CBT", "name": "Cabot Corporation" },
|
||||
{ "ticker": "CMP", "name": "Compass Minerals International Inc." },
|
||||
{ "ticker": "SSRM", "name": "SSR Mining Inc." }
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,393 @@
|
|||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.config import DEFAULT_CONFIG, TradingAgentsConfig
|
||||
from tradingagents.domains.marketdata.fundamental_data_service import (
|
||||
BalanceSheetContext,
|
||||
CashFlowContext,
|
||||
FundamentalDataService,
|
||||
IncomeStatementContext,
|
||||
)
|
||||
from tradingagents.domains.marketdata.insider_data_service import (
|
||||
InsiderDataService,
|
||||
InsiderSentimentContext,
|
||||
InsiderTransactionContext,
|
||||
)
|
||||
from tradingagents.domains.marketdata.market_data_service import (
|
||||
MarketDataService,
|
||||
PriceDataContext,
|
||||
TAReportContext,
|
||||
)
|
||||
|
||||
# Import context models
|
||||
from tradingagents.domains.news.news_service import (
|
||||
GlobalNewsContext,
|
||||
NewsContext,
|
||||
NewsService,
|
||||
)
|
||||
from tradingagents.domains.socialmedia.social_media_service import (
|
||||
SocialMediaService,
|
||||
StockSocialContext,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentToolkit:
|
||||
def __init__(
|
||||
self,
|
||||
news_service: NewsService,
|
||||
marketdata_service: MarketDataService,
|
||||
fundamentaldata_service: FundamentalDataService,
|
||||
socialmedia_service: SocialMediaService,
|
||||
insiderdata_service: InsiderDataService,
|
||||
config: TradingAgentsConfig = DEFAULT_CONFIG,
|
||||
):
|
||||
self._news_service = news_service
|
||||
self._marketdata_service = marketdata_service
|
||||
self._fundamentaldata_service = fundamentaldata_service
|
||||
self._socialmedia_service = socialmedia_service
|
||||
self._insiderdata_service = insiderdata_service
|
||||
self._config = config
|
||||
|
||||
@tool
|
||||
def get_global_news(
|
||||
self,
|
||||
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
|
||||
) -> GlobalNewsContext:
|
||||
"""
|
||||
Retrieve global news from Reddit within a specified time frame.
|
||||
Args:
|
||||
curr_date (str): Date you want to get news for in yyyy-mm-dd format
|
||||
Returns:
|
||||
GlobalNewsContext: Structured global news context with articles and sentiment analysis.
|
||||
"""
|
||||
# Calculate date range (current date only)
|
||||
start_date = curr_date
|
||||
end_date = curr_date
|
||||
|
||||
# Call specialized service method
|
||||
return self._news_service.get_global_news_context(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
categories=["general", "business", "politics"],
|
||||
)
|
||||
|
||||
@tool
|
||||
def get_news(
|
||||
self,
|
||||
ticker: Annotated[
|
||||
str,
|
||||
"Search query of a company, e.g. 'AAPL, TSM, etc.",
|
||||
],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> NewsContext:
|
||||
"""
|
||||
Retrieve the latest news about a given stock from Finnhub within a date range
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
NewsContext: Structured news context with articles and sentiment analysis for the company.
|
||||
"""
|
||||
try:
|
||||
ticker = self._validate_ticker(ticker)
|
||||
# Validate date formats
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
return self._news_service.get_context(
|
||||
query=ticker, start_date=start_date, end_date=end_date, symbol=ticker
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting news for {ticker}: {e}")
|
||||
raise
|
||||
|
||||
@tool
|
||||
def get_socialmedia_stock_info(
|
||||
self,
|
||||
ticker: Annotated[
|
||||
str,
|
||||
"Ticker of a company. e.g. AAPL, TSM",
|
||||
],
|
||||
curr_date: Annotated[str, "Current date you want to get news for"],
|
||||
) -> StockSocialContext:
|
||||
"""
|
||||
Retrieve the latest news about a given stock from Reddit, given the current date.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): current date in yyyy-mm-dd format to get news for
|
||||
Returns:
|
||||
StockSocialContext: Structured social media context with posts and sentiment analysis for the stock.
|
||||
"""
|
||||
try:
|
||||
ticker = self._validate_ticker(ticker)
|
||||
# Validate date format
|
||||
datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
|
||||
return self._socialmedia_service.get_stock_social_context(
|
||||
symbol=ticker, start_date=curr_date, end_date=curr_date
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting social media info for {ticker}: {e}")
|
||||
raise
|
||||
|
||||
@tool
|
||||
def get_market_data(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> PriceDataContext:
|
||||
"""
|
||||
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
PriceDataContext: Structured price data context with historical prices and key metrics.
|
||||
"""
|
||||
try:
|
||||
symbol = self._validate_ticker(symbol)
|
||||
# Validate date formats
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
return self._marketdata_service.get_market_data_context(
|
||||
symbol=symbol, start_date=start_date, end_date=end_date
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market data for {symbol}: {e}")
|
||||
raise
|
||||
|
||||
@tool
|
||||
def get_ta_report(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[
|
||||
str, "technical indicator to get the analysis and report of"
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = None,
|
||||
) -> TAReportContext:
|
||||
"""
|
||||
Retrieve stock stats indicators for a given ticker symbol and indicator.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
indicator (str): Technical indicator to get the analysis and report of
|
||||
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
||||
look_back_days (int): How many days to look back, uses config default if None
|
||||
Returns:
|
||||
TAReportContext: Structured technical analysis context with indicator data and signals.
|
||||
"""
|
||||
try:
|
||||
symbol = self._validate_ticker(symbol)
|
||||
if look_back_days is None:
|
||||
look_back_days = self._config.default_ta_lookback_days
|
||||
start_date, end_date = self._calculate_date_range(curr_date, look_back_days)
|
||||
|
||||
return self._marketdata_service.get_ta_report_context(
|
||||
symbol=symbol,
|
||||
indicator=indicator,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting TA report for {symbol}: {e}")
|
||||
raise
|
||||
|
||||
@tool
|
||||
def get_insider_sentiment(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol for the company"],
|
||||
curr_date: Annotated[
|
||||
str,
|
||||
"current date of you are trading at, yyyy-mm-dd",
|
||||
],
|
||||
) -> InsiderSentimentContext:
|
||||
"""
|
||||
Retrieve insider sentiment information about a company (retrieved from public SEC information) for the configured lookback period
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
InsiderSentimentContext: Structured insider sentiment analysis with transaction data and sentiment scores.
|
||||
"""
|
||||
try:
|
||||
ticker = self._validate_ticker(ticker)
|
||||
start_date, end_date = self._calculate_date_range(curr_date)
|
||||
|
||||
return self._insiderdata_service.get_insider_sentiment_context(
|
||||
symbol=ticker, start_date=start_date, end_date=end_date
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting insider sentiment for {ticker}: {e}")
|
||||
raise
|
||||
|
||||
@tool
|
||||
def get_insider_transactions(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
curr_date: Annotated[
|
||||
str,
|
||||
"current date you are trading at, yyyy-mm-dd",
|
||||
],
|
||||
) -> InsiderTransactionContext:
|
||||
"""
|
||||
Retrieve insider transaction information about a company (retrieved from public SEC information) for the configured lookback period
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
InsiderTransactionContext: Structured insider transaction analysis with detailed transaction data.
|
||||
"""
|
||||
try:
|
||||
ticker = self._validate_ticker(ticker)
|
||||
start_date, end_date = self._calculate_date_range(curr_date)
|
||||
|
||||
return self._insiderdata_service.get_insider_transaction_context(
|
||||
symbol=ticker, start_date=start_date, end_date=end_date
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting insider transactions for {ticker}: {e}")
|
||||
raise
|
||||
|
||||
@tool
|
||||
def get_balance_sheet(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual/quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
) -> BalanceSheetContext:
|
||||
"""
|
||||
Retrieve the most recent balance sheet of a company
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
BalanceSheetContext: Structured balance sheet analysis with key liquidity and debt metrics.
|
||||
"""
|
||||
return self._fundamentaldata_service.get_balance_sheet_context(
|
||||
symbol=ticker,
|
||||
start_date=curr_date,
|
||||
end_date=curr_date,
|
||||
frequency=freq.lower(),
|
||||
)
|
||||
|
||||
@tool
|
||||
def get_cashflow(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual/quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
) -> CashFlowContext:
|
||||
"""
|
||||
Retrieve the most recent cash flow statement of a company
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
CashFlowContext: Structured cash flow analysis with operating cash flow metrics.
|
||||
"""
|
||||
return self._fundamentaldata_service.get_cashflow_context(
|
||||
symbol=ticker,
|
||||
start_date=curr_date,
|
||||
end_date=curr_date,
|
||||
frequency=freq.lower(),
|
||||
)
|
||||
|
||||
@tool
|
||||
def get_income_stmt(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual/quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
) -> IncomeStatementContext:
|
||||
"""
|
||||
Retrieve the most recent income statement of a company
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
IncomeStatementContext: Structured income statement analysis with profitability metrics.
|
||||
"""
|
||||
return self._fundamentaldata_service.get_income_statement_context(
|
||||
symbol=ticker,
|
||||
start_date=curr_date,
|
||||
end_date=curr_date,
|
||||
frequency=freq.lower(),
|
||||
)
|
||||
|
||||
def _calculate_date_range(
|
||||
self, curr_date: str, lookback_days: int | None = None
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Calculate start and end dates based on current date and lookback period.
|
||||
|
||||
Args:
|
||||
curr_date: Current date in YYYY-MM-DD format
|
||||
lookback_days: Number of days to look back (uses config default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (start_date, end_date) in YYYY-MM-DD format
|
||||
|
||||
Raises:
|
||||
ValueError: If date format is invalid
|
||||
"""
|
||||
try:
|
||||
curr_date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid date format '{curr_date}': {e}")
|
||||
raise ValueError(f"Date must be in YYYY-MM-DD format, got: {curr_date}")
|
||||
|
||||
if lookback_days is None:
|
||||
lookback_days = self._config.default_lookback_days
|
||||
|
||||
start_date_obj = curr_date_obj - timedelta(days=lookback_days)
|
||||
return start_date_obj.strftime("%Y-%m-%d"), curr_date
|
||||
|
||||
def _validate_ticker(self, ticker: str) -> str:
|
||||
"""
|
||||
Validate and sanitize ticker symbol.
|
||||
|
||||
Args:
|
||||
ticker: Ticker symbol to validate
|
||||
|
||||
Returns:
|
||||
Sanitized ticker symbol
|
||||
|
||||
Raises:
|
||||
ValueError: If ticker is invalid
|
||||
"""
|
||||
if not ticker or not isinstance(ticker, str):
|
||||
raise ValueError("Ticker must be a non-empty string")
|
||||
|
||||
# Remove whitespace and convert to uppercase
|
||||
ticker = ticker.strip().upper()
|
||||
|
||||
# Basic validation: only letters, numbers, and common symbols
|
||||
if not re.match(r"^[A-Z0-9.-]{1,10}$", ticker):
|
||||
raise ValueError(f"Invalid ticker format: {ticker}")
|
||||
|
||||
return ticker
|
||||
|
|
@ -1,415 +0,0 @@
|
|||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
import tradingagents.dataflows.interface as interface
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
|
||||
DEFAULT_CONFIG = TradingAgentsConfig()
|
||||
|
||||
|
||||
def create_msg_delete():
|
||||
def delete_messages(state):
|
||||
"""Clear messages and add placeholder for Anthropic compatibility"""
|
||||
messages = state["messages"]
|
||||
|
||||
# Remove all messages
|
||||
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||
|
||||
# Add a minimal placeholder message
|
||||
placeholder = HumanMessage(content="Continue")
|
||||
|
||||
return {"messages": removal_operations + [placeholder]}
|
||||
|
||||
return delete_messages
|
||||
|
||||
|
||||
class Toolkit:
|
||||
_config = TradingAgentsConfig()
|
||||
|
||||
@classmethod
|
||||
def update_config(cls, config):
|
||||
"""Update the class-level configuration."""
|
||||
cls._config = config
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
"""Access the configuration."""
|
||||
return self._config
|
||||
|
||||
def __init__(self, config=None):
|
||||
if config:
|
||||
self.update_config(config)
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_reddit_news(
|
||||
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global news from Reddit within a specified time frame.
|
||||
Args:
|
||||
curr_date (str): Date you want to get news for in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
|
||||
"""
|
||||
|
||||
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
|
||||
|
||||
return global_news_result
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_finnhub_news(
|
||||
ticker: Annotated[
|
||||
str,
|
||||
"Search query of a company, e.g. 'AAPL, TSM, etc.",
|
||||
],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
):
|
||||
"""
|
||||
Retrieve the latest news about a given stock from Finnhub within a date range
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted dataframe containing news about the company within the date range from start_date to end_date
|
||||
"""
|
||||
|
||||
end_date_str = end_date
|
||||
|
||||
end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
look_back_days = (end_date_dt - start_date_dt).days
|
||||
|
||||
finnhub_news_result = interface.get_finnhub_news(
|
||||
ticker, end_date_str, look_back_days
|
||||
)
|
||||
|
||||
return finnhub_news_result
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_reddit_stock_info(
|
||||
ticker: Annotated[
|
||||
str,
|
||||
"Ticker of a company. e.g. AAPL, TSM",
|
||||
],
|
||||
curr_date: Annotated[str, "Current date you want to get news for"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest news about a given stock from Reddit, given the current date.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): current date in yyyy-mm-dd format to get news for
|
||||
Returns:
|
||||
str: A formatted dataframe containing the latest news about the company on the given date
|
||||
"""
|
||||
|
||||
stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5)
|
||||
|
||||
return stock_news_results
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_YFin_data(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||
"""
|
||||
|
||||
result_data = interface.get_YFin_data(symbol, start_date, end_date)
|
||||
|
||||
# Convert DataFrame to string for tool output
|
||||
return result_data.to_string()
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_YFin_data_online(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||
"""
|
||||
|
||||
result_data = interface.get_YFin_data_online(symbol, start_date, end_date)
|
||||
|
||||
return result_data
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_stockstats_indicators_report(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[
|
||||
str, "technical indicator to get the analysis and report of"
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve stock stats indicators for a given ticker symbol and indicator.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
indicator (str): Technical indicator to get the analysis and report of
|
||||
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
||||
look_back_days (int): How many days to look back, default is 30
|
||||
Returns:
|
||||
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
|
||||
"""
|
||||
|
||||
result_stockstats = interface.get_stock_stats_indicators_window(
|
||||
symbol, indicator, curr_date, look_back_days, False
|
||||
)
|
||||
|
||||
return result_stockstats
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_stockstats_indicators_report_online(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[
|
||||
str, "technical indicator to get the analysis and report of"
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve stock stats indicators for a given ticker symbol and indicator.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
indicator (str): Technical indicator to get the analysis and report of
|
||||
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
||||
look_back_days (int): How many days to look back, default is 30
|
||||
Returns:
|
||||
str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator.
|
||||
"""
|
||||
|
||||
result_stockstats = interface.get_stock_stats_indicators_window(
|
||||
symbol, indicator, curr_date, look_back_days, True
|
||||
)
|
||||
|
||||
return result_stockstats
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_finnhub_company_insider_sentiment(
|
||||
ticker: Annotated[str, "ticker symbol for the company"],
|
||||
curr_date: Annotated[
|
||||
str,
|
||||
"current date of you are trading at, yyyy-mm-dd",
|
||||
],
|
||||
):
|
||||
"""
|
||||
Retrieve insider sentiment information about a company (retrieved from public SEC information) for the past 30 days
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: a report of the sentiment in the past 30 days starting at curr_date
|
||||
"""
|
||||
|
||||
data_sentiment = interface.get_finnhub_company_insider_sentiment(
|
||||
ticker, curr_date, 30
|
||||
)
|
||||
|
||||
return data_sentiment
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_finnhub_company_insider_transactions(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
curr_date: Annotated[
|
||||
str,
|
||||
"current date you are trading at, yyyy-mm-dd",
|
||||
],
|
||||
):
|
||||
"""
|
||||
Retrieve insider transaction information about a company (retrieved from public SEC information) for the past 30 days
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: a report of the company's insider transactions/trading information in the past 30 days
|
||||
"""
|
||||
|
||||
data_trans = interface.get_finnhub_company_insider_transactions(
|
||||
ticker, curr_date, 30
|
||||
)
|
||||
|
||||
return data_trans
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_simfin_balance_sheet(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual/quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
):
|
||||
"""
|
||||
Retrieve the most recent balance sheet of a company
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: a report of the company's most recent balance sheet
|
||||
"""
|
||||
|
||||
data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date)
|
||||
|
||||
return data_balance_sheet
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_simfin_cashflow(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual/quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
):
|
||||
"""
|
||||
Retrieve the most recent cash flow statement of a company
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: a report of the company's most recent cash flow statement
|
||||
"""
|
||||
|
||||
data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date)
|
||||
|
||||
return data_cashflow
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_simfin_income_stmt(
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[
|
||||
str,
|
||||
"reporting frequency of the company's financial history: annual/quarterly",
|
||||
],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
):
|
||||
"""
|
||||
Retrieve the most recent income statement of a company
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency of the company's financial history: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: a report of the company's most recent income statement
|
||||
"""
|
||||
|
||||
data_income_stmt = interface.get_simfin_income_statements(
|
||||
ticker, freq, curr_date
|
||||
)
|
||||
|
||||
return data_income_stmt
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_google_news(
|
||||
query: Annotated[str, "Query to search with"],
|
||||
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
|
||||
):
|
||||
"""
|
||||
Retrieve the latest news from Google News based on a query and date range.
|
||||
Args:
|
||||
query (str): Query to search with
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
look_back_days (int): How many days to look back
|
||||
Returns:
|
||||
str: A formatted string containing the latest news from Google News based on the query and date range.
|
||||
"""
|
||||
|
||||
google_news_results = interface.get_google_news(query, curr_date, 7)
|
||||
|
||||
return google_news_results
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_stock_news_openai(
|
||||
ticker: Annotated[str, "the company's ticker"],
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
):
|
||||
"""
|
||||
Retrieve the latest news about a given stock by using OpenAI's news API.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted string containing the latest news about the company on the given date.
|
||||
"""
|
||||
|
||||
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
|
||||
|
||||
return openai_news_results
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_global_news_openai(
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
):
|
||||
"""
|
||||
Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
|
||||
Args:
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted string containing the latest macroeconomic news on the given date.
|
||||
"""
|
||||
|
||||
openai_news_results = interface.get_global_news_openai(curr_date)
|
||||
|
||||
return openai_news_results
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_fundamentals_openai(
|
||||
ticker: Annotated[str, "the company's ticker"],
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
):
|
||||
"""
|
||||
Retrieve the latest fundamental information about a given stock on a given date by using OpenAI's news API.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted string containing the latest fundamental information about the company on the given date.
|
||||
"""
|
||||
|
||||
openai_fundamentals_results = interface.get_fundamentals_openai(
|
||||
ticker, curr_date
|
||||
)
|
||||
|
||||
return openai_fundamentals_results
|
||||
|
|
@ -1,312 +0,0 @@
|
|||
"""
|
||||
New Toolkit class using Service/Client/Repository architecture with JSON context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
from tradingagents.services.builders import build_toolkit_services
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tradingagents.services.market_data_service import MarketDataService
|
||||
from tradingagents.services.news_service import NewsService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CONFIG = TradingAgentsConfig()
|
||||
|
||||
|
||||
def create_msg_delete():
|
||||
"""Create message deletion function for agents."""
|
||||
|
||||
def delete_messages(state):
|
||||
"""Clear messages and add placeholder for Anthropic compatibility"""
|
||||
messages = state["messages"]
|
||||
|
||||
# Remove all messages
|
||||
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||
|
||||
# Add a minimal placeholder message
|
||||
placeholder = HumanMessage(content="Continue")
|
||||
|
||||
return {"messages": removal_operations + [placeholder]}
|
||||
|
||||
return delete_messages
|
||||
|
||||
|
||||
class Toolkit:
|
||||
"""
|
||||
New Toolkit class that uses services to provide JSON context to agents.
|
||||
|
||||
This replaces the old interface.py approach with structured Pydantic models
|
||||
that agents can process more dynamically.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TradingAgentsConfig | None = None,
|
||||
services: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize Toolkit with services.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
services: Pre-built services dict, or None to build from config
|
||||
"""
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
|
||||
if services:
|
||||
self.services = services
|
||||
else:
|
||||
logger.info("Building services from config")
|
||||
self.services = build_toolkit_services(self.config)
|
||||
|
||||
# Set up individual services
|
||||
self.market_service: MarketDataService | None = self.services.get("market_data")
|
||||
self.news_service: NewsService | None = self.services.get("news")
|
||||
|
||||
logger.info(f"Toolkit initialized with {len(self.services)} services")
|
||||
|
||||
# Market Data Tools
|
||||
def get_market_data(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve market data context for a given ticker symbol.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: JSON context containing market data with price data and metadata
|
||||
"""
|
||||
if not self.market_service:
|
||||
return self._create_error_context("MarketDataService not available")
|
||||
|
||||
try:
|
||||
context = self.market_service.get_price_context(
|
||||
symbol, start_date, end_date
|
||||
)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market data for {symbol}: {e}")
|
||||
return self._create_error_context(f"Error fetching market data: {str(e)}")
|
||||
|
||||
@tool
|
||||
def get_market_data_with_indicators(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
indicators: Annotated[
|
||||
str, "Comma-separated list of indicators (e.g. 'rsi,macd,close_50_sma')"
|
||||
] = "rsi,macd",
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve market data context with technical indicators.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
indicators (str): Comma-separated indicators
|
||||
Returns:
|
||||
str: JSON context containing market data with technical indicators
|
||||
"""
|
||||
if not self.market_service:
|
||||
return self._create_error_context("MarketDataService not available")
|
||||
|
||||
try:
|
||||
indicator_list = [i.strip() for i in indicators.split(",") if i.strip()]
|
||||
context = self.market_service.get_context(
|
||||
symbol, start_date, end_date, indicators=indicator_list
|
||||
)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market data with indicators for {symbol}: {e}")
|
||||
return self._create_error_context(
|
||||
f"Error fetching market data with indicators: {str(e)}"
|
||||
)
|
||||
|
||||
# News Tools
|
||||
@tool
|
||||
def get_company_news(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve news context for a specific company.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: JSON context containing news articles, sentiment analysis, and metadata
|
||||
"""
|
||||
if not self.news_service:
|
||||
return self._create_error_context("NewsService not available")
|
||||
|
||||
try:
|
||||
context = self.news_service.get_company_news_context(
|
||||
symbol, start_date, end_date
|
||||
)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting company news for {symbol}: {e}")
|
||||
return self._create_error_context(f"Error fetching company news: {str(e)}")
|
||||
|
||||
@tool
|
||||
def get_global_news(
|
||||
self,
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
categories: Annotated[
|
||||
str, "Comma-separated news categories (e.g. 'economy,markets,finance')"
|
||||
] = "economy,markets",
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global/macro news context.
|
||||
Args:
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
categories (str): Comma-separated news categories
|
||||
Returns:
|
||||
str: JSON context containing global news articles and sentiment analysis
|
||||
"""
|
||||
if not self.news_service:
|
||||
return self._create_error_context("NewsService not available")
|
||||
|
||||
try:
|
||||
category_list = [c.strip() for c in categories.split(",") if c.strip()]
|
||||
context = self.news_service.get_global_news_context(
|
||||
start_date, end_date, categories=category_list
|
||||
)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting global news: {e}")
|
||||
return self._create_error_context(f"Error fetching global news: {str(e)}")
|
||||
|
||||
@tool
|
||||
def get_news_by_query(
|
||||
self,
|
||||
query: Annotated[str, "Search query for news"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve news context for a specific query.
|
||||
Args:
|
||||
query (str): Search query for news
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: JSON context containing news articles and sentiment analysis
|
||||
"""
|
||||
if not self.news_service:
|
||||
return self._create_error_context("NewsService not available")
|
||||
|
||||
try:
|
||||
context = self.news_service.get_context(query, start_date, end_date)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting news for query '{query}': {e}")
|
||||
return self._create_error_context(f"Error fetching news: {str(e)}")
|
||||
|
||||
# Legacy compatibility methods (return JSON instead of markdown)
|
||||
@tool
|
||||
def get_YFin_data(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Legacy method: Retrieve market data (now returns JSON context).
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: JSON context containing market data
|
||||
"""
|
||||
return self.get_market_data(symbol, start_date, end_date)
|
||||
|
||||
@tool
|
||||
def get_finnhub_news(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Legacy method: Retrieve company news (now returns JSON context).
|
||||
Args:
|
||||
ticker (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: JSON context containing news data
|
||||
"""
|
||||
return self.get_company_news(ticker, start_date, end_date)
|
||||
|
||||
# Utility methods
|
||||
def _create_error_context(self, error_message: str) -> str:
|
||||
"""Create a JSON error context."""
|
||||
error_context = {
|
||||
"error": True,
|
||||
"message": error_message,
|
||||
"metadata": {
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"source": "toolkit",
|
||||
},
|
||||
}
|
||||
import json
|
||||
|
||||
return json.dumps(error_context, indent=2)
|
||||
|
||||
def get_available_tools(self) -> list:
|
||||
"""Get list of available tools based on configured services."""
|
||||
tools = []
|
||||
|
||||
if self.market_service:
|
||||
tools.extend(
|
||||
[
|
||||
"get_market_data",
|
||||
"get_market_data_with_indicators",
|
||||
"get_YFin_data", # legacy
|
||||
]
|
||||
)
|
||||
|
||||
if self.news_service:
|
||||
tools.extend(
|
||||
[
|
||||
"get_company_news",
|
||||
"get_global_news",
|
||||
"get_news_by_query",
|
||||
"get_finnhub_news", # legacy
|
||||
]
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
def get_toolkit_info(self) -> dict[str, Any]:
|
||||
"""Get information about the toolkit configuration."""
|
||||
return {
|
||||
"toolkit_type": "service_based",
|
||||
"config": {
|
||||
"online_mode": self.config.online_tools,
|
||||
"data_dir": self.config.data_dir,
|
||||
},
|
||||
"services": list(self.services.keys()),
|
||||
"available_tools": self.get_available_tools(),
|
||||
}
|
||||
|
|
@ -1,622 +0,0 @@
|
|||
"""
|
||||
Service-based toolkit for agents using the new JSON context services.
|
||||
Replaces the old markdown-based interface.py with structured service calls.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
from tradingagents.services.fundamental_data_service import FundamentalDataService
|
||||
from tradingagents.services.insider_data_service import InsiderDataService
|
||||
from tradingagents.services.market_data_service import MarketDataService
|
||||
from tradingagents.services.news_service import NewsService
|
||||
from tradingagents.services.openai_data_service import OpenAIDataService
|
||||
from tradingagents.services.social_media_service import SocialMediaService
|
||||
|
||||
DEFAULT_CONFIG = TradingAgentsConfig()
|
||||
|
||||
|
||||
def create_msg_delete():
|
||||
"""Create message deletion function for Anthropic compatibility."""
|
||||
|
||||
def delete_messages(state):
|
||||
"""Clear messages and add placeholder for Anthropic compatibility"""
|
||||
messages = state["messages"]
|
||||
|
||||
# Remove all messages
|
||||
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||
|
||||
# Add a minimal placeholder message
|
||||
placeholder = HumanMessage(content="Continue")
|
||||
|
||||
return {"messages": removal_operations + [placeholder]}
|
||||
|
||||
return delete_messages
|
||||
|
||||
|
||||
class ServiceToolkit:
|
||||
"""Service-based toolkit using the new JSON context services."""
|
||||
|
||||
def __init__(self, config: TradingAgentsConfig | None = None):
|
||||
"""
|
||||
Initialize the service toolkit.
|
||||
|
||||
Args:
|
||||
config: Configuration object for services
|
||||
"""
|
||||
self._config = config or DEFAULT_CONFIG
|
||||
|
||||
# Services will be lazily initialized
|
||||
self._market_service = None
|
||||
self._news_service = None
|
||||
self._social_service = None
|
||||
self._fundamental_service = None
|
||||
self._insider_service = None
|
||||
self._openai_service = None
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
"""Access the configuration."""
|
||||
return self._config
|
||||
|
||||
def update_config(self, config: TradingAgentsConfig):
|
||||
"""Update the configuration and reset services."""
|
||||
self._config = config
|
||||
# Reset services to force re-initialization with new config
|
||||
self._market_service = None
|
||||
self._news_service = None
|
||||
self._social_service = None
|
||||
self._fundamental_service = None
|
||||
self._insider_service = None
|
||||
self._openai_service = None
|
||||
|
||||
def _get_market_service(self) -> MarketDataService:
|
||||
"""Lazy initialization of market data service."""
|
||||
if self._market_service is None:
|
||||
# This would typically use a service factory/builder
|
||||
# For now, return a basic service
|
||||
from tradingagents.services.builders import create_market_data_service
|
||||
|
||||
self._market_service = create_market_data_service(self._config)
|
||||
return self._market_service
|
||||
|
||||
def _get_news_service(self) -> NewsService:
|
||||
"""Lazy initialization of news service."""
|
||||
if self._news_service is None:
|
||||
from tradingagents.services.builders import create_news_service
|
||||
|
||||
self._news_service = create_news_service(self._config)
|
||||
return self._news_service
|
||||
|
||||
def _get_social_service(self) -> SocialMediaService:
|
||||
"""Lazy initialization of social media service."""
|
||||
if self._social_service is None:
|
||||
from tradingagents.services.builders import create_social_media_service
|
||||
|
||||
self._social_service = create_social_media_service(self._config)
|
||||
return self._social_service
|
||||
|
||||
def _get_fundamental_service(self) -> FundamentalDataService:
|
||||
"""Lazy initialization of fundamental data service."""
|
||||
if self._fundamental_service is None:
|
||||
from tradingagents.services.builders import create_fundamental_data_service
|
||||
|
||||
self._fundamental_service = create_fundamental_data_service(self._config)
|
||||
return self._fundamental_service
|
||||
|
||||
def _get_insider_service(self) -> InsiderDataService:
|
||||
"""Lazy initialization of insider data service."""
|
||||
if self._insider_service is None:
|
||||
from tradingagents.services.builders import create_insider_data_service
|
||||
|
||||
self._insider_service = create_insider_data_service(self._config)
|
||||
return self._insider_service
|
||||
|
||||
def _get_openai_service(self) -> OpenAIDataService:
|
||||
"""Lazy initialization of OpenAI data service."""
|
||||
if self._openai_service is None:
|
||||
from tradingagents.services.builders import create_openai_data_service
|
||||
|
||||
self._openai_service = create_openai_data_service(self._config)
|
||||
return self._openai_service
|
||||
|
||||
def _context_to_string(self, context: Any) -> str:
|
||||
"""Convert a context object to a formatted string for agents."""
|
||||
# For now, convert to JSON string
|
||||
# In the future, we might want more sophisticated formatting
|
||||
return json.dumps(context.model_dump(), indent=2, default=str)
|
||||
|
||||
def _calculate_date_range(
|
||||
self, curr_date: str, look_back_days: int
|
||||
) -> tuple[str, str]:
|
||||
"""Calculate start and end dates from current date and lookback days."""
|
||||
end_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
start_date = end_date - timedelta(days=look_back_days)
|
||||
return start_date.strftime("%Y-%m-%d"), end_date.strftime("%Y-%m-%d")
|
||||
|
||||
@tool
|
||||
def get_reddit_news(
|
||||
self,
|
||||
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global news from Reddit within a specified time frame.
|
||||
Args:
|
||||
curr_date (str): Date you want to get news for in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A JSON-formatted context containing the latest global news from Reddit.
|
||||
"""
|
||||
start_date, end_date = self._calculate_date_range(curr_date, 7)
|
||||
|
||||
social_service = self._get_social_service()
|
||||
context = social_service.get_global_trends(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
subreddits=["news", "worldnews", "Economics"],
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_finnhub_news(
|
||||
self,
|
||||
ticker: Annotated[str, "Search query of a company, e.g. 'AAPL, TSM, etc."],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest news about a given stock from Finnhub within a date range.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A JSON-formatted context containing news about the company.
|
||||
"""
|
||||
news_service = self._get_news_service()
|
||||
context = news_service.get_company_news_context(
|
||||
symbol=ticker, start_date=start_date, end_date=end_date, sources=["finnhub"]
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_reddit_stock_info(
|
||||
self,
|
||||
ticker: Annotated[str, "Ticker of a company. e.g. AAPL, TSM"],
|
||||
curr_date: Annotated[str, "Current date you want to get news for"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest news about a given stock from Reddit.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): current date in yyyy-mm-dd format to get news for
|
||||
Returns:
|
||||
str: A JSON-formatted context containing the latest news about the company.
|
||||
"""
|
||||
start_date, end_date = self._calculate_date_range(curr_date, 7)
|
||||
|
||||
social_service = self._get_social_service()
|
||||
context = social_service.get_company_social_context(
|
||||
symbol=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
subreddits=["investing", "stocks", "wallstreetbets"],
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_YFin_data(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A JSON-formatted context containing stock price data and technical indicators.
|
||||
"""
|
||||
market_service = self._get_market_service()
|
||||
context = market_service.get_context(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
force_refresh=False, # Use local-first strategy
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_YFin_data_online(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve fresh stock price data for a given ticker symbol from Yahoo Finance.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A JSON-formatted context containing fresh stock price data.
|
||||
"""
|
||||
market_service = self._get_market_service()
|
||||
context = market_service.get_context(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
force_refresh=True, # Force fresh data
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_stockstats_indicators_report(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[
|
||||
str, "technical indicator to get the analysis and report of"
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve stock stats indicators for a given ticker symbol and indicator.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
indicator (str): Technical indicator to get the analysis and report of
|
||||
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
||||
look_back_days (int): How many days to look back, default is 30
|
||||
Returns:
|
||||
str: A JSON-formatted context containing technical indicators.
|
||||
"""
|
||||
start_date, end_date = self._calculate_date_range(curr_date, look_back_days)
|
||||
|
||||
market_service = self._get_market_service()
|
||||
context = market_service.get_context(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indicators=[indicator],
|
||||
force_refresh=False,
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_stockstats_indicators_report_online(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
indicator: Annotated[
|
||||
str, "technical indicator to get the analysis and report of"
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "The current trading date you are trading on, YYYY-mm-dd"
|
||||
],
|
||||
look_back_days: Annotated[int, "how many days to look back"] = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve fresh stock stats indicators for a given ticker symbol and indicator.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
indicator (str): Technical indicator to get the analysis and report of
|
||||
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
|
||||
look_back_days (int): How many days to look back, default is 30
|
||||
Returns:
|
||||
str: A JSON-formatted context containing fresh technical indicators.
|
||||
"""
|
||||
start_date, end_date = self._calculate_date_range(curr_date, look_back_days)
|
||||
|
||||
market_service = self._get_market_service()
|
||||
context = market_service.get_context(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indicators=[indicator],
|
||||
force_refresh=True,
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_finnhub_company_insider_sentiment(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol for the company"],
|
||||
curr_date: Annotated[str, "current date of you are trading at, yyyy-mm-dd"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve insider sentiment information about a company for the past 30 days.
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: A JSON-formatted context with insider trading sentiment analysis.
|
||||
"""
|
||||
start_date, end_date = self._calculate_date_range(curr_date, 30)
|
||||
|
||||
insider_service = self._get_insider_service()
|
||||
context = insider_service.get_insider_context(
|
||||
symbol=ticker, start_date=start_date, end_date=end_date
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_finnhub_company_insider_transactions(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve insider transaction information about a company for the past 30 days.
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: A JSON-formatted context with insider transaction details.
|
||||
"""
|
||||
start_date, end_date = self._calculate_date_range(curr_date, 30)
|
||||
|
||||
insider_service = self._get_insider_service()
|
||||
context = insider_service.get_insider_context(
|
||||
symbol=ticker, start_date=start_date, end_date=end_date
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_simfin_balance_sheet(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[str, "reporting frequency: annual/quarterly"],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the most recent balance sheet of a company.
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: A JSON-formatted context with the company's balance sheet.
|
||||
"""
|
||||
# Use a reasonable date range for fundamental data
|
||||
start_date = (
|
||||
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=365)
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
fundamental_service = self._get_fundamental_service()
|
||||
context = fundamental_service.get_fundamental_context(
|
||||
symbol=ticker, start_date=start_date, end_date=curr_date, frequency=freq
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_simfin_cashflow(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[str, "reporting frequency: annual/quarterly"],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the most recent cash flow statement of a company.
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: A JSON-formatted context with the company's cash flow statement.
|
||||
"""
|
||||
start_date = (
|
||||
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=365)
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
fundamental_service = self._get_fundamental_service()
|
||||
context = fundamental_service.get_fundamental_context(
|
||||
symbol=ticker, start_date=start_date, end_date=curr_date, frequency=freq
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_simfin_income_stmt(
|
||||
self,
|
||||
ticker: Annotated[str, "ticker symbol"],
|
||||
freq: Annotated[str, "reporting frequency: annual/quarterly"],
|
||||
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the most recent income statement of a company.
|
||||
Args:
|
||||
ticker (str): ticker symbol of the company
|
||||
freq (str): reporting frequency: annual / quarterly
|
||||
curr_date (str): current date you are trading at, yyyy-mm-dd
|
||||
Returns:
|
||||
str: A JSON-formatted context with the company's income statement.
|
||||
"""
|
||||
start_date = (
|
||||
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=365)
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
fundamental_service = self._get_fundamental_service()
|
||||
context = fundamental_service.get_fundamental_context(
|
||||
symbol=ticker, start_date=start_date, end_date=curr_date, frequency=freq
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_google_news(
|
||||
self,
|
||||
query: Annotated[str, "Query to search with"],
|
||||
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve the latest news from Google News based on a query and date range.
|
||||
Args:
|
||||
query (str): Query to search with
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A JSON-formatted context containing the latest news from Google News.
|
||||
"""
|
||||
start_date, end_date = self._calculate_date_range(curr_date, 7)
|
||||
|
||||
news_service = self._get_news_service()
|
||||
context = news_service.get_context(
|
||||
query=query, start_date=start_date, end_date=end_date, sources=["google"]
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_stock_news_openai(
|
||||
self,
|
||||
ticker: Annotated[str, "the company's ticker"],
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve AI-generated news analysis about a given stock.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A JSON-formatted context with AI news analysis.
|
||||
"""
|
||||
# First get the news context
|
||||
start_date, end_date = self._calculate_date_range(curr_date, 7)
|
||||
news_service = self._get_news_service()
|
||||
news_context = news_service.get_company_news_context(
|
||||
symbol=ticker, start_date=start_date, end_date=end_date
|
||||
)
|
||||
|
||||
# Then get AI analysis of the news
|
||||
openai_service = self._get_openai_service()
|
||||
context = openai_service.get_news_impact_analysis(
|
||||
symbol=ticker, news_context=self._context_to_string(news_context)
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_global_news_openai(
|
||||
self,
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve AI-generated macroeconomic news analysis.
|
||||
Args:
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A JSON-formatted context with AI macroeconomic analysis.
|
||||
"""
|
||||
start_date, end_date = self._calculate_date_range(curr_date, 7)
|
||||
|
||||
# Get global news context
|
||||
news_service = self._get_news_service()
|
||||
news_context = news_service.get_global_news_context(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
categories=["economy", "markets", "finance"],
|
||||
)
|
||||
|
||||
# Get AI analysis
|
||||
openai_service = self._get_openai_service()
|
||||
context = openai_service.get_news_impact_analysis(
|
||||
symbol="GLOBAL", # Use a placeholder for global analysis
|
||||
news_context=self._context_to_string(news_context),
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
@tool
|
||||
def get_fundamentals_openai(
|
||||
self,
|
||||
ticker: Annotated[str, "the company's ticker"],
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve AI-generated fundamental analysis about a given stock.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A JSON-formatted context with AI fundamental analysis.
|
||||
"""
|
||||
# Get fundamental data context
|
||||
start_date = (
|
||||
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=365)
|
||||
).strftime("%Y-%m-%d")
|
||||
fundamental_service = self._get_fundamental_service()
|
||||
fundamental_context = fundamental_service.get_fundamental_context(
|
||||
symbol=ticker,
|
||||
start_date=start_date,
|
||||
end_date=curr_date,
|
||||
frequency="quarterly",
|
||||
)
|
||||
|
||||
# Get market data for additional context
|
||||
market_start = (
|
||||
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=30)
|
||||
).strftime("%Y-%m-%d")
|
||||
market_service = self._get_market_service()
|
||||
market_context = market_service.get_context(
|
||||
symbol=ticker, start_date=market_start, end_date=curr_date
|
||||
)
|
||||
|
||||
# Combine contexts for AI analysis
|
||||
combined_context = {
|
||||
"fundamental_data": fundamental_context.model_dump(),
|
||||
"market_data": market_context.model_dump(),
|
||||
}
|
||||
|
||||
# Get AI analysis
|
||||
openai_service = self._get_openai_service()
|
||||
context = openai_service.get_market_sentiment_analysis(
|
||||
symbol=ticker, market_data_context=json.dumps(combined_context, default=str)
|
||||
)
|
||||
|
||||
return self._context_to_string(context)
|
||||
|
||||
|
||||
# Create a default instance for backward compatibility
|
||||
default_toolkit = ServiceToolkit()
|
||||
|
||||
# Export individual tools for use in agents
|
||||
get_reddit_news = default_toolkit.get_reddit_news
|
||||
get_finnhub_news = default_toolkit.get_finnhub_news
|
||||
get_reddit_stock_info = default_toolkit.get_reddit_stock_info
|
||||
get_YFin_data = default_toolkit.get_YFin_data
|
||||
get_YFin_data_online = default_toolkit.get_YFin_data_online
|
||||
get_stockstats_indicators_report = default_toolkit.get_stockstats_indicators_report
|
||||
get_stockstats_indicators_report_online = (
|
||||
default_toolkit.get_stockstats_indicators_report_online
|
||||
)
|
||||
get_finnhub_company_insider_sentiment = (
|
||||
default_toolkit.get_finnhub_company_insider_sentiment
|
||||
)
|
||||
get_finnhub_company_insider_transactions = (
|
||||
default_toolkit.get_finnhub_company_insider_transactions
|
||||
)
|
||||
get_simfin_balance_sheet = default_toolkit.get_simfin_balance_sheet
|
||||
get_simfin_cashflow = default_toolkit.get_simfin_cashflow
|
||||
get_simfin_income_stmt = default_toolkit.get_simfin_income_stmt
|
||||
get_google_news = default_toolkit.get_google_news
|
||||
get_stock_news_openai = default_toolkit.get_stock_news_openai
|
||||
get_global_news_openai = default_toolkit.get_global_news_openai
|
||||
get_fundamentals_openai = default_toolkit.get_fundamentals_openai
|
||||
|
|
@ -1,267 +0,0 @@
|
|||
"""
|
||||
Updated Toolkit class using Service/Client/Repository architecture with JSON context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, RemoveMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
from tradingagents.services.builders import build_toolkit_services
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tradingagents.services.market_data_service import MarketDataService
|
||||
from tradingagents.services.news_service import NewsService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CONFIG = TradingAgentsConfig()
|
||||
|
||||
|
||||
def create_msg_delete():
|
||||
"""Create message deletion function for agents."""
|
||||
|
||||
def delete_messages(state):
|
||||
"""Clear messages and add placeholder for Anthropic compatibility"""
|
||||
messages = state["messages"]
|
||||
|
||||
# Remove all messages
|
||||
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||
|
||||
# Add a minimal placeholder message
|
||||
placeholder = HumanMessage(content="Continue")
|
||||
|
||||
return {"messages": removal_operations + [placeholder]}
|
||||
|
||||
return delete_messages
|
||||
|
||||
|
||||
class Toolkit:
|
||||
"""
|
||||
Toolkit class that uses services to provide JSON context to agents.
|
||||
|
||||
This replaces the old interface.py approach with structured Pydantic models
|
||||
that agents can process more dynamically.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TradingAgentsConfig | None = None,
|
||||
services: dict[str, Any] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize Toolkit with services.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
services: Pre-built services dict, or None to build from config
|
||||
"""
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
|
||||
if services:
|
||||
self.services = services
|
||||
else:
|
||||
logger.info("Building services from config")
|
||||
self.services = build_toolkit_services(self.config)
|
||||
|
||||
# Set up individual services
|
||||
self.market_service: MarketDataService | None = self.services.get("market_data")
|
||||
self.news_service: NewsService | None = self.services.get("news")
|
||||
|
||||
logger.info(f"Toolkit initialized with {len(self.services)} services")
|
||||
|
||||
# Create tool methods as static methods with service access via closure
|
||||
def _create_market_data_tool(self):
|
||||
"""Create market data tool with service access."""
|
||||
market_service = self.market_service
|
||||
|
||||
@tool
|
||||
def get_market_data(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve market data context for a given ticker symbol.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: JSON context containing market data with price data and metadata
|
||||
"""
|
||||
if not market_service:
|
||||
return _create_error_context("MarketDataService not available")
|
||||
|
||||
try:
|
||||
context = market_service.get_price_context(symbol, start_date, end_date)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market data for {symbol}: {e}")
|
||||
return _create_error_context(f"Error fetching market data: {str(e)}")
|
||||
|
||||
return get_market_data
|
||||
|
||||
def _create_market_indicators_tool(self):
|
||||
"""Create market data with indicators tool."""
|
||||
market_service = self.market_service
|
||||
|
||||
@tool
|
||||
def get_market_data_with_indicators(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
indicators: Annotated[
|
||||
str, "Comma-separated list of indicators (e.g. 'rsi,macd,close_50_sma')"
|
||||
] = "rsi,macd",
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve market data context with technical indicators.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
indicators (str): Comma-separated indicators
|
||||
Returns:
|
||||
str: JSON context containing market data with technical indicators
|
||||
"""
|
||||
if not market_service:
|
||||
return _create_error_context("MarketDataService not available")
|
||||
|
||||
try:
|
||||
indicator_list = [i.strip() for i in indicators.split(",") if i.strip()]
|
||||
context = market_service.get_context(
|
||||
symbol, start_date, end_date, indicators=indicator_list
|
||||
)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting market data with indicators for {symbol}: {e}"
|
||||
)
|
||||
return _create_error_context(
|
||||
f"Error fetching market data with indicators: {str(e)}"
|
||||
)
|
||||
|
||||
return get_market_data_with_indicators
|
||||
|
||||
def _create_company_news_tool(self):
|
||||
"""Create company news tool."""
|
||||
news_service = self.news_service
|
||||
|
||||
@tool
|
||||
def get_company_news(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve news context for a specific company.
|
||||
Args:
|
||||
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: JSON context containing news articles, sentiment analysis, and metadata
|
||||
"""
|
||||
if not news_service:
|
||||
return _create_error_context("NewsService not available")
|
||||
|
||||
try:
|
||||
context = news_service.get_company_news_context(
|
||||
symbol, start_date, end_date
|
||||
)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting company news for {symbol}: {e}")
|
||||
return _create_error_context(f"Error fetching company news: {str(e)}")
|
||||
|
||||
return get_company_news
|
||||
|
||||
def _create_global_news_tool(self):
|
||||
"""Create global news tool."""
|
||||
news_service = self.news_service
|
||||
|
||||
@tool
|
||||
def get_global_news(
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
categories: Annotated[
|
||||
str, "Comma-separated news categories (e.g. 'economy,markets,finance')"
|
||||
] = "economy,markets",
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global/macro news context.
|
||||
Args:
|
||||
start_date (str): Start date in yyyy-mm-dd format
|
||||
end_date (str): End date in yyyy-mm-dd format
|
||||
categories (str): Comma-separated news categories
|
||||
Returns:
|
||||
str: JSON context containing global news articles and sentiment analysis
|
||||
"""
|
||||
if not news_service:
|
||||
return _create_error_context("NewsService not available")
|
||||
|
||||
try:
|
||||
category_list = [c.strip() for c in categories.split(",") if c.strip()]
|
||||
context = news_service.get_global_news_context(
|
||||
start_date, end_date, categories=category_list
|
||||
)
|
||||
return context.model_dump_json(indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting global news: {e}")
|
||||
return _create_error_context(f"Error fetching global news: {str(e)}")
|
||||
|
||||
return get_global_news
|
||||
|
||||
def get_tools(self):
|
||||
"""Get all available tools as LangChain tools."""
|
||||
tools = []
|
||||
|
||||
if self.market_service:
|
||||
tools.append(self._create_market_data_tool())
|
||||
tools.append(self._create_market_indicators_tool())
|
||||
|
||||
if self.news_service:
|
||||
tools.append(self._create_company_news_tool())
|
||||
tools.append(self._create_global_news_tool())
|
||||
|
||||
return tools
|
||||
|
||||
def get_available_tools(self) -> list:
|
||||
"""Get list of available tool names based on configured services."""
|
||||
tools = []
|
||||
|
||||
if self.market_service:
|
||||
tools.extend(["get_market_data", "get_market_data_with_indicators"])
|
||||
|
||||
if self.news_service:
|
||||
tools.extend(["get_company_news", "get_global_news"])
|
||||
|
||||
return tools
|
||||
|
||||
def get_toolkit_info(self) -> dict[str, Any]:
|
||||
"""Get information about the toolkit configuration."""
|
||||
return {
|
||||
"toolkit_type": "service_based",
|
||||
"config": {
|
||||
"online_mode": self.config.online_tools,
|
||||
"data_dir": self.config.data_dir,
|
||||
},
|
||||
"services": list(self.services.keys()),
|
||||
"available_tools": self.get_available_tools(),
|
||||
}
|
||||
|
||||
|
||||
def _create_error_context(error_message: str) -> str:
|
||||
"""Create a JSON error context."""
|
||||
import json
|
||||
|
||||
error_context = {
|
||||
"error": True,
|
||||
"message": error_message,
|
||||
"metadata": {"created_at": datetime.utcnow().isoformat(), "source": "toolkit"},
|
||||
}
|
||||
return json.dumps(error_context, indent=2)
|
||||
|
|
@ -2,18 +2,6 @@
|
|||
Client classes for live data access in TradingAgents.
|
||||
"""
|
||||
|
||||
# Re-export existing clients from dataflows
|
||||
from tradingagents.dataflows.reddit_utils import RedditClient
|
||||
|
||||
from .base import BaseClient
|
||||
from .finnhub_client import FinnhubClient
|
||||
from .google_news_client import GoogleNewsClient
|
||||
from .yfinance_client import YFinanceClient
|
||||
|
||||
__all__ = [
|
||||
"BaseClient",
|
||||
"YFinanceClient",
|
||||
"GoogleNewsClient",
|
||||
"FinnhubClient",
|
||||
"RedditClient",
|
||||
]
|
||||
__all__ = ["BaseClient"]
|
||||
|
|
|
|||
|
|
@ -1,100 +1,23 @@
|
|||
"""
|
||||
Base client abstraction for live data access.
|
||||
Base client interface for TradingAgents data sources.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseClient(ABC):
|
||||
"""
|
||||
Base class for all data clients that access live APIs.
|
||||
|
||||
Provides common interface for different data sources while allowing
|
||||
each client to implement its specific data fetching logic.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize client with configuration."""
|
||||
self.config = kwargs
|
||||
"""Abstract base class for all data clients."""
|
||||
|
||||
@abstractmethod
|
||||
def test_connection(self) -> bool:
|
||||
def get_data(self, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Test if the client can connect to its data source.
|
||||
|
||||
Returns:
|
||||
bool: True if connection is successful, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_data(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""
|
||||
Get data from the client's data source.
|
||||
Get data from the client source.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments
|
||||
**kwargs: Client-specific parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Raw data from the source
|
||||
dict: Data dictionary with standardized structure
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_available_symbols(self) -> list[str]:
|
||||
"""
|
||||
Get list of available symbols/tickers from this data source.
|
||||
|
||||
Returns:
|
||||
List[str]: Available symbols, empty list if not supported
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_data_range(
|
||||
self, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get data for a specific date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional client-specific parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Data for the specified range
|
||||
"""
|
||||
return self.get_data(start_date=start_date, end_date=end_date, **kwargs)
|
||||
|
||||
def validate_date_range(self, start_date: str, end_date: str) -> bool:
|
||||
"""
|
||||
Validate that the date range is acceptable.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
bool: True if date range is valid
|
||||
"""
|
||||
try:
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
return start_dt <= end_dt <= datetime.now()
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def get_client_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about this client.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Client metadata
|
||||
"""
|
||||
return {
|
||||
"client_type": self.__class__.__name__,
|
||||
"supports_symbols": len(self.get_available_symbols()) > 0,
|
||||
"config": self.config,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,32 +0,0 @@
|
|||
"""
|
||||
Pytest configuration for FinnhubClient tests with VCR.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config():
|
||||
"""Configure VCR for recording/replaying HTTP interactions."""
|
||||
return {
|
||||
# Don't record the API key in cassettes
|
||||
"filter_headers": ["X-Finnhub-Token", "Authorization"],
|
||||
# Record once, then replay from cassettes
|
||||
"record_mode": "once",
|
||||
# Match requests on URI and method
|
||||
"match_on": ["uri", "method"],
|
||||
# Decode compressed responses for better readability
|
||||
"decode_compressed_response": True,
|
||||
# Store cassettes in the cassettes subdirectory
|
||||
"cassette_library_dir": "tradingagents/clients/cassettes",
|
||||
# Ignore localhost requests
|
||||
"ignore_localhost": True,
|
||||
# Custom serializer for better readability
|
||||
"serializer": "yaml",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vcr_cassette_dir(tmp_path_factory):
|
||||
"""Create temporary directory for VCR cassettes during testing."""
|
||||
return tmp_path_factory.mktemp("cassettes")
|
||||
|
|
@ -42,6 +42,10 @@ class TradingAgentsConfig:
|
|||
# Tool settings
|
||||
online_tools: bool = True
|
||||
|
||||
# Data retrieval settings
|
||||
default_lookback_days: int = 30
|
||||
default_ta_lookback_days: int = 30
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set computed fields after initialization."""
|
||||
self.data_cache_dir = os.path.join(self.project_dir, "dataflows/data_cache")
|
||||
|
|
@ -79,6 +83,8 @@ class TradingAgentsConfig:
|
|||
max_risk_discuss_rounds=int(os.getenv("MAX_RISK_DISCUSS_ROUNDS", "1")),
|
||||
max_recur_limit=int(os.getenv("MAX_RECUR_LIMIT", "100")),
|
||||
online_tools=os.getenv("ONLINE_TOOLS", "true").lower() == "true",
|
||||
default_lookback_days=int(os.getenv("DEFAULT_LOOKBACK_DAYS", "30")),
|
||||
default_ta_lookback_days=int(os.getenv("DEFAULT_TA_LOOKBACK_DAYS", "30")),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
|
|
@ -96,6 +102,8 @@ class TradingAgentsConfig:
|
|||
"max_risk_discuss_rounds": self.max_risk_discuss_rounds,
|
||||
"max_recur_limit": self.max_recur_limit,
|
||||
"online_tools": self.online_tools,
|
||||
"default_lookback_days": self.default_lookback_days,
|
||||
"default_ta_lookback_days": self.default_ta_lookback_days,
|
||||
}
|
||||
|
||||
def copy(self) -> "TradingAgentsConfig":
|
||||
|
|
@ -112,6 +120,8 @@ class TradingAgentsConfig:
|
|||
max_risk_discuss_rounds=self.max_risk_discuss_rounds,
|
||||
max_recur_limit=self.max_recur_limit,
|
||||
online_tools=self.online_tools,
|
||||
default_lookback_days=self.default_lookback_days,
|
||||
default_ta_lookback_days=self.default_ta_lookback_days,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,120 +0,0 @@
|
|||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_result,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
|
||||
def is_rate_limited(response):
|
||||
"""Check if the response indicates rate limiting (status code 429)"""
|
||||
return response.status_code == 429
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(retry_if_result(is_rate_limited)),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
stop=stop_after_attempt(5),
|
||||
)
|
||||
def make_request(url, headers):
|
||||
"""Make a request with retry logic for rate limiting"""
|
||||
# Random delay before each request to avoid detection
|
||||
time.sleep(random.uniform(2, 6))
|
||||
response = requests.get(url, headers=headers)
|
||||
return response
|
||||
|
||||
|
||||
def getNewsData(query, start_date, end_date):
|
||||
"""
|
||||
Scrape Google News search results for a given query and date range.
|
||||
query: str - search query
|
||||
start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy
|
||||
end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy
|
||||
"""
|
||||
if "-" in start_date:
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
start_date = start_date.strftime("%m/%d/%Y")
|
||||
if "-" in end_date:
|
||||
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
end_date = end_date.strftime("%m/%d/%Y")
|
||||
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/101.0.4951.54 Safari/537.36"
|
||||
)
|
||||
}
|
||||
|
||||
news_results = []
|
||||
page = 0
|
||||
while True:
|
||||
offset = page * 10
|
||||
url = (
|
||||
f"https://www.google.com/search?q={query}"
|
||||
f"&tbs=cdr:1,cd_min:{start_date},cd_max:{end_date}"
|
||||
f"&tbm=nws&start={offset}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = make_request(url, headers)
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
results_on_page = soup.select("div.SoaBEf")
|
||||
|
||||
if not results_on_page:
|
||||
break # No more results found
|
||||
|
||||
for el in results_on_page:
|
||||
try:
|
||||
link_elem = el.find("a")
|
||||
# Handle BeautifulSoup element access safely
|
||||
if link_elem:
|
||||
link = getattr(link_elem, "attrs", {}).get("href", "")
|
||||
else:
|
||||
link = ""
|
||||
|
||||
title_elem = el.select_one("div.MBeuO")
|
||||
title = title_elem.get_text() if title_elem else ""
|
||||
|
||||
snippet_elem = el.select_one(".GI74Re")
|
||||
snippet = snippet_elem.get_text() if snippet_elem else ""
|
||||
|
||||
date_elem = el.select_one(".LfVVr")
|
||||
date = date_elem.get_text() if date_elem else ""
|
||||
|
||||
source_elem = el.select_one(".NUnG9d span")
|
||||
source = source_elem.get_text() if source_elem else ""
|
||||
news_results.append(
|
||||
{
|
||||
"link": link,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"date": date,
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing result: {e}")
|
||||
# If one of the fields is not found, skip this result
|
||||
continue
|
||||
|
||||
# Update the progress bar with the current count of results scraped
|
||||
|
||||
# Check for the "Next" link (pagination)
|
||||
next_link = soup.find("a", id="pnnext")
|
||||
if not next_link:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed after multiple retries: {e}")
|
||||
break
|
||||
|
||||
return news_results
|
||||
|
|
@ -1,383 +0,0 @@
|
|||
"""
|
||||
Reddit API integration for social sentiment analysis.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import praw
|
||||
|
||||
# from .api_clients import RateLimiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Extended ticker to company mapping
|
||||
ticker_to_company = {
|
||||
"AAPL": "Apple",
|
||||
"MSFT": "Microsoft",
|
||||
"GOOGL": "Google",
|
||||
"AMZN": "Amazon",
|
||||
"TSLA": "Tesla",
|
||||
"NVDA": "Nvidia",
|
||||
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
|
||||
"JPM": "JPMorgan Chase OR JP Morgan",
|
||||
"JNJ": "Johnson & Johnson OR JNJ",
|
||||
"V": "Visa",
|
||||
"WMT": "Walmart",
|
||||
"META": "Meta OR Facebook",
|
||||
"AMD": "AMD",
|
||||
"INTC": "Intel",
|
||||
"QCOM": "Qualcomm",
|
||||
"BABA": "Alibaba",
|
||||
"ADBE": "Adobe",
|
||||
"NFLX": "Netflix",
|
||||
"CRM": "Salesforce",
|
||||
"PYPL": "PayPal",
|
||||
"PLTR": "Palantir",
|
||||
"MU": "Micron",
|
||||
"SQ": "Block OR Square",
|
||||
"ZM": "Zoom",
|
||||
"CSCO": "Cisco",
|
||||
"SHOP": "Shopify",
|
||||
"ORCL": "Oracle",
|
||||
"X": "Twitter OR X",
|
||||
"SPOT": "Spotify",
|
||||
"AVGO": "Broadcom",
|
||||
"ASML": "ASML",
|
||||
"TWLO": "Twilio",
|
||||
"SNAP": "Snap Inc.",
|
||||
"TEAM": "Atlassian",
|
||||
"SQSP": "Squarespace",
|
||||
"UBER": "Uber",
|
||||
"ROKU": "Roku",
|
||||
"PINS": "Pinterest",
|
||||
# Additional popular tickers
|
||||
"SPY": "SPDR S&P 500 ETF",
|
||||
"QQQ": "Invesco QQQ Trust",
|
||||
"GME": "GameStop",
|
||||
"AMC": "AMC Entertainment",
|
||||
"BB": "BlackBerry",
|
||||
"NOK": "Nokia",
|
||||
"COIN": "Coinbase",
|
||||
"HOOD": "Robinhood",
|
||||
"RBLX": "Roblox",
|
||||
"DKNG": "DraftKings",
|
||||
"PENN": "Penn Entertainment",
|
||||
"SOFI": "SoFi Technologies",
|
||||
}
|
||||
|
||||
|
||||
class RedditClient:
|
||||
"""Client for Reddit API with rate limiting and caching."""
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, user_agent: str):
|
||||
"""
|
||||
Initialize Reddit client.
|
||||
|
||||
Args:
|
||||
client_id: Reddit application client ID
|
||||
client_secret: Reddit application client secret
|
||||
user_agent: User agent string for API requests
|
||||
"""
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.user_agent = user_agent
|
||||
|
||||
# Reddit allows 100 requests per minute for script applications
|
||||
# self.rate_limiter = RateLimiter(100, 60)
|
||||
|
||||
# Initialize Reddit instance
|
||||
self.reddit = praw.Reddit(
|
||||
client_id=client_id, client_secret=client_secret, user_agent=user_agent
|
||||
)
|
||||
|
||||
# Default subreddits for different categories
|
||||
self.subreddits = {
|
||||
"investing": ["investing", "SecurityAnalysis", "ValueInvesting", "stocks"],
|
||||
"trading": ["wallstreetbets", "StockMarket", "pennystocks", "options"],
|
||||
"global_news": ["news", "worldnews", "Economics", "business"],
|
||||
"company_news": ["investing", "stocks", "StockMarket", "SecurityAnalysis"],
|
||||
}
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Test if the Reddit API connection is working."""
|
||||
try:
|
||||
# Test by fetching user info (read-only operation)
|
||||
_ = self.reddit.user.me()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Reddit connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def search_posts(
|
||||
self,
|
||||
query: str,
|
||||
subreddit_names: list[str],
|
||||
limit: int = 25,
|
||||
time_filter: str = "week",
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search for posts across multiple subreddits.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
subreddit_names: List of subreddit names to search
|
||||
limit: Maximum number of posts per subreddit
|
||||
time_filter: Time filter ('day', 'week', 'month', 'year', 'all')
|
||||
|
||||
Returns:
|
||||
List of post dictionaries
|
||||
"""
|
||||
posts = []
|
||||
|
||||
for subreddit_name in subreddit_names:
|
||||
try:
|
||||
# self.rate_limiter.wait_if_needed()
|
||||
|
||||
subreddit = self.reddit.subreddit(subreddit_name)
|
||||
|
||||
# Search posts in the subreddit
|
||||
search_results = subreddit.search(
|
||||
query=query, sort="relevance", time_filter=time_filter, limit=limit
|
||||
)
|
||||
|
||||
for submission in search_results:
|
||||
post_data = {
|
||||
"title": submission.title,
|
||||
"content": submission.selftext,
|
||||
"url": submission.url,
|
||||
"upvotes": submission.ups,
|
||||
"score": submission.score,
|
||||
"num_comments": submission.num_comments,
|
||||
"created_utc": submission.created_utc,
|
||||
"subreddit": subreddit_name,
|
||||
"author": str(submission.author)
|
||||
if submission.author
|
||||
else "[deleted]",
|
||||
}
|
||||
posts.append(post_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching subreddit {subreddit_name}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by score (upvotes - downvotes) descending
|
||||
posts.sort(key=lambda x: x["score"], reverse=True)
|
||||
return posts
|
||||
|
||||
def get_top_posts(
|
||||
self, subreddit_names: list[str], limit: int = 25, time_filter: str = "week"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get top posts from multiple subreddits.
|
||||
|
||||
Args:
|
||||
subreddit_names: List of subreddit names
|
||||
limit: Maximum number of posts per subreddit
|
||||
time_filter: Time filter ('day', 'week', 'month', 'year', 'all')
|
||||
|
||||
Returns:
|
||||
List of post dictionaries
|
||||
"""
|
||||
posts = []
|
||||
|
||||
for subreddit_name in subreddit_names:
|
||||
try:
|
||||
# self.rate_limiter.wait_if_needed()
|
||||
|
||||
subreddit = self.reddit.subreddit(subreddit_name)
|
||||
|
||||
# Get top posts from the subreddit
|
||||
top_posts = subreddit.top(time_filter=time_filter, limit=limit)
|
||||
|
||||
for submission in top_posts:
|
||||
post_data = {
|
||||
"title": submission.title,
|
||||
"content": submission.selftext,
|
||||
"url": submission.url,
|
||||
"upvotes": submission.ups,
|
||||
"score": submission.score,
|
||||
"num_comments": submission.num_comments,
|
||||
"created_utc": submission.created_utc,
|
||||
"subreddit": subreddit_name,
|
||||
"author": str(submission.author)
|
||||
if submission.author
|
||||
else "[deleted]",
|
||||
}
|
||||
posts.append(post_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching top posts from {subreddit_name}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by score descending
|
||||
posts.sort(key=lambda x: x["score"], reverse=True)
|
||||
return posts
|
||||
|
||||
def filter_posts_by_date(
|
||||
self, posts: list[dict[str, Any]], start_date: str, end_date: str
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter posts by date range.
|
||||
|
||||
Args:
|
||||
posts: List of post dictionaries
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
Filtered list of posts
|
||||
"""
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(
|
||||
days=1
|
||||
) # Include end date
|
||||
|
||||
filtered_posts = []
|
||||
for post in posts:
|
||||
post_dt = datetime.fromtimestamp(post["created_utc"])
|
||||
if start_dt <= post_dt <= end_dt:
|
||||
# Add formatted date string
|
||||
post["posted_date"] = post_dt.strftime("%Y-%m-%d")
|
||||
filtered_posts.append(post)
|
||||
|
||||
return filtered_posts
|
||||
|
||||
def filter_posts_by_company(
|
||||
self, posts: list[dict[str, Any]], ticker: str
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter posts that mention a specific company/ticker.
|
||||
|
||||
Args:
|
||||
posts: List of post dictionaries
|
||||
ticker: Stock ticker symbol
|
||||
|
||||
Returns:
|
||||
Filtered list of posts that mention the company
|
||||
"""
|
||||
if ticker not in ticker_to_company:
|
||||
# If ticker not in mapping, search for ticker directly
|
||||
search_terms = [ticker.upper()]
|
||||
else:
|
||||
# Get company names and ticker
|
||||
company_names = ticker_to_company[ticker]
|
||||
if "OR" in company_names:
|
||||
search_terms = [name.strip() for name in company_names.split(" OR")]
|
||||
else:
|
||||
search_terms = [company_names]
|
||||
search_terms.append(ticker.upper())
|
||||
|
||||
filtered_posts = []
|
||||
for post in posts:
|
||||
title_text = post["title"].lower()
|
||||
content_text = post["content"].lower()
|
||||
|
||||
# Check if any search term appears in title or content
|
||||
found = False
|
||||
for term in search_terms:
|
||||
term_lower = term.lower()
|
||||
if re.search(
|
||||
r"\b" + re.escape(term_lower) + r"\b", title_text
|
||||
) or re.search(r"\b" + re.escape(term_lower) + r"\b", content_text):
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
filtered_posts.append(post)
|
||||
|
||||
return filtered_posts
|
||||
|
||||
|
||||
def fetch_top_from_category(
|
||||
category: str,
|
||||
date: str,
|
||||
max_limit: int,
|
||||
query: str | None = None,
|
||||
data_path: str = "reddit_data",
|
||||
client_id: str | None = None,
|
||||
client_secret: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Legacy function to maintain backward compatibility.
|
||||
Now uses live Reddit API if credentials are provided.
|
||||
|
||||
Args:
|
||||
category: Category ('global_news', 'company_news', etc.)
|
||||
date: Date in YYYY-MM-DD format
|
||||
max_limit: Maximum number of posts
|
||||
query: Optional search query (ticker for company_news)
|
||||
data_path: Unused in API mode
|
||||
client_id: Reddit client ID
|
||||
client_secret: Reddit client secret
|
||||
user_agent: Reddit user agent
|
||||
|
||||
Returns:
|
||||
List of post dictionaries
|
||||
"""
|
||||
if not all([client_id, client_secret, user_agent]):
|
||||
logger.warning("Reddit API credentials not provided. Returning empty data.")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Type check ensures these are not None
|
||||
assert client_id is not None
|
||||
assert client_secret is not None
|
||||
assert user_agent is not None
|
||||
client = RedditClient(client_id, client_secret, user_agent)
|
||||
|
||||
# Determine subreddits based on category
|
||||
if category == "global_news":
|
||||
subreddit_names = client.subreddits["global_news"]
|
||||
elif category == "company_news":
|
||||
subreddit_names = client.subreddits["company_news"]
|
||||
else:
|
||||
# Default to investing subreddits
|
||||
subreddit_names = client.subreddits["investing"]
|
||||
|
||||
# Calculate time filter based on date (Reddit doesn't support exact date filtering)
|
||||
post_date = datetime.strptime(date, "%Y-%m-%d")
|
||||
days_ago = (datetime.now() - post_date).days
|
||||
|
||||
if days_ago <= 1:
|
||||
time_filter = "day"
|
||||
elif days_ago <= 7:
|
||||
time_filter = "week"
|
||||
elif days_ago <= 30:
|
||||
time_filter = "month"
|
||||
else:
|
||||
time_filter = "year"
|
||||
|
||||
# Get posts
|
||||
if query and category == "company_news":
|
||||
# Search for specific company
|
||||
posts = client.search_posts(
|
||||
query=query,
|
||||
subreddit_names=subreddit_names,
|
||||
limit=max_limit // len(subreddit_names),
|
||||
time_filter=time_filter,
|
||||
)
|
||||
# Filter by company mentions
|
||||
posts = client.filter_posts_by_company(posts, query)
|
||||
else:
|
||||
# Get top posts
|
||||
posts = client.get_top_posts(
|
||||
subreddit_names=subreddit_names,
|
||||
limit=max_limit // len(subreddit_names),
|
||||
time_filter=time_filter,
|
||||
)
|
||||
|
||||
# Filter by date (approximate)
|
||||
start_date = date
|
||||
end_date = date
|
||||
posts = client.filter_posts_by_date(posts, start_date, end_date)
|
||||
|
||||
# Limit results
|
||||
return posts[:max_limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Reddit data: {e}")
|
||||
return []
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
import os
|
||||
from typing import Annotated
|
||||
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
from stockstats import wrap
|
||||
|
||||
from tradingagents.config import DEFAULT_CONFIG
|
||||
|
||||
|
||||
class StockstatsUtils:
|
||||
@staticmethod
|
||||
def get_stock_stats(
|
||||
symbol: Annotated[str, "ticker symbol for the company"],
|
||||
indicator: Annotated[
|
||||
str, "quantitative indicators based off of the stock data for the company"
|
||||
],
|
||||
curr_date_str: Annotated[
|
||||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
data_dir: Annotated[
|
||||
str,
|
||||
"directory where the stock data is stored.",
|
||||
],
|
||||
online: Annotated[
|
||||
bool,
|
||||
"whether to use online tools to fetch data or offline tools. If True, will use online tools.",
|
||||
] = False,
|
||||
):
|
||||
df = None
|
||||
data = None
|
||||
|
||||
if not online:
|
||||
try:
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
data_dir,
|
||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||
)
|
||||
)
|
||||
df = wrap(data)
|
||||
except FileNotFoundError as err:
|
||||
raise Exception(
|
||||
"Stockstats fail: Yahoo Finance data not fetched yet!"
|
||||
) from err
|
||||
else:
|
||||
# Get today's date as YYYY-mm-dd to add to cache
|
||||
today_date = pd.Timestamp.today()
|
||||
curr_date = pd.to_datetime(curr_date_str)
|
||||
|
||||
end_date = today_date
|
||||
start_date = today_date - pd.DateOffset(years=15)
|
||||
start_date = start_date.strftime("%Y-%m-%d")
|
||||
end_date = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Get config and ensure cache directory exists
|
||||
os.makedirs(DEFAULT_CONFIG.data_cache_dir, exist_ok=True)
|
||||
|
||||
data_file = os.path.join(
|
||||
DEFAULT_CONFIG.data_cache_dir,
|
||||
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
|
||||
)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
data = pd.read_csv(data_file)
|
||||
data["Date"] = pd.to_datetime(data["Date"])
|
||||
else:
|
||||
data = yf.download(
|
||||
symbol,
|
||||
start=start_date,
|
||||
end=end_date,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=True,
|
||||
)
|
||||
|
||||
if data is None:
|
||||
raise ValueError(f"Failed to download data for {symbol}")
|
||||
|
||||
data = data.reset_index()
|
||||
data.to_csv(data_file, index=False)
|
||||
|
||||
df = wrap(data)
|
||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
curr_date = curr_date.strftime("%Y-%m-%d")
|
||||
|
||||
df[indicator] # trigger stockstats to calculate the indicator
|
||||
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
|
||||
|
||||
if not matching_rows.empty:
|
||||
indicator_value = matching_rows[indicator].values[0]
|
||||
return indicator_value
|
||||
else:
|
||||
return "N/A: Not a trading day (weekend or holiday)"
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
from datetime import date, datetime, timedelta
|
||||
from typing import Annotated
|
||||
|
||||
import pandas as pd
|
||||
|
||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||
|
||||
|
||||
def save_output(
|
||||
data: pd.DataFrame, tag: str, save_path: SavePathType | None = None
|
||||
) -> None:
|
||||
if save_path:
|
||||
data.to_csv(save_path)
|
||||
print(f"{tag} saved to {save_path}")
|
||||
|
||||
|
||||
def get_current_date():
|
||||
return date.today().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def decorate_all_methods(decorator):
|
||||
def class_decorator(cls):
|
||||
for attr_name, attr_value in cls.__dict__.items():
|
||||
if callable(attr_value):
|
||||
setattr(cls, attr_name, decorator(attr_value))
|
||||
return cls
|
||||
|
||||
return class_decorator
|
||||
|
||||
|
||||
def get_next_weekday(date):
|
||||
if not isinstance(date, datetime):
|
||||
date = datetime.strptime(date, "%Y-%m-%d")
|
||||
|
||||
if date.weekday() >= 5:
|
||||
days_to_add = 7 - date.weekday()
|
||||
next_weekday = date + timedelta(days=days_to_add)
|
||||
return next_weekday
|
||||
else:
|
||||
return date
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
# gets data/stats
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Annotated, cast
|
||||
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
from pandas import DataFrame, Series
|
||||
|
||||
from .utils import SavePathType
|
||||
|
||||
|
||||
# Module-level cache to avoid memory leaks with instance methods
|
||||
@lru_cache(maxsize=100)
|
||||
def _get_cached_ticker(symbol: str) -> yf.Ticker:
|
||||
"""Get a cached yfinance Ticker instance."""
|
||||
return yf.Ticker(symbol)
|
||||
|
||||
|
||||
class YFinanceUtils:
|
||||
"""Clean YFinance utilities with ticker caching for better performance."""
|
||||
|
||||
def _get_ticker(self, symbol: str) -> yf.Ticker:
|
||||
"""Get a cached yfinance Ticker instance."""
|
||||
return _get_cached_ticker(symbol)
|
||||
|
||||
def get_stock_data(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
start_date: Annotated[
|
||||
str, "start date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
end_date: Annotated[
|
||||
str, "end date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
save_path: SavePathType | None = None,
|
||||
) -> DataFrame:
|
||||
"""Retrieve stock price data for designated ticker symbol."""
|
||||
ticker = self._get_ticker(symbol)
|
||||
|
||||
# Add one day to the end_date so that the data range is inclusive
|
||||
end_date_adjusted = pd.to_datetime(end_date) + pd.DateOffset(days=1)
|
||||
end_date_str = end_date_adjusted.strftime("%Y-%m-%d")
|
||||
|
||||
stock_data = ticker.history(start=start_date, end=end_date_str)
|
||||
return stock_data
|
||||
|
||||
def get_stock_info(self, symbol: Annotated[str, "ticker symbol"]) -> dict:
|
||||
"""Fetches and returns latest stock information."""
|
||||
ticker = self._get_ticker(symbol)
|
||||
return ticker.info
|
||||
|
||||
def get_company_info(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
save_path: str | None = None,
|
||||
) -> DataFrame:
|
||||
"""Fetches and returns company information as a DataFrame."""
|
||||
ticker = self._get_ticker(symbol)
|
||||
info = ticker.info
|
||||
|
||||
company_info = {
|
||||
"Company Name": info.get("shortName", "N/A"),
|
||||
"Industry": info.get("industry", "N/A"),
|
||||
"Sector": info.get("sector", "N/A"),
|
||||
"Country": info.get("country", "N/A"),
|
||||
"Website": info.get("website", "N/A"),
|
||||
}
|
||||
|
||||
company_info_df = DataFrame([company_info])
|
||||
|
||||
if save_path:
|
||||
company_info_df.to_csv(save_path)
|
||||
print(f"Company info for {symbol} saved to {save_path}")
|
||||
|
||||
return company_info_df
|
||||
|
||||
def get_stock_dividends(
|
||||
self,
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
save_path: str | None = None,
|
||||
) -> Series:
|
||||
"""Fetches and returns the latest dividends data as a DataFrame."""
|
||||
ticker = self._get_ticker(symbol)
|
||||
dividends = ticker.dividends
|
||||
|
||||
if save_path:
|
||||
dividends.to_csv(save_path)
|
||||
print(f"Dividends for {symbol} saved to {save_path}")
|
||||
|
||||
return dividends
|
||||
|
||||
def get_income_stmt(self, symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest income statement of the company as a DataFrame."""
|
||||
ticker = self._get_ticker(symbol)
|
||||
return ticker.financials
|
||||
|
||||
def get_balance_sheet(self, symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
|
||||
ticker = self._get_ticker(symbol)
|
||||
return ticker.balance_sheet
|
||||
|
||||
def get_cash_flow(self, symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
|
||||
ticker = self._get_ticker(symbol)
|
||||
return ticker.cashflow
|
||||
|
||||
def get_analyst_recommendations(
|
||||
self, symbol: Annotated[str, "ticker symbol"]
|
||||
) -> tuple[str | None, int]:
|
||||
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
|
||||
ticker = self._get_ticker(symbol)
|
||||
recommendations = cast("DataFrame", ticker.recommendations)
|
||||
|
||||
if recommendations is None or recommendations.empty:
|
||||
return None, 0 # No recommendations available
|
||||
|
||||
# Get the most recent recommendation row (excluding 'period' column if it exists)
|
||||
try:
|
||||
row_0 = recommendations.iloc[0, 1:] # Skip first column (likely 'period')
|
||||
|
||||
# Find the maximum voting result
|
||||
max_votes = row_0.max()
|
||||
majority_voting_result = row_0[row_0 == max_votes].index.tolist()
|
||||
|
||||
return majority_voting_result[0], int(max_votes)
|
||||
except (IndexError, KeyError):
|
||||
return None, 0
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the ticker cache. Useful for testing or memory management."""
|
||||
_get_cached_ticker.cache_clear()
|
||||
|
||||
def cache_info(self) -> dict:
|
||||
"""Get information about the ticker cache."""
|
||||
info = _get_cached_ticker.cache_info()
|
||||
return {
|
||||
"hits": info.hits,
|
||||
"misses": info.misses,
|
||||
"maxsize": info.maxsize,
|
||||
"currsize": info.currsize,
|
||||
}
|
||||
|
|
@ -11,7 +11,10 @@ from datetime import date
|
|||
|
||||
import pytest
|
||||
|
||||
from tradingagents.clients.finnhub_client import FinnhubClient
|
||||
# NOTE: FinnhubClient was removed - this test needs to be updated
|
||||
# from tradingagents.clients.finnhub_client import FinnhubClient
|
||||
|
||||
pytest.skip("FinnhubClient implementation removed", allow_module_level=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -9,12 +9,10 @@ from typing import Any
|
|||
import pandas as pd
|
||||
import yfinance as yf
|
||||
|
||||
from .base import BaseClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class YFinanceClient(BaseClient):
|
||||
class YFinanceClient:
|
||||
"""Client for Yahoo Finance API using yfinance library."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
|
@ -52,9 +50,6 @@ class YFinanceClient(BaseClient):
|
|||
Returns:
|
||||
Dict[str, Any]: Price data with metadata
|
||||
"""
|
||||
if not self.validate_date_range(start_date, end_date):
|
||||
raise ValueError(f"Invalid date range: {start_date} to {end_date}")
|
||||
|
||||
try:
|
||||
ticker = yf.Ticker(symbol.upper())
|
||||
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""
|
||||
Fundamental Data Service for aggregating and analyzing financial statement data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FundamentalDataService:
|
||||
"""Service for fundamental financial data aggregation and analysis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
simfin_client: SimFinClient,
|
||||
repository: FundamentalDataRepository,
|
||||
):
|
||||
"""Initialize Fundamental Data Service.
|
||||
|
||||
Args:
|
||||
simfin_client: Client for SimFin/financial API access
|
||||
repository: Repository for cached fundamental data
|
||||
online_mode: Whether to fetch live data
|
||||
data_dir: Directory for data storage
|
||||
"""
|
||||
self.simfin_client = simfin_client
|
||||
self.repository = repository
|
||||
|
||||
def update_fundamental_data(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
frequency: str = "quarterly",
|
||||
) -> FundamentalContext:
|
||||
pass # TODO: fetch fundementals from simfin, save in repo
|
||||
|
||||
def get_fundamental_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
frequency: str = "quarterly",
|
||||
) -> FundamentalContext:
|
||||
"""Get fundamental analysis context for a company.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
force_refresh: If True, skip local data and fetch fresh from APIs
|
||||
|
||||
Returns:
|
||||
FundamentalContext with financial statements and key ratios
|
||||
"""
|
||||
balance_sheet = None
|
||||
income_statement = None
|
||||
cash_flow = None
|
||||
error_info = {}
|
||||
errors = []
|
||||
data_source = "unknown"
|
||||
|
||||
# return FundamentalContext(
|
||||
# symbol=symbol,
|
||||
# period={"start": start_date, "end": end_date},
|
||||
# balance_sheet=balance_sheet,
|
||||
# income_statement=income_statement,
|
||||
# cash_flow=cash_flow,
|
||||
# key_ratios=key_ratios,
|
||||
# metadata={
|
||||
# "data_quality": data_quality,
|
||||
# "service": "fundamental_data",
|
||||
# "online_mode": self.is_online(),
|
||||
# "frequency": frequency,
|
||||
# "data_source": data_source,
|
||||
# "force_refresh": force_refresh,
|
||||
# **error_info,
|
||||
# },
|
||||
# )
|
||||
|
||||
pass # TODO: read data from repo
|
||||
|
|
@ -9,13 +9,13 @@ from typing import Any
|
|||
import pytest
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.models.context import (
|
||||
from tradingagents.domains.marketdata.fundamental_data_service import (
|
||||
DataQuality,
|
||||
FinancialStatement,
|
||||
FundamentalContext,
|
||||
FundamentalDataService,
|
||||
)
|
||||
from tradingagents.repositories.fundamental_repository import FundamentalDataRepository
|
||||
from tradingagents.services.fundamental_data_service import FundamentalDataService
|
||||
|
||||
|
||||
class MockSimFinClient(BaseClient):
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
"""
|
||||
Insider Data Service for aggregating and analyzing insider trading data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.domains.marketdata.finnhub_client import FinnhubClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataQuality(Enum):
|
||||
"""Data quality levels for insider data."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InsiderTransaction:
|
||||
"""Insider transaction data."""
|
||||
|
||||
date: str # YYYY-MM-DD format
|
||||
insider_name: str
|
||||
title: str
|
||||
transaction_type: str # "Purchase", "Sale", "Exercise", etc.
|
||||
shares: int
|
||||
price_per_share: float
|
||||
total_value: float
|
||||
shares_owned_after: int
|
||||
filing_date: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InsiderSentimentContext:
|
||||
"""Insider sentiment context for trading analysis."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
transactions: list[InsiderTransaction]
|
||||
sentiment_score: float # -1.0 to 1.0 (-1: very bearish, 1: very bullish)
|
||||
net_buying_value: float # Net buying (positive) or selling (negative) in USD
|
||||
insider_count: int # Number of unique insiders in period
|
||||
transaction_count: int
|
||||
analysis_summary: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InsiderTransactionContext:
|
||||
"""Insider transaction context for detailed transaction analysis."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
transactions: list[InsiderTransaction]
|
||||
total_transaction_value: float
|
||||
net_insider_activity: float # Net buying/selling activity
|
||||
top_insiders: list[dict[str, Any]] # Top insiders by transaction value
|
||||
transaction_summary: dict[str, Any] # Summary stats by transaction type
|
||||
analysis_summary: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class InsiderDataService:
|
||||
"""Service for insider trading data aggregation and analysis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: FinnhubClient,
|
||||
repository: InsiderDataRepository,
|
||||
):
|
||||
"""
|
||||
Initialize insider data service.
|
||||
|
||||
Args:
|
||||
client: Client for insider data (e.g., FinnhubClient)
|
||||
repository: Repository for cached insider data
|
||||
online_mode: Whether to use live data
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
self.client = client
|
||||
self.repository = repository
|
||||
|
||||
def get_insider_sentiment_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> InsiderSentimentContext:
|
||||
"""
|
||||
Get insider sentiment context for a company.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
InsiderSentimentContext: Insider sentiment analysis
|
||||
"""
|
||||
# TODO: Implement insider sentiment analysis
|
||||
transactions = []
|
||||
sentiment_score = 0.0
|
||||
net_buying_value = 0.0
|
||||
analysis_summary = f"Insider sentiment analysis for {symbol}"
|
||||
|
||||
metadata = {
|
||||
"data_quality": DataQuality.HIGH if transactions else DataQuality.LOW,
|
||||
"service": "insider_data",
|
||||
"data_source": "placeholder",
|
||||
"analysis_method": "insider_sentiment",
|
||||
}
|
||||
|
||||
return InsiderSentimentContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
transactions=transactions,
|
||||
sentiment_score=sentiment_score,
|
||||
net_buying_value=net_buying_value,
|
||||
insider_count=0,
|
||||
transaction_count=len(transactions),
|
||||
analysis_summary=analysis_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def get_insider_transaction_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> InsiderTransactionContext:
|
||||
"""
|
||||
Get insider transaction context for detailed analysis.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
InsiderTransactionContext: Detailed insider transaction analysis
|
||||
"""
|
||||
# TODO: Implement insider transaction analysis
|
||||
transactions = []
|
||||
total_transaction_value = 0.0
|
||||
net_insider_activity = 0.0
|
||||
top_insiders = []
|
||||
transaction_summary = {}
|
||||
analysis_summary = f"Insider transaction analysis for {symbol}"
|
||||
|
||||
metadata = {
|
||||
"data_quality": DataQuality.HIGH if transactions else DataQuality.LOW,
|
||||
"service": "insider_data",
|
||||
"data_source": "placeholder",
|
||||
"analysis_method": "insider_transactions",
|
||||
}
|
||||
|
||||
return InsiderTransactionContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
transactions=transactions,
|
||||
total_transaction_value=total_transaction_value,
|
||||
net_insider_activity=net_insider_activity,
|
||||
top_insiders=top_insiders,
|
||||
transaction_summary=transaction_summary,
|
||||
analysis_summary=analysis_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def update_insider_sentiment(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
pass # TODO: fetch insider sentiment with finnhub client, save with repo
|
||||
|
||||
def update_insider_transactions(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
pass # TODO: fetch insider transactions with finnhub client, save with repo
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
"""
|
||||
Market data service that provides structured market context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataQuality(Enum):
|
||||
"""Data quality levels for market data."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TechnicalIndicatorData:
|
||||
"""Technical indicator data point."""
|
||||
|
||||
date: str
|
||||
value: float | dict[str, Any]
|
||||
indicator_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketDataContext:
|
||||
"""Market data context for trading analysis."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
price_data: list[dict[str, Any]]
|
||||
technical_indicators: dict[str, list[TechnicalIndicatorData]]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TAReportContext:
|
||||
"""Technical Analysis Report context for specific indicators."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
indicator: str
|
||||
indicator_data: list[TechnicalIndicatorData]
|
||||
analysis_summary: str
|
||||
signal_strength: float # -1.0 to 1.0
|
||||
recommendation: str # "BUY", "SELL", "HOLD"
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceDataContext:
|
||||
"""Price Data context for historical price information."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
price_data: list[dict[str, Any]]
|
||||
latest_price: float
|
||||
price_change: float
|
||||
price_change_percent: float
|
||||
volume_info: dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class MarketDataService:
|
||||
"""Service for market data and technical indicators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
yfin_client: YFinClient,
|
||||
repo: MarketdataRepository,
|
||||
):
|
||||
"""
|
||||
Initialize market data service.
|
||||
|
||||
Args:
|
||||
client: Client for live market data
|
||||
repository: Repository for historical market data
|
||||
online_mode: Whether to use live data
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
self.finnhub_client = finnhub_client
|
||||
self.yfin_client = yfin_client
|
||||
self.repo = repo
|
||||
|
||||
def get_market_data_context(
|
||||
self, symbol: str, start_date: str, end_date: str
|
||||
) -> PriceDataContext:
|
||||
"""
|
||||
Get focused price data context with key metrics.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
PriceDataContext: Focused price data context
|
||||
"""
|
||||
# return PriceDataContext(
|
||||
# symbol=symbol,
|
||||
# period={"start": start_date, "end": end_date},
|
||||
# price_data=price_data.get("data", []),
|
||||
# latest_price=latest_price,
|
||||
# price_change=price_change,
|
||||
# price_change_percent=price_change_percent,
|
||||
# volume_info=volume_info,
|
||||
# metadata=metadata,
|
||||
# )
|
||||
|
||||
pass # TODO: get data from repo
|
||||
|
||||
def get_ta_report_context(
|
||||
self, symbol: str, indicator: str, start_date: str, end_date: str
|
||||
) -> TAReportContext:
|
||||
"""
|
||||
Get technical analysis report context for a specific indicator.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
indicator: Technical indicator name (e.g., 'rsi', 'macd', 'sma')
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
TAReportContext: Focused technical analysis context
|
||||
"""
|
||||
|
||||
# return TAReportContext(
|
||||
# symbol=symbol,
|
||||
# period={"start": start_date, "end": end_date},
|
||||
# indicator=indicator,
|
||||
# indicator_data=indicator_data.get(indicator, []),
|
||||
# analysis_summary=analysis_summary,
|
||||
# signal_strength=signal_strength,
|
||||
# recommendation=recommendation,
|
||||
# metadata=metadata,
|
||||
# )
|
||||
|
||||
pass # TODO get data from repo and calculate indicator with TALib?
|
||||
|
||||
def update_market_data(self, symbol: str, start_date: str, end_date: str):
|
||||
pass # TODO: fetch market data and save
|
||||
|
|
@ -13,9 +13,12 @@ from typing import Any
|
|||
sys.path.insert(0, os.path.abspath("."))
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.models.context import DataQuality, MarketDataContext
|
||||
from tradingagents.domains.marketdata.market_data_service import (
|
||||
DataQuality,
|
||||
MarketDataContext,
|
||||
MarketDataService,
|
||||
)
|
||||
from tradingagents.repositories.market_data_repository import MarketDataRepository
|
||||
from tradingagents.services.market_data_service import MarketDataService
|
||||
|
||||
|
||||
class MockYFinanceClient(BaseClient):
|
||||
|
|
@ -3,11 +3,9 @@ Google News client for live news data via web scraping.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.dataflows.googlenews_utils import getNewsData
|
||||
|
||||
from .base import BaseClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -27,20 +25,6 @@ class GoogleNewsClient(BaseClient):
|
|||
self.max_retries = kwargs.get("max_retries", 3)
|
||||
self.delay_between_requests = kwargs.get("delay_between_requests", 1.0)
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
"""Test Google News connection by fetching a simple query."""
|
||||
try:
|
||||
# Test with a simple query for recent news
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
start_date = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
|
||||
test_data = getNewsData("technology", start_date, end_date)
|
||||
return isinstance(test_data, list)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google News connection test failed: {e}")
|
||||
return False
|
||||
|
||||
def get_data(
|
||||
self, query: str, start_date: str, end_date: str, **kwargs
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -3,15 +3,11 @@ News service that provides structured news context.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.models.context import (
|
||||
ArticleData,
|
||||
NewsContext,
|
||||
SentimentScore,
|
||||
)
|
||||
from tradingagents.repositories.base import BaseRepository
|
||||
|
||||
from .base import BaseService
|
||||
|
|
@ -19,6 +15,64 @@ from .base import BaseService
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataQuality(Enum):
|
||||
"""Data quality levels for news data."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentimentScore:
|
||||
"""Sentiment analysis score."""
|
||||
|
||||
score: float # -1.0 to 1.0
|
||||
confidence: float # 0.0 to 1.0
|
||||
label: str # positive/negative/neutral
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArticleData:
|
||||
"""News article data."""
|
||||
|
||||
title: str
|
||||
content: str
|
||||
author: str
|
||||
source: str
|
||||
date: str # YYYY-MM-DD format
|
||||
url: str
|
||||
sentiment: SentimentScore | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsContext:
|
||||
"""News context for trading analysis."""
|
||||
|
||||
query: str
|
||||
symbol: str | None
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
articles: list[ArticleData]
|
||||
sentiment_summary: SentimentScore
|
||||
article_count: int
|
||||
sources: list[str]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlobalNewsContext:
|
||||
"""Global news context for macro analysis."""
|
||||
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
categories: list[str]
|
||||
articles: list[ArticleData]
|
||||
sentiment_summary: SentimentScore
|
||||
article_count: int
|
||||
sources: list[str]
|
||||
trending_topics: list[str]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class NewsService(BaseService):
|
||||
"""Service for news data and sentiment analysis."""
|
||||
|
||||
|
|
@ -95,7 +149,7 @@ class NewsService(BaseService):
|
|||
end_date: str,
|
||||
categories: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> NewsContext:
|
||||
) -> GlobalNewsContext:
|
||||
"""
|
||||
Get global/macro news context.
|
||||
|
||||
|
|
@ -106,6 +160,18 @@ class NewsService(BaseService):
|
|||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
NewsContext: Global news context
|
||||
GlobalNewsContext: Global news context
|
||||
"""
|
||||
pass
|
||||
# TODO: Implement global news fetching
|
||||
return GlobalNewsContext(
|
||||
period={"start": start_date, "end": end_date},
|
||||
categories=categories or [],
|
||||
articles=[],
|
||||
sentiment_summary=SentimentScore(
|
||||
score=0.0, confidence=0.0, label="neutral"
|
||||
),
|
||||
article_count=0,
|
||||
sources=[],
|
||||
trending_topics=[],
|
||||
metadata={"service": "news", "analysis_method": "global_news"},
|
||||
)
|
||||
|
|
@ -13,9 +13,12 @@ from typing import Any
|
|||
sys.path.insert(0, os.path.abspath("."))
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.models.context import NewsContext, SentimentScore
|
||||
from tradingagents.domains.news.news_service import (
|
||||
NewsContext,
|
||||
NewsService,
|
||||
SentimentScore,
|
||||
)
|
||||
from tradingagents.repositories.news_repository import NewsRepository
|
||||
from tradingagents.services.news_service import NewsService
|
||||
|
||||
|
||||
class MockFinnhubClient(BaseClient):
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
class RedditClient:
|
||||
pass
|
||||
|
|
@ -8,8 +8,6 @@ from dataclasses import asdict, dataclass, field
|
|||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
from .base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -48,7 +46,7 @@ class SocialData:
|
|||
posts: list[SocialPost]
|
||||
|
||||
|
||||
class SocialRepository(BaseRepository):
|
||||
class SocialRepository:
|
||||
"""Repository for accessing cached social media data with source separation."""
|
||||
|
||||
def __init__(self, data_dir: str, **kwargs):
|
||||
|
|
@ -158,7 +156,6 @@ class SocialRepository(BaseRepository):
|
|||
"""
|
||||
# Create source/query directory
|
||||
source_dir = self.social_data_dir / source / query
|
||||
self._ensure_path_exists(source_dir)
|
||||
|
||||
# Create JSON file path
|
||||
file_path = source_dir / f"{date.isoformat()}.json"
|
||||
|
|
@ -302,3 +299,7 @@ class SocialRepository(BaseRepository):
|
|||
merged_posts.sort(key=lambda x: x.created_date, reverse=True) # Newest first
|
||||
|
||||
return merged_posts
|
||||
|
||||
|
||||
# Alias for backwards compatibility
|
||||
SocialMediaRepository = SocialRepository
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
"""
|
||||
Social Media Service for aggregating and analyzing social media data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .reddit_client import RedditClient
|
||||
from .social_media_repository import SocialMediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataQuality(Enum):
|
||||
"""Data quality levels for social media data."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentimentScore:
|
||||
"""Sentiment analysis score."""
|
||||
|
||||
score: float # -1.0 to 1.0
|
||||
confidence: float # 0.0 to 1.0
|
||||
label: str # positive/negative/neutral
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostMetadata:
|
||||
upvotes: int
|
||||
num_comments: int
|
||||
subreddit: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostData:
|
||||
title: str
|
||||
content: str
|
||||
author: str
|
||||
source: str # subreddit name or "reddit"
|
||||
date: str # YYYY-MM-DD format
|
||||
url: str
|
||||
score: int # Reddit score/upvotes
|
||||
comments: int # Number of comments
|
||||
engagement_score: int # Calculated: upvotes + comments
|
||||
subreddit: str | None
|
||||
sentiment: SentimentScore | None = None # Added by sentiment analysis
|
||||
metadata: PostMetadata | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngagementMetrics:
|
||||
"""Engagement metrics for social media posts."""
|
||||
|
||||
total_engagement: float
|
||||
average_engagement: float
|
||||
max_engagement: float
|
||||
total_posts: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SocialContext:
|
||||
"""Social media context data for trading analysis."""
|
||||
|
||||
symbol: str | None
|
||||
period: tuple[str, str] # (start_date, end_date)
|
||||
posts: list[PostData]
|
||||
engagement_metrics: EngagementMetrics
|
||||
sentiment_summary: SentimentScore
|
||||
post_count: int
|
||||
platforms: list[str]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class StockSocialContext:
|
||||
"""Stock-specific social media context for targeted analysis."""
|
||||
|
||||
symbol: str
|
||||
period: dict[str, str] # {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
|
||||
posts: list[PostData]
|
||||
engagement_metrics: EngagementMetrics
|
||||
sentiment_summary: SentimentScore
|
||||
post_count: int
|
||||
platforms: list[str]
|
||||
trending_topics: list[str]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class SocialMediaService:
|
||||
"""Service for social media data aggregation and analysis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reddit_client: RedditClient,
|
||||
repository: SocialMediaRepository,
|
||||
):
|
||||
"""Initialize Social Media Service.
|
||||
|
||||
Args:
|
||||
reddit_client: Client for Reddit API access
|
||||
repository: Repository for cached social data
|
||||
online_mode: Whether to fetch live data
|
||||
data_dir: Directory for data storage
|
||||
"""
|
||||
self.reddit_client = reddit_client
|
||||
self.repository = repository
|
||||
|
||||
def get_context(
|
||||
self,
|
||||
query: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
symbol: str,
|
||||
subreddits: list[str],
|
||||
force_refresh: bool = False,
|
||||
) -> SocialContext:
|
||||
"""Get social media context for a query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
symbol: Optional stock symbol
|
||||
subreddits: Optional list of subreddits to search
|
||||
force_refresh: If True, skip local data and fetch fresh from APIs
|
||||
|
||||
Returns:
|
||||
SocialContext with posts and sentiment analysis
|
||||
"""
|
||||
posts = []
|
||||
error_info = {}
|
||||
data_source = "unknown"
|
||||
|
||||
try:
|
||||
# Local-first data strategy with force refresh option
|
||||
if force_refresh:
|
||||
# Skip local data, fetch fresh from APIs
|
||||
posts, data_source = self._fetch_and_cache_fresh_social_data(
|
||||
query, start_date, end_date, symbol, subreddits
|
||||
)
|
||||
else:
|
||||
# Check local data first, fetch missing if needed
|
||||
posts, data_source = self._get_social_data_local_first(
|
||||
query, start_date, end_date, symbol, subreddits
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching social media data: {e}")
|
||||
error_info = {"error": str(e)}
|
||||
|
||||
# Calculate sentiment and engagement metrics
|
||||
sentiment_summary = self._calculate_sentiment(posts)
|
||||
engagement_metrics = self._calculate_engagement_metrics(posts)
|
||||
|
||||
# Determine data quality based on data source
|
||||
data_quality = self._determine_data_quality(
|
||||
data_source=data_source,
|
||||
record_count=len(posts),
|
||||
has_errors=bool(error_info),
|
||||
)
|
||||
|
||||
# Create structured engagement metrics
|
||||
structured_metrics = EngagementMetrics(
|
||||
total_engagement=float(engagement_metrics.get("total_engagement", 0)),
|
||||
average_engagement=float(engagement_metrics.get("average_engagement", 0)),
|
||||
max_engagement=float(engagement_metrics.get("max_engagement", 0)),
|
||||
total_posts=int(engagement_metrics.get("total_posts", 0)),
|
||||
)
|
||||
|
||||
# Separate non-float metrics for metadata
|
||||
metadata_info = {
|
||||
k: v
|
||||
for k, v in engagement_metrics.items()
|
||||
if k
|
||||
not in [
|
||||
"total_engagement",
|
||||
"average_engagement",
|
||||
"max_engagement",
|
||||
"total_posts",
|
||||
]
|
||||
}
|
||||
|
||||
return SocialContext(
|
||||
symbol=symbol,
|
||||
period=(start_date, end_date),
|
||||
posts=posts,
|
||||
engagement_metrics=structured_metrics,
|
||||
sentiment_summary=sentiment_summary,
|
||||
post_count=len(posts),
|
||||
platforms=["reddit"],
|
||||
metadata={
|
||||
"data_quality": data_quality,
|
||||
"service": "social_media",
|
||||
"subreddits": subreddits or [],
|
||||
"data_source": data_source,
|
||||
"force_refresh": force_refresh,
|
||||
**metadata_info,
|
||||
**error_info,
|
||||
},
|
||||
)
|
||||
|
||||
def get_stock_social_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
subreddits: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> StockSocialContext:
|
||||
"""
|
||||
Get stock-specific social media context with trending topics.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
subreddits: List of subreddits to search
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
StockSocialContext: Focused stock social media context
|
||||
"""
|
||||
# Use existing get_context method
|
||||
base_context = self.get_context(
|
||||
query=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
symbol=symbol,
|
||||
subreddits=subreddits or [],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# TODO: Extract trending topics from posts
|
||||
trending_topics = []
|
||||
|
||||
return StockSocialContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
posts=base_context.posts,
|
||||
engagement_metrics=base_context.engagement_metrics,
|
||||
sentiment_summary=base_context.sentiment_summary,
|
||||
post_count=base_context.post_count,
|
||||
platforms=base_context.platforms,
|
||||
trending_topics=trending_topics,
|
||||
metadata=base_context.metadata,
|
||||
)
|
||||
|
|
@ -12,14 +12,14 @@ from typing import Any
|
|||
sys.path.insert(0, os.path.abspath("."))
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.models.context import (
|
||||
from tradingagents.domains.socialmedia.social_media_service import (
|
||||
DataQuality,
|
||||
PostData,
|
||||
SentimentScore,
|
||||
SocialContext,
|
||||
SocialMediaService,
|
||||
)
|
||||
from tradingagents.repositories.social_repository import SocialRepository
|
||||
from tradingagents.services.social_media_service import SocialMediaService
|
||||
|
||||
|
||||
class MockRedditClient(BaseClient):
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
"""
|
||||
Pydantic models for structured data context in TradingAgents.
|
||||
"""
|
||||
|
||||
from .context import (
|
||||
ArticleData,
|
||||
DataQuality,
|
||||
FinancialStatement,
|
||||
FundamentalContext,
|
||||
InsiderContext,
|
||||
InsiderTransaction,
|
||||
MarketDataContext,
|
||||
NewsContext,
|
||||
PostData,
|
||||
SentimentScore,
|
||||
SocialContext,
|
||||
TechnicalIndicatorData,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DataQuality",
|
||||
"SentimentScore",
|
||||
"MarketDataContext",
|
||||
"NewsContext",
|
||||
"SocialContext",
|
||||
"FundamentalContext",
|
||||
"InsiderContext",
|
||||
"TechnicalIndicatorData",
|
||||
"ArticleData",
|
||||
"PostData",
|
||||
"FinancialStatement",
|
||||
"InsiderTransaction",
|
||||
]
|
||||
|
|
@ -1,292 +0,0 @@
|
|||
"""
|
||||
Pydantic models for structured context objects in TradingAgents.
|
||||
|
||||
These models define the schema for JSON context objects that services
|
||||
provide to agents, replacing the previous markdown string approach.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class DataQuality(str, Enum):
|
||||
"""Data quality indicator for context metadata."""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
class SentimentScore(BaseModel):
|
||||
"""Sentiment analysis result with confidence."""
|
||||
|
||||
score: float = Field(
|
||||
...,
|
||||
ge=-1.0,
|
||||
le=1.0,
|
||||
description="Sentiment score from -1 (negative) to 1 (positive)",
|
||||
)
|
||||
confidence: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Confidence in sentiment score"
|
||||
)
|
||||
label: str | None = Field(
|
||||
default=None, description="Human-readable sentiment label"
|
||||
)
|
||||
|
||||
@validator("label", pre=True, always=True)
|
||||
def set_sentiment_label(cls, v, values):
|
||||
if v is not None:
|
||||
return v
|
||||
|
||||
score = values.get("score", 0)
|
||||
if score > 0.1:
|
||||
return "positive"
|
||||
elif score < -0.1:
|
||||
return "negative"
|
||||
else:
|
||||
return "neutral"
|
||||
|
||||
|
||||
class TechnicalIndicatorData(BaseModel):
|
||||
"""Technical indicator data point."""
|
||||
|
||||
date: str = Field(..., description="Date in YYYY-MM-DD format")
|
||||
value: float | dict[str, float] = Field(
|
||||
..., description="Indicator value or values"
|
||||
)
|
||||
indicator_type: str = Field(
|
||||
..., description="Type of indicator (e.g., 'rsi', 'macd', 'sma')"
|
||||
)
|
||||
|
||||
|
||||
class ArticleData(BaseModel):
|
||||
"""News article data."""
|
||||
|
||||
headline: str = Field(..., description="Article headline")
|
||||
summary: str | None = Field(default=None, description="Article summary or snippet")
|
||||
url: str | None = Field(default=None, description="Article URL")
|
||||
source: str = Field(..., description="News source")
|
||||
date: str = Field(..., description="Publication date in YYYY-MM-DD format")
|
||||
sentiment: SentimentScore | None = Field(
|
||||
default=None, description="Article sentiment analysis"
|
||||
)
|
||||
entities: list[str] = Field(
|
||||
default_factory=list, description="Named entities mentioned"
|
||||
)
|
||||
|
||||
|
||||
class PostData(BaseModel):
|
||||
"""Social media post data."""
|
||||
|
||||
title: str = Field(..., description="Post title")
|
||||
content: str | None = Field(default=None, description="Post content")
|
||||
author: str = Field(..., description="Post author")
|
||||
source: str = Field(..., description="Post source (e.g., subreddit name)")
|
||||
date: str = Field(..., description="Post date in YYYY-MM-DD format")
|
||||
url: str | None = Field(default=None, description="Post URL")
|
||||
score: int = Field(default=0, description="Post score/upvotes")
|
||||
comments: int = Field(default=0, description="Number of comments")
|
||||
engagement_score: int = Field(default=0, description="Combined engagement metric")
|
||||
subreddit: str | None = Field(default=None, description="Subreddit name")
|
||||
sentiment: SentimentScore | None = Field(
|
||||
default=None, description="Post sentiment analysis"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional post metadata"
|
||||
)
|
||||
|
||||
|
||||
class FinancialStatement(BaseModel):
|
||||
"""Financial statement data."""
|
||||
|
||||
period: str = Field(..., description="Reporting period (e.g., '2024-Q1', '2024')")
|
||||
report_date: str = Field(..., description="Report date in YYYY-MM-DD format")
|
||||
publish_date: str = Field(..., description="Publish date in YYYY-MM-DD format")
|
||||
currency: str = Field(default="USD", description="Currency of financial data")
|
||||
data: dict[str, float] = Field(..., description="Financial statement line items")
|
||||
|
||||
|
||||
class InsiderTransaction(BaseModel):
|
||||
"""Insider trading transaction data."""
|
||||
|
||||
filing_date: str = Field(..., description="Filing date in YYYY-MM-DD format")
|
||||
name: str = Field(..., description="Insider name")
|
||||
change: float = Field(..., description="Change in shares")
|
||||
shares: float = Field(..., description="Total shares after transaction")
|
||||
transaction_price: float = Field(..., description="Price per share")
|
||||
transaction_code: str = Field(
|
||||
..., description="Transaction code (e.g., 'S' for sale)"
|
||||
)
|
||||
|
||||
|
||||
class MarketDataContext(BaseModel):
|
||||
"""Market data context with price and technical indicators."""
|
||||
|
||||
symbol: str = Field(..., description="Stock ticker symbol")
|
||||
period: dict[str, str] = Field(
|
||||
..., description="Date range with start and end keys"
|
||||
)
|
||||
price_data: list[dict[str, Any]] = Field(..., description="Historical price data")
|
||||
technical_indicators: dict[str, list[TechnicalIndicatorData]] = Field(
|
||||
default_factory=dict, description="Technical indicators organized by type"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context metadata including data quality and source info",
|
||||
)
|
||||
|
||||
@validator("period")
|
||||
def validate_period(cls, v):
|
||||
required_keys = {"start", "end"}
|
||||
if not required_keys.issubset(v.keys()):
|
||||
raise ValueError(f"Period must contain keys: {required_keys}")
|
||||
return v
|
||||
|
||||
|
||||
class NewsContext(BaseModel):
|
||||
"""News context with articles and sentiment analysis."""
|
||||
|
||||
symbol: str | None = Field(
|
||||
default=None, description="Stock ticker if company-specific"
|
||||
)
|
||||
period: dict[str, str] = Field(
|
||||
..., description="Date range with start and end keys"
|
||||
)
|
||||
articles: list[ArticleData] = Field(..., description="News articles")
|
||||
sentiment_summary: SentimentScore = Field(
|
||||
..., description="Overall sentiment across articles"
|
||||
)
|
||||
article_count: int = Field(..., description="Total number of articles")
|
||||
sources: list[str] = Field(default_factory=list, description="Unique news sources")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context metadata including data quality and coverage info",
|
||||
)
|
||||
|
||||
@validator("article_count", pre=True, always=True)
|
||||
def set_article_count(cls, v, values):
|
||||
if v is not None:
|
||||
return v
|
||||
return len(values.get("articles", []))
|
||||
|
||||
@validator("sources", pre=True, always=True)
|
||||
def set_sources(cls, v, values):
|
||||
if v:
|
||||
return v
|
||||
articles = values.get("articles", [])
|
||||
return list({article.source for article in articles})
|
||||
|
||||
|
||||
class SocialContext(BaseModel):
|
||||
"""Social media context with posts and engagement metrics."""
|
||||
|
||||
symbol: str | None = Field(
|
||||
default=None, description="Stock ticker if company-specific"
|
||||
)
|
||||
period: dict[str, str] = Field(
|
||||
..., description="Date range with start and end keys"
|
||||
)
|
||||
posts: list[PostData] = Field(..., description="Social media posts")
|
||||
engagement_metrics: dict[str, float] = Field(
|
||||
default_factory=dict, description="Aggregated engagement metrics"
|
||||
)
|
||||
sentiment_summary: SentimentScore = Field(
|
||||
..., description="Overall sentiment across posts"
|
||||
)
|
||||
post_count: int = Field(..., description="Total number of posts")
|
||||
platforms: list[str] = Field(
|
||||
default_factory=list, description="Social media platforms"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context metadata including data quality and platform info",
|
||||
)
|
||||
|
||||
@validator("post_count", pre=True, always=True)
|
||||
def set_post_count(cls, v, values):
|
||||
if v is not None:
|
||||
return v
|
||||
return len(values.get("posts", []))
|
||||
|
||||
@property
|
||||
def platform(self) -> str | None:
|
||||
"""Primary platform for backward compatibility."""
|
||||
return self.platforms[0] if self.platforms else None
|
||||
|
||||
|
||||
class FundamentalContext(BaseModel):
|
||||
"""Fundamental analysis context with financial statements."""
|
||||
|
||||
symbol: str = Field(..., description="Stock ticker symbol")
|
||||
period: dict[str, str] = Field(
|
||||
..., description="Date range with start and end keys"
|
||||
)
|
||||
balance_sheet: FinancialStatement | None = Field(
|
||||
default=None, description="Balance sheet data"
|
||||
)
|
||||
income_statement: FinancialStatement | None = Field(
|
||||
default=None, description="Income statement data"
|
||||
)
|
||||
cash_flow: FinancialStatement | None = Field(
|
||||
default=None, description="Cash flow statement data"
|
||||
)
|
||||
key_ratios: dict[str, float] = Field(
|
||||
default_factory=dict, description="Calculated financial ratios"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context metadata including data quality and completeness",
|
||||
)
|
||||
|
||||
|
||||
class InsiderContext(BaseModel):
|
||||
"""Insider trading context with transaction data and sentiment."""
|
||||
|
||||
symbol: str = Field(..., description="Stock ticker symbol")
|
||||
period: dict[str, str] = Field(
|
||||
..., description="Date range with start and end keys"
|
||||
)
|
||||
transactions: list[InsiderTransaction] = Field(
|
||||
..., description="Insider transactions"
|
||||
)
|
||||
sentiment_data: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Insider sentiment metrics"
|
||||
)
|
||||
transaction_count: int = Field(..., description="Total number of transactions")
|
||||
net_activity: dict[str, float] = Field(
|
||||
default_factory=dict, description="Net buying/selling activity metrics"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context metadata including data quality and coverage",
|
||||
)
|
||||
|
||||
@validator("transaction_count", pre=True, always=True)
|
||||
def set_transaction_count(cls, v, values):
|
||||
if v is not None:
|
||||
return v
|
||||
return len(values.get("transactions", []))
|
||||
|
||||
|
||||
# Base context for extensibility
|
||||
class BaseContext(BaseModel):
|
||||
"""Base context model for common fields."""
|
||||
|
||||
period: dict[str, str] = Field(
|
||||
..., description="Date range with start and end keys"
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=lambda: {
|
||||
"data_quality": DataQuality.MEDIUM,
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"source": "unknown",
|
||||
},
|
||||
description="Context metadata",
|
||||
)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True # Serialize enums as values
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
"""
|
||||
Repository classes for historical data access in TradingAgents.
|
||||
"""
|
||||
|
||||
from .fundamental_repository import FundamentalDataRepository
|
||||
from .insider_repository import InsiderDataRepository
|
||||
from .llm_repository import LLMRepository
|
||||
from .market_data_repository import MarketDataRepository
|
||||
from .news_repository import NewsRepository
|
||||
from .openai_repository import OpenAIRepository
|
||||
from .social_repository import SocialRepository
|
||||
|
||||
__all__ = [
|
||||
"MarketDataRepository",
|
||||
"NewsRepository",
|
||||
"SocialRepository",
|
||||
"FundamentalDataRepository",
|
||||
"InsiderDataRepository",
|
||||
"OpenAIRepository",
|
||||
"LLMRepository",
|
||||
]
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
"""
|
||||
Base repository class with common utilities.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
"""Base repository class with shared utility methods."""
|
||||
|
||||
def _ensure_path_exists(self, path: Path) -> None:
|
||||
"""
|
||||
Ensure a directory path exists, creating it if necessary.
|
||||
|
||||
Args:
|
||||
path: Path to ensure exists (can be file path - will create parent dirs)
|
||||
"""
|
||||
if path.suffix: # It's a file path, create parent directories
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
else: # It's a directory path, create the directory itself
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""
|
||||
Service classes for TradingAgents that return Pydantic context objects.
|
||||
"""
|
||||
|
||||
from .base import BaseService
|
||||
from .market_data_service import MarketDataService
|
||||
from .news_service import NewsService
|
||||
|
||||
__all__ = [
|
||||
"BaseService",
|
||||
"MarketDataService",
|
||||
"NewsService",
|
||||
]
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""
|
||||
Base service class for TradingAgents services.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseService:
|
||||
"""Base service class with common functionality."""
|
||||
|
||||
def __init__(self, online_mode: bool = True, data_dir: str = "data", **kwargs):
|
||||
"""Initialize base service.
|
||||
|
||||
Args:
|
||||
online_mode: Whether to use live APIs or cached data only
|
||||
data_dir: Directory for data storage
|
||||
"""
|
||||
self.online_mode = online_mode
|
||||
self.data_dir = data_dir
|
||||
|
||||
def is_online(self) -> bool:
|
||||
"""Check if service is in online mode."""
|
||||
return self.online_mode
|
||||
|
||||
def set_online_mode(self, online: bool) -> None:
|
||||
"""Set online mode for the service."""
|
||||
self.online_mode = online
|
||||
|
|
@ -1,364 +0,0 @@
|
|||
"""
|
||||
Simple builder functions for dependency injection in TradingAgents services.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from tradingagents.clients import (
|
||||
FinnhubClient,
|
||||
GoogleNewsClient,
|
||||
RedditClient,
|
||||
YFinanceClient,
|
||||
)
|
||||
from tradingagents.config import TradingAgentsConfig
|
||||
from tradingagents.repositories import (
|
||||
FundamentalDataRepository,
|
||||
InsiderDataRepository,
|
||||
MarketDataRepository,
|
||||
NewsRepository,
|
||||
OpenAIRepository,
|
||||
SocialRepository,
|
||||
)
|
||||
|
||||
from .fundamental_data_service import FundamentalDataService
|
||||
from .insider_data_service import InsiderDataService
|
||||
from .market_data_service import MarketDataService
|
||||
from .news_service import NewsService
|
||||
from .openai_data_service import OpenAIDataService
|
||||
from .social_media_service import SocialMediaService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_market_data_service(config: TradingAgentsConfig) -> MarketDataService:
|
||||
"""
|
||||
Build MarketDataService with appropriate client and repository.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
|
||||
Returns:
|
||||
MarketDataService: Configured service
|
||||
"""
|
||||
client = None
|
||||
repository = None
|
||||
|
||||
# Create client for online mode
|
||||
if config.online_tools:
|
||||
try:
|
||||
client = YFinanceClient()
|
||||
logger.info("Created YFinanceClient for live data")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create YFinanceClient: {e}")
|
||||
|
||||
# Always create repository for fallback/offline mode
|
||||
try:
|
||||
repository = MarketDataRepository(
|
||||
data_dir=config.data_dir, cache_dir=config.data_cache_dir
|
||||
)
|
||||
logger.info(f"Created MarketDataRepository with data_dir: {config.data_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create MarketDataRepository: {e}")
|
||||
|
||||
return MarketDataService(
|
||||
client=client,
|
||||
repository=repository,
|
||||
online_mode=config.online_tools,
|
||||
data_dir=config.data_dir,
|
||||
)
|
||||
|
||||
|
||||
def build_news_service(config: TradingAgentsConfig) -> NewsService:
|
||||
"""
|
||||
Build NewsService with appropriate clients and repository.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
|
||||
Returns:
|
||||
NewsService: Configured service
|
||||
"""
|
||||
finnhub_client = None
|
||||
google_client = None
|
||||
repository = None
|
||||
|
||||
# Create clients for online mode
|
||||
if config.online_tools:
|
||||
# Finnhub client
|
||||
if config.finnhub_api_key:
|
||||
try:
|
||||
finnhub_client = FinnhubClient(config.finnhub_api_key)
|
||||
logger.info("Created FinnhubClient for live news data")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create FinnhubClient: {e}")
|
||||
else:
|
||||
logger.info("No Finnhub API key provided, skipping FinnhubClient")
|
||||
|
||||
# Google News client
|
||||
try:
|
||||
google_client = GoogleNewsClient()
|
||||
logger.info("Created GoogleNewsClient for live news data")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create GoogleNewsClient: {e}")
|
||||
|
||||
# Always create repository for fallback/offline mode
|
||||
try:
|
||||
repository = NewsRepository(
|
||||
data_dir=config.data_dir, cache_dir=config.data_cache_dir
|
||||
)
|
||||
logger.info(f"Created NewsRepository with data_dir: {config.data_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create NewsRepository: {e}")
|
||||
|
||||
return NewsService(
|
||||
finnhub_client=finnhub_client,
|
||||
google_client=google_client,
|
||||
repository=repository,
|
||||
online_mode=config.online_tools,
|
||||
data_dir=config.data_dir,
|
||||
)
|
||||
|
||||
|
||||
def build_social_media_service(config: TradingAgentsConfig) -> SocialMediaService:
|
||||
"""
|
||||
Build SocialMediaService with appropriate client and repository.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
|
||||
Returns:
|
||||
SocialMediaService: Configured service
|
||||
"""
|
||||
client = None
|
||||
repository = None
|
||||
|
||||
# Create client for online mode
|
||||
if config.online_tools:
|
||||
# Reddit client
|
||||
if config.reddit_client_id and config.reddit_client_secret:
|
||||
try:
|
||||
client = RedditClient(
|
||||
client_id=config.reddit_client_id,
|
||||
client_secret=config.reddit_client_secret,
|
||||
user_agent=config.reddit_user_agent,
|
||||
)
|
||||
logger.info("Created RedditClient for live social media data")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create RedditClient: {e}")
|
||||
else:
|
||||
logger.info("No Reddit credentials provided, skipping RedditClient")
|
||||
|
||||
# Always create repository for fallback/offline mode
|
||||
try:
|
||||
repository = SocialRepository(config.data_dir)
|
||||
logger.info(f"Created SocialRepository with data_dir: {config.data_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create SocialRepository: {e}")
|
||||
|
||||
return SocialMediaService(
|
||||
client=client,
|
||||
repository=repository,
|
||||
online_mode=config.online_tools,
|
||||
data_dir=config.data_dir,
|
||||
)
|
||||
|
||||
|
||||
def build_fundamental_service(config: TradingAgentsConfig) -> FundamentalDataService:
|
||||
"""
|
||||
Build FundamentalDataService with appropriate client and repository.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
|
||||
Returns:
|
||||
FundamentalDataService: Configured service
|
||||
"""
|
||||
client = None
|
||||
repository = None
|
||||
|
||||
# Create client for online mode
|
||||
if config.online_tools:
|
||||
# SimFin client (would be implemented when SimFinClient is available)
|
||||
# try:
|
||||
# client = SimFinClient() # This would need API key configuration
|
||||
# logger.info("Created SimFinClient for live fundamental data")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to create SimFinClient: {e}")
|
||||
logger.info("SimFinClient not yet implemented, using repository only")
|
||||
|
||||
# Always create repository for fallback/offline mode
|
||||
try:
|
||||
repository = FundamentalDataRepository(config.data_dir)
|
||||
logger.info(
|
||||
f"Created FundamentalDataRepository with data_dir: {config.data_dir}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create FundamentalDataRepository: {e}")
|
||||
|
||||
return FundamentalDataService(
|
||||
simfin_client=client,
|
||||
repository=repository,
|
||||
online_mode=config.online_tools,
|
||||
data_dir=config.data_dir,
|
||||
)
|
||||
|
||||
|
||||
def build_insider_service(config: TradingAgentsConfig) -> InsiderDataService:
|
||||
"""
|
||||
Build InsiderDataService with appropriate client and repository.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
|
||||
Returns:
|
||||
InsiderDataService: Configured service
|
||||
"""
|
||||
client = None
|
||||
repository = None
|
||||
|
||||
# Create client for online mode
|
||||
if config.online_tools:
|
||||
# Finnhub client for insider data
|
||||
if config.finnhub_api_key:
|
||||
try:
|
||||
client = FinnhubClient(config.finnhub_api_key)
|
||||
logger.info("Created FinnhubClient for live insider trading data")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create FinnhubClient for insider data: {e}")
|
||||
else:
|
||||
logger.info("No Finnhub API key provided, skipping insider data client")
|
||||
|
||||
# Always create repository for fallback/offline mode
|
||||
try:
|
||||
repository = InsiderDataRepository(config.data_dir)
|
||||
logger.info(f"Created InsiderDataRepository with data_dir: {config.data_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create InsiderDataRepository: {e}")
|
||||
|
||||
return InsiderDataService(
|
||||
finnhub_client=client,
|
||||
repository=repository,
|
||||
online_mode=config.online_tools,
|
||||
data_dir=config.data_dir,
|
||||
)
|
||||
|
||||
|
||||
def build_openai_service(config: TradingAgentsConfig) -> OpenAIDataService:
|
||||
"""
|
||||
Build OpenAIDataService with appropriate client and repository.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
|
||||
Returns:
|
||||
OpenAIDataService: Configured service
|
||||
"""
|
||||
client = None
|
||||
repository = None
|
||||
|
||||
# Create client for online mode
|
||||
if config.online_tools:
|
||||
# OpenAI client (would be implemented when OpenAIClient is available)
|
||||
# if config.openai_api_key:
|
||||
# try:
|
||||
# client = OpenAIClient(api_key=config.openai_api_key)
|
||||
# logger.info("Created OpenAIClient for AI-powered analysis")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to create OpenAIClient: {e}")
|
||||
# else:
|
||||
# logger.info("No OpenAI API key provided, skipping OpenAI client")
|
||||
logger.info("OpenAIClient not yet implemented, using repository only")
|
||||
|
||||
# Always create repository for fallback/offline mode
|
||||
try:
|
||||
repository = OpenAIRepository(config.data_dir)
|
||||
logger.info(f"Created OpenAIRepository with data_dir: {config.data_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create OpenAIRepository: {e}")
|
||||
|
||||
return OpenAIDataService(
|
||||
openai_client=client,
|
||||
repository=repository,
|
||||
online_mode=config.online_tools,
|
||||
data_dir=config.data_dir,
|
||||
)
|
||||
|
||||
|
||||
def build_all_services(config: TradingAgentsConfig) -> dict:
|
||||
"""
|
||||
Build all available services.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of service name to service instance
|
||||
"""
|
||||
services = {}
|
||||
|
||||
# Build MarketDataService
|
||||
try:
|
||||
services["market_data"] = build_market_data_service(config)
|
||||
logger.info("Built MarketDataService")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build MarketDataService: {e}")
|
||||
|
||||
# Build NewsService
|
||||
try:
|
||||
services["news"] = build_news_service(config)
|
||||
logger.info("Built NewsService")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build NewsService: {e}")
|
||||
|
||||
# Build SocialMediaService
|
||||
try:
|
||||
services["social_media"] = build_social_media_service(config)
|
||||
logger.info("Built SocialMediaService")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build SocialMediaService: {e}")
|
||||
|
||||
# Build FundamentalDataService
|
||||
try:
|
||||
services["fundamental"] = build_fundamental_service(config)
|
||||
logger.info("Built FundamentalDataService")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build FundamentalDataService: {e}")
|
||||
|
||||
# Build InsiderDataService
|
||||
try:
|
||||
services["insider"] = build_insider_service(config)
|
||||
logger.info("Built InsiderDataService")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build InsiderDataService: {e}")
|
||||
|
||||
# Build OpenAIDataService
|
||||
try:
|
||||
services["openai"] = build_openai_service(config)
|
||||
logger.info("Built OpenAIDataService")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build OpenAIDataService: {e}")
|
||||
|
||||
logger.info(f"Built {len(services)} services: {list(services.keys())}")
|
||||
return services
|
||||
|
||||
|
||||
def build_toolkit_services(config: TradingAgentsConfig) -> dict:
|
||||
"""
|
||||
Build services specifically configured for Toolkit usage.
|
||||
|
||||
Args:
|
||||
config: TradingAgents configuration
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of services for Toolkit
|
||||
"""
|
||||
return build_all_services(config)
|
||||
|
||||
|
||||
# Aliases for the service toolkit
|
||||
create_market_data_service = build_market_data_service
|
||||
create_news_service = build_news_service
|
||||
create_social_media_service = build_social_media_service
|
||||
create_fundamental_data_service = build_fundamental_service
|
||||
create_insider_data_service = build_insider_service
|
||||
create_openai_data_service = build_openai_service
|
||||
|
|
@ -1,692 +0,0 @@
|
|||
"""
|
||||
Fundamental Data Service for aggregating and analyzing financial statement data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.clients import FinnhubClient
|
||||
from tradingagents.models.context import (
|
||||
DataQuality,
|
||||
FinancialStatement,
|
||||
FundamentalContext,
|
||||
)
|
||||
from tradingagents.repositories.fundamental_repository import FundamentalDataRepository
|
||||
from tradingagents.services.base import BaseService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FundamentalDataService(BaseService):
|
||||
"""Service for fundamental financial data aggregation and analysis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finnhub_client: FinnhubClient,
|
||||
repository: FundamentalDataRepository,
|
||||
data_dir: str = "data",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Fundamental Data Service.
|
||||
|
||||
Args:
|
||||
finnhub_client: Client for Finnhub/financial API access
|
||||
repository: Repository for cached fundamental data
|
||||
data_dir: Directory for data storage
|
||||
"""
|
||||
super().__init__(online_mode=True, data_dir=data_dir, **kwargs)
|
||||
self.finnhub_client = finnhub_client
|
||||
self.repository = repository
|
||||
|
||||
def get_fundamental_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
frequency: str = "quarterly",
|
||||
force_refresh: bool = False,
|
||||
**kwargs,
|
||||
) -> FundamentalContext:
|
||||
"""Get fundamental analysis context for a company.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
frequency: Reporting frequency ('quarterly' or 'annual')
|
||||
force_refresh: If True, skip local data and fetch fresh from APIs
|
||||
|
||||
Returns:
|
||||
FundamentalContext with financial statements and key ratios
|
||||
"""
|
||||
# Validate date strings first
|
||||
try:
|
||||
start_dt = date.fromisoformat(start_date)
|
||||
end_dt = date.fromisoformat(end_date)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid date format: {e}")
|
||||
|
||||
# Check date order
|
||||
if end_dt < start_dt:
|
||||
raise ValueError(f"End date {end_date} is before start date {start_date}")
|
||||
|
||||
balance_sheet = None
|
||||
income_statement = None
|
||||
cash_flow = None
|
||||
error_info = {}
|
||||
errors = []
|
||||
data_source = "unknown"
|
||||
|
||||
try:
|
||||
# Local-first data strategy with force refresh option
|
||||
if force_refresh:
|
||||
# Skip local data, fetch fresh from APIs
|
||||
balance_sheet, income_statement, cash_flow, data_source = (
|
||||
self._fetch_and_cache_fresh_fundamental_data(
|
||||
symbol, start_date, end_date, frequency
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Check local data first, fetch missing if needed
|
||||
balance_sheet, income_statement, cash_flow, data_source = (
|
||||
self._get_fundamental_data_local_first(
|
||||
symbol, start_date, end_date, frequency
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching fundamental data: {e}")
|
||||
errors.append(str(e))
|
||||
|
||||
# Add error info if there were any errors
|
||||
if errors:
|
||||
error_info = {"error": "; ".join(errors)}
|
||||
|
||||
# Calculate key financial ratios
|
||||
key_ratios = self._calculate_key_ratios(
|
||||
balance_sheet, income_statement, cash_flow
|
||||
)
|
||||
|
||||
# Determine data quality based on data source
|
||||
data_quality = self._determine_data_quality(
|
||||
data_source=data_source,
|
||||
statement_count=sum(
|
||||
[
|
||||
balance_sheet is not None,
|
||||
income_statement is not None,
|
||||
cash_flow is not None,
|
||||
]
|
||||
),
|
||||
has_errors=bool(errors),
|
||||
)
|
||||
|
||||
# Handle partial data scenarios gracefully
|
||||
context = self._handle_partial_statements(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
frequency=frequency,
|
||||
balance_sheet=balance_sheet,
|
||||
income_statement=income_statement,
|
||||
cash_flow=cash_flow,
|
||||
key_ratios=key_ratios,
|
||||
data_quality=data_quality,
|
||||
data_source=data_source,
|
||||
force_refresh=force_refresh,
|
||||
error_info=error_info,
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def get_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
frequency: str = "quarterly",
|
||||
**kwargs,
|
||||
) -> FundamentalContext:
|
||||
"""Alias for get_fundamental_context for consistency with other services."""
|
||||
return self.get_fundamental_context(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
frequency=frequency,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_balance_sheet(
|
||||
self, symbol: str, frequency: str, report_date: date
|
||||
) -> FinancialStatement | None:
|
||||
"""Get balance sheet data from client."""
|
||||
try:
|
||||
data = self.finnhub_client.get_balance_sheet(symbol, frequency, report_date)
|
||||
return self._convert_to_financial_statement(data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get balance sheet for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _get_income_statement(
|
||||
self, symbol: str, frequency: str, report_date: date
|
||||
) -> FinancialStatement | None:
|
||||
"""Get income statement data from client."""
|
||||
try:
|
||||
data = self.finnhub_client.get_income_statement(
|
||||
symbol, frequency, report_date
|
||||
)
|
||||
return self._convert_to_financial_statement(data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get income statement for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _get_cash_flow(
|
||||
self, symbol: str, frequency: str, report_date: date
|
||||
) -> FinancialStatement | None:
|
||||
"""Get cash flow statement data from client."""
|
||||
try:
|
||||
data = self.finnhub_client.get_cash_flow(symbol, frequency, report_date)
|
||||
return self._convert_to_financial_statement(data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cash flow for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _convert_to_financial_statement(
|
||||
self, data: dict[str, Any]
|
||||
) -> FinancialStatement | None:
|
||||
"""Convert raw financial data to FinancialStatement object."""
|
||||
if not data or "data" not in data or not data["data"]:
|
||||
return None
|
||||
|
||||
try:
|
||||
return FinancialStatement(
|
||||
period=data.get("period", "Unknown"),
|
||||
report_date=data.get("report_date", ""),
|
||||
publish_date=data.get("publish_date", ""),
|
||||
currency=data.get("currency", "USD"),
|
||||
data=data["data"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert financial statement: {e}")
|
||||
return None
|
||||
|
||||
def _parse_cached_statements(self, cached_data: dict[str, Any]) -> tuple:
|
||||
"""Parse cached repository data into financial statements."""
|
||||
balance_sheet = None
|
||||
income_statement = None
|
||||
cash_flow = None
|
||||
|
||||
if cached_data and "financial_statements" in cached_data:
|
||||
statements = cached_data["financial_statements"]
|
||||
|
||||
if "balance_sheet" in statements:
|
||||
balance_sheet = FinancialStatement(**statements["balance_sheet"])
|
||||
if "income_statement" in statements:
|
||||
income_statement = FinancialStatement(**statements["income_statement"])
|
||||
if "cash_flow" in statements:
|
||||
cash_flow = FinancialStatement(**statements["cash_flow"])
|
||||
|
||||
return balance_sheet, income_statement, cash_flow
|
||||
|
||||
def _get_fundamental_data_local_first(
|
||||
self, symbol: str, start_date: str, end_date: str, frequency: str
|
||||
) -> tuple[
|
||||
FinancialStatement | None,
|
||||
FinancialStatement | None,
|
||||
FinancialStatement | None,
|
||||
str,
|
||||
]:
|
||||
"""Get fundamental data using local-first strategy: check local data first, fetch missing if needed."""
|
||||
try:
|
||||
# Check if we have sufficient local data
|
||||
if self.repository.has_data_for_period(
|
||||
symbol, start_date, end_date, frequency=frequency
|
||||
):
|
||||
logger.info(
|
||||
f"Using local fundamental data for {symbol} ({start_date} to {end_date})"
|
||||
)
|
||||
cached_data = self.repository.get_data(
|
||||
symbol=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
frequency=frequency,
|
||||
)
|
||||
balance_sheet, income_statement, cash_flow = (
|
||||
self._parse_cached_statements(cached_data)
|
||||
)
|
||||
return balance_sheet, income_statement, cash_flow, "local_cache"
|
||||
|
||||
# We don't have sufficient local data - need to fetch from APIs
|
||||
logger.info(
|
||||
f"Local data insufficient, fetching from APIs for {symbol} ({start_date} to {end_date})"
|
||||
)
|
||||
balance_sheet, income_statement, cash_flow, _ = (
|
||||
self._fetch_fresh_fundamental_data(
|
||||
symbol, start_date, end_date, frequency
|
||||
)
|
||||
)
|
||||
|
||||
# Cache the fresh data
|
||||
if any([balance_sheet, income_statement, cash_flow]):
|
||||
try:
|
||||
cache_data = {
|
||||
"symbol": symbol,
|
||||
"frequency": frequency,
|
||||
"financial_statements": {},
|
||||
"metadata": {"cached_at": datetime.utcnow().isoformat()},
|
||||
}
|
||||
|
||||
if balance_sheet:
|
||||
cache_data["financial_statements"]["balance_sheet"] = (
|
||||
balance_sheet.model_dump()
|
||||
)
|
||||
if income_statement:
|
||||
cache_data["financial_statements"]["income_statement"] = (
|
||||
income_statement.model_dump()
|
||||
)
|
||||
if cash_flow:
|
||||
cache_data["financial_statements"]["cash_flow"] = (
|
||||
cash_flow.model_dump()
|
||||
)
|
||||
|
||||
self.repository.store_data(symbol, cache_data, frequency=frequency)
|
||||
logger.debug(f"Cached fresh fundamental data for {symbol}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to cache fundamental data for {symbol}: {e}"
|
||||
)
|
||||
|
||||
return balance_sheet, income_statement, cash_flow, "live_api"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching fundamental data for {symbol}: {e}")
|
||||
return None, None, None, "error"
|
||||
|
||||
def _fetch_and_cache_fresh_fundamental_data(
|
||||
self, symbol: str, start_date: str, end_date: str, frequency: str
|
||||
) -> tuple[
|
||||
FinancialStatement | None,
|
||||
FinancialStatement | None,
|
||||
FinancialStatement | None,
|
||||
str,
|
||||
]:
|
||||
"""Force fetch fresh fundamental data from APIs and cache it, bypassing local data."""
|
||||
try:
|
||||
logger.info(
|
||||
f"Force refreshing fundamental data from APIs for {symbol} ({start_date} to {end_date})"
|
||||
)
|
||||
|
||||
# Clear existing data
|
||||
try:
|
||||
self.repository.clear_data(
|
||||
symbol, start_date, end_date, frequency=frequency
|
||||
)
|
||||
logger.debug(f"Cleared existing fundamental data for {symbol}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear existing fundamental data for {symbol}: {e}"
|
||||
)
|
||||
|
||||
# Fetch fresh data
|
||||
balance_sheet, income_statement, cash_flow, _ = (
|
||||
self._fetch_fresh_fundamental_data(
|
||||
symbol, start_date, end_date, frequency
|
||||
)
|
||||
)
|
||||
|
||||
# Cache the fresh data
|
||||
if any([balance_sheet, income_statement, cash_flow]):
|
||||
try:
|
||||
cache_data = {
|
||||
"symbol": symbol,
|
||||
"frequency": frequency,
|
||||
"financial_statements": {},
|
||||
"metadata": {"refreshed_at": datetime.utcnow().isoformat()},
|
||||
}
|
||||
|
||||
if balance_sheet:
|
||||
cache_data["financial_statements"]["balance_sheet"] = (
|
||||
balance_sheet.model_dump()
|
||||
)
|
||||
if income_statement:
|
||||
cache_data["financial_statements"]["income_statement"] = (
|
||||
income_statement.model_dump()
|
||||
)
|
||||
if cash_flow:
|
||||
cache_data["financial_statements"]["cash_flow"] = (
|
||||
cash_flow.model_dump()
|
||||
)
|
||||
|
||||
self.repository.store_data(
|
||||
symbol, cache_data, frequency=frequency, overwrite=True
|
||||
)
|
||||
logger.debug(f"Cached refreshed fundamental data for {symbol}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to cache refreshed fundamental data for {symbol}: {e}"
|
||||
)
|
||||
|
||||
return balance_sheet, income_statement, cash_flow, "live_api_refresh"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error force refreshing fundamental data for {symbol}: {e}")
|
||||
return None, None, None, "refresh_error"
|
||||
|
||||
def _fetch_fresh_fundamental_data(
|
||||
self, symbol: str, start_date: str, end_date: str, frequency: str
|
||||
) -> tuple[
|
||||
FinancialStatement | None,
|
||||
FinancialStatement | None,
|
||||
FinancialStatement | None,
|
||||
str,
|
||||
]:
|
||||
"""Fetch fresh fundamental data from APIs."""
|
||||
balance_sheet = None
|
||||
income_statement = None
|
||||
cash_flow = None
|
||||
|
||||
if self.is_online() and self.finnhub_client:
|
||||
# Parse end_date string to date object for client calls
|
||||
try:
|
||||
end_date_obj = date.fromisoformat(end_date)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid end_date format '{end_date}': {e}")
|
||||
return balance_sheet, income_statement, cash_flow, "date_error"
|
||||
|
||||
# Get financial statements from Finnhub client
|
||||
balance_sheet = self._get_balance_sheet(symbol, frequency, end_date_obj)
|
||||
income_statement = self._get_income_statement(
|
||||
symbol, frequency, end_date_obj
|
||||
)
|
||||
cash_flow = self._get_cash_flow(symbol, frequency, end_date_obj)
|
||||
|
||||
return balance_sheet, income_statement, cash_flow, "live_api"
|
||||
|
||||
def _calculate_key_ratios(
|
||||
self,
|
||||
balance_sheet: FinancialStatement | None,
|
||||
income_statement: FinancialStatement | None,
|
||||
cash_flow: FinancialStatement | None,
|
||||
) -> dict[str, float]:
|
||||
"""Calculate key financial ratios from financial statements."""
|
||||
ratios = {}
|
||||
|
||||
try:
|
||||
# Extract data from statements
|
||||
bs_data = balance_sheet.data if balance_sheet else {}
|
||||
is_data = income_statement.data if income_statement else {}
|
||||
|
||||
# Liquidity Ratios
|
||||
if (
|
||||
"Total Current Assets" in bs_data
|
||||
and "Total Current Liabilities" in bs_data
|
||||
):
|
||||
current_liabilities = bs_data["Total Current Liabilities"]
|
||||
if current_liabilities > 0:
|
||||
ratios["current_ratio"] = (
|
||||
bs_data["Total Current Assets"] / current_liabilities
|
||||
)
|
||||
|
||||
# Quick ratio (more conservative)
|
||||
if all(
|
||||
k in bs_data
|
||||
for k in [
|
||||
"Cash and Cash Equivalents",
|
||||
"Short-term Investments",
|
||||
"Accounts Receivable",
|
||||
"Total Current Liabilities",
|
||||
]
|
||||
):
|
||||
quick_assets = (
|
||||
bs_data["Cash and Cash Equivalents"]
|
||||
+ bs_data.get("Short-term Investments", 0)
|
||||
+ bs_data["Accounts Receivable"]
|
||||
)
|
||||
current_liabilities = bs_data["Total Current Liabilities"]
|
||||
if current_liabilities > 0:
|
||||
ratios["quick_ratio"] = quick_assets / current_liabilities
|
||||
|
||||
# Cash ratio
|
||||
if (
|
||||
"Cash and Cash Equivalents" in bs_data
|
||||
and "Total Current Liabilities" in bs_data
|
||||
):
|
||||
current_liabilities = bs_data["Total Current Liabilities"]
|
||||
if current_liabilities > 0:
|
||||
cash_and_equivalents = bs_data[
|
||||
"Cash and Cash Equivalents"
|
||||
] + bs_data.get("Short-term Investments", 0)
|
||||
ratios["cash_ratio"] = cash_and_equivalents / current_liabilities
|
||||
|
||||
# Leverage Ratios
|
||||
if "Long-term Debt" in bs_data and "Total Shareholders Equity" in bs_data:
|
||||
equity = bs_data["Total Shareholders Equity"]
|
||||
if equity > 0:
|
||||
ratios["debt_to_equity"] = bs_data["Long-term Debt"] / equity
|
||||
|
||||
if "Long-term Debt" in bs_data and "Total Assets" in bs_data:
|
||||
assets = bs_data["Total Assets"]
|
||||
if assets > 0:
|
||||
ratios["debt_to_assets"] = bs_data["Long-term Debt"] / assets
|
||||
|
||||
if "Total Assets" in bs_data and "Total Shareholders Equity" in bs_data:
|
||||
equity = bs_data["Total Shareholders Equity"]
|
||||
if equity > 0:
|
||||
ratios["equity_multiplier"] = bs_data["Total Assets"] / equity
|
||||
|
||||
# Profitability Ratios
|
||||
if "Total Revenue" in is_data and "Cost of Revenue" in is_data:
|
||||
revenue = is_data["Total Revenue"]
|
||||
if revenue > 0:
|
||||
ratios["gross_margin"] = (
|
||||
revenue - is_data["Cost of Revenue"]
|
||||
) / revenue
|
||||
|
||||
if "Operating Income" in is_data and "Total Revenue" in is_data:
|
||||
revenue = is_data["Total Revenue"]
|
||||
if revenue > 0:
|
||||
ratios["operating_margin"] = is_data["Operating Income"] / revenue
|
||||
|
||||
if "Net Income" in is_data and "Total Revenue" in is_data:
|
||||
revenue = is_data["Total Revenue"]
|
||||
if revenue > 0:
|
||||
ratios["net_margin"] = is_data["Net Income"] / revenue
|
||||
|
||||
# Return on Equity (ROE)
|
||||
if "Net Income" in is_data and "Total Shareholders Equity" in bs_data:
|
||||
equity = bs_data["Total Shareholders Equity"]
|
||||
if equity > 0:
|
||||
ratios["roe"] = is_data["Net Income"] / equity
|
||||
|
||||
# Return on Assets (ROA)
|
||||
if "Net Income" in is_data and "Total Assets" in bs_data:
|
||||
assets = bs_data["Total Assets"]
|
||||
if assets > 0:
|
||||
ratios["roa"] = is_data["Net Income"] / assets
|
||||
|
||||
# Efficiency Ratios
|
||||
if "Total Revenue" in is_data and "Total Assets" in bs_data:
|
||||
assets = bs_data["Total Assets"]
|
||||
if assets > 0:
|
||||
ratios["asset_turnover"] = is_data["Total Revenue"] / assets
|
||||
|
||||
# Inventory turnover
|
||||
if "Cost of Revenue" in is_data and "Inventory" in bs_data:
|
||||
inventory = bs_data["Inventory"]
|
||||
if inventory > 0:
|
||||
ratios["inventory_turnover"] = (
|
||||
is_data["Cost of Revenue"] / inventory
|
||||
)
|
||||
|
||||
# Receivables turnover
|
||||
if "Total Revenue" in is_data and "Accounts Receivable" in bs_data:
|
||||
receivables = bs_data["Accounts Receivable"]
|
||||
if receivables > 0:
|
||||
ratios["receivables_turnover"] = (
|
||||
is_data["Total Revenue"] / receivables
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating financial ratios: {e}")
|
||||
|
||||
return ratios
|
||||
|
||||
def _handle_partial_statements(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
frequency: str,
|
||||
balance_sheet: FinancialStatement | None,
|
||||
income_statement: FinancialStatement | None,
|
||||
cash_flow: FinancialStatement | None,
|
||||
key_ratios: dict[str, float],
|
||||
data_quality: DataQuality,
|
||||
data_source: str,
|
||||
force_refresh: bool,
|
||||
error_info: dict[str, Any],
|
||||
) -> FundamentalContext:
|
||||
"""Create context even if some statements are missing.
|
||||
|
||||
- If all statements fail: Raise exception
|
||||
- If some statements succeed: Return partial context
|
||||
- Mark missing statements in metadata
|
||||
"""
|
||||
statement_count = sum(
|
||||
[
|
||||
balance_sheet is not None,
|
||||
income_statement is not None,
|
||||
cash_flow is not None,
|
||||
]
|
||||
)
|
||||
|
||||
# If all statements failed, raise exception
|
||||
if statement_count == 0 and data_source not in ["local_cache"]:
|
||||
error_msg = f"Failed to fetch any financial statements for {symbol}"
|
||||
if error_info:
|
||||
error_msg += f": {error_info.get('error', 'Unknown error')}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Create metadata with partial data information
|
||||
metadata = {
|
||||
"data_quality": data_quality,
|
||||
"service": "fundamental_data",
|
||||
"online_mode": self.is_online(),
|
||||
"frequency": frequency,
|
||||
"data_source": data_source,
|
||||
"force_refresh": force_refresh,
|
||||
"has_balance_sheet": balance_sheet is not None,
|
||||
"has_income_statement": income_statement is not None,
|
||||
"has_cash_flow": cash_flow is not None,
|
||||
"partial_data": statement_count < 3,
|
||||
"statement_count": statement_count,
|
||||
**error_info,
|
||||
}
|
||||
|
||||
return FundamentalContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
balance_sheet=balance_sheet,
|
||||
income_statement=income_statement,
|
||||
cash_flow=cash_flow,
|
||||
key_ratios=key_ratios,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def detect_fundamental_gaps(
|
||||
self, symbol: str, start_date: str, end_date: str, frequency: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Returns list of report dates that need fetching.
|
||||
|
||||
Example: If requesting quarterly from 2024-01-01 to 2024-12-31
|
||||
and cache has Q1 and Q3, returns ["2024-06-30", "2024-09-30", "2024-12-31"]
|
||||
|
||||
For quarterly: Check for Q1 (Mar 31), Q2 (Jun 30), Q3 (Sep 30), Q4 (Dec 31)
|
||||
For annual: Check for fiscal year ends
|
||||
"""
|
||||
try:
|
||||
start_dt = date.fromisoformat(start_date)
|
||||
end_dt = date.fromisoformat(end_date)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
f"Invalid date format in gap detection: {start_date}, {end_date}"
|
||||
)
|
||||
return []
|
||||
|
||||
# Get existing data from repository
|
||||
try:
|
||||
cached_data = self.repository.get_data(
|
||||
symbol, start_date, end_date, frequency
|
||||
)
|
||||
existing_dates = set()
|
||||
|
||||
if cached_data and "financial_statements" in cached_data:
|
||||
for statement_type in [
|
||||
"balance_sheet",
|
||||
"income_statement",
|
||||
"cash_flow",
|
||||
]:
|
||||
if statement_type in cached_data["financial_statements"]:
|
||||
stmt = cached_data["financial_statements"][statement_type]
|
||||
if "report_date" in stmt:
|
||||
existing_dates.add(stmt["report_date"])
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking cached data for gap detection: {e}")
|
||||
existing_dates = set()
|
||||
|
||||
# Calculate expected report dates based on frequency
|
||||
expected_dates = []
|
||||
current_year = start_dt.year
|
||||
end_year = end_dt.year
|
||||
|
||||
if frequency == "quarterly":
|
||||
# Standard quarterly dates: Mar 31, Jun 30, Sep 30, Dec 31
|
||||
quarter_dates = [
|
||||
(3, 31), # Q1
|
||||
(6, 30), # Q2
|
||||
(9, 30), # Q3
|
||||
(12, 31), # Q4
|
||||
]
|
||||
|
||||
for year in range(current_year, end_year + 1):
|
||||
for month, day in quarter_dates:
|
||||
report_date = date(year, month, day)
|
||||
if start_dt <= report_date <= end_dt:
|
||||
expected_dates.append(report_date.isoformat())
|
||||
|
||||
elif frequency == "annual":
|
||||
# Standard fiscal year end: Dec 31
|
||||
for year in range(current_year, end_year + 1):
|
||||
report_date = date(year, 12, 31)
|
||||
if start_dt <= report_date <= end_dt:
|
||||
expected_dates.append(report_date.isoformat())
|
||||
|
||||
# Return dates that are expected but not in cache
|
||||
missing_dates = [d for d in expected_dates if d not in existing_dates]
|
||||
|
||||
if missing_dates:
|
||||
logger.info(
|
||||
f"Gap detection for {symbol}: missing {len(missing_dates)} report periods"
|
||||
)
|
||||
|
||||
return missing_dates
|
||||
|
||||
def _determine_data_quality(
|
||||
self, data_source: str, statement_count: int, has_errors: bool = False
|
||||
) -> DataQuality:
|
||||
"""Determine data quality based on source, statement count, and errors."""
|
||||
if has_errors or statement_count == 0:
|
||||
return DataQuality.LOW
|
||||
|
||||
if data_source in ["local_cache", "error", "refresh_error"]:
|
||||
return DataQuality.LOW
|
||||
elif data_source in ["live_api", "live_api_refresh"]:
|
||||
if statement_count == 3:
|
||||
return DataQuality.HIGH # All three statements available
|
||||
elif statement_count == 2:
|
||||
return DataQuality.MEDIUM # Two statements available
|
||||
else:
|
||||
return DataQuality.LOW # One or no statements
|
||||
else:
|
||||
return DataQuality.MEDIUM
|
||||
|
|
@ -1,346 +0,0 @@
|
|||
"""
|
||||
Market data service that provides structured market context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
|
||||
from tradingagents.models.context import (
|
||||
MarketDataContext,
|
||||
TechnicalIndicatorData,
|
||||
)
|
||||
from tradingagents.repositories.base import BaseRepository
|
||||
|
||||
from .base import BaseService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketDataService(BaseService):
|
||||
"""Service for market data and technical indicators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: BaseClient | None = None,
|
||||
repository: BaseRepository | None = None,
|
||||
online_mode: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize market data service.
|
||||
|
||||
Args:
|
||||
client: Client for live market data
|
||||
repository: Repository for historical market data
|
||||
online_mode: Whether to use live data
|
||||
**kwargs: Additional configuration
|
||||
"""
|
||||
super().__init__(online_mode, **kwargs)
|
||||
self.client = client
|
||||
self.repository = repository
|
||||
self.stockstats_utils = StockstatsUtils()
|
||||
|
||||
def get_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
indicators: list[str] | None = None,
|
||||
force_refresh: bool = False,
|
||||
**kwargs,
|
||||
) -> MarketDataContext:
|
||||
"""
|
||||
Get market data context with price data and technical indicators.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
indicators: List of technical indicators to calculate
|
||||
force_refresh: If True, skip local data and fetch fresh from API
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
MarketDataContext: Structured market data context
|
||||
"""
|
||||
if indicators is None:
|
||||
indicators = ["rsi", "macd", "close_50_sma"]
|
||||
|
||||
# Local-first data strategy with force refresh option
|
||||
if force_refresh:
|
||||
# Skip local data, fetch fresh from API
|
||||
price_data = self._fetch_and_cache_fresh_data(symbol, start_date, end_date)
|
||||
data_source = "live_api_refresh"
|
||||
else:
|
||||
# Check local data first, fetch missing if needed
|
||||
price_data = self._get_price_data_local_first(symbol, start_date, end_date)
|
||||
data_source = price_data.get("metadata", {}).get("source", "unknown")
|
||||
|
||||
# Calculate technical indicators
|
||||
technical_indicators = self._calculate_indicators(
|
||||
symbol, start_date, end_date, indicators
|
||||
)
|
||||
|
||||
# Determine data quality
|
||||
data_quality = self._determine_data_quality(
|
||||
data_source=data_source,
|
||||
record_count=len(price_data.get("data", [])),
|
||||
has_errors="error" in price_data.get("metadata", {}),
|
||||
)
|
||||
|
||||
# Create metadata
|
||||
metadata = self._create_base_metadata(
|
||||
data_quality=data_quality,
|
||||
price_data_source=data_source,
|
||||
indicator_count=len(technical_indicators),
|
||||
symbol=symbol,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
|
||||
return MarketDataContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
price_data=price_data.get("data", []),
|
||||
technical_indicators=technical_indicators,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def get_price_context(
|
||||
self, symbol: str, start_date: str, end_date: str, **kwargs
|
||||
) -> MarketDataContext:
|
||||
"""
|
||||
Get market data context with just price data (no indicators).
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
MarketDataContext: Market context with price data only
|
||||
"""
|
||||
return self.get_context(symbol, start_date, end_date, indicators=[], **kwargs)
|
||||
|
||||
def _get_price_data_local_first(
|
||||
self, symbol: str, start_date: str, end_date: str
|
||||
) -> dict[str, Any]:
|
||||
"""Get price data using local-first strategy: check local data first, fetch missing if needed."""
|
||||
try:
|
||||
# Check if we have sufficient local data
|
||||
if self.repository and self.repository.has_data_for_period(
|
||||
symbol, start_date, end_date
|
||||
):
|
||||
logger.info(
|
||||
f"Using local data for {symbol} ({start_date} to {end_date})"
|
||||
)
|
||||
local_data = self.repository.get_data(
|
||||
symbol=symbol, start_date=start_date, end_date=end_date
|
||||
)
|
||||
local_data["metadata"] = local_data.get("metadata", {})
|
||||
local_data["metadata"]["source"] = "local_cache"
|
||||
return local_data
|
||||
|
||||
# We don't have sufficient local data - need to fetch from API
|
||||
if self.client:
|
||||
logger.info(
|
||||
f"Local data insufficient, fetching from API for {symbol} ({start_date} to {end_date})"
|
||||
)
|
||||
fresh_data = self.client.get_data(
|
||||
symbol=symbol, start_date=start_date, end_date=end_date
|
||||
)
|
||||
|
||||
# Cache the fresh data if we have a repository
|
||||
if fresh_data and self.repository:
|
||||
try:
|
||||
self.repository.store_data(symbol, fresh_data)
|
||||
logger.debug(f"Cached fresh data for {symbol}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache data for {symbol}: {e}")
|
||||
|
||||
fresh_data["metadata"] = fresh_data.get("metadata", {})
|
||||
fresh_data["metadata"]["source"] = "live_api"
|
||||
return fresh_data
|
||||
|
||||
# No client available, try repository as fallback
|
||||
elif self.repository:
|
||||
logger.warning(
|
||||
f"No API client available, using partial local data for {symbol}"
|
||||
)
|
||||
local_data = self.repository.get_data(
|
||||
symbol=symbol, start_date=start_date, end_date=end_date
|
||||
)
|
||||
local_data["metadata"] = local_data.get("metadata", {})
|
||||
local_data["metadata"]["source"] = "local_partial"
|
||||
return local_data
|
||||
|
||||
else:
|
||||
logger.warning(f"No data source available for {symbol}")
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"data": [],
|
||||
"metadata": {
|
||||
"source": "none",
|
||||
"error": "No client or repository configured",
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching price data for {symbol}: {e}")
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"data": [],
|
||||
"metadata": {"source": "error", "error": str(e)},
|
||||
}
|
||||
|
||||
def _fetch_and_cache_fresh_data(
|
||||
self, symbol: str, start_date: str, end_date: str
|
||||
) -> dict[str, Any]:
|
||||
"""Force fetch fresh data from API and cache it, bypassing local data."""
|
||||
try:
|
||||
if not self.client:
|
||||
logger.warning(f"No API client available for force refresh of {symbol}")
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"data": [],
|
||||
"metadata": {
|
||||
"source": "no_client",
|
||||
"error": "No API client configured for force refresh",
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Force refreshing data from API for {symbol} ({start_date} to {end_date})"
|
||||
)
|
||||
|
||||
# Clear existing data if we have a repository
|
||||
if self.repository:
|
||||
try:
|
||||
self.repository.clear_data(symbol, start_date, end_date)
|
||||
logger.debug(f"Cleared existing data for {symbol}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear existing data for {symbol}: {e}")
|
||||
|
||||
# Fetch fresh data
|
||||
fresh_data = self.client.get_data(
|
||||
symbol=symbol, start_date=start_date, end_date=end_date
|
||||
)
|
||||
|
||||
# Cache the fresh data
|
||||
if fresh_data and self.repository:
|
||||
try:
|
||||
self.repository.store_data(symbol, fresh_data, overwrite=True)
|
||||
logger.debug(f"Cached refreshed data for {symbol}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache refreshed data for {symbol}: {e}")
|
||||
|
||||
fresh_data["metadata"] = fresh_data.get("metadata", {})
|
||||
fresh_data["metadata"]["source"] = "live_api_refresh"
|
||||
return fresh_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error force refreshing data for {symbol}: {e}")
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"data": [],
|
||||
"metadata": {"source": "refresh_error", "error": str(e)},
|
||||
}
|
||||
|
||||
def _calculate_indicators(
|
||||
self, symbol: str, start_date: str, end_date: str, indicators: list[str]
|
||||
) -> dict[str, list[TechnicalIndicatorData]]:
|
||||
"""Calculate technical indicators."""
|
||||
if not indicators:
|
||||
return {}
|
||||
|
||||
technical_data = {}
|
||||
|
||||
for indicator in indicators:
|
||||
try:
|
||||
logger.info(f"Calculating {indicator} for {symbol}")
|
||||
|
||||
# Use existing stockstats utility
|
||||
indicator_data = self._get_indicator_data(
|
||||
symbol, indicator, start_date, end_date
|
||||
)
|
||||
|
||||
if indicator_data:
|
||||
technical_data[indicator] = indicator_data
|
||||
else:
|
||||
logger.warning(f"No data returned for indicator {indicator}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating {indicator} for {symbol}: {e}")
|
||||
continue
|
||||
|
||||
return technical_data
|
||||
|
||||
def _get_indicator_data(
|
||||
self, symbol: str, indicator: str, start_date: str, end_date: str
|
||||
) -> list[TechnicalIndicatorData]:
|
||||
"""Get indicator data using StockstatsUtils."""
|
||||
try:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Get data for the date range
|
||||
current_date = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
|
||||
indicator_points = []
|
||||
|
||||
# Iterate through date range
|
||||
while current_date >= start_date_dt:
|
||||
date_str = current_date.strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
# Use stockstats utility to get indicator value
|
||||
# This assumes the existing data directory structure
|
||||
data_dir = self.config.get("data_dir", "data")
|
||||
price_data_dir = f"{data_dir}/market_data/price_data"
|
||||
|
||||
indicator_value = StockstatsUtils.get_stock_stats(
|
||||
symbol,
|
||||
indicator,
|
||||
date_str,
|
||||
price_data_dir,
|
||||
online=self.online_mode,
|
||||
)
|
||||
|
||||
if indicator_value is not None and indicator_value != "":
|
||||
# Handle different indicator value types
|
||||
if isinstance(indicator_value, int | float):
|
||||
value = float(indicator_value)
|
||||
elif isinstance(indicator_value, str):
|
||||
try:
|
||||
value = float(indicator_value)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Could not parse indicator value: {indicator_value}"
|
||||
)
|
||||
current_date -= timedelta(days=1)
|
||||
continue
|
||||
else:
|
||||
# For complex indicators like MACD, this might be a dict
|
||||
value = indicator_value
|
||||
|
||||
indicator_points.append(
|
||||
TechnicalIndicatorData(
|
||||
date=date_str, value=value, indicator_type=indicator
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Could not get {indicator} for {symbol} on {date_str}: {e}"
|
||||
)
|
||||
|
||||
current_date -= timedelta(days=1)
|
||||
|
||||
# Return in chronological order
|
||||
return list(reversed(indicator_points))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting indicator data for {indicator}: {e}")
|
||||
return []
|
||||
|
|
@ -1,577 +0,0 @@
|
|||
"""
|
||||
Social Media Service for aggregating and analyzing social media data.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.models.context import (
|
||||
DataQuality,
|
||||
PostData,
|
||||
SentimentScore,
|
||||
SocialContext,
|
||||
)
|
||||
from tradingagents.repositories.base import BaseRepository
|
||||
from tradingagents.services.base import BaseService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SocialMediaService(BaseService):
|
||||
"""Service for social media data aggregation and analysis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reddit_client: BaseClient | None = None,
|
||||
repository: BaseRepository | None = None,
|
||||
online_mode: bool = True,
|
||||
data_dir: str = "data",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Social Media Service.
|
||||
|
||||
Args:
|
||||
reddit_client: Client for Reddit API access
|
||||
repository: Repository for cached social data
|
||||
online_mode: Whether to fetch live data
|
||||
data_dir: Directory for data storage
|
||||
"""
|
||||
super().__init__(online_mode=online_mode, data_dir=data_dir, **kwargs)
|
||||
self.reddit_client = reddit_client
|
||||
self.repository = repository
|
||||
|
||||
def get_context(
|
||||
self,
|
||||
query: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
symbol: str | None = None,
|
||||
subreddits: list[str] | None = None,
|
||||
force_refresh: bool = False,
|
||||
**kwargs,
|
||||
) -> SocialContext:
|
||||
"""Get social media context for a query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
symbol: Optional stock symbol
|
||||
subreddits: Optional list of subreddits to search
|
||||
force_refresh: If True, skip local data and fetch fresh from APIs
|
||||
|
||||
Returns:
|
||||
SocialContext with posts and sentiment analysis
|
||||
"""
|
||||
posts = []
|
||||
error_info = {}
|
||||
data_source = "unknown"
|
||||
|
||||
try:
|
||||
# Local-first data strategy with force refresh option
|
||||
if force_refresh:
|
||||
# Skip local data, fetch fresh from APIs
|
||||
posts, data_source = self._fetch_and_cache_fresh_social_data(
|
||||
query, start_date, end_date, symbol, subreddits
|
||||
)
|
||||
else:
|
||||
# Check local data first, fetch missing if needed
|
||||
posts, data_source = self._get_social_data_local_first(
|
||||
query, start_date, end_date, symbol, subreddits
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching social media data: {e}")
|
||||
error_info = {"error": str(e)}
|
||||
|
||||
# Calculate sentiment and engagement metrics
|
||||
sentiment_summary = self._calculate_sentiment(posts)
|
||||
engagement_metrics = self._calculate_engagement_metrics(posts)
|
||||
|
||||
# Determine data quality based on data source
|
||||
data_quality = self._determine_data_quality(
|
||||
data_source=data_source,
|
||||
record_count=len(posts),
|
||||
has_errors=bool(error_info),
|
||||
)
|
||||
|
||||
# Separate float metrics from metadata
|
||||
float_metrics = {
|
||||
k: v for k, v in engagement_metrics.items() if isinstance(v, int | float)
|
||||
}
|
||||
metadata_info = {
|
||||
k: v
|
||||
for k, v in engagement_metrics.items()
|
||||
if not isinstance(v, int | float)
|
||||
}
|
||||
|
||||
return SocialContext(
|
||||
symbol=symbol,
|
||||
period={"start": start_date, "end": end_date},
|
||||
posts=posts,
|
||||
engagement_metrics=float_metrics,
|
||||
sentiment_summary=sentiment_summary,
|
||||
post_count=len(posts),
|
||||
platforms=["reddit"],
|
||||
metadata={
|
||||
"data_quality": data_quality,
|
||||
"service": "social_media",
|
||||
"online_mode": self.is_online(),
|
||||
"subreddits": subreddits or [],
|
||||
"data_source": data_source,
|
||||
"force_refresh": force_refresh,
|
||||
**metadata_info,
|
||||
**error_info,
|
||||
},
|
||||
)
|
||||
|
||||
def get_company_social_context(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
subreddits: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> SocialContext:
|
||||
"""Get company-specific social media context.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
subreddits: Optional list of subreddits
|
||||
|
||||
Returns:
|
||||
SocialContext for the company
|
||||
"""
|
||||
return self.get_context(
|
||||
query=symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
symbol=symbol,
|
||||
subreddits=subreddits,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_global_trends(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
subreddits: list[str] | None = None,
|
||||
**kwargs,
|
||||
) -> SocialContext:
|
||||
"""Get global social media trends.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format
|
||||
subreddits: Optional list of subreddits
|
||||
|
||||
Returns:
|
||||
SocialContext with global trends
|
||||
"""
|
||||
posts = []
|
||||
|
||||
try:
|
||||
if self.is_online() and self.reddit_client:
|
||||
subreddit_list = subreddits or ["news", "worldnews", "Economics"]
|
||||
|
||||
# Get top posts from subreddits
|
||||
raw_posts = self.reddit_client.get_top_posts(
|
||||
subreddit_names=subreddit_list, limit=50, time_filter="week"
|
||||
)
|
||||
|
||||
# Filter by date
|
||||
if hasattr(self.reddit_client, "filter_posts_by_date"):
|
||||
raw_posts = self.reddit_client.filter_posts_by_date(
|
||||
raw_posts, start_date, end_date
|
||||
)
|
||||
|
||||
posts = self._convert_to_post_data(raw_posts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching global trends: {e}")
|
||||
|
||||
sentiment_summary = self._calculate_sentiment(posts)
|
||||
engagement_metrics = self._calculate_engagement_metrics(posts)
|
||||
|
||||
# Separate float metrics from metadata
|
||||
float_metrics = {
|
||||
k: v for k, v in engagement_metrics.items() if isinstance(v, int | float)
|
||||
}
|
||||
metadata_info = {
|
||||
k: v
|
||||
for k, v in engagement_metrics.items()
|
||||
if not isinstance(v, int | float)
|
||||
}
|
||||
|
||||
return SocialContext(
|
||||
symbol=None, # No specific symbol for global trends
|
||||
period={"start": start_date, "end": end_date},
|
||||
posts=posts,
|
||||
engagement_metrics=float_metrics,
|
||||
sentiment_summary=sentiment_summary,
|
||||
post_count=len(posts),
|
||||
platforms=["reddit"],
|
||||
metadata={
|
||||
"data_quality": self._determine_data_quality(
|
||||
data_source="live_api" if self.is_online() else "offline",
|
||||
record_count=len(posts),
|
||||
has_errors=False,
|
||||
),
|
||||
"service": "social_media",
|
||||
"type": "global_trends",
|
||||
"subreddits": subreddits or [],
|
||||
**metadata_info,
|
||||
},
|
||||
)
|
||||
|
||||
def _convert_to_post_data(self, raw_posts: list[dict[str, Any]]) -> list[PostData]:
|
||||
"""Convert raw Reddit posts to PostData objects."""
|
||||
posts = []
|
||||
|
||||
for post in raw_posts:
|
||||
try:
|
||||
# Calculate engagement score
|
||||
engagement = post.get("upvotes", 0) + post.get("num_comments", 0)
|
||||
|
||||
# Get posted date
|
||||
if "posted_date" in post:
|
||||
date_str = post["posted_date"]
|
||||
elif "created_utc" in post:
|
||||
date_str = datetime.fromtimestamp(post["created_utc"]).strftime(
|
||||
"%Y-%m-%d"
|
||||
)
|
||||
else:
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
post_data = PostData(
|
||||
title=post.get("title", ""),
|
||||
content=post.get("content", ""),
|
||||
author=post.get("author", "unknown"),
|
||||
source=post.get("subreddit", "reddit"),
|
||||
date=date_str,
|
||||
url=post.get("url", ""),
|
||||
score=post.get("score", 0),
|
||||
comments=post.get("num_comments", 0),
|
||||
engagement_score=engagement,
|
||||
subreddit=post.get("subreddit"),
|
||||
metadata={
|
||||
"upvotes": post.get("upvotes", 0),
|
||||
"num_comments": post.get("num_comments", 0),
|
||||
"subreddit": post.get("subreddit", ""),
|
||||
},
|
||||
)
|
||||
posts.append(post_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error converting post: {e}")
|
||||
continue
|
||||
|
||||
return posts
|
||||
|
||||
def _convert_cached_to_posts(self, cached_data: dict[str, Any]) -> list[PostData]:
|
||||
"""Convert cached repository data to PostData objects."""
|
||||
posts = []
|
||||
|
||||
if not cached_data or "posts" not in cached_data:
|
||||
return posts
|
||||
|
||||
for post in cached_data.get("posts", []):
|
||||
try:
|
||||
posts.append(PostData(**post))
|
||||
except Exception as e:
|
||||
logger.warning(f"Error converting cached post: {e}")
|
||||
|
||||
return posts
|
||||
|
||||
def _get_social_data_local_first(
|
||||
self,
|
||||
query: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
symbol: str | None,
|
||||
subreddits: list[str] | None,
|
||||
) -> tuple[list[PostData], str]:
|
||||
"""Get social data using local-first strategy: check local data first, fetch missing if needed."""
|
||||
try:
|
||||
# Check if we have sufficient local data
|
||||
search_key = symbol or query
|
||||
if self.repository and self.repository.has_data_for_period(
|
||||
search_key, start_date, end_date, symbol=symbol
|
||||
):
|
||||
logger.info(
|
||||
f"Using local social data for {search_key} ({start_date} to {end_date})"
|
||||
)
|
||||
cached_data = self.repository.get_data(
|
||||
query=search_key, start_date=start_date, end_date=end_date
|
||||
)
|
||||
posts = self._convert_cached_to_posts(cached_data)
|
||||
return posts, "local_cache"
|
||||
|
||||
# We don't have sufficient local data - need to fetch from APIs
|
||||
logger.info(
|
||||
f"Local data insufficient, fetching from APIs for {search_key} ({start_date} to {end_date})"
|
||||
)
|
||||
posts, _ = self._fetch_fresh_social_data(
|
||||
query, start_date, end_date, symbol, subreddits
|
||||
)
|
||||
|
||||
# Cache the fresh data if we have a repository
|
||||
if posts and self.repository:
|
||||
try:
|
||||
posts_data = [post.model_dump() for post in posts]
|
||||
cache_data = {
|
||||
"query": query,
|
||||
"symbol": symbol,
|
||||
"posts": posts_data,
|
||||
"subreddits": subreddits,
|
||||
"metadata": {"cached_at": datetime.utcnow().isoformat()},
|
||||
}
|
||||
self.repository.store_data(search_key, cache_data, symbol=symbol)
|
||||
logger.debug(f"Cached fresh social data for {search_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache social data for {search_key}: {e}")
|
||||
|
||||
return posts, "live_api"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching social data for {query}: {e}")
|
||||
return [], "error"
|
||||
|
||||
def _fetch_and_cache_fresh_social_data(
|
||||
self,
|
||||
query: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
symbol: str | None,
|
||||
subreddits: list[str] | None,
|
||||
) -> tuple[list[PostData], str]:
|
||||
"""Force fetch fresh social data from APIs and cache it, bypassing local data."""
|
||||
try:
|
||||
search_key = symbol or query
|
||||
logger.info(
|
||||
f"Force refreshing social data from APIs for {search_key} ({start_date} to {end_date})"
|
||||
)
|
||||
|
||||
# Clear existing data if we have a repository
|
||||
if self.repository:
|
||||
try:
|
||||
self.repository.clear_data(
|
||||
search_key, start_date, end_date, symbol=symbol
|
||||
)
|
||||
logger.debug(f"Cleared existing social data for {search_key}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear existing social data for {search_key}: {e}"
|
||||
)
|
||||
|
||||
# Fetch fresh data
|
||||
posts, _ = self._fetch_fresh_social_data(
|
||||
query, start_date, end_date, symbol, subreddits
|
||||
)
|
||||
|
||||
# Cache the fresh data
|
||||
if posts and self.repository:
|
||||
try:
|
||||
posts_data = [post.model_dump() for post in posts]
|
||||
cache_data = {
|
||||
"query": query,
|
||||
"symbol": symbol,
|
||||
"posts": posts_data,
|
||||
"subreddits": subreddits,
|
||||
"metadata": {"refreshed_at": datetime.utcnow().isoformat()},
|
||||
}
|
||||
self.repository.store_data(
|
||||
search_key, cache_data, symbol=symbol, overwrite=True
|
||||
)
|
||||
logger.debug(f"Cached refreshed social data for {search_key}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to cache refreshed social data for {search_key}: {e}"
|
||||
)
|
||||
|
||||
return posts, "live_api_refresh"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error force refreshing social data for {query}: {e}")
|
||||
return [], "refresh_error"
|
||||
|
||||
def _fetch_fresh_social_data(
|
||||
self,
|
||||
query: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
symbol: str | None,
|
||||
subreddits: list[str] | None,
|
||||
) -> tuple[list[PostData], str]:
|
||||
"""Fetch fresh social data from APIs."""
|
||||
posts = []
|
||||
|
||||
if self.is_online() and self.reddit_client:
|
||||
# Get live Reddit data
|
||||
subreddit_list = subreddits or ["investing", "stocks", "wallstreetbets"]
|
||||
|
||||
# Search for posts
|
||||
raw_posts = self.reddit_client.search_posts(
|
||||
query=query,
|
||||
subreddit_names=subreddit_list,
|
||||
limit=50,
|
||||
time_filter="week",
|
||||
)
|
||||
|
||||
# Filter by date
|
||||
if hasattr(self.reddit_client, "filter_posts_by_date"):
|
||||
raw_posts = self.reddit_client.filter_posts_by_date(
|
||||
raw_posts, start_date, end_date
|
||||
)
|
||||
|
||||
# Convert to PostData objects
|
||||
posts = self._convert_to_post_data(raw_posts)
|
||||
|
||||
return posts, "live_api"
|
||||
|
||||
def _calculate_sentiment(self, posts: list[PostData]) -> SentimentScore:
|
||||
"""Calculate overall sentiment from posts."""
|
||||
if not posts:
|
||||
return SentimentScore(score=0.0, confidence=0.0, label="neutral")
|
||||
|
||||
total_score = 0.0
|
||||
total_weight = 0.0
|
||||
|
||||
for post in posts:
|
||||
# Simple sentiment analysis based on keywords and engagement
|
||||
sentiment_score = self._analyze_post_sentiment(post)
|
||||
|
||||
# Weight by engagement
|
||||
weight = 1 + (
|
||||
post.engagement_score / 1000
|
||||
) # Higher engagement = more weight
|
||||
total_score += sentiment_score * weight
|
||||
total_weight += weight
|
||||
|
||||
# Set individual post sentiment
|
||||
post.sentiment = SentimentScore(
|
||||
score=sentiment_score,
|
||||
confidence=0.7, # Moderate confidence for keyword-based analysis
|
||||
label="positive"
|
||||
if sentiment_score > 0.2
|
||||
else "negative"
|
||||
if sentiment_score < -0.2
|
||||
else "neutral",
|
||||
)
|
||||
|
||||
# Calculate weighted average
|
||||
avg_score = total_score / total_weight if total_weight > 0 else 0.0
|
||||
|
||||
# Determine label
|
||||
if avg_score > 0.2:
|
||||
label = "positive"
|
||||
elif avg_score < -0.2:
|
||||
label = "negative"
|
||||
else:
|
||||
label = "neutral"
|
||||
|
||||
# Confidence based on number of posts
|
||||
confidence = min(0.9, 0.5 + (len(posts) / 100))
|
||||
|
||||
return SentimentScore(score=avg_score, confidence=confidence, label=label)
|
||||
|
||||
def _analyze_post_sentiment(self, post: PostData) -> float:
|
||||
"""Analyze sentiment of a single post."""
|
||||
text = f"{post.title} {post.content or ''}".lower()
|
||||
|
||||
# Simple keyword-based sentiment
|
||||
positive_words = [
|
||||
"bullish",
|
||||
"moon",
|
||||
"gains",
|
||||
"buy",
|
||||
"hold",
|
||||
"amazing",
|
||||
"great",
|
||||
"excellent",
|
||||
"positive",
|
||||
"growth",
|
||||
"beat",
|
||||
"upgrade",
|
||||
"🚀",
|
||||
]
|
||||
negative_words = [
|
||||
"bearish",
|
||||
"crash",
|
||||
"sell",
|
||||
"loss",
|
||||
"decline",
|
||||
"terrible",
|
||||
"bad",
|
||||
"negative",
|
||||
"downgrade",
|
||||
"warning",
|
||||
"overvalued",
|
||||
]
|
||||
|
||||
positive_count = sum(1 for word in positive_words if word in text)
|
||||
negative_count = sum(1 for word in negative_words if word in text)
|
||||
|
||||
# Score from -1 to 1
|
||||
if positive_count + negative_count == 0:
|
||||
return 0.0
|
||||
|
||||
score = (positive_count - negative_count) / (positive_count + negative_count)
|
||||
|
||||
# Adjust for score ratio (upvotes vs downvotes implied)
|
||||
if post.score > 0:
|
||||
score_adjustment = min(0.2, post.score / 1000)
|
||||
score = score * 0.8 + score_adjustment * 0.2
|
||||
|
||||
return max(-1.0, min(1.0, score))
|
||||
|
||||
def _calculate_engagement_metrics(self, posts: list[PostData]) -> dict[str, float]:
|
||||
"""Calculate engagement metrics from posts."""
|
||||
if not posts:
|
||||
return {
|
||||
"total_engagement": 0,
|
||||
"average_engagement": 0,
|
||||
"max_engagement": 0,
|
||||
"total_posts": 0,
|
||||
}
|
||||
|
||||
engagements = [post.engagement_score for post in posts]
|
||||
|
||||
metrics = {
|
||||
"total_engagement": sum(engagements),
|
||||
"average_engagement": sum(engagements) / len(engagements),
|
||||
"max_engagement": max(engagements),
|
||||
"total_posts": len(posts),
|
||||
}
|
||||
|
||||
# Add top posts info
|
||||
sorted_posts = sorted(posts, key=lambda p: p.engagement_score, reverse=True)
|
||||
metrics["top_posts"] = [
|
||||
{"title": p.title[:100], "engagement": p.engagement_score}
|
||||
for p in sorted_posts[:3]
|
||||
]
|
||||
|
||||
return metrics
|
||||
|
||||
def _determine_data_quality(
|
||||
self, data_source: str, record_count: int, has_errors: bool = False
|
||||
) -> DataQuality:
|
||||
"""Determine data quality based on source, record count, and errors."""
|
||||
if has_errors or record_count == 0:
|
||||
return DataQuality.LOW
|
||||
|
||||
if data_source in ["local_cache", "error", "refresh_error"]:
|
||||
return DataQuality.LOW
|
||||
elif data_source in ["live_api", "live_api_refresh"]:
|
||||
if record_count >= 20:
|
||||
return DataQuality.HIGH
|
||||
elif record_count >= 5:
|
||||
return DataQuality.MEDIUM
|
||||
else:
|
||||
return DataQuality.LOW
|
||||
else:
|
||||
return DataQuality.MEDIUM
|
||||
|
|
@ -1,403 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test InsiderDataService with mock Finnhub client and real InsiderDataRepository.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.insert(0, os.path.abspath("."))
|
||||
|
||||
from tradingagents.clients.base import BaseClient
|
||||
from tradingagents.models.context import DataQuality, InsiderContext, InsiderTransaction
|
||||
from tradingagents.repositories.insider_repository import InsiderDataRepository
|
||||
from tradingagents.services.insider_data_service import InsiderDataService
|
||||
|
||||
|
||||
class MockFinnhubClient(BaseClient):
|
||||
"""Mock Finnhub client that returns sample insider trading data."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.connection_works = True
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
return self.connection_works
|
||||
|
||||
def get_data(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""Not used directly by InsiderDataService."""
|
||||
return {}
|
||||
|
||||
def get_insider_trading(
|
||||
self, ticker: str, start_date: str, end_date: str
|
||||
) -> dict[str, Any]:
|
||||
"""Return mock insider trading data."""
|
||||
# Use fixed dates within test range for predictable filtering
|
||||
base_date = datetime(2024, 6, 15) # Within our test range
|
||||
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"data_type": "insider_trading",
|
||||
"transactions": [
|
||||
{
|
||||
"filingDate": (base_date - timedelta(days=30)).strftime("%Y-%m-%d"),
|
||||
"name": "John Smith",
|
||||
"change": -50000,
|
||||
"sharesTotal": 150000,
|
||||
"transactionPrice": 180.50,
|
||||
"transactionCode": "S", # Sale
|
||||
},
|
||||
{
|
||||
"filingDate": (base_date - timedelta(days=20)).strftime("%Y-%m-%d"),
|
||||
"name": "Jane Doe",
|
||||
"change": 25000,
|
||||
"sharesTotal": 75000,
|
||||
"transactionPrice": 185.25,
|
||||
"transactionCode": "P", # Purchase
|
||||
},
|
||||
{
|
||||
"filingDate": (base_date - timedelta(days=10)).strftime("%Y-%m-%d"),
|
||||
"name": "Robert Johnson",
|
||||
"change": -10000,
|
||||
"sharesTotal": 40000,
|
||||
"transactionPrice": 178.75,
|
||||
"transactionCode": "S", # Sale
|
||||
},
|
||||
{
|
||||
"filingDate": (base_date - timedelta(days=5)).strftime("%Y-%m-%d"),
|
||||
"name": "Mary Wilson",
|
||||
"change": 15000,
|
||||
"sharesTotal": 65000,
|
||||
"transactionPrice": 182.00,
|
||||
"transactionCode": "P", # Purchase
|
||||
},
|
||||
],
|
||||
"metadata": {
|
||||
"source": "mock_finnhub",
|
||||
"retrieved_at": datetime(2024, 1, 2).isoformat(),
|
||||
"symbol": ticker,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_online_mode_with_mock_finnhub():
|
||||
"""Test InsiderDataService in online mode with mock Finnhub client."""
|
||||
# Create mock client and real repository
|
||||
mock_finnhub = MockFinnhubClient()
|
||||
real_repo = InsiderDataRepository("test_data")
|
||||
|
||||
# Create service in online mode
|
||||
service = InsiderDataService(
|
||||
finnhub_client=mock_finnhub,
|
||||
repository=real_repo,
|
||||
online_mode=True,
|
||||
data_dir="test_data",
|
||||
)
|
||||
|
||||
# Test getting insider context
|
||||
context = service.get_insider_context(
|
||||
symbol="AAPL",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
force_refresh=True,
|
||||
)
|
||||
|
||||
# Validate context structure
|
||||
assert isinstance(context, InsiderContext)
|
||||
assert context.symbol == "AAPL"
|
||||
assert context.period["start"] == "2024-01-01"
|
||||
assert context.period["end"] == "2024-12-31"
|
||||
|
||||
# Validate transactions
|
||||
assert len(context.transactions) == 4
|
||||
assert all(isinstance(tx, InsiderTransaction) for tx in context.transactions)
|
||||
|
||||
# Check transaction details
|
||||
first_tx = context.transactions[0]
|
||||
assert first_tx.name == "John Smith"
|
||||
assert first_tx.change == -50000 # Sale
|
||||
assert first_tx.transaction_code == "S"
|
||||
assert first_tx.transaction_price == 180.50
|
||||
|
||||
# Validate sentiment data and net activity
|
||||
assert "buy_sell_ratio" in context.sentiment_data
|
||||
assert "insider_sentiment_score" in context.sentiment_data
|
||||
|
||||
assert "net_shares_change" in context.net_activity
|
||||
assert "net_transaction_value" in context.net_activity
|
||||
assert "buy_transactions" in context.net_activity
|
||||
assert "sell_transactions" in context.net_activity
|
||||
|
||||
# Validate metadata
|
||||
assert context.transaction_count == 4
|
||||
assert "data_quality" in context.metadata
|
||||
assert context.metadata["service"] == "insider_data"
|
||||
|
||||
# Test JSON serialization
|
||||
json_output = context.model_dump_json(indent=2)
|
||||
assert len(json_output) > 0
|
||||
|
||||
|
||||
def test_insider_sentiment_analysis():
|
||||
"""Test insider sentiment calculation based on transactions."""
|
||||
mock_finnhub = MockFinnhubClient()
|
||||
service = InsiderDataService(
|
||||
finnhub_client=mock_finnhub, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_insider_context("TSLA", "2024-01-01", "2024-12-31")
|
||||
|
||||
# Check sentiment calculations
|
||||
sentiment = context.sentiment_data
|
||||
|
||||
# Should have buy/sell ratio
|
||||
assert "buy_sell_ratio" in sentiment
|
||||
assert sentiment["buy_sell_ratio"] > 0
|
||||
|
||||
# Should have insider sentiment score (-1 to 1)
|
||||
assert "insider_sentiment_score" in sentiment
|
||||
assert -1.0 <= sentiment["insider_sentiment_score"] <= 1.0
|
||||
|
||||
# Check net activity calculations
|
||||
net_activity = context.net_activity
|
||||
|
||||
# Net shares change: sum of all changes
|
||||
expected_net_change = -50000 + 25000 + (-10000) + 15000 # -20000
|
||||
assert net_activity["net_shares_change"] == expected_net_change
|
||||
|
||||
# Should have buy/sell transaction counts
|
||||
assert net_activity["buy_transactions"] == 2 # Jane Doe and Mary Wilson
|
||||
assert net_activity["sell_transactions"] == 2 # John Smith and Robert Johnson
|
||||
|
||||
|
||||
def test_offline_mode():
|
||||
"""Test InsiderDataService in offline mode."""
|
||||
real_repo = InsiderDataRepository("test_data")
|
||||
|
||||
service = InsiderDataService(
|
||||
finnhub_client=None, repository=real_repo, online_mode=False
|
||||
)
|
||||
|
||||
# Should handle offline gracefully
|
||||
context = service.get_insider_context("AAPL", "2024-01-01", "2024-12-31")
|
||||
|
||||
assert context.symbol == "AAPL"
|
||||
assert len(context.transactions) == 0 # No data available offline
|
||||
assert context.transaction_count == 0
|
||||
# Sentiment data should have default values even with no transactions
|
||||
assert context.sentiment_data.get("insider_sentiment_score", 0) == 0
|
||||
assert context.sentiment_data.get("buy_sell_ratio", 0) == 0
|
||||
# Net activity should have default values
|
||||
assert context.net_activity.get("net_shares_change", 0) == 0
|
||||
assert context.net_activity.get("net_transaction_value", 0) == 0
|
||||
assert context.metadata.get("data_quality") == DataQuality.LOW
|
||||
|
||||
|
||||
def test_empty_data_handling():
|
||||
"""Test handling when no insider transactions are available."""
|
||||
|
||||
class EmptyDataClient(MockFinnhubClient):
|
||||
def get_insider_trading(self, ticker, start_date, end_date):
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"data_type": "insider_trading",
|
||||
"transactions": [],
|
||||
"metadata": {"source": "mock_finnhub", "empty": True},
|
||||
}
|
||||
|
||||
empty_client = EmptyDataClient()
|
||||
service = InsiderDataService(
|
||||
finnhub_client=empty_client, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_insider_context("XYZ", "2024-01-01", "2024-12-31")
|
||||
|
||||
# Should handle empty data gracefully
|
||||
assert context.symbol == "XYZ"
|
||||
assert len(context.transactions) == 0
|
||||
assert context.transaction_count == 0
|
||||
|
||||
# Sentiment should be neutral with no data
|
||||
assert context.sentiment_data.get("insider_sentiment_score", 0) == 0
|
||||
assert context.net_activity.get("net_shares_change", 0) == 0
|
||||
assert context.metadata.get("data_quality") == DataQuality.LOW
|
||||
|
||||
|
||||
def test_error_handling():
|
||||
"""Test error handling with broken client."""
|
||||
|
||||
class BrokenFinnhubClient(BaseClient):
|
||||
def test_connection(self):
|
||||
return False
|
||||
|
||||
def get_data(self, *args, **kwargs):
|
||||
raise Exception("Finnhub API error")
|
||||
|
||||
def get_insider_trading(self, *args, **kwargs):
|
||||
raise Exception("Finnhub API error")
|
||||
|
||||
broken_client = BrokenFinnhubClient()
|
||||
service = InsiderDataService(
|
||||
finnhub_client=broken_client, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
# Should handle errors gracefully
|
||||
context = service.get_insider_context(
|
||||
"FAIL", "2024-01-01", "2024-12-31", force_refresh=True
|
||||
)
|
||||
|
||||
assert context.symbol == "FAIL"
|
||||
assert len(context.transactions) == 0
|
||||
assert context.transaction_count == 0
|
||||
assert context.metadata.get("data_quality") == DataQuality.LOW
|
||||
# Service logs errors but doesn't include them in metadata
|
||||
|
||||
|
||||
def test_transaction_filtering():
|
||||
"""Test filtering transactions by date range."""
|
||||
|
||||
# Create a client that returns transactions outside the date range
|
||||
class DateFilterTestClient(MockFinnhubClient):
|
||||
def get_insider_trading(self, ticker, start_date, end_date):
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"data_type": "insider_trading",
|
||||
"transactions": [
|
||||
{
|
||||
"filingDate": "2023-12-15", # Before start date
|
||||
"name": "Old Transaction",
|
||||
"change": -1000,
|
||||
"sharesTotal": 10000,
|
||||
"transactionPrice": 100.0,
|
||||
"transactionCode": "S",
|
||||
},
|
||||
{
|
||||
"filingDate": "2024-06-15", # Within range
|
||||
"name": "Valid Transaction",
|
||||
"change": 5000,
|
||||
"sharesTotal": 15000,
|
||||
"transactionPrice": 110.0,
|
||||
"transactionCode": "P",
|
||||
},
|
||||
{
|
||||
"filingDate": "2025-01-15", # After end date
|
||||
"name": "Future Transaction",
|
||||
"change": -2000,
|
||||
"sharesTotal": 8000,
|
||||
"transactionPrice": 120.0,
|
||||
"transactionCode": "S",
|
||||
},
|
||||
],
|
||||
"metadata": {"source": "mock_finnhub"},
|
||||
}
|
||||
|
||||
filter_client = DateFilterTestClient()
|
||||
service = InsiderDataService(
|
||||
finnhub_client=filter_client, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_insider_context("TEST", "2024-01-01", "2024-12-31")
|
||||
|
||||
# Should only include the transaction within the date range
|
||||
assert len(context.transactions) == 1
|
||||
assert context.transactions[0].name == "Valid Transaction"
|
||||
assert context.transaction_count == 1
|
||||
|
||||
|
||||
def test_json_structure():
|
||||
"""Test JSON structure of insider context."""
|
||||
mock_finnhub = MockFinnhubClient()
|
||||
service = InsiderDataService(
|
||||
finnhub_client=mock_finnhub, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_insider_context("NVDA", "2024-01-01", "2024-12-31")
|
||||
json_data = context.model_dump()
|
||||
|
||||
# Validate required fields
|
||||
required_fields = [
|
||||
"symbol",
|
||||
"period",
|
||||
"transactions",
|
||||
"sentiment_data",
|
||||
"transaction_count",
|
||||
"net_activity",
|
||||
"metadata",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in json_data
|
||||
|
||||
# Validate transaction structure
|
||||
if json_data["transactions"]:
|
||||
transaction = json_data["transactions"][0]
|
||||
required_tx_fields = [
|
||||
"filing_date",
|
||||
"name",
|
||||
"change",
|
||||
"shares",
|
||||
"transaction_price",
|
||||
"transaction_code",
|
||||
]
|
||||
for field in required_tx_fields:
|
||||
assert field in transaction
|
||||
|
||||
# Validate sentiment data structure
|
||||
sentiment = json_data["sentiment_data"]
|
||||
assert "buy_sell_ratio" in sentiment
|
||||
assert "insider_sentiment_score" in sentiment
|
||||
|
||||
# Validate net activity structure
|
||||
net_activity = json_data["net_activity"]
|
||||
expected_net_fields = [
|
||||
"net_shares_change",
|
||||
"net_transaction_value",
|
||||
"buy_transactions",
|
||||
"sell_transactions",
|
||||
]
|
||||
for field in expected_net_fields:
|
||||
assert field in net_activity
|
||||
|
||||
# Validate metadata
|
||||
metadata = json_data["metadata"]
|
||||
assert "data_quality" in metadata
|
||||
assert "service" in metadata
|
||||
|
||||
|
||||
def test_comprehensive_sentiment_calculation():
|
||||
"""Test comprehensive insider sentiment calculation."""
|
||||
mock_finnhub = MockFinnhubClient()
|
||||
service = InsiderDataService(
|
||||
finnhub_client=mock_finnhub, repository=None, online_mode=True
|
||||
)
|
||||
|
||||
context = service.get_insider_context("COMP", "2024-01-01", "2024-12-31")
|
||||
|
||||
# Validate sentiment calculations are reasonable
|
||||
sentiment = context.sentiment_data
|
||||
net_activity = context.net_activity
|
||||
|
||||
# Buy/sell ratio should be positive (we have both buys and sells)
|
||||
assert sentiment["buy_sell_ratio"] >= 0
|
||||
|
||||
# Insider sentiment score should be between -1 and 1
|
||||
assert -1.0 <= sentiment["insider_sentiment_score"] <= 1.0
|
||||
|
||||
# Net transaction value should be calculated correctly
|
||||
expected_value = (
|
||||
(-50000 * 180.50) # John Smith sale
|
||||
+ (25000 * 185.25) # Jane Doe purchase
|
||||
+ (-10000 * 178.75) # Robert Johnson sale
|
||||
+ (15000 * 182.00) # Mary Wilson purchase
|
||||
)
|
||||
assert abs(net_activity["net_transaction_value"] - expected_value) < 0.01
|
||||
|
||||
# Transaction counts should match
|
||||
assert net_activity["buy_transactions"] == 2
|
||||
assert net_activity["sell_transactions"] == 2
|
||||
|
||||
# Net shares change
|
||||
expected_net_shares = -50000 + 25000 + (-10000) + 15000 # -20000
|
||||
assert net_activity["net_shares_change"] == expected_net_shares
|
||||
Loading…
Reference in New Issue