This commit is contained in:
Martin C. Richards 2025-07-31 18:00:01 +02:00
parent b677288876
commit b2a09403fa
85 changed files with 13867 additions and 3103 deletions

68
.env.example Normal file
View File

@ -0,0 +1,68 @@
# =============================================================================
# TradingAgents Configuration
# =============================================================================
# -----------------------------------------------------------------------------
# LLM Provider Configuration (Choose One)
# -----------------------------------------------------------------------------
# For OpenAI
# OPENAI_API_KEY=your_openai_api_key_here
# For Anthropic Claude (Currently Active)
ANTHROPIC_API_KEY=your_key_here
# For Google Gemini
# GOOGLE_API_KEY=your_google_api_key_here
# -----------------------------------------------------------------------------
# LLM Settings
# -----------------------------------------------------------------------------
LLM_PROVIDER=anthropic
DEEP_THINK_LLM=claude-3-5-sonnet-20241022
QUICK_THINK_LLM=claude-3-5-haiku-20241022
BACKEND_URL=https://api.anthropic.com/v1
# -----------------------------------------------------------------------------
# Data Sources (Optional but Recommended)
# -----------------------------------------------------------------------------
# FinnHub API for financial data (free tier available)
# Sign up at: https://finnhub.io/
FINNHUB_API_KEY=your_finnhub_api_key_here
# Reddit API for social sentiment analysis
# Create app at: https://www.reddit.com/prefs/apps/
# Choose "script" application type
REDDIT_CLIENT_ID=your_reddit_client_id_here
REDDIT_CLIENT_SECRET=your_reddit_client_secret_here
REDDIT_USER_AGENT=TradingAgents/1.0 by YourUsername
# -----------------------------------------------------------------------------
# System Configuration
# -----------------------------------------------------------------------------
# Data and results directories
TRADINGAGENTS_RESULTS_DIR=./results
TRADINGAGENTS_DATA_DIR=/Users/martinrichards/Documents/TradingAgents/data
# Analysis settings
MAX_DEBATE_ROUNDS=5
MAX_RISK_DISCUSS_ROUNDS=5
MAX_RECUR_LIMIT=100
# Tool settings
ONLINE_TOOLS=true
# -----------------------------------------------------------------------------
# Setup Instructions:
# -----------------------------------------------------------------------------
# 1. Copy this file to .env: `cp .env.example .env`
# 2. Replace placeholder values with your actual API keys
# 3. Set ONLINE_TOOLS=true to fetch live data from APIs
# 4. Set ONLINE_TOOLS=false to use cached data only (requires local data files)
#
# API Setup Guide:
# ✅ LLM Provider - Choose one: OpenAI, Anthropic, or Google
# ✅ FinnHub - Sign up at https://finnhub.io/ for financial data
# ✅ Reddit - Create app at https://www.reddit.com/prefs/apps/ for social sentiment
# ✅ Google News - No API key required (web scraping)
# ✅ Yahoo Finance - No API key required (free access via yfinance)
# -----------------------------------------------------------------------------

75
.mise.toml Normal file
View File

@ -0,0 +1,75 @@
[tools]
python = "3.13"
uv = "latest"
ruff = "latest"
"npm:pyright" = "latest"
[env]
# Python environment settings
PYTHONPATH = "."
PYTHONDONTWRITEBYTECODE = "1"
PYTHONUNBUFFERED = "1"
# TradingAgents specific environment variables
TRADINGAGENTS_RESULTS_DIR = "./results"
TRADINGAGENTS_DATA_DIR = "./data"
[tasks.install]
description = "Install dependencies using uv"
run = "uv sync --dev"
[tasks.dev]
description = "Run the CLI application"
run = "uv run python -m cli.main"
[tasks.run]
description = "Run the main application"
run = "uv run python main.py"
[tasks.test]
description = "Run tests with pytest"
run = "uv run pytest"
[tasks.lint]
description = "Run ruff linting"
run = "ruff check ."
[tasks.format]
description = "Format code with ruff"
run = "ruff format ."
[tasks.typecheck]
description = "Run pyright type checking"
run = "pyright"
[tasks.fix]
description = "Auto-fix linting issues"
run = "ruff check --fix ."
[tasks.all]
description = "Run format, lint, and typecheck"
run = [
"ruff format .",
"ruff check .",
"pyright"
]
[tasks.clean]
description = "Clean up cache and build artifacts"
run = [
"find . -type d -name __pycache__ -exec rm -rf {} +",
"find . -type f -name '*.pyc' -delete",
"rm -rf .pytest_cache",
"rm -rf .ruff_cache",
"rm -rf dist",
"rm -rf build",
"rm -rf *.egg-info"
]
[tasks.setup]
description = "Initial project setup"
run = [
"mise install",
"uv sync --dev",
"echo 'Setup complete! Run mise run --help to see available tasks.'"
]

View File

@ -1 +1 @@
3.10
3.13

166
AGENTS.md Normal file
View File

@ -0,0 +1,166 @@
# AGENTS.md - TradingAgents Development Guide
## What TradingAgents Does
**TradingAgents** is a multi-agent LLM financial trading framework that simulates a real-world trading firm using specialized AI agents. The system analyzes stocks through collaborative decision-making, mirroring professional trading teams.
### Core Architecture
- Built on **LangGraph** with state-based workflows
- **TradingAgentsGraph** orchestrates the entire process
- Agents work sequentially and in parallel to analyze market conditions
### Agent Teams & Workflow
1. **Analyst Team**: Market, Social Media, News, and Fundamentals analysts gather data
2. **Research Team**: Bull/Bear researchers debate, Research Manager decides
3. **Trading Team**: Trader develops detailed trading plans
4. **Risk Management**: Risk analysts debate, Risk Manager makes final decision
### Data Sources
- Yahoo Finance, FinnHub API, Reddit, Google News, StockStats
- Supports both real-time online data and cached offline data for backtesting
### Decision Process
Sequential analysis → Structured debate → Managerial oversight → Risk assessment → Memory & learning
## Build/Test Commands
This project uses [mise](https://mise.jdx.dev/) for tool and task management. All development tasks are managed through mise.
### Initial Setup
- **First-time setup**: `mise run setup` - Install tools and dependencies
- **Install tools only**: `mise install` - Install Python, uv, ruff, pyright
- **Install dependencies**: `mise run install` - Install project dependencies with uv
### Development Workflow
- **CLI Application**: `mise run dev` - Interactive CLI for running trading analysis
- **Direct Python Usage**: `mise run run` - Run main.py programmatically
- **Format code**: `mise run format` - Auto-format with ruff
- **Lint code**: `mise run lint` - Check code quality with ruff
- **Type checking**: `mise run typecheck` - Run pyright type checker
- **Fix lint issues**: `mise run fix` - Auto-fix linting issues
- **Run all checks**: `mise run all` - Format, lint, and typecheck
- **Clean artifacts**: `mise run clean` - Remove cache and build files
### Testing
- **Run tests**: `mise run test` - Run tests with pytest (when available)
### Configuration
- **Environment Variables**: Create `.env` file with API keys (see `.env.example`)
- **Tool Configuration**: `.mise.toml` manages Python 3.13, uv, ruff, pyright
- **Code Quality**: `pyproject.toml` contains ruff and pyright configurations
## Configuration System
### Environment Variables
Create `.env` file with API keys (see `.env.example`):
#### Core LLM APIs (Choose One)
```bash
# For OpenAI (default)
export OPENAI_API_KEY="your_openai_api_key"
# For Anthropic Claude
export ANTHROPIC_API_KEY="your_anthropic_api_key"
# For Google Gemini
export GOOGLE_API_KEY="your_google_api_key"
```
#### Data Sources (Optional)
```bash
# For financial data
export FINNHUB_API_KEY="your_finnhub_api_key"
# For Reddit data
export REDDIT_CLIENT_ID="your_reddit_client_id"
export REDDIT_CLIENT_SECRET="your_reddit_client_secret"
export REDDIT_USER_AGENT="your_app_name"
```
### Configuration Management
- **Config Class**: `TradingAgentsConfig` in `tradingagents/config.py` handles all configuration
- Use `TradingAgentsConfig.from_env()` for environment-based configuration
- Key settings: `max_debate_rounds`, `llm_provider`, `online_tools`
- Results are saved to `results_dir/{ticker}/{date}/` with structured reports
### Configuration Examples
#### Anthropic Setup
```python
config = TradingAgentsConfig.from_env()
config.llm_provider = "anthropic"
config.deep_think_llm = "claude-3-5-sonnet-20241022"
config.quick_think_llm = "claude-3-5-haiku-20241022"
```
#### Google Gemini Setup
```python
config.llm_provider = "google"
config.deep_think_llm = "gemini-2.0-flash"
config.quick_think_llm = "gemini-2.0-flash"
```
#### Data Mode Configuration
- `config.online_tools = True` - Real-time data (requires API keys)
- `config.online_tools = False` - Cached data (faster, historical only)
## Code Style Guidelines
- **Imports**: Standard library first, third-party, then local imports (langchain, tradingagents modules)
- **Formatting**: Auto-formatted with ruff (`mise run format`)
- **Linting**: Code quality checked with ruff (`mise run lint`)
- **Type Checking**: Static analysis with pyright (`mise run typecheck`)
- **Functions**: Snake_case naming (e.g., `fundamentals_analyst_node`, `create_fundamentals_analyst`)
- **Classes**: PascalCase (e.g., `TradingAgentsGraph`, `MessageBuffer`)
- **Variables**: Snake_case (e.g., `current_date`, `company_of_interest`)
- **Constants**: UPPER_CASE (e.g., `DEFAULT_CONFIG`)
## Project Structure
- **Main entry**: `main.py` for package usage, `cli/main.py` for CLI
- **Core logic**: `tradingagents/` package with agents, dataflows, graph modules
- **Configuration**: `tradingagents/config.py` for LLM and system settings
- **CLI interface**: `cli/` directory with rich-based terminal UI
- **Tool Management**: `.mise.toml` for development tool configuration
- **Dependencies**: `pyproject.toml` for project dependencies and tool settings
## Key Patterns
- **Agent creation**: Factory functions that return node functions (e.g., `create_fundamentals_analyst`)
- **State management**: Dictionary-based state passed between graph nodes
- **Tool integration**: LangChain tools bound to LLMs via `llm.bind_tools(tools)`
- **Configuration**: Use `TradingAgentsConfig.from_env()` for environment-based configuration
- **Debate-Driven Decision Making**: Critical decisions emerge from structured agent debates
- **Memory-Augmented Learning**: Agents learn from past similar situations using vector similarity
- **Dual-Mode Data Access**: Support for both live API calls and pre-processed cached data
- **Factory Pattern**: Agent creation via factory functions for flexible configuration
- **Signal Processing**: Final trading decisions processed into clean BUY/SELL/HOLD signals
## Development Guidelines
### Working with Agents
- Each agent has its own memory instance in `FinancialSituationMemory`
- Agents use the unified `Toolkit` for data access
- Agent state is passed sequentially through the workflow
- Configuration affects debate rounds, LLM selection, and data sources
### Working with Data Sources
- All data utilities follow consistent date range patterns: `curr_date + look_back_days`
- Interface functions return markdown-formatted strings for LLM consumption
- Check `online_tools` config flag to determine live vs cached data usage
- Data caching happens in `data_cache_dir` for online mode
### CLI Development
- CLI uses Rich for terminal UI with live updating displays
- Agent progress tracking through `MessageBuffer` class
- Questionnaire-driven configuration collection
- Real-time streaming of analysis results
### File Structure Context
- **`cli/`**: Interactive command-line interface
- **`tradingagents/agents/`**: All agent implementations
- **`tradingagents/dataflows/`**: Data source integrations
- **`tradingagents/graph/`**: LangGraph workflow orchestration
- **`tradingagents/config.py`**: Configuration management
- **`main.py`**: Direct Python usage example
- **`CLAUDE.md`**: Guidance for Claude Code development

327
CLAUDE.md Normal file
View File

@ -0,0 +1,327 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Common Development Commands
This project uses [mise](https://mise.jdx.dev/) for tool and task management. All development tasks are managed through mise.
### Initial Setup
- **First-time setup**: `mise run setup` - Install tools and dependencies
- **Install tools only**: `mise install` - Install Python, uv, ruff, pyright
- **Install dependencies**: `mise run install` - Install project dependencies with uv
### Development Workflow
- **CLI Application**: `mise run dev` - Interactive CLI for running trading analysis
- **Direct Python Usage**: `mise run run` - Run main.py programmatically
- **Format code**: `mise run format` - Auto-format with ruff
- **Lint code**: `mise run lint` - Check code quality with ruff
- **Type checking**: `mise run typecheck` - Run pyright type checker
- **Fix lint issues**: `mise run fix` - Auto-fix linting issues
- **Run all checks**: `mise run all` - Format, lint, and typecheck
- **Clean artifacts**: `mise run clean` - Remove cache and build files
### Testing
#### Running Tests
- **Run all tests**: `mise run test` - Run tests with pytest
- **Run specific test file**: `uv run pytest test_social_media_service.py` - Run individual test file
- **Verbose output**: `uv run pytest -v` - Run tests with detailed output
- **Run with output**: `uv run pytest -s` - Show print statements and debug output
- **Test coverage**: `uv run pytest --cov=tradingagents` - Run tests with coverage report
#### Test Development (TDD Approach)
This project follows **Test-Driven Development (TDD)** for service layer development:
1. **Write test first**: Create `test_{service_name}_service.py` with comprehensive test cases
2. **Run test (should fail)**: Verify test fails with appropriate error messages
3. **Implement minimum code**: Write just enough code to make the test pass
4. **Refactor**: Improve code while keeping tests passing
5. **Repeat**: Add more test cases and implement additional functionality
#### Test Structure and Conventions
- **Test files**: Named `test_{component}_service.py` and placed next to source code (not in separate tests/ directory)
- **Test functions**: Named `test_{functionality}()` and should not return values (use `assert` statements)
- **Mock clients**: Create mock implementations of BaseClient for testing services
- **Real repositories**: Use actual repository implementations (don't mock the repository layer)
- **Test data**: Use realistic mock data that matches expected API responses
- **Date handling**: Use fixed dates (e.g., `datetime(2024, 1, 2)`) in mocks for predictable filtering
#### Service Testing Pattern
Example test structure for services:
```python
def test_online_mode_with_mock_client():
"""Test service in online mode with mock client."""
mock_client = MockServiceClient()
real_repo = ServiceRepository("test_data")
service = ServiceClass(
client=mock_client,
repository=real_repo,
online_mode=True
)
context = service.get_context("TEST", "2024-01-01", "2024-01-05")
# Validate structure
assert isinstance(context, ContextModel)
assert context.symbol == "TEST"
assert len(context.data) > 0
# Test JSON serialization
json_output = context.model_dump_json()
assert len(json_output) > 0
```
#### Mock Client Guidelines
- **Extend BaseClient**: All mock clients must implement the abstract `get_data()` method
- **Realistic data**: Return data structures that match actual API responses
- **Date consistency**: Use fixed dates that work with test date ranges
- **Error simulation**: Create broken clients for testing error handling paths
- **Multiple scenarios**: Provide different data for different test cases
### Configuration
- **Environment Variables**: Create `.env` file with API keys (see `.env.example`)
- **Config Class**: `TradingAgentsConfig` in `tradingagents/config.py` handles all configuration
- **Tool Configuration**: `.mise.toml` manages Python 3.13, uv, ruff, pyright
- **Code Quality**: `pyproject.toml` contains ruff and pyright configurations
#### Required Environment Variables
##### Core LLM APIs (Choose One)
```bash
# For OpenAI (default)
export OPENAI_API_KEY="your_openai_api_key"
# For Anthropic Claude
export ANTHROPIC_API_KEY="your_anthropic_api_key"
# For Google Gemini
export GOOGLE_API_KEY="your_google_api_key"
```
##### Data Sources (Optional)
```bash
# For financial data
export FINNHUB_API_KEY="your_finnhub_api_key"
# For Reddit data
export REDDIT_CLIENT_ID="your_reddit_client_id"
export REDDIT_CLIENT_SECRET="your_reddit_client_secret"
export REDDIT_USER_AGENT="your_app_name"
```
## High-Level Architecture
### Multi-Agent Trading Framework
TradingAgents implements a sophisticated multi-agent system that mirrors real-world trading firms with specialized roles and structured workflows.
### Core Architecture Components
#### 1. **Agent Teams** (Sequential Workflow)
```
Analyst Team → Research Team → Trading Team → Risk Management Team
```
**Analyst Team** (`tradingagents/agents/analysts/`)
- **Market Analyst**: Technical analysis using Yahoo Finance and StockStats
- **Fundamentals Analyst**: Financial statements and company fundamentals via SimFin/Finnhub
- **News Analyst**: News sentiment analysis and world affairs impact
- **Social Media Analyst**: Reddit and social platform sentiment analysis
**Research Team** (`tradingagents/agents/researchers/`)
- **Bull Researcher**: Advocates for investment opportunities and growth potential
- **Bear Researcher**: Highlights risks and argues against investments
- **Research Manager**: Synthesizes debates and creates investment recommendations
**Trading Team** (`tradingagents/agents/trader/`)
- **Trader**: Converts investment plans into specific trading decisions
**Risk Management Team** (`tradingagents/agents/risk_mgmt/`)
- **Aggressive/Conservative/Neutral Debators**: Different risk perspectives
- **Risk Manager**: Final decision maker balancing risk and reward
#### 2. **Data Layer** (`tradingagents/services/` + `tradingagents/dataflows/`)
**New Service-Based Architecture** (Current):
- **Service Layer**: `MarketDataService`, `NewsService`, `SocialMediaService`, `FundamentalDataService`, `InsiderDataService`, `OpenAIDataService`
- Orchestrate between clients and repositories with local-first data strategy
- Provide structured JSON contexts via Pydantic models
- Support `force_refresh=False` parameter for explicit data refreshing
- Automatically check local cache first, fetch from APIs only when data is missing
- **Client Layer**: Live API integrations - `YFinanceClient`, `FinnhubClient`, `RedditClient`, `GoogleNewsClient`, `SimFinClient`
- **Repository Layer**: Smart cached data storage with gap detection - `MarketDataRepository`, `NewsRepository`, `SocialRepository`
- **Context Models**: Pydantic models for structured JSON data - `MarketDataContext`, `NewsContext`, `SocialContext`, `FundamentalContext`
- **Toolkit Integration**: `ServiceToolkit` provides 100% backward-compatible interface returning JSON contexts
**Legacy Data Integration System** (Being phased out):
- **Yahoo Finance** (`yfin_utils.py`): Stock prices, financials, analyst recommendations
- **Finnhub** (`finnhub_utils.py`): News, insider trading, SEC filings
- **Reddit** (`reddit_utils.py`): Social sentiment from curated subreddits
- **Google News** (`googlenews_utils.py`): Web-scraped news with retry logic
- **SimFin**: Balance sheets, cash flow, income statements
- **StockStats** (`stockstats_utils.py`): Technical indicators (MACD, RSI, etc.)
- **Interface Layer** (`interface.py`): Standardized agent-facing APIs with markdown formatting
#### 3. **Graph Orchestration** (`tradingagents/graph/`)
LangGraph-based workflow management:
- **TradingAgentsGraph**: Main orchestrator class
- **State Management**: `AgentState`, `InvestDebateState`, `RiskDebateState` track workflow progress
- **Conditional Logic**: Dynamic routing based on tool usage and debate completion
- **Memory System**: ChromaDB-based vector memory for learning from past decisions
#### 4. **Configuration System**
- **TradingAgentsConfig**: Centralized configuration with environment variable support
- **Multi-LLM Support**: OpenAI, Anthropic, Google, Ollama, OpenRouter
- **Data Modes**: Online (live APIs) vs offline (cached data)
### Key Design Patterns
1. **Debate-Driven Decision Making**: Critical decisions emerge from structured agent debates
2. **Memory-Augmented Learning**: Agents learn from past similar situations using vector similarity
3. **Local-First Data Strategy**: Automatically check local cache first, fetch from APIs only when needed
4. **Structured JSON Contexts**: Replace error-prone string parsing with rich Pydantic models
5. **Factory Pattern**: Agent creation via factory functions for flexible configuration
6. **Signal Processing**: Final trading decisions processed into clean BUY/SELL/HOLD signals
7. **Quality-Aware Data**: All contexts include quality metadata to help agents make better decisions
### Code Style Guidelines
#### General Style
- **Functions**: Snake_case naming (e.g., `fundamentals_analyst_node`, `create_fundamentals_analyst`)
- **Classes**: PascalCase (e.g., `TradingAgentsGraph`, `MessageBuffer`)
- **Variables**: Snake_case (e.g., `current_date`, `company_of_interest`)
- **Constants**: UPPER_CASE (e.g., `DEFAULT_CONFIG`)
- **Imports**: Standard library first, third-party, then local imports (langchain, tradingagents modules)
#### Ruff Formatting & Linting Rules
**Formatting** (`mise run format`):
- **Line length**: 88 characters maximum
- **Quote style**: Double quotes (`"string"`)
- **Indentation**: 4 spaces (no tabs)
- **Trailing commas**: Preserved for multi-line structures
- **Line endings**: Auto-detected based on platform
**Linting** (`mise run lint`):
- **Selected rules**:
- `E`, `W`: pycodestyle errors and warnings
- `F`: pyflakes (undefined names, unused imports)
- `I`: isort (import sorting)
- `B`: flake8-bugbear (common bugs)
- `C4`: flake8-comprehensions (list/dict comprehensions)
- `UP`: pyupgrade (Python syntax modernization)
- `ARG`: flake8-unused-arguments
- `SIM`: flake8-simplify (code simplification)
- `TCH`: flake8-type-checking (type annotation imports)
- **Ignored rules**:
- `E501`: Line too long (handled by formatter)
- `B008`: Function calls in argument defaults (allowed for LangChain)
- `C901`: Complex functions (legacy code tolerance)
- `ARG001`, `ARG002`: Unused arguments (common in callbacks)
- **Import sorting**: `tradingagents` and `cli` treated as first-party modules
#### Pyright Type Checking Rules
**Configuration** (`mise run typecheck`):
- **Tool**: pyright 1.1.390+ with standard type checking mode
- **Python version**: 3.10+ (configured for compatibility with modern syntax)
- **Coverage**: Includes `tradingagents/`, `cli/`, and `main.py`
- **Exclusions**: `__pycache__`, `node_modules`, `.venv`, `venv`, `build`, `dist`
**Configured Type Checking Rules**:
- `reportMissingImports = true` - Catch undefined module imports
- `reportMissingTypeStubs = false` - Allow libraries without type stubs
- `reportGeneralTypeIssues = true` - General type inconsistencies
- `reportOptionalMemberAccess = true` - Unsafe access to optional values
- `reportOptionalCall = true` - Calling potentially None values
- `reportOptionalIterable = true` - Iterating over potentially None values
- `reportOptionalContextManager = true` - Using None in with statements
- `reportOptionalOperand = true` - Operations on potentially None values
- `reportTypedDictNotRequiredAccess = false` - Flexible TypedDict access
- `reportPrivateImportUsage = false` - Allow importing private modules
- `reportUnknownParameterType = false` - Allow untyped parameters
- `reportUnknownArgumentType = false` - Allow untyped arguments
- `reportUnknownLambdaType = false` - Allow untyped lambdas
- `reportUnknownVariableType = false` - Allow untyped variables
- `reportUnknownMemberType = false` - Allow untyped attributes
**Type Annotation Guidelines**:
- Use modern Python 3.10+ union syntax: `str | None` instead of `Optional[str]`
- Use built-in generics: `list[str]` instead of `List[str]`
- Use `dict[str, Any]` for flexible dictionaries
- Import `from typing import Any` for untyped data structures
- Prefer explicit return types on public functions
- Use `# type: ignore` sparingly with explanatory comments
### Development Guidelines
#### Working with Agents
**Current Approach** (JSON Contexts):
- Import `ServiceToolkit` from `tradingagents.agents.utils.service_toolkit`
- Use `context_helpers` for parsing structured JSON data (MarketDataParser, NewsParser, etc.)
- All toolkit methods return structured JSON instead of markdown strings
- Check data quality with `is_high_quality_data()` before analysis
- Use `extract_latest_price()`, `extract_sentiment_score()` for quick data extraction
**Legacy Approach** (Being phased out):
- Agents use the unified `Toolkit` for markdown-formatted data access
#### Working with Data Sources
**Current Service-Based Approach**:
- **Local-First Strategy**: Services automatically check local cache first, only fetch from APIs when data is missing
- **Force Refresh**: Use `force_refresh=True` parameter to bypass local data and get fresh API data
- **Structured Contexts**: Services return Pydantic models (`MarketDataContext`, `NewsContext`, etc.) with rich metadata
- **Quality Awareness**: All contexts include data_quality (HIGH/MEDIUM/LOW) and data_source information
- **JSON Serialization**: All context models support `.model_dump_json()` for agent consumption
- **Smart Caching**: Repositories detect data gaps and fetch only missing periods
- **Cost Efficient**: Minimizes expensive API calls through intelligent local-first logic
**Service Configuration**:
```python
# Services use dependency injection
service = MarketDataService(
client=YFinanceClient(),
repository=MarketDataRepository("cache_dir"),
online_mode=True # Enable API fetching when local data insufficient
)
# Get data with local-first strategy
context = service.get_context("AAPL", "2024-01-01", "2024-01-31")
# Force fresh data when needed
fresh_context = service.get_context("AAPL", "2024-01-01", "2024-01-31", force_refresh=True)
```
**Legacy Approach** (Being phased out):
- Interface functions return markdown-formatted strings for LLM consumption
- All data utilities follow consistent date range patterns: `curr_date + look_back_days`
- Check `online_tools` config flag to determine live vs cached data usage
- Data caching happens in `data_cache_dir` for online mode
#### Configuration Management
- Use `TradingAgentsConfig.from_env()` for environment-based configuration
- Key settings: `max_debate_rounds`, `llm_provider`, `online_tools`
- Results are saved to `results_dir/{ticker}/{date}/` with structured reports
#### CLI Development
- CLI uses Rich for terminal UI with live updating displays
- Agent progress tracking through `MessageBuffer` class
- Questionnaire-driven configuration collection
- Real-time streaming of analysis results
### File Structure Context
- **`cli/`**: Interactive command-line interface
- **`tradingagents/agents/`**: All agent implementations
- **`utils/service_toolkit.py`**: New ServiceToolkit with JSON contexts (100% backward compatible)
- **`utils/context_helpers.py`**: Helper functions for parsing structured JSON data
- **`utils/agent_utils.py`**: Legacy Toolkit (being phased out)
- **`tradingagents/services/`**: Service layer with local-first data strategy
- **`tradingagents/dataflows/`**: Legacy data source integrations (being phased out)
- **`tradingagents/graph/`**: LangGraph workflow orchestration
- **`tradingagents/config.py`**: Configuration management
- **`main.py`**: Direct Python usage example
- **`AGENTS.md`**: Detailed agent documentation
- **`examples/agent_json_migration.py`**: Migration guide from markdown to JSON contexts

View File

@ -0,0 +1,289 @@
# Product Requirements Document: FundamentalDataService Completion
## Overview
Complete the `FundamentalDataService` to provide strongly-typed fundamental financial data to trading agents using a local-first data strategy with gap detection and intelligent caching.
## Current State Analysis
### Issues to Fix
- **CRITICAL**: Service calls `FinnhubClient` methods with string dates but client expects `date` objects
- **CRITICAL**: References non-existent `self.simfin_client` instead of `self.finnhub_client`
- Missing strongly-typed interfaces between components
- Incomplete local-first strategy implementation
- No concrete gap detection logic
- Missing error recovery for partial data
### What Works
- ✅ `FinnhubClient` fully implemented with strict `date` object interface
- ✅ `FundamentalDataRepository` with dataclass-based storage
- ✅ `FundamentalContext` Pydantic model for agent consumption
- ✅ Basic service structure and error handling
## Technical Requirements
### 1. Strongly-Typed Interfaces
#### Client → Service Interface
```python
# FinnhubClient methods (already implemented)
def get_balance_sheet(symbol: str, frequency: str, report_date: date) -> dict[str, Any]
def get_income_statement(symbol: str, frequency: str, report_date: date) -> dict[str, Any]
def get_cash_flow(symbol: str, frequency: str, report_date: date) -> dict[str, Any]
```
#### Service → Repository Interface
```python
# Repository methods (already implemented)
def has_data_for_period(symbol: str, start_date: str, end_date: str, frequency: str) -> bool
def get_data(symbol: str, start_date: str, end_date: str, frequency: str) -> dict[str, Any]
def store_data(symbol: str, cache_data: dict, frequency: str, overwrite: bool) -> bool
def clear_data(symbol: str, start_date: str, end_date: str, frequency: str) -> bool
```
#### Service → Agent Interface
```python
# Service output (already defined)
def get_context(symbol: str, start_date: str, end_date: str, frequency: str, force_refresh: bool) -> FundamentalContext
```
### 2. Local-First Data Strategy
#### Flow
1. **Repository Lookup**: Check `FundamentalDataRepository.has_data_for_period()`
2. **Gap Detection**: Identify missing data periods using `detect_fundamental_gaps()`
3. **Selective Fetching**: Fetch only missing data from `FinnhubClient`
4. **Cache Updates**: Store new data via `repository.store_data()`
5. **Context Assembly**: Return validated `FundamentalContext`
#### Gap Detection Implementation
```python
def detect_fundamental_gaps(self, symbol: str, start_date: str, end_date: str, frequency: str) -> list[str]:
"""
Returns list of report dates that need fetching.
Example: If requesting quarterly from 2024-01-01 to 2024-12-31
and cache has Q1 and Q3, returns ["2024-06-30", "2024-09-30", "2024-12-31"]
For quarterly: Check for Q1 (Mar 31), Q2 (Jun 30), Q3 (Sep 30), Q4 (Dec 31)
For annual: Check for fiscal year ends
"""
# Implementation should:
# 1. Get existing report dates from repository
# 2. Calculate expected report dates in requested period
# 3. Return difference between expected and existing
```
#### Force Refresh Support
- `force_refresh=True` bypasses local data completely
- Clears existing cache before fetching fresh data
- Stores refreshed data with metadata indicating refresh
#### Cache Invalidation Strategy
- **Fundamental data is immutable**: Once a report is filed, it doesn't change
- **No staleness checks needed**: Reports are valid indefinitely
- **Only fetch if missing**: Never re-fetch existing reports
### 3. Date Object Conversion
#### Service Boundary Conversion
```python
# Service receives string dates from agents
def get_context(self, symbol: str, start_date: str, end_date: str, ...) -> FundamentalContext:
# Validate date strings
try:
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# Check date order
if end_dt < start_dt:
raise ValueError(f"End date {end_date} is before start date {start_date}")
# Use date objects when calling FinnhubClient
data = self.finnhub_client.get_balance_sheet(symbol, frequency, end_dt)
```
### 4. Error Recovery and Partial Data
```python
def handle_partial_statements(
self,
balance_sheet: dict | None,
income_statement: dict | None,
cash_flow: dict | None
) -> FundamentalContext:
"""
Create context even if some statements are missing.
- If all statements fail: Raise exception
- If some statements succeed: Return partial context
- Mark missing statements in metadata
"""
metadata = {
"has_balance_sheet": balance_sheet is not None,
"has_income_statement": income_statement is not None,
"has_cash_flow": cash_flow is not None,
"partial_data": any(s is None for s in [balance_sheet, income_statement, cash_flow])
}
# Convert available statements to FinancialStatement objects
# Return FundamentalContext with available data
```
### 5. Pydantic Validation
#### Context Structure
```python
@dataclass
class FundamentalContext(BaseModel):
symbol: str
period: dict[str, str] # {"start": "2024-01-01", "end": "2024-01-31"}
balance_sheet: FinancialStatement | None
income_statement: FinancialStatement | None
cash_flow: FinancialStatement | None
key_ratios: dict[str, float]
metadata: dict[str, Any]
@validator('period')
def validate_period(cls, v):
# Ensure start and end dates are present and valid
return v
```
## Implementation Tasks
### Phase 1: Fix Critical Issues
1. **Date Conversion Fix**
- Add `date.fromisoformat()` conversion in service methods
- Add date validation (format, order)
- Update all `FinnhubClient` method calls to use `date` objects
- File: `tradingagents/services/fundamental_data_service.py:153, 164, 175`
2. **Client Reference Fix**
- Replace `self.simfin_client` with `self.finnhub_client`
- File: `tradingagents/services/fundamental_data_service.py:375`
### Phase 2: Enhanced Local-First Strategy
3. **Gap Detection Logic**
- Implement `detect_fundamental_gaps()` method
- Calculate expected report dates based on frequency
- Compare with cached data to find gaps
- Handle fiscal year variations
4. **Partial Data Handling**
- Implement `handle_partial_statements()` method
- Continue processing if some statements succeed
- Mark missing data in metadata
- Only fail if all statements fail
### Phase 3: Type Safety & Validation
5. **Comprehensive Type Checking**
- Run `mise run typecheck` - must pass with 0 errors
- Validate all `date` object conversions
- Ensure Pydantic model compliance
6. **Enhanced Testing**
- Update existing tests for new date handling
- Add gap detection test scenarios
- Test partial data scenarios
- Test force refresh behavior
- Test date validation edge cases
## Testing Scenarios
### Integration Tests
1. **Gap Detection**
- Test with empty cache (should fetch all)
- Test with partial cache (should fetch only missing)
- Test with complete cache (should fetch none)
2. **Partial Data Recovery**
- Test when balance sheet API fails but others succeed
- Test when only one statement type is available
- Test when all APIs fail (should raise exception)
3. **Date Handling**
- Test invalid date formats
- Test end_date < start_date
- Test boundary conditions (year start/end)
4. **Force Refresh**
- Test that force_refresh=True clears cache
- Test that new data is fetched and stored
## Success Criteria
### Functional Requirements
- ✅ Service successfully calls `FinnhubClient` with `date` objects
- ✅ Gap detection correctly identifies missing reports
- ✅ Partial data scenarios handled gracefully
- ✅ Local-first strategy works: checks cache → identifies gaps → fetches missing → stores updates
- ✅ Returns properly validated `FundamentalContext` to agents
- ✅ Force refresh bypasses cache and refreshes data
### Technical Requirements
- ✅ Zero type checking errors: `mise run typecheck`
- ✅ Zero linting errors: `mise run lint`
- ✅ All existing tests pass
- ✅ No runtime errors with date conversions
- ✅ Proper error messages for validation failures
### Quality Requirements
- ✅ Strongly-typed interfaces between all components
- ✅ Comprehensive error handling and logging
- ✅ Efficient caching with minimal API calls
- ✅ Clear separation of concerns between service, client, and repository
## Dependencies
### Completed
- ✅ `FinnhubClient` with `date` object interface
- ✅ `FundamentalDataRepository` with dataclass storage
- ✅ `FundamentalContext` Pydantic model
### Required
- Working `FinnhubClient` instance with valid API key
- Writable data directory for repository storage
## Timeline
### Immediate (Today)
- Fix critical date conversion and reference issues
- Implement basic gap detection
- Add date validation
### Next Steps
- Implement partial data handling
- Comprehensive testing
- Integration with agent workflows
## Acceptance Criteria
### Must Have
1. **Type Safety**: Service passes `mise run typecheck` with zero errors
2. **Client Integration**: All `FinnhubClient` calls use `date` objects correctly
3. **Gap Detection**: Correctly identifies missing report periods
4. **Partial Data**: Service returns partial context when some statements fail
5. **Local-First**: Service checks repository before API calls
6. **Context Validation**: Returns valid `FundamentalContext` with Pydantic validation
7. **Error Handling**: Graceful handling of API failures and missing data
### Should Have
1. **Cache Efficiency**: Minimal redundant API calls
2. **Force Refresh**: Complete cache bypass when requested
3. **Data Quality**: Metadata indicating data completeness
4. **Clear Error Messages**: Informative errors for date validation failures
### Nice to Have
1. **Performance Metrics**: Timing and cache hit rate logging
2. **Fiscal Year Handling**: Support for non-calendar fiscal years
3. **Bulk Operations**: Fetch multiple symbols efficiently
---
This PRD focuses on completing the `FundamentalDataService` as a strongly-typed, local-first data service that seamlessly integrates with the existing `FinnhubClient` and `FundamentalDataRepository` components while providing robust gap detection and partial data handling.

502
MarketDataService_PRD.md Normal file
View File

@ -0,0 +1,502 @@
# Product Requirements Document: MarketDataService Completion
## Overview
Complete the `MarketDataService` to provide strongly-typed market data and technical indicators to trading agents using a local-first data strategy with gap detection and intelligent caching.
## Current State Analysis
### Issues to Fix
- **CRITICAL**: Service uses `BaseClient` inheritance but `YFinanceClient` exists and needs refactoring to FinnhubClient standard
- **CRITICAL**: Service calls client methods with string dates instead of date objects
- **CRITICAL**: Need to integrate `stockstats` library for technical analysis calculations instead of legacy utils
- **CRITICAL**: `MarketDataRepository` exists but missing service interface methods
- Missing strongly-typed interface between YFinanceClient and service
- YFinanceClient uses BaseClient inheritance and string dates (needs refactoring)
- No concrete gap detection logic
- Missing technical indicator data sufficiency validation
### What Works
- ✅ Local-first data strategy implementation (`_get_price_data_local_first`)
- ✅ Force refresh logic (`_fetch_and_cache_fresh_data`)
- ✅ `MarketDataContext` Pydantic model for agent consumption
- ✅ Error handling and metadata creation patterns
- ✅ `YFinanceClient` exists with yfinance SDK integration and comprehensive methods
- ✅ `MarketDataRepository` exists with CSV storage and pandas DataFrame operations
- ✅ Service structure ready for `stockstats` integration for technical analysis
## Technical Requirements
### 1. Strongly-Typed Interfaces
#### Client → Service Interface
```python
# YFinanceClient methods (to be refactored)
def get_historical_data(symbol: str, start_date: date, end_date: date) -> dict[str, Any]
def get_price_data(symbol: str, start_date: date, end_date: date) -> dict[str, Any]
# Technical analysis handled in service layer using stockstats
# No get_technical_indicator method needed in client - calculated from OHLCV data
```
#### Service → Repository Interface
```python
# MarketDataRepository methods (to be implemented)
def has_data_for_period(symbol: str, start_date: str, end_date: str) -> bool
def get_data(symbol: str, start_date: str, end_date: str) -> dict[str, Any]
def store_data(symbol: str, cache_data: dict, overwrite: bool) -> bool
def clear_data(symbol: str, start_date: str, end_date: str) -> bool
```
#### Service → Agent Interface
```python
# Service output (already defined)
def get_context(symbol: str, start_date: str, end_date: str, indicators: list[str], force_refresh: bool) -> MarketDataContext
```
### 2. Local-First Data Strategy
#### Flow
1. **Repository Lookup**: Check `MarketDataRepository.has_data_for_period()`
2. **Gap Detection**: Identify missing price data periods using `detect_market_gaps()`
3. **Data Sufficiency Check**: Ensure enough historical data for requested indicators
4. **Selective Fetching**: Fetch only missing data from `YFinanceClient`
5. **Cache Updates**: Store new data via `repository.store_data()`
6. **Context Assembly**: Return validated `MarketDataContext`
#### Gap Detection Implementation
```python
def detect_market_gaps(self, cached_dates: list[str], requested_start: str, requested_end: str) -> list[tuple[str, str]]:
"""
Returns list of (start, end) tuples for missing periods.
Example: If requesting 2024-01-01 to 2024-01-31 and cache has:
- 2024-01-01 to 2024-01-10
- 2024-01-20 to 2024-01-25
Returns: [("2024-01-11", "2024-01-19"), ("2024-01-26", "2024-01-31")]
Accounts for:
- Weekends (Saturday/Sunday)
- Market holidays
- Continuous date ranges to minimize API calls
"""
# Implementation should use pandas business day logic
```
#### Force Refresh Support
- `force_refresh=True` bypasses local data completely
- Clears existing cache before fetching fresh data
- Stores refreshed data with metadata indicating refresh
#### Cache Invalidation Strategy
- **Historical data is immutable**: Data older than yesterday never changes
- **Today's data needs updates**: During market hours, refresh every 15 minutes
- **After market close**: Today's data becomes immutable
```python
def is_data_stale(self, data_date: date, last_updated: datetime) -> bool:
today = date.today()
if data_date < today:
return False # Historical data never stale
# For today's data, check if market is open and last update > 15 min
if is_market_open() and (datetime.now() - last_updated).minutes > 15:
return True
return False
```
### 3. Date Object Conversion
#### Service Boundary Conversion
```python
# Service receives string dates from agents
def get_context(self, symbol: str, start_date: str, end_date: str, ...) -> MarketDataContext:
# Validate date strings
try:
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# Check date order
if end_dt < start_dt:
raise ValueError(f"End date {end_date} is before start date {start_date}")
# Expand date range for technical indicators
expanded_start = self._calculate_lookback_start(start_dt, indicators)
# Use date objects when calling YFinanceClient
price_data = self.yfinance_client.get_historical_data(symbol, expanded_start, end_dt)
# Calculate technical indicators using stockstats library
technical_indicators = self._calculate_technical_indicators(price_data, indicators)
```
### 4. Technical Analysis with Stockstats
#### Data Sufficiency Validation
```python
# Minimum data points required for each indicator
INDICATOR_REQUIREMENTS = {
"sma_20": 20,
"sma_200": 200,
"ema_12": 24, # 2x for exponential smoothing
"ema_200": 400,
"rsi_14": 28, # 2x period for warm-up
"macd": 34, # 26 + 8 for signal line
"bb_upper": 20, # Based on 20-period SMA
"atr_14": 28, # 2x period for accuracy
"stochrsi_14": 42, # 3x period for double smoothing
}
def _calculate_lookback_start(self, start_date: date, indicators: list[str]) -> date:
"""Calculate how far back we need data to compute indicators accurately."""
max_lookback = 0
for indicator in indicators:
lookback = INDICATOR_REQUIREMENTS.get(indicator, 0)
max_lookback = max(max_lookback, lookback)
# Add buffer for weekends/holidays
business_days_back = max_lookback * 1.5
return start_date - timedelta(days=int(business_days_back))
def _validate_data_sufficiency(self, data_points: int, indicators: list[str]) -> dict[str, bool]:
"""Check if we have enough data for each indicator."""
return {
indicator: data_points >= INDICATOR_REQUIREMENTS.get(indicator, 0)
for indicator in indicators
}
```
#### Stockstats Integration
```python
def _calculate_technical_indicators(self, price_data: list[dict], indicators: list[str]) -> dict[str, list[dict]]:
"""
Calculate technical indicators using stockstats library.
Args:
price_data: OHLCV data from YFinanceClient
indicators: List of requested indicators (e.g., ['rsi_14', 'macd', 'bb_upper', 'sma_20'])
Returns:
Dict mapping indicator names to time series data
"""
import pandas as pd
from stockstats import StockDataFrame
# Convert price data to pandas DataFrame
df = pd.DataFrame(price_data)
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
# Check data sufficiency
sufficiency = self._validate_data_sufficiency(len(df), indicators)
# Create StockDataFrame for technical analysis
sdf = StockDataFrame.retype(df)
# Calculate requested indicators
indicator_data = {}
for indicator in indicators:
if not sufficiency[indicator]:
logger.warning(f"Insufficient data for {indicator}, need {INDICATOR_REQUIREMENTS[indicator]} points")
indicator_data[indicator] = []
continue
try:
if indicator in sdf.columns:
values = sdf[indicator].dropna()
indicator_data[indicator] = [
{"date": idx.strftime("%Y-%m-%d"), "value": float(val)}
for idx, val in values.items()
]
except Exception as e:
logger.warning(f"Failed to calculate {indicator}: {e}")
indicator_data[indicator] = []
return indicator_data
```
### 5. Error Recovery and Partial Data
```python
def handle_partial_price_data(
self,
requested_start: str,
requested_end: str,
available_data: list[dict]
) -> MarketDataContext:
"""
Handle cases where only partial date range is available.
- If no data available: Raise exception
- If partial data: Return what's available with metadata
- Mark gaps in metadata
"""
if not available_data:
raise ValueError(f"No market data available for {symbol}")
actual_start = min(d['date'] for d in available_data)
actual_end = max(d['date'] for d in available_data)
metadata = {
"requested_period": {"start": requested_start, "end": requested_end},
"actual_period": {"start": actual_start, "end": actual_end},
"partial_data": actual_start > requested_start or actual_end < requested_end,
"data_points": len(available_data)
}
# Return context with available data and metadata
```
### 6. Pydantic Validation
#### Context Structure
```python
@dataclass
class MarketDataContext(BaseModel):
symbol: str
period: dict[str, str] # {"start": "2024-01-01", "end": "2024-01-31"}
price_data: list[dict[str, Any]] # OHLCV records
technical_indicators: dict[str, list[TechnicalIndicatorData]]
metadata: dict[str, Any]
@validator('price_data')
def validate_price_data(cls, v):
# Ensure OHLCV fields present and valid
required_fields = {'date', 'open', 'high', 'low', 'close', 'volume'}
for record in v:
if not all(field in record for field in required_fields):
raise ValueError(f"Missing required OHLCV fields")
return v
```
## Implementation Tasks
### Phase 1: Refactor YFinanceClient
1. **YFinanceClient Refactoring**
- **Refactor existing** `tradingagents/clients/yfinance_client.py`
- Remove BaseClient inheritance
- Update all method signatures to accept `date` objects instead of strings
- Keep all existing functionality intact
- Example changes:
```python
# Current (wrong)
def get_historical_data(self, symbol: str, start_date: str, end_date: str) -> dict[str, Any]:
# Updated (correct)
def get_historical_data(self, symbol: str, start_date: date, end_date: date) -> dict[str, Any]:
```
2. **Comprehensive Testing**
- Update `tradingagents/clients/test_yfinance_client.py`
- Test with date objects
- Use pytest-vcr for HTTP interaction recording
- Test error handling and edge cases
### Phase 2: Update MarketDataRepository
3. **Repository Interface Enhancement**
- Update existing `tradingagents/repositories/market_data_repository.py`
- Add missing service interface methods: `has_data_for_period()`, `get_data()`, `store_data()`, `clear_data()`
- Maintain existing CSV/pandas functionality while adding service compatibility
- Support gap detection and partial data scenarios
### Phase 3: Update MarketDataService
4. **Client Integration Fix**
- Replace `BaseClient` dependency with `YFinanceClient`
- File: `tradingagents/services/market_data_service.py:8, 26`
- Update constructor to accept `yfinance_client: YFinanceClient`
5. **Date Conversion and Validation**
- Add `date.fromisoformat()` conversion in service methods
- Add date validation (format, order)
- Update client calls to use date objects instead of strings
- File: `tradingagents/services/market_data_service.py:151, 227`
6. **Technical Indicator Integration with Stockstats**
- Implement `_calculate_technical_indicators()` method using `stockstats` library
- Add `_calculate_lookback_start()` for data sufficiency
- Add `_validate_data_sufficiency()` to check if enough data
- Replace legacy `StockstatsUtils` integration with direct stockstats usage
- File: `tradingagents/services/market_data_service.py:9, 43, 280-346`
### Phase 4: Type Safety & Validation
7. **Comprehensive Type Checking**
- Run `mise run typecheck` - must pass with 0 errors
- Validate all date object conversions
- Ensure MarketDataContext compliance
8. **Enhanced Testing**
- Update existing service tests for new YFinanceClient interface
- Add gap detection test scenarios
- Test technical indicator data sufficiency
- Test partial data handling
## Testing Scenarios
### Integration Tests
1. **Gap Detection**
- Test with empty cache (should fetch all)
- Test with partial cache (should fetch only missing periods)
- Test weekend/holiday handling
2. **Technical Indicator Sufficiency**
- Test SMA_200 with only 100 days of data (should skip indicator)
- Test RSI_14 with exactly 28 days (should calculate)
- Test mixed indicators with varying data requirements
3. **Partial Data Recovery**
- Test when API returns less data than requested
- Test when some dates are missing (holidays)
- Test metadata accuracy for partial data
4. **Date Handling**
- Test invalid date formats
- Test end_date < start_date
- Test future dates
- Test weekend date handling
5. **Cache Staleness**
- Test historical data (should never refresh)
- Test today's data during market hours (should refresh if > 15 min)
- Test today's data after market close (should not refresh)
## Success Criteria
### Functional Requirements
- ✅ Service successfully calls refactored `YFinanceClient` with `date` objects
- ✅ Gap detection correctly identifies missing trading days
- ✅ Technical indicators validate data sufficiency before calculation
- ✅ Partial data scenarios handled gracefully
- ✅ Local-first strategy works: checks cache → identifies gaps → fetches missing → stores updates
- ✅ Returns properly validated `MarketDataContext` to agents
- ✅ Technical indicators calculated from OHLCV data using stockstats library
- ✅ Force refresh bypasses cache and refreshes data
### Technical Requirements
- ✅ Zero type checking errors: `mise run typecheck`
- ✅ Zero linting errors: `mise run lint`
- ✅ All existing tests pass with updated architecture
- ✅ No runtime errors with date conversions
- ✅ Proper error messages for validation failures
### Quality Requirements
- ✅ Strongly-typed interfaces between all components
- ✅ Official yfinance SDK and stockstats library usage
- ✅ Comprehensive error handling and logging
- ✅ Efficient caching with minimal API calls
- ✅ Clear separation of concerns between service, client, and repository
## Data Architecture
### YFinanceClient Response Format
```python
{
"symbol": "AAPL",
"period": {"start": "2024-01-01", "end": "2024-01-31"},
"data": [
{
"date": "2024-01-02", # Note: Jan 1 was a holiday
"open": 150.0,
"high": 155.0,
"low": 149.0,
"close": 154.0,
"volume": 1000000,
"adj_close": 154.0
},
...
],
"metadata": {
"source": "yfinance",
"retrieved_at": "2024-01-31T10:00:00Z",
"data_quality": "HIGH",
"missing_dates": ["2024-01-01", "2024-01-15"] # Holidays
}
}
```
### Technical Indicator Data Format
```python
# MarketDataContext.technical_indicators structure
{
"rsi_14": [
{"date": "2024-01-29", "value": 65.5}, # First valid after 28 days
{"date": "2024-01-30", "value": 67.2},
...
],
"sma_200": [], # Empty if insufficient data
"macd": [
{"date": "2024-01-31", "value": {"macd": 2.1, "signal": 1.8, "histogram": 0.3}}
],
"_metadata": {
"indicators_calculated": ["rsi_14", "macd"],
"indicators_skipped": {
"sma_200": "Insufficient data: need 200 points, have 31"
}
}
}
```
## Dependencies
### Existing Components (Need Updates)
- ✅ `YFinanceClient` exists but needs refactoring (remove BaseClient, use date objects)
- ✅ `MarketDataRepository` exists with CSV storage but needs service interface methods
- ✅ Tests exist but need updates for new interfaces
### Required
- Official `yfinance` library for market data fetching
- `stockstats` library for technical analysis calculations
- `pandas` for date/time handling and business day calculations
- Working internet connection for live data fetching
- Writable data directory for repository storage
## Timeline
### Immediate (Phase 1)
- Refactor existing YFinanceClient to use date objects
- Remove BaseClient inheritance
- Update tests for new interface
### Phase 2-3
- Add service interface methods to MarketDataRepository
- Update MarketDataService to use refactored YFinanceClient
- Implement data sufficiency validation
- Integrate stockstats library for technical indicators
### Phase 4
- Comprehensive type checking and validation
- Integration testing with gap detection
- Performance optimization and caching efficiency
## Acceptance Criteria
### Must Have
1. **Type Safety**: Service passes `mise run typecheck` with zero errors
2. **Client Refactoring**: YFinanceClient uses date objects, no BaseClient
3. **Gap Detection**: Correctly identifies missing trading days
4. **Data Sufficiency**: Validates enough data for technical indicators
5. **Partial Data**: Service handles incomplete data gracefully
6. **Local-First**: Service checks repository before API calls
7. **Context Validation**: Returns valid `MarketDataContext` with Pydantic validation
8. **Technical Indicators**: Calculated using stockstats with proper validation
### Should Have
1. **Cache Efficiency**: Minimal redundant API calls to Yahoo Finance
2. **Force Refresh**: Complete cache bypass when requested
3. **Stale Data Handling**: Refresh today's data during market hours
4. **Clear Error Messages**: Informative errors for validation failures
### Nice to Have
1. **Performance Metrics**: Timing and cache hit rate logging
2. **Extended Indicators**: Support for 50+ technical indicators
3. **Real-time Data**: WebSocket integration for live prices
4. **Bulk Symbol Support**: Fetch multiple symbols efficiently
---
This PRD focuses on completing the `MarketDataService` as a strongly-typed, local-first data service that integrates OHLCV price data from a refactored `YFinanceClient` and calculates comprehensive technical indicators using the `stockstats` library, with robust gap detection and data sufficiency validation.

779
NewsService_PRD.md Normal file
View File

@ -0,0 +1,779 @@
# Product Requirements Document: NewsService Completion
## Overview
Complete the `NewsService` to provide strongly-typed news data and sentiment analysis to trading agents using a local-first data strategy with RSS feed integration, article content extraction, and LLM-powered sentiment analysis.
## Current State Analysis
### Issues to Fix
- **CRITICAL**: Service is currently empty placeholder with only method stubs
- **CRITICAL**: Need to implement GoogleNewsClient to read RSS feeds
- **CRITICAL**: Need RSS article fetching with fallback to Internet Archive
- **CRITICAL**: Need LLM-powered sentiment analysis integration
- **CRITICAL**: Service uses `BaseClient` inheritance instead of typed clients
- **CRITICAL**: `NewsRepository` has different interface than service expectations
- Missing strongly-typed interfaces between components
- No concrete approach for article content extraction
### What Works
- ✅ `NewsContext` and `ArticleData` Pydantic models for agent consumption
- ✅ `SentimentScore` model for structured sentiment data
- ✅ `FinnhubClient` with `get_company_news()` method using date objects
- ✅ `NewsRepository` with dataclass-based storage and deduplication
- ✅ Service structure placeholder ready for implementation
## Technical Requirements
### 1. Strongly-Typed Interfaces
#### Client → Service Interface
```python
# FinnhubClient methods (already implemented)
def get_company_news(symbol: str, start_date: date, end_date: date) -> dict[str, Any]
# GoogleNewsClient methods (to be implemented)
def fetch_rss_feed(query: str, start_date: date, end_date: date) -> dict[str, Any]
def fetch_article_content(url: str, use_archive_fallback: bool = True) -> dict[str, Any]
def get_company_news(symbol: str, start_date: date, end_date: date) -> dict[str, Any]
def get_global_news(start_date: date, end_date: date, categories: list[str]) -> dict[str, Any]
```
#### Service → Repository Interface
```python
# NewsRepository methods (to be implemented/bridged)
def has_data_for_period(query: str, start_date: str, end_date: str, symbol: str | None) -> bool
def get_data(query: str, start_date: str, end_date: str, symbol: str | None) -> dict[str, Any]
def store_data(query: str, cache_data: dict, symbol: str | None, overwrite: bool) -> bool
def clear_data(query: str, start_date: str, end_date: str, symbol: str | None) -> bool
```
#### Service → Agent Interface
```python
# Service output (already defined)
def get_context(query: str, start_date: str, end_date: str, symbol: str | None, sources: list[str], force_refresh: bool) -> NewsContext
```
### 2. Local-First Data Strategy
#### Flow
1. **Repository Lookup**: Check `NewsRepository.has_data_for_period()`
2. **Freshness Check**: Determine if cache needs updating (news is append-only)
3. **RSS Feed Fetching**: Fetch RSS feeds from Google News
4. **Content Extraction**: Extract full article content with Internet Archive fallback
5. **LLM Analysis**: Perform sentiment analysis using LLM
6. **Cache Updates**: Store enriched articles via `repository.store_data()`
7. **Context Assembly**: Return validated `NewsContext`
#### News-Specific Gap Detection
```python
def should_fetch_new_articles(self, last_fetch_time: datetime, current_time: datetime) -> bool:
"""
News doesn't have "gaps" - it's append-only. Check if enough time passed for new articles.
Returns True if:
- Last fetch was more than 6 hours ago
- User requested force_refresh
- No data exists for the query/period
"""
if not last_fetch_time:
return True
hours_since_fetch = (current_time - last_fetch_time).total_seconds() / 3600
return hours_since_fetch >= 6 # Fetch new articles every 6 hours
```
#### Force Refresh Support
- `force_refresh=True` fetches all articles fresh from sources
- Does NOT clear existing cache (news is immutable)
- Deduplicates against existing articles before storing
#### Cache Invalidation Strategy
- **Articles are immutable**: Once published, articles don't change
- **Cache grows append-only**: New articles are added, old ones retained
- **Freshness check**: Re-fetch every 6 hours for new articles
- **No deletion**: Articles are never removed from cache
### 3. RSS Feed Processing & Article Fetching
#### GoogleNewsClient RSS Implementation
```python
import feedparser
from newspaper import Article
import requests
from datetime import date, datetime
from typing import Any, Optional
class GoogleNewsClient:
"""Google News RSS client following FinnhubClient standard."""
def __init__(self):
self.base_rss_url = "https://news.google.com/rss"
self.archive_base_url = "https://archive.org/wayback/available"
def fetch_rss_feed(self, query: str, start_date: date, end_date: date) -> dict[str, Any]:
"""
Fetch RSS feed data for news articles.
Args:
query: Search query or company symbol
start_date: Start date for filtering articles
end_date: End date for filtering articles
Returns:
Dict containing RSS feed articles with metadata
"""
# Construct RSS feed URL
rss_url = f"{self.base_rss_url}/search?q={query}&hl=en-US&gl=US&ceid=US:en"
# Parse RSS feed
feed = feedparser.parse(rss_url)
# Filter and structure articles
articles = []
for entry in feed.entries:
# Parse publication date
pub_date = datetime(*entry.published_parsed[:6]).date()
# Filter by date range
if start_date <= pub_date <= end_date:
articles.append({
"headline": entry.title,
"url": entry.link,
"source": entry.source.get('title', 'Google News'),
"date": pub_date.isoformat(),
"summary": entry.get('summary', ''),
})
return {
"query": query,
"period": {"start": start_date.isoformat(), "end": end_date.isoformat()},
"articles": articles,
"metadata": {
"source": "google_news_rss",
"rss_feed_url": rss_url,
"article_count": len(articles)
}
}
def fetch_article_content(self, url: str, use_archive_fallback: bool = True) -> dict[str, Any]:
"""
Fetch full article content from URL with Internet Archive fallback.
Args:
url: Article URL to fetch
use_archive_fallback: Whether to try Internet Archive if direct fetch fails
Returns:
Dict containing article content, title, publication date
"""
try:
# Try direct fetch
article = Article(url)
article.download()
article.parse()
return {
"content": article.text,
"title": article.title,
"authors": article.authors,
"publish_date": article.publish_date.isoformat() if article.publish_date else None,
"extracted_via": "direct_fetch",
"extraction_success": True
}
except Exception as e:
if use_archive_fallback:
# Try Internet Archive
archive_url = self._get_archive_url(url)
if archive_url:
try:
article = Article(archive_url)
article.download()
article.parse()
return {
"content": article.text,
"title": article.title,
"authors": article.authors,
"publish_date": article.publish_date.isoformat() if article.publish_date else None,
"extracted_via": "internet_archive",
"extraction_success": True
}
except Exception:
pass
# Return failure
return {
"content": "",
"title": "",
"extracted_via": "failed",
"extraction_success": False,
"error": str(e)
}
def _get_archive_url(self, url: str) -> Optional[str]:
"""Get Internet Archive URL for a given URL."""
try:
response = requests.get(f"{self.archive_base_url}?url={url}")
data = response.json()
if data.get("archived_snapshots", {}).get("closest", {}).get("available"):
return data["archived_snapshots"]["closest"]["url"]
except Exception:
pass
return None
```
### 4. LLM-Powered Sentiment Analysis
#### Sentiment Analysis Integration
```python
class LLMSentimentAnalyzer:
"""LLM-based sentiment analyzer for financial news."""
def __init__(self, llm_client):
self.llm_client = llm_client
self.sentiment_prompt = """
Analyze the sentiment of this financial news article for trading purposes.
Article:
Title: {headline}
Content: {content}
Provide your analysis in the following JSON format:
{{
"score": <float between -1.0 (very negative) and 1.0 (very positive)>,
"confidence": <float between 0.0 and 1.0>,
"label": <"positive", "negative", or "neutral">,
"reasoning": <brief explanation>,
"key_themes": <list of key financial themes>,
"financial_entities": <list of mentioned companies/tickers>
}}
Focus on the financial and market implications of the news.
"""
def analyze_sentiment(self, article: ArticleData) -> SentimentScore:
"""
Analyze article sentiment using LLM.
Args:
article: Article data with headline and content
Returns:
SentimentScore with score, confidence, and label
"""
# Prepare prompt
prompt = self.sentiment_prompt.format(
headline=article.headline,
content=article.content[:2000] # Limit content length
)
# Get LLM response
response = self.llm_client.complete(prompt)
# Parse response
try:
result = json.loads(response)
# Convert to SentimentScore
score = result.get("score", 0.0)
return SentimentScore(
positive=max(0, score),
negative=abs(min(0, score)),
neutral=1.0 - abs(score),
metadata={
"confidence": result.get("confidence", 0.5),
"label": result.get("label", "neutral"),
"reasoning": result.get("reasoning", ""),
"key_themes": result.get("key_themes", []),
"financial_entities": result.get("financial_entities", [])
}
)
except Exception as e:
# Return neutral sentiment on error
return SentimentScore(
positive=0.0,
negative=0.0,
neutral=1.0,
metadata={"error": str(e)}
)
def batch_analyze(self, articles: list[ArticleData], batch_size: int = 5) -> list[SentimentScore]:
"""
Batch process sentiment analysis for multiple articles.
Args:
articles: List of articles to analyze
batch_size: Number of articles to process in parallel
Returns:
List of sentiment scores corresponding to input articles
"""
results = []
for i in range(0, len(articles), batch_size):
batch = articles[i:i + batch_size]
# Process batch (could be parallelized)
for article in batch:
sentiment = self.analyze_sentiment(article)
results.append(sentiment)
# Add small delay to respect rate limits
time.sleep(0.1)
return results
```
### 5. Date Object Conversion
#### Service Boundary Conversion
```python
# Service receives string dates from agents
def get_context(self, query: str, start_date: str, end_date: str, ...) -> NewsContext:
# Validate date strings
try:
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# Check date order
if end_dt < start_dt:
raise ValueError(f"End date {end_date} is before start date {start_date}")
# Fetch from multiple sources
finnhub_data = self.finnhub_client.get_company_news(symbol, start_dt, end_dt) if symbol else None
google_rss = self.google_client.fetch_rss_feed(query, start_dt, end_dt)
# Fetch full article content for RSS articles
for article in google_rss.get('articles', []):
content_data = self.google_client.fetch_article_content(article['url'])
article.update(content_data)
# Combine all articles
all_articles = self._combine_and_deduplicate(finnhub_data, google_rss)
# Perform LLM sentiment analysis
enriched_articles = []
for article in all_articles:
article_data = ArticleData(**article)
article_data.sentiment = self.sentiment_analyzer.analyze_sentiment(article_data)
enriched_articles.append(article_data)
# Create and return context
return self._create_news_context(enriched_articles, start_date, end_date)
```
### 6. Error Recovery and Partial Data
```python
def handle_source_failure(
self,
finnhub_data: dict | None,
google_data: dict | None,
errors: dict[str, Exception]
) -> NewsContext:
"""
Handle cases where one or more news sources fail.
- If all sources fail: Raise exception
- If some sources succeed: Return partial data with metadata
- Track content extraction failures separately
"""
if not finnhub_data and not google_data:
raise ValueError("All news sources failed to return data")
# Track extraction statistics
extraction_stats = {
"total_articles": 0,
"successful_extractions": 0,
"archive_fallbacks": 0,
"failed_extractions": 0
}
# Process available articles
all_articles = []
successful_sources = []
if finnhub_data:
all_articles.extend(finnhub_data.get('articles', []))
successful_sources.append('finnhub')
if google_data:
articles = google_data.get('articles', [])
for article in articles:
extraction_stats["total_articles"] += 1
if article.get("extraction_success"):
extraction_stats["successful_extractions"] += 1
if article.get("extracted_via") == "internet_archive":
extraction_stats["archive_fallbacks"] += 1
else:
extraction_stats["failed_extractions"] += 1
all_articles.extend(articles)
successful_sources.append('google_news')
metadata = {
"sources_requested": ["finnhub", "google_news"],
"sources_successful": successful_sources,
"sources_failed": {source: str(error) for source, error in errors.items()},
"extraction_stats": extraction_stats,
"partial_data": len(successful_sources) < 2
}
# Deduplicate and return context
return self._create_context(all_articles, metadata)
```
### 7. Repository Method Bridging
```python
# Add these bridge methods to NewsRepository
def has_data_for_period(self, query: str, start_date: str, end_date: str, symbol: str | None = None) -> bool:
"""Bridge to existing get_news_data method."""
existing_data = self.get_news_data(
symbol=symbol or query,
start_date=start_date,
end_date=end_date
)
return len(existing_data.get('articles', [])) > 0
def get_data(self, query: str, start_date: str, end_date: str, symbol: str | None = None) -> dict[str, Any]:
"""Bridge to existing get_news_data method."""
return self.get_news_data(
symbol=symbol or query,
start_date=start_date,
end_date=end_date
)
def store_data(self, query: str, cache_data: dict, symbol: str | None = None, overwrite: bool = False) -> bool:
"""Bridge to existing store_news_articles method."""
articles = cache_data.get('articles', [])
if not articles:
return False
# Convert to expected format
news_articles = [
NewsArticle(
symbol=symbol or query,
headline=a['headline'],
summary=a.get('summary', ''),
content=a.get('content', ''),
url=a['url'],
source=a['source'],
date=a['date'],
entities=a.get('entities', []),
sentiment_score=a.get('sentiment', {}).get('score', 0.0),
sentiment_metadata=a.get('sentiment', {})
)
for a in articles
]
return self.store_news_articles(news_articles)
def clear_data(self, query: str, start_date: str, end_date: str, symbol: str | None = None) -> bool:
"""News is append-only, so this just marks data as stale for re-fetch."""
# Implementation depends on repository design
# Could update metadata to trigger re-fetch
return True
```
### 8. Pydantic Validation
#### Context Structure
```python
@dataclass
class NewsContext(BaseModel):
symbol: str | None
period: dict[str, str] # {"start": "2024-01-01", "end": "2024-01-31"}
articles: list[ArticleData]
sentiment_summary: SentimentScore
article_count: int
sources: list[str]
metadata: dict[str, Any]
@validator('period')
def validate_period(cls, v):
# Ensure start and end dates are present and valid
if 'start' not in v or 'end' not in v:
raise ValueError("Period must have 'start' and 'end' dates")
return v
@validator('articles')
def validate_articles(cls, v):
# Ensure no duplicate URLs
urls = [a.url for a in v]
if len(urls) != len(set(urls)):
raise ValueError("Duplicate articles detected")
return v
```
## Implementation Tasks
### Phase 1: Create GoogleNewsClient
1. **GoogleNewsClient Implementation**
- Create `tradingagents/clients/google_news_client.py` following FinnhubClient standard
- Implement RSS feed parsing using `feedparser` library
- Add `fetch_rss_feed()` method with Google News RSS integration
- Add `fetch_article_content()` method with `newspaper3k` and Internet Archive fallback
- Use `date` objects for all date parameters
- No BaseClient inheritance
2. **Article Content Extraction**
- Implement robust article content extraction using `newspaper3k`
- Add fallback to Internet Archive Wayback Machine for failed fetches
- Handle paywall detection and alternative content sources
- Extract clean text, title, publication date, and metadata
3. **Comprehensive Testing**
- Create test suite for GoogleNewsClient
- Test RSS parsing with various queries
- Test content extraction with real and archived URLs
- Use pytest-vcr for HTTP interaction recording
### Phase 2: Bridge NewsRepository Interface
4. **Repository Interface Standardization**
- Add standard service interface methods to `NewsRepository`
- Bridge existing methods without changing underlying storage
- File: `tradingagents/repositories/news_repository.py`
- Maintain backward compatibility
### Phase 3: Implement NewsService
5. **Service Core Implementation**
- Replace method stubs with full implementation
- Implement `get_context()`, `get_company_news_context()`, `get_global_news_context()`
- Add local-first data strategy with freshness checking
- Replace `BaseClient` dependencies with typed clients
- File: `tradingagents/services/news_service.py`
6. **LLM Sentiment Analysis Integration**
- Implement `LLMSentimentAnalyzer` class
- Create financial news sentiment prompts
- Add batch processing for efficiency
- Handle LLM rate limiting and errors
7. **Date Conversion and Article Processing**
- Add date validation and conversion
- Implement RSS article fetching pipeline
- Add content extraction with fallback
- Combine articles from multiple sources
- Implement deduplication by URL
### Phase 4: Type Safety & Validation
8. **Comprehensive Type Checking**
- Run `mise run typecheck` - must pass with 0 errors
- Validate all date object conversions
- Ensure NewsContext compliance
9. **Enhanced Testing**
- Test RSS feed parsing edge cases
- Test content extraction failures and fallbacks
- Test LLM sentiment analysis with various article types
- Test multi-source aggregation and deduplication
## Testing Scenarios
### Integration Tests
1. **RSS Feed Processing**
- Test with various search queries
- Test date filtering in RSS results
- Test handling of malformed RSS feeds
2. **Content Extraction**
- Test direct fetch success
- Test Internet Archive fallback
- Test paywall detection
- Test extraction failure handling
3. **LLM Sentiment Analysis**
- Test positive news sentiment
- Test negative earnings reports
- Test neutral market updates
- Test batch processing
- Test LLM error handling
4. **Multi-Source Aggregation**
- Test both sources succeed
- Test Finnhub fails, Google succeeds
- Test Google fails, Finnhub succeeds
- Test both sources fail
5. **Date Handling**
- Test invalid date formats
- Test end_date < start_date
- Test date filtering in RSS feeds
## Success Criteria
### Functional Requirements
- ✅ Service successfully implements all placeholder methods
- ✅ GoogleNewsClient reads and parses RSS feeds correctly
- ✅ Article content extraction works with Internet Archive fallback
- ✅ LLM sentiment analysis provides structured financial sentiment
- ✅ Local-first strategy with proper freshness checking
- ✅ Multi-source aggregation with deduplication
- ✅ Returns properly validated `NewsContext` to agents
- ✅ Force refresh fetches fresh articles without clearing cache
### Technical Requirements
- ✅ Zero type checking errors: `mise run typecheck`
- ✅ Zero linting errors: `mise run lint`
- ✅ All tests pass with new implementation
- ✅ No runtime errors with date conversions
- ✅ Proper error messages for validation failures
### Quality Requirements
- ✅ Strongly-typed interfaces between all components
- ✅ RSS feed parsing with robust error handling
- ✅ Article content extraction with fallback strategy
- ✅ LLM integration with proper prompt engineering
- ✅ Efficient caching with minimal external calls
- ✅ Clear separation of concerns
## Data Architecture
### GoogleNewsClient RSS Response Format
```python
{
"query": "Apple stock",
"period": {"start": "2024-01-01", "end": "2024-01-31"},
"articles": [
{
"headline": "Apple Stock Soars on New Product Launch",
"summary": "Brief summary from RSS feed...",
"content": "Full article text extracted from source...",
"url": "https://www.cnbc.com/2024/01/20/apple-stock.html",
"source": "CNBC",
"date": "2024-01-20",
"authors": ["Tech Reporter"],
"publish_date": "2024-01-20T14:30:00Z",
"extracted_via": "direct_fetch", # or "internet_archive"
"extraction_success": true
}
],
"metadata": {
"source": "google_news_rss",
"article_count": 25,
"rss_feed_url": "https://news.google.com/rss/search?q=Apple+stock",
"extraction_stats": {
"successful": 22,
"archive_fallback": 2,
"failed": 3
}
}
}
```
### LLM Sentiment Analysis Response Format
```python
{
"article_url": "https://www.cnbc.com/2024/01/20/apple-stock.html",
"sentiment": {
"positive": 0.7,
"negative": 0.1,
"neutral": 0.2,
"metadata": {
"score": 0.7,
"confidence": 0.85,
"label": "positive",
"reasoning": "Article discusses positive earnings and growth outlook",
"key_themes": ["earnings_beat", "product_launch", "revenue_growth"],
"financial_entities": ["AAPL", "Apple Inc.", "iPhone 15"]
}
}
}
```
### Aggregate Sentiment Summary
```python
{
"sentiment_summary": {
"positive": 0.65, # Average across all articles
"negative": 0.20,
"neutral": 0.15,
"metadata": {
"dominant_sentiment": "positive",
"confidence": 0.82,
"article_count": 25,
"themes": {
"earnings": 8,
"product_launch": 5,
"market_analysis": 12
}
}
}
}
```
## Dependencies
### Components to Create
- ⏳ `GoogleNewsClient` - Full implementation with RSS and content extraction
- ⏳ `LLMSentimentAnalyzer` - LLM integration for sentiment analysis
- ⏳ `NewsService` - Replace stubs with full implementation
### Existing Components
- ✅ `FinnhubClient` with company news using date objects
- ✅ `NewsRepository` with dataclass storage
- ✅ `NewsContext` and related Pydantic models
### Required Libraries
- `feedparser` - RSS feed parsing
- `newspaper3k` - Article content extraction
- `requests` - HTTP requests and Internet Archive API
- `beautifulsoup4` - HTML parsing fallback
- LLM client library (OpenAI, Anthropic, etc.)
## Timeline
### Immediate (Phase 1)
- Create GoogleNewsClient with RSS and content extraction
- Implement feedparser integration
- Add Internet Archive fallback
- Create comprehensive test suite
### Phase 2-3
- Add repository bridge methods
- Implement full NewsService
- Integrate LLM sentiment analysis
- Handle multi-source aggregation
### Phase 4
- Type checking and validation
- Integration testing
- Performance optimization
- Documentation
## Acceptance Criteria
### Must Have
1. **Type Safety**: Service passes `mise run typecheck` with zero errors
2. **RSS Integration**: Successfully parse Google News RSS feeds
3. **Content Extraction**: Extract full articles with fallback
4. **LLM Sentiment**: Financial sentiment analysis for all articles
5. **Service Implementation**: All stubs replaced with working code
6. **Local-First**: Check cache before fetching new data
7. **Multi-Source**: Aggregate Finnhub and Google News
### Should Have
1. **Extraction Stats**: Track success/failure rates
2. **Batch Processing**: Efficient LLM sentiment analysis
3. **Force Refresh**: Fetch new articles on demand
4. **Error Recovery**: Handle partial failures gracefully
### Nice to Have
1. **Additional Sources**: Support more news providers
2. **Real-time Monitoring**: WebSocket for breaking news
3. **Advanced Extraction**: Handle PDFs, videos
4. **Sentiment Trends**: Track sentiment over time
---
This PRD focuses on completing the currently empty `NewsService` with a full implementation including RSS feed integration, article content extraction with Internet Archive fallback, and LLM-powered sentiment analysis for financial news.

424
SocialMediaService_PRD.md Normal file
View File

@ -0,0 +1,424 @@
# Product Requirements Document: SocialMediaService Completion
## Overview
Complete the `SocialMediaService` to provide strongly-typed social media data and sentiment analysis to trading agents using a local-first data strategy with gap detection and intelligent caching.
## Current State Analysis
### Issues to Fix
- **CRITICAL**: Missing `RedditClient` implementation - service calls non-existent client methods
- **CRITICAL**: Service uses `BaseClient` inheritance but needs typed `RedditClient`
- **CRITICAL**: `SocialRepository` has different interface than standard service pattern
- **CRITICAL**: Repository uses `date` objects internally but service expects string date interface
- Missing strongly-typed interfaces between components
- Service calls `reddit_client.search_posts()`, `get_top_posts()`, `filter_posts_by_date()` methods that don't exist
### What Works
- ✅ Local-first data strategy implementation (`_get_social_data_local_first`)
- ✅ Force refresh logic (`_fetch_and_cache_fresh_social_data`)
- ✅ `SocialContext` Pydantic model for agent consumption
- ✅ Comprehensive sentiment analysis with keyword-based scoring
- ✅ Engagement metrics calculation and post ranking
- ✅ Error handling and metadata creation patterns
- ✅ `SocialRepository` with JSON storage and post deduplication
- ✅ `PostData` and `SentimentScore` models for structured data
- ✅ Real-time sentiment analysis with weighted scoring
## Technical Requirements
### 1. Strongly-Typed Interfaces
#### Client → Service Interface
```python
# RedditClient methods (to be implemented)
def search_posts(query: str, subreddit_names: list[str], start_date: date, end_date: date, limit: int, time_filter: str) -> dict[str, Any]
def get_top_posts(subreddit_names: list[str], start_date: date, end_date: date, limit: int, time_filter: str) -> dict[str, Any]
def get_company_posts(symbol: str, subreddit_names: list[str], start_date: date, end_date: date, limit: int) -> dict[str, Any]
```
#### Service → Repository Interface
```python
# SocialRepository methods (to be implemented/bridged)
def has_data_for_period(query: str, start_date: str, end_date: str, symbol: str | None) -> bool
def get_data(query: str, start_date: str, end_date: str, symbol: str | None) -> dict[str, Any]
def store_data(query: str, cache_data: dict, symbol: str | None, overwrite: bool) -> bool
def clear_data(query: str, start_date: str, end_date: str, symbol: str | None) -> bool
```
#### Service → Agent Interface
```python
# Service output (already defined)
def get_context(query: str, start_date: str, end_date: str, symbol: str | None, subreddits: list[str], force_refresh: bool) -> SocialContext
def get_company_social_context(symbol: str, start_date: str, end_date: str, subreddits: list[str]) -> SocialContext
def get_global_trends(start_date: str, end_date: str, subreddits: list[str]) -> SocialContext
```
### 2. Local-First Data Strategy
#### Flow
1. **Repository Lookup**: Check `SocialRepository.has_data_for_period()`
2. **Gap Detection**: Identify missing social media data periods
3. **Selective Fetching**: Fetch only missing data from `RedditClient`
4. **Cache Updates**: Store new data via `repository.store_data()`
5. **Context Assembly**: Return validated `SocialContext`
#### Force Refresh Support
- `force_refresh=True` bypasses local data completely
- Clears existing cache before fetching fresh data
- Stores refreshed data with metadata indicating refresh
### 3. Date Object Conversion
#### Service Boundary Conversion
```python
# Service receives string dates from agents
def get_context(self, query: str, start_date: str, end_date: str, ...) -> SocialContext:
# Convert to date objects for client calls
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
# Use date objects when calling RedditClient
posts_data = self.reddit_client.search_posts(query, subreddits, start_dt, end_dt, limit, time_filter)
# Repository bridge handles string to date conversion internally
cached_data = self.repository.get_data(query, start_date, end_date, symbol)
```
### 4. Reddit API Integration
#### RedditClient Implementation Strategy
```python
# RedditClient following FinnhubClient standard
class RedditClient:
"""Client for Reddit API access with PRAW library integration."""
def __init__(self, client_id: str, client_secret: str, user_agent: str):
"""Initialize Reddit client with PRAW."""
import praw
self.reddit = praw.Reddit(
client_id=client_id,
client_secret=client_secret,
user_agent=user_agent
)
def search_posts(self, query: str, subreddit_names: list[str],
start_date: date, end_date: date, limit: int = 50,
time_filter: str = "week") -> dict[str, Any]:
"""Search for posts across subreddits within date range."""
def get_top_posts(self, subreddit_names: list[str],
start_date: date, end_date: date, limit: int = 50,
time_filter: str = "week") -> dict[str, Any]:
"""Get top posts from subreddits within date range."""
def get_company_posts(self, symbol: str, subreddit_names: list[str],
start_date: date, end_date: date, limit: int = 50) -> dict[str, Any]:
"""Get company-specific posts from subreddits."""
```
#### Reddit Response Format
```python
{
"query": "AAPL",
"period": {"start": "2024-01-01", "end": "2024-01-31"},
"posts": [
{
"title": "Apple earnings discussion",
"content": "What do you think about...",
"author": "redditor123",
"subreddit": "investing",
"created_utc": 1704067200,
"score": 125,
"num_comments": 45,
"upvote_ratio": 0.87,
"url": "https://reddit.com/r/investing/comments/abc123",
"id": "abc123"
}
],
"metadata": {
"source": "reddit",
"retrieved_at": "2024-01-31T10:00:00Z",
"data_quality": "HIGH",
"subreddits": ["investing", "stocks"],
"total_posts": 25
}
}
```
### 5. Sentiment Analysis Enhancement
#### Advanced Sentiment Features
- **Weighted Scoring**: High-engagement posts have more influence on overall sentiment
- **Keyword Analysis**: Comprehensive positive/negative keyword detection
- **Score Adjustment**: Reddit score (upvotes) influences sentiment confidence
- **Confidence Metrics**: Based on post count and engagement levels
- **Multi-level Analysis**: Individual post sentiment + overall summary sentiment
#### Sentiment Calculation Strategy
```python
def _calculate_advanced_sentiment(self, posts: list[PostData]) -> SentimentScore:
"""Enhanced sentiment analysis with multiple factors."""
# Weight by engagement score (upvotes + comments)
# Adjust for subreddit context (WSB vs investing)
# Consider temporal patterns (recent posts weighted higher)
# Apply confidence scoring based on data volume
```
### 6. Pydantic Validation
#### Context Structure
```python
@dataclass
class SocialContext(BaseModel):
symbol: str | None
period: dict[str, str] # {"start": "2024-01-01", "end": "2024-01-31"}
posts: list[PostData]
engagement_metrics: dict[str, float]
sentiment_summary: SentimentScore
post_count: int
platforms: list[str] # ["reddit"]
metadata: dict[str, Any]
```
#### PostData Format
```python
@dataclass
class PostData(BaseModel):
title: str
content: str
author: str
source: str # subreddit name
date: str
url: str
score: int
comments: int
engagement_score: int
subreddit: str | None
sentiment: SentimentScore | None
metadata: dict[str, Any]
```
## Implementation Tasks
### Phase 1: Create RedditClient
1. **RedditClient Implementation**
- Create `tradingagents/clients/reddit_client.py`
- Follow FinnhubClient standard: no BaseClient inheritance, date objects, proper error handling
- Use PRAW (Python Reddit API Wrapper) library for Reddit API access
- Methods: `search_posts()`, `get_top_posts()`, `get_company_posts()`
- Implement date filtering for posts within specified ranges
- Handle Reddit API rate limits and authentication
2. **Comprehensive Testing**
- Create `tradingagents/clients/test_reddit_client.py`
- Use pytest-vcr for Reddit API interaction recording
- Test all client methods with multiple queries and subreddits
- Test error handling and API rate limit scenarios
- Mock Reddit API responses for consistent testing
### Phase 2: Bridge SocialRepository Interface
3. **Repository Interface Standardization**
- Add standard service interface methods to `SocialRepository`
- Bridge existing `get_social_data()` with `get_data()`
- Bridge existing `store_social_posts()` with `store_data()`
- Add missing `has_data_for_period()` and `clear_data()` methods
- File: `tradingagents/repositories/social_repository.py`
- Maintain existing dataclass functionality while adding service compatibility
4. **Repository Method Implementation**
```python
# Add these methods to SocialRepository
def has_data_for_period(self, query: str, start_date: str, end_date: str, symbol: str | None = None) -> bool
def get_data(self, query: str, start_date: str, end_date: str, symbol: str | None = None) -> dict[str, Any]
def store_data(self, query: str, cache_data: dict, symbol: str | None = None, overwrite: bool = False) -> bool
def clear_data(self, query: str, start_date: str, end_date: str, symbol: str | None = None) -> bool
```
### Phase 3: Update SocialMediaService
5. **Client Integration Fix**
- Replace `BaseClient` dependency with `RedditClient`
- File: `tradingagents/services/social_media_service.py:27`
- Update constructor: `reddit_client: RedditClient`
6. **Date Conversion Fix**
- Add `date.fromisoformat()` conversion in service methods
- Update all client calls to use date objects instead of strings
- File: `tradingagents/services/social_media_service.py:182-190, 418-429`
7. **Repository Interface Integration**
- Update repository method calls to use new standard interface
- Ensure proper error handling for repository operations
- File: `tradingagents/services/social_media_service.py:302-311, 325-337`
### Phase 4: Type Safety & Validation
8. **Comprehensive Type Checking**
- Run `mise run typecheck` - must pass with 0 errors
- Validate all date object conversions
- Ensure SocialContext compliance
9. **Enhanced Testing**
- Update existing service tests for new RedditClient interface
- Add gap detection test scenarios
- Test sentiment analysis accuracy with known datasets
- Test multi-subreddit aggregation and deduplication
## Success Criteria
### Functional Requirements
- ✅ Service successfully calls `RedditClient` with `date` objects
- ✅ Local-first strategy works: checks cache → identifies gaps → fetches missing → stores updates
- ✅ Returns properly validated `SocialContext` to agents
- ✅ Sentiment analysis provides accurate scores with confidence metrics
- ✅ Multi-subreddit support with post deduplication
- ✅ Force refresh bypasses cache and refreshes data
### Technical Requirements
- ✅ Zero type checking errors: `mise run typecheck`
- ✅ Zero linting errors: `mise run lint`
- ✅ All existing tests pass with updated architecture
- ✅ No runtime errors with date conversions
### Quality Requirements
- ✅ Strongly-typed interfaces between all components
- ✅ PRAW library integration for reliable Reddit API access
- ✅ Comprehensive error handling and logging
- ✅ Efficient caching with minimal API calls
- ✅ Clear separation of concerns between service, client, and repository
- ✅ Accurate sentiment analysis with engagement weighting
## Data Architecture
### RedditClient Response Format
```python
{
"query": "Tesla",
"period": {"start": "2024-01-01", "end": "2024-01-31"},
"posts": [
{
"title": "Tesla Q4 earnings beat expectations",
"content": "Tesla reported strong Q4 results...",
"author": "teslaInvestor",
"subreddit": "TeslaInvestors",
"created_utc": 1704067200,
"score": 245,
"num_comments": 67,
"upvote_ratio": 0.92,
"url": "https://reddit.com/r/TeslaInvestors/comments/xyz789",
"id": "xyz789"
}
],
"metadata": {
"source": "reddit",
"retrieved_at": "2024-01-31T10:00:00Z",
"data_quality": "HIGH",
"subreddits": ["TeslaInvestors", "stocks"],
"post_count": 25,
"api_calls": 3
}
}
```
### SocialRepository Data Bridge Format
```python
# Repository stores data in existing SocialPost format but provides service interface
{
"query": "Tesla",
"symbol": "TSLA",
"posts": [
{
"title": "Tesla Q4 earnings beat expectations",
"content": "Tesla reported strong Q4 results...",
"author": "teslaInvestor",
"source": "TeslaInvestors",
"date": "2024-01-15",
"url": "https://reddit.com/r/TeslaInvestors/comments/xyz789",
"score": 245,
"comments": 67,
"engagement_score": 312,
"subreddit": "TeslaInvestors",
"sentiment": {
"score": 0.7,
"confidence": 0.8,
"label": "positive"
},
"metadata": {
"platform_id": "xyz789",
"upvote_ratio": 0.92
}
}
],
"metadata": {
"cached_at": "2024-01-31T10:00:00Z",
"post_count": 25,
"sources": ["reddit"]
}
}
```
## Dependencies
### Missing Components (Need Creation)
- ⏳ `RedditClient` needs full implementation from scratch
- ⏳ Service interface bridge methods for `SocialRepository`
- ⏳ Comprehensive pytest-vcr test suites for Reddit API
### Existing Components (Ready)
- ✅ `SocialRepository` with JSON storage and deduplication
- ✅ `SocialContext` and `PostData` Pydantic models
- ✅ Sentiment analysis and engagement metrics logic
### Required
- PRAW (Python Reddit API Wrapper) library for Reddit integration
- Valid Reddit API credentials (client_id, client_secret, user_agent)
- Working internet connection for live data fetching
- Writable data directory for repository storage
## Timeline
### Immediate (Phase 1)
- Create RedditClient following FinnhubClient standard with PRAW integration
- Implement comprehensive testing with pytest-vcr for Reddit API
- Validate client functionality with multiple subreddits and queries
### Phase 2-3
- Add standard service interface methods to SocialRepository
- Update SocialMediaService to use RedditClient with date objects
- Bridge repository interfaces while maintaining existing functionality
### Phase 4
- Comprehensive type checking and validation
- Integration testing with sentiment analysis workflows
- Performance optimization and caching efficiency
## Acceptance Criteria
### Must Have
1. **Type Safety**: Service passes `mise run typecheck` with zero errors
2. **Client Integration**: All `RedditClient` calls use `date` objects correctly
3. **Local-First**: Service checks repository before Reddit API calls
4. **Context Validation**: Returns valid `SocialContext` with Pydantic validation
5. **Sentiment Analysis**: Provides accurate sentiment scores with confidence metrics
6. **Multi-Platform**: Seamlessly aggregates social data from Reddit with extensibility
### Should Have
1. **Gap Detection**: Intelligent identification of missing data periods
2. **Cache Efficiency**: Minimal redundant API calls to Reddit
3. **Force Refresh**: Complete cache bypass when requested
4. **Data Quality**: Metadata indicating data source and quality metrics
5. **Deduplication**: Automatic removal of duplicate posts by platform_id
### Nice to Have
1. **Performance Metrics**: Timing and cache hit rate logging
2. **Data Staleness**: Automatic refresh of old cached social data
3. **Enhanced Sentiment**: Integration with advanced NLP libraries (TextBlob, VADER)
4. **Real-time Social**: Support for live social media feeds and alerts
5. **Platform Expansion**: Easy addition of Twitter, Discord, other social platforms
---
This PRD focuses on completing the `SocialMediaService` as a strongly-typed, local-first data service that integrates Reddit social media data through a new `RedditClient` following the established FinnhubClient standard patterns, while providing comprehensive sentiment analysis and engagement metrics to trading agents.

View File

@ -0,0 +1,502 @@
{
"russell2000_by_sector": {
"Information Technology": {
"sector_weight": "16.15%",
"companies": [
{ "ticker": "TWOU", "name": "2U Inc." },
{ "ticker": "EGHT", "name": "8x8 Inc." },
{ "ticker": "ATEN", "name": "A10 Networks Inc" },
{ "ticker": "ACIW", "name": "ACI Worldwide Inc" },
{ "ticker": "ADTN", "name": "ADTRAN Inc." },
{ "ticker": "AGYS", "name": "Agilysys Inc." },
{ "ticker": "AMBA", "name": "Ambarella Inc" },
{ "ticker": "AMSWA", "name": "American Software Inc." },
{ "ticker": "AMKR", "name": "Amkor Technology Inc." },
{ "ticker": "AAOI", "name": "Applied Optoelectronics Inc" },
{ "ticker": "APPS", "name": "Digital Turbine Inc" },
{ "ticker": "AZPN", "name": "Aspen Technology, Inc." },
{ "ticker": "ATNI", "name": "Atlantic Tele-Network Inc." },
{ "ticker": "AVID", "name": "Avid Technology, Inc." },
{ "ticker": "ACLS", "name": "Axcelius Technologies Inc." },
{ "ticker": "BELFB", "name": "Bel Fuse Inc." },
{ "ticker": "BHE", "name": "Benchmark Electronics Inc." },
{ "ticker": "BNFT", "name": "Benefitfocus Inc" },
{ "ticker": "BLKB", "name": "Blackbaud Inc." },
{ "ticker": "EPAY", "name": "Bottomline Technologies Inc." },
{ "ticker": "BOX", "name": "Box Inc" },
{ "ticker": "BCOV", "name": "Brightcove Inc" },
{ "ticker": "CACI", "name": "CACI International Inc." },
{ "ticker": "CAMP", "name": "CalAmp Corp." },
{ "ticker": "CALX", "name": "Calix Inc." },
{ "ticker": "CARB", "name": "Carbonite Inc" },
{ "ticker": "CRCM", "name": "Care.com Inc" },
{ "ticker": "CSLT", "name": "Castlight Health Inc" },
{ "ticker": "CTS", "name": "CTS Corporation" },
{ "ticker": "CEVA", "name": "CEVA Inc." },
{ "ticker": "ECOM", "name": "ChannelAdvisor Corp" },
{ "ticker": "CIEN", "name": "Ciena Corporation" },
{ "ticker": "CRUS", "name": "Cirrus Logic Inc." },
{ "ticker": "CLFD", "name": "Clearfield Inc" },
{ "ticker": "COHR", "name": "Coherent, Inc." },
{ "ticker": "COHU", "name": "Cohu Inc." },
{ "ticker": "CVLT", "name": "CommVault Systems, Inc." },
{ "ticker": "CMTL", "name": "Comtech Telecommunications Corp." },
{
"ticker": "CNSL",
"name": "Consolidated Communications Holdings, Inc."
},
{ "ticker": "CSGS", "name": "CSG Systems International Inc." },
{ "ticker": "DAKT", "name": "Daktronics Inc." },
{ "ticker": "DBD", "name": "Diebold Nixdorf, Incorporated" },
{ "ticker": "DMRC", "name": "Digimarc Corporation" },
{ "ticker": "DIOD", "name": "Diodes Incorporated" },
{ "ticker": "DSPG", "name": "DSP Group Inc." }
]
},
"Financials": {
"sector_weight": "17.76%",
"companies": [
{ "ticker": "SRCE", "name": "1st Source Corporation" },
{ "ticker": "ALEX", "name": "Alexander & Baldwin, Inc." },
{
"ticker": "AEL",
"name": "American Equity Investment Life Holding Co."
},
{ "ticker": "AMNB", "name": "American National Bankshares Inc." },
{ "ticker": "AMSF", "name": "Amerisafe, Inc." },
{ "ticker": "ABCB", "name": "Ameris Bancorp" },
{ "ticker": "ATLO", "name": "Ames National Corporation" },
{ "ticker": "AMBC", "name": "Ambac Financial Group, Inc." },
{ "ticker": "AAMC", "name": "Altisource Asset Management Corp" },
{ "ticker": "ARGO", "name": "Argo Group International Holdings Ltd." },
{ "ticker": "BANF", "name": "BancFirst Corporation" },
{ "ticker": "BANC", "name": "Banc of California Inc" },
{ "ticker": "BANR", "name": "Banner Corporation" },
{ "ticker": "BHB", "name": "Bar Harbor Bankshares" },
{ "ticker": "BDGE", "name": "Bridge Bancorp Inc." },
{ "ticker": "BRKL", "name": "Brookline Bancorp Inc." },
{ "ticker": "BMRC", "name": "Bryn Mawr Bank Corp." },
{ "ticker": "CLMS", "name": "Calamos Asset Management Inc." },
{ "ticker": "CAC", "name": "Camden National Corporation" },
{ "ticker": "CCBG", "name": "Capital City Bank Group Inc." },
{ "ticker": "CATY", "name": "Cathay General Bancorp" },
{ "ticker": "CSFL", "name": "CenterState Banks, Inc." },
{ "ticker": "CPF", "name": "Central Pacific Financial Corp." },
{ "ticker": "CNS", "name": "Cohen & Steers Inc." },
{ "ticker": "COR", "name": "CoreSite Realty Corp" },
{ "ticker": "CORR", "name": "Corenergy Infrastructure Trust Inc" },
{ "ticker": "CUBE", "name": "CubeSmart" },
{ "ticker": "CONE", "name": "CyrusOne Inc" },
{ "ticker": "DRH", "name": "DiamondRock Hospitality Co." },
{ "ticker": "DGICA", "name": "Donegal Group Inc." },
{ "ticker": "EXR", "name": "EastGroup Properties" },
{ "ticker": "EMCI", "name": "EMC Insurance Group Inc." },
{ "ticker": "EPR", "name": "EPR Properties" },
{ "ticker": "ESNT", "name": "Essent Group Ltd" },
{ "ticker": "FFG", "name": "FBL Financial Group Inc." },
{ "ticker": "FNHC", "name": "Federated National Holding Co" },
{ "ticker": "GTY", "name": "Getty Realty Corp." },
{ "ticker": "GBLI", "name": "Global Indemnity plc" },
{ "ticker": "HALL", "name": "Hallmark Financial Services Inc." },
{ "ticker": "HCI", "name": "HCI Group Inc" },
{ "ticker": "HRTG", "name": "Heritage Insurance Holdings Inc" },
{ "ticker": "HTH", "name": "Hilltop Holdings Inc." },
{ "ticker": "HMN", "name": "Horace Mann Educators Corp." }
]
},
"Health Care": {
"sector_weight": "15.74%",
"companies": [
{ "ticker": "ACHC", "name": "Acadia Healthcare Company, Inc." },
{ "ticker": "ADMA", "name": "ADMA Biologics" },
{ "ticker": "AERI", "name": "Aerie Pharmaceuticals Inc" },
{ "ticker": "AGIO", "name": "Agios Pharmaceuticals" },
{ "ticker": "AKRO", "name": "Akero Therapeutics" },
{ "ticker": "ALKS", "name": "Alkermes" },
{ "ticker": "AMAG", "name": "AMAG Pharmaceuticals, Inc." },
{ "ticker": "AMPH", "name": "Amphastar Pharmaceuticals Inc" },
{ "ticker": "ANGI", "name": "AngioDynamics Inc" },
{ "ticker": "ATRC", "name": "AtriCure, Inc." },
{ "ticker": "AVID", "name": "Avidity Biosciences" },
{ "ticker": "AXSM", "name": "Axsome Therapeutics" },
{ "ticker": "BBIO", "name": "BridgeBio Pharma" },
{ "ticker": "BPMC", "name": "Blueprint Medicines Corp" },
{ "ticker": "CORT", "name": "Corcept Therapeutics" },
{ "ticker": "CRNX", "name": "Crinetics Pharmaceuticals" },
{ "ticker": "CYTK", "name": "Cytokinetics" },
{ "ticker": "ENSG", "name": "The Ensign Group, Inc." },
{ "ticker": "GKOS", "name": "Glaukos" },
{ "ticker": "GH", "name": "Guardant Health" },
{ "ticker": "HALO", "name": "Halozyme Therapeutics" },
{ "ticker": "HQY", "name": "HealthEquity, Inc." },
{ "ticker": "HIMS", "name": "Hims & Hers Health, Inc." },
{ "ticker": "NARI", "name": "Inari Medical" },
{ "ticker": "INSM", "name": "Insmed Incorporated" },
{ "ticker": "ITCI", "name": "Intra-Cellular Therapies, Inc." },
{ "ticker": "IRTC", "name": "iRhythm Technologies" },
{ "ticker": "LNTH", "name": "Lantheus" },
{ "ticker": "MDGL", "name": "Madrigal Pharmaceuticals" },
{ "ticker": "MMSI", "name": "Merit Medical Systems" },
{ "ticker": "NUVL", "name": "Nuvalent" },
{ "ticker": "OPCH", "name": "Option Care Health" },
{ "ticker": "RDNT", "name": "RadNet" },
{ "ticker": "RVMD", "name": "Revolution Medicines" },
{ "ticker": "RYTM", "name": "Rhythm Pharmaceuticals" },
{ "ticker": "SMMT", "name": "Summit Therapeutics Inc." },
{ "ticker": "TGTX", "name": "TG Therapeutics" },
{ "ticker": "TMDX", "name": "TransMedics Group" },
{ "ticker": "PCVX", "name": "Vaxcyte" },
{ "ticker": "RNA", "name": "Avidity Biosciences" }
]
},
"Industrials": {
"sector_weight": "15.31%",
"companies": [
{ "ticker": "AAON", "name": "AAON Inc." },
{ "ticker": "AAR", "name": "AAR Corporation" },
{ "ticker": "AVAV", "name": "AeroVironment Inc." },
{ "ticker": "ATI", "name": "Allegheny Technologies Inc" },
{ "ticker": "AMRC", "name": "Ameresco Inc." },
{ "ticker": "APG", "name": "APi Group Corporation" },
{ "ticker": "APOG", "name": "Apogee Enterprises" },
{ "ticker": "AIT", "name": "Applied Industrial Technologies" },
{ "ticker": "ASTE", "name": "Astec Industries" },
{ "ticker": "B", "name": "Barnes Group Inc." },
{ "ticker": "CRS", "name": "Carpenter Technology Corp" },
{ "ticker": "GTLS", "name": "Chart Industries" },
{ "ticker": "EME", "name": "EMCOR Group" },
{ "ticker": "FSS", "name": "Federal Signal Corp." },
{ "ticker": "GVA", "name": "Granite Construction" },
{ "ticker": "HEIA", "name": "HEICO Corp." },
{ "ticker": "LNN", "name": "Lindsay Corporation" },
{ "ticker": "MTZ", "name": "MasTec Inc." },
{ "ticker": "MOG.A", "name": "Moog Inc." },
{ "ticker": "MSA", "name": "MSA Safety Inc." },
{ "ticker": "MLI", "name": "Mueller Industries" },
{ "ticker": "TNC", "name": "Tennant Company" },
{ "ticker": "WTS", "name": "Watts Water Technologies" }
]
},
"Consumer Discretionary": {
"sector_weight": "10.01%",
"companies": [
{ "ticker": "ANF", "name": "Abercrombie & Fitch" },
{ "ticker": "ABG", "name": "Asbury Automotive Group, Inc." },
{ "ticker": "BOOT", "name": "Boot Barn" },
{ "ticker": "EAT", "name": "Brinker International" },
{ "ticker": "CVNA", "name": "Carvana Co." },
{ "ticker": "CNK", "name": "Cinemark" },
{ "ticker": "DDS", "name": "Dillard's" },
{ "ticker": "ELF", "name": "e.l.f. Beauty" },
{ "ticker": "GPI", "name": "Group 1 Automotive" },
{ "ticker": "LTH", "name": "Life Time Group" },
{ "ticker": "RRR", "name": "Red Rock Resorts" },
{ "ticker": "RUSH", "name": "Rush Enterprises" },
{ "ticker": "SHAK", "name": "Shake Shack" },
{ "ticker": "URBN", "name": "Urban Outfitters" }
]
},
"Real Estate": {
"sector_weight": "7.06%",
"companies": [
{ "ticker": "AAT", "name": "American Assets Trust" },
{ "ticker": "ADC", "name": "Agree Realty Corporation" },
{ "ticker": "AKR", "name": "Acadia Realty Trust" },
{ "ticker": "ALX", "name": "Alexander's Inc" },
{ "ticker": "APTS", "name": "Preferred Apartment Communities Inc" },
{ "ticker": "AHH", "name": "Armada Hoffler Properties Inc" },
{ "ticker": "AHT", "name": "Ashford Hospitality Trust Inc" },
{
"ticker": "ARI",
"name": "Apollo Commercial Real Estate Finance, Inc."
},
{ "ticker": "ACRE", "name": "Ares Commercial Real Estate Corp" },
{ "ticker": "AAIC", "name": "Arlington Asset Investment Corp" },
{ "ticker": "ARR", "name": "ARMOUR Residential REIT, Inc." },
{ "ticker": "BRG", "name": "Bluerock Residential Growth REIT Inc" },
{ "ticker": "CTRE", "name": "Caretrust REIT Inc" },
{ "ticker": "CTT", "name": "Catchmark Timber Trust Inc" },
{ "ticker": "CLDT", "name": "Chatham Lodging Trust" },
{ "ticker": "CUZ", "name": "Cousins Properties" },
{ "ticker": "CTO", "name": "CTO Realty Growth, Inc." },
{ "ticker": "DEA", "name": "Easterly Government Properties Inc" },
{ "ticker": "ELME", "name": "Elme Communities" },
{ "ticker": "FR", "name": "First Industrial Realty Trust Inc" },
{ "ticker": "FOR", "name": "Forestar Group Inc" },
{ "ticker": "FBRT", "name": "Franklin BSP Realty Trust" },
{ "ticker": "FSP", "name": "Franklin Street Properties Corp" },
{ "ticker": "FRPH", "name": "FRP Holdings Inc" },
{ "ticker": "GRBK", "name": "Green Brick Partners Inc" },
{ "ticker": "HASI", "name": "HASI" },
{ "ticker": "HR", "name": "Healthcare Realty Trust Inc" },
{ "ticker": "HT", "name": "Hersha Hospitality Trust" },
{ "ticker": "HIW", "name": "Highwoods Properties Inc" },
{ "ticker": "HPP", "name": "Hudson Pacific Properties Inc" },
{ "ticker": "IRT", "name": "Independence Realty Trust Inc" },
{ "ticker": "IRC", "name": "Inland Real Estate Corporation" },
{ "ticker": "INN", "name": "Summit Hotel Properties Inc" },
{ "ticker": "STAR", "name": "iStar" },
{ "ticker": "IRET", "name": "Investors Real Estate Trust" },
{ "ticker": "KW", "name": "Kennedy-Wilson Holdings Inc" },
{ "ticker": "LADR", "name": "Ladder Capital Corp" },
{ "ticker": "LGIH", "name": "LGI Homes Inc" },
{ "ticker": "LSI", "name": "Life Storage, Inc" },
{ "ticker": "LTC", "name": "LTC Properties Inc" },
{ "ticker": "LXP", "name": "LXP Industrial Trust" },
{ "ticker": "MMI", "name": "Marcus & Millichap Inc" },
{ "ticker": "MPW", "name": "Medical Properties Trust Inc" },
{ "ticker": "MITT", "name": "AG Mortgage Investment Trust Inc" },
{
"ticker": "MNR",
"name": "Monmouth Real Estate Investment Corporation"
},
{ "ticker": "NHI", "name": "National Health Investors Inc" },
{ "ticker": "SNR", "name": "New Senior Investment Group Inc" },
{ "ticker": "NYMT", "name": "New York Mortgage Trust Inc" },
{ "ticker": "NYRT", "name": "New York REIT Inc" },
{ "ticker": "NXRT", "name": "NexPoint Residential Trust, Inc." },
{ "ticker": "OLP", "name": "One Liberty Properties" },
{ "ticker": "ORC", "name": "Orchid Island Capital Inc" },
{ "ticker": "PEB", "name": "Pebblebrook Hotel Trust" },
{ "ticker": "PCH", "name": "PotlatchDeltic" },
{ "ticker": "PMT", "name": "PennyMac Mortgage Investment Trust" },
{
"ticker": "PEI",
"name": "Pennsylvania Real Estate Investment Trust"
},
{ "ticker": "QTS", "name": "QTS Realty Trust Inc" },
{ "ticker": "RAS", "name": "RAIT Financial Trust" },
{ "ticker": "RWT", "name": "Redwood Trust Inc" },
{ "ticker": "RITM", "name": "Rithm Capital Corp" },
{ "ticker": "RPRT", "name": "Rithm Property Trust" },
{ "ticker": "RPT", "name": "RPT Realty" },
{ "ticker": "SBRA", "name": "Sabra Health Care REIT" },
{ "ticker": "BFS", "name": "Saul Centers Inc" },
{ "ticker": "SUI", "name": "Sun Communities" },
{ "ticker": "SHO", "name": "Sunstone Hotel Investors" },
{ "ticker": "TRC", "name": "Tejon Ranch Co" },
{ "ticker": "UMH", "name": "UMH Properties Inc" },
{ "ticker": "UHT", "name": "Universal Health Realty Income Trust" },
{ "ticker": "UBA", "name": "Urstadt Biddle Properties Inc" },
{ "ticker": "VRE", "name": "Veris Residential" },
{ "ticker": "XHR", "name": "Xenia Hotels & Resorts Inc" }
]
},
"Energy": {
"sector_weight": "4.74%",
"companies": [
{ "ticker": "AXAS", "name": "Abraxas Petroleum Corp." },
{ "ticker": "AE", "name": "Adams Resources & Energy Inc." },
{ "ticker": "BAS", "name": "Basic Energy Services" },
{ "ticker": "BRS", "name": "Bristow Group Inc." },
{ "ticker": "CJ", "name": "C&J Energy Services" },
{ "ticker": "CPE", "name": "Callon Petroleum Company" },
{ "ticker": "CRR", "name": "CARBO Ceramics Inc." },
{ "ticker": "CHRD", "name": "Chord Energy Corporation" },
{ "ticker": "CLNE", "name": "Clean Energy Fuels Corp." },
{ "ticker": "DK", "name": "Delek US Holdings" },
{ "ticker": "DEC", "name": "Diversified Energy Company" },
{ "ticker": "ESOA", "name": "Energy Services of America" },
{ "ticker": "FET", "name": "Forum Energy Technologies" },
{ "ticker": "HLX", "name": "Helix Energy Solutions Group" },
{ "ticker": "KEG", "name": "Key Energy Services" },
{ "ticker": "MRC", "name": "MRC Global Inc." },
{ "ticker": "NGS", "name": "Natural Gas Services Group" },
{ "ticker": "NR", "name": "Newpark Resources" },
{ "ticker": "OIS", "name": "Oil States International" },
{ "ticker": "PDCE", "name": "PDC Energy" },
{ "ticker": "PHX", "name": "PHX Minerals Inc." },
{ "ticker": "WTI", "name": "W&T Offshore" }
]
},
"Materials": {
"sector_weight": "4.10%",
"companies": [
{ "ticker": "AAON", "name": "AAON Inc" },
{ "ticker": "ACET", "name": "Aceto Corp" },
{ "ticker": "ALTO", "name": "Alto Ingredients, Inc." },
{ "ticker": "AVD", "name": "American Vanguard Corp" },
{ "ticker": "APOG", "name": "Apogee Enterprises, Inc" },
{ "ticker": "AREX", "name": "Approach Resources, Inc" },
{ "ticker": "AVNT", "name": "Avient Corporation" },
{ "ticker": "BCPC", "name": "Balchem Corporation" },
{ "ticker": "BAS", "name": "Basic Energy Services, Inc" },
{ "ticker": "BERY", "name": "Berry Global Group Inc" },
{ "ticker": "BCC", "name": "Boise Cascade Co" },
{ "ticker": "BLDR", "name": "Builders FirstSource Inc" },
{ "ticker": "CBT", "name": "Cabot Corporation" },
{ "ticker": "CRR", "name": "CARBO Ceramics Inc" },
{ "ticker": "CENT", "name": "Central Garden & Pet Company" },
{ "ticker": "CENX", "name": "Century Aluminum Company" },
{ "ticker": "CCF", "name": "Chase Corporation" },
{ "ticker": "CLW", "name": "Clearwater Paper Corp" },
{ "ticker": "CLF", "name": "Cliffs Natural Resources Inc" },
{ "ticker": "CDE", "name": "Coeur Mining Inc" },
{ "ticker": "CMC", "name": "Commercial Metals Co" },
{ "ticker": "CBPX", "name": "Continental Building Products Inc" },
{ "ticker": "DAR", "name": "Darling Ingredients Inc" },
{ "ticker": "ESTE", "name": "Earthstone Energy Inc" },
{ "ticker": "WIRE", "name": "Encore Wire Corp" },
{ "ticker": "FOE", "name": "Ferro Corporation" },
{ "ticker": "FTK", "name": "Flotek Industries Inc" },
{ "ticker": "FET", "name": "Forum Energy Technologies Inc" },
{ "ticker": "FF", "name": "FutureFuel Corp" },
{ "ticker": "GEOS", "name": "Geospace Technologies Corporation" },
{ "ticker": "ROCK", "name": "Gibraltar Industries, Inc." },
{ "ticker": "GPRE", "name": "Green Plains Inc" },
{ "ticker": "GEF", "name": "Greif Inc" },
{ "ticker": "GFF", "name": "Griffon Corporation" },
{ "ticker": "HWKN", "name": "Hawkins Inc" },
{ "ticker": "HAYN", "name": "Haynes International Inc" },
{ "ticker": "HL", "name": "Hecla Mining Company" },
{ "ticker": "HLX", "name": "Helix Energy Solutions Group, Inc" },
{ "ticker": "IPHS", "name": "Innophos Holdings Inc" },
{ "ticker": "IOSP", "name": "Innospec Inc" },
{ "ticker": "IBP", "name": "Installed Building Products Inc" },
{ "ticker": "IIIN", "name": "Insteel Industries Inc" },
{ "ticker": "IO", "name": "Ion Geophysical Corp" },
{ "ticker": "KALU", "name": "Kaiser Aluminum Corporation" },
{ "ticker": "KEG", "name": "Key Energy Services Inc" },
{ "ticker": "KTOS", "name": "Kratos Defense & Security Solutions" },
{ "ticker": "LFCR", "name": "Lifecore Biomedical, Inc." },
{ "ticker": "LPX", "name": "Louisiana-Pacific Corp" },
{ "ticker": "LXU", "name": "LSB Industries Inc" },
{ "ticker": "DOOR", "name": "Masonite International Corp" },
{ "ticker": "MTRN", "name": "Materion Corporation" },
{ "ticker": "MTX", "name": "Minerals Technologies Inc" },
{ "ticker": "MRC", "name": "MRC Global Inc" },
{ "ticker": "NGS", "name": "Natural Gas Services Group Inc" },
{ "ticker": "NP", "name": "Neenah Paper, Inc." },
{ "ticker": "NR", "name": "Newpark Resources Inc" },
{ "ticker": "NL", "name": "NL Industries Inc" },
{ "ticker": "NTK", "name": "Nortek Inc" },
{ "ticker": "NWPX", "name": "Northwest Pipe Co" },
{ "ticker": "ODC", "name": "Oil-Dri Corporation of America" },
{ "ticker": "OIS", "name": "Oil States International Inc" },
{ "ticker": "OLN", "name": "Olin Corp" },
{ "ticker": "OMN", "name": "OMNOVA Solutions Inc" },
{ "ticker": "PHX", "name": "PHX Minerals Inc" },
{ "ticker": "KWR", "name": "Quaker Chemical Corporation" },
{ "ticker": "NX", "name": "Quanex Building Products Corporation" },
{ "ticker": "REX", "name": "REX American Resources Corporation" },
{ "ticker": "SCHN", "name": "Schnitzer Steel Industries Inc" },
{ "ticker": "SWM", "name": "Schweitzer-Mauduit International Inc" },
{ "ticker": "SXT", "name": "Sensient Technologies" },
{ "ticker": "SCL", "name": "Stepan Company" },
{ "ticker": "SXC", "name": "SunCoke Energy" },
{ "ticker": "TTI", "name": "TETRA Technologies Inc" },
{ "ticker": "TREX", "name": "Trex Co. Inc" },
{ "ticker": "UFPI", "name": "UFP Industries Inc" },
{ "ticker": "USLM", "name": "United States Lime & Minerals Inc" },
{ "ticker": "UEC", "name": "Uranium Energy Corp" },
{ "ticker": "VHI", "name": "Valhi, Inc" },
{ "ticker": "WDFC", "name": "WD-40 Company" },
{ "ticker": "WOR", "name": "Worthington Industries, Inc." }
]
},
"Consumer Staples": {
"sector_weight": "3.51%",
"companies": [
{ "ticker": "BRBR", "name": "BellRing Brands" },
{ "ticker": "CALM", "name": "Cal-Maine Foods" },
{ "ticker": "COKE", "name": "Coca-Cola Consolidated, Inc." },
{ "ticker": "FIZZ", "name": "National Beverage" },
{ "ticker": "HIMS", "name": "Hims & Hers Health, Inc." },
{ "ticker": "IPAR", "name": "Inter Parfums" },
{ "ticker": "LANC", "name": "Lancaster Colony" },
{ "ticker": "PBH", "name": "Prestige Consumer Healthcare" },
{ "ticker": "SFM", "name": "Sprouts Farmers Market, Inc." }
]
},
"Utilities": {
"sector_weight": "3.14%",
"companies": [
{ "ticker": "ALE", "name": "ALLETE Inc" },
{ "ticker": "AWR", "name": "American States Water Company" },
{ "ticker": "ARTNA", "name": "Artesian Resources Corporation" },
{ "ticker": "AT", "name": "Atlantic Power Corporation" },
{ "ticker": "AY", "name": "Atlantica Yield PLC" },
{ "ticker": "AVA", "name": "Avista Corp" },
{ "ticker": "BKH", "name": "Black Hills Corporation" },
{ "ticker": "CWT", "name": "California Water Service Group" },
{ "ticker": "CPK", "name": "Chesapeake Utilities Corporation" },
{ "ticker": "CWCO", "name": "Consolidated Water Co. Ltd" },
{ "ticker": "EE", "name": "El Paso Electric Company" },
{ "ticker": "IDA", "name": "IDACORP" },
{ "ticker": "MGEE", "name": "MGE Energy Inc" },
{ "ticker": "MSEX", "name": "Middlesex Water Company" },
{ "ticker": "NWN", "name": "Northwest Natural Gas Company" },
{ "ticker": "NWE", "name": "NorthWestern Corp" },
{ "ticker": "OGS", "name": "ONE Gas Inc" },
{ "ticker": "ORA", "name": "Ormat Technologies Inc" },
{ "ticker": "OTTR", "name": "Otter Tail Corporation" },
{ "ticker": "PEGI", "name": "Pattern Energy Group Inc" },
{ "ticker": "POR", "name": "Portland General Electric Company" },
{ "ticker": "SJW", "name": "SJW Group" },
{ "ticker": "SJI", "name": "South Jersey Industries" },
{ "ticker": "SWX", "name": "Southwest Gas Corp" },
{ "ticker": "UTL", "name": "Unitil Corporation" },
{ "ticker": "WGL", "name": "WGL Holdings Inc" },
{ "ticker": "YORW", "name": "York Water Co" }
]
},
"Communication Services": {
"sector_weight": "2.48%",
"companies": [
{ "ticker": "ADTN", "name": "ADTRAN Inc." },
{ "ticker": "ASTS", "name": "AST SpaceMobile" },
{ "ticker": "ATNI", "name": "Atlantic Tele-Network" },
{ "ticker": "CAMP", "name": "CalAmp Corp." },
{ "ticker": "CALX", "name": "Calix Inc." },
{ "ticker": "CIEN", "name": "Ciena Corporation" },
{ "ticker": "CCOI", "name": "Cogent Communications Group" },
{ "ticker": "CMTL", "name": "Comtech Telecommunications" },
{ "ticker": "CNSL", "name": "Consolidated Communications Holdings" },
{ "ticker": "DSPG", "name": "DSP Group Inc." },
{ "ticker": "SATS", "name": "EchoStar Corporation" },
{ "ticker": "GCI", "name": "General Communication Inc." },
{ "ticker": "GSAT", "name": "Globalstar Inc." },
{ "ticker": "GOGO", "name": "Gogo Inc." },
{ "ticker": "HLIT", "name": "Harmonic Inc." },
{ "ticker": "IDT", "name": "IDT Corporation" },
{ "ticker": "INFN", "name": "Infinera Corp." },
{ "ticker": "IDCC", "name": "InterDigital Communications" },
{ "ticker": "IRDM", "name": "Iridium Communications" },
{ "ticker": "KVHI", "name": "KVH Industries" },
{ "ticker": "LORL", "name": "Loral Space & Communications" },
{ "ticker": "NTGR", "name": "Netgear Inc." },
{ "ticker": "NSR", "name": "NeuStar Inc." },
{ "ticker": "ORBM", "name": "ORBCOMM Inc." },
{ "ticker": "SHEN", "name": "Shenandoah Telecommunications" },
{ "ticker": "VSAT", "name": "ViaSat Inc." },
{ "ticker": "VG", "name": "Vonage Holdings" },
{ "ticker": "WIN", "name": "Windstream Holdings" }
]
}
},
"summary": {
"total_companies": 1978,
"market_cap_representation": "$3.0 trillion",
"percentage_of_russell_3000": "7%",
"last_updated": "July 30, 2025",
"last_reconstitution": "June 27, 2025",
"minimum_market_cap": "$150.4 million",
"largest_company": {
"name": "Applied Industrial Tech",
"market_cap": "$7.1 billion"
},
"median_market_cap": "$840 million",
"average_market_cap": "$4.3 billion",
"russell_1000_2000_breakpoint": "$4.6 billion",
"rebalancing_frequency": "Annual (moving to semi-annual in 2026)",
"assets_benchmarked": "$8.5 trillion",
"passive_tracking": "$2.0 trillion",
"notes": [
"The Russell 2000 represents the small-cap segment of the U.S. equity universe",
"Sector weights and constituents change with annual reconstitution",
"Technology and Consumer Discretionary sectors rebounded in Q2 2025",
"Healthcare IPOs represented 7 of 11 new additions in 2025",
"Semi-annual reconstitution begins in 2026 (June and November)"
]
}
}

595
assets/indexes/snp500.json Normal file
View File

@ -0,0 +1,595 @@
{
"sp500_by_sector": {
"information_technology": {
"sector_weight": "31.7%",
"companies": [
{ "ticker": "AAPL", "name": "Apple Inc." },
{ "ticker": "MSFT", "name": "Microsoft Corporation" },
{ "ticker": "NVDA", "name": "NVIDIA Corporation" },
{ "ticker": "AVGO", "name": "Broadcom Inc." },
{ "ticker": "ORCL", "name": "Oracle Corporation" },
{ "ticker": "CRM", "name": "Salesforce, Inc." },
{ "ticker": "ADBE", "name": "Adobe Inc." },
{ "ticker": "AMD", "name": "Advanced Micro Devices, Inc." },
{ "ticker": "CSCO", "name": "Cisco Systems, Inc." },
{ "ticker": "TXN", "name": "Texas Instruments Incorporated" },
{ "ticker": "AMAT", "name": "Applied Materials, Inc." },
{ "ticker": "QCOM", "name": "QUALCOMM Incorporated" },
{ "ticker": "NOW", "name": "ServiceNow, Inc." },
{ "ticker": "INTU", "name": "Intuit Inc." },
{
"ticker": "IBM",
"name": "International Business Machines Corporation"
},
{ "ticker": "SNPS", "name": "Synopsys, Inc." },
{ "ticker": "CDNS", "name": "Cadence Design Systems, Inc." },
{ "ticker": "PANW", "name": "Palo Alto Networks, Inc." },
{ "ticker": "ANET", "name": "Arista Networks Inc" },
{ "ticker": "LRCX", "name": "Lam Research Corporation" },
{ "ticker": "KLAC", "name": "KLA Corporation" },
{ "ticker": "MCHP", "name": "Microchip Technology Incorporated" },
{ "ticker": "ADI", "name": "Analog Devices, Inc." },
{ "ticker": "MRVL", "name": "Marvell Technology, Inc." },
{ "ticker": "ADSK", "name": "Autodesk, Inc." },
{ "ticker": "MU", "name": "Micron Technology, Inc." },
{ "ticker": "ANSS", "name": "ANSYS, Inc." },
{ "ticker": "FTNT", "name": "Fortinet, Inc." },
{ "ticker": "NTAP", "name": "NetApp, Inc." },
{ "ticker": "PLTR", "name": "Palantir Technologies Inc." },
{ "ticker": "SNOW", "name": "Snowflake Inc." },
{ "ticker": "CRWD", "name": "CrowdStrike Holdings, Inc." },
{ "ticker": "DDOG", "name": "Datadog, Inc." },
{ "ticker": "ZS", "name": "Zscaler, Inc." },
{ "ticker": "MDB", "name": "MongoDB, Inc." },
{ "ticker": "NET", "name": "Cloudflare, Inc." },
{ "ticker": "TER", "name": "Teradyne, Inc." },
{ "ticker": "KEYS", "name": "Keysight Technologies, Inc." },
{ "ticker": "TYL", "name": "Tyler Technologies, Inc." },
{ "ticker": "FICO", "name": "Fair Isaac Corporation" },
{ "ticker": "ZBRA", "name": "Zebra Technologies Corporation" },
{ "ticker": "PTC", "name": "PTC Inc." },
{
"ticker": "CTSH",
"name": "Cognizant Technology Solutions Corporation"
},
{ "ticker": "ACN", "name": "Accenture plc" },
{ "ticker": "EPAM", "name": "EPAM Systems, Inc." },
{ "ticker": "IT", "name": "Gartner, Inc." },
{ "ticker": "CDW", "name": "CDW Corporation" },
{ "ticker": "WDC", "name": "Western Digital Corporation" },
{ "ticker": "STX", "name": "Seagate Technology Holdings plc" },
{ "ticker": "VRSN", "name": "VeriSign, Inc." },
{ "ticker": "GDDY", "name": "GoDaddy Inc." },
{ "ticker": "BR", "name": "Broadridge Financial Solutions, Inc." },
{ "ticker": "HPE", "name": "Hewlett Packard Enterprise Company" },
{ "ticker": "HPQ", "name": "HP Inc." },
{ "ticker": "TDY", "name": "Teledyne Technologies Incorporated" },
{ "ticker": "SMCI", "name": "Super Micro Computer, Inc." },
{ "ticker": "FSLR", "name": "First Solar, Inc." },
{ "ticker": "ON", "name": "ON Semiconductor Corporation" },
{ "ticker": "TRMB", "name": "Trimble Inc." },
{ "ticker": "FFIV", "name": "F5, Inc." },
{ "ticker": "NXPI", "name": "NXP Semiconductors N.V." },
{ "ticker": "TEL", "name": "TE Connectivity plc" },
{ "ticker": "APH", "name": "Amphenol Corporation" },
{ "ticker": "GLW", "name": "Corning Incorporated" },
{ "ticker": "MPWR", "name": "Monolithic Power Systems, Inc." },
{ "ticker": "ROP", "name": "Roper Technologies, Inc." },
{ "ticker": "MSI", "name": "Motorola Solutions, Inc." },
{ "ticker": "WDAY", "name": "Workday, Inc." },
{ "ticker": "INTC", "name": "Intel Corporation" },
{ "ticker": "JBL", "name": "Jabil Inc." },
{ "ticker": "GEN", "name": "Gen Digital Inc." }
]
},
"financials": {
"sector_weight": "14.0%",
"companies": [
{ "ticker": "JPM", "name": "JPMorgan Chase & Co." },
{ "ticker": "BRK.B", "name": "Berkshire Hathaway Inc. (Class B)" },
{ "ticker": "V", "name": "Visa Inc." },
{ "ticker": "MA", "name": "Mastercard Incorporated" },
{ "ticker": "BAC", "name": "Bank of America Corporation" },
{ "ticker": "WFC", "name": "Wells Fargo & Company" },
{ "ticker": "GS", "name": "The Goldman Sachs Group, Inc." },
{ "ticker": "AXP", "name": "American Express Company" },
{ "ticker": "MS", "name": "Morgan Stanley" },
{ "ticker": "SCHW", "name": "The Charles Schwab Corporation" },
{ "ticker": "BX", "name": "Blackstone Inc." },
{ "ticker": "PGR", "name": "The Progressive Corporation" },
{ "ticker": "CB", "name": "Chubb Limited" },
{ "ticker": "SPGI", "name": "S&P Global Inc." },
{ "ticker": "MMC", "name": "Marsh & McLennan Companies, Inc." },
{ "ticker": "BLK", "name": "BlackRock, Inc." },
{ "ticker": "CME", "name": "CME Group Inc." },
{ "ticker": "KKR", "name": "KKR & Co. Inc." },
{ "ticker": "AON", "name": "Aon plc" },
{ "ticker": "COF", "name": "Capital One Financial Corporation" },
{ "ticker": "PNC", "name": "The PNC Financial Services Group, Inc." },
{ "ticker": "USB", "name": "U.S. Bancorp" },
{ "ticker": "TFC", "name": "Truist Financial Corporation" },
{ "ticker": "C", "name": "Citigroup Inc." },
{ "ticker": "ICE", "name": "Intercontinental Exchange, Inc." },
{ "ticker": "MCO", "name": "Moody's Corporation" },
{ "ticker": "MSCI", "name": "MSCI Inc." },
{ "ticker": "TRV", "name": "The Travelers Companies, Inc." },
{ "ticker": "AFL", "name": "Aflac Incorporated" },
{ "ticker": "MET", "name": "MetLife, Inc." },
{ "ticker": "PRU", "name": "Prudential Financial, Inc." },
{ "ticker": "ALL", "name": "The Allstate Corporation" },
{ "ticker": "AIG", "name": "American International Group, Inc." },
{
"ticker": "WTW",
"name": "Willis Towers Watson Public Limited Company"
},
{ "ticker": "AJG", "name": "Arthur J. Gallagher & Co." },
{ "ticker": "AMP", "name": "Ameriprise Financial, Inc." },
{ "ticker": "TROW", "name": "T. Rowe Price Group, Inc." },
{ "ticker": "STT", "name": "State Street Corporation" },
{ "ticker": "NTRS", "name": "Northern Trust Corporation" },
{ "ticker": "RJF", "name": "Raymond James Financial, Inc." },
{ "ticker": "BEN", "name": "Franklin Resources, Inc." },
{ "ticker": "IVZ", "name": "Invesco Ltd." },
{ "ticker": "APO", "name": "Apollo Global Management, Inc." },
{ "ticker": "ARES", "name": "Ares Management Corporation" },
{ "ticker": "COIN", "name": "Coinbase Global, Inc." },
{ "ticker": "RF", "name": "Regions Financial Corporation" },
{ "ticker": "FITB", "name": "Fifth Third Bancorp" },
{ "ticker": "HBAN", "name": "Huntington Bancshares Incorporated" },
{ "ticker": "KEY", "name": "KeyCorp" },
{ "ticker": "CFG", "name": "Citizens Financial Group, Inc." },
{ "ticker": "MTB", "name": "M&T Bank Corporation" },
{ "ticker": "ZION", "name": "Zions Bancorporation" },
{ "ticker": "PYPL", "name": "PayPal Holdings, Inc." },
{
"ticker": "FIS",
"name": "Fidelity National Information Services, Inc."
},
{ "ticker": "FI", "name": "Fiserv, Inc." },
{ "ticker": "GPN", "name": "Global Payments Inc." },
{ "ticker": "SYF", "name": "Synchrony Financial" },
{ "ticker": "CPAY", "name": "Corpay, Inc." },
{ "ticker": "NDAQ", "name": "Nasdaq, Inc." },
{ "ticker": "CBOE", "name": "Cboe Global Markets, Inc." },
{ "ticker": "CINF", "name": "Cincinnati Financial Corporation" },
{ "ticker": "WRB", "name": "W. R. Berkley Corporation" },
{ "ticker": "L", "name": "Loews Corporation" },
{ "ticker": "PFG", "name": "Principal Financial Group, Inc." },
{
"ticker": "HIG",
"name": "The Hartford Financial Services Group, Inc."
},
{ "ticker": "ACGL", "name": "Arch Capital Group Ltd." },
{ "ticker": "EG", "name": "Everest Group, Ltd." },
{ "ticker": "ERIE", "name": "Erie Indemnity Company" },
{ "ticker": "BRO", "name": "Brown & Brown, Inc." },
{ "ticker": "FDS", "name": "FactSet Research Systems Inc." },
{ "ticker": "BK", "name": "The Bank of New York Mellon Corporation" }
]
},
"health_care": {
"sector_weight": "10.8%",
"companies": [
{ "ticker": "UNH", "name": "UnitedHealth Group Incorporated" },
{ "ticker": "JNJ", "name": "Johnson & Johnson" },
{ "ticker": "LLY", "name": "Eli Lilly and Company" },
{ "ticker": "PFE", "name": "Pfizer Inc." },
{ "ticker": "ABBV", "name": "AbbVie Inc." },
{ "ticker": "MRK", "name": "Merck & Co., Inc." },
{ "ticker": "ABT", "name": "Abbott Laboratories" },
{ "ticker": "TMO", "name": "Thermo Fisher Scientific Inc." },
{ "ticker": "DHR", "name": "Danaher Corporation" },
{ "ticker": "BMY", "name": "Bristol-Myers Squibb Company" },
{ "ticker": "AMGN", "name": "Amgen Inc." },
{ "ticker": "VRTX", "name": "Vertex Pharmaceuticals Incorporated" },
{ "ticker": "GILD", "name": "Gilead Sciences, Inc." },
{ "ticker": "BSX", "name": "Boston Scientific Corporation" },
{ "ticker": "SYK", "name": "Stryker Corporation" },
{ "ticker": "MDT", "name": "Medtronic plc" },
{ "ticker": "ISRG", "name": "Intuitive Surgical, Inc." },
{ "ticker": "ZTS", "name": "Zoetis Inc." },
{ "ticker": "CVS", "name": "CVS Health Corporation" },
{ "ticker": "ELV", "name": "Elevance Health, Inc." },
{ "ticker": "HUM", "name": "Humana Inc." },
{ "ticker": "CNC", "name": "Centene Corporation" },
{ "ticker": "HCA", "name": "HCA Healthcare, Inc." },
{ "ticker": "REGN", "name": "Regeneron Pharmaceuticals, Inc." },
{ "ticker": "MRNA", "name": "Moderna, Inc." },
{ "ticker": "EW", "name": "Edwards Lifesciences Corporation" },
{ "ticker": "BDX", "name": "Becton, Dickinson and Company" },
{ "ticker": "A", "name": "Agilent Technologies, Inc." },
{ "ticker": "IQV", "name": "IQVIA Holdings Inc." },
{ "ticker": "LH", "name": "Labcorp Holdings Inc." },
{ "ticker": "DGX", "name": "Quest Diagnostics Incorporated" },
{ "ticker": "MTD", "name": "Mettler-Toledo International Inc." },
{ "ticker": "WAT", "name": "Waters Corporation" },
{ "ticker": "TECH", "name": "Bio-Techne Corporation" },
{ "ticker": "WST", "name": "West Pharmaceutical Services, Inc." },
{ "ticker": "STE", "name": "STERIS plc" },
{ "ticker": "RMD", "name": "ResMed Inc." },
{ "ticker": "IDXX", "name": "IDEXX Laboratories, Inc." },
{ "ticker": "ALGN", "name": "Align Technology, Inc." },
{ "ticker": "COO", "name": "The Cooper Companies, Inc." },
{ "ticker": "ZBH", "name": "Zimmer Biomet Holdings, Inc." },
{ "ticker": "BAX", "name": "Baxter International Inc." },
{ "ticker": "VTRS", "name": "Viatris Inc." },
{ "ticker": "CTLT", "name": "Catalent, Inc." },
{
"ticker": "CRL",
"name": "Charles River Laboratories International, Inc."
},
{ "ticker": "BIO", "name": "Bio-Rad Laboratories, Inc." },
{ "ticker": "GEHC", "name": "GE HealthCare Technologies Inc." },
{ "ticker": "CI", "name": "The Cigna Group" },
{ "ticker": "CAH", "name": "Cardinal Health, Inc." },
{ "ticker": "COR", "name": "Cencora, Inc." },
{ "ticker": "MCK", "name": "McKesson Corporation" },
{ "ticker": "MOH", "name": "Molina Healthcare, Inc." },
{ "ticker": "DXCM", "name": "DexCom, Inc." },
{ "ticker": "PODD", "name": "Insulet Corporation" },
{ "ticker": "HOLX", "name": "Hologic, Inc." },
{ "ticker": "BIIB", "name": "Biogen Inc." }
]
},
"consumer_discretionary": {
"sector_weight": "10.4%",
"companies": [
{ "ticker": "AMZN", "name": "Amazon.com, Inc." },
{ "ticker": "TSLA", "name": "Tesla, Inc." },
{ "ticker": "HD", "name": "The Home Depot, Inc." },
{ "ticker": "MCD", "name": "McDonald's Corporation" },
{ "ticker": "NKE", "name": "NIKE, Inc." },
{ "ticker": "LOW", "name": "Lowe's Companies, Inc." },
{ "ticker": "TJX", "name": "The TJX Companies, Inc." },
{ "ticker": "BKNG", "name": "Booking Holdings Inc." },
{ "ticker": "GM", "name": "General Motors Company" },
{ "ticker": "F", "name": "Ford Motor Company" },
{ "ticker": "SBUX", "name": "Starbucks Corporation" },
{ "ticker": "TGT", "name": "Target Corporation" },
{ "ticker": "ROST", "name": "Ross Stores, Inc." },
{ "ticker": "DG", "name": "Dollar General Corporation" },
{ "ticker": "CMG", "name": "Chipotle Mexican Grill, Inc." },
{ "ticker": "MAR", "name": "Marriott International, Inc." },
{ "ticker": "HLT", "name": "Hilton Worldwide Holdings Inc." },
{ "ticker": "RCL", "name": "Royal Caribbean Cruises Ltd." },
{ "ticker": "ABNB", "name": "Airbnb, Inc." },
{ "ticker": "ORLY", "name": "O'Reilly Automotive, Inc." },
{ "ticker": "AZO", "name": "AutoZone, Inc." },
{ "ticker": "YUM", "name": "Yum! Brands, Inc." },
{ "ticker": "LVS", "name": "Las Vegas Sands Corp." },
{ "ticker": "WYNN", "name": "Wynn Resorts, Limited" },
{ "ticker": "MGM", "name": "MGM Resorts International" },
{ "ticker": "CZR", "name": "Caesars Entertainment, Inc." },
{ "ticker": "DHI", "name": "D.R. Horton, Inc." },
{ "ticker": "LEN", "name": "Lennar Corporation" },
{ "ticker": "PHM", "name": "PulteGroup, Inc." },
{ "ticker": "GPC", "name": "Genuine Parts Company" },
{ "ticker": "KMX", "name": "CarMax, Inc." },
{ "ticker": "TSCO", "name": "Tractor Supply Company" },
{ "ticker": "BBY", "name": "Best Buy Co., Inc." },
{ "ticker": "POOL", "name": "Pool Corporation" },
{ "ticker": "NCLH", "name": "Norwegian Cruise Line Holdings Ltd." },
{ "ticker": "CCL", "name": "Carnival Corporation & plc" },
{ "ticker": "DASH", "name": "DoorDash, Inc." },
{ "ticker": "ETSY", "name": "Etsy, Inc." },
{ "ticker": "EBAY", "name": "eBay Inc." },
{ "ticker": "CVNA", "name": "Carvana Co." },
{ "ticker": "WSM", "name": "Williams-Sonoma, Inc." },
{ "ticker": "GRMN", "name": "Garmin Ltd." },
{ "ticker": "RL", "name": "Ralph Lauren Corporation" },
{ "ticker": "TPR", "name": "Tapestry, Inc." },
{ "ticker": "HAS", "name": "Hasbro, Inc." },
{ "ticker": "MAT", "name": "Mattel, Inc." },
{ "ticker": "LULU", "name": "lululemon athletica inc." },
{ "ticker": "ULTA", "name": "Ulta Beauty, Inc." },
{ "ticker": "DECK", "name": "Deckers Outdoor Corporation" },
{ "ticker": "NVR", "name": "NVR, Inc." },
{ "ticker": "DRI", "name": "Darden Restaurants, Inc." },
{ "ticker": "DPZ", "name": "Domino's Pizza, Inc." },
{ "ticker": "EXPE", "name": "Expedia Group, Inc." },
{ "ticker": "APTV", "name": "Aptiv PLC" }
]
},
"communication_services": {
"sector_weight": "9.5%",
"companies": [
{ "ticker": "GOOGL", "name": "Alphabet Inc. (Class A)" },
{ "ticker": "GOOG", "name": "Alphabet Inc. (Class C)" },
{ "ticker": "META", "name": "Meta Platforms, Inc." },
{ "ticker": "NFLX", "name": "Netflix, Inc." },
{ "ticker": "CMCSA", "name": "Comcast Corporation" },
{ "ticker": "VZ", "name": "Verizon Communications Inc." },
{ "ticker": "T", "name": "AT&T Inc." },
{ "ticker": "TMUS", "name": "T-Mobile US, Inc." },
{ "ticker": "DIS", "name": "The Walt Disney Company" },
{ "ticker": "CHTR", "name": "Charter Communications, Inc." },
{ "ticker": "WBD", "name": "Warner Bros. Discovery, Inc." },
{ "ticker": "EA", "name": "Electronic Arts Inc." },
{ "ticker": "TTWO", "name": "Take-Two Interactive Software, Inc." },
{ "ticker": "ATVI", "name": "Activision Blizzard, Inc." },
{ "ticker": "PARA", "name": "Paramount Global" },
{ "ticker": "FOXA", "name": "Fox Corporation (Class A)" },
{ "ticker": "FOX", "name": "Fox Corporation (Class B)" },
{ "ticker": "NWSA", "name": "News Corporation (Class A)" },
{ "ticker": "NWS", "name": "News Corporation (Class B)" },
{ "ticker": "IPG", "name": "The Interpublic Group of Companies, Inc." },
{ "ticker": "OMC", "name": "Omnicom Group Inc." },
{ "ticker": "LYV", "name": "Live Nation Entertainment, Inc." },
{ "ticker": "TKO", "name": "TKO Group Holdings, Inc." },
{ "ticker": "MTCH", "name": "Match Group, Inc." },
{ "ticker": "PINS", "name": "Pinterest, Inc." },
{ "ticker": "SNAP", "name": "Snap Inc." },
{ "ticker": "ROKU", "name": "Roku, Inc." }
]
},
"industrials": {
"sector_weight": "7.7%",
"companies": [
{ "ticker": "GE", "name": "General Electric Company" },
{ "ticker": "CAT", "name": "Caterpillar Inc." },
{ "ticker": "RTX", "name": "RTX Corporation" },
{ "ticker": "HON", "name": "Honeywell International Inc." },
{ "ticker": "UNP", "name": "Union Pacific Corporation" },
{ "ticker": "BA", "name": "The Boeing Company" },
{ "ticker": "LMT", "name": "Lockheed Martin Corporation" },
{ "ticker": "GD", "name": "General Dynamics Corporation" },
{ "ticker": "NOC", "name": "Northrop Grumman Corporation" },
{ "ticker": "DE", "name": "Deere & Company" },
{ "ticker": "WM", "name": "Waste Management, Inc." },
{ "ticker": "UPS", "name": "United Parcel Service, Inc." },
{ "ticker": "FDX", "name": "FedEx Corporation" },
{ "ticker": "NSC", "name": "Norfolk Southern Corporation" },
{ "ticker": "CSX", "name": "CSX Corporation" },
{ "ticker": "ETN", "name": "Eaton Corporation plc" },
{ "ticker": "EMR", "name": "Emerson Electric Co." },
{ "ticker": "MMM", "name": "3M Company" },
{ "ticker": "ITW", "name": "Illinois Tool Works Inc." },
{ "ticker": "PH", "name": "Parker-Hannifin Corporation" },
{ "ticker": "JCI", "name": "Johnson Controls International plc" },
{ "ticker": "CARR", "name": "Carrier Global Corporation" },
{ "ticker": "OTIS", "name": "Otis Worldwide Corporation" },
{ "ticker": "CMI", "name": "Cummins Inc." },
{ "ticker": "PCAR", "name": "PACCAR Inc" },
{ "ticker": "DOV", "name": "Dover Corporation" },
{ "ticker": "FTV", "name": "Fortive Corporation" },
{ "ticker": "XYL", "name": "Xylem Inc." },
{ "ticker": "IEX", "name": "IDEX Corporation" },
{ "ticker": "IR", "name": "Ingersoll Rand Inc." },
{ "ticker": "RSG", "name": "Republic Services, Inc." },
{ "ticker": "CTAS", "name": "Cintas Corporation" },
{ "ticker": "FAST", "name": "Fastenal Company" },
{ "ticker": "GWW", "name": "W.W. Grainger, Inc." },
{ "ticker": "VRSK", "name": "Verisk Analytics, Inc." },
{ "ticker": "EFX", "name": "Equifax Inc." },
{ "ticker": "TDG", "name": "TransDigm Group Incorporated" },
{ "ticker": "LHX", "name": "L3Harris Technologies, Inc." },
{ "ticker": "TXT", "name": "Textron Inc." },
{ "ticker": "HWM", "name": "Howmet Aerospace Inc." },
{ "ticker": "HII", "name": "Huntington Ingalls Industries, Inc." },
{ "ticker": "LUV", "name": "Southwest Airlines Co." },
{ "ticker": "DAL", "name": "Delta Air Lines, Inc." },
{ "ticker": "UAL", "name": "United Airlines Holdings, Inc." },
{ "ticker": "AAL", "name": "American Airlines Group Inc." },
{ "ticker": "ALK", "name": "Alaska Air Group, Inc." },
{ "ticker": "ROL", "name": "Rollins, Inc." },
{ "ticker": "VLTO", "name": "Veralto Corporation" },
{ "ticker": "LII", "name": "Lennox International Inc." },
{ "ticker": "AME", "name": "AMETEK, Inc." },
{ "ticker": "PWR", "name": "Quanta Services, Inc." },
{ "ticker": "ROK", "name": "Rockwell Automation, Inc." },
{
"ticker": "WAB",
"name": "Westinghouse Air Brake Technologies Corporation"
},
{ "ticker": "AXON", "name": "Axon Enterprise, Inc." },
{ "ticker": "ODFL", "name": "Old Dominion Freight Line, Inc." },
{ "ticker": "CPRT", "name": "Copart, Inc." },
{ "ticker": "URI", "name": "United Rentals, Inc." },
{
"ticker": "EXPD",
"name": "Expeditors International of Washington, Inc."
},
{ "ticker": "J", "name": "Jacobs Solutions Inc." },
{ "ticker": "JBHT", "name": "J.B. Hunt Transport Services, Inc." },
{ "ticker": "LDOS", "name": "Leidos Holdings, Inc." },
{ "ticker": "HUBB", "name": "Hubbell Incorporated" },
{ "ticker": "PNR", "name": "Pentair plc" },
{ "ticker": "SNA", "name": "Snap-on Incorporated" },
{ "ticker": "TT", "name": "Trane Technologies plc" },
{ "ticker": "UBER", "name": "Uber Technologies, Inc." },
{ "ticker": "ADP", "name": "Automatic Data Processing, Inc." },
{ "ticker": "PAYX", "name": "Paychex, Inc." },
{ "ticker": "PAYC", "name": "Paycom Software, Inc." },
{ "ticker": "GEV", "name": "GE Vernova Inc." }
]
},
"consumer_staples": {
"sector_weight": "6.2%",
"companies": [
{ "ticker": "PG", "name": "The Procter & Gamble Company" },
{ "ticker": "KO", "name": "The Coca-Cola Company" },
{ "ticker": "PEP", "name": "PepsiCo, Inc." },
{ "ticker": "WMT", "name": "Walmart Inc." },
{ "ticker": "COST", "name": "Costco Wholesale Corporation" },
{ "ticker": "PM", "name": "Philip Morris International Inc." },
{ "ticker": "MO", "name": "Altria Group, Inc." },
{ "ticker": "MDLZ", "name": "Mondelez International, Inc." },
{ "ticker": "CL", "name": "Colgate-Palmolive Company" },
{ "ticker": "GIS", "name": "General Mills, Inc." },
{ "ticker": "KMB", "name": "Kimberly-Clark Corporation" },
{ "ticker": "CHD", "name": "Church & Dwight Co., Inc." },
{ "ticker": "CLX", "name": "The Clorox Company" },
{ "ticker": "SJM", "name": "The J.M. Smucker Company" },
{ "ticker": "HRL", "name": "Hormel Foods Corporation" },
{ "ticker": "CPB", "name": "Campbell Soup Company" },
{ "ticker": "CAG", "name": "Conagra Brands, Inc." },
{ "ticker": "KHC", "name": "The Kraft Heinz Company" },
{ "ticker": "TSN", "name": "Tyson Foods, Inc." },
{ "ticker": "KDP", "name": "Keurig Dr Pepper Inc." },
{ "ticker": "MNST", "name": "Monster Beverage Corporation" },
{ "ticker": "EL", "name": "The Estée Lauder Companies Inc." },
{ "ticker": "K", "name": "Kellanova" },
{ "ticker": "HSY", "name": "The Hershey Company" },
{ "ticker": "MKC", "name": "McCormick & Company, Incorporated" },
{ "ticker": "LW", "name": "Lamb Weston Holdings, Inc." },
{ "ticker": "KR", "name": "The Kroger Co." },
{ "ticker": "SYY", "name": "Sysco Corporation" },
{ "ticker": "ADM", "name": "Archer-Daniels-Midland Company" },
{ "ticker": "BG", "name": "Bunge Limited" },
{ "ticker": "DG", "name": "Dollar General Corporation" },
{ "ticker": "DLTR", "name": "Dollar Tree, Inc." },
{ "ticker": "KVUE", "name": "Kenvue Inc." },
{ "ticker": "TGT", "name": "Target Corporation" }
]
},
"energy": {
"sector_weight": "3.2%",
"companies": [
{ "ticker": "XOM", "name": "Exxon Mobil Corporation" },
{ "ticker": "CVX", "name": "Chevron Corporation" },
{ "ticker": "COP", "name": "ConocoPhillips" },
{ "ticker": "EOG", "name": "EOG Resources, Inc." },
{ "ticker": "MPC", "name": "Marathon Petroleum Corporation" },
{ "ticker": "PSX", "name": "Phillips 66" },
{ "ticker": "VLO", "name": "Valero Energy Corporation" },
{ "ticker": "SLB", "name": "Schlumberger Limited" },
{ "ticker": "BKR", "name": "Baker Hughes Company" },
{ "ticker": "HAL", "name": "Halliburton Company" },
{ "ticker": "KMI", "name": "Kinder Morgan, Inc." },
{ "ticker": "OKE", "name": "ONEOK, Inc." },
{ "ticker": "WMB", "name": "The Williams Companies, Inc." },
{ "ticker": "HES", "name": "Hess Corporation" },
{ "ticker": "FANG", "name": "Diamondback Energy, Inc." },
{ "ticker": "CTRA", "name": "Coterra Energy Inc." },
{ "ticker": "DVN", "name": "Devon Energy Corporation" },
{ "ticker": "OXY", "name": "Occidental Petroleum Corporation" },
{ "ticker": "APA", "name": "APA Corporation" },
{ "ticker": "TRGP", "name": "Targa Resources Corp." },
{ "ticker": "MRO", "name": "Marathon Oil Corporation" },
{ "ticker": "OVV", "name": "Ovintiv Inc." },
{ "ticker": "EQT", "name": "EQT Corporation" },
{ "ticker": "EXE", "name": "Expand Energy Corporation" },
{ "ticker": "TPL", "name": "Texas Pacific Land Corporation" }
]
},
"utilities": {
"sector_weight": "2.6%",
"companies": [
{ "ticker": "NEE", "name": "NextEra Energy, Inc." },
{ "ticker": "SO", "name": "The Southern Company" },
{ "ticker": "DUK", "name": "Duke Energy Corporation" },
{ "ticker": "D", "name": "Dominion Energy, Inc." },
{ "ticker": "AEP", "name": "American Electric Power Company, Inc." },
{ "ticker": "EXC", "name": "Exelon Corporation" },
{ "ticker": "XEL", "name": "Xcel Energy Inc." },
{ "ticker": "SRE", "name": "Sempra" },
{
"ticker": "PEG",
"name": "Public Service Enterprise Group Incorporated"
},
{ "ticker": "ED", "name": "Consolidated Edison, Inc." },
{ "ticker": "WEC", "name": "WEC Energy Group, Inc." },
{ "ticker": "ETR", "name": "Entergy Corporation" },
{ "ticker": "CMS", "name": "CMS Energy Corporation" },
{ "ticker": "FE", "name": "FirstEnergy Corp." },
{ "ticker": "EVRG", "name": "Evergy, Inc." },
{ "ticker": "AEE", "name": "Ameren Corporation" },
{ "ticker": "PPL", "name": "PPL Corporation" },
{ "ticker": "EIX", "name": "Edison International" },
{ "ticker": "NI", "name": "NiSource Inc." },
{ "ticker": "PNW", "name": "Pinnacle West Capital Corporation" },
{ "ticker": "AES", "name": "The AES Corporation" },
{ "ticker": "LNT", "name": "Alliant Energy Corporation" },
{ "ticker": "DTE", "name": "DTE Energy Company" },
{ "ticker": "ES", "name": "Eversource Energy" },
{ "ticker": "CNP", "name": "CenterPoint Energy, Inc." },
{ "ticker": "ATO", "name": "Atmos Energy Corporation" },
{ "ticker": "NRG", "name": "NRG Energy, Inc." },
{ "ticker": "AWK", "name": "American Water Works Company, Inc." },
{ "ticker": "VST", "name": "Vistra Corp." },
{ "ticker": "CEG", "name": "Constellation Energy Corporation" },
{ "ticker": "PCG", "name": "PG&E Corporation" }
]
},
"real_estate": {
"sector_weight": "2.2%",
"companies": [
{ "ticker": "AMT", "name": "American Tower Corporation" },
{ "ticker": "PLD", "name": "Prologis, Inc." },
{ "ticker": "CCI", "name": "Crown Castle Inc." },
{ "ticker": "EQIX", "name": "Equinix, Inc." },
{ "ticker": "PSA", "name": "Public Storage" },
{ "ticker": "WELL", "name": "Welltower Inc." },
{ "ticker": "DLR", "name": "Digital Realty Trust, Inc." },
{ "ticker": "SPG", "name": "Simon Property Group, Inc." },
{ "ticker": "O", "name": "Realty Income Corporation" },
{ "ticker": "VICI", "name": "VICI Properties Inc." },
{ "ticker": "EXR", "name": "Extra Space Storage Inc." },
{ "ticker": "AVB", "name": "AvalonBay Communities, Inc." },
{ "ticker": "EQR", "name": "Equity Residential" },
{ "ticker": "VTR", "name": "Ventas, Inc." },
{ "ticker": "BXP", "name": "Boston Properties, Inc." },
{ "ticker": "HST", "name": "Host Hotels & Resorts, Inc." },
{ "ticker": "MAA", "name": "Mid-America Apartment Communities, Inc." },
{ "ticker": "ESS", "name": "Essex Property Trust, Inc." },
{ "ticker": "INVH", "name": "Invitation Homes Inc." },
{ "ticker": "CPT", "name": "Camden Property Trust" },
{ "ticker": "ARE", "name": "Alexandria Real Estate Equities, Inc." },
{ "ticker": "IRM", "name": "Iron Mountain Incorporated" },
{ "ticker": "KIM", "name": "Kimco Realty Corporation" },
{ "ticker": "REG", "name": "Regency Centers Corporation" },
{ "ticker": "FRT", "name": "Federal Realty Investment Trust" },
{ "ticker": "SBAC", "name": "SBA Communications Corporation" },
{ "ticker": "UDR", "name": "UDR, Inc." },
{ "ticker": "PEAK", "name": "Healthpeak Properties, Inc." },
{ "ticker": "VNO", "name": "Vornado Realty Trust" },
{ "ticker": "DRE", "name": "Duke Realty Corporation" },
{ "ticker": "WY", "name": "Weyerhaeuser Company" },
{ "ticker": "CSGP", "name": "CoStar Group, Inc." },
{ "ticker": "CBRE", "name": "CBRE Group, Inc." }
]
},
"materials": {
"sector_weight": "1.8%",
"companies": [
{ "ticker": "LIN", "name": "Linde plc" },
{ "ticker": "APD", "name": "Air Products and Chemicals, Inc." },
{ "ticker": "SHW", "name": "The Sherwin-Williams Company" },
{ "ticker": "ECL", "name": "Ecolab Inc." },
{ "ticker": "FCX", "name": "Freeport-McMoRan Inc." },
{ "ticker": "NEM", "name": "Newmont Corporation" },
{ "ticker": "NUE", "name": "Nucor Corporation" },
{ "ticker": "STLD", "name": "Steel Dynamics, Inc." },
{ "ticker": "VMC", "name": "Vulcan Materials Company" },
{ "ticker": "MLM", "name": "Martin Marietta Materials, Inc." },
{ "ticker": "IP", "name": "International Paper Company" },
{ "ticker": "PKG", "name": "Packaging Corporation of America" },
{ "ticker": "BALL", "name": "Ball Corporation" },
{ "ticker": "CF", "name": "CF Industries Holdings, Inc." },
{ "ticker": "DOW", "name": "Dow Inc." },
{ "ticker": "LYB", "name": "LyondellBasell Industries N.V." },
{ "ticker": "PPG", "name": "PPG Industries, Inc." },
{ "ticker": "DD", "name": "DuPont de Nemours, Inc." },
{ "ticker": "EMN", "name": "Eastman Chemical Company" },
{ "ticker": "MOS", "name": "The Mosaic Company" },
{ "ticker": "ALB", "name": "Albemarle Corporation" },
{ "ticker": "CE", "name": "Celanese Corporation" },
{ "ticker": "WRK", "name": "WestRock Company" },
{ "ticker": "CTVA", "name": "Corteva, Inc." },
{ "ticker": "IFF", "name": "International Flavors & Fragrances Inc." },
{ "ticker": "RPM", "name": "RPM International Inc." },
{ "ticker": "AXTA", "name": "Axalta Coating Systems Ltd." },
{ "ticker": "AVY", "name": "Avery Dennison Corporation" },
{ "ticker": "AMCR", "name": "Amcor plc" },
{ "ticker": "SW", "name": "Smurfit Westrock Plc" }
]
}
},
"summary": {
"total_companies": 503,
"total_market_cap_representation": "~80% of US stock market",
"last_updated": "July 2025",
"minimum_market_cap_for_inclusion": "$22.7 billion",
"top_10_companies_weight": "~34% of total index value"
}
}

View File

@ -1,29 +1,32 @@
from typing import Optional
import datetime
import typer
from pathlib import Path
from functools import wraps
from rich.console import Console
from rich.panel import Panel
from rich.spinner import Spinner
from rich.live import Live
from rich.columns import Columns
from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
from rich.live import Live
from rich.table import Table
from collections import deque
import time
from rich.tree import Tree
from functools import wraps
from pathlib import Path
import typer
from rich import box
from rich.align import Align
from rich.rule import Rule
from rich.columns import Columns
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.spinner import Spinner
from rich.table import Table
from rich.text import Text
from cli.utils import (
get_analysis_date,
get_ticker,
select_analysts,
select_deep_thinking_agent,
select_llm_provider,
select_research_depth,
select_shallow_thinking_agent,
)
from tradingagents.config import TradingAgentsConfig
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
console = Console()
@ -99,7 +102,7 @@ class MessageBuffer:
if content is not None:
latest_section = section
latest_content = content
if latest_section and latest_content:
# Format the current section for display
section_titles = {
@ -304,16 +307,16 @@ def update_display(layout, spinner_text=None):
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
text_parts.append(item.get('text', ''))
elif item.get('type') == 'tool_use':
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "tool_use":
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
content_str = ' '.join(text_parts)
content_str = " ".join(text_parts)
elif not isinstance(content_str, str):
content_str = str(content)
# Truncate message content if too long
if len(content_str) > 200:
content_str = content_str[:197] + "..."
@ -338,10 +341,12 @@ def update_display(layout, spinner_text=None):
if spinner_text:
messages_table.add_row("", "Spinner", spinner_text)
# Add a footer to indicate if messages were truncated
# Add a footer row to indicate if messages were truncated
if len(all_messages) > max_messages:
messages_table.footer = (
f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]"
messages_table.add_row(
"",
"",
f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]",
)
layout["messages"].update(
@ -394,7 +399,7 @@ def update_display(layout, spinner_text=None):
def get_user_selections():
"""Get all user selections before starting the analysis display."""
# Display ASCII art welcome message
with open("./cli/static/welcome.txt", "r") as f:
with open("./cli/static/welcome.txt") as f:
welcome_ascii = f.read()
# Create welcome box content
@ -465,12 +470,10 @@ def get_user_selections():
# Step 5: OpenAI backend
console.print(
create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to"
)
create_question_box("Step 5: OpenAI backend", "Select which service to talk to")
)
selected_llm_provider, backend_url = select_llm_provider()
# Step 6: Thinking agents
console.print(
create_question_box(
@ -492,30 +495,6 @@ def get_user_selections():
}
def get_ticker():
"""Get ticker symbol from user input."""
return typer.prompt("", default="SPY")
def get_analysis_date():
"""Get the analysis date from user input."""
while True:
date_str = typer.prompt(
"", default=datetime.datetime.now().strftime("%Y-%m-%d")
)
try:
# Validate date format and ensure it's not in the future
analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
if analysis_date.date() > datetime.datetime.now().date():
console.print("[red]Error: Analysis date cannot be in the future[/red]")
continue
return date_str
except ValueError:
console.print(
"[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
)
def display_complete_report(final_state):
"""Display the complete analysis report with team-based panels."""
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
@ -712,6 +691,7 @@ def update_research_team_status(status):
for agent in research_team:
message_buffer.update_agent_status(agent, status)
def extract_content_string(content):
"""Extract string content from various message formats."""
if isinstance(content, str):
@ -721,28 +701,30 @@ def extract_content_string(content):
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
text_parts.append(item.get('text', ''))
elif item.get('type') == 'tool_use':
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "tool_use":
text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
else:
text_parts.append(str(item))
return ' '.join(text_parts)
return " ".join(text_parts)
else:
return str(content)
def run_analysis():
# First get all user selections
selections = get_user_selections()
# Create config with selected research depth
config = DEFAULT_CONFIG.copy()
config["max_debate_rounds"] = selections["research_depth"]
config["max_risk_discuss_rounds"] = selections["research_depth"]
config["quick_think_llm"] = selections["shallow_thinker"]
config["deep_think_llm"] = selections["deep_thinker"]
config["backend_url"] = selections["backend_url"]
config["llm_provider"] = selections["llm_provider"].lower()
config = TradingAgentsConfig(
max_debate_rounds=selections["research_depth"],
max_risk_discuss_rounds=selections["research_depth"],
quick_think_llm=selections["shallow_thinker"],
deep_think_llm=selections["deep_thinker"],
backend_url=selections["backend_url"],
llm_provider=selections["llm_provider"].lower(),
)
# Initialize the graph
graph = TradingAgentsGraph(
@ -750,7 +732,9 @@ def run_analysis():
)
# Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir = (
Path(config.results_dir) / selections["ticker"] / selections["analysis_date"]
)
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
@ -759,6 +743,7 @@ def run_analysis():
def save_message_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
@ -766,10 +751,12 @@ def run_analysis():
content = content.replace("\n", " ") # Replace newlines with spaces
with open(log_file, "a") as f:
f.write(f"{timestamp} [{message_type}] {content}\n")
return wrapper
def save_tool_call_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(*args, **kwargs):
func(*args, **kwargs)
@ -777,29 +764,39 @@ def run_analysis():
args_str = ", ".join(f"{k}={v}" for k, v in args.items())
with open(log_file, "a") as f:
f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
return wrapper
def save_report_section_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
def wrapper(section_name, content):
func(section_name, content)
if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
if (
section_name in obj.report_sections
and obj.report_sections[section_name] is not None
):
content = obj.report_sections[section_name]
if content:
file_name = f"{section_name}.md"
with open(report_dir / file_name, "w") as f:
f.write(content)
return wrapper
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
message_buffer.add_tool_call = save_tool_call_decorator(
message_buffer, "add_tool_call"
)
message_buffer.update_report_section = save_report_section_decorator(
message_buffer, "update_report_section"
)
# Now start the display layout
layout = create_layout()
with Live(layout, refresh_per_second=4) as live:
with Live(layout, refresh_per_second=4):
# Initial display
update_display(layout)
@ -850,14 +847,16 @@ def run_analysis():
# Extract message content and type
if hasattr(last_message, "content"):
content = extract_content_string(last_message.content) # Use the helper function
content = extract_content_string(
last_message.content
) # Use the helper function
msg_type = "Reasoning"
else:
content = str(last_message)
msg_type = "System"
# Add message to buffer
message_buffer.add_message(msg_type, content)
message_buffer.add_message(msg_type, content)
# If it's a tool call, add it to tool calls
if hasattr(last_message, "tool_calls"):
@ -1075,7 +1074,7 @@ def run_analysis():
# Get final state and decision
final_state = trace[-1]
decision = graph.process_signal(final_state["final_trade_decision"])
_decision = graph.process_signal(final_state["final_trade_decision"])
# Update all agent statuses to completed
for agent in message_buffer.agent_status:
@ -1086,7 +1085,7 @@ def run_analysis():
)
# Update final report sections
for section in message_buffer.report_sections.keys():
for section in message_buffer.report_sections:
if section in final_state:
message_buffer.update_report_section(section, final_state[section])

View File

@ -1,6 +1,4 @@
from enum import Enum
from typing import List, Optional, Dict
from pydantic import BaseModel
class AnalystType(str, Enum):

View File

@ -1,8 +1,10 @@
import questionary
from typing import List, Optional, Tuple, Dict
from rich.console import Console
from cli.models import AnalystType
console = Console()
ANALYST_ORDER = [
("Market Analyst", AnalystType.MARKET),
("Social Media Analyst", AnalystType.SOCIAL),
@ -64,7 +66,7 @@ def get_analysis_date() -> str:
return date.strip()
def select_analysts() -> List[AnalystType]:
def select_analysts() -> list[AnalystType]:
"""Select analysts using an interactive checkbox."""
choices = questionary.checkbox(
"Select Your [Analysts Team]:",
@ -129,30 +131,60 @@ def select_shallow_thinking_agent(provider) -> str:
SHALLOW_AGENT_OPTIONS = {
"openai": [
("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"),
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
(
"GPT-4.1-nano - Ultra-lightweight model for basic operations",
"gpt-4.1-nano",
),
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
(
"Claude Haiku 3.5 - Fast inference and standard capabilities",
"claude-3-5-haiku-latest",
),
(
"Claude Sonnet 3.5 - Highly capable standard model",
"claude-3-5-sonnet-latest",
),
(
"Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities",
"claude-3-7-sonnet-latest",
),
(
"Claude Sonnet 4 - High performance and excellent reasoning",
"claude-sonnet-4-0",
),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
(
"Gemini 2.0 Flash-Lite - Cost efficiency and low latency",
"gemini-2.0-flash-lite",
),
(
"Gemini 2.0 Flash - Next generation features, speed, and thinking",
"gemini-2.0-flash",
),
(
"Gemini 2.5 Flash - Adaptive thinking, cost efficiency",
"gemini-2.5-flash-preview-05-20",
),
],
"openrouter": [
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
(
"Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B",
"meta-llama/llama-3.3-8b-instruct:free",
),
(
"google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token",
"google/gemini-2.0-flash-exp:free",
),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"),
]
],
}
choice = questionary.select(
@ -186,7 +218,10 @@ def select_deep_thinking_agent(provider) -> str:
# Define deep thinking llm engine options with their corresponding model names
DEEP_AGENT_OPTIONS = {
"openai": [
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
(
"GPT-4.1-nano - Ultra-lightweight model for basic operations",
"gpt-4.1-nano",
),
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
("o4-mini - Specialized reasoning model (compact)", "o4-mini"),
@ -195,28 +230,55 @@ def select_deep_thinking_agent(provider) -> str:
("o1 - Premier reasoning and problem-solving model", "o1"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
(
"Claude Haiku 3.5 - Fast inference and standard capabilities",
"claude-3-5-haiku-latest",
),
(
"Claude Sonnet 3.5 - Highly capable standard model",
"claude-3-5-sonnet-latest",
),
(
"Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities",
"claude-3-7-sonnet-latest",
),
(
"Claude Sonnet 4 - High performance and excellent reasoning",
"claude-sonnet-4-0",
),
("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
(
"Gemini 2.0 Flash-Lite - Cost efficiency and low latency",
"gemini-2.0-flash-lite",
),
(
"Gemini 2.0 Flash - Next generation features, speed, and thinking",
"gemini-2.0-flash",
),
(
"Gemini 2.5 Flash - Adaptive thinking, cost efficiency",
"gemini-2.5-flash-preview-05-20",
),
("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"),
],
"openrouter": [
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
(
"DeepSeek V3 - a 685B-parameter, mixture-of-experts model",
"deepseek/deepseek-chat-v3-0324:free",
),
(
"Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.",
"deepseek/deepseek-chat-v3-0324:free",
),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"),
]
],
}
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
@ -239,6 +301,7 @@ def select_deep_thinking_agent(provider) -> str:
return choice
def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection."""
# Define OpenAI api options with their corresponding endpoints
@ -247,9 +310,9 @@ def select_llm_provider() -> tuple[str, str]:
("Anthropic", "https://api.anthropic.com/"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Ollama", "http://localhost:11434/v1"),
]
choice = questionary.select(
"Select your LLM Provider:",
choices=[
@ -265,12 +328,12 @@ def select_llm_provider() -> tuple[str, str]:
]
),
).ask()
if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
exit(1)
display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
return display_name, url

18
main.py
View File

@ -1,14 +1,14 @@
from tradingagents.config import TradingAgentsConfig
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model
config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Increase debate rounds
# Create a custom config using Anthropic
config = TradingAgentsConfig(
llm_provider="anthropic",
deep_think_llm="claude-3-5-sonnet-20241022",
quick_think_llm="claude-3-5-haiku-20241022",
max_debate_rounds=1,
online_tools=True,
)
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)

View File

@ -1,11 +1,10 @@
[project]
name = "tradingagents"
version = "0.1.0"
description = "Add your description here"
description = "Multi-Agents LLM Financial Trading Framework"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.13"
dependencies = [
"akshare>=1.16.98",
"backtrader>=1.9.78.123",
"chainlit>=2.5.5",
"chromadb>=1.0.12",
@ -20,6 +19,7 @@ dependencies = [
"pandas>=2.3.0",
"parsel>=1.10.0",
"praw>=7.8.1",
"python-dotenv>=1.1.0",
"pytz>=2025.2",
"questionary>=2.1.0",
"redis>=6.2.0",
@ -29,6 +29,119 @@ dependencies = [
"stockstats>=0.6.5",
"tqdm>=4.67.1",
"tushare>=1.4.21",
"typer>=0.12.0",
"typing-extensions>=4.14.0",
"yfinance>=0.2.63",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-cov>=5.0.0",
"pytest-asyncio>=0.24.0",
"ruff>=0.8.0",
"pyright>=1.1.390",
]
[project.scripts]
tradingagents = "cli.main:app"
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[tool.ruff]
line-length = 88
target-version = "py310"
extend-exclude = [
"migrations",
"venv",
".venv",
"build",
"dist",
"*.egg-info",
]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"ARG", # flake8-unused-arguments
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
]
ignore = [
"E501", # line too long, handled by formatter
"B008", # do not perform function calls in argument defaults
"C901", # too complex
"ARG002", # unused method argument
"ARG001", # unused function argument
]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"] # unused imports in __init__.py
"tests/**/*" = ["ARG", "SIM"] # test files can be more flexible
[tool.ruff.lint.isort]
known-first-party = ["tradingagents", "cli"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
[tool.pyright]
include = ["tradingagents", "cli", "main.py"]
exclude = [
"**/__pycache__",
"**/node_modules",
".venv",
"venv",
"build",
"dist",
]
pythonVersion = "3.10"
pythonPlatform = "All"
typeCheckingMode = "standard"
reportMissingImports = true
reportMissingTypeStubs = false
reportGeneralTypeIssues = true
reportOptionalMemberAccess = true
reportOptionalCall = true
reportOptionalIterable = true
reportOptionalContextManager = true
reportOptionalOperand = true
reportTypedDictNotRequiredAccess = false
reportPrivateImportUsage = false
reportUnknownParameterType = false
reportUnknownArgumentType = false
reportUnknownLambdaType = false
reportUnknownVariableType = false
reportUnknownMemberType = false
[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-ra -q --strict-markers --strict-config"
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"integration: marks tests as integration tests",
"unit: marks tests as unit tests",
]
[dependency-groups]
dev = [
"pytest>=8.4.1",
"pytest-asyncio>=1.1.0",
"pytest-cov>=6.2.1",
"pytest-vcr>=1.0.2",
"ruff>=0.12.5",
]

11
pyrightconfig.json Normal file
View File

@ -0,0 +1,11 @@
{
"venvPath": ".",
"venv": ".venv",
"pythonVersion": "3.13",
"typeCheckingMode": "standard",
"reportMissingImports": true,
"reportMissingTypeStubs": false,
"useLibraryCodeForTypes": true,
"autoSearchPaths": true,
"extraPaths": []
}

View File

@ -1,26 +0,0 @@
typing-extensions
langchain-openai
langchain-experimental
pandas
yfinance
praw
feedparser
stockstats
eodhd
langgraph
chromadb
setuptools
backtrader
akshare
tushare
finnhub-python
parsel
requests
tqdm
pytz
redis
chainlit
rich
questionary
langchain_anthropic
langchain-google-genai

View File

@ -2,7 +2,7 @@
Setup script for the TradingAgents package.
"""
from setuptools import setup, find_packages
from setuptools import find_packages, setup
setup(
name="tradingagents",

View File

@ -1,27 +1,24 @@
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
from .risk_mgmt.aggresive_debator import create_risky_debator
from .risk_mgmt.conservative_debator import create_safe_debator
from .risk_mgmt.neutral_debator import create_neutral_debator
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .trader.trader import create_trader
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.agent_utils import Toolkit, create_msg_delete
from .utils.memory import FinancialSituationMemory
from .utils.service_toolkit import ServiceToolkit
__all__ = [
"FinancialSituationMemory",
"Toolkit",
"ServiceToolkit",
"AgentState",
"create_msg_delete",
"InvestDebateState",

View File

@ -1,15 +1,12 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
def create_fundamentals_analyst(llm, toolkit):
def fundamentals_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
if toolkit.config["online_tools"]:
if toolkit.config.online_tools:
tools = [toolkit.get_fundamentals_openai]
else:
tools = [

View File

@ -1,16 +1,12 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
def create_market_analyst(llm, toolkit):
def market_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
if toolkit.config["online_tools"]:
if toolkit.config.online_tools:
tools = [
toolkit.get_YFin_data_online,
toolkit.get_stockstats_indicators_report_online,
@ -80,7 +76,7 @@ Volume-Based Indicators:
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"market_report": report,

View File

@ -1,6 +1,4 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
def create_news_analyst(llm, toolkit):
@ -8,7 +6,7 @@ def create_news_analyst(llm, toolkit):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
if toolkit.config["online_tools"]:
if toolkit.config.online_tools:
tools = [toolkit.get_global_news_openai, toolkit.get_google_news]
else:
tools = [

View File

@ -1,15 +1,12 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
def create_social_media_analyst(llm, toolkit):
def social_media_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
if toolkit.config["online_tools"]:
if toolkit.config.online_tools:
tools = [toolkit.get_stock_news_openai]
else:
tools = [

View File

@ -1,7 +1,3 @@
import time
import json
def create_research_manager(llm, memory):
def research_manager_node(state) -> dict:
history = state["investment_debate_state"].get("history", "")
@ -16,7 +12,7 @@ def create_research_manager(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for _, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
@ -28,7 +24,7 @@ Additionally, develop a detailed investment plan for the trader. This should inc
Your Recommendation: A decisive stance supported by the most convincing arguments.
Rationale: An explanation of why these arguments lead to your conclusion.
Strategic Actions: Concrete steps for implementing the recommendation.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
Here are your past reflections on mistakes:
\"{past_memory_str}\"

View File

@ -1,12 +1,5 @@
import time
import json
def create_risk_manager(llm, memory):
def risk_manager_node(state) -> dict:
company_name = state["company_of_interest"]
history = state["risk_debate_state"]["history"]
risk_debate_state = state["risk_debate_state"]
market_research_report = state["market_report"]
@ -19,7 +12,7 @@ def create_risk_manager(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for _, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
@ -36,7 +29,7 @@ Deliverables:
---
**Analysts Debate History:**
**Analysts Debate History:**
{history}
---

View File

@ -1,8 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_bear_researcher(llm, memory):
def bear_node(state) -> dict:
investment_debate_state = state["investment_debate_state"]
@ -19,7 +14,7 @@ def create_bear_researcher(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.

View File

@ -1,8 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_bull_researcher(llm, memory):
def bull_node(state) -> dict:
investment_debate_state = state["investment_debate_state"]
@ -19,7 +14,7 @@ def create_bull_researcher(llm, memory):
past_memories = memory.get_memories(curr_situation, n_matches=2)
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.

View File

@ -1,7 +1,3 @@
import time
import json
def create_risky_debator(llm):
def risky_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]

View File

@ -1,8 +1,3 @@
from langchain_core.messages import AIMessage
import time
import json
def create_safe_debator(llm):
def safe_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]

View File

@ -1,7 +1,3 @@
import time
import json
def create_neutral_debator(llm):
def neutral_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]

View File

@ -1,6 +1,4 @@
import functools
import time
import json
def create_trader(llm, memory):
@ -17,7 +15,7 @@ def create_trader(llm, memory):
past_memory_str = ""
if past_memories:
for i, rec in enumerate(past_memories, 1):
for _i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
else:
past_memory_str = "No past memories found."

View File

@ -1,10 +1,7 @@
from typing import Annotated, Sequence
from datetime import date, timedelta, datetime
from typing_extensions import TypedDict, Optional
from langchain_openai import ChatOpenAI
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
from typing import Annotated
from langgraph.graph import MessagesState
from typing_extensions import TypedDict
# Researcher team state

View File

@ -1,43 +1,38 @@
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage
from typing import List
from datetime import datetime
from typing import Annotated
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import RemoveMessage
from langchain_core.messages import HumanMessage, RemoveMessage
from langchain_core.tools import tool
from datetime import date, timedelta, datetime
import functools
import pandas as pd
import os
from dateutil.relativedelta import relativedelta
from langchain_openai import ChatOpenAI
import tradingagents.dataflows.interface as interface
from tradingagents.default_config import DEFAULT_CONFIG
from langchain_core.messages import HumanMessage
from tradingagents.config import TradingAgentsConfig
DEFAULT_CONFIG = TradingAgentsConfig()
def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
class Toolkit:
_config = DEFAULT_CONFIG.copy()
_config = TradingAgentsConfig()
@classmethod
def update_config(cls, config):
"""Update the class-level configuration."""
cls._config.update(config)
cls._config = config
@property
def config(self):
@ -60,7 +55,7 @@ class Toolkit:
Returns:
str: A formatted dataframe containing the latest global news from Reddit in the specified time frame.
"""
global_news_result = interface.get_reddit_global_news(curr_date, 7, 5)
return global_news_result
@ -87,9 +82,9 @@ class Toolkit:
end_date_str = end_date
end_date = datetime.strptime(end_date, "%Y-%m-%d")
start_date = datetime.strptime(start_date, "%Y-%m-%d")
look_back_days = (end_date - start_date).days
end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
look_back_days = (end_date_dt - start_date_dt).days
finnhub_news_result = interface.get_finnhub_news(
ticker, end_date_str, look_back_days
@ -138,7 +133,8 @@ class Toolkit:
result_data = interface.get_YFin_data(symbol, start_date, end_date)
return result_data
# Convert DataFrame to string for tool output
return result_data.to_string()
@staticmethod
@tool

View File

@ -0,0 +1,312 @@
"""
Helper functions for agents to work with JSON contexts from the new ServiceToolkit.
Provides utilities to parse and extract data from structured contexts.
"""
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
class ContextParser:
"""Helper class to parse and extract data from JSON contexts."""
@staticmethod
def parse_context(context_json: str) -> dict[str, Any]:
"""
Parse JSON context string into dictionary.
Args:
context_json: JSON string from toolkit method
Returns:
Dictionary representation of the context
"""
try:
return json.loads(context_json)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON context: {e}")
return {}
@staticmethod
def get_data_quality(context: dict[str, Any]) -> str:
"""Extract data quality from context metadata."""
return context.get("metadata", {}).get("data_quality", "UNKNOWN")
@staticmethod
def get_data_source(context: dict[str, Any]) -> str:
"""Extract data source from context metadata."""
return context.get("metadata", {}).get("data_source", "unknown")
@staticmethod
def is_high_quality(context: dict[str, Any]) -> bool:
"""Check if context has high quality data."""
return ContextParser.get_data_quality(context) == "HIGH"
@staticmethod
def is_fresh_data(context: dict[str, Any]) -> bool:
"""Check if context contains fresh (non-cached) data."""
source = ContextParser.get_data_source(context)
return source in ["live_api", "live_api_refresh"]
class MarketDataParser(ContextParser):
"""Parser for MarketDataContext objects."""
@staticmethod
def get_price_data(context: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract price data from market context."""
return context.get("price_data", [])
@staticmethod
def get_latest_price(context: dict[str, Any]) -> float | None:
"""Get the most recent closing price."""
price_data = MarketDataParser.get_price_data(context)
if price_data:
return price_data[-1].get("Close")
return None
@staticmethod
def get_technical_indicators(context: dict[str, Any]) -> dict[str, Any]:
"""Extract technical indicators from context."""
return context.get("technical_indicators", {})
@staticmethod
def get_indicator_value(context: dict[str, Any], indicator: str) -> float | None:
"""Get the latest value for a specific technical indicator."""
indicators = MarketDataParser.get_technical_indicators(context)
indicator_data = indicators.get(indicator, {})
if isinstance(indicator_data, dict) and "values" in indicator_data:
values = indicator_data["values"]
if values:
return values[-1] # Get latest value
return None
@staticmethod
def format_price_summary(context: dict[str, Any]) -> str:
"""Create a formatted summary of price data for agents."""
symbol = context.get("symbol", "UNKNOWN")
period = context.get("period", {})
price_data = MarketDataParser.get_price_data(context)
if not price_data:
return f"No price data available for {symbol}"
latest = price_data[-1]
first = price_data[0]
latest_price = latest.get("Close", 0)
start_price = first.get("Close", 0)
change = latest_price - start_price
change_pct = (change / start_price * 100) if start_price else 0
summary = f"""
Market Data Summary for {symbol}:
- Period: {period.get("start")} to {period.get("end")}
- Latest Price: ${latest_price:.2f}
- Period Change: ${change:.2f} ({change_pct:+.2f}%)
- Data Points: {len(price_data)}
- Data Quality: {MarketDataParser.get_data_quality(context)}
- Data Source: {MarketDataParser.get_data_source(context)}
""".strip()
return summary
class NewsParser(ContextParser):
"""Parser for NewsContext objects."""
@staticmethod
def get_articles(context: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract articles from news context."""
return context.get("articles", [])
@staticmethod
def get_sentiment_summary(context: dict[str, Any]) -> dict[str, Any]:
"""Extract overall sentiment summary."""
return context.get("sentiment_summary", {})
@staticmethod
def get_sentiment_score(context: dict[str, Any]) -> float:
"""Get overall sentiment score (-1 to 1)."""
sentiment = NewsParser.get_sentiment_summary(context)
return sentiment.get("score", 0.0)
@staticmethod
def get_sentiment_label(context: dict[str, Any]) -> str:
"""Get sentiment label (positive/negative/neutral)."""
sentiment = NewsParser.get_sentiment_summary(context)
return sentiment.get("label", "neutral")
@staticmethod
def format_news_summary(context: dict[str, Any]) -> str:
"""Create a formatted summary of news data for agents."""
symbol = context.get("symbol", "GLOBAL")
period = context.get("period", {})
articles = NewsParser.get_articles(context)
sentiment = NewsParser.get_sentiment_summary(context)
summary = f"""
News Analysis for {symbol}:
- Period: {period.get("start")} to {period.get("end")}
- Articles: {len(articles)}
- Overall Sentiment: {sentiment.get("label", "neutral").upper()} (score: {sentiment.get("score", 0):.2f})
- Confidence: {sentiment.get("confidence", 0):.2f}
- Data Quality: {NewsParser.get_data_quality(context)}
- Sources: {", ".join(context.get("sources", []))}
""".strip()
return summary
@staticmethod
def get_recent_headlines(context: dict[str, Any], limit: int = 5) -> list[str]:
"""Get recent headlines for quick overview."""
articles = NewsParser.get_articles(context)
return [article.get("headline", "") for article in articles[:limit]]
class SocialParser(ContextParser):
"""Parser for SocialContext objects."""
@staticmethod
def get_posts(context: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract posts from social context."""
return context.get("posts", [])
@staticmethod
def get_engagement_metrics(context: dict[str, Any]) -> dict[str, float]:
"""Extract engagement metrics."""
return context.get("engagement_metrics", {})
@staticmethod
def format_social_summary(context: dict[str, Any]) -> str:
"""Create a formatted summary of social media data for agents."""
symbol = context.get("symbol", "GLOBAL")
period = context.get("period", {})
posts = SocialParser.get_posts(context)
engagement = SocialParser.get_engagement_metrics(context)
summary = f"""
Social Media Analysis for {symbol}:
- Period: {period.get("start")} to {period.get("end")}
- Posts: {len(posts)}
- Total Engagement: {engagement.get("total_engagement", 0)}
- Average Engagement: {engagement.get("average_engagement", 0):.1f}
- Data Quality: {SocialParser.get_data_quality(context)}
- Platforms: {", ".join(context.get("platforms", []))}
""".strip()
return summary
class FundamentalParser(ContextParser):
"""Parser for FundamentalContext objects."""
@staticmethod
def get_key_ratios(context: dict[str, Any]) -> dict[str, float]:
"""Extract key financial ratios."""
return context.get("key_ratios", {})
@staticmethod
def get_balance_sheet(context: dict[str, Any]) -> dict[str, Any] | None:
"""Extract balance sheet data."""
return context.get("balance_sheet")
@staticmethod
def get_income_statement(context: dict[str, Any]) -> dict[str, Any] | None:
"""Extract income statement data."""
return context.get("income_statement")
@staticmethod
def format_fundamental_summary(context: dict[str, Any]) -> str:
"""Create a formatted summary of fundamental data for agents."""
symbol = context.get("symbol", "UNKNOWN")
ratios = FundamentalParser.get_key_ratios(context)
key_metrics = []
if "current_ratio" in ratios:
key_metrics.append(f"Current Ratio: {ratios['current_ratio']:.2f}")
if "debt_to_equity" in ratios:
key_metrics.append(f"D/E Ratio: {ratios['debt_to_equity']:.2f}")
if "roe" in ratios:
key_metrics.append(f"ROE: {ratios['roe']:.2%}")
summary = f"""
Fundamental Analysis for {symbol}:
- Key Ratios: {len(ratios)} available
- {chr(10).join(["- " + metric for metric in key_metrics[:5]])}
- Data Quality: {FundamentalParser.get_data_quality(context)}
""".strip()
return summary
def create_context_summary(context_json: str, context_type: str = "auto") -> str:
"""
Create a human-readable summary of any context.
Args:
context_json: JSON string from toolkit
context_type: Type of context (auto-detect if not specified)
Returns:
Formatted summary string
"""
try:
context = ContextParser.parse_context(context_json)
if context_type == "auto":
# Auto-detect context type based on fields
if "price_data" in context:
context_type = "market"
elif "articles" in context:
context_type = "news"
elif "posts" in context:
context_type = "social"
elif "key_ratios" in context:
context_type = "fundamental"
# Generate appropriate summary
if context_type == "market":
return MarketDataParser.format_price_summary(context)
elif context_type == "news":
return NewsParser.format_news_summary(context)
elif context_type == "social":
return SocialParser.format_social_summary(context)
elif context_type == "fundamental":
return FundamentalParser.format_fundamental_summary(context)
else:
# Generic summary
symbol = context.get("symbol", "N/A")
data_quality = ContextParser.get_data_quality(context)
data_source = ContextParser.get_data_source(context)
return (
f"Context for {symbol} - Quality: {data_quality}, Source: {data_source}"
)
except Exception as e:
logger.error(f"Error creating context summary: {e}")
return f"Error parsing context: {e}"
# Convenience functions for common operations
def extract_latest_price(market_context_json: str) -> float | None:
"""Quick extraction of latest price from market context."""
context = ContextParser.parse_context(market_context_json)
return MarketDataParser.get_latest_price(context)
def extract_sentiment_score(news_context_json: str) -> float:
"""Quick extraction of sentiment score from news context."""
context = ContextParser.parse_context(news_context_json)
return NewsParser.get_sentiment_score(context)
def is_high_quality_data(context_json: str) -> bool:
"""Quick check if context contains high quality data."""
context = ContextParser.parse_context(context_json)
return ContextParser.is_high_quality(context)

View File

@ -5,20 +5,18 @@ from openai import OpenAI
class FinancialSituationMemory:
def __init__(self, name, config):
if config["backend_url"] == "http://localhost:11434/v1":
if config.backend_url == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.client = OpenAI(base_url=config.backend_url)
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text):
"""Get OpenAI embedding for a text"""
response = self.client.embeddings.create(
model=self.embedding, input=text
)
response = self.client.embeddings.create(model=self.embedding, input=text)
return response.data[0].embedding
def add_situations(self, situations_and_advice):
@ -55,21 +53,42 @@ class FinancialSituationMemory:
)
matched_results = []
for i in range(len(results["documents"][0])):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
)
if (
results
and "documents" in results
and results["documents"]
and len(results["documents"]) > 0
):
for i in range(len(results["documents"][0])):
if (
"metadatas" in results
and results["metadatas"]
and len(results["metadatas"]) > 0
and i < len(results["metadatas"][0])
and "distances" in results
and results["distances"]
and len(results["distances"]) > 0
and i < len(results["distances"][0])
):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i].get(
"recommendation", ""
),
"similarity_score": 1 - results["distances"][0][i],
}
)
return matched_results
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory()
from tradingagents.config import TradingAgentsConfig
config = TradingAgentsConfig()
matcher = FinancialSituationMemory("example_memory", config)
# Example data
example_data = [
@ -96,7 +115,7 @@ if __name__ == "__main__":
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""

View File

@ -0,0 +1,312 @@
"""
New Toolkit class using Service/Client/Repository architecture with JSON context.
"""
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any
from langchain_core.messages import HumanMessage, RemoveMessage
from langchain_core.tools import tool
from tradingagents.config import TradingAgentsConfig
from tradingagents.services.builders import build_toolkit_services
if TYPE_CHECKING:
from tradingagents.services.market_data_service import MarketDataService
from tradingagents.services.news_service import NewsService
logger = logging.getLogger(__name__)
DEFAULT_CONFIG = TradingAgentsConfig()
def create_msg_delete():
"""Create message deletion function for agents."""
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
class Toolkit:
"""
New Toolkit class that uses services to provide JSON context to agents.
This replaces the old interface.py approach with structured Pydantic models
that agents can process more dynamically.
"""
def __init__(
self,
config: TradingAgentsConfig | None = None,
services: dict[str, Any] | None = None,
):
"""
Initialize Toolkit with services.
Args:
config: TradingAgents configuration
services: Pre-built services dict, or None to build from config
"""
self.config = config or DEFAULT_CONFIG
if services:
self.services = services
else:
logger.info("Building services from config")
self.services = build_toolkit_services(self.config)
# Set up individual services
self.market_service: MarketDataService | None = self.services.get("market_data")
self.news_service: NewsService | None = self.services.get("news")
logger.info(f"Toolkit initialized with {len(self.services)} services")
# Market Data Tools
def get_market_data(
self,
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve market data context for a given ticker symbol.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: JSON context containing market data with price data and metadata
"""
if not self.market_service:
return self._create_error_context("MarketDataService not available")
try:
context = self.market_service.get_price_context(
symbol, start_date, end_date
)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Error getting market data for {symbol}: {e}")
return self._create_error_context(f"Error fetching market data: {str(e)}")
@tool
def get_market_data_with_indicators(
self,
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
indicators: Annotated[
str, "Comma-separated list of indicators (e.g. 'rsi,macd,close_50_sma')"
] = "rsi,macd",
) -> str:
"""
Retrieve market data context with technical indicators.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
indicators (str): Comma-separated indicators
Returns:
str: JSON context containing market data with technical indicators
"""
if not self.market_service:
return self._create_error_context("MarketDataService not available")
try:
indicator_list = [i.strip() for i in indicators.split(",") if i.strip()]
context = self.market_service.get_context(
symbol, start_date, end_date, indicators=indicator_list
)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Error getting market data with indicators for {symbol}: {e}")
return self._create_error_context(
f"Error fetching market data with indicators: {str(e)}"
)
# News Tools
@tool
def get_company_news(
self,
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve news context for a specific company.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: JSON context containing news articles, sentiment analysis, and metadata
"""
if not self.news_service:
return self._create_error_context("NewsService not available")
try:
context = self.news_service.get_company_news_context(
symbol, start_date, end_date
)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Error getting company news for {symbol}: {e}")
return self._create_error_context(f"Error fetching company news: {str(e)}")
@tool
def get_global_news(
self,
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
categories: Annotated[
str, "Comma-separated news categories (e.g. 'economy,markets,finance')"
] = "economy,markets",
) -> str:
"""
Retrieve global/macro news context.
Args:
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
categories (str): Comma-separated news categories
Returns:
str: JSON context containing global news articles and sentiment analysis
"""
if not self.news_service:
return self._create_error_context("NewsService not available")
try:
category_list = [c.strip() for c in categories.split(",") if c.strip()]
context = self.news_service.get_global_news_context(
start_date, end_date, categories=category_list
)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Error getting global news: {e}")
return self._create_error_context(f"Error fetching global news: {str(e)}")
@tool
def get_news_by_query(
self,
query: Annotated[str, "Search query for news"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve news context for a specific query.
Args:
query (str): Search query for news
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: JSON context containing news articles and sentiment analysis
"""
if not self.news_service:
return self._create_error_context("NewsService not available")
try:
context = self.news_service.get_context(query, start_date, end_date)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Error getting news for query '{query}': {e}")
return self._create_error_context(f"Error fetching news: {str(e)}")
# Legacy compatibility methods (return JSON instead of markdown)
@tool
def get_YFin_data(
self,
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Legacy method: Retrieve market data (now returns JSON context).
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: JSON context containing market data
"""
return self.get_market_data(symbol, start_date, end_date)
@tool
def get_finnhub_news(
self,
ticker: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Legacy method: Retrieve company news (now returns JSON context).
Args:
ticker (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: JSON context containing news data
"""
return self.get_company_news(ticker, start_date, end_date)
# Utility methods
def _create_error_context(self, error_message: str) -> str:
"""Create a JSON error context."""
error_context = {
"error": True,
"message": error_message,
"metadata": {
"created_at": datetime.utcnow().isoformat(),
"source": "toolkit",
},
}
import json
return json.dumps(error_context, indent=2)
def get_available_tools(self) -> list:
"""Get list of available tools based on configured services."""
tools = []
if self.market_service:
tools.extend(
[
"get_market_data",
"get_market_data_with_indicators",
"get_YFin_data", # legacy
]
)
if self.news_service:
tools.extend(
[
"get_company_news",
"get_global_news",
"get_news_by_query",
"get_finnhub_news", # legacy
]
)
return tools
def get_toolkit_info(self) -> dict[str, Any]:
"""Get information about the toolkit configuration."""
return {
"toolkit_type": "service_based",
"config": {
"online_mode": self.config.online_tools,
"data_dir": self.config.data_dir,
},
"services": list(self.services.keys()),
"available_tools": self.get_available_tools(),
}

View File

@ -0,0 +1,622 @@
"""
Service-based toolkit for agents using the new JSON context services.
Replaces the old markdown-based interface.py with structured service calls.
"""
import json
from datetime import datetime, timedelta
from typing import Annotated, Any
from langchain_core.messages import HumanMessage, RemoveMessage
from langchain_core.tools import tool
from tradingagents.config import TradingAgentsConfig
from tradingagents.services.fundamental_data_service import FundamentalDataService
from tradingagents.services.insider_data_service import InsiderDataService
from tradingagents.services.market_data_service import MarketDataService
from tradingagents.services.news_service import NewsService
from tradingagents.services.openai_data_service import OpenAIDataService
from tradingagents.services.social_media_service import SocialMediaService
DEFAULT_CONFIG = TradingAgentsConfig()
def create_msg_delete():
"""Create message deletion function for Anthropic compatibility."""
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
class ServiceToolkit:
"""Service-based toolkit using the new JSON context services."""
def __init__(self, config: TradingAgentsConfig | None = None):
"""
Initialize the service toolkit.
Args:
config: Configuration object for services
"""
self._config = config or DEFAULT_CONFIG
# Services will be lazily initialized
self._market_service = None
self._news_service = None
self._social_service = None
self._fundamental_service = None
self._insider_service = None
self._openai_service = None
@property
def config(self):
"""Access the configuration."""
return self._config
def update_config(self, config: TradingAgentsConfig):
"""Update the configuration and reset services."""
self._config = config
# Reset services to force re-initialization with new config
self._market_service = None
self._news_service = None
self._social_service = None
self._fundamental_service = None
self._insider_service = None
self._openai_service = None
def _get_market_service(self) -> MarketDataService:
"""Lazy initialization of market data service."""
if self._market_service is None:
# This would typically use a service factory/builder
# For now, return a basic service
from tradingagents.services.builders import create_market_data_service
self._market_service = create_market_data_service(self._config)
return self._market_service
def _get_news_service(self) -> NewsService:
"""Lazy initialization of news service."""
if self._news_service is None:
from tradingagents.services.builders import create_news_service
self._news_service = create_news_service(self._config)
return self._news_service
def _get_social_service(self) -> SocialMediaService:
"""Lazy initialization of social media service."""
if self._social_service is None:
from tradingagents.services.builders import create_social_media_service
self._social_service = create_social_media_service(self._config)
return self._social_service
def _get_fundamental_service(self) -> FundamentalDataService:
"""Lazy initialization of fundamental data service."""
if self._fundamental_service is None:
from tradingagents.services.builders import create_fundamental_data_service
self._fundamental_service = create_fundamental_data_service(self._config)
return self._fundamental_service
def _get_insider_service(self) -> InsiderDataService:
"""Lazy initialization of insider data service."""
if self._insider_service is None:
from tradingagents.services.builders import create_insider_data_service
self._insider_service = create_insider_data_service(self._config)
return self._insider_service
def _get_openai_service(self) -> OpenAIDataService:
"""Lazy initialization of OpenAI data service."""
if self._openai_service is None:
from tradingagents.services.builders import create_openai_data_service
self._openai_service = create_openai_data_service(self._config)
return self._openai_service
def _context_to_string(self, context: Any) -> str:
"""Convert a context object to a formatted string for agents."""
# For now, convert to JSON string
# In the future, we might want more sophisticated formatting
return json.dumps(context.model_dump(), indent=2, default=str)
def _calculate_date_range(
self, curr_date: str, look_back_days: int
) -> tuple[str, str]:
"""Calculate start and end dates from current date and lookback days."""
end_date = datetime.strptime(curr_date, "%Y-%m-%d")
start_date = end_date - timedelta(days=look_back_days)
return start_date.strftime("%Y-%m-%d"), end_date.strftime("%Y-%m-%d")
@tool
def get_reddit_news(
self,
curr_date: Annotated[str, "Date you want to get news for in yyyy-mm-dd format"],
) -> str:
"""
Retrieve global news from Reddit within a specified time frame.
Args:
curr_date (str): Date you want to get news for in yyyy-mm-dd format
Returns:
str: A JSON-formatted context containing the latest global news from Reddit.
"""
start_date, end_date = self._calculate_date_range(curr_date, 7)
social_service = self._get_social_service()
context = social_service.get_global_trends(
start_date=start_date,
end_date=end_date,
subreddits=["news", "worldnews", "Economics"],
)
return self._context_to_string(context)
@tool
def get_finnhub_news(
self,
ticker: Annotated[str, "Search query of a company, e.g. 'AAPL, TSM, etc."],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve the latest news about a given stock from Finnhub within a date range.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A JSON-formatted context containing news about the company.
"""
news_service = self._get_news_service()
context = news_service.get_company_news_context(
symbol=ticker, start_date=start_date, end_date=end_date, sources=["finnhub"]
)
return self._context_to_string(context)
@tool
def get_reddit_stock_info(
self,
ticker: Annotated[str, "Ticker of a company. e.g. AAPL, TSM"],
curr_date: Annotated[str, "Current date you want to get news for"],
) -> str:
"""
Retrieve the latest news about a given stock from Reddit.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): current date in yyyy-mm-dd format to get news for
Returns:
str: A JSON-formatted context containing the latest news about the company.
"""
start_date, end_date = self._calculate_date_range(curr_date, 7)
social_service = self._get_social_service()
context = social_service.get_company_social_context(
symbol=ticker,
start_date=start_date,
end_date=end_date,
subreddits=["investing", "stocks", "wallstreetbets"],
)
return self._context_to_string(context)
@tool
def get_YFin_data(
self,
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve the stock price data for a given ticker symbol from Yahoo Finance.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A JSON-formatted context containing stock price data and technical indicators.
"""
market_service = self._get_market_service()
context = market_service.get_context(
symbol=symbol,
start_date=start_date,
end_date=end_date,
force_refresh=False, # Use local-first strategy
)
return self._context_to_string(context)
@tool
def get_YFin_data_online(
self,
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve fresh stock price data for a given ticker symbol from Yahoo Finance.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: A JSON-formatted context containing fresh stock price data.
"""
market_service = self._get_market_service()
context = market_service.get_context(
symbol=symbol,
start_date=start_date,
end_date=end_date,
force_refresh=True, # Force fresh data
)
return self._context_to_string(context)
@tool
def get_stockstats_indicators_report(
self,
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[
str, "technical indicator to get the analysis and report of"
],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
],
look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str:
"""
Retrieve stock stats indicators for a given ticker symbol and indicator.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
indicator (str): Technical indicator to get the analysis and report of
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
look_back_days (int): How many days to look back, default is 30
Returns:
str: A JSON-formatted context containing technical indicators.
"""
start_date, end_date = self._calculate_date_range(curr_date, look_back_days)
market_service = self._get_market_service()
context = market_service.get_context(
symbol=symbol,
start_date=start_date,
end_date=end_date,
indicators=[indicator],
force_refresh=False,
)
return self._context_to_string(context)
@tool
def get_stockstats_indicators_report_online(
self,
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[
str, "technical indicator to get the analysis and report of"
],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
],
look_back_days: Annotated[int, "how many days to look back"] = 30,
) -> str:
"""
Retrieve fresh stock stats indicators for a given ticker symbol and indicator.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
indicator (str): Technical indicator to get the analysis and report of
curr_date (str): The current trading date you are trading on, YYYY-mm-dd
look_back_days (int): How many days to look back, default is 30
Returns:
str: A JSON-formatted context containing fresh technical indicators.
"""
start_date, end_date = self._calculate_date_range(curr_date, look_back_days)
market_service = self._get_market_service()
context = market_service.get_context(
symbol=symbol,
start_date=start_date,
end_date=end_date,
indicators=[indicator],
force_refresh=True,
)
return self._context_to_string(context)
@tool
def get_finnhub_company_insider_sentiment(
self,
ticker: Annotated[str, "ticker symbol for the company"],
curr_date: Annotated[str, "current date of you are trading at, yyyy-mm-dd"],
) -> str:
"""
Retrieve insider sentiment information about a company for the past 30 days.
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: A JSON-formatted context with insider trading sentiment analysis.
"""
start_date, end_date = self._calculate_date_range(curr_date, 30)
insider_service = self._get_insider_service()
context = insider_service.get_insider_context(
symbol=ticker, start_date=start_date, end_date=end_date
)
return self._context_to_string(context)
@tool
def get_finnhub_company_insider_transactions(
self,
ticker: Annotated[str, "ticker symbol"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
) -> str:
"""
Retrieve insider transaction information about a company for the past 30 days.
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: A JSON-formatted context with insider transaction details.
"""
start_date, end_date = self._calculate_date_range(curr_date, 30)
insider_service = self._get_insider_service()
context = insider_service.get_insider_context(
symbol=ticker, start_date=start_date, end_date=end_date
)
return self._context_to_string(context)
@tool
def get_simfin_balance_sheet(
self,
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[str, "reporting frequency: annual/quarterly"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
) -> str:
"""
Retrieve the most recent balance sheet of a company.
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: A JSON-formatted context with the company's balance sheet.
"""
# Use a reasonable date range for fundamental data
start_date = (
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=365)
).strftime("%Y-%m-%d")
fundamental_service = self._get_fundamental_service()
context = fundamental_service.get_fundamental_context(
symbol=ticker, start_date=start_date, end_date=curr_date, frequency=freq
)
return self._context_to_string(context)
@tool
def get_simfin_cashflow(
self,
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[str, "reporting frequency: annual/quarterly"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
) -> str:
"""
Retrieve the most recent cash flow statement of a company.
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: A JSON-formatted context with the company's cash flow statement.
"""
start_date = (
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=365)
).strftime("%Y-%m-%d")
fundamental_service = self._get_fundamental_service()
context = fundamental_service.get_fundamental_context(
symbol=ticker, start_date=start_date, end_date=curr_date, frequency=freq
)
return self._context_to_string(context)
@tool
def get_simfin_income_stmt(
self,
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[str, "reporting frequency: annual/quarterly"],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
) -> str:
"""
Retrieve the most recent income statement of a company.
Args:
ticker (str): ticker symbol of the company
freq (str): reporting frequency: annual / quarterly
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: A JSON-formatted context with the company's income statement.
"""
start_date = (
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=365)
).strftime("%Y-%m-%d")
fundamental_service = self._get_fundamental_service()
context = fundamental_service.get_fundamental_context(
symbol=ticker, start_date=start_date, end_date=curr_date, frequency=freq
)
return self._context_to_string(context)
@tool
def get_google_news(
self,
query: Annotated[str, "Query to search with"],
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve the latest news from Google News based on a query and date range.
Args:
query (str): Query to search with
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A JSON-formatted context containing the latest news from Google News.
"""
start_date, end_date = self._calculate_date_range(curr_date, 7)
news_service = self._get_news_service()
context = news_service.get_context(
query=query, start_date=start_date, end_date=end_date, sources=["google"]
)
return self._context_to_string(context)
@tool
def get_stock_news_openai(
self,
ticker: Annotated[str, "the company's ticker"],
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve AI-generated news analysis about a given stock.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A JSON-formatted context with AI news analysis.
"""
# First get the news context
start_date, end_date = self._calculate_date_range(curr_date, 7)
news_service = self._get_news_service()
news_context = news_service.get_company_news_context(
symbol=ticker, start_date=start_date, end_date=end_date
)
# Then get AI analysis of the news
openai_service = self._get_openai_service()
context = openai_service.get_news_impact_analysis(
symbol=ticker, news_context=self._context_to_string(news_context)
)
return self._context_to_string(context)
@tool
def get_global_news_openai(
self,
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve AI-generated macroeconomic news analysis.
Args:
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A JSON-formatted context with AI macroeconomic analysis.
"""
start_date, end_date = self._calculate_date_range(curr_date, 7)
# Get global news context
news_service = self._get_news_service()
news_context = news_service.get_global_news_context(
start_date=start_date,
end_date=end_date,
categories=["economy", "markets", "finance"],
)
# Get AI analysis
openai_service = self._get_openai_service()
context = openai_service.get_news_impact_analysis(
symbol="GLOBAL", # Use a placeholder for global analysis
news_context=self._context_to_string(news_context),
)
return self._context_to_string(context)
@tool
def get_fundamentals_openai(
self,
ticker: Annotated[str, "the company's ticker"],
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve AI-generated fundamental analysis about a given stock.
Args:
ticker (str): Ticker of a company. e.g. AAPL, TSM
curr_date (str): Current date in yyyy-mm-dd format
Returns:
str: A JSON-formatted context with AI fundamental analysis.
"""
# Get fundamental data context
start_date = (
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=365)
).strftime("%Y-%m-%d")
fundamental_service = self._get_fundamental_service()
fundamental_context = fundamental_service.get_fundamental_context(
symbol=ticker,
start_date=start_date,
end_date=curr_date,
frequency="quarterly",
)
# Get market data for additional context
market_start = (
datetime.strptime(curr_date, "%Y-%m-%d") - timedelta(days=30)
).strftime("%Y-%m-%d")
market_service = self._get_market_service()
market_context = market_service.get_context(
symbol=ticker, start_date=market_start, end_date=curr_date
)
# Combine contexts for AI analysis
combined_context = {
"fundamental_data": fundamental_context.model_dump(),
"market_data": market_context.model_dump(),
}
# Get AI analysis
openai_service = self._get_openai_service()
context = openai_service.get_market_sentiment_analysis(
symbol=ticker, market_data_context=json.dumps(combined_context, default=str)
)
return self._context_to_string(context)
# Create a default instance for backward compatibility
default_toolkit = ServiceToolkit()
# Export individual tools for use in agents
get_reddit_news = default_toolkit.get_reddit_news
get_finnhub_news = default_toolkit.get_finnhub_news
get_reddit_stock_info = default_toolkit.get_reddit_stock_info
get_YFin_data = default_toolkit.get_YFin_data
get_YFin_data_online = default_toolkit.get_YFin_data_online
get_stockstats_indicators_report = default_toolkit.get_stockstats_indicators_report
get_stockstats_indicators_report_online = (
default_toolkit.get_stockstats_indicators_report_online
)
get_finnhub_company_insider_sentiment = (
default_toolkit.get_finnhub_company_insider_sentiment
)
get_finnhub_company_insider_transactions = (
default_toolkit.get_finnhub_company_insider_transactions
)
get_simfin_balance_sheet = default_toolkit.get_simfin_balance_sheet
get_simfin_cashflow = default_toolkit.get_simfin_cashflow
get_simfin_income_stmt = default_toolkit.get_simfin_income_stmt
get_google_news = default_toolkit.get_google_news
get_stock_news_openai = default_toolkit.get_stock_news_openai
get_global_news_openai = default_toolkit.get_global_news_openai
get_fundamentals_openai = default_toolkit.get_fundamentals_openai

View File

@ -0,0 +1,267 @@
"""
Updated Toolkit class using Service/Client/Repository architecture with JSON context.
"""
import logging
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any
from langchain_core.messages import HumanMessage, RemoveMessage
from langchain_core.tools import tool
from tradingagents.config import TradingAgentsConfig
from tradingagents.services.builders import build_toolkit_services
if TYPE_CHECKING:
from tradingagents.services.market_data_service import MarketDataService
from tradingagents.services.news_service import NewsService
logger = logging.getLogger(__name__)
DEFAULT_CONFIG = TradingAgentsConfig()
def create_msg_delete():
"""Create message deletion function for agents."""
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
class Toolkit:
"""
Toolkit class that uses services to provide JSON context to agents.
This replaces the old interface.py approach with structured Pydantic models
that agents can process more dynamically.
"""
def __init__(
self,
config: TradingAgentsConfig | None = None,
services: dict[str, Any] | None = None,
):
"""
Initialize Toolkit with services.
Args:
config: TradingAgents configuration
services: Pre-built services dict, or None to build from config
"""
self.config = config or DEFAULT_CONFIG
if services:
self.services = services
else:
logger.info("Building services from config")
self.services = build_toolkit_services(self.config)
# Set up individual services
self.market_service: MarketDataService | None = self.services.get("market_data")
self.news_service: NewsService | None = self.services.get("news")
logger.info(f"Toolkit initialized with {len(self.services)} services")
# Create tool methods as static methods with service access via closure
def _create_market_data_tool(self):
"""Create market data tool with service access."""
market_service = self.market_service
@tool
def get_market_data(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve market data context for a given ticker symbol.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: JSON context containing market data with price data and metadata
"""
if not market_service:
return _create_error_context("MarketDataService not available")
try:
context = market_service.get_price_context(symbol, start_date, end_date)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Error getting market data for {symbol}: {e}")
return _create_error_context(f"Error fetching market data: {str(e)}")
return get_market_data
def _create_market_indicators_tool(self):
"""Create market data with indicators tool."""
market_service = self.market_service
@tool
def get_market_data_with_indicators(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
indicators: Annotated[
str, "Comma-separated list of indicators (e.g. 'rsi,macd,close_50_sma')"
] = "rsi,macd",
) -> str:
"""
Retrieve market data context with technical indicators.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
indicators (str): Comma-separated indicators
Returns:
str: JSON context containing market data with technical indicators
"""
if not market_service:
return _create_error_context("MarketDataService not available")
try:
indicator_list = [i.strip() for i in indicators.split(",") if i.strip()]
context = market_service.get_context(
symbol, start_date, end_date, indicators=indicator_list
)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(
f"Error getting market data with indicators for {symbol}: {e}"
)
return _create_error_context(
f"Error fetching market data with indicators: {str(e)}"
)
return get_market_data_with_indicators
def _create_company_news_tool(self):
"""Create company news tool."""
news_service = self.news_service
@tool
def get_company_news(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
"""
Retrieve news context for a specific company.
Args:
symbol (str): Ticker symbol of the company, e.g. AAPL, TSM
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns:
str: JSON context containing news articles, sentiment analysis, and metadata
"""
if not news_service:
return _create_error_context("NewsService not available")
try:
context = news_service.get_company_news_context(
symbol, start_date, end_date
)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Error getting company news for {symbol}: {e}")
return _create_error_context(f"Error fetching company news: {str(e)}")
return get_company_news
def _create_global_news_tool(self):
"""Create global news tool."""
news_service = self.news_service
@tool
def get_global_news(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
categories: Annotated[
str, "Comma-separated news categories (e.g. 'economy,markets,finance')"
] = "economy,markets",
) -> str:
"""
Retrieve global/macro news context.
Args:
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
categories (str): Comma-separated news categories
Returns:
str: JSON context containing global news articles and sentiment analysis
"""
if not news_service:
return _create_error_context("NewsService not available")
try:
category_list = [c.strip() for c in categories.split(",") if c.strip()]
context = news_service.get_global_news_context(
start_date, end_date, categories=category_list
)
return context.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Error getting global news: {e}")
return _create_error_context(f"Error fetching global news: {str(e)}")
return get_global_news
def get_tools(self):
"""Get all available tools as LangChain tools."""
tools = []
if self.market_service:
tools.append(self._create_market_data_tool())
tools.append(self._create_market_indicators_tool())
if self.news_service:
tools.append(self._create_company_news_tool())
tools.append(self._create_global_news_tool())
return tools
def get_available_tools(self) -> list:
"""Get list of available tool names based on configured services."""
tools = []
if self.market_service:
tools.extend(["get_market_data", "get_market_data_with_indicators"])
if self.news_service:
tools.extend(["get_company_news", "get_global_news"])
return tools
def get_toolkit_info(self) -> dict[str, Any]:
"""Get information about the toolkit configuration."""
return {
"toolkit_type": "service_based",
"config": {
"online_mode": self.config.online_tools,
"data_dir": self.config.data_dir,
},
"services": list(self.services.keys()),
"available_tools": self.get_available_tools(),
}
def _create_error_context(error_message: str) -> str:
"""Create a JSON error context."""
import json
error_context = {
"error": True,
"message": error_message,
"metadata": {"created_at": datetime.utcnow().isoformat(), "source": "toolkit"},
}
return json.dumps(error_context, indent=2)

View File

@ -0,0 +1,19 @@
"""
Client classes for live data access in TradingAgents.
"""
# Re-export existing clients from dataflows
from tradingagents.dataflows.reddit_utils import RedditClient
from .base import BaseClient
from .finnhub_client import FinnhubClient
from .google_news_client import GoogleNewsClient
from .yfinance_client import YFinanceClient
__all__ = [
"BaseClient",
"YFinanceClient",
"GoogleNewsClient",
"FinnhubClient",
"RedditClient",
]

View File

@ -0,0 +1,100 @@
"""
Base client abstraction for live data access.
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any
class BaseClient(ABC):
"""
Base class for all data clients that access live APIs.
Provides common interface for different data sources while allowing
each client to implement its specific data fetching logic.
"""
def __init__(self, **kwargs):
"""Initialize client with configuration."""
self.config = kwargs
@abstractmethod
def test_connection(self) -> bool:
"""
Test if the client can connect to its data source.
Returns:
bool: True if connection is successful, False otherwise
"""
pass
@abstractmethod
def get_data(self, *args, **kwargs) -> dict[str, Any]:
"""
Get data from the client's data source.
Args:
*args: Positional arguments
**kwargs: Client-specific parameters
Returns:
Dict[str, Any]: Raw data from the source
"""
pass
def get_available_symbols(self) -> list[str]:
"""
Get list of available symbols/tickers from this data source.
Returns:
List[str]: Available symbols, empty list if not supported
"""
return []
def get_data_range(
self, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""
Get data for a specific date range.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional client-specific parameters
Returns:
Dict[str, Any]: Data for the specified range
"""
return self.get_data(start_date=start_date, end_date=end_date, **kwargs)
def validate_date_range(self, start_date: str, end_date: str) -> bool:
"""
Validate that the date range is acceptable.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
Returns:
bool: True if date range is valid
"""
try:
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
return start_dt <= end_dt <= datetime.now()
except ValueError:
return False
def get_client_info(self) -> dict[str, Any]:
"""
Get information about this client.
Returns:
Dict[str, Any]: Client metadata
"""
return {
"client_type": self.__class__.__name__,
"supports_symbols": len(self.get_available_symbols()) > 0,
"config": self.config,
}

View File

@ -0,0 +1,32 @@
"""
Pytest configuration for FinnhubClient tests with VCR.
"""
import pytest
@pytest.fixture(scope="module")
def vcr_config():
"""Configure VCR for recording/replaying HTTP interactions."""
return {
# Don't record the API key in cassettes
"filter_headers": ["X-Finnhub-Token", "Authorization"],
# Record once, then replay from cassettes
"record_mode": "once",
# Match requests on URI and method
"match_on": ["uri", "method"],
# Decode compressed responses for better readability
"decode_compressed_response": True,
# Store cassettes in the cassettes subdirectory
"cassette_library_dir": "tradingagents/clients/cassettes",
# Ignore localhost requests
"ignore_localhost": True,
# Custom serializer for better readability
"serializer": "yaml",
}
@pytest.fixture(scope="session")
def vcr_cassette_dir(tmp_path_factory):
"""Create temporary directory for VCR cassettes during testing."""
return tmp_path_factory.mktemp("cassettes")

View File

@ -0,0 +1,238 @@
"""
Finnhub client for financial data access.
"""
import logging
from datetime import date
from typing import Any
import finnhub
logger = logging.getLogger(__name__)
class FinnhubClient:
"""
Finnhub API client for accessing financial data including fundamental data.
Provides access to:
- Company news
- Insider transactions
- Insider sentiment
- Real-time quotes
- Company profiles
- Financial statements (balance sheet, income statement, cash flow)
"""
def __init__(self, api_key: str):
"""
Initialize Finnhub client with official SDK.
Args:
api_key: Finnhub API key
"""
self.api_key = api_key
self.client = finnhub.Client(api_key=api_key)
def test_connection(self) -> bool:
"""Test if the Finnhub API connection is working."""
try:
# Test with a simple quote request
response = self.client.quote("AAPL")
return "c" in response # 'c' is current price field
except Exception as e:
logger.error(f"Finnhub connection test failed: {e}")
return False
def get_balance_sheet(
self, symbol: str, frequency: str, report_date: date
) -> dict[str, Any]:
"""
Get balance sheet data from Finnhub.
Args:
symbol: Stock symbol (e.g., 'AAPL')
frequency: Reporting frequency ('quarterly' or 'annual')
report_date: Report date as date object
Returns:
Balance sheet data from Finnhub API
"""
try:
# Finnhub SDK expects frequency as 'quarterly' or 'annual'
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
return response if isinstance(response, dict) else {"data": []}
except Exception as e:
logger.error(f"Error fetching balance sheet for {symbol}: {e}")
return {"data": []}
def get_income_statement(
self, symbol: str, frequency: str, report_date: date
) -> dict[str, Any]:
"""
Get income statement data from Finnhub.
Args:
symbol: Stock symbol (e.g., 'AAPL')
frequency: Reporting frequency ('quarterly' or 'annual')
report_date: Report date as date object
Returns:
Income statement data from Finnhub API
"""
try:
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
return response if isinstance(response, dict) else {"data": []}
except Exception as e:
logger.error(f"Error fetching income statement for {symbol}: {e}")
return {"data": []}
def get_cash_flow(
self, symbol: str, frequency: str, report_date: date
) -> dict[str, Any]:
"""
Get cash flow statement data from Finnhub.
Args:
symbol: Stock symbol (e.g., 'AAPL')
frequency: Reporting frequency ('quarterly' or 'annual')
report_date: Report date as date object
Returns:
Cash flow statement data from Finnhub API
"""
try:
freq = "quarterly" if frequency.lower() in ["quarterly", "q"] else "annual"
response = self.client.financials_reported(symbol=symbol.upper(), freq=freq)
return response if isinstance(response, dict) else {"data": []}
except Exception as e:
logger.error(f"Error fetching cash flow for {symbol}: {e}")
return {"data": []}
def get_company_news(
self, symbol: str, start_date: date, end_date: date
) -> list[dict[str, Any]]:
"""
Get company news for a specific symbol and date range.
Args:
symbol: Stock symbol (e.g., 'AAPL')
start_date: Start date as date object
end_date: End date as date object
Returns:
List of news articles
"""
# Convert date objects to strings for API
start_str = start_date.isoformat()
end_str = end_date.isoformat()
try:
response = self.client.company_news(
symbol.upper(), _from=start_str, to=end_str
)
return response if isinstance(response, list) else []
except Exception as e:
logger.error(f"Error fetching news for {symbol}: {e}")
return []
def get_insider_transactions(
self, symbol: str, start_date: date, end_date: date
) -> dict[str, Any]:
"""
Get insider transactions for a company.
Args:
symbol: Stock symbol
start_date: Start date as date object
end_date: End date as date object
Returns:
Insider transaction data
"""
start_str = start_date.isoformat()
end_str = end_date.isoformat()
try:
response = self.client.stock_insider_transactions(
symbol.upper(), _from=start_str, to=end_str
)
return response if isinstance(response, dict) else {"data": []}
except Exception as e:
logger.error(f"Error fetching insider transactions for {symbol}: {e}")
return {"data": []}
def get_insider_sentiment(
self, symbol: str, start_date: date, end_date: date
) -> dict[str, Any]:
"""
Get insider sentiment data for a company.
Args:
symbol: Stock symbol
start_date: Start date as date object
end_date: End date as date object
Returns:
Insider sentiment data
"""
start_str = start_date.isoformat()
end_str = end_date.isoformat()
try:
response = self.client.stock_insider_sentiment(
symbol.upper(), _from=start_str, to=end_str
)
return response if isinstance(response, dict) else {"data": []}
except Exception as e:
logger.error(f"Error fetching insider sentiment for {symbol}: {e}")
return {"data": []}
def get_quote(self, symbol: str) -> dict[str, Any]:
"""
Get current quote for a symbol.
Args:
symbol: Stock symbol
Returns:
Quote data with current price, change, etc.
"""
try:
response = self.client.quote(symbol.upper())
return response if isinstance(response, dict) else {}
except Exception as e:
logger.error(f"Error fetching quote for {symbol}: {e}")
return {}
def get_company_profile(self, symbol: str) -> dict[str, Any]:
"""
Get company profile information.
Args:
symbol: Stock symbol
Returns:
Company profile data
"""
try:
response = self.client.company_profile2(symbol=symbol.upper())
return response if isinstance(response, dict) else {}
except Exception as e:
logger.error(f"Error fetching company profile for {symbol}: {e}")
return {}
def get_client_info(self) -> dict[str, Any]:
"""
Get information about this client.
Returns:
Dict[str, Any]: Client metadata
"""
return {
"client_type": self.__class__.__name__,
"api_key_set": bool(self.api_key),
"sdk_version": getattr(finnhub, "__version__", "unknown"),
}

View File

@ -0,0 +1,210 @@
"""
Google News client for live news data via web scraping.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
from tradingagents.dataflows.googlenews_utils import getNewsData
from .base import BaseClient
logger = logging.getLogger(__name__)
class GoogleNewsClient(BaseClient):
"""Client for Google News data via web scraping."""
def __init__(self, **kwargs):
"""
Initialize Google News client.
Args:
**kwargs: Configuration options including rate limits
"""
super().__init__(**kwargs)
self.max_retries = kwargs.get("max_retries", 3)
self.delay_between_requests = kwargs.get("delay_between_requests", 1.0)
def test_connection(self) -> bool:
"""Test Google News connection by fetching a simple query."""
try:
# Test with a simple query for recent news
end_date = datetime.now().strftime("%Y-%m-%d")
start_date = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
test_data = getNewsData("technology", start_date, end_date)
return isinstance(test_data, list)
except Exception as e:
logger.error(f"Google News connection test failed: {e}")
return False
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""
Get news data for a query and date range.
Args:
query: Search query
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
Dict[str, Any]: News data with metadata
"""
if not self.validate_date_range(start_date, end_date):
raise ValueError(f"Invalid date range: {start_date} to {end_date}")
try:
# Replace spaces with + for URL encoding
formatted_query = query.replace(" ", "+")
logger.info(
f"Fetching Google News for query: {query} from {start_date} to {end_date}"
)
news_results = getNewsData(formatted_query, start_date, end_date)
if not news_results:
logger.warning(f"No news found for query: {query}")
return {
"query": query,
"period": {"start": start_date, "end": end_date},
"articles": [],
"metadata": {
"source": "google_news",
"empty": True,
"reason": "no_articles_found",
},
}
# Process and standardize article data
processed_articles = []
for article in news_results:
processed_article = {
"headline": article.get("title", ""),
"summary": article.get("snippet", ""),
"url": article.get("link", ""),
"source": article.get("source", "Unknown"),
"date": article.get(
"date", end_date
), # Fallback to end_date if no date
"entities": article.get("entities", []),
}
processed_articles.append(processed_article)
return {
"query": query,
"period": {"start": start_date, "end": end_date},
"articles": processed_articles,
"metadata": {
"source": "google_news",
"article_count": len(processed_articles),
"retrieved_at": datetime.utcnow().isoformat(),
"search_query": formatted_query,
},
}
except Exception as e:
logger.error(f"Error fetching Google News for query '{query}': {e}")
raise
def get_company_news(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""
Get news data specific to a company symbol.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
Dict[str, Any]: Company-specific news data
"""
# Create company-focused search query
company_query = f"{symbol} stock"
result = self.get_data(company_query, start_date, end_date, **kwargs)
result["symbol"] = symbol
result["metadata"]["query_type"] = "company_specific"
return result
def get_global_news(
self,
start_date: str,
end_date: str,
categories: list[str] | None = None,
**kwargs,
) -> dict[str, Any]:
"""
Get global/macro news that might affect markets.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
categories: List of news categories to search
**kwargs: Additional parameters
Returns:
Dict[str, Any]: Global news data
"""
if categories is None:
categories = ["economy", "finance", "markets", "business"]
all_articles = []
for category in categories:
try:
category_data = self.get_data(category, start_date, end_date, **kwargs)
# Add category tag to each article
for article in category_data.get("articles", []):
article["category"] = category
all_articles.extend(category_data.get("articles", []))
except Exception as e:
logger.warning(f"Failed to fetch news for category '{category}': {e}")
continue
return {
"query": "global_news",
"categories": categories,
"period": {"start": start_date, "end": end_date},
"articles": all_articles,
"metadata": {
"source": "google_news",
"article_count": len(all_articles),
"categories_searched": categories,
"retrieved_at": datetime.utcnow().isoformat(),
"query_type": "global_news",
},
}
def get_available_categories(self) -> list[str]:
"""
Get list of commonly used news categories.
Returns:
List[str]: News categories
"""
return [
"business",
"economy",
"finance",
"markets",
"technology",
"politics",
"world",
"healthcare",
"energy",
"crypto",
]

View File

@ -0,0 +1,317 @@
#!/usr/bin/env python3
"""
Comprehensive tests for FinnhubClient using pytest-vcr.
This test suite records real API interactions with Finnhub and replays them
in subsequent runs, providing realistic testing without network dependencies.
"""
import os
from datetime import date
import pytest
from tradingagents.clients.finnhub_client import FinnhubClient
@pytest.fixture
def finnhub_client():
"""Create FinnhubClient with test API key."""
# Use environment variable or test key for VCR recording
api_key = os.getenv("FINNHUB_API_KEY", "test_api_key")
return FinnhubClient(api_key=api_key)
@pytest.fixture
def vcr_config():
"""Configure VCR with proper settings."""
return {
"filter_headers": ["X-Finnhub-Token"], # Filter out API key from recordings
"record_mode": "once", # Record once, then replay
"match_on": ["uri", "method"],
"decode_compressed_response": True,
}
class TestFinnhubClientConnection:
"""Test client connection and initialization."""
@pytest.mark.vcr
def test_client_initialization(self, finnhub_client):
"""Test that client initializes correctly."""
assert finnhub_client.api_key is not None
assert finnhub_client.client is not None
@pytest.mark.vcr
def test_connection_success(self, finnhub_client):
"""Test successful API connection."""
assert finnhub_client.test_connection() is True
def test_client_info(self, finnhub_client):
"""Test client metadata."""
info = finnhub_client.get_client_info()
assert info["client_type"] == "FinnhubClient"
assert info["api_key_set"] is True
class TestFundamentalDataMethods:
"""Test the fundamental data methods needed by FundamentalDataService."""
@pytest.mark.vcr
def test_get_balance_sheet_quarterly(self, finnhub_client):
"""Test balance sheet retrieval with date object."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
assert isinstance(data, dict)
# Finnhub financials_reported returns data structure with 'data' key
assert "data" in data or len(data) > 0
@pytest.mark.vcr
def test_get_balance_sheet_annual(self, finnhub_client):
"""Test annual balance sheet retrieval."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_balance_sheet("AAPL", "annual", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_get_income_statement_quarterly(self, finnhub_client):
"""Test income statement retrieval."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_income_statement("AAPL", "quarterly", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_get_income_statement_annual(self, finnhub_client):
"""Test annual income statement retrieval."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_income_statement("AAPL", "annual", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_get_cash_flow_with_date_object(self, finnhub_client):
"""Test cash flow retrieval with date object."""
test_date = date(2024, 3, 31)
data = finnhub_client.get_cash_flow("AAPL", "quarterly", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_fundamental_data_different_symbols(self, finnhub_client):
"""Test fundamental data for different symbols."""
symbols = ["AAPL", "MSFT", "GOOGL"]
test_date = date(2024, 1, 1)
for symbol in symbols:
data = finnhub_client.get_balance_sheet(symbol, "quarterly", test_date)
assert isinstance(data, dict)
class TestExistingMethods:
"""Test existing methods with enhanced date support."""
@pytest.mark.vcr
def test_get_company_news_january(self, finnhub_client):
"""Test company news for January 2024."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
news = finnhub_client.get_company_news("AAPL", start_date, end_date)
assert isinstance(news, list)
@pytest.mark.vcr
def test_get_company_news_date_objects(self, finnhub_client):
"""Test company news with date objects."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
news = finnhub_client.get_company_news("AAPL", start_date, end_date)
assert isinstance(news, list)
@pytest.mark.vcr
def test_get_insider_transactions_date_objects(self, finnhub_client):
"""Test insider transactions with date objects."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
data = finnhub_client.get_insider_transactions("AAPL", start_date, end_date)
assert isinstance(data, dict)
assert "data" in data
@pytest.mark.vcr
def test_get_insider_sentiment(self, finnhub_client):
"""Test insider sentiment with date objects."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
data = finnhub_client.get_insider_sentiment("AAPL", start_date, end_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_get_quote(self, finnhub_client):
"""Test stock quote retrieval."""
quote = finnhub_client.get_quote("AAPL")
assert isinstance(quote, dict)
# Quote should contain current price 'c' field
if quote: # Only check if we got data
assert "c" in quote
@pytest.mark.vcr
def test_get_company_profile(self, finnhub_client):
"""Test company profile retrieval."""
profile = finnhub_client.get_company_profile("AAPL")
assert isinstance(profile, dict)
class TestErrorHandling:
"""Test error handling and edge cases."""
@pytest.mark.vcr
def test_invalid_symbol_balance_sheet(self, finnhub_client):
"""Test balance sheet with invalid symbol."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_balance_sheet(
"INVALID_SYMBOL_XYZ", "quarterly", test_date
)
# Should return dict with empty data structure, not raise exception
assert isinstance(data, dict)
assert "data" in data
@pytest.mark.vcr
def test_invalid_symbol_news(self, finnhub_client):
"""Test news with invalid symbol."""
start_date = date(2024, 1, 1)
end_date = date(2024, 1, 31)
news = finnhub_client.get_company_news(
"INVALID_SYMBOL_XYZ", start_date, end_date
)
# Should return empty list, not raise exception
assert isinstance(news, list)
def test_connection_with_invalid_api_key(self):
"""Test connection failure with invalid API key."""
client = FinnhubClient("invalid_api_key")
# This should return False, not raise an exception
assert client.test_connection() is False
@pytest.mark.vcr
def test_frequency_normalization(self, finnhub_client):
"""Test that frequency parameters are normalized correctly."""
# Test different frequency formats
frequencies = ["quarterly", "QUARTERLY", "q", "Q", "annual", "ANNUAL", "a", "A"]
test_date = date(2024, 1, 1)
for freq in frequencies:
data = finnhub_client.get_balance_sheet("AAPL", freq, test_date)
assert isinstance(data, dict)
def test_date_edge_cases(self, finnhub_client):
"""Test date edge cases."""
# Test with year-end date
test_date = date(2024, 12, 31)
# This shouldn't raise exceptions
# (actual API calls are mocked/recorded)
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
assert isinstance(data, dict)
class TestMultipleSymbolsAndTimeframes:
"""Test with multiple symbols and different timeframes."""
@pytest.mark.vcr
def test_multiple_symbols_balance_sheet(self, finnhub_client):
"""Test balance sheet for multiple major symbols."""
symbols = ["AAPL", "MSFT", "GOOGL", "TSLA"]
test_date = date(2024, 1, 1)
for symbol in symbols:
data = finnhub_client.get_balance_sheet(symbol, "quarterly", test_date)
assert isinstance(data, dict)
@pytest.mark.vcr
def test_quarterly_vs_annual_frequency(self, finnhub_client):
"""Test quarterly vs annual frequency."""
symbol = "AAPL"
test_date = date(2024, 1, 1)
quarterly_data = finnhub_client.get_balance_sheet(
symbol, "quarterly", test_date
)
annual_data = finnhub_client.get_balance_sheet(symbol, "annual", test_date)
assert isinstance(quarterly_data, dict)
assert isinstance(annual_data, dict)
@pytest.mark.vcr
def test_all_fundamental_methods_same_symbol(self, finnhub_client):
"""Test all fundamental methods for the same symbol."""
symbol = "AAPL"
frequency = "quarterly"
report_date = date(2024, 1, 1)
balance_sheet = finnhub_client.get_balance_sheet(symbol, frequency, report_date)
income_statement = finnhub_client.get_income_statement(
symbol, frequency, report_date
)
cash_flow = finnhub_client.get_cash_flow(symbol, frequency, report_date)
assert isinstance(balance_sheet, dict)
assert isinstance(income_statement, dict)
assert isinstance(cash_flow, dict)
# Integration test to verify the client works with service expectations
class TestServiceIntegration:
"""Test that client works as expected by FundamentalDataService."""
def test_service_expected_methods_exist(self, finnhub_client):
"""Test that all methods expected by FundamentalDataService exist."""
# These are the methods the service calls
assert hasattr(finnhub_client, "get_balance_sheet")
assert hasattr(finnhub_client, "get_income_statement")
assert hasattr(finnhub_client, "get_cash_flow")
# Verify method signatures accept the expected parameters
import inspect
for method_name in [
"get_balance_sheet",
"get_income_statement",
"get_cash_flow",
]:
method = getattr(finnhub_client, method_name)
sig = inspect.signature(method)
params = list(sig.parameters.keys())
# Service calls: method(symbol, frequency, date)
assert "symbol" in params
assert "frequency" in params
assert "report_date" in params
@pytest.mark.vcr
def test_data_format_for_service_conversion(self, finnhub_client):
"""Test that returned data format can be processed by service conversion method."""
test_date = date(2024, 1, 1)
data = finnhub_client.get_balance_sheet("AAPL", "quarterly", test_date)
# Service expects either empty dict or dict with data
assert isinstance(data, dict)
# The service's _convert_to_financial_statement expects either:
# 1. Empty/falsy data -> returns None
# 2. Dict with "data" key containing the actual financial data
if data: # If we got data
# Should be able to access data["data"] or similar structure
# This validates the format is compatible with service expectations
assert isinstance(data, dict)

View File

@ -0,0 +1,245 @@
"""
Yahoo Finance client for live market data.
"""
import logging
from datetime import datetime, timedelta
from typing import Any
import pandas as pd
import yfinance as yf
from .base import BaseClient
logger = logging.getLogger(__name__)
class YFinanceClient(BaseClient):
"""Client for Yahoo Finance API using yfinance library."""
def __init__(self, **kwargs):
"""
Initialize Yahoo Finance client.
Args:
**kwargs: Configuration options
"""
super().__init__(**kwargs)
self.session = None
def test_connection(self) -> bool:
"""Test Yahoo Finance connection by fetching a known ticker."""
try:
ticker = yf.Ticker("AAPL")
info = ticker.info
return bool(info and "symbol" in info)
except Exception as e:
logger.error(f"Yahoo Finance connection test failed: {e}")
return False
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""
Get historical price data for a symbol.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
Dict[str, Any]: Price data with metadata
"""
if not self.validate_date_range(start_date, end_date):
raise ValueError(f"Invalid date range: {start_date} to {end_date}")
try:
ticker = yf.Ticker(symbol.upper())
# Add one day to end_date to make it inclusive
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
end_date_adjusted = end_date_obj + timedelta(days=1)
end_date_str = end_date_adjusted.strftime("%Y-%m-%d")
data = ticker.history(start=start_date, end=end_date_str)
if data.empty:
logger.warning(
f"No data found for {symbol} between {start_date} and {end_date}"
)
return {
"symbol": symbol,
"data": [],
"metadata": {
"source": "yahoo_finance",
"empty": True,
"reason": "no_data_available",
},
}
# Remove timezone info and format data
if isinstance(data.index, pd.DatetimeIndex) and data.index.tz is not None:
data.index = data.index.tz_localize(None)
# Reset index to make Date a column
data = data.reset_index()
data["Date"] = data["Date"].dt.strftime("%Y-%m-%d %H:%M:%S")
# Round numerical values
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
for col in numeric_columns:
if col in data.columns:
data[col] = data[col].round(2)
# Convert to list of dictionaries
records = data.to_dict("records")
return {
"symbol": symbol,
"period": {"start": start_date, "end": end_date},
"data": records,
"metadata": {
"source": "yahoo_finance",
"record_count": len(records),
"columns": list(data.columns),
"retrieved_at": datetime.utcnow().isoformat(),
},
}
except Exception as e:
logger.error(f"Error fetching Yahoo Finance data for {symbol}: {e}")
raise
def get_company_info(self, symbol: str) -> dict[str, Any]:
"""
Get company information for a symbol.
Args:
symbol: Stock ticker symbol
Returns:
Dict[str, Any]: Company information
"""
try:
ticker = yf.Ticker(symbol.upper())
info = ticker.info
return {
"symbol": symbol,
"info": info,
"metadata": {
"source": "yahoo_finance",
"retrieved_at": datetime.utcnow().isoformat(),
},
}
except Exception as e:
logger.error(f"Error fetching company info for {symbol}: {e}")
return {
"symbol": symbol,
"info": {},
"metadata": {
"source": "yahoo_finance",
"error": str(e),
"retrieved_at": datetime.utcnow().isoformat(),
},
}
def get_financials(
self, symbol: str, statement_type: str = "income"
) -> dict[str, Any]:
"""
Get financial statements for a symbol.
Args:
symbol: Stock ticker symbol
statement_type: Type of statement ("income", "balance", "cashflow")
Returns:
Dict[str, Any]: Financial statement data
"""
try:
ticker = yf.Ticker(symbol.upper())
if statement_type == "income":
annual = ticker.financials
quarterly = ticker.quarterly_financials
elif statement_type == "balance":
annual = ticker.balance_sheet
quarterly = ticker.quarterly_balance_sheet
elif statement_type == "cashflow":
annual = ticker.cashflow
quarterly = ticker.quarterly_cashflow
else:
raise ValueError(f"Unknown statement type: {statement_type}")
result = {
"symbol": symbol,
"statement_type": statement_type,
"annual": {},
"quarterly": {},
"metadata": {
"source": "yahoo_finance",
"retrieved_at": datetime.utcnow().isoformat(),
},
}
# Process annual data
if not annual.empty:
annual_data = annual.copy()
if isinstance(annual_data.columns, pd.DatetimeIndex):
annual_data.columns = annual_data.columns.strftime("%Y-%m-%d")
result["annual"] = annual_data.to_dict()
# Process quarterly data
if not quarterly.empty:
quarterly_data = quarterly.copy()
if isinstance(quarterly_data.columns, pd.DatetimeIndex):
quarterly_data.columns = quarterly_data.columns.strftime("%Y-%m-%d")
result["quarterly"] = quarterly_data.to_dict()
return result
except Exception as e:
logger.error(
f"Error fetching {statement_type} financials for {symbol}: {e}"
)
return {
"symbol": symbol,
"statement_type": statement_type,
"annual": {},
"quarterly": {},
"metadata": {
"source": "yahoo_finance",
"error": str(e),
"retrieved_at": datetime.utcnow().isoformat(),
},
}
def get_available_symbols(self) -> list[str]:
"""
Yahoo Finance doesn't provide a direct way to list all symbols.
Return common major symbols as examples.
"""
return [
"AAPL",
"MSFT",
"GOOGL",
"AMZN",
"TSLA",
"META",
"NVDA",
"AMD",
"JPM",
"JNJ",
"V",
"WMT",
"PG",
"UNH",
"HD",
"MA",
"BAC",
"DIS",
]

119
tradingagents/config.py Normal file
View File

@ -0,0 +1,119 @@
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Literal, cast
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
# dotenv not installed, skip loading
pass
@dataclass
class TradingAgentsConfig:
"""Configuration for TradingAgents system with type safety and validation."""
# Directory settings
project_dir: str = field(
default_factory=lambda: str(Path(__file__).parent.absolute())
)
results_dir: str = field(
default_factory=lambda: os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results")
)
data_dir: str = "/Users/yluo/Documents/Code/ScAI/FR1-data"
data_cache_dir: str = field(init=False)
# LLM settings
llm_provider: Literal["openai", "anthropic", "google", "ollama", "openrouter"] = (
"openai"
)
deep_think_llm: str = "o4-mini"
quick_think_llm: str = "gpt-4o-mini"
backend_url: str = "https://api.openai.com/v1"
# Debate and discussion settings
max_debate_rounds: int = 1
max_risk_discuss_rounds: int = 1
max_recur_limit: int = 100
# Tool settings
online_tools: bool = True
def __post_init__(self):
"""Set computed fields after initialization."""
self.data_cache_dir = os.path.join(self.project_dir, "dataflows/data_cache")
@classmethod
def _get_llm_provider(
cls, default: str = "openai"
) -> Literal["openai", "anthropic", "google", "ollama", "openrouter"]:
"""Get and validate LLM provider from environment."""
valid_providers = ["openai", "anthropic", "google", "ollama", "openrouter"]
provider = os.getenv("LLM_PROVIDER", default)
if provider not in valid_providers:
raise ValueError(
f"Invalid LLM_PROVIDER: {provider}. Must be one of: {', '.join(valid_providers)}"
)
return cast(
"Literal['openai', 'anthropic', 'google', 'ollama', 'openrouter']", provider
)
@classmethod
def from_env(cls) -> "TradingAgentsConfig":
"""Create config with environment variable overrides."""
return cls(
results_dir=os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
data_dir=os.getenv(
"TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data"
),
llm_provider=cls._get_llm_provider(),
deep_think_llm=os.getenv("DEEP_THINK_LLM", "o4-mini"),
quick_think_llm=os.getenv("QUICK_THINK_LLM", "gpt-4o-mini"),
backend_url=os.getenv("BACKEND_URL", "https://api.openai.com/v1"),
max_debate_rounds=int(os.getenv("MAX_DEBATE_ROUNDS", "1")),
max_risk_discuss_rounds=int(os.getenv("MAX_RISK_DISCUSS_ROUNDS", "1")),
max_recur_limit=int(os.getenv("MAX_RECUR_LIMIT", "100")),
online_tools=os.getenv("ONLINE_TOOLS", "true").lower() == "true",
)
def to_dict(self) -> dict:
"""Convert to dictionary for backward compatibility."""
return {
"project_dir": self.project_dir,
"results_dir": self.results_dir,
"data_dir": self.data_dir,
"data_cache_dir": self.data_cache_dir,
"llm_provider": self.llm_provider,
"deep_think_llm": self.deep_think_llm,
"quick_think_llm": self.quick_think_llm,
"backend_url": self.backend_url,
"max_debate_rounds": self.max_debate_rounds,
"max_risk_discuss_rounds": self.max_risk_discuss_rounds,
"max_recur_limit": self.max_recur_limit,
"online_tools": self.online_tools,
}
def copy(self) -> "TradingAgentsConfig":
"""Create a copy of the configuration."""
return TradingAgentsConfig(
project_dir=self.project_dir,
results_dir=self.results_dir,
data_dir=self.data_dir,
llm_provider=self.llm_provider,
deep_think_llm=self.deep_think_llm,
quick_think_llm=self.quick_think_llm,
backend_url=self.backend_url,
max_debate_rounds=self.max_debate_rounds,
max_risk_discuss_rounds=self.max_risk_discuss_rounds,
max_recur_limit=self.max_recur_limit,
online_tools=self.online_tools,
)
# For backward compatibility, create a default instance
DEFAULT_CONFIG = TradingAgentsConfig()

View File

@ -1,46 +0,0 @@
from .finnhub_utils import get_data_in_range
from .googlenews_utils import getNewsData
from .yfin_utils import YFinanceUtils
from .reddit_utils import fetch_top_from_category
from .stockstats_utils import StockstatsUtils
from .yfin_utils import YFinanceUtils
from .interface import (
# News and sentiment functions
get_finnhub_news,
get_finnhub_company_insider_sentiment,
get_finnhub_company_insider_transactions,
get_google_news,
get_reddit_global_news,
get_reddit_company_news,
# Financial statements functions
get_simfin_balance_sheet,
get_simfin_cashflow,
get_simfin_income_statements,
# Technical analysis functions
get_stock_stats_indicators_window,
get_stockstats_indicator,
# Market data functions
get_YFin_data_window,
get_YFin_data,
)
__all__ = [
# News and sentiment functions
"get_finnhub_news",
"get_finnhub_company_insider_sentiment",
"get_finnhub_company_insider_transactions",
"get_google_news",
"get_reddit_global_news",
"get_reddit_company_news",
# Financial statements functions
"get_simfin_balance_sheet",
"get_simfin_cashflow",
"get_simfin_income_statements",
# Technical analysis functions
"get_stock_stats_indicators_window",
"get_stockstats_indicator",
# Market data functions
"get_YFin_data_window",
"get_YFin_data",
]

View File

@ -1,34 +0,0 @@
import tradingagents.default_config as default_config
from typing import Dict, Optional
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
DATA_DIR: Optional[str] = None
def initialize_config():
"""Initialize the configuration with default values."""
global _config, DATA_DIR
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
DATA_DIR = _config["data_dir"]
def set_config(config: Dict):
"""Update the configuration with custom values."""
global _config, DATA_DIR
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
DATA_DIR = _config["data_dir"]
def get_config() -> Dict:
"""Get the current configuration."""
if _config is None:
initialize_config()
return _config.copy()
# Initialize with default config
initialize_config()

View File

@ -1,36 +0,0 @@
import json
import os
def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None):
"""
Gets finnhub data saved and processed on disk.
Args:
start_date (str): Start date in YYYY-MM-DD format.
end_date (str): End date in YYYY-MM-DD format.
data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported.
data_dir (str): Directory where the data is saved.
period (str): Default to none, if there is a period specified, should be annual or quarterly.
"""
if period:
data_path = os.path.join(
data_dir,
"finnhub_data",
data_type,
f"{ticker}_{period}_data_formatted.json",
)
else:
data_path = os.path.join(
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
)
data = open(data_path, "r")
data = json.load(data)
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)
filtered_data = {}
for key, value in data.items():
if start_date <= key <= end_date and len(value) > 0:
filtered_data[key] = value
return filtered_data

View File

@ -1,15 +1,14 @@
import json
import random
import time
from datetime import datetime
import requests
from bs4 import BeautifulSoup
from datetime import datetime
import time
import random
from tenacity import (
retry,
retry_if_result,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
retry_if_result,
)
@ -73,11 +72,24 @@ def getNewsData(query, start_date, end_date):
for el in results_on_page:
try:
link = el.find("a")["href"]
title = el.select_one("div.MBeuO").get_text()
snippet = el.select_one(".GI74Re").get_text()
date = el.select_one(".LfVVr").get_text()
source = el.select_one(".NUnG9d span").get_text()
link_elem = el.find("a")
# Handle BeautifulSoup element access safely
if link_elem:
link = getattr(link_elem, "attrs", {}).get("href", "")
else:
link = ""
title_elem = el.select_one("div.MBeuO")
title = title_elem.get_text() if title_elem else ""
snippet_elem = el.select_one(".GI74Re")
snippet = snippet_elem.get_text() if snippet_elem else ""
date_elem = el.select_one(".LfVVr")
date = date_elem.get_text() if date_elem else ""
source_elem = el.select_one(".NUnG9d span")
source = source_elem.get_text() if source_elem else ""
news_results.append(
{
"link": link,

View File

@ -1,807 +0,0 @@
from typing import Annotated, Dict
from .reddit_utils import fetch_top_from_category
from .yfin_utils import *
from .stockstats_utils import *
from .googlenews_utils import *
from .finnhub_utils import get_data_in_range
from dateutil.relativedelta import relativedelta
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import json
import os
import pandas as pd
from tqdm import tqdm
import yfinance as yf
from openai import OpenAI
from .config import get_config, set_config, DATA_DIR
def get_finnhub_news(
ticker: Annotated[
str,
"Search query of a company's, e.g. 'AAPL, TSM, etc.",
],
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
):
"""
Retrieve news about a company within a time frame
Args
ticker (str): ticker for the company you are interested in
start_date (str): Start date in yyyy-mm-dd format
end_date (str): End date in yyyy-mm-dd format
Returns
str: dataframe containing the news of the company in the time frame
"""
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
if len(result) == 0:
return ""
combined_result = ""
for day, data in result.items():
if len(data) == 0:
continue
for entry in data:
current_news = (
"### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"]
)
combined_result += current_news + "\n\n"
return f"## {ticker} News, from {before} to {curr_date}:\n" + str(combined_result)
def get_finnhub_company_insider_sentiment(
ticker: Annotated[str, "ticker symbol for the company"],
curr_date: Annotated[
str,
"current date of you are trading at, yyyy-mm-dd",
],
look_back_days: Annotated[int, "number of days to look back"],
):
"""
Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading on, yyyy-mm-dd
Returns:
str: a report of the sentiment in the past 15 days starting at curr_date
"""
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR)
if len(data) == 0:
return ""
result_str = ""
seen_dicts = []
for date, senti_list in data.items():
for entry in senti_list:
if entry not in seen_dicts:
result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n"
seen_dicts.append(entry)
return (
f"## {ticker} Insider Sentiment Data for {before} to {curr_date}:\n"
+ result_str
+ "The change field refers to the net buying/selling from all insiders' transactions. The mspr field refers to monthly share purchase ratio."
)
def get_finnhub_company_insider_transactions(
ticker: Annotated[str, "ticker symbol"],
curr_date: Annotated[
str,
"current date you are trading at, yyyy-mm-dd",
],
look_back_days: Annotated[int, "how many days to look back"],
):
"""
Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days
Args:
ticker (str): ticker symbol of the company
curr_date (str): current date you are trading at, yyyy-mm-dd
Returns:
str: a report of the company's insider transaction/trading informtaion in the past 15 days
"""
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR)
if len(data) == 0:
return ""
result_str = ""
seen_dicts = []
for date, senti_list in data.items():
for entry in senti_list:
if entry not in seen_dicts:
result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n"
seen_dicts.append(entry)
return (
f"## {ticker} insider transactions from {before} to {curr_date}:\n"
+ result_str
+ "The change field reflects the variation in share count—here a negative number indicates a reduction in holdings—while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction."
)
def get_simfin_balance_sheet(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"balance_sheet",
"companies",
"us",
f"us-balance-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No balance sheet available before the given current date.")
return ""
# Get the most recent balance sheet by selecting the row with the latest Publish Date
latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_balance_sheet = latest_balance_sheet.drop("SimFinId")
return (
f"## {freq} balance sheet for {ticker} released on {str(latest_balance_sheet['Publish Date'])[0:10]}: \n"
+ str(latest_balance_sheet)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of assets, liabilities, and equity. Assets are grouped as current (liquid items like cash and receivables) and noncurrent (long-term investments and property). Liabilities are split between short-term obligations and long-term debts, while equity reflects shareholder funds such as paid-in capital and retained earnings. Together, these components ensure that total assets equal the sum of liabilities and equity."
)
def get_simfin_cashflow(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"cash_flow",
"companies",
"us",
f"us-cashflow-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No cash flow statement available before the given current date.")
return ""
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_cash_flow = latest_cash_flow.drop("SimFinId")
return (
f"## {freq} cash flow statement for {ticker} released on {str(latest_cash_flow['Publish Date'])[0:10]}: \n"
+ str(latest_cash_flow)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of cash movements. Operating activities show cash generated from core business operations, including net income adjustments for non-cash items and working capital changes. Investing activities cover asset acquisitions/disposals and investments. Financing activities include debt transactions, equity issuances/repurchases, and dividend payments. The net change in cash represents the overall increase or decrease in the company's cash position during the reporting period."
)
def get_simfin_income_statements(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
str,
"reporting frequency of the company's financial history: annual / quarterly",
],
curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"],
):
data_path = os.path.join(
DATA_DIR,
"fundamental_data",
"simfin_data_all",
"income_statements",
"companies",
"us",
f"us-income-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
# Convert date strings to datetime objects and remove any time components
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
# Convert the current date to datetime and normalize
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
# Filter the DataFrame for the given ticker and for reports that were published on or before the current date
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No income statement available before the given current date.")
return ""
# Get the most recent income statement by selecting the row with the latest Publish Date
latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
# drop the SimFinID column
latest_income = latest_income.drop("SimFinId")
return (
f"## {freq} income statement for {ticker} released on {str(latest_income['Publish Date'])[0:10]}: \n"
+ str(latest_income)
+ "\n\nThis includes metadata like reporting dates and currency, share details, and a comprehensive breakdown of the company's financial performance. Starting with Revenue, it shows Cost of Revenue and resulting Gross Profit. Operating Expenses are detailed, including SG&A, R&D, and Depreciation. The statement then shows Operating Income, followed by non-operating items and Interest Expense, leading to Pretax Income. After accounting for Income Tax and any Extraordinary items, it concludes with Net Income, representing the company's bottom-line profit or loss for the period."
)
def get_google_news(
query: Annotated[str, "Query to search with"],
curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
) -> str:
query = query.replace(" ", "+")
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
news_results = getNewsData(query, before, curr_date)
news_str = ""
for news in news_results:
news_str += (
f"### {news['title']} (source: {news['source']}) \n\n{news['snippet']}\n\n"
)
if len(news_results) == 0:
return ""
return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
def get_reddit_global_news(
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
) -> str:
"""
Retrieve the latest top reddit news
Args:
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
"""
start_date = datetime.strptime(start_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (start_date - curr_date).days + 1
pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations)
while curr_date <= start_date:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"global_news",
curr_date_str,
max_limit_per_day,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0:
return ""
news_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}"
def get_reddit_company_news(
ticker: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
max_limit_per_day: Annotated[int, "Maximum number of news per day"],
) -> str:
"""
Retrieve the latest top reddit news
Args:
ticker: ticker symbol of the company
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
"""
start_date = datetime.strptime(start_date, "%Y-%m-%d")
before = start_date - relativedelta(days=look_back_days)
before = before.strftime("%Y-%m-%d")
posts = []
# iterate from start_date to end_date
curr_date = datetime.strptime(before, "%Y-%m-%d")
total_iterations = (start_date - curr_date).days + 1
pbar = tqdm(
desc=f"Getting Company News for {ticker} on {start_date}",
total=total_iterations,
)
while curr_date <= start_date:
curr_date_str = curr_date.strftime("%Y-%m-%d")
fetch_result = fetch_top_from_category(
"company_news",
curr_date_str,
max_limit_per_day,
ticker,
data_path=os.path.join(DATA_DIR, "reddit_data"),
)
posts.extend(fetch_result)
curr_date += relativedelta(days=1)
pbar.update(1)
pbar.close()
if len(posts) == 0:
return ""
news_str = ""
for post in posts:
if post["content"] == "":
news_str += f"### {post['title']}\n\n"
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"##{ticker} News Reddit, from {before} to {curr_date}:\n\n{news_str}"
def get_stock_stats_indicators_window(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
],
look_back_days: Annotated[int, "how many days to look back"],
online: Annotated[bool, "to fetch data online or offline"],
) -> str:
best_ind_params = {
# Moving Averages
"close_50_sma": (
"50 SMA: A medium-term trend indicator. "
"Usage: Identify trend direction and serve as dynamic support/resistance. "
"Tips: It lags price; combine with faster indicators for timely signals."
),
"close_200_sma": (
"200 SMA: A long-term trend benchmark. "
"Usage: Confirm overall market trend and identify golden/death cross setups. "
"Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries."
),
"close_10_ema": (
"10 EMA: A responsive short-term average. "
"Usage: Capture quick shifts in momentum and potential entry points. "
"Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals."
),
# MACD Related
"macd": (
"MACD: Computes momentum via differences of EMAs. "
"Usage: Look for crossovers and divergence as signals of trend changes. "
"Tips: Confirm with other indicators in low-volatility or sideways markets."
),
"macds": (
"MACD Signal: An EMA smoothing of the MACD line. "
"Usage: Use crossovers with the MACD line to trigger trades. "
"Tips: Should be part of a broader strategy to avoid false positives."
),
"macdh": (
"MACD Histogram: Shows the gap between the MACD line and its signal. "
"Usage: Visualize momentum strength and spot divergence early. "
"Tips: Can be volatile; complement with additional filters in fast-moving markets."
),
# Momentum Indicators
"rsi": (
"RSI: Measures momentum to flag overbought/oversold conditions. "
"Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. "
"Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis."
),
# Volatility Indicators
"boll": (
"Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. "
"Usage: Acts as a dynamic benchmark for price movement. "
"Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals."
),
"boll_ub": (
"Bollinger Upper Band: Typically 2 standard deviations above the middle line. "
"Usage: Signals potential overbought conditions and breakout zones. "
"Tips: Confirm signals with other tools; prices may ride the band in strong trends."
),
"boll_lb": (
"Bollinger Lower Band: Typically 2 standard deviations below the middle line. "
"Usage: Indicates potential oversold conditions. "
"Tips: Use additional analysis to avoid false reversal signals."
),
"atr": (
"ATR: Averages true range to measure volatility. "
"Usage: Set stop-loss levels and adjust position sizes based on current market volatility. "
"Tips: It's a reactive measure, so use it as part of a broader risk management strategy."
),
# Volume-Based Indicators
"vwma": (
"VWMA: A moving average weighted by volume. "
"Usage: Confirm trends by integrating price action with volume data. "
"Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses."
),
"mfi": (
"MFI: The Money Flow Index is a momentum indicator that uses both price and volume to measure buying and selling pressure. "
"Usage: Identify overbought (>80) or oversold (<20) conditions and confirm the strength of trends or reversals. "
"Tips: Use alongside RSI or MACD to confirm signals; divergence between price and MFI can indicate potential reversals."
),
}
if indicator not in best_ind_params:
raise ValueError(
f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}"
)
end_date = curr_date
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
before = curr_date - relativedelta(days=look_back_days)
if not online:
# read from YFin data
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
data["Date"] = pd.to_datetime(data["Date"], utc=True)
dates_in_df = data["Date"].astype(str).str[:10]
ind_string = ""
while curr_date >= before:
# only do the trading dates
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
)
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
curr_date = curr_date - relativedelta(days=1)
else:
# online gathering
ind_string = ""
while curr_date >= before:
indicator_value = get_stockstats_indicator(
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
)
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
curr_date = curr_date - relativedelta(days=1)
result_str = (
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
+ ind_string
+ "\n\n"
+ best_ind_params.get(indicator, "No description available.")
)
return result_str
def get_stockstats_indicator(
symbol: Annotated[str, "ticker symbol of the company"],
indicator: Annotated[str, "technical indicator to get the analysis and report of"],
curr_date: Annotated[
str, "The current trading date you are trading on, YYYY-mm-dd"
],
online: Annotated[bool, "to fetch data online or offline"],
) -> str:
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
curr_date = curr_date.strftime("%Y-%m-%d")
try:
indicator_value = StockstatsUtils.get_stock_stats(
symbol,
indicator,
curr_date,
os.path.join(DATA_DIR, "market_data", "price_data"),
online=online,
)
except Exception as e:
print(
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
)
return ""
return str(indicator_value)
def get_YFin_data_window(
symbol: Annotated[str, "ticker symbol of the company"],
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
look_back_days: Annotated[int, "how many days to look back"],
) -> str:
# calculate past days
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
before = date_obj - relativedelta(days=look_back_days)
start_date = before.strftime("%Y-%m-%d")
# read in data
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
# Extract just the date part for comparison
data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)
]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1)
# Set pandas display options to show the full DataFrame
with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None
):
df_string = filtered_data.to_string()
return (
f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n"
+ df_string
)
def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
# Create ticker object
ticker = yf.Ticker(symbol.upper())
# Fetch historical data for the specified date range
data = ticker.history(start=start_date, end=end_date)
# Check if data is empty
if data.empty:
return (
f"No data found for symbol '{symbol}' between {start_date} and {end_date}"
)
# Remove timezone info from index for cleaner output
if data.index.tz is not None:
data.index = data.index.tz_localize(None)
# Round numerical values to 2 decimal places for cleaner display
numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"]
for col in numeric_columns:
if col in data.columns:
data[col] = data[col].round(2)
# Convert DataFrame to CSV string
csv_string = data.to_csv()
# Add header information
header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n"
header += f"# Total records: {len(data)}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return header + csv_string
def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
) -> str:
# read in data
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
if end_date > "2025-03-25":
raise Exception(
f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25"
)
# Extract just the date part for comparison
data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)
]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1)
# remove the index from the dataframe
filtered_data = filtered_data.reset_index(drop=True)
return filtered_data
def get_stock_news_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
def get_global_news_openai(curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
def get_fundamentals_openai(ticker, curr_date):
config = get_config()
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text

View File

@ -1,12 +1,19 @@
import requests
import time
import json
from datetime import datetime, timedelta
from contextlib import contextmanager
from typing import Annotated
import os
import re
"""
Reddit API integration for social sentiment analysis.
"""
import logging
import re
from datetime import datetime, timedelta
from typing import Any
import praw
# from .api_clients import RateLimiter
logger = logging.getLogger(__name__)
# Extended ticker to company mapping
ticker_to_company = {
"AAPL": "Apple",
"MSFT": "Microsoft",
@ -38,7 +45,7 @@ ticker_to_company = {
"X": "Twitter OR X",
"SPOT": "Spotify",
"AVGO": "Broadcom",
"ASML": "ASML ",
"ASML": "ASML",
"TWLO": "Twilio",
"SNAP": "Snap Inc.",
"TEAM": "Atlassian",
@ -46,90 +53,331 @@ ticker_to_company = {
"UBER": "Uber",
"ROKU": "Roku",
"PINS": "Pinterest",
# Additional popular tickers
"SPY": "SPDR S&P 500 ETF",
"QQQ": "Invesco QQQ Trust",
"GME": "GameStop",
"AMC": "AMC Entertainment",
"BB": "BlackBerry",
"NOK": "Nokia",
"COIN": "Coinbase",
"HOOD": "Robinhood",
"RBLX": "Roblox",
"DKNG": "DraftKings",
"PENN": "Penn Entertainment",
"SOFI": "SoFi Technologies",
}
def fetch_top_from_category(
category: Annotated[
str, "Category to fetch top post from. Collection of subreddits."
],
date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
data_path: Annotated[
str,
"Path to the data folder. Default is 'reddit_data'.",
] = "reddit_data",
):
base_path = data_path
class RedditClient:
"""Client for Reddit API with rate limiting and caching."""
all_content = []
def __init__(self, client_id: str, client_secret: str, user_agent: str):
"""
Initialize Reddit client.
if max_limit < len(os.listdir(os.path.join(base_path, category))):
raise ValueError(
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
Args:
client_id: Reddit application client ID
client_secret: Reddit application client secret
user_agent: User agent string for API requests
"""
self.client_id = client_id
self.client_secret = client_secret
self.user_agent = user_agent
# Reddit allows 100 requests per minute for script applications
# self.rate_limiter = RateLimiter(100, 60)
# Initialize Reddit instance
self.reddit = praw.Reddit(
client_id=client_id, client_secret=client_secret, user_agent=user_agent
)
limit_per_subreddit = max_limit // len(
os.listdir(os.path.join(base_path, category))
)
# Default subreddits for different categories
self.subreddits = {
"investing": ["investing", "SecurityAnalysis", "ValueInvesting", "stocks"],
"trading": ["wallstreetbets", "StockMarket", "pennystocks", "options"],
"global_news": ["news", "worldnews", "Economics", "business"],
"company_news": ["investing", "stocks", "StockMarket", "SecurityAnalysis"],
}
for data_file in os.listdir(os.path.join(base_path, category)):
# check if data_file is a .jsonl file
if not data_file.endswith(".jsonl"):
continue
def test_connection(self) -> bool:
"""Test if the Reddit API connection is working."""
try:
# Test by fetching user info (read-only operation)
_ = self.reddit.user.me()
return True
except Exception as e:
logger.error(f"Reddit connection test failed: {e}")
return False
all_content_curr_subreddit = []
def search_posts(
self,
query: str,
subreddit_names: list[str],
limit: int = 25,
time_filter: str = "week",
) -> list[dict[str, Any]]:
"""
Search for posts across multiple subreddits.
with open(os.path.join(base_path, category, data_file), "rb") as f:
for i, line in enumerate(f):
# skip empty lines
if not line.strip():
continue
Args:
query: Search query
subreddit_names: List of subreddit names to search
limit: Maximum number of posts per subreddit
time_filter: Time filter ('day', 'week', 'month', 'year', 'all')
parsed_line = json.loads(line)
Returns:
List of post dictionaries
"""
posts = []
# select only lines that are from the date
post_date = datetime.utcfromtimestamp(
parsed_line["created_utc"]
).strftime("%Y-%m-%d")
if post_date != date:
continue
for subreddit_name in subreddit_names:
try:
# self.rate_limiter.wait_if_needed()
# if is company_news, check that the title or the content has the company's name (query) mentioned
if "company" in category and query:
search_terms = []
if "OR" in ticker_to_company[query]:
search_terms = ticker_to_company[query].split(" OR ")
else:
search_terms = [ticker_to_company[query]]
subreddit = self.reddit.subreddit(subreddit_name)
search_terms.append(query)
# Search posts in the subreddit
search_results = subreddit.search(
query=query, sort="relevance", time_filter=time_filter, limit=limit
)
found = False
for term in search_terms:
if re.search(
term, parsed_line["title"], re.IGNORECASE
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
found = True
break
for submission in search_results:
post_data = {
"title": submission.title,
"content": submission.selftext,
"url": submission.url,
"upvotes": submission.ups,
"score": submission.score,
"num_comments": submission.num_comments,
"created_utc": submission.created_utc,
"subreddit": subreddit_name,
"author": str(submission.author)
if submission.author
else "[deleted]",
}
posts.append(post_data)
if not found:
continue
except Exception as e:
logger.error(f"Error searching subreddit {subreddit_name}: {e}")
continue
post = {
"title": parsed_line["title"],
"content": parsed_line["selftext"],
"url": parsed_line["url"],
"upvotes": parsed_line["ups"],
"posted_date": post_date,
}
# Sort by score (upvotes - downvotes) descending
posts.sort(key=lambda x: x["score"], reverse=True)
return posts
all_content_curr_subreddit.append(post)
def get_top_posts(
self, subreddit_names: list[str], limit: int = 25, time_filter: str = "week"
) -> list[dict[str, Any]]:
"""
Get top posts from multiple subreddits.
# sort all_content_curr_subreddit by upvote_ratio in descending order
all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True)
Args:
subreddit_names: List of subreddit names
limit: Maximum number of posts per subreddit
time_filter: Time filter ('day', 'week', 'month', 'year', 'all')
all_content.extend(all_content_curr_subreddit[:limit_per_subreddit])
Returns:
List of post dictionaries
"""
posts = []
return all_content
for subreddit_name in subreddit_names:
try:
# self.rate_limiter.wait_if_needed()
subreddit = self.reddit.subreddit(subreddit_name)
# Get top posts from the subreddit
top_posts = subreddit.top(time_filter=time_filter, limit=limit)
for submission in top_posts:
post_data = {
"title": submission.title,
"content": submission.selftext,
"url": submission.url,
"upvotes": submission.ups,
"score": submission.score,
"num_comments": submission.num_comments,
"created_utc": submission.created_utc,
"subreddit": subreddit_name,
"author": str(submission.author)
if submission.author
else "[deleted]",
}
posts.append(post_data)
except Exception as e:
logger.error(f"Error fetching top posts from {subreddit_name}: {e}")
continue
# Sort by score descending
posts.sort(key=lambda x: x["score"], reverse=True)
return posts
def filter_posts_by_date(
self, posts: list[dict[str, Any]], start_date: str, end_date: str
) -> list[dict[str, Any]]:
"""
Filter posts by date range.
Args:
posts: List of post dictionaries
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
Returns:
Filtered list of posts
"""
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(
days=1
) # Include end date
filtered_posts = []
for post in posts:
post_dt = datetime.fromtimestamp(post["created_utc"])
if start_dt <= post_dt <= end_dt:
# Add formatted date string
post["posted_date"] = post_dt.strftime("%Y-%m-%d")
filtered_posts.append(post)
return filtered_posts
def filter_posts_by_company(
self, posts: list[dict[str, Any]], ticker: str
) -> list[dict[str, Any]]:
"""
Filter posts that mention a specific company/ticker.
Args:
posts: List of post dictionaries
ticker: Stock ticker symbol
Returns:
Filtered list of posts that mention the company
"""
if ticker not in ticker_to_company:
# If ticker not in mapping, search for ticker directly
search_terms = [ticker.upper()]
else:
# Get company names and ticker
company_names = ticker_to_company[ticker]
if "OR" in company_names:
search_terms = [name.strip() for name in company_names.split(" OR")]
else:
search_terms = [company_names]
search_terms.append(ticker.upper())
filtered_posts = []
for post in posts:
title_text = post["title"].lower()
content_text = post["content"].lower()
# Check if any search term appears in title or content
found = False
for term in search_terms:
term_lower = term.lower()
if re.search(
r"\b" + re.escape(term_lower) + r"\b", title_text
) or re.search(r"\b" + re.escape(term_lower) + r"\b", content_text):
found = True
break
if found:
filtered_posts.append(post)
return filtered_posts
def fetch_top_from_category(
category: str,
date: str,
max_limit: int,
query: str | None = None,
data_path: str = "reddit_data",
client_id: str | None = None,
client_secret: str | None = None,
user_agent: str | None = None,
) -> list[dict[str, Any]]:
"""
Legacy function to maintain backward compatibility.
Now uses live Reddit API if credentials are provided.
Args:
category: Category ('global_news', 'company_news', etc.)
date: Date in YYYY-MM-DD format
max_limit: Maximum number of posts
query: Optional search query (ticker for company_news)
data_path: Unused in API mode
client_id: Reddit client ID
client_secret: Reddit client secret
user_agent: Reddit user agent
Returns:
List of post dictionaries
"""
if not all([client_id, client_secret, user_agent]):
logger.warning("Reddit API credentials not provided. Returning empty data.")
return []
try:
# Type check ensures these are not None
assert client_id is not None
assert client_secret is not None
assert user_agent is not None
client = RedditClient(client_id, client_secret, user_agent)
# Determine subreddits based on category
if category == "global_news":
subreddit_names = client.subreddits["global_news"]
elif category == "company_news":
subreddit_names = client.subreddits["company_news"]
else:
# Default to investing subreddits
subreddit_names = client.subreddits["investing"]
# Calculate time filter based on date (Reddit doesn't support exact date filtering)
post_date = datetime.strptime(date, "%Y-%m-%d")
days_ago = (datetime.now() - post_date).days
if days_ago <= 1:
time_filter = "day"
elif days_ago <= 7:
time_filter = "week"
elif days_ago <= 30:
time_filter = "month"
else:
time_filter = "year"
# Get posts
if query and category == "company_news":
# Search for specific company
posts = client.search_posts(
query=query,
subreddit_names=subreddit_names,
limit=max_limit // len(subreddit_names),
time_filter=time_filter,
)
# Filter by company mentions
posts = client.filter_posts_by_company(posts, query)
else:
# Get top posts
posts = client.get_top_posts(
subreddit_names=subreddit_names,
limit=max_limit // len(subreddit_names),
time_filter=time_filter,
)
# Filter by date (approximate)
start_date = date
end_date = date
posts = client.filter_posts_by_date(posts, start_date, end_date)
# Limit results
return posts[:max_limit]
except Exception as e:
logger.error(f"Error fetching Reddit data: {e}")
return []

View File

@ -1,9 +1,11 @@
import os
from typing import Annotated
import pandas as pd
import yfinance as yf
from stockstats import wrap
from typing import Annotated
import os
from .config import get_config
from tradingagents.config import DEFAULT_CONFIG
class StockstatsUtils:
@ -13,7 +15,7 @@ class StockstatsUtils:
indicator: Annotated[
str, "quantitative indicators based off of the stock data for the company"
],
curr_date: Annotated[
curr_date_str: Annotated[
str, "curr date for retrieving stock price data, YYYY-mm-dd"
],
data_dir: Annotated[
@ -37,12 +39,14 @@ class StockstatsUtils:
)
)
df = wrap(data)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
except FileNotFoundError as err:
raise Exception(
"Stockstats fail: Yahoo Finance data not fetched yet!"
) from err
else:
# Get today's date as YYYY-mm-dd to add to cache
today_date = pd.Timestamp.today()
curr_date = pd.to_datetime(curr_date)
curr_date = pd.to_datetime(curr_date_str)
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
@ -50,11 +54,10 @@ class StockstatsUtils:
end_date = end_date.strftime("%Y-%m-%d")
# Get config and ensure cache directory exists
config = get_config()
os.makedirs(config["data_cache_dir"], exist_ok=True)
os.makedirs(DEFAULT_CONFIG.data_cache_dir, exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
DEFAULT_CONFIG.data_cache_dir,
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
)
@ -70,6 +73,10 @@ class StockstatsUtils:
progress=False,
auto_adjust=True,
)
if data is None:
raise ValueError(f"Failed to download data for {symbol}")
data = data.reset_index()
data.to_csv(data_file, index=False)
@ -78,7 +85,7 @@ class StockstatsUtils:
curr_date = curr_date.strftime("%Y-%m-%d")
df[indicator] # trigger stockstats to calculate the indicator
matching_rows = df[df["Date"].str.startswith(curr_date)]
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
if not matching_rows.empty:
indicator_value = matching_rows[indicator].values[0]

View File

@ -1,12 +1,14 @@
import os
import json
import pandas as pd
from datetime import date, timedelta, datetime
from datetime import date, datetime, timedelta
from typing import Annotated
import pandas as pd
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
def save_output(
data: pd.DataFrame, tag: str, save_path: SavePathType | None = None
) -> None:
if save_path:
data.to_csv(save_path)
print(f"{tag} saved to {save_path}")
@ -27,7 +29,6 @@ def decorate_all_methods(decorator):
def get_next_weekday(date):
if not isinstance(date, datetime):
date = datetime.strptime(date, "%Y-%m-%d")

View File

@ -1,29 +1,31 @@
# gets data/stats
import yfinance as yf
from typing import Annotated, Callable, Any, Optional
from pandas import DataFrame
from functools import lru_cache
from typing import Annotated, cast
import pandas as pd
from functools import wraps
import yfinance as yf
from pandas import DataFrame, Series
from .utils import save_output, SavePathType, decorate_all_methods
from .utils import SavePathType
def init_ticker(func: Callable) -> Callable:
"""Decorator to initialize yf.Ticker and pass it to the function."""
@wraps(func)
def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any:
ticker = yf.Ticker(symbol)
return func(ticker, *args, **kwargs)
return wrapper
# Module-level cache to avoid memory leaks with instance methods
@lru_cache(maxsize=100)
def _get_cached_ticker(symbol: str) -> yf.Ticker:
"""Get a cached yfinance Ticker instance."""
return yf.Ticker(symbol)
@decorate_all_methods(init_ticker)
class YFinanceUtils:
"""Clean YFinance utilities with ticker caching for better performance."""
def _get_ticker(self, symbol: str) -> yf.Ticker:
"""Get a cached yfinance Ticker instance."""
return _get_cached_ticker(symbol)
def get_stock_data(
self,
symbol: Annotated[str, "ticker symbol"],
start_date: Annotated[
str, "start date for retrieving stock price data, YYYY-mm-dd"
@ -31,32 +33,32 @@ class YFinanceUtils:
end_date: Annotated[
str, "end date for retrieving stock price data, YYYY-mm-dd"
],
save_path: SavePathType = None,
save_path: SavePathType | None = None,
) -> DataFrame:
"""retrieve stock price data for designated ticker symbol"""
ticker = symbol
# add one day to the end_date so that the data range is inclusive
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
end_date = end_date.strftime("%Y-%m-%d")
stock_data = ticker.history(start=start_date, end=end_date)
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
"""Retrieve stock price data for designated ticker symbol."""
ticker = self._get_ticker(symbol)
# Add one day to the end_date so that the data range is inclusive
end_date_adjusted = pd.to_datetime(end_date) + pd.DateOffset(days=1)
end_date_str = end_date_adjusted.strftime("%Y-%m-%d")
stock_data = ticker.history(start=start_date, end=end_date_str)
return stock_data
def get_stock_info(
symbol: Annotated[str, "ticker symbol"],
) -> dict:
def get_stock_info(self, symbol: Annotated[str, "ticker symbol"]) -> dict:
"""Fetches and returns latest stock information."""
ticker = symbol
stock_info = ticker.info
return stock_info
ticker = self._get_ticker(symbol)
return ticker.info
def get_company_info(
self,
symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None,
save_path: str | None = None,
) -> DataFrame:
"""Fetches and returns company information as a DataFrame."""
ticker = symbol
ticker = self._get_ticker(symbol)
info = ticker.info
company_info = {
"Company Name": info.get("shortName", "N/A"),
"Industry": info.get("industry", "N/A"),
@ -64,54 +66,77 @@ class YFinanceUtils:
"Country": info.get("country", "N/A"),
"Website": info.get("website", "N/A"),
}
company_info_df = DataFrame([company_info])
if save_path:
company_info_df.to_csv(save_path)
print(f"Company info for {ticker.ticker} saved to {save_path}")
print(f"Company info for {symbol} saved to {save_path}")
return company_info_df
def get_stock_dividends(
self,
symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None,
) -> DataFrame:
save_path: str | None = None,
) -> Series:
"""Fetches and returns the latest dividends data as a DataFrame."""
ticker = symbol
ticker = self._get_ticker(symbol)
dividends = ticker.dividends
if save_path:
dividends.to_csv(save_path)
print(f"Dividends for {ticker.ticker} saved to {save_path}")
print(f"Dividends for {symbol} saved to {save_path}")
return dividends
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
def get_income_stmt(self, symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest income statement of the company as a DataFrame."""
ticker = symbol
income_stmt = ticker.financials
return income_stmt
ticker = self._get_ticker(symbol)
return ticker.financials
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
def get_balance_sheet(self, symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
ticker = symbol
balance_sheet = ticker.balance_sheet
return balance_sheet
ticker = self._get_ticker(symbol)
return ticker.balance_sheet
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
def get_cash_flow(self, symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
ticker = symbol
cash_flow = ticker.cashflow
return cash_flow
ticker = self._get_ticker(symbol)
return ticker.cashflow
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
def get_analyst_recommendations(
self, symbol: Annotated[str, "ticker symbol"]
) -> tuple[str | None, int]:
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
ticker = symbol
recommendations = ticker.recommendations
if recommendations.empty:
ticker = self._get_ticker(symbol)
recommendations = cast("DataFrame", ticker.recommendations)
if recommendations is None or recommendations.empty:
return None, 0 # No recommendations available
# Assuming 'period' column exists and needs to be excluded
row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary
# Get the most recent recommendation row (excluding 'period' column if it exists)
try:
row_0 = recommendations.iloc[0, 1:] # Skip first column (likely 'period')
# Find the maximum voting result
max_votes = row_0.max()
majority_voting_result = row_0[row_0 == max_votes].index.tolist()
# Find the maximum voting result
max_votes = row_0.max()
majority_voting_result = row_0[row_0 == max_votes].index.tolist()
return majority_voting_result[0], max_votes
return majority_voting_result[0], int(max_votes)
except (IndexError, KeyError):
return None, 0
def clear_cache(self) -> None:
"""Clear the ticker cache. Useful for testing or memory management."""
_get_cached_ticker.cache_clear()
def cache_info(self) -> dict:
"""Get information about the ticker cache."""
info = _get_cached_ticker.cache_info()
return {
"hits": info.hits,
"misses": info.misses,
"maxsize": info.maxsize,
"currsize": info.currsize,
}

View File

@ -1,22 +0,0 @@
import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
"max_recur_limit": 100,
# Tool settings
"online_tools": True,
}

View File

@ -1,11 +1,11 @@
# TradingAgents/graph/__init__.py
from .trading_graph import TradingAgentsGraph
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor
from .trading_graph import TradingAgentsGraph
__all__ = [
"TradingAgentsGraph",

View File

@ -15,7 +15,9 @@ class ConditionalLogic:
"""Determine if market analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
if hasattr(last_message, "tool_calls") and getattr(
last_message, "tool_calls", None
):
return "tools_market"
return "Msg Clear Market"
@ -23,7 +25,9 @@ class ConditionalLogic:
"""Determine if social media analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
if hasattr(last_message, "tool_calls") and getattr(
last_message, "tool_calls", None
):
return "tools_social"
return "Msg Clear Social"
@ -31,7 +35,9 @@ class ConditionalLogic:
"""Determine if news analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
if hasattr(last_message, "tool_calls") and getattr(
last_message, "tool_calls", None
):
return "tools_news"
return "Msg Clear News"
@ -39,7 +45,9 @@ class ConditionalLogic:
"""Determine if fundamentals analysis should continue."""
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
if hasattr(last_message, "tool_calls") and getattr(
last_message, "tool_calls", None
):
return "tools_fundamentals"
return "Msg Clear Fundamentals"

View File

@ -1,8 +1,8 @@
# TradingAgents/graph/propagation.py
from typing import Dict, Any
from typing import Any
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
@ -17,21 +17,33 @@ class Propagator:
def create_initial_state(
self, company_name: str, trade_date: str
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Create the initial state for the agent graph."""
return {
"messages": [("human", company_name)],
"company_of_interest": company_name,
"trade_date": str(trade_date),
"investment_debate_state": InvestDebateState(
{"history": "", "current_response": "", "count": 0}
{
"bull_history": "",
"bear_history": "",
"history": "",
"current_response": "",
"judge_decision": "",
"count": 0,
}
),
"risk_debate_state": RiskDebateState(
{
"risky_history": "",
"safe_history": "",
"neutral_history": "",
"history": "",
"latest_speaker": "",
"current_risky_response": "",
"current_safe_response": "",
"current_neutral_response": "",
"judge_decision": "",
"count": 0,
}
),
@ -41,7 +53,7 @@ class Propagator:
"news_report": "",
}
def get_graph_args(self) -> Dict[str, Any]:
def get_graph_args(self) -> dict[str, Any]:
"""Get arguments for the graph invocation."""
return {
"stream_mode": "values",

View File

@ -1,13 +1,19 @@
# TradingAgents/graph/reflection.py
from typing import Dict, Any
from typing import Any
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
class Reflector:
"""Handles reflection on decisions and updating memory."""
def __init__(self, quick_thinking_llm: ChatOpenAI):
def __init__(
self,
quick_thinking_llm: ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI,
):
"""Initialize the reflector with an LLM."""
self.quick_thinking_llm = quick_thinking_llm
self.reflection_system_prompt = self._get_reflection_prompt()
@ -15,7 +21,7 @@ class Reflector:
def _get_reflection_prompt(self) -> str:
"""Get the system prompt for reflection."""
return """
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
1. Reasoning:
@ -25,7 +31,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh
- Technical indicators.
- Technical signals.
- Price movement analysis.
- Overall market data analysis
- Overall market data analysis
- News analysis.
- Social media and sentiment analysis.
- Fundamental data analysis.
@ -46,7 +52,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
"""
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
def _extract_current_situation(self, current_state: dict[str, Any]) -> str:
"""Extract the current market situation from the state."""
curr_market_report = current_state["market_report"]
curr_sentiment_report = current_state["sentiment_report"]
@ -68,7 +74,13 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur
]
result = self.quick_thinking_llm.invoke(messages).content
return result
# Ensure we return a string
if isinstance(result, str):
return result
elif isinstance(result, list):
return str(result)
else:
return str(result)
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
"""Reflect on bull researcher's analysis and update memory."""

View File

@ -1,11 +1,27 @@
# TradingAgents/graph/setup.py
from typing import Dict, Any
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph, START
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.agents import (
create_bear_researcher,
create_bull_researcher,
create_fundamentals_analyst,
create_market_analyst,
create_msg_delete,
create_neutral_debator,
create_news_analyst,
create_research_manager,
create_risk_manager,
create_risky_debator,
create_safe_debator,
create_social_media_analyst,
create_trader,
)
from tradingagents.agents.utils.agent_states import AgentState
from tradingagents.agents.utils.agent_utils import Toolkit
@ -17,10 +33,10 @@ class GraphSetup:
def __init__(
self,
quick_thinking_llm: ChatOpenAI,
deep_thinking_llm: ChatOpenAI,
quick_thinking_llm: ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI,
deep_thinking_llm: ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI,
toolkit: Toolkit,
tool_nodes: Dict[str, ToolNode],
tool_nodes: dict[str, ToolNode],
bull_memory,
bear_memory,
trader_memory,
@ -40,9 +56,7 @@ class GraphSetup:
self.risk_manager_memory = risk_manager_memory
self.conditional_logic = conditional_logic
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
):
def setup_graph(self, selected_analysts=None):
"""Set up and compile the agent workflow graph.
Args:
@ -52,6 +66,9 @@ class GraphSetup:
- "news": News analyst
- "fundamentals": Fundamentals analyst
"""
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
@ -150,7 +167,7 @@ class GraphSetup:
# Connect to next analyst or to Bull Researcher if this is the last analyst
if i < len(selected_analysts) - 1:
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
next_analyst = f"{selected_analysts[i + 1].capitalize()} Analyst"
workflow.add_edge(current_clear, next_analyst)
else:
workflow.add_edge(current_clear, "Bull Researcher")

View File

@ -1,12 +1,18 @@
# TradingAgents/graph/signal_processing.py
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
class SignalProcessor:
"""Processes trading signals to extract actionable decisions."""
def __init__(self, quick_thinking_llm: ChatOpenAI):
def __init__(
self,
quick_thinking_llm: ChatOpenAI | ChatAnthropic | ChatGoogleGenerativeAI,
):
"""Initialize with an LLM for processing."""
self.quick_thinking_llm = quick_thinking_llm
@ -28,4 +34,11 @@ class SignalProcessor:
("human", full_signal),
]
return self.quick_thinking_llm.invoke(messages).content
result = self.quick_thinking_llm.invoke(messages).content
# Ensure we return a string
if isinstance(result, str):
return result
elif isinstance(result, list):
return str(result)
else:
return str(result)

View File

@ -1,31 +1,22 @@
# TradingAgents/graph/trading_graph.py
import json
import os
from pathlib import Path
import json
from datetime import date
from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from tradingagents.agents import *
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.agent_utils import Toolkit
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
RiskDebateState,
)
from tradingagents.dataflows.interface import set_config
from tradingagents.config import TradingAgentsConfig
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
from .reflection import Reflector
from .setup import GraphSetup
from .signal_processing import SignalProcessor
@ -34,50 +25,70 @@ class TradingAgentsGraph:
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
selected_analysts=None,
debug=False,
config: Dict[str, Any] = None,
config: TradingAgentsConfig | None = None,
):
"""Initialize the trading agents graph and components.
Args:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
config: Configuration object. If None, uses default config
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
# Update the interface's config
set_config(self.config)
self.debug = debug
self.config = config or TradingAgentsConfig()
# Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
os.path.join(self.config.project_dir, "dataflows/data_cache"),
exist_ok=True,
)
# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
if (
self.config.llm_provider.lower() == "openai"
or self.config.llm_provider == "ollama"
or self.config.llm_provider == "openrouter"
):
self.deep_thinking_llm = ChatOpenAI(
model=self.config.deep_think_llm, base_url=self.config.backend_url
)
self.quick_thinking_llm = ChatOpenAI(
model=self.config.quick_think_llm, base_url=self.config.backend_url
)
elif self.config.llm_provider.lower() == "anthropic":
self.deep_thinking_llm = ChatAnthropic(
model_name=self.config.deep_think_llm, timeout=60, stop=[]
)
self.quick_thinking_llm = ChatAnthropic(
model_name=self.config.quick_think_llm, timeout=60, stop=[]
)
elif self.config.llm_provider.lower() == "google":
self.deep_thinking_llm = ChatGoogleGenerativeAI(
model=self.config.deep_think_llm
)
self.quick_thinking_llm = ChatGoogleGenerativeAI(
model=self.config.quick_think_llm
)
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
raise ValueError(f"Unsupported LLM provider: {self.config.llm_provider}")
self.toolkit = Toolkit(config=self.config)
# Initialize memories
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory(
"invest_judge_memory", self.config
)
self.risk_manager_memory = FinancialSituationMemory(
"risk_manager_memory", self.config
)
# Create tool nodes
self.tool_nodes = self._create_tool_nodes()
@ -109,7 +120,7 @@ class TradingAgentsGraph:
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
def _create_tool_nodes(self) -> dict[str, ToolNode]:
"""Create tool nodes for different data sources."""
return {
"market": ToolNode(

View File

@ -0,0 +1,33 @@
"""
Pydantic models for structured data context in TradingAgents.
"""
from .context import (
ArticleData,
DataQuality,
FinancialStatement,
FundamentalContext,
InsiderContext,
InsiderTransaction,
MarketDataContext,
NewsContext,
PostData,
SentimentScore,
SocialContext,
TechnicalIndicatorData,
)
__all__ = [
"DataQuality",
"SentimentScore",
"MarketDataContext",
"NewsContext",
"SocialContext",
"FundamentalContext",
"InsiderContext",
"TechnicalIndicatorData",
"ArticleData",
"PostData",
"FinancialStatement",
"InsiderTransaction",
]

View File

@ -0,0 +1,292 @@
"""
Pydantic models for structured context objects in TradingAgents.
These models define the schema for JSON context objects that services
provide to agents, replacing the previous markdown string approach.
"""
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, validator
class DataQuality(str, Enum):
"""Data quality indicator for context metadata."""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class SentimentScore(BaseModel):
"""Sentiment analysis result with confidence."""
score: float = Field(
...,
ge=-1.0,
le=1.0,
description="Sentiment score from -1 (negative) to 1 (positive)",
)
confidence: float = Field(
..., ge=0.0, le=1.0, description="Confidence in sentiment score"
)
label: str | None = Field(
default=None, description="Human-readable sentiment label"
)
@validator("label", pre=True, always=True)
def set_sentiment_label(cls, v, values):
if v is not None:
return v
score = values.get("score", 0)
if score > 0.1:
return "positive"
elif score < -0.1:
return "negative"
else:
return "neutral"
class TechnicalIndicatorData(BaseModel):
"""Technical indicator data point."""
date: str = Field(..., description="Date in YYYY-MM-DD format")
value: float | dict[str, float] = Field(
..., description="Indicator value or values"
)
indicator_type: str = Field(
..., description="Type of indicator (e.g., 'rsi', 'macd', 'sma')"
)
class ArticleData(BaseModel):
"""News article data."""
headline: str = Field(..., description="Article headline")
summary: str | None = Field(default=None, description="Article summary or snippet")
url: str | None = Field(default=None, description="Article URL")
source: str = Field(..., description="News source")
date: str = Field(..., description="Publication date in YYYY-MM-DD format")
sentiment: SentimentScore | None = Field(
default=None, description="Article sentiment analysis"
)
entities: list[str] = Field(
default_factory=list, description="Named entities mentioned"
)
class PostData(BaseModel):
"""Social media post data."""
title: str = Field(..., description="Post title")
content: str | None = Field(default=None, description="Post content")
author: str = Field(..., description="Post author")
source: str = Field(..., description="Post source (e.g., subreddit name)")
date: str = Field(..., description="Post date in YYYY-MM-DD format")
url: str | None = Field(default=None, description="Post URL")
score: int = Field(default=0, description="Post score/upvotes")
comments: int = Field(default=0, description="Number of comments")
engagement_score: int = Field(default=0, description="Combined engagement metric")
subreddit: str | None = Field(default=None, description="Subreddit name")
sentiment: SentimentScore | None = Field(
default=None, description="Post sentiment analysis"
)
metadata: dict[str, Any] = Field(
default_factory=dict, description="Additional post metadata"
)
class FinancialStatement(BaseModel):
"""Financial statement data."""
period: str = Field(..., description="Reporting period (e.g., '2024-Q1', '2024')")
report_date: str = Field(..., description="Report date in YYYY-MM-DD format")
publish_date: str = Field(..., description="Publish date in YYYY-MM-DD format")
currency: str = Field(default="USD", description="Currency of financial data")
data: dict[str, float] = Field(..., description="Financial statement line items")
class InsiderTransaction(BaseModel):
"""Insider trading transaction data."""
filing_date: str = Field(..., description="Filing date in YYYY-MM-DD format")
name: str = Field(..., description="Insider name")
change: float = Field(..., description="Change in shares")
shares: float = Field(..., description="Total shares after transaction")
transaction_price: float = Field(..., description="Price per share")
transaction_code: str = Field(
..., description="Transaction code (e.g., 'S' for sale)"
)
class MarketDataContext(BaseModel):
"""Market data context with price and technical indicators."""
symbol: str = Field(..., description="Stock ticker symbol")
period: dict[str, str] = Field(
..., description="Date range with start and end keys"
)
price_data: list[dict[str, Any]] = Field(..., description="Historical price data")
technical_indicators: dict[str, list[TechnicalIndicatorData]] = Field(
default_factory=dict, description="Technical indicators organized by type"
)
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Context metadata including data quality and source info",
)
@validator("period")
def validate_period(cls, v):
required_keys = {"start", "end"}
if not required_keys.issubset(v.keys()):
raise ValueError(f"Period must contain keys: {required_keys}")
return v
class NewsContext(BaseModel):
"""News context with articles and sentiment analysis."""
symbol: str | None = Field(
default=None, description="Stock ticker if company-specific"
)
period: dict[str, str] = Field(
..., description="Date range with start and end keys"
)
articles: list[ArticleData] = Field(..., description="News articles")
sentiment_summary: SentimentScore = Field(
..., description="Overall sentiment across articles"
)
article_count: int = Field(..., description="Total number of articles")
sources: list[str] = Field(default_factory=list, description="Unique news sources")
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Context metadata including data quality and coverage info",
)
@validator("article_count", pre=True, always=True)
def set_article_count(cls, v, values):
if v is not None:
return v
return len(values.get("articles", []))
@validator("sources", pre=True, always=True)
def set_sources(cls, v, values):
if v:
return v
articles = values.get("articles", [])
return list({article.source for article in articles})
class SocialContext(BaseModel):
"""Social media context with posts and engagement metrics."""
symbol: str | None = Field(
default=None, description="Stock ticker if company-specific"
)
period: dict[str, str] = Field(
..., description="Date range with start and end keys"
)
posts: list[PostData] = Field(..., description="Social media posts")
engagement_metrics: dict[str, float] = Field(
default_factory=dict, description="Aggregated engagement metrics"
)
sentiment_summary: SentimentScore = Field(
..., description="Overall sentiment across posts"
)
post_count: int = Field(..., description="Total number of posts")
platforms: list[str] = Field(
default_factory=list, description="Social media platforms"
)
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Context metadata including data quality and platform info",
)
@validator("post_count", pre=True, always=True)
def set_post_count(cls, v, values):
if v is not None:
return v
return len(values.get("posts", []))
@property
def platform(self) -> str | None:
"""Primary platform for backward compatibility."""
return self.platforms[0] if self.platforms else None
class FundamentalContext(BaseModel):
"""Fundamental analysis context with financial statements."""
symbol: str = Field(..., description="Stock ticker symbol")
period: dict[str, str] = Field(
..., description="Date range with start and end keys"
)
balance_sheet: FinancialStatement | None = Field(
default=None, description="Balance sheet data"
)
income_statement: FinancialStatement | None = Field(
default=None, description="Income statement data"
)
cash_flow: FinancialStatement | None = Field(
default=None, description="Cash flow statement data"
)
key_ratios: dict[str, float] = Field(
default_factory=dict, description="Calculated financial ratios"
)
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Context metadata including data quality and completeness",
)
class InsiderContext(BaseModel):
"""Insider trading context with transaction data and sentiment."""
symbol: str = Field(..., description="Stock ticker symbol")
period: dict[str, str] = Field(
..., description="Date range with start and end keys"
)
transactions: list[InsiderTransaction] = Field(
..., description="Insider transactions"
)
sentiment_data: dict[str, Any] = Field(
default_factory=dict, description="Insider sentiment metrics"
)
transaction_count: int = Field(..., description="Total number of transactions")
net_activity: dict[str, float] = Field(
default_factory=dict, description="Net buying/selling activity metrics"
)
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Context metadata including data quality and coverage",
)
@validator("transaction_count", pre=True, always=True)
def set_transaction_count(cls, v, values):
if v is not None:
return v
return len(values.get("transactions", []))
# Base context for extensibility
class BaseContext(BaseModel):
"""Base context model for common fields."""
period: dict[str, str] = Field(
..., description="Date range with start and end keys"
)
metadata: dict[str, Any] = Field(
default_factory=lambda: {
"data_quality": DataQuality.MEDIUM,
"created_at": datetime.utcnow().isoformat(),
"source": "unknown",
},
description="Context metadata",
)
class Config:
use_enum_values = True # Serialize enums as values
json_encoders = {datetime: lambda v: v.isoformat()}

View File

@ -0,0 +1,21 @@
"""
Repository classes for historical data access in TradingAgents.
"""
from .fundamental_repository import FundamentalDataRepository
from .insider_repository import InsiderDataRepository
from .llm_repository import LLMRepository
from .market_data_repository import MarketDataRepository
from .news_repository import NewsRepository
from .openai_repository import OpenAIRepository
from .social_repository import SocialRepository
__all__ = [
"MarketDataRepository",
"NewsRepository",
"SocialRepository",
"FundamentalDataRepository",
"InsiderDataRepository",
"OpenAIRepository",
"LLMRepository",
]

View File

@ -0,0 +1,21 @@
"""
Base repository class with common utilities.
"""
from pathlib import Path
class BaseRepository:
"""Base repository class with shared utility methods."""
def _ensure_path_exists(self, path: Path) -> None:
"""
Ensure a directory path exists, creating it if necessary.
Args:
path: Path to ensure exists (can be file path - will create parent dirs)
"""
if path.suffix: # It's a file path, create parent directories
path.parent.mkdir(parents=True, exist_ok=True)
else: # It's a directory path, create the directory itself
path.mkdir(parents=True, exist_ok=True)

View File

@ -0,0 +1,313 @@
"""
Repository for fundamental financial data (balance sheets, income statements, cash flow).
"""
import json
import logging
from dataclasses import asdict, dataclass
from datetime import date
from pathlib import Path
from .base import BaseRepository
logger = logging.getLogger(__name__)
@dataclass
class FinancialStatement:
"""Represents a financial statement with standardized structure matching the service."""
period: str
report_date: str
publish_date: str
currency: str
data: dict[str, float]
@dataclass
class FinancialData:
"""Container for all financial statements for a symbol and date."""
symbol: str
date: date
financial_statements: dict[
str, FinancialStatement
] # "balance_sheet", "income_statement", "cash_flow"
class FundamentalDataRepository(BaseRepository):
"""Repository for accessing cached fundamental financial data as a KV store."""
def __init__(self, data_dir: str, **kwargs):
"""
Initialize fundamental data repository.
Args:
data_dir: Base directory for fundamental data storage
**kwargs: Additional configuration
"""
self.fundamental_data_dir = Path(data_dir) / "fundamental_data"
self.fundamental_data_dir.mkdir(parents=True, exist_ok=True)
def get_financial_data(
self, symbol: str, start_date: date, end_date: date
) -> dict[date, FinancialData]:
"""
Get cached fundamental data for a symbol and date range.
Args:
symbol: Stock ticker symbol
start_date: Start date
end_date: End date
Returns:
Dict[date, FinancialData]: Financial data keyed by date
"""
symbol_dir = self.fundamental_data_dir / symbol
if not symbol_dir.exists():
logger.warning(f"No data directory found for symbol {symbol}")
return {}
financial_data = {}
# Scan for JSON files in the symbol directory
for json_file in symbol_dir.glob("*.json"):
try:
# Parse date from filename (YYYY-MM-DD.json)
date_str = json_file.stem
file_date = date.fromisoformat(date_str)
# Filter by date range
if start_date <= file_date <= end_date:
with open(json_file) as f:
data = json.load(f)
# Create FinancialStatement objects from JSON data
financial_statements = {}
for statement_type, statement_data in data.get(
"financial_statements", {}
).items():
financial_statements[statement_type] = FinancialStatement(
**statement_data
)
# Create FinancialData container
financial_data[file_date] = FinancialData(
symbol=symbol,
date=file_date,
financial_statements=financial_statements,
)
except (ValueError, json.JSONDecodeError, KeyError) as e:
logger.error(f"Error reading financial data from {json_file}: {e}")
continue
logger.info(f"Retrieved {len(financial_data)} financial records for {symbol}")
return financial_data
def has_data_for_period(
self, symbol: str, start_date: str, end_date: str, frequency: str = "quarterly"
) -> bool:
"""
Check if we have sufficient data for the given period.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency (not used in this implementation)
Returns:
bool: True if we have any data in the period
"""
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
financial_data = self.get_financial_data(symbol, start_dt, end_dt)
return len(financial_data) > 0
def get_data(
self,
symbol: str,
start_date: str,
end_date: str,
frequency: str = "quarterly",
**kwargs,
) -> dict[str, any]:
"""
Get data in the format expected by the service.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency
**kwargs: Additional parameters
Returns:
Dict with financial_statements structure expected by service
"""
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
financial_data = self.get_financial_data(symbol, start_dt, end_dt)
if not financial_data:
return {}
# Get the most recent data
latest_date = max(financial_data.keys())
latest_data = financial_data[latest_date]
return {
"financial_statements": {
statement_type: asdict(statement)
for statement_type, statement in latest_data.financial_statements.items()
}
}
def store_data(
self,
symbol: str,
cache_data: dict,
frequency: str = "quarterly",
overwrite: bool = False,
) -> bool:
"""
Store data in the format expected by the service.
Args:
symbol: Stock ticker symbol
cache_data: Data dictionary with financial_statements
frequency: Reporting frequency (not used)
overwrite: Whether to overwrite existing data
Returns:
bool: True if successful
"""
try:
# Extract financial statements from cache_data
statements_data = cache_data.get("financial_statements", {})
if not statements_data:
logger.warning(
f"No financial statements found in cache_data for {symbol}"
)
return False
# Use today's date as the storage date
storage_date = date.today()
# Convert to FinancialStatement objects
financial_statements = {}
for statement_type, statement_dict in statements_data.items():
financial_statements[statement_type] = FinancialStatement(
**statement_dict
)
# Store the statements
self.store_financial_statements(symbol, storage_date, financial_statements)
return True
except Exception as e:
logger.error(f"Error storing data for {symbol}: {e}")
return False
def clear_data(
self, symbol: str, start_date: str, end_date: str, frequency: str = "quarterly"
) -> bool:
"""
Clear data for a symbol and date range.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency (not used)
Returns:
bool: True if successful
"""
try:
symbol_dir = self.fundamental_data_dir / symbol
if not symbol_dir.exists():
return True # Nothing to clear
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
# Remove files in the date range
for json_file in symbol_dir.glob("*.json"):
try:
date_str = json_file.stem
file_date = date.fromisoformat(date_str)
if start_dt <= file_date <= end_dt:
json_file.unlink()
logger.debug(f"Removed {json_file}")
except (ValueError, OSError) as e:
logger.warning(f"Error removing {json_file}: {e}")
continue
return True
except Exception as e:
logger.error(f"Error clearing data for {symbol}: {e}")
return False
def store_financial_statements(
self,
symbol: str,
date: date,
statements: dict[str, FinancialStatement],
) -> tuple[date, dict[str, FinancialStatement]]:
"""
Store financial statements for a symbol and date.
Args:
symbol: Stock ticker symbol
date: Date of the financial statements
statements: Dictionary of statements keyed by type ("balance_sheet", "income_statement", "cash_flow")
Returns:
Tuple[date, dict[str, FinancialStatement]]: The stored date and statements
"""
# Create symbol directory
symbol_dir = self.fundamental_data_dir / symbol
self._ensure_path_exists(symbol_dir)
# Create JSON file path
file_path = symbol_dir / f"{date.isoformat()}.json"
try:
# Prepare data for JSON serialization
statements_data = {}
for statement_type, statement in statements.items():
statements_data[statement_type] = asdict(statement)
data = {
"symbol": symbol,
"date": date.isoformat(),
"financial_statements": statements_data,
"metadata": {
"stored_at": date.today().isoformat(),
"repository": "fundamental_data_repository",
"statement_count": len(statements),
},
}
# Write to JSON file
with open(file_path, "w") as f:
json.dump(data, f, indent=2, default=str)
logger.info(
f"Stored {len(statements)} financial statements for {symbol} on {date}"
)
return (date, statements)
except Exception as e:
logger.error(
f"Error storing financial statements for {symbol} on {date}: {e}"
)
raise

View File

@ -0,0 +1,136 @@
"""
Repository for historical market data (CSV files).
"""
import logging
from datetime import date
from pathlib import Path
import pandas as pd
from .base import BaseRepository
logger = logging.getLogger(__name__)
class MarketDataRepository(BaseRepository):
"""Repository for accessing historical market data from CSV files."""
def __init__(self, data_dir: str, **kwargs):
"""
Initialize market data repository.
Args:
data_dir: Base directory for market data storage
**kwargs: Additional configuration
"""
self.market_data_dir = Path(data_dir) / "market_data"
self.market_data_dir.mkdir(parents=True, exist_ok=True)
def get_market_data_df(
self, symbol: str, start_date: date, end_date: date
) -> pd.DataFrame:
"""
Get historical market data as DataFrame for a symbol and date range.
Args:
symbol: Stock ticker symbol
start_date: Start date
end_date: End date
Returns:
pd.DataFrame: Market data filtered by date range
"""
csv_path = self.market_data_dir / f"{symbol}.csv"
if not csv_path.exists():
logger.warning(f"No CSV file found for symbol {symbol} at {csv_path}")
return pd.DataFrame()
try:
# Read CSV file
df = pd.read_csv(csv_path)
if df.empty:
logger.warning(f"Empty CSV file for symbol {symbol}")
return pd.DataFrame()
# Convert Date column to date objects for filtering
if "Date" in df.columns:
df["Date"] = pd.to_datetime(df["Date"]).dt.date
# Filter by date range
filtered_df = df[
(df["Date"] >= start_date) & (df["Date"] <= end_date)
].copy()
logger.info(
f"Retrieved {len(filtered_df)} records for {symbol} from {start_date} to {end_date}"
)
return filtered_df
else:
logger.warning(f"No 'Date' column found in {csv_path}")
return df
except Exception as e:
logger.error(f"Error reading CSV file {csv_path}: {e}")
return pd.DataFrame()
def store_marketdata(self, symbol: str, marketdata: pd.DataFrame) -> pd.DataFrame:
"""
Store market data DataFrame to CSV file, appending or replacing existing data.
Args:
symbol: Stock ticker symbol
marketdata: DataFrame with market data to store
Returns:
pd.DataFrame: The combined DataFrame that was stored
"""
if marketdata.empty:
logger.warning(f"Empty DataFrame provided for {symbol}")
return marketdata
csv_path = self.market_data_dir / f"{symbol}.csv"
try:
if csv_path.exists():
# Load existing data
existing_df = pd.read_csv(csv_path)
if not existing_df.empty and "Date" in existing_df.columns:
# Ensure Date columns are in same format for comparison
existing_df["Date"] = pd.to_datetime(
existing_df["Date"]
).dt.strftime("%Y-%m-%d")
marketdata_copy = marketdata.copy()
marketdata_copy["Date"] = pd.to_datetime(
marketdata_copy["Date"]
).dt.strftime("%Y-%m-%d")
# Combine and remove duplicates by Date, keeping newer data
combined_df = pd.concat(
[existing_df, marketdata_copy], ignore_index=True
)
combined_df = combined_df.drop_duplicates(
subset=["Date"], keep="last"
)
combined_df = combined_df.sort_values("Date").reset_index(drop=True)
else:
# Existing file is empty or malformed, use new data
combined_df = marketdata.copy()
else:
# No existing file, use new data
combined_df = marketdata.copy()
# Save to CSV
combined_df.to_csv(csv_path, index=False)
logger.info(
f"Stored {len(marketdata)} records for {symbol}, total records: {len(combined_df)}"
)
return combined_df
except Exception as e:
logger.error(f"Error storing market data for {symbol}: {e}")
raise

View File

@ -0,0 +1,305 @@
"""
Repository for historical news data (cached files).
"""
import json
import logging
from dataclasses import asdict, dataclass, field
from datetime import date
from pathlib import Path
from .base import BaseRepository
logger = logging.getLogger(__name__)
@dataclass
class NewsArticle:
"""Represents a news article."""
headline: str
url: str # Unique identifier for deduplication
source: str # "Finnhub", "Google News", etc.
published_date: date
# Optional fields
summary: str | None = None
entities: list[str] = field(default_factory=list)
sentiment_score: float | None = None
author: str | None = None
category: str | None = None
@dataclass
class NewsData:
"""Container for news data with metadata."""
query: str
date: date
source: str # "finnhub", "google_news"
articles: list[NewsArticle]
class NewsRepository(BaseRepository):
"""Repository for accessing cached news data with source separation."""
def __init__(self, data_dir: str, **kwargs):
"""
Initialize news repository.
Args:
data_dir: Base directory for news data storage
**kwargs: Additional configuration
"""
self.news_data_dir = Path(data_dir) / "news_data"
self.news_data_dir.mkdir(parents=True, exist_ok=True)
def get_news_data(
self,
query: str,
start_date: date,
end_date: date,
sources: list[str] | None = None,
) -> dict[date, list[NewsData]]:
"""
Get cached news data for a query and date range across sources.
Args:
query: Search query or symbol
start_date: Start date
end_date: End date
sources: List of sources to check (default: ["finnhub", "google_news"])
Returns:
Dict[date, list[NewsData]]: News data keyed by date, with list of source data
"""
if sources is None:
sources = ["finnhub", "google_news"]
news_data = {}
for source in sources:
source_dir = self.news_data_dir / source / query
if not source_dir.exists():
logger.debug(f"No data directory found for {source}/{query}")
continue
# Scan for JSON files in the source/query directory
for json_file in source_dir.glob("*.json"):
try:
# Parse date from filename (YYYY-MM-DD.json)
date_str = json_file.stem
file_date = date.fromisoformat(date_str)
# Filter by date range
if start_date <= file_date <= end_date:
with open(json_file) as f:
data = json.load(f)
# Create NewsArticle objects from JSON data
articles = []
for article_data in data.get("articles", []):
# Convert date strings back to date objects
article_data_copy = article_data.copy()
if "published_date" in article_data_copy:
article_data_copy["published_date"] = (
date.fromisoformat(
article_data_copy["published_date"]
)
)
article = NewsArticle(**article_data_copy)
articles.append(article)
# Create NewsData container
news_data_item = NewsData(
query=query,
date=file_date,
source=source,
articles=articles,
)
# Group by date (multiple sources per date)
if file_date not in news_data:
news_data[file_date] = []
news_data[file_date].append(news_data_item)
except (ValueError, json.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"Error reading news data from {json_file}: {e}")
continue
logger.info(
f"Retrieved news data for {len(news_data)} dates for query '{query}'"
)
return news_data
def store_news_articles(
self,
query: str,
date: date,
source: str,
articles: list[NewsArticle],
) -> tuple[date, NewsData]:
"""
Store news articles for a query, date, and source, merging with existing data.
Args:
query: Search query or symbol
date: Date of the news articles
source: News source ("finnhub", "google_news", etc.)
articles: List of news articles
Returns:
Tuple[date, NewsData]: The stored date and news data
"""
# Create source/query directory
source_dir = self.news_data_dir / source / query
self._ensure_path_exists(source_dir)
# Create JSON file path
file_path = source_dir / f"{date.isoformat()}.json"
try:
# Merge with existing articles if file exists
merged_articles = self._merge_articles_with_existing(file_path, articles)
# Prepare data for JSON serialization
articles_data = []
for article in merged_articles:
article_dict = asdict(article)
# Convert date objects to ISO format strings for JSON
if article_dict.get("published_date"):
article_dict["published_date"] = article_dict[
"published_date"
].isoformat()
articles_data.append(article_dict)
data = {
"query": query,
"date": date.isoformat(),
"source": source,
"articles": articles_data,
"metadata": {
"article_count": len(merged_articles),
"stored_at": date.today().isoformat(),
"repository": "news_repository",
},
}
# Write to JSON file
with open(file_path, "w") as f:
json.dump(data, f, indent=2, default=str)
# Create NewsData result
news_data = NewsData(
query=query, date=date, source=source, articles=merged_articles
)
logger.info(
f"Stored {len(articles)} new articles for {query} on {date} from {source} (total: {len(merged_articles)})"
)
return (date, news_data)
except Exception as e:
logger.error(
f"Error storing news articles for {query} on {date} from {source}: {e}"
)
raise
def store_news_data_batch(
self,
query: str,
news_data_by_source: dict[str, dict[date, list[NewsArticle]]],
) -> dict[date, list[NewsData]]:
"""
Store multiple news data sets for a query across sources.
Args:
query: Search query or symbol
news_data_by_source: Nested dict of {source: {date: [articles]}}
Returns:
Dict[date, list[NewsData]]: The stored news data organized by date
"""
stored_data = {}
for source, date_articles in news_data_by_source.items():
for article_date, articles in date_articles.items():
try:
stored_date, stored_news_data = self.store_news_articles(
query, article_date, source, articles
)
# Group by date
if stored_date not in stored_data:
stored_data[stored_date] = []
stored_data[stored_date].append(stored_news_data)
except Exception as e:
logger.error(
f"Failed to store news data for {query} on {article_date} from {source}: {e}"
)
continue
total_dates = len(stored_data)
total_sources = sum(len(news_list) for news_list in stored_data.values())
logger.info(
f"Stored news data for {total_dates} dates, {total_sources} source entries for query '{query}'"
)
return stored_data
def _merge_articles_with_existing(
self, file_path: Path, new_articles: list[NewsArticle]
) -> list[NewsArticle]:
"""
Merge new articles with existing articles, deduplicating by URL.
Args:
file_path: Path to existing JSON file
new_articles: New articles to merge
Returns:
List[NewsArticle]: Merged and deduplicated articles
"""
existing_articles = []
# Load existing articles if file exists
if file_path.exists():
try:
with open(file_path) as f:
data = json.load(f)
for existing_data in data.get("articles", []):
# Convert date strings back to date objects
existing_data_copy = existing_data.copy()
if "published_date" in existing_data_copy:
existing_data_copy["published_date"] = date.fromisoformat(
existing_data_copy["published_date"]
)
existing_article = NewsArticle(**existing_data_copy)
existing_articles.append(existing_article)
except (json.JSONDecodeError, KeyError, ValueError, TypeError) as e:
logger.warning(f"Error reading existing file {file_path}: {e}")
existing_articles = []
# Merge articles, deduplicating by URL (keep newer data)
articles_by_url = {}
# Add existing articles
for article in existing_articles:
articles_by_url[article.url] = article
# Add/update with new articles (they take precedence)
for article in new_articles:
articles_by_url[article.url] = article
# Return as sorted list
merged_articles = list(articles_by_url.values())
merged_articles.sort(
key=lambda x: x.published_date, reverse=True
) # Newest first
return merged_articles

View File

@ -0,0 +1,304 @@
"""
Repository for social media data (Reddit posts and social media content).
"""
import json
import logging
from dataclasses import asdict, dataclass, field
from datetime import date
from pathlib import Path
from .base import BaseRepository
logger = logging.getLogger(__name__)
@dataclass
class SocialPost:
"""Represents a social media post."""
title: str
content: str
author: str
source: str # "reddit", "twitter", etc.
platform_id: (
str # Unique identifier for deduplication (Reddit post ID, tweet ID, etc.)
)
created_date: date
# Optional fields
subreddit: str | None = None # Reddit-specific
score: int = 0
comments_count: int = 0
upvote_ratio: float | None = None # Reddit-specific
url: str | None = None
sentiment_score: float | None = None
engagement_score: int = 0
hashtags: list[str] = field(default_factory=list)
mentions: list[str] = field(default_factory=list)
@dataclass
class SocialData:
"""Container for social media data with metadata."""
query: str
date: date
source: str # "reddit", "twitter", etc.
posts: list[SocialPost]
class SocialRepository(BaseRepository):
"""Repository for accessing cached social media data with source separation."""
def __init__(self, data_dir: str, **kwargs):
"""
Initialize social media repository.
Args:
data_dir: Base directory for social media data storage
**kwargs: Additional configuration
"""
self.social_data_dir = Path(data_dir) / "social_data"
self.social_data_dir.mkdir(parents=True, exist_ok=True)
def get_social_data(
self,
query: str,
start_date: date,
end_date: date,
sources: list[str] | None = None,
) -> dict[date, list[SocialData]]:
"""
Get cached social media data for a query and date range across sources.
Args:
query: Search query or symbol
start_date: Start date
end_date: End date
sources: List of sources to check (default: ["reddit"])
Returns:
Dict[date, list[SocialData]]: Social data keyed by date, with list of source data
"""
if sources is None:
sources = ["reddit"]
social_data = {}
for source in sources:
source_dir = self.social_data_dir / source / query
if not source_dir.exists():
logger.debug(f"No data directory found for {source}/{query}")
continue
# Scan for JSON files in the source/query directory
for json_file in source_dir.glob("*.json"):
try:
# Parse date from filename (YYYY-MM-DD.json)
date_str = json_file.stem
file_date = date.fromisoformat(date_str)
# Filter by date range
if start_date <= file_date <= end_date:
with open(json_file) as f:
data = json.load(f)
# Create SocialPost objects from JSON data
posts = []
for post_data in data.get("posts", []):
# Convert date strings back to date objects
post_data_copy = post_data.copy()
if "created_date" in post_data_copy:
post_data_copy["created_date"] = date.fromisoformat(
post_data_copy["created_date"]
)
post = SocialPost(**post_data_copy)
posts.append(post)
# Create SocialData container
social_data_item = SocialData(
query=query, date=file_date, source=source, posts=posts
)
# Group by date (multiple sources per date)
if file_date not in social_data:
social_data[file_date] = []
social_data[file_date].append(social_data_item)
except (ValueError, json.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"Error reading social data from {json_file}: {e}")
continue
logger.info(
f"Retrieved social data for {len(social_data)} dates for query '{query}'"
)
return social_data
def store_social_posts(
self,
query: str,
date: date,
source: str,
posts: list[SocialPost],
) -> tuple[date, SocialData]:
"""
Store social media posts for a query, date, and source, merging with existing data.
Args:
query: Search query or symbol
date: Date of the social media posts
source: Social media source ("reddit", "twitter", etc.)
posts: List of social media posts
Returns:
Tuple[date, SocialData]: The stored date and social data
"""
# Create source/query directory
source_dir = self.social_data_dir / source / query
self._ensure_path_exists(source_dir)
# Create JSON file path
file_path = source_dir / f"{date.isoformat()}.json"
try:
# Merge with existing posts if file exists
merged_posts = self._merge_posts_with_existing(file_path, posts)
# Prepare data for JSON serialization
posts_data = []
for post in merged_posts:
post_dict = asdict(post)
# Convert date objects to ISO format strings for JSON
if post_dict.get("created_date"):
post_dict["created_date"] = post_dict["created_date"].isoformat()
posts_data.append(post_dict)
data = {
"query": query,
"date": date.isoformat(),
"source": source,
"posts": posts_data,
"metadata": {
"post_count": len(merged_posts),
"stored_at": date.today().isoformat(),
"repository": "social_repository",
},
}
# Write to JSON file
with open(file_path, "w") as f:
json.dump(data, f, indent=2, default=str)
# Create SocialData result
social_data = SocialData(
query=query, date=date, source=source, posts=merged_posts
)
logger.info(
f"Stored {len(posts)} new posts for {query} on {date} from {source} (total: {len(merged_posts)})"
)
return (date, social_data)
except Exception as e:
logger.error(
f"Error storing social posts for {query} on {date} from {source}: {e}"
)
raise
def store_social_data_batch(
self,
query: str,
social_data_by_source: dict[str, dict[date, list[SocialPost]]],
) -> dict[date, list[SocialData]]:
"""
Store multiple social media data sets for a query across sources.
Args:
query: Search query or symbol
social_data_by_source: Nested dict of {source: {date: [posts]}}
Returns:
Dict[date, list[SocialData]]: The stored social data organized by date
"""
stored_data = {}
for source, date_posts in social_data_by_source.items():
for post_date, posts in date_posts.items():
try:
stored_date, stored_social_data = self.store_social_posts(
query, post_date, source, posts
)
# Group by date
if stored_date not in stored_data:
stored_data[stored_date] = []
stored_data[stored_date].append(stored_social_data)
except Exception as e:
logger.error(
f"Failed to store social data for {query} on {post_date} from {source}: {e}"
)
continue
total_dates = len(stored_data)
total_sources = sum(len(social_list) for social_list in stored_data.values())
logger.info(
f"Stored social data for {total_dates} dates, {total_sources} source entries for query '{query}'"
)
return stored_data
def _merge_posts_with_existing(
self, file_path: Path, new_posts: list[SocialPost]
) -> list[SocialPost]:
"""
Merge new posts with existing posts, deduplicating by platform_id.
Args:
file_path: Path to existing JSON file
new_posts: New posts to merge
Returns:
List[SocialPost]: Merged and deduplicated posts
"""
existing_posts = []
# Load existing posts if file exists
if file_path.exists():
try:
with open(file_path) as f:
data = json.load(f)
for existing_data in data.get("posts", []):
# Convert date strings back to date objects
existing_data_copy = existing_data.copy()
if "created_date" in existing_data_copy:
existing_data_copy["created_date"] = date.fromisoformat(
existing_data_copy["created_date"]
)
existing_post = SocialPost(**existing_data_copy)
existing_posts.append(existing_post)
except (json.JSONDecodeError, KeyError, ValueError, TypeError) as e:
logger.warning(f"Error reading existing file {file_path}: {e}")
existing_posts = []
# Merge posts, deduplicating by platform_id (keep newer data)
posts_by_id = {}
# Add existing posts
for post in existing_posts:
posts_by_id[post.platform_id] = post
# Add/update with new posts (they take precedence)
for post in new_posts:
posts_by_id[post.platform_id] = post
# Return as sorted list
merged_posts = list(posts_by_id.values())
merged_posts.sort(key=lambda x: x.created_date, reverse=True) # Newest first
return merged_posts

View File

@ -0,0 +1,13 @@
"""
Service classes for TradingAgents that return Pydantic context objects.
"""
from .base import BaseService
from .market_data_service import MarketDataService
from .news_service import NewsService
__all__ = [
"BaseService",
"MarketDataService",
"NewsService",
]

View File

@ -0,0 +1,30 @@
"""
Base service class for TradingAgents services.
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
class BaseService:
"""Base service class with common functionality."""
def __init__(self, online_mode: bool = True, data_dir: str = "data", **kwargs):
"""Initialize base service.
Args:
online_mode: Whether to use live APIs or cached data only
data_dir: Directory for data storage
"""
self.online_mode = online_mode
self.data_dir = data_dir
def is_online(self) -> bool:
"""Check if service is in online mode."""
return self.online_mode
def set_online_mode(self, online: bool) -> None:
"""Set online mode for the service."""
self.online_mode = online

View File

@ -0,0 +1,364 @@
"""
Simple builder functions for dependency injection in TradingAgents services.
"""
import logging
from tradingagents.clients import (
FinnhubClient,
GoogleNewsClient,
RedditClient,
YFinanceClient,
)
from tradingagents.config import TradingAgentsConfig
from tradingagents.repositories import (
FundamentalDataRepository,
InsiderDataRepository,
MarketDataRepository,
NewsRepository,
OpenAIRepository,
SocialRepository,
)
from .fundamental_data_service import FundamentalDataService
from .insider_data_service import InsiderDataService
from .market_data_service import MarketDataService
from .news_service import NewsService
from .openai_data_service import OpenAIDataService
from .social_media_service import SocialMediaService
logger = logging.getLogger(__name__)
def build_market_data_service(config: TradingAgentsConfig) -> MarketDataService:
"""
Build MarketDataService with appropriate client and repository.
Args:
config: TradingAgents configuration
Returns:
MarketDataService: Configured service
"""
client = None
repository = None
# Create client for online mode
if config.online_tools:
try:
client = YFinanceClient()
logger.info("Created YFinanceClient for live data")
except Exception as e:
logger.warning(f"Failed to create YFinanceClient: {e}")
# Always create repository for fallback/offline mode
try:
repository = MarketDataRepository(
data_dir=config.data_dir, cache_dir=config.data_cache_dir
)
logger.info(f"Created MarketDataRepository with data_dir: {config.data_dir}")
except Exception as e:
logger.error(f"Failed to create MarketDataRepository: {e}")
return MarketDataService(
client=client,
repository=repository,
online_mode=config.online_tools,
data_dir=config.data_dir,
)
def build_news_service(config: TradingAgentsConfig) -> NewsService:
"""
Build NewsService with appropriate clients and repository.
Args:
config: TradingAgents configuration
Returns:
NewsService: Configured service
"""
finnhub_client = None
google_client = None
repository = None
# Create clients for online mode
if config.online_tools:
# Finnhub client
if config.finnhub_api_key:
try:
finnhub_client = FinnhubClient(config.finnhub_api_key)
logger.info("Created FinnhubClient for live news data")
except Exception as e:
logger.warning(f"Failed to create FinnhubClient: {e}")
else:
logger.info("No Finnhub API key provided, skipping FinnhubClient")
# Google News client
try:
google_client = GoogleNewsClient()
logger.info("Created GoogleNewsClient for live news data")
except Exception as e:
logger.warning(f"Failed to create GoogleNewsClient: {e}")
# Always create repository for fallback/offline mode
try:
repository = NewsRepository(
data_dir=config.data_dir, cache_dir=config.data_cache_dir
)
logger.info(f"Created NewsRepository with data_dir: {config.data_dir}")
except Exception as e:
logger.error(f"Failed to create NewsRepository: {e}")
return NewsService(
finnhub_client=finnhub_client,
google_client=google_client,
repository=repository,
online_mode=config.online_tools,
data_dir=config.data_dir,
)
def build_social_media_service(config: TradingAgentsConfig) -> SocialMediaService:
"""
Build SocialMediaService with appropriate client and repository.
Args:
config: TradingAgents configuration
Returns:
SocialMediaService: Configured service
"""
client = None
repository = None
# Create client for online mode
if config.online_tools:
# Reddit client
if config.reddit_client_id and config.reddit_client_secret:
try:
client = RedditClient(
client_id=config.reddit_client_id,
client_secret=config.reddit_client_secret,
user_agent=config.reddit_user_agent,
)
logger.info("Created RedditClient for live social media data")
except Exception as e:
logger.warning(f"Failed to create RedditClient: {e}")
else:
logger.info("No Reddit credentials provided, skipping RedditClient")
# Always create repository for fallback/offline mode
try:
repository = SocialRepository(config.data_dir)
logger.info(f"Created SocialRepository with data_dir: {config.data_dir}")
except Exception as e:
logger.error(f"Failed to create SocialRepository: {e}")
return SocialMediaService(
client=client,
repository=repository,
online_mode=config.online_tools,
data_dir=config.data_dir,
)
def build_fundamental_service(config: TradingAgentsConfig) -> FundamentalDataService:
"""
Build FundamentalDataService with appropriate client and repository.
Args:
config: TradingAgents configuration
Returns:
FundamentalDataService: Configured service
"""
client = None
repository = None
# Create client for online mode
if config.online_tools:
# SimFin client (would be implemented when SimFinClient is available)
# try:
# client = SimFinClient() # This would need API key configuration
# logger.info("Created SimFinClient for live fundamental data")
# except Exception as e:
# logger.warning(f"Failed to create SimFinClient: {e}")
logger.info("SimFinClient not yet implemented, using repository only")
# Always create repository for fallback/offline mode
try:
repository = FundamentalDataRepository(config.data_dir)
logger.info(
f"Created FundamentalDataRepository with data_dir: {config.data_dir}"
)
except Exception as e:
logger.error(f"Failed to create FundamentalDataRepository: {e}")
return FundamentalDataService(
simfin_client=client,
repository=repository,
online_mode=config.online_tools,
data_dir=config.data_dir,
)
def build_insider_service(config: TradingAgentsConfig) -> InsiderDataService:
"""
Build InsiderDataService with appropriate client and repository.
Args:
config: TradingAgents configuration
Returns:
InsiderDataService: Configured service
"""
client = None
repository = None
# Create client for online mode
if config.online_tools:
# Finnhub client for insider data
if config.finnhub_api_key:
try:
client = FinnhubClient(config.finnhub_api_key)
logger.info("Created FinnhubClient for live insider trading data")
except Exception as e:
logger.warning(f"Failed to create FinnhubClient for insider data: {e}")
else:
logger.info("No Finnhub API key provided, skipping insider data client")
# Always create repository for fallback/offline mode
try:
repository = InsiderDataRepository(config.data_dir)
logger.info(f"Created InsiderDataRepository with data_dir: {config.data_dir}")
except Exception as e:
logger.error(f"Failed to create InsiderDataRepository: {e}")
return InsiderDataService(
finnhub_client=client,
repository=repository,
online_mode=config.online_tools,
data_dir=config.data_dir,
)
def build_openai_service(config: TradingAgentsConfig) -> OpenAIDataService:
"""
Build OpenAIDataService with appropriate client and repository.
Args:
config: TradingAgents configuration
Returns:
OpenAIDataService: Configured service
"""
client = None
repository = None
# Create client for online mode
if config.online_tools:
# OpenAI client (would be implemented when OpenAIClient is available)
# if config.openai_api_key:
# try:
# client = OpenAIClient(api_key=config.openai_api_key)
# logger.info("Created OpenAIClient for AI-powered analysis")
# except Exception as e:
# logger.warning(f"Failed to create OpenAIClient: {e}")
# else:
# logger.info("No OpenAI API key provided, skipping OpenAI client")
logger.info("OpenAIClient not yet implemented, using repository only")
# Always create repository for fallback/offline mode
try:
repository = OpenAIRepository(config.data_dir)
logger.info(f"Created OpenAIRepository with data_dir: {config.data_dir}")
except Exception as e:
logger.error(f"Failed to create OpenAIRepository: {e}")
return OpenAIDataService(
openai_client=client,
repository=repository,
online_mode=config.online_tools,
data_dir=config.data_dir,
)
def build_all_services(config: TradingAgentsConfig) -> dict:
"""
Build all available services.
Args:
config: TradingAgents configuration
Returns:
dict: Dictionary of service name to service instance
"""
services = {}
# Build MarketDataService
try:
services["market_data"] = build_market_data_service(config)
logger.info("Built MarketDataService")
except Exception as e:
logger.error(f"Failed to build MarketDataService: {e}")
# Build NewsService
try:
services["news"] = build_news_service(config)
logger.info("Built NewsService")
except Exception as e:
logger.error(f"Failed to build NewsService: {e}")
# Build SocialMediaService
try:
services["social_media"] = build_social_media_service(config)
logger.info("Built SocialMediaService")
except Exception as e:
logger.error(f"Failed to build SocialMediaService: {e}")
# Build FundamentalDataService
try:
services["fundamental"] = build_fundamental_service(config)
logger.info("Built FundamentalDataService")
except Exception as e:
logger.error(f"Failed to build FundamentalDataService: {e}")
# Build InsiderDataService
try:
services["insider"] = build_insider_service(config)
logger.info("Built InsiderDataService")
except Exception as e:
logger.error(f"Failed to build InsiderDataService: {e}")
# Build OpenAIDataService
try:
services["openai"] = build_openai_service(config)
logger.info("Built OpenAIDataService")
except Exception as e:
logger.error(f"Failed to build OpenAIDataService: {e}")
logger.info(f"Built {len(services)} services: {list(services.keys())}")
return services
def build_toolkit_services(config: TradingAgentsConfig) -> dict:
"""
Build services specifically configured for Toolkit usage.
Args:
config: TradingAgents configuration
Returns:
dict: Dictionary of services for Toolkit
"""
return build_all_services(config)
# Aliases for the service toolkit
create_market_data_service = build_market_data_service
create_news_service = build_news_service
create_social_media_service = build_social_media_service
create_fundamental_data_service = build_fundamental_service
create_insider_data_service = build_insider_service
create_openai_data_service = build_openai_service

View File

@ -0,0 +1,692 @@
"""
Fundamental Data Service for aggregating and analyzing financial statement data.
"""
import logging
from datetime import date, datetime
from typing import Any
from tradingagents.clients import FinnhubClient
from tradingagents.models.context import (
DataQuality,
FinancialStatement,
FundamentalContext,
)
from tradingagents.repositories.fundamental_repository import FundamentalDataRepository
from tradingagents.services.base import BaseService
logger = logging.getLogger(__name__)
class FundamentalDataService(BaseService):
"""Service for fundamental financial data aggregation and analysis."""
def __init__(
self,
finnhub_client: FinnhubClient,
repository: FundamentalDataRepository,
data_dir: str = "data",
**kwargs,
):
"""Initialize Fundamental Data Service.
Args:
finnhub_client: Client for Finnhub/financial API access
repository: Repository for cached fundamental data
data_dir: Directory for data storage
"""
super().__init__(online_mode=True, data_dir=data_dir, **kwargs)
self.finnhub_client = finnhub_client
self.repository = repository
def get_fundamental_context(
self,
symbol: str,
start_date: str,
end_date: str,
frequency: str = "quarterly",
force_refresh: bool = False,
**kwargs,
) -> FundamentalContext:
"""Get fundamental analysis context for a company.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
frequency: Reporting frequency ('quarterly' or 'annual')
force_refresh: If True, skip local data and fetch fresh from APIs
Returns:
FundamentalContext with financial statements and key ratios
"""
# Validate date strings first
try:
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
except ValueError as e:
raise ValueError(f"Invalid date format: {e}")
# Check date order
if end_dt < start_dt:
raise ValueError(f"End date {end_date} is before start date {start_date}")
balance_sheet = None
income_statement = None
cash_flow = None
error_info = {}
errors = []
data_source = "unknown"
try:
# Local-first data strategy with force refresh option
if force_refresh:
# Skip local data, fetch fresh from APIs
balance_sheet, income_statement, cash_flow, data_source = (
self._fetch_and_cache_fresh_fundamental_data(
symbol, start_date, end_date, frequency
)
)
else:
# Check local data first, fetch missing if needed
balance_sheet, income_statement, cash_flow, data_source = (
self._get_fundamental_data_local_first(
symbol, start_date, end_date, frequency
)
)
except Exception as e:
logger.error(f"Error fetching fundamental data: {e}")
errors.append(str(e))
# Add error info if there were any errors
if errors:
error_info = {"error": "; ".join(errors)}
# Calculate key financial ratios
key_ratios = self._calculate_key_ratios(
balance_sheet, income_statement, cash_flow
)
# Determine data quality based on data source
data_quality = self._determine_data_quality(
data_source=data_source,
statement_count=sum(
[
balance_sheet is not None,
income_statement is not None,
cash_flow is not None,
]
),
has_errors=bool(errors),
)
# Handle partial data scenarios gracefully
context = self._handle_partial_statements(
symbol=symbol,
start_date=start_date,
end_date=end_date,
frequency=frequency,
balance_sheet=balance_sheet,
income_statement=income_statement,
cash_flow=cash_flow,
key_ratios=key_ratios,
data_quality=data_quality,
data_source=data_source,
force_refresh=force_refresh,
error_info=error_info,
)
return context
def get_context(
self,
symbol: str,
start_date: str,
end_date: str,
frequency: str = "quarterly",
**kwargs,
) -> FundamentalContext:
"""Alias for get_fundamental_context for consistency with other services."""
return self.get_fundamental_context(
symbol=symbol,
start_date=start_date,
end_date=end_date,
frequency=frequency,
**kwargs,
)
def _get_balance_sheet(
self, symbol: str, frequency: str, report_date: date
) -> FinancialStatement | None:
"""Get balance sheet data from client."""
try:
data = self.finnhub_client.get_balance_sheet(symbol, frequency, report_date)
return self._convert_to_financial_statement(data)
except Exception as e:
logger.warning(f"Failed to get balance sheet for {symbol}: {e}")
return None
def _get_income_statement(
self, symbol: str, frequency: str, report_date: date
) -> FinancialStatement | None:
"""Get income statement data from client."""
try:
data = self.finnhub_client.get_income_statement(
symbol, frequency, report_date
)
return self._convert_to_financial_statement(data)
except Exception as e:
logger.warning(f"Failed to get income statement for {symbol}: {e}")
return None
def _get_cash_flow(
self, symbol: str, frequency: str, report_date: date
) -> FinancialStatement | None:
"""Get cash flow statement data from client."""
try:
data = self.finnhub_client.get_cash_flow(symbol, frequency, report_date)
return self._convert_to_financial_statement(data)
except Exception as e:
logger.warning(f"Failed to get cash flow for {symbol}: {e}")
return None
def _convert_to_financial_statement(
self, data: dict[str, Any]
) -> FinancialStatement | None:
"""Convert raw financial data to FinancialStatement object."""
if not data or "data" not in data or not data["data"]:
return None
try:
return FinancialStatement(
period=data.get("period", "Unknown"),
report_date=data.get("report_date", ""),
publish_date=data.get("publish_date", ""),
currency=data.get("currency", "USD"),
data=data["data"],
)
except Exception as e:
logger.warning(f"Failed to convert financial statement: {e}")
return None
def _parse_cached_statements(self, cached_data: dict[str, Any]) -> tuple:
"""Parse cached repository data into financial statements."""
balance_sheet = None
income_statement = None
cash_flow = None
if cached_data and "financial_statements" in cached_data:
statements = cached_data["financial_statements"]
if "balance_sheet" in statements:
balance_sheet = FinancialStatement(**statements["balance_sheet"])
if "income_statement" in statements:
income_statement = FinancialStatement(**statements["income_statement"])
if "cash_flow" in statements:
cash_flow = FinancialStatement(**statements["cash_flow"])
return balance_sheet, income_statement, cash_flow
def _get_fundamental_data_local_first(
self, symbol: str, start_date: str, end_date: str, frequency: str
) -> tuple[
FinancialStatement | None,
FinancialStatement | None,
FinancialStatement | None,
str,
]:
"""Get fundamental data using local-first strategy: check local data first, fetch missing if needed."""
try:
# Check if we have sufficient local data
if self.repository.has_data_for_period(
symbol, start_date, end_date, frequency=frequency
):
logger.info(
f"Using local fundamental data for {symbol} ({start_date} to {end_date})"
)
cached_data = self.repository.get_data(
symbol=symbol,
start_date=start_date,
end_date=end_date,
frequency=frequency,
)
balance_sheet, income_statement, cash_flow = (
self._parse_cached_statements(cached_data)
)
return balance_sheet, income_statement, cash_flow, "local_cache"
# We don't have sufficient local data - need to fetch from APIs
logger.info(
f"Local data insufficient, fetching from APIs for {symbol} ({start_date} to {end_date})"
)
balance_sheet, income_statement, cash_flow, _ = (
self._fetch_fresh_fundamental_data(
symbol, start_date, end_date, frequency
)
)
# Cache the fresh data
if any([balance_sheet, income_statement, cash_flow]):
try:
cache_data = {
"symbol": symbol,
"frequency": frequency,
"financial_statements": {},
"metadata": {"cached_at": datetime.utcnow().isoformat()},
}
if balance_sheet:
cache_data["financial_statements"]["balance_sheet"] = (
balance_sheet.model_dump()
)
if income_statement:
cache_data["financial_statements"]["income_statement"] = (
income_statement.model_dump()
)
if cash_flow:
cache_data["financial_statements"]["cash_flow"] = (
cash_flow.model_dump()
)
self.repository.store_data(symbol, cache_data, frequency=frequency)
logger.debug(f"Cached fresh fundamental data for {symbol}")
except Exception as e:
logger.warning(
f"Failed to cache fundamental data for {symbol}: {e}"
)
return balance_sheet, income_statement, cash_flow, "live_api"
except Exception as e:
logger.error(f"Error fetching fundamental data for {symbol}: {e}")
return None, None, None, "error"
def _fetch_and_cache_fresh_fundamental_data(
self, symbol: str, start_date: str, end_date: str, frequency: str
) -> tuple[
FinancialStatement | None,
FinancialStatement | None,
FinancialStatement | None,
str,
]:
"""Force fetch fresh fundamental data from APIs and cache it, bypassing local data."""
try:
logger.info(
f"Force refreshing fundamental data from APIs for {symbol} ({start_date} to {end_date})"
)
# Clear existing data
try:
self.repository.clear_data(
symbol, start_date, end_date, frequency=frequency
)
logger.debug(f"Cleared existing fundamental data for {symbol}")
except Exception as e:
logger.warning(
f"Failed to clear existing fundamental data for {symbol}: {e}"
)
# Fetch fresh data
balance_sheet, income_statement, cash_flow, _ = (
self._fetch_fresh_fundamental_data(
symbol, start_date, end_date, frequency
)
)
# Cache the fresh data
if any([balance_sheet, income_statement, cash_flow]):
try:
cache_data = {
"symbol": symbol,
"frequency": frequency,
"financial_statements": {},
"metadata": {"refreshed_at": datetime.utcnow().isoformat()},
}
if balance_sheet:
cache_data["financial_statements"]["balance_sheet"] = (
balance_sheet.model_dump()
)
if income_statement:
cache_data["financial_statements"]["income_statement"] = (
income_statement.model_dump()
)
if cash_flow:
cache_data["financial_statements"]["cash_flow"] = (
cash_flow.model_dump()
)
self.repository.store_data(
symbol, cache_data, frequency=frequency, overwrite=True
)
logger.debug(f"Cached refreshed fundamental data for {symbol}")
except Exception as e:
logger.warning(
f"Failed to cache refreshed fundamental data for {symbol}: {e}"
)
return balance_sheet, income_statement, cash_flow, "live_api_refresh"
except Exception as e:
logger.error(f"Error force refreshing fundamental data for {symbol}: {e}")
return None, None, None, "refresh_error"
def _fetch_fresh_fundamental_data(
self, symbol: str, start_date: str, end_date: str, frequency: str
) -> tuple[
FinancialStatement | None,
FinancialStatement | None,
FinancialStatement | None,
str,
]:
"""Fetch fresh fundamental data from APIs."""
balance_sheet = None
income_statement = None
cash_flow = None
if self.is_online() and self.finnhub_client:
# Parse end_date string to date object for client calls
try:
end_date_obj = date.fromisoformat(end_date)
except ValueError as e:
logger.error(f"Invalid end_date format '{end_date}': {e}")
return balance_sheet, income_statement, cash_flow, "date_error"
# Get financial statements from Finnhub client
balance_sheet = self._get_balance_sheet(symbol, frequency, end_date_obj)
income_statement = self._get_income_statement(
symbol, frequency, end_date_obj
)
cash_flow = self._get_cash_flow(symbol, frequency, end_date_obj)
return balance_sheet, income_statement, cash_flow, "live_api"
def _calculate_key_ratios(
self,
balance_sheet: FinancialStatement | None,
income_statement: FinancialStatement | None,
cash_flow: FinancialStatement | None,
) -> dict[str, float]:
"""Calculate key financial ratios from financial statements."""
ratios = {}
try:
# Extract data from statements
bs_data = balance_sheet.data if balance_sheet else {}
is_data = income_statement.data if income_statement else {}
# Liquidity Ratios
if (
"Total Current Assets" in bs_data
and "Total Current Liabilities" in bs_data
):
current_liabilities = bs_data["Total Current Liabilities"]
if current_liabilities > 0:
ratios["current_ratio"] = (
bs_data["Total Current Assets"] / current_liabilities
)
# Quick ratio (more conservative)
if all(
k in bs_data
for k in [
"Cash and Cash Equivalents",
"Short-term Investments",
"Accounts Receivable",
"Total Current Liabilities",
]
):
quick_assets = (
bs_data["Cash and Cash Equivalents"]
+ bs_data.get("Short-term Investments", 0)
+ bs_data["Accounts Receivable"]
)
current_liabilities = bs_data["Total Current Liabilities"]
if current_liabilities > 0:
ratios["quick_ratio"] = quick_assets / current_liabilities
# Cash ratio
if (
"Cash and Cash Equivalents" in bs_data
and "Total Current Liabilities" in bs_data
):
current_liabilities = bs_data["Total Current Liabilities"]
if current_liabilities > 0:
cash_and_equivalents = bs_data[
"Cash and Cash Equivalents"
] + bs_data.get("Short-term Investments", 0)
ratios["cash_ratio"] = cash_and_equivalents / current_liabilities
# Leverage Ratios
if "Long-term Debt" in bs_data and "Total Shareholders Equity" in bs_data:
equity = bs_data["Total Shareholders Equity"]
if equity > 0:
ratios["debt_to_equity"] = bs_data["Long-term Debt"] / equity
if "Long-term Debt" in bs_data and "Total Assets" in bs_data:
assets = bs_data["Total Assets"]
if assets > 0:
ratios["debt_to_assets"] = bs_data["Long-term Debt"] / assets
if "Total Assets" in bs_data and "Total Shareholders Equity" in bs_data:
equity = bs_data["Total Shareholders Equity"]
if equity > 0:
ratios["equity_multiplier"] = bs_data["Total Assets"] / equity
# Profitability Ratios
if "Total Revenue" in is_data and "Cost of Revenue" in is_data:
revenue = is_data["Total Revenue"]
if revenue > 0:
ratios["gross_margin"] = (
revenue - is_data["Cost of Revenue"]
) / revenue
if "Operating Income" in is_data and "Total Revenue" in is_data:
revenue = is_data["Total Revenue"]
if revenue > 0:
ratios["operating_margin"] = is_data["Operating Income"] / revenue
if "Net Income" in is_data and "Total Revenue" in is_data:
revenue = is_data["Total Revenue"]
if revenue > 0:
ratios["net_margin"] = is_data["Net Income"] / revenue
# Return on Equity (ROE)
if "Net Income" in is_data and "Total Shareholders Equity" in bs_data:
equity = bs_data["Total Shareholders Equity"]
if equity > 0:
ratios["roe"] = is_data["Net Income"] / equity
# Return on Assets (ROA)
if "Net Income" in is_data and "Total Assets" in bs_data:
assets = bs_data["Total Assets"]
if assets > 0:
ratios["roa"] = is_data["Net Income"] / assets
# Efficiency Ratios
if "Total Revenue" in is_data and "Total Assets" in bs_data:
assets = bs_data["Total Assets"]
if assets > 0:
ratios["asset_turnover"] = is_data["Total Revenue"] / assets
# Inventory turnover
if "Cost of Revenue" in is_data and "Inventory" in bs_data:
inventory = bs_data["Inventory"]
if inventory > 0:
ratios["inventory_turnover"] = (
is_data["Cost of Revenue"] / inventory
)
# Receivables turnover
if "Total Revenue" in is_data and "Accounts Receivable" in bs_data:
receivables = bs_data["Accounts Receivable"]
if receivables > 0:
ratios["receivables_turnover"] = (
is_data["Total Revenue"] / receivables
)
except Exception as e:
logger.warning(f"Error calculating financial ratios: {e}")
return ratios
def _handle_partial_statements(
self,
symbol: str,
start_date: str,
end_date: str,
frequency: str,
balance_sheet: FinancialStatement | None,
income_statement: FinancialStatement | None,
cash_flow: FinancialStatement | None,
key_ratios: dict[str, float],
data_quality: DataQuality,
data_source: str,
force_refresh: bool,
error_info: dict[str, Any],
) -> FundamentalContext:
"""Create context even if some statements are missing.
- If all statements fail: Raise exception
- If some statements succeed: Return partial context
- Mark missing statements in metadata
"""
statement_count = sum(
[
balance_sheet is not None,
income_statement is not None,
cash_flow is not None,
]
)
# If all statements failed, raise exception
if statement_count == 0 and data_source not in ["local_cache"]:
error_msg = f"Failed to fetch any financial statements for {symbol}"
if error_info:
error_msg += f": {error_info.get('error', 'Unknown error')}"
raise ValueError(error_msg)
# Create metadata with partial data information
metadata = {
"data_quality": data_quality,
"service": "fundamental_data",
"online_mode": self.is_online(),
"frequency": frequency,
"data_source": data_source,
"force_refresh": force_refresh,
"has_balance_sheet": balance_sheet is not None,
"has_income_statement": income_statement is not None,
"has_cash_flow": cash_flow is not None,
"partial_data": statement_count < 3,
"statement_count": statement_count,
**error_info,
}
return FundamentalContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
balance_sheet=balance_sheet,
income_statement=income_statement,
cash_flow=cash_flow,
key_ratios=key_ratios,
metadata=metadata,
)
def detect_fundamental_gaps(
self, symbol: str, start_date: str, end_date: str, frequency: str
) -> list[str]:
"""
Returns list of report dates that need fetching.
Example: If requesting quarterly from 2024-01-01 to 2024-12-31
and cache has Q1 and Q3, returns ["2024-06-30", "2024-09-30", "2024-12-31"]
For quarterly: Check for Q1 (Mar 31), Q2 (Jun 30), Q3 (Sep 30), Q4 (Dec 31)
For annual: Check for fiscal year ends
"""
try:
start_dt = date.fromisoformat(start_date)
end_dt = date.fromisoformat(end_date)
except ValueError:
logger.error(
f"Invalid date format in gap detection: {start_date}, {end_date}"
)
return []
# Get existing data from repository
try:
cached_data = self.repository.get_data(
symbol, start_date, end_date, frequency
)
existing_dates = set()
if cached_data and "financial_statements" in cached_data:
for statement_type in [
"balance_sheet",
"income_statement",
"cash_flow",
]:
if statement_type in cached_data["financial_statements"]:
stmt = cached_data["financial_statements"][statement_type]
if "report_date" in stmt:
existing_dates.add(stmt["report_date"])
except Exception as e:
logger.warning(f"Error checking cached data for gap detection: {e}")
existing_dates = set()
# Calculate expected report dates based on frequency
expected_dates = []
current_year = start_dt.year
end_year = end_dt.year
if frequency == "quarterly":
# Standard quarterly dates: Mar 31, Jun 30, Sep 30, Dec 31
quarter_dates = [
(3, 31), # Q1
(6, 30), # Q2
(9, 30), # Q3
(12, 31), # Q4
]
for year in range(current_year, end_year + 1):
for month, day in quarter_dates:
report_date = date(year, month, day)
if start_dt <= report_date <= end_dt:
expected_dates.append(report_date.isoformat())
elif frequency == "annual":
# Standard fiscal year end: Dec 31
for year in range(current_year, end_year + 1):
report_date = date(year, 12, 31)
if start_dt <= report_date <= end_dt:
expected_dates.append(report_date.isoformat())
# Return dates that are expected but not in cache
missing_dates = [d for d in expected_dates if d not in existing_dates]
if missing_dates:
logger.info(
f"Gap detection for {symbol}: missing {len(missing_dates)} report periods"
)
return missing_dates
def _determine_data_quality(
self, data_source: str, statement_count: int, has_errors: bool = False
) -> DataQuality:
"""Determine data quality based on source, statement count, and errors."""
if has_errors or statement_count == 0:
return DataQuality.LOW
if data_source in ["local_cache", "error", "refresh_error"]:
return DataQuality.LOW
elif data_source in ["live_api", "live_api_refresh"]:
if statement_count == 3:
return DataQuality.HIGH # All three statements available
elif statement_count == 2:
return DataQuality.MEDIUM # Two statements available
else:
return DataQuality.LOW # One or no statements
else:
return DataQuality.MEDIUM

View File

@ -0,0 +1,346 @@
"""
Market data service that provides structured market context.
"""
import logging
from typing import Any
from tradingagents.clients.base import BaseClient
from tradingagents.dataflows.stockstats_utils import StockstatsUtils
from tradingagents.models.context import (
MarketDataContext,
TechnicalIndicatorData,
)
from tradingagents.repositories.base import BaseRepository
from .base import BaseService
logger = logging.getLogger(__name__)
class MarketDataService(BaseService):
"""Service for market data and technical indicators."""
def __init__(
self,
client: BaseClient | None = None,
repository: BaseRepository | None = None,
online_mode: bool = True,
**kwargs,
):
"""
Initialize market data service.
Args:
client: Client for live market data
repository: Repository for historical market data
online_mode: Whether to use live data
**kwargs: Additional configuration
"""
super().__init__(online_mode, **kwargs)
self.client = client
self.repository = repository
self.stockstats_utils = StockstatsUtils()
def get_context(
self,
symbol: str,
start_date: str,
end_date: str,
indicators: list[str] | None = None,
force_refresh: bool = False,
**kwargs,
) -> MarketDataContext:
"""
Get market data context with price data and technical indicators.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
indicators: List of technical indicators to calculate
force_refresh: If True, skip local data and fetch fresh from API
**kwargs: Additional parameters
Returns:
MarketDataContext: Structured market data context
"""
if indicators is None:
indicators = ["rsi", "macd", "close_50_sma"]
# Local-first data strategy with force refresh option
if force_refresh:
# Skip local data, fetch fresh from API
price_data = self._fetch_and_cache_fresh_data(symbol, start_date, end_date)
data_source = "live_api_refresh"
else:
# Check local data first, fetch missing if needed
price_data = self._get_price_data_local_first(symbol, start_date, end_date)
data_source = price_data.get("metadata", {}).get("source", "unknown")
# Calculate technical indicators
technical_indicators = self._calculate_indicators(
symbol, start_date, end_date, indicators
)
# Determine data quality
data_quality = self._determine_data_quality(
data_source=data_source,
record_count=len(price_data.get("data", [])),
has_errors="error" in price_data.get("metadata", {}),
)
# Create metadata
metadata = self._create_base_metadata(
data_quality=data_quality,
price_data_source=data_source,
indicator_count=len(technical_indicators),
symbol=symbol,
force_refresh=force_refresh,
)
return MarketDataContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
price_data=price_data.get("data", []),
technical_indicators=technical_indicators,
metadata=metadata,
)
def get_price_context(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> MarketDataContext:
"""
Get market data context with just price data (no indicators).
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
MarketDataContext: Market context with price data only
"""
return self.get_context(symbol, start_date, end_date, indicators=[], **kwargs)
def _get_price_data_local_first(
self, symbol: str, start_date: str, end_date: str
) -> dict[str, Any]:
"""Get price data using local-first strategy: check local data first, fetch missing if needed."""
try:
# Check if we have sufficient local data
if self.repository and self.repository.has_data_for_period(
symbol, start_date, end_date
):
logger.info(
f"Using local data for {symbol} ({start_date} to {end_date})"
)
local_data = self.repository.get_data(
symbol=symbol, start_date=start_date, end_date=end_date
)
local_data["metadata"] = local_data.get("metadata", {})
local_data["metadata"]["source"] = "local_cache"
return local_data
# We don't have sufficient local data - need to fetch from API
if self.client:
logger.info(
f"Local data insufficient, fetching from API for {symbol} ({start_date} to {end_date})"
)
fresh_data = self.client.get_data(
symbol=symbol, start_date=start_date, end_date=end_date
)
# Cache the fresh data if we have a repository
if fresh_data and self.repository:
try:
self.repository.store_data(symbol, fresh_data)
logger.debug(f"Cached fresh data for {symbol}")
except Exception as e:
logger.warning(f"Failed to cache data for {symbol}: {e}")
fresh_data["metadata"] = fresh_data.get("metadata", {})
fresh_data["metadata"]["source"] = "live_api"
return fresh_data
# No client available, try repository as fallback
elif self.repository:
logger.warning(
f"No API client available, using partial local data for {symbol}"
)
local_data = self.repository.get_data(
symbol=symbol, start_date=start_date, end_date=end_date
)
local_data["metadata"] = local_data.get("metadata", {})
local_data["metadata"]["source"] = "local_partial"
return local_data
else:
logger.warning(f"No data source available for {symbol}")
return {
"symbol": symbol,
"data": [],
"metadata": {
"source": "none",
"error": "No client or repository configured",
},
}
except Exception as e:
logger.error(f"Error fetching price data for {symbol}: {e}")
return {
"symbol": symbol,
"data": [],
"metadata": {"source": "error", "error": str(e)},
}
def _fetch_and_cache_fresh_data(
self, symbol: str, start_date: str, end_date: str
) -> dict[str, Any]:
"""Force fetch fresh data from API and cache it, bypassing local data."""
try:
if not self.client:
logger.warning(f"No API client available for force refresh of {symbol}")
return {
"symbol": symbol,
"data": [],
"metadata": {
"source": "no_client",
"error": "No API client configured for force refresh",
},
}
logger.info(
f"Force refreshing data from API for {symbol} ({start_date} to {end_date})"
)
# Clear existing data if we have a repository
if self.repository:
try:
self.repository.clear_data(symbol, start_date, end_date)
logger.debug(f"Cleared existing data for {symbol}")
except Exception as e:
logger.warning(f"Failed to clear existing data for {symbol}: {e}")
# Fetch fresh data
fresh_data = self.client.get_data(
symbol=symbol, start_date=start_date, end_date=end_date
)
# Cache the fresh data
if fresh_data and self.repository:
try:
self.repository.store_data(symbol, fresh_data, overwrite=True)
logger.debug(f"Cached refreshed data for {symbol}")
except Exception as e:
logger.warning(f"Failed to cache refreshed data for {symbol}: {e}")
fresh_data["metadata"] = fresh_data.get("metadata", {})
fresh_data["metadata"]["source"] = "live_api_refresh"
return fresh_data
except Exception as e:
logger.error(f"Error force refreshing data for {symbol}: {e}")
return {
"symbol": symbol,
"data": [],
"metadata": {"source": "refresh_error", "error": str(e)},
}
def _calculate_indicators(
self, symbol: str, start_date: str, end_date: str, indicators: list[str]
) -> dict[str, list[TechnicalIndicatorData]]:
"""Calculate technical indicators."""
if not indicators:
return {}
technical_data = {}
for indicator in indicators:
try:
logger.info(f"Calculating {indicator} for {symbol}")
# Use existing stockstats utility
indicator_data = self._get_indicator_data(
symbol, indicator, start_date, end_date
)
if indicator_data:
technical_data[indicator] = indicator_data
else:
logger.warning(f"No data returned for indicator {indicator}")
except Exception as e:
logger.error(f"Error calculating {indicator} for {symbol}: {e}")
continue
return technical_data
def _get_indicator_data(
self, symbol: str, indicator: str, start_date: str, end_date: str
) -> list[TechnicalIndicatorData]:
"""Get indicator data using StockstatsUtils."""
try:
from datetime import datetime, timedelta
# Get data for the date range
current_date = datetime.strptime(end_date, "%Y-%m-%d")
start_date_dt = datetime.strptime(start_date, "%Y-%m-%d")
indicator_points = []
# Iterate through date range
while current_date >= start_date_dt:
date_str = current_date.strftime("%Y-%m-%d")
try:
# Use stockstats utility to get indicator value
# This assumes the existing data directory structure
data_dir = self.config.get("data_dir", "data")
price_data_dir = f"{data_dir}/market_data/price_data"
indicator_value = StockstatsUtils.get_stock_stats(
symbol,
indicator,
date_str,
price_data_dir,
online=self.online_mode,
)
if indicator_value is not None and indicator_value != "":
# Handle different indicator value types
if isinstance(indicator_value, int | float):
value = float(indicator_value)
elif isinstance(indicator_value, str):
try:
value = float(indicator_value)
except ValueError:
logger.warning(
f"Could not parse indicator value: {indicator_value}"
)
current_date -= timedelta(days=1)
continue
else:
# For complex indicators like MACD, this might be a dict
value = indicator_value
indicator_points.append(
TechnicalIndicatorData(
date=date_str, value=value, indicator_type=indicator
)
)
except Exception as e:
logger.debug(
f"Could not get {indicator} for {symbol} on {date_str}: {e}"
)
current_date -= timedelta(days=1)
# Return in chronological order
return list(reversed(indicator_points))
except Exception as e:
logger.error(f"Error getting indicator data for {indicator}: {e}")
return []

View File

@ -0,0 +1,111 @@
"""
News service that provides structured news context.
"""
import logging
from datetime import datetime
from typing import Any
from tradingagents.clients.base import BaseClient
from tradingagents.models.context import (
ArticleData,
NewsContext,
SentimentScore,
)
from tradingagents.repositories.base import BaseRepository
from .base import BaseService
logger = logging.getLogger(__name__)
class NewsService(BaseService):
"""Service for news data and sentiment analysis."""
def __init__(
self,
finnhub_client: BaseClient | None = None,
google_client: BaseClient | None = None,
repository: BaseRepository | None = None,
online_mode: bool = True,
**kwargs,
):
"""
Initialize news service.
Args:
finnhub_client: Client for Finnhub news data
google_client: Client for Google News data
repository: Repository for cached news data
online_mode: Whether to use live data
**kwargs: Additional configuration
"""
super().__init__(online_mode, **kwargs)
self.finnhub_client = finnhub_client
self.google_client = google_client
self.repository = repository
def get_context(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None = None,
sources: list[str] | None = None,
force_refresh: bool = False,
**kwargs,
) -> NewsContext:
"""
Get news context for a query and date range.
Args:
query: Search query
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
symbol: Stock ticker symbol if company-specific
sources: List of sources to use ('finnhub', 'google', or both)
force_refresh: If True, skip local data and fetch fresh from APIs
**kwargs: Additional parameters
Returns:
NewsContext: Structured news context
"""
pass
def get_company_news_context(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> NewsContext:
"""
Get news context specific to a company.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
**kwargs: Additional parameters
Returns:
NewsContext: Company-specific news context
"""
pass
def get_global_news_context(
self,
start_date: str,
end_date: str,
categories: list[str] | None = None,
**kwargs,
) -> NewsContext:
"""
Get global/macro news context.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
categories: News categories to search
**kwargs: Additional parameters
Returns:
NewsContext: Global news context
"""
pass

View File

@ -0,0 +1,577 @@
"""
Social Media Service for aggregating and analyzing social media data.
"""
import logging
from datetime import datetime
from typing import Any
from tradingagents.clients.base import BaseClient
from tradingagents.models.context import (
DataQuality,
PostData,
SentimentScore,
SocialContext,
)
from tradingagents.repositories.base import BaseRepository
from tradingagents.services.base import BaseService
logger = logging.getLogger(__name__)
class SocialMediaService(BaseService):
"""Service for social media data aggregation and analysis."""
def __init__(
self,
reddit_client: BaseClient | None = None,
repository: BaseRepository | None = None,
online_mode: bool = True,
data_dir: str = "data",
**kwargs,
):
"""Initialize Social Media Service.
Args:
reddit_client: Client for Reddit API access
repository: Repository for cached social data
online_mode: Whether to fetch live data
data_dir: Directory for data storage
"""
super().__init__(online_mode=online_mode, data_dir=data_dir, **kwargs)
self.reddit_client = reddit_client
self.repository = repository
def get_context(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None = None,
subreddits: list[str] | None = None,
force_refresh: bool = False,
**kwargs,
) -> SocialContext:
"""Get social media context for a query.
Args:
query: Search query
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
symbol: Optional stock symbol
subreddits: Optional list of subreddits to search
force_refresh: If True, skip local data and fetch fresh from APIs
Returns:
SocialContext with posts and sentiment analysis
"""
posts = []
error_info = {}
data_source = "unknown"
try:
# Local-first data strategy with force refresh option
if force_refresh:
# Skip local data, fetch fresh from APIs
posts, data_source = self._fetch_and_cache_fresh_social_data(
query, start_date, end_date, symbol, subreddits
)
else:
# Check local data first, fetch missing if needed
posts, data_source = self._get_social_data_local_first(
query, start_date, end_date, symbol, subreddits
)
except Exception as e:
logger.error(f"Error fetching social media data: {e}")
error_info = {"error": str(e)}
# Calculate sentiment and engagement metrics
sentiment_summary = self._calculate_sentiment(posts)
engagement_metrics = self._calculate_engagement_metrics(posts)
# Determine data quality based on data source
data_quality = self._determine_data_quality(
data_source=data_source,
record_count=len(posts),
has_errors=bool(error_info),
)
# Separate float metrics from metadata
float_metrics = {
k: v for k, v in engagement_metrics.items() if isinstance(v, int | float)
}
metadata_info = {
k: v
for k, v in engagement_metrics.items()
if not isinstance(v, int | float)
}
return SocialContext(
symbol=symbol,
period={"start": start_date, "end": end_date},
posts=posts,
engagement_metrics=float_metrics,
sentiment_summary=sentiment_summary,
post_count=len(posts),
platforms=["reddit"],
metadata={
"data_quality": data_quality,
"service": "social_media",
"online_mode": self.is_online(),
"subreddits": subreddits or [],
"data_source": data_source,
"force_refresh": force_refresh,
**metadata_info,
**error_info,
},
)
def get_company_social_context(
self,
symbol: str,
start_date: str,
end_date: str,
subreddits: list[str] | None = None,
**kwargs,
) -> SocialContext:
"""Get company-specific social media context.
Args:
symbol: Stock ticker symbol
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
subreddits: Optional list of subreddits
Returns:
SocialContext for the company
"""
return self.get_context(
query=symbol,
start_date=start_date,
end_date=end_date,
symbol=symbol,
subreddits=subreddits,
**kwargs,
)
def get_global_trends(
self,
start_date: str,
end_date: str,
subreddits: list[str] | None = None,
**kwargs,
) -> SocialContext:
"""Get global social media trends.
Args:
start_date: Start date in YYYY-MM-DD format
end_date: End date in YYYY-MM-DD format
subreddits: Optional list of subreddits
Returns:
SocialContext with global trends
"""
posts = []
try:
if self.is_online() and self.reddit_client:
subreddit_list = subreddits or ["news", "worldnews", "Economics"]
# Get top posts from subreddits
raw_posts = self.reddit_client.get_top_posts(
subreddit_names=subreddit_list, limit=50, time_filter="week"
)
# Filter by date
if hasattr(self.reddit_client, "filter_posts_by_date"):
raw_posts = self.reddit_client.filter_posts_by_date(
raw_posts, start_date, end_date
)
posts = self._convert_to_post_data(raw_posts)
except Exception as e:
logger.error(f"Error fetching global trends: {e}")
sentiment_summary = self._calculate_sentiment(posts)
engagement_metrics = self._calculate_engagement_metrics(posts)
# Separate float metrics from metadata
float_metrics = {
k: v for k, v in engagement_metrics.items() if isinstance(v, int | float)
}
metadata_info = {
k: v
for k, v in engagement_metrics.items()
if not isinstance(v, int | float)
}
return SocialContext(
symbol=None, # No specific symbol for global trends
period={"start": start_date, "end": end_date},
posts=posts,
engagement_metrics=float_metrics,
sentiment_summary=sentiment_summary,
post_count=len(posts),
platforms=["reddit"],
metadata={
"data_quality": self._determine_data_quality(
data_source="live_api" if self.is_online() else "offline",
record_count=len(posts),
has_errors=False,
),
"service": "social_media",
"type": "global_trends",
"subreddits": subreddits or [],
**metadata_info,
},
)
def _convert_to_post_data(self, raw_posts: list[dict[str, Any]]) -> list[PostData]:
"""Convert raw Reddit posts to PostData objects."""
posts = []
for post in raw_posts:
try:
# Calculate engagement score
engagement = post.get("upvotes", 0) + post.get("num_comments", 0)
# Get posted date
if "posted_date" in post:
date_str = post["posted_date"]
elif "created_utc" in post:
date_str = datetime.fromtimestamp(post["created_utc"]).strftime(
"%Y-%m-%d"
)
else:
date_str = datetime.now().strftime("%Y-%m-%d")
post_data = PostData(
title=post.get("title", ""),
content=post.get("content", ""),
author=post.get("author", "unknown"),
source=post.get("subreddit", "reddit"),
date=date_str,
url=post.get("url", ""),
score=post.get("score", 0),
comments=post.get("num_comments", 0),
engagement_score=engagement,
subreddit=post.get("subreddit"),
metadata={
"upvotes": post.get("upvotes", 0),
"num_comments": post.get("num_comments", 0),
"subreddit": post.get("subreddit", ""),
},
)
posts.append(post_data)
except Exception as e:
logger.warning(f"Error converting post: {e}")
continue
return posts
def _convert_cached_to_posts(self, cached_data: dict[str, Any]) -> list[PostData]:
"""Convert cached repository data to PostData objects."""
posts = []
if not cached_data or "posts" not in cached_data:
return posts
for post in cached_data.get("posts", []):
try:
posts.append(PostData(**post))
except Exception as e:
logger.warning(f"Error converting cached post: {e}")
return posts
def _get_social_data_local_first(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None,
subreddits: list[str] | None,
) -> tuple[list[PostData], str]:
"""Get social data using local-first strategy: check local data first, fetch missing if needed."""
try:
# Check if we have sufficient local data
search_key = symbol or query
if self.repository and self.repository.has_data_for_period(
search_key, start_date, end_date, symbol=symbol
):
logger.info(
f"Using local social data for {search_key} ({start_date} to {end_date})"
)
cached_data = self.repository.get_data(
query=search_key, start_date=start_date, end_date=end_date
)
posts = self._convert_cached_to_posts(cached_data)
return posts, "local_cache"
# We don't have sufficient local data - need to fetch from APIs
logger.info(
f"Local data insufficient, fetching from APIs for {search_key} ({start_date} to {end_date})"
)
posts, _ = self._fetch_fresh_social_data(
query, start_date, end_date, symbol, subreddits
)
# Cache the fresh data if we have a repository
if posts and self.repository:
try:
posts_data = [post.model_dump() for post in posts]
cache_data = {
"query": query,
"symbol": symbol,
"posts": posts_data,
"subreddits": subreddits,
"metadata": {"cached_at": datetime.utcnow().isoformat()},
}
self.repository.store_data(search_key, cache_data, symbol=symbol)
logger.debug(f"Cached fresh social data for {search_key}")
except Exception as e:
logger.warning(f"Failed to cache social data for {search_key}: {e}")
return posts, "live_api"
except Exception as e:
logger.error(f"Error fetching social data for {query}: {e}")
return [], "error"
def _fetch_and_cache_fresh_social_data(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None,
subreddits: list[str] | None,
) -> tuple[list[PostData], str]:
"""Force fetch fresh social data from APIs and cache it, bypassing local data."""
try:
search_key = symbol or query
logger.info(
f"Force refreshing social data from APIs for {search_key} ({start_date} to {end_date})"
)
# Clear existing data if we have a repository
if self.repository:
try:
self.repository.clear_data(
search_key, start_date, end_date, symbol=symbol
)
logger.debug(f"Cleared existing social data for {search_key}")
except Exception as e:
logger.warning(
f"Failed to clear existing social data for {search_key}: {e}"
)
# Fetch fresh data
posts, _ = self._fetch_fresh_social_data(
query, start_date, end_date, symbol, subreddits
)
# Cache the fresh data
if posts and self.repository:
try:
posts_data = [post.model_dump() for post in posts]
cache_data = {
"query": query,
"symbol": symbol,
"posts": posts_data,
"subreddits": subreddits,
"metadata": {"refreshed_at": datetime.utcnow().isoformat()},
}
self.repository.store_data(
search_key, cache_data, symbol=symbol, overwrite=True
)
logger.debug(f"Cached refreshed social data for {search_key}")
except Exception as e:
logger.warning(
f"Failed to cache refreshed social data for {search_key}: {e}"
)
return posts, "live_api_refresh"
except Exception as e:
logger.error(f"Error force refreshing social data for {query}: {e}")
return [], "refresh_error"
def _fetch_fresh_social_data(
self,
query: str,
start_date: str,
end_date: str,
symbol: str | None,
subreddits: list[str] | None,
) -> tuple[list[PostData], str]:
"""Fetch fresh social data from APIs."""
posts = []
if self.is_online() and self.reddit_client:
# Get live Reddit data
subreddit_list = subreddits or ["investing", "stocks", "wallstreetbets"]
# Search for posts
raw_posts = self.reddit_client.search_posts(
query=query,
subreddit_names=subreddit_list,
limit=50,
time_filter="week",
)
# Filter by date
if hasattr(self.reddit_client, "filter_posts_by_date"):
raw_posts = self.reddit_client.filter_posts_by_date(
raw_posts, start_date, end_date
)
# Convert to PostData objects
posts = self._convert_to_post_data(raw_posts)
return posts, "live_api"
def _calculate_sentiment(self, posts: list[PostData]) -> SentimentScore:
"""Calculate overall sentiment from posts."""
if not posts:
return SentimentScore(score=0.0, confidence=0.0, label="neutral")
total_score = 0.0
total_weight = 0.0
for post in posts:
# Simple sentiment analysis based on keywords and engagement
sentiment_score = self._analyze_post_sentiment(post)
# Weight by engagement
weight = 1 + (
post.engagement_score / 1000
) # Higher engagement = more weight
total_score += sentiment_score * weight
total_weight += weight
# Set individual post sentiment
post.sentiment = SentimentScore(
score=sentiment_score,
confidence=0.7, # Moderate confidence for keyword-based analysis
label="positive"
if sentiment_score > 0.2
else "negative"
if sentiment_score < -0.2
else "neutral",
)
# Calculate weighted average
avg_score = total_score / total_weight if total_weight > 0 else 0.0
# Determine label
if avg_score > 0.2:
label = "positive"
elif avg_score < -0.2:
label = "negative"
else:
label = "neutral"
# Confidence based on number of posts
confidence = min(0.9, 0.5 + (len(posts) / 100))
return SentimentScore(score=avg_score, confidence=confidence, label=label)
def _analyze_post_sentiment(self, post: PostData) -> float:
"""Analyze sentiment of a single post."""
text = f"{post.title} {post.content or ''}".lower()
# Simple keyword-based sentiment
positive_words = [
"bullish",
"moon",
"gains",
"buy",
"hold",
"amazing",
"great",
"excellent",
"positive",
"growth",
"beat",
"upgrade",
"🚀",
]
negative_words = [
"bearish",
"crash",
"sell",
"loss",
"decline",
"terrible",
"bad",
"negative",
"downgrade",
"warning",
"overvalued",
]
positive_count = sum(1 for word in positive_words if word in text)
negative_count = sum(1 for word in negative_words if word in text)
# Score from -1 to 1
if positive_count + negative_count == 0:
return 0.0
score = (positive_count - negative_count) / (positive_count + negative_count)
# Adjust for score ratio (upvotes vs downvotes implied)
if post.score > 0:
score_adjustment = min(0.2, post.score / 1000)
score = score * 0.8 + score_adjustment * 0.2
return max(-1.0, min(1.0, score))
def _calculate_engagement_metrics(self, posts: list[PostData]) -> dict[str, float]:
"""Calculate engagement metrics from posts."""
if not posts:
return {
"total_engagement": 0,
"average_engagement": 0,
"max_engagement": 0,
"total_posts": 0,
}
engagements = [post.engagement_score for post in posts]
metrics = {
"total_engagement": sum(engagements),
"average_engagement": sum(engagements) / len(engagements),
"max_engagement": max(engagements),
"total_posts": len(posts),
}
# Add top posts info
sorted_posts = sorted(posts, key=lambda p: p.engagement_score, reverse=True)
metrics["top_posts"] = [
{"title": p.title[:100], "engagement": p.engagement_score}
for p in sorted_posts[:3]
]
return metrics
def _determine_data_quality(
self, data_source: str, record_count: int, has_errors: bool = False
) -> DataQuality:
"""Determine data quality based on source, record count, and errors."""
if has_errors or record_count == 0:
return DataQuality.LOW
if data_source in ["local_cache", "error", "refresh_error"]:
return DataQuality.LOW
elif data_source in ["live_api", "live_api_refresh"]:
if record_count >= 20:
return DataQuality.HIGH
elif record_count >= 5:
return DataQuality.MEDIUM
else:
return DataQuality.LOW
else:
return DataQuality.MEDIUM

View File

@ -0,0 +1,495 @@
"""
Test FundamentalDataService with mock SimFin clients and real FundamentalDataRepository.
"""
import tempfile
from datetime import datetime
from typing import Any
import pytest
from tradingagents.clients.base import BaseClient
from tradingagents.models.context import (
DataQuality,
FinancialStatement,
FundamentalContext,
)
from tradingagents.repositories.fundamental_repository import FundamentalDataRepository
from tradingagents.services.fundamental_data_service import FundamentalDataService
class MockSimFinClient(BaseClient):
"""Mock SimFin client that returns sample financial statement data."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connection_works = True
def test_connection(self) -> bool:
return self.connection_works
def get_data(self, *args, **kwargs) -> dict[str, Any]:
"""Not used directly by FundamentalDataService."""
return {}
def get_balance_sheet(
self, ticker: str, freq: str, curr_date: str
) -> dict[str, Any]:
"""Return mock balance sheet data."""
return {
"ticker": ticker,
"statement_type": "balance_sheet",
"frequency": freq,
"period": "Q3-2024" if freq == "quarterly" else "2024",
"report_date": "2024-09-30",
"publish_date": "2024-10-30",
"currency": "USD",
"data": {
"Total Assets": 365725000000.0,
"Total Current Assets": 143566000000.0,
"Cash and Cash Equivalents": 28965000000.0,
"Short-term Investments": 31590000000.0,
"Accounts Receivable": 13348000000.0,
"Inventory": 6511000000.0,
"Total Non-Current Assets": 222159000000.0,
"Property, Plant & Equipment": 43715000000.0,
"Intangible Assets": 11235000000.0,
"Total Liabilities": 279414000000.0,
"Total Current Liabilities": 136817000000.0,
"Accounts Payable": 58146000000.0,
"Short-term Debt": 20208000000.0,
"Total Non-Current Liabilities": 142597000000.0,
"Long-term Debt": 106550000000.0,
"Total Shareholders Equity": 86311000000.0,
"Retained Earnings": 672000000.0,
"Common Stock": 77958000000.0,
},
"metadata": {
"source": "mock_simfin",
"retrieved_at": datetime(2024, 1, 2).isoformat(),
},
}
def get_income_statement(
self, ticker: str, freq: str, curr_date: str
) -> dict[str, Any]:
"""Return mock income statement data."""
return {
"ticker": ticker,
"statement_type": "income_statement",
"frequency": freq,
"period": "Q3-2024" if freq == "quarterly" else "2024",
"report_date": "2024-09-30",
"publish_date": "2024-10-30",
"currency": "USD",
"data": {
"Total Revenue": 94930000000.0,
"Cost of Revenue": 55720000000.0,
"Gross Profit": 39210000000.0,
"Operating Expenses": 15706000000.0,
"Research and Development": 8067000000.0,
"Sales, General & Administrative": 7639000000.0,
"Operating Income": 23504000000.0,
"Interest Expense": 1013000000.0,
"Other Income": 269000000.0,
"Income Before Tax": 22760000000.0,
"Tax Provision": 4438000000.0,
"Net Income": 18322000000.0,
"Basic EPS": 1.18,
"Diluted EPS": 1.15,
"Shares Outstanding": 15550193000,
},
"metadata": {
"source": "mock_simfin",
"retrieved_at": datetime(2024, 1, 2).isoformat(),
},
}
def get_cash_flow(self, ticker: str, freq: str, curr_date: str) -> dict[str, Any]:
"""Return mock cash flow statement data."""
return {
"ticker": ticker,
"statement_type": "cash_flow",
"frequency": freq,
"period": "Q3-2024" if freq == "quarterly" else "2024",
"report_date": "2024-09-30",
"publish_date": "2024-10-30",
"currency": "USD",
"data": {
"Net Income": 18322000000.0,
"Depreciation & Amortization": 2871000000.0,
"Changes in Working Capital": -1684000000.0,
"Operating Cash Flow": 23302000000.0,
"Capital Expenditures": -2736000000.0,
"Acquisitions": -1800000000.0,
"Asset Sales": 234000000.0,
"Investing Cash Flow": -4302000000.0,
"Dividends Paid": -3746000000.0,
"Share Repurchases": -24979000000.0,
"Debt Proceeds": 750000000.0,
"Debt Repayment": -1500000000.0,
"Financing Cash Flow": -28475000000.0,
"Free Cash Flow": 20566000000.0,
"Net Change in Cash": -9475000000.0,
},
"metadata": {
"source": "mock_simfin",
"retrieved_at": datetime(2024, 1, 2).isoformat(),
},
}
@pytest.fixture
def temp_data_dir():
"""Create a temporary directory for test data and clean up after test."""
with tempfile.TemporaryDirectory(prefix="fundamental_test_") as temp_dir:
yield temp_dir
@pytest.fixture
def mock_simfin_client():
"""Create a mock SimFin client for testing."""
return MockSimFinClient()
@pytest.fixture
def broken_simfin_client():
"""Create a broken SimFin client for error testing."""
class BrokenSimFinClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_balance_sheet(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_income_statement(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_cash_flow(self, *args, **kwargs):
raise Exception("SimFin API error")
return BrokenSimFinClient()
@pytest.fixture
def partial_data_client():
"""Create a client that returns partial data for testing."""
class PartialDataClient(MockSimFinClient):
def get_cash_flow(self, ticker, freq, curr_date):
# Simulate missing cash flow data
raise Exception("Cash flow data not available")
def get_income_statement(self, ticker, freq, curr_date):
# Simulate missing income statement
return {"data": {}, "metadata": {"error": "No data found"}}
return PartialDataClient()
def test_online_mode_with_mock_simfin(temp_data_dir, mock_simfin_client):
"""Test FundamentalDataService in online mode with mock SimFin client."""
# Create real repository with temporary directory
real_repo = FundamentalDataRepository(temp_data_dir)
# Create service with mock client and real repository
service = FundamentalDataService(
simfin_client=mock_simfin_client,
repository=real_repo,
data_dir=temp_data_dir,
)
# Test getting fundamental context with all three statements
context = service.get_fundamental_context(
symbol="AAPL",
start_date="2024-01-01",
end_date="2024-12-31",
frequency="quarterly",
force_refresh=True, # Force using mock client instead of cache
)
# Validate context structure
assert isinstance(context, FundamentalContext)
assert context.symbol == "AAPL"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-12-31"
# Validate financial statements
assert context.balance_sheet is not None
assert isinstance(context.balance_sheet, FinancialStatement)
assert context.balance_sheet.period == "Q3-2024"
assert context.balance_sheet.currency == "USD"
assert "Total Assets" in context.balance_sheet.data
assert context.income_statement is not None
assert isinstance(context.income_statement, FinancialStatement)
assert "Total Revenue" in context.income_statement.data
assert "Net Income" in context.income_statement.data
assert context.cash_flow is not None
assert isinstance(context.cash_flow, FinancialStatement)
assert "Operating Cash Flow" in context.cash_flow.data
assert "Free Cash Flow" in context.cash_flow.data
# Validate key ratios calculation
assert len(context.key_ratios) > 0
assert "current_ratio" in context.key_ratios
assert "debt_to_equity" in context.key_ratios
assert "roe" in context.key_ratios # Return on Equity
assert "gross_margin" in context.key_ratios
# Validate metadata
assert "data_quality" in context.metadata
assert context.metadata["service"] == "fundamental_data"
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
assert len(json_output) > 0
def test_annual_vs_quarterly_frequency(temp_data_dir, mock_simfin_client):
"""Test different reporting frequencies."""
real_repo = FundamentalDataRepository(temp_data_dir)
service = FundamentalDataService(
simfin_client=mock_simfin_client, repository=real_repo, data_dir=temp_data_dir
)
# Test quarterly
quarterly_context = service.get_fundamental_context(
symbol="MSFT",
start_date="2024-01-01",
end_date="2024-12-31",
frequency="quarterly",
)
assert quarterly_context.balance_sheet is not None
assert quarterly_context.balance_sheet.period == "Q3-2024"
# Test annual
annual_context = service.get_fundamental_context(
symbol="MSFT",
start_date="2024-01-01",
end_date="2024-12-31",
frequency="annual",
)
assert annual_context.balance_sheet is not None
assert annual_context.balance_sheet.period == "2024"
def test_financial_ratio_calculations(temp_data_dir, mock_simfin_client):
"""Test calculation of key financial ratios."""
real_repo = FundamentalDataRepository(temp_data_dir)
service = FundamentalDataService(
simfin_client=mock_simfin_client, repository=real_repo, data_dir=temp_data_dir
)
context = service.get_fundamental_context("TSLA", "2024-01-01", "2024-12-31")
# Check that key ratios are calculated
ratios = context.key_ratios
# Liquidity ratios
assert "current_ratio" in ratios
assert ratios["current_ratio"] > 0
# Leverage ratios
assert "debt_to_equity" in ratios
assert ratios["debt_to_equity"] >= 0
# Profitability ratios
assert "gross_margin" in ratios
assert "operating_margin" in ratios
assert "net_margin" in ratios
assert "roe" in ratios # Return on Equity
assert "roa" in ratios # Return on Assets
# Efficiency ratios
assert "asset_turnover" in ratios
# Validate ratio calculations are reasonable
assert 0 <= ratios["gross_margin"] <= 1
assert 0 <= ratios["net_margin"] <= 1
def test_offline_mode(temp_data_dir):
"""Test FundamentalDataService without a client (offline mode)."""
real_repo = FundamentalDataRepository(temp_data_dir)
service = FundamentalDataService(
simfin_client=None, repository=real_repo, data_dir=temp_data_dir
)
# Should handle offline gracefully
context = service.get_fundamental_context("AAPL", "2024-01-01", "2024-12-31")
assert context.symbol == "AAPL"
assert context.balance_sheet is None # No data available offline
assert context.income_statement is None
assert context.cash_flow is None
assert len(context.key_ratios) == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_partial_data_handling():
"""Test handling when only some financial statements are available."""
class PartialDataClient(MockSimFinClient):
def get_cash_flow(self, ticker, freq, curr_date):
# Simulate missing cash flow data
raise Exception("Cash flow data not available")
def get_income_statement(self, ticker, freq, curr_date):
# Simulate missing income statement
return {"data": {}, "metadata": {"error": "No data found"}}
partial_client = PartialDataClient()
service = FundamentalDataService(
simfin_client=partial_client, repository=None, online_mode=True
)
context = service.get_fundamental_context("XYZ", "2024-01-01", "2024-12-31")
# Should have balance sheet but not others
assert context.balance_sheet is not None
assert context.income_statement is None # Failed to load
assert context.cash_flow is None # Failed to load
# Ratios should be limited without full data (only balance sheet ratios available)
assert (
len(context.key_ratios) <= 8
) # Only balance sheet ratios possible, no profitability ratios
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_error_handling():
"""Test error handling with broken client."""
class BrokenSimFinClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_balance_sheet(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_income_statement(self, *args, **kwargs):
raise Exception("SimFin API error")
def get_cash_flow(self, *args, **kwargs):
raise Exception("SimFin API error")
broken_client = BrokenSimFinClient()
service = FundamentalDataService(
simfin_client=broken_client, repository=None, online_mode=True
)
# Should handle errors gracefully
context = service.get_fundamental_context(
"FAIL", "2024-01-01", "2024-12-31", force_refresh=True
)
assert context.symbol == "FAIL"
assert context.balance_sheet is None
assert context.income_statement is None
assert context.cash_flow is None
assert len(context.key_ratios) == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
# Service logs errors but doesn't include them in metadata
def test_json_structure():
"""Test JSON structure of fundamental context."""
mock_simfin = MockSimFinClient()
service = FundamentalDataService(
simfin_client=mock_simfin, repository=None, online_mode=True
)
context = service.get_fundamental_context("NVDA", "2024-01-01", "2024-12-31")
json_data = context.model_dump()
# Validate required fields
required_fields = [
"symbol",
"period",
"balance_sheet",
"income_statement",
"cash_flow",
"key_ratios",
"metadata",
]
for field in required_fields:
assert field in json_data
# Validate financial statement structure
if json_data["balance_sheet"]:
balance_sheet = json_data["balance_sheet"]
required_statement_fields = [
"period",
"report_date",
"publish_date",
"currency",
"data",
]
for field in required_statement_fields:
assert field in balance_sheet
# Check some key balance sheet items
bs_data = balance_sheet["data"]
assert "Total Assets" in bs_data
assert "Total Liabilities" in bs_data
assert "Total Shareholders Equity" in bs_data
# Validate key ratios
ratios = json_data["key_ratios"]
assert isinstance(ratios, dict)
assert len(ratios) > 0
# Validate metadata
metadata = json_data["metadata"]
assert "data_quality" in metadata
assert "service" in metadata
def test_comprehensive_ratio_calculation():
"""Test comprehensive financial ratio calculations."""
mock_simfin = MockSimFinClient()
service = FundamentalDataService(
simfin_client=mock_simfin, repository=None, online_mode=True
)
context = service.get_fundamental_context("COMP", "2024-01-01", "2024-12-31")
ratios = context.key_ratios
# Liquidity ratios
# Not all ratios may be calculable depending on available data
calculated_ratios = set(ratios.keys())
core_ratios = {
"current_ratio",
"debt_to_equity",
"gross_margin",
"net_margin",
"roe",
"roa",
}
# At least the core ratios should be present
assert core_ratios.issubset(calculated_ratios), (
f"Missing core ratios: {core_ratios - calculated_ratios}"
)
# All ratio values should be numbers
for ratio_name, ratio_value in ratios.items():
assert isinstance(ratio_value, int | float), (
f"{ratio_name} should be numeric, got {type(ratio_value)}"
)
assert ratio_value == ratio_value, (
f"{ratio_name} should not be NaN"
) # NaN check

View File

@ -0,0 +1,403 @@
#!/usr/bin/env python3
"""
Test InsiderDataService with mock Finnhub client and real InsiderDataRepository.
"""
import os
import sys
from datetime import datetime, timedelta
from typing import Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath("."))
from tradingagents.clients.base import BaseClient
from tradingagents.models.context import DataQuality, InsiderContext, InsiderTransaction
from tradingagents.repositories.insider_repository import InsiderDataRepository
from tradingagents.services.insider_data_service import InsiderDataService
class MockFinnhubClient(BaseClient):
"""Mock Finnhub client that returns sample insider trading data."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connection_works = True
def test_connection(self) -> bool:
return self.connection_works
def get_data(self, *args, **kwargs) -> dict[str, Any]:
"""Not used directly by InsiderDataService."""
return {}
def get_insider_trading(
self, ticker: str, start_date: str, end_date: str
) -> dict[str, Any]:
"""Return mock insider trading data."""
# Use fixed dates within test range for predictable filtering
base_date = datetime(2024, 6, 15) # Within our test range
return {
"ticker": ticker,
"data_type": "insider_trading",
"transactions": [
{
"filingDate": (base_date - timedelta(days=30)).strftime("%Y-%m-%d"),
"name": "John Smith",
"change": -50000,
"sharesTotal": 150000,
"transactionPrice": 180.50,
"transactionCode": "S", # Sale
},
{
"filingDate": (base_date - timedelta(days=20)).strftime("%Y-%m-%d"),
"name": "Jane Doe",
"change": 25000,
"sharesTotal": 75000,
"transactionPrice": 185.25,
"transactionCode": "P", # Purchase
},
{
"filingDate": (base_date - timedelta(days=10)).strftime("%Y-%m-%d"),
"name": "Robert Johnson",
"change": -10000,
"sharesTotal": 40000,
"transactionPrice": 178.75,
"transactionCode": "S", # Sale
},
{
"filingDate": (base_date - timedelta(days=5)).strftime("%Y-%m-%d"),
"name": "Mary Wilson",
"change": 15000,
"sharesTotal": 65000,
"transactionPrice": 182.00,
"transactionCode": "P", # Purchase
},
],
"metadata": {
"source": "mock_finnhub",
"retrieved_at": datetime(2024, 1, 2).isoformat(),
"symbol": ticker,
},
}
def test_online_mode_with_mock_finnhub():
"""Test InsiderDataService in online mode with mock Finnhub client."""
# Create mock client and real repository
mock_finnhub = MockFinnhubClient()
real_repo = InsiderDataRepository("test_data")
# Create service in online mode
service = InsiderDataService(
finnhub_client=mock_finnhub,
repository=real_repo,
online_mode=True,
data_dir="test_data",
)
# Test getting insider context
context = service.get_insider_context(
symbol="AAPL",
start_date="2024-01-01",
end_date="2024-12-31",
force_refresh=True,
)
# Validate context structure
assert isinstance(context, InsiderContext)
assert context.symbol == "AAPL"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-12-31"
# Validate transactions
assert len(context.transactions) == 4
assert all(isinstance(tx, InsiderTransaction) for tx in context.transactions)
# Check transaction details
first_tx = context.transactions[0]
assert first_tx.name == "John Smith"
assert first_tx.change == -50000 # Sale
assert first_tx.transaction_code == "S"
assert first_tx.transaction_price == 180.50
# Validate sentiment data and net activity
assert "buy_sell_ratio" in context.sentiment_data
assert "insider_sentiment_score" in context.sentiment_data
assert "net_shares_change" in context.net_activity
assert "net_transaction_value" in context.net_activity
assert "buy_transactions" in context.net_activity
assert "sell_transactions" in context.net_activity
# Validate metadata
assert context.transaction_count == 4
assert "data_quality" in context.metadata
assert context.metadata["service"] == "insider_data"
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
assert len(json_output) > 0
def test_insider_sentiment_analysis():
"""Test insider sentiment calculation based on transactions."""
mock_finnhub = MockFinnhubClient()
service = InsiderDataService(
finnhub_client=mock_finnhub, repository=None, online_mode=True
)
context = service.get_insider_context("TSLA", "2024-01-01", "2024-12-31")
# Check sentiment calculations
sentiment = context.sentiment_data
# Should have buy/sell ratio
assert "buy_sell_ratio" in sentiment
assert sentiment["buy_sell_ratio"] > 0
# Should have insider sentiment score (-1 to 1)
assert "insider_sentiment_score" in sentiment
assert -1.0 <= sentiment["insider_sentiment_score"] <= 1.0
# Check net activity calculations
net_activity = context.net_activity
# Net shares change: sum of all changes
expected_net_change = -50000 + 25000 + (-10000) + 15000 # -20000
assert net_activity["net_shares_change"] == expected_net_change
# Should have buy/sell transaction counts
assert net_activity["buy_transactions"] == 2 # Jane Doe and Mary Wilson
assert net_activity["sell_transactions"] == 2 # John Smith and Robert Johnson
def test_offline_mode():
"""Test InsiderDataService in offline mode."""
real_repo = InsiderDataRepository("test_data")
service = InsiderDataService(
finnhub_client=None, repository=real_repo, online_mode=False
)
# Should handle offline gracefully
context = service.get_insider_context("AAPL", "2024-01-01", "2024-12-31")
assert context.symbol == "AAPL"
assert len(context.transactions) == 0 # No data available offline
assert context.transaction_count == 0
# Sentiment data should have default values even with no transactions
assert context.sentiment_data.get("insider_sentiment_score", 0) == 0
assert context.sentiment_data.get("buy_sell_ratio", 0) == 0
# Net activity should have default values
assert context.net_activity.get("net_shares_change", 0) == 0
assert context.net_activity.get("net_transaction_value", 0) == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_empty_data_handling():
"""Test handling when no insider transactions are available."""
class EmptyDataClient(MockFinnhubClient):
def get_insider_trading(self, ticker, start_date, end_date):
return {
"ticker": ticker,
"data_type": "insider_trading",
"transactions": [],
"metadata": {"source": "mock_finnhub", "empty": True},
}
empty_client = EmptyDataClient()
service = InsiderDataService(
finnhub_client=empty_client, repository=None, online_mode=True
)
context = service.get_insider_context("XYZ", "2024-01-01", "2024-12-31")
# Should handle empty data gracefully
assert context.symbol == "XYZ"
assert len(context.transactions) == 0
assert context.transaction_count == 0
# Sentiment should be neutral with no data
assert context.sentiment_data.get("insider_sentiment_score", 0) == 0
assert context.net_activity.get("net_shares_change", 0) == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_error_handling():
"""Test error handling with broken client."""
class BrokenFinnhubClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("Finnhub API error")
def get_insider_trading(self, *args, **kwargs):
raise Exception("Finnhub API error")
broken_client = BrokenFinnhubClient()
service = InsiderDataService(
finnhub_client=broken_client, repository=None, online_mode=True
)
# Should handle errors gracefully
context = service.get_insider_context(
"FAIL", "2024-01-01", "2024-12-31", force_refresh=True
)
assert context.symbol == "FAIL"
assert len(context.transactions) == 0
assert context.transaction_count == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
# Service logs errors but doesn't include them in metadata
def test_transaction_filtering():
"""Test filtering transactions by date range."""
# Create a client that returns transactions outside the date range
class DateFilterTestClient(MockFinnhubClient):
def get_insider_trading(self, ticker, start_date, end_date):
return {
"ticker": ticker,
"data_type": "insider_trading",
"transactions": [
{
"filingDate": "2023-12-15", # Before start date
"name": "Old Transaction",
"change": -1000,
"sharesTotal": 10000,
"transactionPrice": 100.0,
"transactionCode": "S",
},
{
"filingDate": "2024-06-15", # Within range
"name": "Valid Transaction",
"change": 5000,
"sharesTotal": 15000,
"transactionPrice": 110.0,
"transactionCode": "P",
},
{
"filingDate": "2025-01-15", # After end date
"name": "Future Transaction",
"change": -2000,
"sharesTotal": 8000,
"transactionPrice": 120.0,
"transactionCode": "S",
},
],
"metadata": {"source": "mock_finnhub"},
}
filter_client = DateFilterTestClient()
service = InsiderDataService(
finnhub_client=filter_client, repository=None, online_mode=True
)
context = service.get_insider_context("TEST", "2024-01-01", "2024-12-31")
# Should only include the transaction within the date range
assert len(context.transactions) == 1
assert context.transactions[0].name == "Valid Transaction"
assert context.transaction_count == 1
def test_json_structure():
"""Test JSON structure of insider context."""
mock_finnhub = MockFinnhubClient()
service = InsiderDataService(
finnhub_client=mock_finnhub, repository=None, online_mode=True
)
context = service.get_insider_context("NVDA", "2024-01-01", "2024-12-31")
json_data = context.model_dump()
# Validate required fields
required_fields = [
"symbol",
"period",
"transactions",
"sentiment_data",
"transaction_count",
"net_activity",
"metadata",
]
for field in required_fields:
assert field in json_data
# Validate transaction structure
if json_data["transactions"]:
transaction = json_data["transactions"][0]
required_tx_fields = [
"filing_date",
"name",
"change",
"shares",
"transaction_price",
"transaction_code",
]
for field in required_tx_fields:
assert field in transaction
# Validate sentiment data structure
sentiment = json_data["sentiment_data"]
assert "buy_sell_ratio" in sentiment
assert "insider_sentiment_score" in sentiment
# Validate net activity structure
net_activity = json_data["net_activity"]
expected_net_fields = [
"net_shares_change",
"net_transaction_value",
"buy_transactions",
"sell_transactions",
]
for field in expected_net_fields:
assert field in net_activity
# Validate metadata
metadata = json_data["metadata"]
assert "data_quality" in metadata
assert "service" in metadata
def test_comprehensive_sentiment_calculation():
"""Test comprehensive insider sentiment calculation."""
mock_finnhub = MockFinnhubClient()
service = InsiderDataService(
finnhub_client=mock_finnhub, repository=None, online_mode=True
)
context = service.get_insider_context("COMP", "2024-01-01", "2024-12-31")
# Validate sentiment calculations are reasonable
sentiment = context.sentiment_data
net_activity = context.net_activity
# Buy/sell ratio should be positive (we have both buys and sells)
assert sentiment["buy_sell_ratio"] >= 0
# Insider sentiment score should be between -1 and 1
assert -1.0 <= sentiment["insider_sentiment_score"] <= 1.0
# Net transaction value should be calculated correctly
expected_value = (
(-50000 * 180.50) # John Smith sale
+ (25000 * 185.25) # Jane Doe purchase
+ (-10000 * 178.75) # Robert Johnson sale
+ (15000 * 182.00) # Mary Wilson purchase
)
assert abs(net_activity["net_transaction_value"] - expected_value) < 0.01
# Transaction counts should match
assert net_activity["buy_transactions"] == 2
assert net_activity["sell_transactions"] == 2
# Net shares change
expected_net_shares = -50000 + 25000 + (-10000) + 15000 # -20000
assert net_activity["net_shares_change"] == expected_net_shares

View File

@ -0,0 +1,543 @@
#!/usr/bin/env python3
"""
Test MarketDataService with mock YFinanceClient and real MarketDataRepository.
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath("."))
from tradingagents.clients.base import BaseClient
from tradingagents.models.context import DataQuality, MarketDataContext
from tradingagents.repositories.market_data_repository import MarketDataRepository
from tradingagents.services.market_data_service import MarketDataService
class MockYFinanceClient(BaseClient):
"""Mock Yahoo Finance client that returns predictable test data."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connection_works = True
def test_connection(self) -> bool:
return self.connection_works
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""Return realistic mock market data."""
# Generate realistic price data
base_price = {"AAPL": 180.0, "TSLA": 250.0, "MSFT": 400.0}.get(symbol, 100.0)
mock_data = []
current_date = datetime.strptime(start_date, "%Y-%m-%d")
end_date_dt = datetime.strptime(end_date, "%Y-%m-%d")
price = base_price
while current_date <= end_date_dt:
# Simulate some price movement
price_change = (
hash(current_date.strftime("%Y-%m-%d")) % 10 - 5
) / 100 # -5% to +5%
price *= 1 + price_change * 0.01
mock_data.append(
{
"Date": current_date.strftime("%Y-%m-%d %H:%M:%S"),
"Open": round(price * 0.99, 2),
"High": round(price * 1.02, 2),
"Low": round(price * 0.98, 2),
"Close": round(price, 2),
"Adj Close": round(price, 2),
"Volume": 45000000 + (hash(symbol) % 20000000),
}
)
current_date += timedelta(days=1)
return {
"symbol": symbol,
"period": {"start": start_date, "end": end_date},
"data": mock_data,
"metadata": {
"source": "mock_yahoo_finance",
"record_count": len(mock_data),
"columns": [
"Date",
"Open",
"High",
"Low",
"Close",
"Adj Close",
"Volume",
],
"retrieved_at": datetime.utcnow().isoformat(),
},
}
def test_online_mode_with_mock_client():
"""Test MarketDataService in online mode with mock client."""
print("📈 Testing MarketDataService - Online Mode")
# Create mock client and real repository
mock_client = MockYFinanceClient()
real_repo = MarketDataRepository("test_data")
# Create service in online mode
service = MarketDataService(
client=mock_client, repository=real_repo, online_mode=True, data_dir="test_data"
)
try:
# Test basic price context
context = service.get_price_context(
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
)
print(f"✅ Price context created: {context.__class__.__name__}")
print(f" Symbol: {context.symbol}")
print(f" Period: {context.period}")
print(f" Price data records: {len(context.price_data)}")
print(f" Technical indicators: {len(context.technical_indicators)}")
# Validate required fields
assert context.symbol == "AAPL"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-01-05"
assert len(context.price_data) > 0
assert "data_quality" in context.metadata
print("✅ Basic validation passed")
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
parsed = json.loads(json_output)
print(f"✅ JSON serialization: {len(json_output)} characters")
print(f" Top-level keys: {list(parsed.keys())}")
# Test with technical indicators
context_with_indicators = service.get_context(
symbol="TSLA",
start_date="2024-01-01",
end_date="2024-01-03",
indicators=["rsi", "macd"],
)
print("✅ Context with indicators created")
print(" Requested indicators: ['rsi', 'macd']")
print(
f" Available indicators: {list(context_with_indicators.technical_indicators.keys())}"
)
return True
except Exception as e:
print(f"❌ Online mode test failed: {e}")
return False
def test_offline_mode_with_real_repository():
"""Test MarketDataService in offline mode with real repository."""
print("\n💾 Testing MarketDataService - Offline Mode")
# Create service in offline mode (no client)
real_repo = MarketDataRepository("test_data")
service = MarketDataService(
client=None, repository=real_repo, online_mode=False, data_dir="test_data"
)
try:
# Test offline context (will likely return empty data)
context = service.get_price_context(
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
)
print(f"✅ Offline context created: {context.__class__.__name__}")
print(f" Symbol: {context.symbol}")
print(f" Price data records: {len(context.price_data)}")
print(f" Data quality: {context.metadata.get('data_quality')}")
print(f" Service mode: online={service.is_online()}")
# Should handle empty data gracefully
assert context.symbol == "AAPL"
assert isinstance(context.price_data, list)
assert "data_quality" in context.metadata
print("✅ Offline mode graceful handling verified")
return True
except Exception as e:
print(f"❌ Offline mode test failed: {e}")
return False
def test_error_handling():
"""Test error handling scenarios."""
print("\n⚠️ Testing Error Handling")
# Test with broken client
class BrokenClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("Simulated client failure")
broken_client = BrokenClient()
real_repo = MarketDataRepository("test_data")
service = MarketDataService(
client=broken_client,
repository=real_repo,
online_mode=True, # Online mode but client will fail
data_dir="test_data",
)
try:
context = service.get_price_context("AAPL", "2024-01-01", "2024-01-05")
print("✅ Error handling worked")
print(f" Symbol: {context.symbol}")
print(f" Price data records: {len(context.price_data)}")
print(f" Data quality: {context.metadata.get('data_quality')}")
# Should fallback to repository or return empty data
assert context.symbol == "AAPL"
assert isinstance(context.price_data, list)
return True
except Exception as e:
print(f"❌ Error handling test failed: {e}")
return False
def test_data_quality_assessment():
"""Test data quality determination logic."""
print("\n🔍 Testing Data Quality Assessment")
mock_client = MockYFinanceClient()
real_repo = MarketDataRepository("test_data")
service = MarketDataService(
client=mock_client, repository=real_repo, online_mode=True, data_dir="test_data"
)
try:
# Test with good data
context = service.get_context("AAPL", "2024-01-01", "2024-01-10")
data_quality = context.metadata.get("data_quality")
print(f"✅ Data quality assessment: {data_quality}")
print(f" Records: {len(context.price_data)}")
print(f" Online mode: {service.is_online()}")
# Should be medium or high quality for mock data
assert data_quality in [DataQuality.MEDIUM, DataQuality.HIGH]
return True
except Exception as e:
print(f"❌ Data quality test failed: {e}")
return False
def test_json_structure_validation():
"""Test detailed JSON structure validation."""
print("\n📄 Testing JSON Structure")
mock_client = MockYFinanceClient()
service = MarketDataService(client=mock_client, repository=None, online_mode=True)
try:
context = service.get_price_context("MSFT", "2024-01-01", "2024-01-03")
json_str = context.model_dump_json(indent=2)
data = json.loads(json_str)
# Validate required structure
required_fields = [
"symbol",
"period",
"price_data",
"technical_indicators",
"metadata",
]
for field in required_fields:
assert field in data, f"Missing field: {field}"
# Validate period structure
period = data["period"]
assert "start" in period and "end" in period
# Validate price data structure
assert isinstance(data["price_data"], list)
if data["price_data"]:
first_record = data["price_data"][0]
required_price_fields = ["Date", "Open", "High", "Low", "Close", "Volume"]
for field in required_price_fields:
assert field in first_record, f"Missing price field: {field}"
# Validate metadata
metadata = data["metadata"]
assert "data_quality" in metadata
assert "service" in metadata
print("✅ JSON structure validation passed")
print(f" Fields: {list(data.keys())}")
print(f" Price records: {len(data['price_data'])}")
print(f" Metadata keys: {list(metadata.keys())}")
return True
except Exception as e:
print(f"❌ JSON structure test failed: {e}")
return False
def test_force_refresh_parameter():
"""Test the force_refresh parameter functionality."""
try:
mock_client = MockYFinanceClient()
real_repo = MarketDataRepository("test_data")
service = MarketDataService(
client=mock_client, repository=real_repo, online_mode=True
)
# Test normal flow (should use repository if available)
normal_context = service.get_context(
"AAPL", "2024-01-01", "2024-01-31", force_refresh=False
)
# Test force refresh (should bypass repository and use client)
refresh_context = service.get_context(
"AAPL", "2024-01-01", "2024-01-31", force_refresh=True
)
# Both should return valid contexts
assert isinstance(normal_context, MarketDataContext)
assert isinstance(refresh_context, MarketDataContext)
assert normal_context.symbol == "AAPL"
assert refresh_context.symbol == "AAPL"
# Check metadata indicates source
refresh_metadata = refresh_context.metadata
assert "force_refresh" in refresh_metadata
assert refresh_metadata["force_refresh"]
print("✅ Force refresh parameter test passed")
return True
except Exception as e:
print(f"❌ Force refresh test failed: {e}")
return False
def test_local_first_strategy():
"""Test that the service checks local data first when available."""
try:
class MockRepositoryWithData(MarketDataRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True # Pretend we have the data
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"symbol": kwargs.get("symbol", "TEST"),
"data": [
{"date": "2024-01-01", "close": 150.0},
{"date": "2024-01-02", "close": 151.0},
],
"metadata": {"source": "test_repository"},
}
mock_client = MockYFinanceClient()
mock_repo = MockRepositoryWithData("test_data")
service = MarketDataService(
client=mock_client, repository=mock_repo, online_mode=True
)
# Should use local data since repository has_data_for_period returns True
context = service.get_context("TEST", "2024-01-01", "2024-01-31")
# Verify we used local data
assert context.metadata.get("price_data_source") == "local_cache"
assert len(context.price_data) == 2 # From mock repository
print("✅ Local-first strategy test passed")
return True
except Exception as e:
print(f"❌ Local-first strategy test failed: {e}")
return False
def test_local_first_fallback_to_api():
"""Test that service falls back to API when local data is insufficient."""
try:
class MockRepositoryWithoutData(MarketDataRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return False # Pretend we don't have the data
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"symbol": kwargs.get("symbol", "TEST"),
"data": [],
"metadata": {},
}
def store_data(
self,
symbol: str,
data: dict[str, Any],
overwrite: bool = False,
**kwargs,
) -> bool:
return True # Pretend storage was successful
mock_client = MockYFinanceClient()
mock_repo = MockRepositoryWithoutData("test_data")
service = MarketDataService(
client=mock_client, repository=mock_repo, online_mode=True
)
# Should fall back to API since repository doesn't have data
context = service.get_context("TEST", "2024-01-01", "2024-01-31")
# Verify we used API data
assert context.metadata.get("price_data_source") == "live_api"
assert len(context.price_data) > 0 # From mock client
print("✅ Local-first fallback to API test passed")
return True
except Exception as e:
print(f"❌ Local-first fallback test failed: {e}")
return False
def test_force_refresh_bypasses_local_data():
"""Test that force_refresh=True bypasses local data even when available."""
try:
class MockRepositoryAlwaysHasData(MarketDataRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True # Always claim we have data
def get_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"symbol": kwargs.get("symbol", "TEST"),
"data": [
{"date": "2024-01-01", "close": 100.0}
], # Different from client
"metadata": {"source": "local"},
}
def clear_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True
def store_data(
self,
symbol: str,
data: dict[str, Any],
overwrite: bool = False,
**kwargs,
) -> bool:
return True
mock_client = MockYFinanceClient()
mock_repo = MockRepositoryAlwaysHasData("test_data")
service = MarketDataService(
client=mock_client, repository=mock_repo, online_mode=True
)
# Force refresh should bypass local data
context = service.get_context(
"TEST", "2024-01-01", "2024-01-31", force_refresh=True
)
# Verify we used API data (force refresh)
assert context.metadata.get("price_data_source") == "live_api_refresh"
assert context.metadata.get("force_refresh")
# Should have more data from client than the single point from repository
assert len(context.price_data) > 1
print("✅ Force refresh bypasses local data test passed")
return True
except Exception as e:
print(f"❌ Force refresh bypass test failed: {e}")
return False
def main():
"""Run all MarketDataService tests."""
print("🧪 Testing MarketDataService\n")
tests = [
test_online_mode_with_mock_client,
test_offline_mode_with_real_repository,
test_error_handling,
test_data_quality_assessment,
test_json_structure_validation,
test_force_refresh_parameter,
test_local_first_strategy,
test_local_first_fallback_to_api,
test_force_refresh_bypasses_local_data,
]
passed = 0
failed = 0
for test in tests:
try:
if test():
passed += 1
else:
failed += 1
except Exception as e:
print(f"❌ Test {test.__name__} crashed: {e}")
failed += 1
print("\n📊 MarketDataService Test Results:")
print(f" ✅ Passed: {passed}")
print(f" ❌ Failed: {failed}")
if failed == 0:
print("🎉 All MarketDataService tests passed!")
else:
print("⚠️ Some tests failed - check output above")
return failed == 0
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -0,0 +1,737 @@
#!/usr/bin/env python3
"""
Test NewsService with mock clients and real NewsRepository.
"""
import json
import os
import sys
from datetime import datetime, timedelta
from typing import Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath("."))
from tradingagents.clients.base import BaseClient
from tradingagents.models.context import NewsContext, SentimentScore
from tradingagents.repositories.news_repository import NewsRepository
from tradingagents.services.news_service import NewsService
class MockFinnhubClient(BaseClient):
"""Mock Finnhub client that returns sample news data."""
def test_connection(self) -> bool:
return True
def get_data(self, *args, **kwargs) -> dict[str, Any]:
"""Not used directly by NewsService."""
return {}
def get_company_news(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""Return mock Finnhub company news."""
return {
"symbol": symbol,
"period": {"start": start_date, "end": end_date},
"articles": [
{
"headline": f"{symbol} Beats Q4 Earnings Expectations",
"summary": f"{symbol} reported earnings of $2.50 per share, beating analyst estimates of $2.25.",
"url": f"https://example.com/finnhub/{symbol.lower()}-earnings",
"source": "Finnhub Financial",
"date": start_date,
"entities": [symbol],
},
{
"headline": f"Insider Trading Activity at {symbol}",
"summary": f"Company executives at {symbol} have increased their holdings by 15% this quarter.",
"url": f"https://example.com/finnhub/{symbol.lower()}-insider",
"source": "Finnhub SEC Filings",
"date": end_date,
"entities": [symbol, "insider trading"],
},
],
"metadata": {
"source": "mock_finnhub",
"article_count": 2,
"retrieved_at": datetime.utcnow().isoformat(),
},
}
class MockGoogleNewsClient(BaseClient):
"""Mock Google News client that returns sample articles."""
def test_connection(self) -> bool:
return True
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
"""Return mock Google News data."""
article_templates = [
{
"template": "{query} Stock Surges on Positive Outlook",
"summary": "Shares of {query} rose 5% in after-hours trading following strong guidance for next quarter.",
"source": "Mock Market News",
},
{
"template": "Analysts Recommend Buy Rating for {query}",
"summary": "Three major investment firms upgraded {query} to 'Buy' with improved price targets.",
"source": "Mock Investment Daily",
},
{
"template": "{query} Announces Strategic Partnership",
"summary": "The company revealed a new collaboration that could expand market reach significantly.",
"source": "Mock Business Wire",
},
]
articles = []
for i, template in enumerate(article_templates):
current_date = datetime.strptime(start_date, "%Y-%m-%d") + timedelta(days=i)
articles.append(
{
"headline": template["template"].format(query=query),
"summary": template["summary"].format(query=query),
"url": f"https://example.com/google/{query.lower()}-{i}",
"source": template["source"],
"date": current_date.strftime("%Y-%m-%d"),
"entities": [query],
}
)
return {
"query": query,
"period": {"start": start_date, "end": end_date},
"articles": articles,
"metadata": {
"source": "mock_google_news",
"article_count": len(articles),
"retrieved_at": datetime.utcnow().isoformat(),
},
}
def test_online_mode_with_mock_clients():
"""Test NewsService in online mode with mock clients."""
print("📰 Testing NewsService - Online Mode")
# Create mock clients and real repository
mock_finnhub = MockFinnhubClient()
mock_google = MockGoogleNewsClient()
real_repo = NewsRepository("test_data")
# Create service in online mode
service = NewsService(
finnhub_client=mock_finnhub,
google_client=mock_google,
repository=real_repo,
online_mode=True,
data_dir="test_data",
)
try:
# Test company news context
context = service.get_company_news_context(
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
)
print(f"✅ Company news context created: {context.__class__.__name__}")
print(f" Symbol: {context.symbol}")
print(f" Period: {context.period}")
print(f" Articles: {len(context.articles)}")
print(f" Sentiment score: {context.sentiment_summary.score:.3f}")
print(f" Sentiment confidence: {context.sentiment_summary.confidence:.3f}")
print(f" Sources: {context.sources}")
# Validate required fields
assert context.symbol == "AAPL"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-01-05"
assert len(context.articles) > 0
assert (
context.sentiment_summary.score >= -1.0
and context.sentiment_summary.score <= 1.0
)
assert "data_quality" in context.metadata
print("✅ Basic validation passed")
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
parsed = json.loads(json_output)
print(f"✅ JSON serialization: {len(json_output)} characters")
print(f" Top-level keys: {list(parsed.keys())}")
return True
except Exception as e:
print(f"❌ Online mode test failed: {e}")
return False
def test_global_news_context():
"""Test global news functionality."""
print("\n🌍 Testing Global News Context")
mock_google = MockGoogleNewsClient()
real_repo = NewsRepository("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_google,
repository=real_repo,
online_mode=True,
data_dir="test_data",
)
try:
# Test global news with categories
context = service.get_global_news_context(
start_date="2024-01-01",
end_date="2024-01-03",
categories=["economy", "markets"],
)
print("✅ Global news context created")
print(f" Symbol: {context.symbol}") # Should be None for global news
print(f" Articles: {len(context.articles)}")
print(f" Categories searched: {context.metadata.get('categories', [])}")
print(f" Sentiment score: {context.sentiment_summary.score:.3f}")
# Validate global news structure
assert context.symbol is None # Global news shouldn't have a symbol
assert len(context.articles) > 0
assert "categories" in context.metadata
print("✅ Global news validation passed")
return True
except Exception as e:
print(f"❌ Global news test failed: {e}")
return False
def test_offline_mode_with_real_repository():
"""Test NewsService in offline mode with real repository."""
print("\n💾 Testing NewsService - Offline Mode")
# Create service in offline mode (no clients)
real_repo = NewsRepository("test_data")
service = NewsService(
finnhub_client=None,
google_client=None,
repository=real_repo,
online_mode=False,
data_dir="test_data",
)
try:
# Test offline context (will likely return empty data)
context = service.get_company_news_context(
symbol="AAPL", start_date="2024-01-01", end_date="2024-01-05"
)
print(f"✅ Offline context created: {context.__class__.__name__}")
print(f" Symbol: {context.symbol}")
print(f" Articles: {len(context.articles)}")
print(f" Data quality: {context.metadata.get('data_quality')}")
print(f" Service mode: online={service.is_online()}")
# Should handle empty data gracefully
assert context.symbol == "AAPL"
assert isinstance(context.articles, list)
assert isinstance(context.sentiment_summary, SentimentScore)
assert "data_quality" in context.metadata
print("✅ Offline mode graceful handling verified")
return True
except Exception as e:
print(f"❌ Offline mode test failed: {e}")
return False
def test_sentiment_analysis():
"""Test sentiment analysis functionality."""
print("\n😊 Testing Sentiment Analysis")
# Create service with custom articles for sentiment testing
class SentimentTestClient(BaseClient):
def test_connection(self):
return True
def get_data(self, query, start_date, end_date, **kwargs):
return {
"query": query,
"articles": [
{
"headline": f"{query} Soars on Excellent Earnings Report",
"summary": "Great performance with strong growth and positive outlook for investors.",
"source": "Positive News",
"date": start_date,
"entities": [query],
},
{
"headline": f"{query} Faces Challenges in Market Downturn",
"summary": "Concerns about declining revenue and poor market conditions affecting performance.",
"source": "Negative News",
"date": end_date,
"entities": [query],
},
],
}
sentiment_client = SentimentTestClient()
service = NewsService(
finnhub_client=None,
google_client=sentiment_client,
repository=None,
online_mode=True,
)
try:
context = service.get_context(
"TEST", "2024-01-01", "2024-01-02", sources=["google"]
)
print("✅ Sentiment analysis completed")
print(f" Articles analyzed: {len(context.articles)}")
print(f" Overall sentiment: {context.sentiment_summary.score:.3f}")
print(f" Confidence: {context.sentiment_summary.confidence:.3f}")
print(f" Label: {context.sentiment_summary.label}")
# Validate sentiment processing
assert len(context.articles) == 2
assert (
context.sentiment_summary.score >= -1.0
and context.sentiment_summary.score <= 1.0
)
assert (
context.sentiment_summary.confidence >= 0.0
and context.sentiment_summary.confidence <= 1.0
)
assert context.sentiment_summary.label in ["positive", "negative", "neutral"]
# Check individual article sentiments
for article in context.articles:
if article.sentiment:
assert (
article.sentiment.score >= -1.0 and article.sentiment.score <= 1.0
)
print("✅ Sentiment analysis validation passed")
return True
except Exception as e:
print(f"❌ Sentiment analysis test failed: {e}")
return False
def test_multiple_source_aggregation():
"""Test aggregation from multiple news sources."""
print("\n🔄 Testing Multiple Source Aggregation")
mock_finnhub = MockFinnhubClient()
mock_google = MockGoogleNewsClient()
real_repo = NewsRepository("test_data")
service = NewsService(
finnhub_client=mock_finnhub,
google_client=mock_google,
repository=real_repo,
online_mode=True,
)
try:
# Test with both sources
context = service.get_context(
query="MSFT",
start_date="2024-01-01",
end_date="2024-01-03",
symbol="MSFT",
sources=["finnhub", "google"],
)
print("✅ Multi-source aggregation completed")
print(f" Total articles: {len(context.articles)}")
print(f" Unique sources: {context.sources}")
print(f" Sources used: {context.metadata.get('sources_used', [])}")
# Should have articles from both sources
assert len(context.articles) > 0
assert len(context.sources) > 0
# Check that articles from different sources are present
source_counts = {}
for article in context.articles:
source = article.source
source_counts[source] = source_counts.get(source, 0) + 1
print(f" Source distribution: {source_counts}")
print("✅ Multi-source aggregation validated")
return True
except Exception as e:
print(f"❌ Multi-source test failed: {e}")
return False
def test_json_structure_validation():
"""Test detailed JSON structure validation."""
print("\n📄 Testing JSON Structure")
mock_google = MockGoogleNewsClient()
service = NewsService(
finnhub_client=None,
google_client=mock_google,
repository=None,
online_mode=True,
)
try:
context = service.get_context(
"TSLA", "2024-01-01", "2024-01-03", sources=["google"]
)
json_str = context.model_dump_json(indent=2)
data = json.loads(json_str)
# Validate required structure
required_fields = [
"symbol",
"period",
"articles",
"sentiment_summary",
"article_count",
"sources",
"metadata",
]
for field in required_fields:
assert field in data, f"Missing field: {field}"
# Validate period structure
period = data["period"]
assert "start" in period and "end" in period
# Validate articles structure
assert isinstance(data["articles"], list)
if data["articles"]:
first_article = data["articles"][0]
required_article_fields = ["headline", "source", "date"]
for field in required_article_fields:
assert field in first_article, f"Missing article field: {field}"
# Validate sentiment structure
sentiment = data["sentiment_summary"]
assert (
"score" in sentiment and "confidence" in sentiment and "label" in sentiment
)
assert -1.0 <= sentiment["score"] <= 1.0
assert 0.0 <= sentiment["confidence"] <= 1.0
# Validate metadata
metadata = data["metadata"]
assert "data_quality" in metadata
assert "service" in metadata
print("✅ JSON structure validation passed")
print(f" Fields: {list(data.keys())}")
print(f" Articles: {len(data['articles'])}")
print(f" Sentiment score: {sentiment['score']:.3f}")
return True
except Exception as e:
print(f"❌ JSON structure test failed: {e}")
return False
def test_force_refresh_parameter():
"""Test the force_refresh parameter functionality."""
print("\n🔄 Testing Force Refresh Parameter")
try:
mock_google = MockGoogleNewsClient()
real_repo = NewsRepository("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_google,
repository=real_repo,
online_mode=True,
)
# Test normal flow (should use repository if available)
normal_context = service.get_context(
"AAPL", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=False
)
# Test force refresh (should bypass repository and use client)
refresh_context = service.get_context(
"AAPL", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=True
)
# Both should return valid contexts
assert isinstance(normal_context, NewsContext)
assert isinstance(refresh_context, NewsContext)
assert normal_context.symbol == "AAPL"
assert refresh_context.symbol == "AAPL"
# Check metadata indicates source
refresh_metadata = refresh_context.metadata
assert "force_refresh" in refresh_metadata
assert refresh_metadata["force_refresh"]
print("✅ Force refresh parameter test passed")
return True
except Exception as e:
print(f"❌ Force refresh test failed: {e}")
return False
def test_local_first_strategy():
"""Test that the service checks local data first when available."""
print("\n🏠 Testing Local-First Strategy")
try:
class MockRepositoryWithData(NewsRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True # Pretend we have the data
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"query": kwargs.get("query", "TEST"),
"symbol": kwargs.get("symbol"),
"articles": [
{
"headline": "Test Article from Local Cache",
"summary": "This article came from local repository",
"source": "Local Cache",
"date": "2024-01-01",
"url": "https://local.cache/test",
"entities": ["TEST"],
}
],
"metadata": {"source": "test_repository"},
}
mock_client = MockGoogleNewsClient()
mock_repo = MockRepositoryWithData("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_client,
repository=mock_repo,
online_mode=True,
)
# Should use local data since repository has_data_for_period returns True
context = service.get_context(
"TEST", "2024-01-01", "2024-01-31", sources=["google"]
)
# Verify we used local data
assert context.metadata.get("data_source") == "local_cache"
assert len(context.articles) == 1 # From mock repository
assert context.articles[0].headline == "Test Article from Local Cache"
print("✅ Local-first strategy test passed")
return True
except Exception as e:
print(f"❌ Local-first strategy test failed: {e}")
return False
def test_local_first_fallback_to_api():
"""Test that service falls back to API when local data is insufficient."""
print("\n🔄 Testing Local-First Fallback to API")
try:
class MockRepositoryWithoutData(NewsRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return False # Pretend we don't have the data
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"query": kwargs.get("query", "TEST"),
"articles": [],
"metadata": {},
}
def store_data(
self,
symbol: str,
data: dict[str, Any],
overwrite: bool = False,
**kwargs,
) -> bool:
return True # Pretend storage was successful
mock_client = MockGoogleNewsClient()
mock_repo = MockRepositoryWithoutData("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_client,
repository=mock_repo,
online_mode=True,
)
# Should fall back to API since repository doesn't have data
context = service.get_context(
"TEST", "2024-01-01", "2024-01-31", sources=["google"]
)
# Verify we used API data
assert context.metadata.get("data_source") == "live_api"
assert len(context.articles) > 0 # From mock client
print("✅ Local-first fallback to API test passed")
return True
except Exception as e:
print(f"❌ Local-first fallback test failed: {e}")
return False
def test_force_refresh_bypasses_local_data():
"""Test that force_refresh=True bypasses local data even when available."""
print("\n⚡ Testing Force Refresh Bypasses Local Data")
try:
class MockRepositoryAlwaysHasData(NewsRepository):
def has_data_for_period(
self, identifier: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True # Always claim we have data
def get_data(
self, query: str, start_date: str, end_date: str, **kwargs
) -> dict[str, Any]:
return {
"query": kwargs.get("query", "TEST"),
"symbol": kwargs.get("symbol"),
"articles": [
{
"headline": "Old Cached Article",
"summary": "This is from local cache",
"source": "Local Cache",
"date": "2024-01-01",
"url": "https://cache.local/old",
"entities": ["TEST"],
}
],
"metadata": {"source": "local"},
}
def clear_data(
self, symbol: str, start_date: str, end_date: str, **kwargs
) -> bool:
return True
def store_data(
self,
symbol: str,
data: dict[str, Any],
overwrite: bool = False,
**kwargs,
) -> bool:
return True
mock_client = MockGoogleNewsClient()
mock_repo = MockRepositoryAlwaysHasData("test_data")
service = NewsService(
finnhub_client=None,
google_client=mock_client,
repository=mock_repo,
online_mode=True,
)
# Force refresh should bypass local data
context = service.get_context(
"TEST", "2024-01-01", "2024-01-31", sources=["google"], force_refresh=True
)
# Verify we used API data (force refresh)
assert context.metadata.get("data_source") == "live_api_refresh"
assert context.metadata.get("force_refresh")
# Should have fresh data from client, not the old cached article
assert len(context.articles) > 1 # Client returns multiple articles
print("✅ Force refresh bypasses local data test passed")
return True
except Exception as e:
print(f"❌ Force refresh bypass test failed: {e}")
return False
def main():
"""Run all NewsService tests."""
print("🧪 Testing NewsService\n")
tests = [
test_online_mode_with_mock_clients,
test_global_news_context,
test_offline_mode_with_real_repository,
test_sentiment_analysis,
test_multiple_source_aggregation,
test_json_structure_validation,
test_force_refresh_parameter,
test_local_first_strategy,
test_local_first_fallback_to_api,
test_force_refresh_bypasses_local_data,
]
passed = 0
failed = 0
for test in tests:
try:
if test():
passed += 1
else:
failed += 1
except Exception as e:
print(f"❌ Test {test.__name__} crashed: {e}")
failed += 1
print("\n📊 NewsService Test Results:")
print(f" ✅ Passed: {passed}")
print(f" ❌ Failed: {failed}")
if failed == 0:
print("🎉 All NewsService tests passed!")
else:
print("⚠️ Some tests failed - check output above")
return failed == 0
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -0,0 +1,378 @@
#!/usr/bin/env python3
"""
Test SocialMediaService with mock RedditClient and real SocialRepository.
"""
import os
import sys
from datetime import datetime, timedelta
from typing import Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath("."))
from tradingagents.clients.base import BaseClient
from tradingagents.models.context import (
DataQuality,
PostData,
SentimentScore,
SocialContext,
)
from tradingagents.repositories.social_repository import SocialRepository
from tradingagents.services.social_media_service import SocialMediaService
class MockRedditClient(BaseClient):
"""Mock Reddit client that returns sample social media data."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.connection_works = True
def test_connection(self) -> bool:
return self.connection_works
def get_data(self, *args, **kwargs) -> dict[str, Any]:
"""Not used directly by SocialMediaService."""
return {}
def search_posts(
self,
query: str,
subreddit_names: list[str],
limit: int = 25,
time_filter: str = "week",
) -> list[dict[str, Any]]:
"""Return mock Reddit search results."""
posts = []
# Use fixed dates that will work with our date filter
base_date = datetime(2024, 1, 2) # Within our test range
for i, subreddit in enumerate(
subreddit_names[:2]
): # Limit to 2 subreddits for testing
posts.extend(
[
{
"title": f"{query} to the moon! 🚀",
"content": f"DD on {query}: Strong fundamentals, great earnings beat. Buy and hold!",
"url": f"https://reddit.com/r/{subreddit}/post1",
"upvotes": 1500 - (i * 100),
"score": 1450 - (i * 100),
"num_comments": 234,
"created_utc": (base_date + timedelta(hours=i)).timestamp(),
"subreddit": subreddit,
"author": f"WSBtrader{i}",
"posted_date": (base_date + timedelta(hours=i)).strftime(
"%Y-%m-%d"
),
},
{
"title": f"Why I'm bearish on {query}",
"content": f"Overvalued, competition increasing, margins declining. Time to sell {query}.",
"url": f"https://reddit.com/r/{subreddit}/post2",
"upvotes": 800 - (i * 50),
"score": 750 - (i * 50),
"num_comments": 156,
"created_utc": (base_date + timedelta(hours=i + 1)).timestamp(),
"subreddit": subreddit,
"author": f"BearishTrader{i}",
"posted_date": (base_date + timedelta(hours=i + 1)).strftime(
"%Y-%m-%d"
),
},
]
)
return posts
def get_top_posts(
self, subreddit_names: list[str], limit: int = 25, time_filter: str = "week"
) -> list[dict[str, Any]]:
"""Return mock top posts from subreddits."""
posts = []
# Use fixed dates that will work with our date filter
base_date = datetime(2024, 1, 2) # Within our test range
for subreddit in subreddit_names[:2]:
posts.append(
{
"title": "Market Update: Tech stocks rally continues",
"content": "FAANG stocks leading the charge. SPY hit new ATH. Bull market confirmed.",
"url": f"https://reddit.com/r/{subreddit}/top1",
"upvotes": 2500,
"score": 2400,
"num_comments": 456,
"created_utc": base_date.timestamp(),
"subreddit": subreddit,
"author": "MarketWatcher",
"posted_date": base_date.strftime("%Y-%m-%d"),
}
)
return posts
def filter_posts_by_date(
self, posts: list[dict[str, Any]], start_date: str, end_date: str
) -> list[dict[str, Any]]:
"""Filter posts by date range."""
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
filtered = []
for post in posts:
if "posted_date" in post:
post_dt = datetime.strptime(post["posted_date"], "%Y-%m-%d")
if start_dt <= post_dt <= end_dt:
filtered.append(post)
return filtered
def test_online_mode_with_mock_reddit():
"""Test SocialMediaService in online mode with mock Reddit client."""
# Create mock client and real repository
mock_reddit = MockRedditClient()
real_repo = SocialRepository("test_data")
# Create service in online mode
service = SocialMediaService(
reddit_client=mock_reddit,
repository=real_repo,
online_mode=True,
data_dir="test_data",
)
# Test company-specific social context
context = service.get_company_social_context(
symbol="TSLA",
start_date="2024-01-01",
end_date="2024-01-05",
subreddits=["wallstreetbets", "stocks"],
force_refresh=True,
)
# Validate context structure
assert isinstance(context, SocialContext)
assert context.symbol == "TSLA"
assert context.period["start"] == "2024-01-01"
assert context.period["end"] == "2024-01-05"
assert len(context.posts) > 0
assert isinstance(context.sentiment_summary, SentimentScore)
assert context.post_count == len(context.posts)
assert "data_quality" in context.metadata
# Test JSON serialization
json_output = context.model_dump_json(indent=2)
assert len(json_output) > 0
# Validate individual posts
for post in context.posts:
assert isinstance(post, PostData)
assert post.title
assert post.author
assert post.date
assert post.score >= 0
def test_global_social_trends():
"""Test global social media trends functionality."""
mock_reddit = MockRedditClient()
real_repo = SocialRepository("test_data")
service = SocialMediaService(
reddit_client=mock_reddit, repository=real_repo, online_mode=True
)
# Test global trends
context = service.get_global_trends(
start_date="2024-01-01",
end_date="2024-01-03",
subreddits=["investing", "stocks", "wallstreetbets"],
force_refresh=True,
)
# Validate global context
assert context.symbol is None # Global trends have no specific symbol
assert len(context.posts) > 0
assert "reddit" in context.platforms
assert "subreddits" in context.metadata
def test_sentiment_analysis():
"""Test sentiment analysis on social posts."""
# Create service with posts that have clear sentiment
class SentimentTestClient(MockRedditClient):
def search_posts(self, query, subreddit_names, limit=25, time_filter="week"):
return [
{
"title": f"{query} is the best investment ever! 🚀🚀🚀",
"content": "Amazing earnings, incredible growth, bullish AF!",
"upvotes": 5000,
"score": 4900,
"num_comments": 500,
"subreddit": "wallstreetbets",
"author": "BullishTrader",
"posted_date": "2024-01-01",
},
{
"title": f"WARNING: {query} is about to crash hard",
"content": "Terrible fundamentals, overvalued, sell now before it's too late!",
"upvotes": 100,
"score": 50,
"num_comments": 30,
"subreddit": "stocks",
"author": "BearishAnalyst",
"posted_date": "2024-01-01",
},
]
sentiment_client = SentimentTestClient()
service = SocialMediaService(
reddit_client=sentiment_client, repository=None, online_mode=True
)
context = service.get_context("GME", "2024-01-01", "2024-01-02")
# Check sentiment analysis
assert context.sentiment_summary.score != 0 # Should have some sentiment
assert context.sentiment_summary.confidence > 0
assert context.sentiment_summary.label in ["positive", "negative", "neutral"]
# Check individual post sentiments
for post in context.posts:
if post.sentiment:
assert -1.0 <= post.sentiment.score <= 1.0
def test_offline_mode():
"""Test SocialMediaService in offline mode."""
real_repo = SocialRepository("test_data")
service = SocialMediaService(
reddit_client=None, repository=real_repo, online_mode=False
)
# Should handle offline gracefully
context = service.get_context("AAPL", "2024-01-01", "2024-01-05", symbol="AAPL")
assert context.symbol == "AAPL"
assert isinstance(context.posts, list)
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_engagement_metrics():
"""Test calculation of engagement metrics."""
mock_reddit = MockRedditClient()
service = SocialMediaService(
reddit_client=mock_reddit, repository=None, online_mode=True
)
context = service.get_company_social_context(
symbol="NVDA",
start_date="2024-01-01",
end_date="2024-01-02",
subreddits=["nvidia", "stocks"],
)
# Check engagement metrics in the context
assert len(context.engagement_metrics) > 0
assert (
"total_engagement" in context.engagement_metrics
or "total_engagement" in context.metadata
)
# Verify post scores
for post in context.posts:
# Posts should have score and comments
assert post.score >= 0
assert post.comments >= 0
def test_subreddit_filtering():
"""Test filtering by specific subreddits."""
mock_reddit = MockRedditClient()
service = SocialMediaService(
reddit_client=mock_reddit, repository=None, online_mode=True
)
# Test with specific subreddits
context = service.get_company_social_context(
symbol="AMD",
start_date="2024-01-01",
end_date="2024-01-02",
subreddits=["AMD_Stock", "wallstreetbets"],
)
# Check that posts are from requested subreddits
subreddit_set = set()
for post in context.posts:
if post.subreddit:
subreddit_set.add(post.subreddit)
assert len(subreddit_set) > 0
assert all(sub in ["AMD_Stock", "wallstreetbets"] for sub in subreddit_set)
def test_error_handling():
"""Test error handling with broken client."""
class BrokenRedditClient(BaseClient):
def test_connection(self):
return False
def get_data(self, *args, **kwargs):
raise Exception("Reddit API error")
def search_posts(self, *args, **kwargs):
raise Exception("Reddit API error")
def get_top_posts(self, *args, **kwargs):
raise Exception("Reddit API error")
broken_client = BrokenRedditClient()
service = SocialMediaService(
reddit_client=broken_client, repository=None, online_mode=True
)
# Should handle errors gracefully
context = service.get_context("TSLA", "2024-01-01", "2024-01-02", symbol="TSLA")
assert context.symbol == "TSLA"
assert len(context.posts) == 0
assert context.metadata.get("data_quality") == DataQuality.LOW
def test_json_structure():
"""Test JSON structure of social context."""
mock_reddit = MockRedditClient()
service = SocialMediaService(
reddit_client=mock_reddit, repository=None, online_mode=True
)
context = service.get_context("PLTR", "2024-01-01", "2024-01-02")
json_data = context.model_dump()
# Validate required fields
required_fields = [
"symbol",
"period",
"posts",
"sentiment_summary",
"post_count",
"platforms",
"metadata",
]
for field in required_fields:
assert field in json_data
# Validate posts structure
if json_data["posts"]:
first_post = json_data["posts"][0]
post_fields = ["title", "author", "date", "score"]
for field in post_fields:
assert field in first_post
# Validate sentiment structure
sentiment = json_data["sentiment_summary"]
assert "score" in sentiment
assert "confidence" in sentiment
assert "label" in sentiment

1920
uv.lock

File diff suppressed because it is too large Load Diff