TradingAgents/tradingagents/dataflows/twitter_data.py

219 lines
6.6 KiB
Python

import json
import time
from datetime import datetime
from pathlib import Path
import tweepy
from dotenv import load_dotenv
from tradingagents.config import config
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Load environment variables
load_dotenv()
# Constants
DATA_DIR = Path("data")
CACHE_FILE = DATA_DIR / ".twitter_cache.json"
USAGE_FILE = DATA_DIR / ".twitter_usage.json"
MONTHLY_LIMIT = 200
CACHE_DURATION_HOURS = 4
def _ensure_data_dir():
"""Ensure the data directory exists."""
DATA_DIR.mkdir(exist_ok=True)
def _load_json(file_path: Path) -> dict:
"""Load JSON data from a file, returning empty dict if not found."""
if not file_path.exists():
return {}
try:
with open(file_path, "r") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return {}
def _save_json(file_path: Path, data: dict):
"""Save dictionary to a JSON file."""
_ensure_data_dir()
try:
with open(file_path, "w") as f:
json.dump(data, f, indent=2)
except IOError as e:
logger.warning(f"Could not save to {file_path}: {e}")
def _get_cache_key(prefix: str, identifier: str) -> str:
"""Generate a cache key."""
return f"{prefix}:{identifier}"
def _is_cache_valid(timestamp: float) -> bool:
"""Check if the cached entry is still valid."""
age_hours = (time.time() - timestamp) / 3600
return age_hours < CACHE_DURATION_HOURS
def _check_usage_limit() -> bool:
"""Check if the monthly usage limit has been reached."""
usage_data = _load_json(USAGE_FILE)
current_month = datetime.now().strftime("%Y-%m")
# Reset usage if it's a new month
if usage_data.get("month") != current_month:
usage_data = {"month": current_month, "count": 0}
_save_json(USAGE_FILE, usage_data)
return True
return usage_data.get("count", 0) < MONTHLY_LIMIT
def _increment_usage():
"""Increment the usage counter."""
usage_data = _load_json(USAGE_FILE)
current_month = datetime.now().strftime("%Y-%m")
if usage_data.get("month") != current_month:
usage_data = {"month": current_month, "count": 0}
usage_data["count"] = usage_data.get("count", 0) + 1
_save_json(USAGE_FILE, usage_data)
def get_tweets(query: str, count: int = 10) -> str:
"""
Fetches recent tweets matching the query using Twitter API v2.
Includes caching and rate limiting.
Args:
query (str): The search query (e.g., "AAPL", "Bitcoin").
count (int): Number of tweets to retrieve (default 10).
Returns:
str: A formatted string containing the tweets or an error message.
"""
# 1. Check Cache
cache_key = _get_cache_key("search", query)
cache = _load_json(CACHE_FILE)
if cache_key in cache:
entry = cache[cache_key]
if _is_cache_valid(entry["timestamp"]):
return entry["data"] + "\n\n(Source: Local Cache)"
# 2. Check Rate Limit
if not _check_usage_limit():
return "Error: Monthly Twitter API usage limit (200 calls) reached."
bearer_token = config.validate_key("twitter_bearer_token", "Twitter")
try:
client = tweepy.Client(bearer_token=bearer_token)
# Search for recent tweets
safe_count = max(10, min(count, 100))
response = client.search_recent_tweets(
query=query,
max_results=safe_count,
tweet_fields=["created_at", "author_id", "public_metrics"],
)
# 3. Increment Usage
_increment_usage()
if not response.data:
result = f"No tweets found for query: {query}"
else:
formatted_tweets = f"## Recent Tweets for '{query}'\n\n"
for tweet in response.data:
metrics = tweet.public_metrics
formatted_tweets += f"- **{tweet.created_at}**: {tweet.text}\n"
if metrics:
formatted_tweets += f" (Likes: {metrics.get('like_count', 0)}, Retweets: {metrics.get('retweet_count', 0)})\n"
formatted_tweets += "\n"
result = formatted_tweets
# 4. Save to Cache
cache[cache_key] = {"timestamp": time.time(), "data": result}
_save_json(CACHE_FILE, cache)
return result
except Exception as e:
return f"Error fetching tweets: {str(e)}"
def get_tweets_from_user(username: str, count: int = 10) -> str:
"""
Fetches recent tweets from a specific user using Twitter API v2.
Includes caching and rate limiting.
Args:
username (str): The Twitter username (without @).
count (int): Number of tweets to retrieve (default 10).
Returns:
str: A formatted string containing the tweets or an error message.
"""
# 1. Check Cache
cache_key = _get_cache_key("user", username)
cache = _load_json(CACHE_FILE)
if cache_key in cache:
entry = cache[cache_key]
if _is_cache_valid(entry["timestamp"]):
return entry["data"] + "\n\n(Source: Local Cache)"
# 2. Check Rate Limit
if not _check_usage_limit():
return "Error: Monthly Twitter API usage limit (200 calls) reached."
bearer_token = config.validate_key("twitter_bearer_token", "Twitter")
try:
client = tweepy.Client(bearer_token=bearer_token)
# First, get the user ID
user = client.get_user(username=username)
if not user.data:
return f"Error: User '@{username}' not found."
user_id = user.data.id
# max_results must be between 5 and 100 for get_users_tweets
safe_count = max(5, min(count, 100))
response = client.get_users_tweets(
id=user_id, max_results=safe_count, tweet_fields=["created_at", "public_metrics"]
)
# 3. Increment Usage
_increment_usage()
if not response.data:
result = f"No recent tweets found for user: @{username}"
else:
formatted_tweets = f"## Recent Tweets from @{username}\n\n"
for tweet in response.data:
metrics = tweet.public_metrics
formatted_tweets += f"- **{tweet.created_at}**: {tweet.text}\n"
if metrics:
formatted_tweets += f" (Likes: {metrics.get('like_count', 0)}, Retweets: {metrics.get('retweet_count', 0)})\n"
formatted_tweets += "\n"
result = formatted_tweets
# 4. Save to Cache
cache[cache_key] = {"timestamp": time.time(), "data": result}
_save_json(CACHE_FILE, cache)
return result
except Exception as e:
return f"Error fetching tweets from user @{username}: {str(e)}"