feat: discovery pipeline enhancements with ML signal scanner

Major additions:
- ML win probability scanner: scans ticker universe using trained
  LightGBM/TabPFN model, surfaces candidates with P(WIN) above threshold
- 30-feature engineering pipeline (20 base + 10 interaction features)
  computed from OHLCV data via stockstats + pandas
- Triple-barrier labeling for training data generation
- Dataset builder and training script with calibration analysis
- Discovery enrichment: confluence scoring, short interest extraction,
  earnings estimates, options signal normalization, quant pre-score
- Configurable prompt logging (log_prompts_console flag)
- Enhanced ranker investment thesis (4-6 sentence reasoning)
- Typed DiscoveryConfig dataclass for all discovery settings
- Console price charts for visual ticker analysis

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Youssef Aitousarrah 2026-02-09 22:53:42 -08:00
parent 1d78271ef4
commit 43bdd6de11
133 changed files with 15720 additions and 7384 deletions

View File

@ -6,14 +6,32 @@ cd "$ROOT_DIR"
echo "Running pre-commit checks..."
python -m compileall -q tradingagents
# Run black formatter (auto-fix)
if python - <<'PY'
import importlib.util
raise SystemExit(0 if importlib.util.find_spec("pytest") else 1)
raise SystemExit(0 if importlib.util.find_spec("black") else 1)
PY
then
python -m pytest -q
echo "🎨 Running black formatter..."
python -m black tradingagents/ cli/ scripts/ --quiet
else
echo "pytest not installed; skipping test run."
echo "⚠️ black not installed; skipping formatting."
fi
# Run ruff linter (auto-fix, but don't fail on warnings)
if python - <<'PY'
import importlib.util
raise SystemExit(0 if importlib.util.find_spec("ruff") else 1)
PY
then
echo "🔍 Running ruff linter..."
python -m ruff check tradingagents/ cli/ scripts/ --fix --exit-zero
else
echo "⚠️ ruff not installed; skipping linting."
fi
# CRITICAL: Check for syntax errors (this will fail the commit)
echo "🐍 Checking for syntax errors..."
python -m compileall -q tradingagents cli scripts
echo "✅ Pre-commit checks passed!"

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,4 @@
from enum import Enum
from typing import List, Optional, Dict
from pydantic import BaseModel
class AnalystType(str, Enum):

View File

@ -1,7 +1,11 @@
from typing import List
import questionary
from typing import List, Optional, Tuple, Dict
from cli.models import AnalystType
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
ANALYST_ORDER = [
("Market Analyst", AnalystType.MARKET),
@ -68,9 +72,7 @@ def select_analysts() -> List[AnalystType]:
"""Select analysts using an interactive checkbox."""
choices = questionary.checkbox(
"Select Your [Analysts Team]:",
choices=[
questionary.Choice(display, value=value) for display, value in ANALYST_ORDER
],
choices=[questionary.Choice(display, value=value) for display, value in ANALYST_ORDER],
instruction="\n- Press Space to select/unselect analysts\n- Press 'a' to select/unselect all\n- Press Enter when done",
validate=lambda x: len(x) > 0 or "You must select at least one analyst.",
style=questionary.Style(
@ -102,9 +104,7 @@ def select_research_depth() -> int:
choice = questionary.select(
"Select Your [Research Depth]:",
choices=[
questionary.Choice(display, value=value) for display, value in DEPTH_OPTIONS
],
choices=[questionary.Choice(display, value=value) for display, value in DEPTH_OPTIONS],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
@ -135,28 +135,44 @@ def select_shallow_thinking_agent(provider) -> str:
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
(
"Claude Haiku 3.5 - Fast inference and standard capabilities",
"claude-3-5-haiku-latest",
),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
(
"Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities",
"claude-3-7-sonnet-latest",
),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
(
"Gemini 2.0 Flash - Next generation features, speed, and thinking",
"gemini-2.0-flash",
),
("Gemini 2.5 Flash-Lite - Ultra-fast and cost-effective", "gemini-2.5-flash-lite"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
("Gemini 2.5 Pro - Most capable Gemini model", "gemini-2.5-pro"),
("Gemini 3.0 Pro Preview - Next generation preview", "gemini-3-pro-preview"),
("Gemini 3.0 Flash Preview - Latest generation preview", "gemini-3-flash-preview"),
],
"openrouter": [
("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"),
("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"),
("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"),
(
"Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B",
"meta-llama/llama-3.3-8b-instruct:free",
),
(
"google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token",
"google/gemini-2.0-flash-exp:free",
),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"),
]
],
}
choice = questionary.select(
@ -176,9 +192,7 @@ def select_shallow_thinking_agent(provider) -> str:
).ask()
if choice is None:
console.print(
"\n[red]No shallow thinking llm engine selected. Exiting...[/red]"
)
console.print("\n[red]No shallow thinking llm engine selected. Exiting...[/red]")
exit(1)
return choice
@ -200,30 +214,46 @@ def select_deep_thinking_agent(provider) -> str:
("o1 - Premier reasoning and problem-solving model", "o1"),
],
"anthropic": [
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
(
"Claude Haiku 3.5 - Fast inference and standard capabilities",
"claude-3-5-haiku-latest",
),
("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"),
("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"),
(
"Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities",
"claude-3-7-sonnet-latest",
),
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
(
"Gemini 2.0 Flash - Next generation features, speed, and thinking",
"gemini-2.0-flash",
),
("Gemini 2.5 Flash-Lite - Ultra-fast and cost-effective", "gemini-2.5-flash-lite"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
("Gemini 2.5 Pro - Most capable Gemini model", "gemini-2.5-pro"),
("Gemini 3.0 Pro Preview - Next generation preview", "gemini-3-pro-preview"),
("Gemini 3.0 Flash Preview - Latest generation preview", "gemini-3-flash-preview"),
],
"openrouter": [
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"),
(
"DeepSeek V3 - a 685B-parameter, mixture-of-experts model",
"deepseek/deepseek-chat-v3-0324:free",
),
(
"Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.",
"deepseek/deepseek-chat-v3-0324:free",
),
],
"ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"),
]
],
}
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
@ -246,6 +276,7 @@ def select_deep_thinking_agent(provider) -> str:
return choice
def select_llm_provider() -> tuple[str, str]:
"""Select the OpenAI api url using interactive selection."""
# Define OpenAI api options with their corresponding endpoints
@ -254,14 +285,13 @@ def select_llm_provider() -> tuple[str, str]:
("Anthropic", "https://api.anthropic.com/"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Ollama", "http://localhost:11434/v1"),
]
choice = questionary.select(
"Select your LLM Provider:",
choices=[
questionary.Choice(display, value=(display, value))
for display, value in BASE_URLS
questionary.Choice(display, value=(display, value)) for display, value in BASE_URLS
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
@ -272,12 +302,12 @@ def select_llm_provider() -> tuple[str, str]:
]
),
).ask()
if choice is None:
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
exit(1)
display_name, url = choice
print(f"You selected: {display_name}\tURL: {url}")
logger.info(f"You selected: {display_name}\tURL: {url}")
return display_name, url

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,44 @@
# ML Win Probability Model — TabPFN + Triple-Barrier
## Overview
Add an ML model that predicts win probability for each discovery candidate.
- **Training data**: Universe-wide historical simulation (~375K labeled samples)
- **Model**: TabPFN (foundation model for tabular data) with LightGBM fallback
- **Labels**: Triple-barrier method (+5% profit, -3% stop loss, 7-day timeout)
- **Integration**: Adds `ml_win_probability` field during enrichment
## Components
### 1. Feature Engineering (`tradingagents/ml/feature_engineering.py`)
Shared feature extraction used by both training and inference.
20 features computed locally from OHLCV via stockstats + pandas.
### 2. Dataset Builder (`scripts/build_ml_dataset.py`)
- Fetches OHLCV for ~500 stocks × 3 years
- Computes features locally (no API calls for indicators)
- Applies triple-barrier labels
- Outputs `data/ml/training_dataset.parquet`
### 3. Model Trainer (`scripts/train_ml_model.py`)
- Time-based train/validation split
- TabPFN or LightGBM training
- Walk-forward evaluation
- Outputs `data/ml/tabpfn_model.pkl` + `data/ml/metrics.json`
### 4. Pipeline Integration
- `tradingagents/ml/predictor.py` — model loading + inference
- `tradingagents/dataflows/discovery/filter.py` — call predictor during enrichment
- `tradingagents/dataflows/discovery/ranker.py` — surface in LLM prompt
## Triple-Barrier Labels
```
+1 (WIN): Price hits +5% within 7 trading days
-1 (LOSS): Price hits -3% within 7 trading days
0 (TIMEOUT): Neither barrier hit
```
## Features (20)
All computed locally from OHLCV — zero API calls for indicators.
rsi_14, macd, macd_signal, macd_hist, atr_pct, bb_width_pct, bb_position,
adx, mfi, stoch_k, volume_ratio_5d, volume_ratio_20d, return_1d, return_5d,
return_20d, sma50_distance, sma200_distance, high_low_range, gap_pct, log_market_cap

11
main.py
View File

@ -1,11 +1,14 @@
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from dotenv import load_dotenv
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.utils.logger import get_logger
# Load environment variables from .env file
load_dotenv()
logger = get_logger(__name__)
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-4o-mini" # Use a different model
@ -25,7 +28,7 @@ ta = TradingAgentsGraph(debug=True, config=config)
# forward propagate
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
logger.info(decision)
# Memorize mistakes and reflect
# ta.reflect_and_remember(1000) # parameter is the position returns

View File

@ -12,6 +12,7 @@ dependencies = [
"eodhd>=1.0.32",
"feedparser>=6.0.11",
"finnhub-python>=2.4.23",
"google-genai>=1.60.0",
"grip>=4.6.2",
"langchain-anthropic>=0.3.15",
"langchain-experimental>=0.3.4",
@ -23,13 +24,50 @@ dependencies = [
"praw>=7.8.1",
"pytz>=2025.2",
"questionary>=2.1.0",
"rapidfuzz>=3.14.3",
"redis>=6.2.0",
"requests>=2.32.4",
"rich>=14.0.0",
"plotext>=5.2.8",
"plotille>=5.0.0",
"setuptools>=80.9.0",
"stockstats>=0.6.5",
"tavily>=1.1.0",
"tqdm>=4.67.1",
"tushare>=1.4.21",
"typing-extensions>=4.14.0",
"yfinance>=0.2.63",
"streamlit>=1.40.0",
"plotly>=5.18.0",
"lightgbm>=4.6.0",
"tabpfn>=2.1.3",
]
[dependency-groups]
dev = [
"black>=24.0.0",
"ruff>=0.8.0",
"pytest>=8.0.0",
]
[tool.black]
line-length = 100
target-version = ['py310']
include = '\.pyi?$'
[tool.ruff]
line-length = 100
target-version = "py310"
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
]
ignore = [
"E501", # line too long (handled by black)
]

View File

@ -25,3 +25,7 @@ questionary
langchain_anthropic
langchain-google-genai
tweepy
plotext
plotille
streamlit>=1.40.0
plotly>=5.18.0

View File

@ -13,140 +13,174 @@ Usage:
python scripts/analyze_insider_transactions.py AAPL --csv # Save to CSV
"""
import yfinance as yf
import pandas as pd
import sys
import os
import sys
from datetime import datetime
from pathlib import Path
import pandas as pd
import yfinance as yf
sys.path.insert(0, str(Path(__file__).parent.parent))
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def classify_transaction(text):
"""Classify transaction type based on text description."""
if pd.isna(text) or text == '':
return 'Grant/Exercise'
if pd.isna(text) or text == "":
return "Grant/Exercise"
text_lower = str(text).lower()
if 'sale' in text_lower:
return 'Sale'
elif 'purchase' in text_lower or 'buy' in text_lower:
return 'Purchase'
elif 'gift' in text_lower:
return 'Gift'
if "sale" in text_lower:
return "Sale"
elif "purchase" in text_lower or "buy" in text_lower:
return "Purchase"
elif "gift" in text_lower:
return "Gift"
else:
return 'Other'
return "Other"
def analyze_insider_transactions(ticker: str, save_csv: bool = False, output_dir: str = None):
"""Analyze and aggregate insider transactions for a given ticker.
Args:
ticker: Stock ticker symbol
save_csv: Whether to save results to CSV files
output_dir: Directory to save CSV files (default: current directory)
Returns:
Dictionary with DataFrames: 'by_position', 'yearly', 'sentiment'
"""
print(f"\n{'='*80}")
print(f"INSIDER TRANSACTIONS ANALYSIS: {ticker.upper()}")
print(f"{'='*80}")
result = {'by_position': None, 'by_person': None, 'yearly': None, 'sentiment': None}
logger.info(f"\n{'='*80}")
logger.info(f"INSIDER TRANSACTIONS ANALYSIS: {ticker.upper()}")
logger.info(f"{'='*80}")
result = {"by_position": None, "by_person": None, "yearly": None, "sentiment": None}
try:
ticker_obj = yf.Ticker(ticker.upper())
data = ticker_obj.insider_transactions
if data is None or data.empty:
print(f"No insider transaction data found for {ticker}")
logger.warning(f"No insider transaction data found for {ticker}")
return result
# Parse transaction type and year
data['Transaction'] = data['Text'].apply(classify_transaction)
data['Year'] = pd.to_datetime(data['Start Date']).dt.year
data["Transaction"] = data["Text"].apply(classify_transaction)
data["Year"] = pd.to_datetime(data["Start Date"]).dt.year
# ============================================================
# BY POSITION, YEAR, TRANSACTION TYPE
# ============================================================
print(f"\n## BY POSITION\n")
agg = data.groupby(['Position', 'Year', 'Transaction']).agg({
'Shares': 'sum',
'Value': 'sum'
}).reset_index()
agg['Ticker'] = ticker.upper()
result['by_position'] = agg
for position in sorted(agg['Position'].unique()):
print(f"\n### {position}")
print("-" * 50)
pos_data = agg[agg['Position'] == position].sort_values(['Year', 'Transaction'], ascending=[False, True])
logger.info("\n## BY POSITION\n")
agg = (
data.groupby(["Position", "Year", "Transaction"])
.agg({"Shares": "sum", "Value": "sum"})
.reset_index()
)
agg["Ticker"] = ticker.upper()
result["by_position"] = agg
for position in sorted(agg["Position"].unique()):
logger.info(f"\n### {position}")
logger.info("-" * 50)
pos_data = agg[agg["Position"] == position].sort_values(
["Year", "Transaction"], ascending=[False, True]
)
for _, row in pos_data.iterrows():
value_str = f"${row['Value']:>15,.0f}" if pd.notna(row['Value']) and row['Value'] > 0 else f"{'N/A':>16}"
print(f" {row['Year']} | {row['Transaction']:15} | {row['Shares']:>12,.0f} shares | {value_str}")
value_str = (
f"${row['Value']:>15,.0f}"
if pd.notna(row["Value"]) and row["Value"] > 0
else f"{'N/A':>16}"
)
logger.info(
f" {row['Year']} | {row['Transaction']:15} | {row['Shares']:>12,.0f} shares | {value_str}"
)
# ============================================================
# BY INSIDER
# ============================================================
print(f"\n\n{'='*80}")
print("INSIDER TRANSACTIONS BY PERSON")
print(f"{'='*80}")
logger.info(f"\n\n{'='*80}")
logger.info("INSIDER TRANSACTIONS BY PERSON")
logger.info(f"{'='*80}")
insider_col = "Insider"
if insider_col not in data.columns and "Name" in data.columns:
insider_col = "Name"
insider_col = 'Insider'
if insider_col not in data.columns and 'Name' in data.columns:
insider_col = 'Name'
if insider_col in data.columns:
agg_person = data.groupby([insider_col, 'Position', 'Year', 'Transaction']).agg({
'Shares': 'sum',
'Value': 'sum'
}).reset_index()
agg_person['Ticker'] = ticker.upper()
result['by_person'] = agg_person
agg_person = (
data.groupby([insider_col, "Position", "Year", "Transaction"])
.agg({"Shares": "sum", "Value": "sum"})
.reset_index()
)
agg_person["Ticker"] = ticker.upper()
result["by_person"] = agg_person
for person in sorted(agg_person[insider_col].unique()):
print(f"\n### {str(person)}")
print("-" * 50)
p_data = agg_person[agg_person[insider_col] == person].sort_values(['Year', 'Transaction'], ascending=[False, True])
logger.info(f"\n### {str(person)}")
logger.info("-" * 50)
p_data = agg_person[agg_person[insider_col] == person].sort_values(
["Year", "Transaction"], ascending=[False, True]
)
for _, row in p_data.iterrows():
value_str = f"${row['Value']:>15,.0f}" if pd.notna(row['Value']) and row['Value'] > 0 else f"{'N/A':>16}"
pos_str = str(row['Position'])[:25]
print(f" {row['Year']} | {pos_str:25} | {row['Transaction']:15} | {row['Shares']:>12,.0f} shares | {value_str}")
value_str = (
f"${row['Value']:>15,.0f}"
if pd.notna(row["Value"]) and row["Value"] > 0
else f"{'N/A':>16}"
)
pos_str = str(row["Position"])[:25]
logger.info(
f" {row['Year']} | {pos_str:25} | {row['Transaction']:15} | {row['Shares']:>12,.0f} shares | {value_str}"
)
else:
print(f"Warning: Could not find 'Insider' or 'Name' column in data. Columns: {data.columns.tolist()}")
logger.warning(
f"Warning: Could not find 'Insider' or 'Name' column in data. Columns: {data.columns.tolist()}"
)
# ============================================================
# YEARLY SUMMARY
# ============================================================
print(f"\n\n{'='*80}")
print("YEARLY SUMMARY BY TRANSACTION TYPE")
print(f"{'='*80}")
yearly = data.groupby(['Year', 'Transaction']).agg({
'Shares': 'sum',
'Value': 'sum'
}).reset_index()
yearly['Ticker'] = ticker.upper()
result['yearly'] = yearly
for year in sorted(yearly['Year'].unique(), reverse=True):
print(f"\n{year}:")
year_data = yearly[yearly['Year'] == year].sort_values('Transaction')
logger.info(f"\n\n{'='*80}")
logger.info("YEARLY SUMMARY BY TRANSACTION TYPE")
logger.info(f"{'='*80}")
yearly = (
data.groupby(["Year", "Transaction"])
.agg({"Shares": "sum", "Value": "sum"})
.reset_index()
)
yearly["Ticker"] = ticker.upper()
result["yearly"] = yearly
for year in sorted(yearly["Year"].unique(), reverse=True):
logger.info(f"\n{year}:")
year_data = yearly[yearly["Year"] == year].sort_values("Transaction")
for _, row in year_data.iterrows():
value_str = f"${row['Value']:>15,.0f}" if pd.notna(row['Value']) and row['Value'] > 0 else f"{'N/A':>16}"
print(f" {row['Transaction']:15} | {row['Shares']:>12,.0f} shares | {value_str}")
value_str = (
f"${row['Value']:>15,.0f}"
if pd.notna(row["Value"]) and row["Value"] > 0
else f"{'N/A':>16}"
)
logger.info(f" {row['Transaction']:15} | {row['Shares']:>12,.0f} shares | {value_str}")
# ============================================================
# OVERALL SENTIMENT
# ============================================================
print(f"\n\n{'='*80}")
print("INSIDER SENTIMENT SUMMARY")
print(f"{'='*80}\n")
total_sales = data[data['Transaction'] == 'Sale']['Value'].sum()
total_purchases = data[data['Transaction'] == 'Purchase']['Value'].sum()
sales_count = len(data[data['Transaction'] == 'Sale'])
purchases_count = len(data[data['Transaction'] == 'Purchase'])
logger.info(f"\n\n{'='*80}")
logger.info("INSIDER SENTIMENT SUMMARY")
logger.info(f"{'='*80}\n")
total_sales = data[data["Transaction"] == "Sale"]["Value"].sum()
total_purchases = data[data["Transaction"] == "Purchase"]["Value"].sum()
sales_count = len(data[data["Transaction"] == "Sale"])
purchases_count = len(data[data["Transaction"] == "Purchase"])
net_value = total_purchases - total_sales
# Determine sentiment
if total_purchases > total_sales:
sentiment = "BULLISH"
@ -156,134 +190,158 @@ def analyze_insider_transactions(ticker: str, save_csv: bool = False, output_dir
sentiment = "SLIGHTLY_BEARISH"
else:
sentiment = "NEUTRAL"
result['sentiment'] = pd.DataFrame([{
'Ticker': ticker.upper(),
'Total_Sales_Count': sales_count,
'Total_Sales_Value': total_sales,
'Total_Purchases_Count': purchases_count,
'Total_Purchases_Value': total_purchases,
'Net_Value': net_value,
'Sentiment': sentiment
}])
print(f"Total Sales: {sales_count:>5} transactions | ${total_sales:>15,.0f}")
print(f"Total Purchases: {purchases_count:>5} transactions | ${total_purchases:>15,.0f}")
result["sentiment"] = pd.DataFrame(
[
{
"Ticker": ticker.upper(),
"Total_Sales_Count": sales_count,
"Total_Sales_Value": total_sales,
"Total_Purchases_Count": purchases_count,
"Total_Purchases_Value": total_purchases,
"Net_Value": net_value,
"Sentiment": sentiment,
}
]
)
logger.info(f"Total Sales: {sales_count:>5} transactions | ${total_sales:>15,.0f}")
logger.info(f"Total Purchases: {purchases_count:>5} transactions | ${total_purchases:>15,.0f}")
if sentiment == "BULLISH":
print(f"\n⚡ BULLISH: Insiders are net BUYERS (${net_value:,.0f} net buying)")
logger.info(f"\n⚡ BULLISH: Insiders are net BUYERS (${net_value:,.0f} net buying)")
elif sentiment == "BEARISH":
print(f"\n⚠️ BEARISH: Significant insider SELLING (${-net_value:,.0f} net selling)")
logger.info(f"\n⚠️ BEARISH: Significant insider SELLING (${-net_value:,.0f} net selling)")
elif sentiment == "SLIGHTLY_BEARISH":
print(f"\n⚠️ SLIGHTLY BEARISH: More selling than buying (${-net_value:,.0f} net selling)")
logger.info(
f"\n⚠️ SLIGHTLY BEARISH: More selling than buying (${-net_value:,.0f} net selling)"
)
else:
print(f"\n📊 NEUTRAL: Balanced insider activity")
logger.info("\n📊 NEUTRAL: Balanced insider activity")
# Save to CSV if requested
if save_csv:
if output_dir is None:
output_dir = os.getcwd()
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Save by position
by_pos_file = os.path.join(output_dir, f"insider_by_position_{ticker.upper()}_{timestamp}.csv")
by_pos_file = os.path.join(
output_dir, f"insider_by_position_{ticker.upper()}_{timestamp}.csv"
)
agg.to_csv(by_pos_file, index=False)
print(f"\n📁 Saved: {by_pos_file}")
logger.info(f"\n📁 Saved: {by_pos_file}")
# Save by person
if result['by_person'] is not None:
by_person_file = os.path.join(output_dir, f"insider_by_person_{ticker.upper()}_{timestamp}.csv")
result['by_person'].to_csv(by_person_file, index=False)
print(f"📁 Saved: {by_person_file}")
if result["by_person"] is not None:
by_person_file = os.path.join(
output_dir, f"insider_by_person_{ticker.upper()}_{timestamp}.csv"
)
result["by_person"].to_csv(by_person_file, index=False)
logger.info(f"📁 Saved: {by_person_file}")
# Save yearly summary
yearly_file = os.path.join(output_dir, f"insider_yearly_{ticker.upper()}_{timestamp}.csv")
yearly_file = os.path.join(
output_dir, f"insider_yearly_{ticker.upper()}_{timestamp}.csv"
)
yearly.to_csv(yearly_file, index=False)
print(f"📁 Saved: {yearly_file}")
logger.info(f"📁 Saved: {yearly_file}")
# Save sentiment summary
sentiment_file = os.path.join(output_dir, f"insider_sentiment_{ticker.upper()}_{timestamp}.csv")
result['sentiment'].to_csv(sentiment_file, index=False)
print(f"📁 Saved: {sentiment_file}")
sentiment_file = os.path.join(
output_dir, f"insider_sentiment_{ticker.upper()}_{timestamp}.csv"
)
result["sentiment"].to_csv(sentiment_file, index=False)
logger.info(f"📁 Saved: {sentiment_file}")
except Exception as e:
print(f"Error analyzing {ticker}: {str(e)}")
logger.error(f"Error analyzing {ticker}: {str(e)}")
return result
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python analyze_insider_transactions.py TICKER [TICKER2 ...] [--csv] [--output-dir DIR]")
print("Example: python analyze_insider_transactions.py AAPL TSLA NVDA")
print(" python analyze_insider_transactions.py AAPL --csv")
print(" python analyze_insider_transactions.py AAPL --csv --output-dir ./output")
logger.info(
"Usage: python analyze_insider_transactions.py TICKER [TICKER2 ...] [--csv] [--output-dir DIR]"
)
logger.info("Example: python analyze_insider_transactions.py AAPL TSLA NVDA")
logger.info(" python analyze_insider_transactions.py AAPL --csv")
logger.info(" python analyze_insider_transactions.py AAPL --csv --output-dir ./output")
sys.exit(1)
# Parse arguments
args = sys.argv[1:]
save_csv = '--csv' in args
save_csv = "--csv" in args
output_dir = None
if '--output-dir' in args:
idx = args.index('--output-dir')
if "--output-dir" in args:
idx = args.index("--output-dir")
if idx + 1 < len(args):
output_dir = args[idx + 1]
args = args[:idx] + args[idx+2:]
args = args[:idx] + args[idx + 2 :]
else:
print("Error: --output-dir requires a directory path")
logger.error("Error: --output-dir requires a directory path")
sys.exit(1)
if save_csv:
args.remove('--csv')
tickers = [t for t in args if not t.startswith('--')]
args.remove("--csv")
tickers = [t for t in args if not t.startswith("--")]
# Collect all results for combined CSV
all_by_position = []
all_by_person = []
all_yearly = []
all_sentiment = []
for ticker in tickers:
result = analyze_insider_transactions(ticker, save_csv=save_csv, output_dir=output_dir)
if result['by_position'] is not None:
all_by_position.append(result['by_position'])
if result['by_person'] is not None:
all_by_person.append(result['by_person'])
if result['yearly'] is not None:
all_yearly.append(result['yearly'])
if result['sentiment'] is not None:
all_sentiment.append(result['sentiment'])
if result["by_position"] is not None:
all_by_position.append(result["by_position"])
if result["by_person"] is not None:
all_by_person.append(result["by_person"])
if result["yearly"] is not None:
all_yearly.append(result["yearly"])
if result["sentiment"] is not None:
all_sentiment.append(result["sentiment"])
# If multiple tickers and CSV mode, also save combined files
if save_csv and len(tickers) > 1:
if output_dir is None:
output_dir = os.getcwd()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if all_by_position:
combined_pos = pd.concat(all_by_position, ignore_index=True)
combined_pos_file = os.path.join(output_dir, f"insider_by_position_combined_{timestamp}.csv")
combined_pos_file = os.path.join(
output_dir, f"insider_by_position_combined_{timestamp}.csv"
)
combined_pos.to_csv(combined_pos_file, index=False)
print(f"\n📁 Combined: {combined_pos_file}")
logger.info(f"\n📁 Combined: {combined_pos_file}")
if all_by_person:
combined_person = pd.concat(all_by_person, ignore_index=True)
combined_person_file = os.path.join(output_dir, f"insider_by_person_combined_{timestamp}.csv")
combined_person_file = os.path.join(
output_dir, f"insider_by_person_combined_{timestamp}.csv"
)
combined_person.to_csv(combined_person_file, index=False)
print(f"📁 Combined: {combined_person_file}")
logger.info(f"📁 Combined: {combined_person_file}")
if all_yearly:
combined_yearly = pd.concat(all_yearly, ignore_index=True)
combined_yearly_file = os.path.join(output_dir, f"insider_yearly_combined_{timestamp}.csv")
combined_yearly_file = os.path.join(
output_dir, f"insider_yearly_combined_{timestamp}.csv"
)
combined_yearly.to_csv(combined_yearly_file, index=False)
print(f"📁 Combined: {combined_yearly_file}")
logger.info(f"📁 Combined: {combined_yearly_file}")
if all_sentiment:
combined_sentiment = pd.concat(all_sentiment, ignore_index=True)
combined_sentiment_file = os.path.join(output_dir, f"insider_sentiment_combined_{timestamp}.csv")
combined_sentiment_file = os.path.join(
output_dir, f"insider_sentiment_combined_{timestamp}.csv"
)
combined_sentiment.to_csv(combined_sentiment_file, index=False)
print(f"📁 Combined: {combined_sentiment_file}")
logger.info(f"📁 Combined: {combined_sentiment_file}")

View File

@ -11,18 +11,23 @@ Usage:
python scripts/build_historical_memories.py
"""
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.historical_memory_builder import HistoricalMemoryBuilder
import pickle
from datetime import datetime, timedelta
from tradingagents.agents.utils.historical_memory_builder import HistoricalMemoryBuilder
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def main():
print("""
logger.info("""
TradingAgents - Historical Memory Builder
@ -30,25 +35,34 @@ def main():
# Configuration
tickers = [
"AAPL", "GOOGL", "MSFT", "NVDA", "TSLA", # Tech
"JPM", "BAC", "GS", # Finance
"XOM", "CVX", # Energy
"JNJ", "PFE", # Healthcare
"WMT", "AMZN" # Retail
"AAPL",
"GOOGL",
"MSFT",
"NVDA",
"TSLA", # Tech
"JPM",
"BAC",
"GS", # Finance
"XOM",
"CVX", # Energy
"JNJ",
"PFE", # Healthcare
"WMT",
"AMZN", # Retail
]
# Date range - last 2 years
end_date = datetime.now()
start_date = end_date - timedelta(days=730) # 2 years
print(f"Tickers: {', '.join(tickers)}")
print(f"Period: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
print(f"Lookforward: 7 days (1 week returns)")
print(f"Sample interval: 30 days (monthly)\n")
logger.info(f"Tickers: {', '.join(tickers)}")
logger.info(f"Period: {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}")
logger.info("Lookforward: 7 days (1 week returns)")
logger.info("Sample interval: 30 days (monthly)\n")
proceed = input("Proceed with memory building? (y/n): ")
if proceed.lower() != 'y':
print("Aborted.")
if proceed.lower() != "y":
logger.info("Aborted.")
return
# Build memories
@ -59,7 +73,7 @@ def main():
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
lookforward_days=7,
interval_days=30
interval_days=30,
)
# Save to disk
@ -74,39 +88,36 @@ def main():
# Save the ChromaDB collection data
# Note: ChromaDB doesn't serialize well, so we extract the data
collection = memory.situation_collection
data = {
"documents": [],
"metadatas": [],
"embeddings": [],
"ids": []
}
# Get all items from collection
results = collection.get(include=["documents", "metadatas", "embeddings"])
with open(filename, 'wb') as f:
pickle.dump({
"documents": results["documents"],
"metadatas": results["metadatas"],
"embeddings": results["embeddings"],
"ids": results["ids"],
"created_at": timestamp,
"tickers": tickers,
"config": {
"start_date": start_date.strftime("%Y-%m-%d"),
"end_date": end_date.strftime("%Y-%m-%d"),
"lookforward_days": 7,
"interval_days": 30
}
}, f)
with open(filename, "wb") as f:
pickle.dump(
{
"documents": results["documents"],
"metadatas": results["metadatas"],
"embeddings": results["embeddings"],
"ids": results["ids"],
"created_at": timestamp,
"tickers": tickers,
"config": {
"start_date": start_date.strftime("%Y-%m-%d"),
"end_date": end_date.strftime("%Y-%m-%d"),
"lookforward_days": 7,
"interval_days": 30,
},
},
f,
)
print(f"✅ Saved {agent_type} memory to {filename}")
logger.info(f"✅ Saved {agent_type} memory to {filename}")
print(f"\n🎉 Memory building complete!")
print(f" Memories saved to: {memory_dir}")
print(f"\n📝 To use these memories, update DEFAULT_CONFIG with:")
print(f' "memory_dir": "{memory_dir}"')
print(f' "load_historical_memories": True')
logger.info("\n🎉 Memory building complete!")
logger.info(f" Memories saved to: {memory_dir}")
logger.info("\n📝 To use these memories, update DEFAULT_CONFIG with:")
logger.info(f' "memory_dir": "{memory_dir}"')
logger.info(' "load_historical_memories": True')
if __name__ == "__main__":

278
scripts/build_ml_dataset.py Normal file
View File

@ -0,0 +1,278 @@
#!/usr/bin/env python3
"""Build ML training dataset from historical OHLCV data.
Fetches price data for a universe of liquid stocks, computes features
locally via stockstats, and applies triple-barrier labels.
Usage:
python scripts/build_ml_dataset.py
python scripts/build_ml_dataset.py --stocks 100 --years 2
python scripts/build_ml_dataset.py --ticker-file data/tickers_top50.txt
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
import numpy as np
import pandas as pd
# Add project root to path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
from tradingagents.ml.feature_engineering import (
FEATURE_COLUMNS,
MIN_HISTORY_ROWS,
apply_triple_barrier_labels,
compute_features_bulk,
)
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Default universe: S&P 500 most liquid by volume (top ~200)
# Can be overridden via --ticker-file
DEFAULT_TICKERS = [
# Mega-cap tech
"AAPL", "MSFT", "GOOGL", "AMZN", "NVDA", "META", "TSLA", "AVGO", "ORCL", "CRM",
"AMD", "INTC", "CSCO", "ADBE", "NFLX", "QCOM", "TXN", "AMAT", "MU", "LRCX",
"KLAC", "MRVL", "SNPS", "CDNS", "PANW", "CRWD", "FTNT", "NOW", "UBER", "ABNB",
# Financials
"JPM", "BAC", "WFC", "GS", "MS", "C", "SCHW", "BLK", "AXP", "USB",
"PNC", "TFC", "COF", "BK", "STT", "FITB", "HBAN", "RF", "CFG", "KEY",
# Healthcare
"UNH", "JNJ", "LLY", "PFE", "ABBV", "MRK", "TMO", "ABT", "DHR", "BMY",
"AMGN", "GILD", "ISRG", "VRTX", "REGN", "MDT", "SYK", "BSX", "EW", "ZTS",
# Consumer
"WMT", "PG", "KO", "PEP", "COST", "MCD", "NKE", "SBUX", "TGT", "LOW",
"HD", "TJX", "ROST", "DG", "DLTR", "EL", "CL", "KMB", "GIS", "K",
# Energy
"XOM", "CVX", "COP", "EOG", "SLB", "MPC", "PSX", "VLO", "OXY", "DVN",
"HAL", "FANG", "HES", "BKR", "KMI", "WMB", "OKE", "ET", "TRGP", "LNG",
# Industrials
"CAT", "DE", "UNP", "UPS", "HON", "RTX", "BA", "LMT", "GD", "NOC",
"GE", "MMM", "EMR", "ITW", "PH", "ROK", "ETN", "SWK", "CMI", "PCAR",
# Materials & Utilities
"LIN", "APD", "ECL", "SHW", "DD", "NEM", "FCX", "VMC", "MLM", "NUE",
"NEE", "DUK", "SO", "D", "AEP", "EXC", "SRE", "XEL", "WEC", "ES",
# REITs & Telecom
"AMT", "PLD", "CCI", "EQIX", "SPG", "O", "PSA", "DLR", "WELL", "AVB",
"T", "VZ", "TMUS", "CHTR", "CMCSA",
# High-volatility / popular retail
"COIN", "MARA", "RIOT", "PLTR", "SOFI", "HOOD", "RBLX", "SNAP", "PINS", "SQ",
"SHOP", "SE", "ROKU", "DKNG", "PENN", "WYNN", "MGM", "LVS", "DASH", "TTD",
# Biotech
"MRNA", "BNTX", "BIIB", "SGEN", "ALNY", "BMRN", "EXAS", "DXCM", "HZNP", "INCY",
]
OUTPUT_DIR = Path("data/ml")
def fetch_ohlcv(ticker: str, start: str, end: str) -> pd.DataFrame:
"""Fetch OHLCV data for a single ticker via yfinance."""
from tradingagents.dataflows.y_finance import download_history
df = download_history(
ticker,
start=start,
end=end,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
if df.empty:
return df
df = df.reset_index()
return df
def get_market_cap(ticker: str) -> float | None:
"""Get current market cap for a ticker (snapshot — used as static feature)."""
try:
import yfinance as yf
info = yf.Ticker(ticker).info
return info.get("marketCap")
except Exception:
return None
def process_ticker(
ticker: str,
start: str,
end: str,
profit_target: float,
stop_loss: float,
max_holding_days: int,
market_cap: float | None = None,
) -> pd.DataFrame | None:
"""Process a single ticker: fetch data, compute features, apply labels."""
try:
ohlcv = fetch_ohlcv(ticker, start, end)
if ohlcv.empty or len(ohlcv) < MIN_HISTORY_ROWS + max_holding_days:
logger.debug(f"{ticker}: insufficient data ({len(ohlcv)} rows), skipping")
return None
# Compute features
features = compute_features_bulk(ohlcv, market_cap=market_cap)
if features.empty:
logger.debug(f"{ticker}: feature computation failed, skipping")
return None
# Compute triple-barrier labels
close = ohlcv.set_index("Date")["Close"] if "Date" in ohlcv.columns else ohlcv["Close"]
if isinstance(close.index, pd.DatetimeIndex):
pass
else:
close.index = pd.to_datetime(close.index)
labels = apply_triple_barrier_labels(
close,
profit_target=profit_target,
stop_loss=stop_loss,
max_holding_days=max_holding_days,
)
# Align features and labels by date
combined = features.join(labels, how="inner")
# Drop rows with NaN features or labels
combined = combined.dropna(subset=["label"] + FEATURE_COLUMNS)
if combined.empty:
logger.debug(f"{ticker}: no valid rows after alignment, skipping")
return None
# Add metadata columns
combined["ticker"] = ticker
combined["date"] = combined.index
logger.info(
f"{ticker}: {len(combined)} samples "
f"(WIN={int((combined['label'] == 1).sum())}, "
f"LOSS={int((combined['label'] == -1).sum())}, "
f"TIMEOUT={int((combined['label'] == 0).sum())})"
)
return combined
except Exception as e:
logger.warning(f"{ticker}: error processing — {e}")
return None
def build_dataset(
tickers: list[str],
start: str = "2022-01-01",
end: str = "2025-12-31",
profit_target: float = 0.05,
stop_loss: float = 0.03,
max_holding_days: int = 7,
) -> pd.DataFrame:
"""Build the full training dataset across all tickers."""
all_data = []
total = len(tickers)
logger.info(f"Building ML dataset: {total} tickers, {start} to {end}")
logger.info(
f"Triple-barrier: +{profit_target*100:.0f}% profit, "
f"-{stop_loss*100:.0f}% stop, {max_holding_days}d timeout"
)
# Batch-fetch market caps
logger.info("Fetching market caps...")
market_caps = {}
for ticker in tickers:
market_caps[ticker] = get_market_cap(ticker)
time.sleep(0.05) # rate limit courtesy
for i, ticker in enumerate(tickers):
logger.info(f"[{i+1}/{total}] Processing {ticker}...")
result = process_ticker(
ticker=ticker,
start=start,
end=end,
profit_target=profit_target,
stop_loss=stop_loss,
max_holding_days=max_holding_days,
market_cap=market_caps.get(ticker),
)
if result is not None:
all_data.append(result)
# Brief pause between tickers to be polite to yfinance
if (i + 1) % 50 == 0:
logger.info(f"Progress: {i+1}/{total} tickers processed, pausing 2s...")
time.sleep(2)
if not all_data:
logger.error("No data collected — check tickers and date range")
return pd.DataFrame()
dataset = pd.concat(all_data, ignore_index=True)
logger.info(f"\n{'='*60}")
logger.info(f"Dataset built: {len(dataset)} total samples from {len(all_data)} tickers")
logger.info(f"Label distribution:")
logger.info(f" WIN (+1): {int((dataset['label'] == 1).sum()):>7} ({(dataset['label'] == 1).mean()*100:.1f}%)")
logger.info(f" LOSS (-1): {int((dataset['label'] == -1).sum()):>7} ({(dataset['label'] == -1).mean()*100:.1f}%)")
logger.info(f" TIMEOUT: {int((dataset['label'] == 0).sum()):>7} ({(dataset['label'] == 0).mean()*100:.1f}%)")
logger.info(f"Features: {len(FEATURE_COLUMNS)}")
logger.info(f"{'='*60}")
return dataset
def main():
parser = argparse.ArgumentParser(description="Build ML training dataset")
parser.add_argument("--stocks", type=int, default=None, help="Limit to N stocks from default universe")
parser.add_argument("--ticker-file", type=str, default=None, help="File with tickers (one per line)")
parser.add_argument("--start", type=str, default="2022-01-01", help="Start date (YYYY-MM-DD)")
parser.add_argument("--end", type=str, default="2025-12-31", help="End date (YYYY-MM-DD)")
parser.add_argument("--profit-target", type=float, default=0.05, help="Profit target fraction (default: 0.05)")
parser.add_argument("--stop-loss", type=float, default=0.03, help="Stop loss fraction (default: 0.03)")
parser.add_argument("--holding-days", type=int, default=7, help="Max holding days (default: 7)")
parser.add_argument("--output", type=str, default=None, help="Output parquet path")
args = parser.parse_args()
# Determine ticker list
if args.ticker_file:
with open(args.ticker_file) as f:
tickers = [line.strip().upper() for line in f if line.strip() and not line.startswith("#")]
logger.info(f"Loaded {len(tickers)} tickers from {args.ticker_file}")
else:
tickers = DEFAULT_TICKERS
if args.stocks:
tickers = tickers[: args.stocks]
# Build dataset
dataset = build_dataset(
tickers=tickers,
start=args.start,
end=args.end,
profit_target=args.profit_target,
stop_loss=args.stop_loss,
max_holding_days=args.holding_days,
)
if dataset.empty:
logger.error("Empty dataset — aborting")
sys.exit(1)
# Save
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
output_path = args.output or str(OUTPUT_DIR / "training_dataset.parquet")
dataset.to_parquet(output_path, index=False)
logger.info(f"Saved dataset to {output_path} ({os.path.getsize(output_path) / 1e6:.1f} MB)")
if __name__ == "__main__":
main()

View File

@ -9,41 +9,78 @@ This script creates memory sets optimized for:
- Long-term investing (90-day horizon, quarterly samples)
"""
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import sys
from pathlib import Path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.agents.utils.historical_memory_builder import HistoricalMemoryBuilder
import pickle
from datetime import datetime, timedelta
from tradingagents.agents.utils.historical_memory_builder import HistoricalMemoryBuilder
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Strategy configurations
STRATEGIES = {
"day_trading": {
"lookforward_days": 1, # Next day returns
"interval_days": 1, # Sample daily
"lookforward_days": 1, # Next day returns
"interval_days": 1, # Sample daily
"description": "Day Trading - Capture intraday momentum and next-day moves",
"tickers": ["SPY", "QQQ", "AAPL", "TSLA", "NVDA", "AMD", "AMZN"], # High volume
},
"swing_trading": {
"lookforward_days": 7, # Weekly returns
"interval_days": 7, # Sample weekly
"lookforward_days": 7, # Weekly returns
"interval_days": 7, # Sample weekly
"description": "Swing Trading - Capture week-long trends and momentum",
"tickers": ["AAPL", "GOOGL", "MSFT", "NVDA", "TSLA", "META", "AMZN", "AMD", "NFLX"],
"tickers": [
"AAPL",
"GOOGL",
"MSFT",
"NVDA",
"TSLA",
"META",
"AMZN",
"AMD",
"NFLX",
],
},
"position_trading": {
"lookforward_days": 30, # Monthly returns
"interval_days": 30, # Sample monthly
"lookforward_days": 30, # Monthly returns
"interval_days": 30, # Sample monthly
"description": "Position Trading - Capture monthly trends and fundamentals",
"tickers": ["AAPL", "GOOGL", "MSFT", "NVDA", "TSLA", "JPM", "BAC", "XOM", "JNJ", "WMT"],
"tickers": [
"AAPL",
"GOOGL",
"MSFT",
"NVDA",
"TSLA",
"JPM",
"BAC",
"XOM",
"JNJ",
"WMT",
],
},
"long_term_investing": {
"lookforward_days": 90, # Quarterly returns
"interval_days": 90, # Sample quarterly
"lookforward_days": 90, # Quarterly returns
"interval_days": 90, # Sample quarterly
"description": "Long-term Investing - Capture fundamental value and trends",
"tickers": ["AAPL", "GOOGL", "MSFT", "BRK.B", "JPM", "JNJ", "PG", "KO", "DIS", "V"],
"tickers": [
"AAPL",
"GOOGL",
"MSFT",
"BRK.B",
"JPM",
"JNJ",
"PG",
"KO",
"DIS",
"V",
],
},
}
@ -53,7 +90,7 @@ def build_strategy_memories(strategy_name: str, config: dict):
strategy = STRATEGIES[strategy_name]
print(f"""
logger.info(f"""
Building Memories: {strategy_name.upper().replace('_', ' ')}
@ -72,11 +109,11 @@ Tickers: {', '.join(strategy['tickers'])}
builder = HistoricalMemoryBuilder(DEFAULT_CONFIG)
memories = builder.populate_agent_memories(
tickers=strategy['tickers'],
tickers=strategy["tickers"],
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
lookforward_days=strategy['lookforward_days'],
interval_days=strategy['interval_days']
lookforward_days=strategy["lookforward_days"],
interval_days=strategy["interval_days"],
)
# Save to disk
@ -92,33 +129,36 @@ Tickers: {', '.join(strategy['tickers'])}
collection = memory.situation_collection
results = collection.get(include=["documents", "metadatas", "embeddings"])
with open(filename, 'wb') as f:
pickle.dump({
"documents": results["documents"],
"metadatas": results["metadatas"],
"embeddings": results["embeddings"],
"ids": results["ids"],
"created_at": timestamp,
"strategy": strategy_name,
"tickers": strategy['tickers'],
"config": {
"start_date": start_date.strftime("%Y-%m-%d"),
"end_date": end_date.strftime("%Y-%m-%d"),
"lookforward_days": strategy['lookforward_days'],
"interval_days": strategy['interval_days']
}
}, f)
with open(filename, "wb") as f:
pickle.dump(
{
"documents": results["documents"],
"metadatas": results["metadatas"],
"embeddings": results["embeddings"],
"ids": results["ids"],
"created_at": timestamp,
"strategy": strategy_name,
"tickers": strategy["tickers"],
"config": {
"start_date": start_date.strftime("%Y-%m-%d"),
"end_date": end_date.strftime("%Y-%m-%d"),
"lookforward_days": strategy["lookforward_days"],
"interval_days": strategy["interval_days"],
},
},
f,
)
print(f"✅ Saved {agent_type} memory to {filename}")
logger.info(f"✅ Saved {agent_type} memory to {filename}")
print(f"\n🎉 {strategy_name.replace('_', ' ').title()} memories complete!")
print(f" Saved to: {memory_dir}\n")
logger.info(f"\n🎉 {strategy_name.replace('_', ' ').title()} memories complete!")
logger.info(f" Saved to: {memory_dir}\n")
return memory_dir
def main():
print("""
logger.info("""
TradingAgents - Strategy-Specific Memory Builder
@ -131,29 +171,31 @@ This script builds optimized memories for different trading styles:
4. Long-term - 90-day returns, quarterly samples
""")
print("Available strategies:")
logger.info("Available strategies:")
for i, (name, config) in enumerate(STRATEGIES.items(), 1):
print(f" {i}. {name.replace('_', ' ').title()}")
print(f" {config['description']}")
print(f" Horizon: {config['lookforward_days']} days, Interval: {config['interval_days']} days\n")
logger.info(f" {i}. {name.replace('_', ' ').title()}")
logger.info(f" {config['description']}")
logger.info(
f" Horizon: {config['lookforward_days']} days, Interval: {config['interval_days']} days\n"
)
choice = input("Choose strategy (1-4, or 'all' for all strategies): ").strip()
if choice.lower() == 'all':
if choice.lower() == "all":
strategies_to_build = list(STRATEGIES.keys())
else:
try:
idx = int(choice) - 1
strategies_to_build = [list(STRATEGIES.keys())[idx]]
except (ValueError, IndexError):
print("Invalid choice. Exiting.")
logger.error("Invalid choice. Exiting.")
return
print(f"\nWill build memories for: {', '.join(strategies_to_build)}")
logger.info(f"\nWill build memories for: {', '.join(strategies_to_build)}")
proceed = input("Proceed? (y/n): ")
if proceed.lower() != 'y':
print("Aborted.")
if proceed.lower() != "y":
logger.info("Aborted.")
return
# Build memories for each selected strategy
@ -163,19 +205,19 @@ This script builds optimized memories for different trading styles:
results[strategy_name] = memory_dir
# Print summary
print("\n" + "="*70)
print("📊 MEMORY BUILDING COMPLETE")
print("="*70)
logger.info("\n" + "=" * 70)
logger.info("📊 MEMORY BUILDING COMPLETE")
logger.info("=" * 70)
for strategy_name, memory_dir in results.items():
print(f"\n{strategy_name.replace('_', ' ').title()}:")
print(f" Location: {memory_dir}")
print(f" Config to use:")
print(f' "memory_dir": "{memory_dir}"')
print(f' "load_historical_memories": True')
logger.info(f"\n{strategy_name.replace('_', ' ').title()}:")
logger.info(f" Location: {memory_dir}")
logger.info(" Config to use:")
logger.info(f' "memory_dir": "{memory_dir}"')
logger.info(' "load_historical_memories": True')
print("\n" + "="*70)
print("\n💡 TIP: To use a specific strategy's memories, update your config:")
print("""
logger.info("\n" + "=" * 70)
logger.info("\n💡 TIP: To use a specific strategy's memories, update your config:")
logger.info("""
config = DEFAULT_CONFIG.copy()
config["memory_dir"] = "data/memories/swing_trading" # or your strategy
config["load_historical_memories"] = True

View File

@ -12,31 +12,58 @@ Examples:
python scripts/scan_reddit_dd.py --output reports/reddit_dd_2024_01_15.md
"""
import argparse
import os
import sys
import argparse
from datetime import datetime
from pathlib import Path
from dotenv import load_dotenv
from tradingagents.utils.logger import get_logger
load_dotenv()
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from tradingagents.dataflows.reddit_api import get_reddit_undiscovered_dd
logger = get_logger(__name__)
from langchain_openai import ChatOpenAI
from tradingagents.dataflows.reddit_api import get_reddit_undiscovered_dd
def main():
parser = argparse.ArgumentParser(description='Scan Reddit for high-quality DD posts')
parser.add_argument('--hours', type=int, default=72, help='Hours to look back (default: 72)')
parser.add_argument('--limit', type=int, default=100, help='Number of posts to scan (default: 100)')
parser.add_argument('--top', type=int, default=15, help='Number of top DD to include (default: 15)')
parser.add_argument('--output', type=str, help='Output markdown file (default: reports/reddit_dd_YYYY_MM_DD.md)')
parser.add_argument('--min-score', type=int, default=55, help='Minimum quality score (default: 55)')
parser.add_argument('--model', type=str, default='gpt-4o-mini', help='LLM model to use (default: gpt-4o-mini)')
parser.add_argument('--temperature', type=float, default=0, help='LLM temperature (default: 0)')
parser.add_argument('--comments', type=int, default=10, help='Number of top comments to include (default: 10)')
parser = argparse.ArgumentParser(description="Scan Reddit for high-quality DD posts")
parser.add_argument("--hours", type=int, default=72, help="Hours to look back (default: 72)")
parser.add_argument(
"--limit", type=int, default=100, help="Number of posts to scan (default: 100)"
)
parser.add_argument(
"--top", type=int, default=15, help="Number of top DD to include (default: 15)"
)
parser.add_argument(
"--output",
type=str,
help="Output markdown file (default: reports/reddit_dd_YYYY_MM_DD.md)",
)
parser.add_argument(
"--min-score", type=int, default=55, help="Minimum quality score (default: 55)"
)
parser.add_argument(
"--model",
type=str,
default="gpt-4o-mini",
help="LLM model to use (default: gpt-4o-mini)",
)
parser.add_argument("--temperature", type=float, default=0, help="LLM temperature (default: 0)")
parser.add_argument(
"--comments",
type=int,
default=10,
help="Number of top comments to include (default: 10)",
)
args = parser.parse_args()
@ -51,36 +78,36 @@ def main():
timestamp = datetime.now().strftime("%Y_%m_%d_%H%M")
output_file = reports_dir / f"reddit_dd_{timestamp}.md"
print("=" * 70)
print("📊 REDDIT DD SCANNER")
print("=" * 70)
print(f"Lookback: {args.hours} hours")
print(f"Scan limit: {args.limit} posts")
print(f"Top results: {args.top}")
print(f"Min quality score: {args.min_score}")
print(f"LLM model: {args.model}")
print(f"Temperature: {args.temperature}")
print(f"Output: {output_file}")
print("=" * 70)
print()
logger.info("=" * 70)
logger.info("📊 REDDIT DD SCANNER")
logger.info("=" * 70)
logger.info(f"Lookback: {args.hours} hours")
logger.info(f"Scan limit: {args.limit} posts")
logger.info(f"Top results: {args.top}")
logger.info(f"Min quality score: {args.min_score}")
logger.info(f"LLM model: {args.model}")
logger.info(f"Temperature: {args.temperature}")
logger.info(f"Output: {output_file}")
logger.info("=" * 70)
logger.info("")
# Initialize LLM
print("Initializing LLM...")
logger.info("Initializing LLM...")
llm = ChatOpenAI(
model=args.model,
temperature=args.temperature,
api_key=os.getenv("OPENAI_API_KEY")
api_key=os.getenv("OPENAI_API_KEY"),
)
# Scan Reddit
print(f"\n🔍 Scanning Reddit (last {args.hours} hours)...\n")
logger.info(f"\n🔍 Scanning Reddit (last {args.hours} hours)...\n")
dd_report = get_reddit_undiscovered_dd(
lookback_hours=args.hours,
scan_limit=args.limit,
top_n=args.top,
num_comments=args.comments,
llm_evaluator=llm
llm_evaluator=llm,
)
# Add header with metadata
@ -98,47 +125,49 @@ def main():
full_report = header + dd_report
# Save to file
with open(output_file, 'w') as f:
with open(output_file, "w") as f:
f.write(full_report)
print("\n" + "=" * 70)
print(f"✅ Report saved to: {output_file}")
print("=" * 70)
logger.info("\n" + "=" * 70)
logger.info(f"✅ Report saved to: {output_file}")
logger.info("=" * 70)
# Print summary
print("\n📈 SUMMARY:")
logger.info("\n📈 SUMMARY:")
# Count quality posts by parsing the report
import re
quality_match = re.search(r'\*\*High Quality:\*\* (\d+) DD posts', dd_report)
scanned_match = re.search(r'\*\*Scanned:\*\* (\d+) posts', dd_report)
quality_match = re.search(r"\*\*High Quality:\*\* (\d+) DD posts", dd_report)
scanned_match = re.search(r"\*\*Scanned:\*\* (\d+) posts", dd_report)
if scanned_match and quality_match:
scanned = int(scanned_match.group(1))
quality = int(quality_match.group(1))
print(f" • Posts scanned: {scanned}")
print(f" • Quality DD found: {quality}")
logger.info(f" • Posts scanned: {scanned}")
logger.info(f" • Quality DD found: {quality}")
if scanned > 0:
print(f" • Quality rate: {(quality/scanned)*100:.1f}%")
logger.info(f" • Quality rate: {(quality/scanned)*100:.1f}%")
# Extract tickers
ticker_matches = re.findall(r'\*\*Ticker:\*\* \$([A-Z]+)', dd_report)
ticker_matches = re.findall(r"\*\*Ticker:\*\* \$([A-Z]+)", dd_report)
if ticker_matches:
unique_tickers = list(set(ticker_matches))
print(f" • Tickers mentioned: {', '.join(['$' + t for t in unique_tickers])}")
logger.info(f" • Tickers mentioned: {', '.join(['$' + t for t in unique_tickers])}")
print()
print("💡 TIP: Review the report and investigate promising opportunities!")
logger.info("")
logger.info("💡 TIP: Review the report and investigate promising opportunities!")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⚠️ Scan interrupted by user")
logger.warning("\n\n⚠️ Scan interrupted by user")
sys.exit(1)
except Exception as e:
print(f"\n❌ Error: {str(e)}")
logger.error(f"\n❌ Error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@ -0,0 +1,313 @@
#!/usr/bin/env python3
"""
Daily Performance Tracker
Tracks the performance of historical recommendations and updates the database.
Run this daily (via cron or manually) to monitor how recommendations perform over time.
Usage:
python scripts/track_recommendation_performance.py
Cron example (runs daily at 5pm after market close):
0 17 * * 1-5 cd /path/to/TradingAgents && python scripts/track_recommendation_performance.py
"""
import glob
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
# Add parent directory to path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from tradingagents.dataflows.y_finance import get_stock_price
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def load_recommendations() -> List[Dict[str, Any]]:
"""Load all historical recommendations from the recommendations directory."""
recommendations_dir = "data/recommendations"
if not os.path.exists(recommendations_dir):
logger.warning(f"No recommendations directory found at {recommendations_dir}")
return []
all_recs = []
pattern = os.path.join(recommendations_dir, "*.json")
for filepath in glob.glob(pattern):
try:
with open(filepath, "r") as f:
data = json.load(f)
# Each file contains recommendations from one discovery run
recs = data.get("recommendations", [])
run_date = data.get("date", os.path.basename(filepath).replace(".json", ""))
for rec in recs:
rec["discovery_date"] = run_date
all_recs.append(rec)
except Exception as e:
logger.error(f"Error loading {filepath}: {e}")
return all_recs
def update_performance(recommendations: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Update performance metrics for all recommendations."""
today = datetime.now().strftime("%Y-%m-%d")
for rec in recommendations:
ticker = rec.get("ticker")
discovery_date = rec.get("discovery_date")
entry_price = rec.get("entry_price")
if not all([ticker, discovery_date, entry_price]):
continue
# Skip if already marked as closed
if rec.get("status") == "closed":
continue
try:
# Get current price
current_price_data = get_stock_price(ticker, curr_date=today)
# Parse the price from the response (it returns a markdown report)
# Format is typically: "**Current Price**: $XXX.XX"
import re
price_match = re.search(r"\$([0-9,.]+)", current_price_data)
if price_match:
current_price = float(price_match.group(1).replace(",", ""))
else:
logger.warning(f"Could not parse price for {ticker}")
continue
# Calculate days since recommendation
rec_date = datetime.strptime(discovery_date, "%Y-%m-%d")
days_held = (datetime.now() - rec_date).days
# Calculate return
return_pct = ((current_price - entry_price) / entry_price) * 100
# Update metrics
rec["current_price"] = current_price
rec["return_pct"] = round(return_pct, 2)
rec["days_held"] = days_held
rec["last_updated"] = today
# Check specific time periods
if days_held >= 7 and "return_7d" not in rec:
rec["return_7d"] = round(return_pct, 2)
if days_held >= 30 and "return_30d" not in rec:
rec["return_30d"] = round(return_pct, 2)
rec["status"] = "closed" # Mark as complete after 30 days
# Determine win/loss for completed periods
if "return_7d" in rec:
rec["win_7d"] = rec["return_7d"] > 0
if "return_30d" in rec:
rec["win_30d"] = rec["return_30d"] > 0
logger.info(
f"{ticker}: Entry ${entry_price:.2f} → Current ${current_price:.2f} ({return_pct:+.1f}%) [{days_held}d]"
)
except Exception as e:
logger.error(f"✗ Error tracking {ticker}: {e}")
return recommendations
def save_performance_database(recommendations: List[Dict[str, Any]]):
"""Save the updated performance database."""
db_path = "data/recommendations/performance_database.json"
# Group by discovery date for organized storage
by_date = {}
for rec in recommendations:
date = rec.get("discovery_date", "unknown")
if date not in by_date:
by_date[date] = []
by_date[date].append(rec)
database = {
"last_updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"total_recommendations": len(recommendations),
"recommendations_by_date": by_date,
}
with open(db_path, "w") as f:
json.dump(database, f, indent=2)
logger.info(f"\n💾 Saved performance database to {db_path}")
def calculate_statistics(recommendations: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Calculate aggregate statistics from historical performance."""
stats = {
"total_recommendations": len(recommendations),
"by_strategy": {},
"overall_7d": {"count": 0, "wins": 0, "avg_return": 0},
"overall_30d": {"count": 0, "wins": 0, "avg_return": 0},
}
# Calculate by strategy
for rec in recommendations:
strategy = rec.get("strategy_match", "unknown")
if strategy not in stats["by_strategy"]:
stats["by_strategy"][strategy] = {
"count": 0,
"wins_7d": 0,
"losses_7d": 0,
"wins_30d": 0,
"losses_30d": 0,
"avg_return_7d": 0,
"avg_return_30d": 0,
}
stats["by_strategy"][strategy]["count"] += 1
# 7-day stats
if "return_7d" in rec:
stats["overall_7d"]["count"] += 1
if rec.get("win_7d"):
stats["overall_7d"]["wins"] += 1
stats["by_strategy"][strategy]["wins_7d"] += 1
else:
stats["by_strategy"][strategy]["losses_7d"] += 1
stats["overall_7d"]["avg_return"] += rec["return_7d"]
# 30-day stats
if "return_30d" in rec:
stats["overall_30d"]["count"] += 1
if rec.get("win_30d"):
stats["overall_30d"]["wins"] += 1
stats["by_strategy"][strategy]["wins_30d"] += 1
else:
stats["by_strategy"][strategy]["losses_30d"] += 1
stats["overall_30d"]["avg_return"] += rec["return_30d"]
# Calculate averages and win rates
if stats["overall_7d"]["count"] > 0:
stats["overall_7d"]["win_rate"] = round(
(stats["overall_7d"]["wins"] / stats["overall_7d"]["count"]) * 100, 1
)
stats["overall_7d"]["avg_return"] = round(
stats["overall_7d"]["avg_return"] / stats["overall_7d"]["count"], 2
)
if stats["overall_30d"]["count"] > 0:
stats["overall_30d"]["win_rate"] = round(
(stats["overall_30d"]["wins"] / stats["overall_30d"]["count"]) * 100, 1
)
stats["overall_30d"]["avg_return"] = round(
stats["overall_30d"]["avg_return"] / stats["overall_30d"]["count"], 2
)
# Calculate per-strategy stats
for strategy, data in stats["by_strategy"].items():
total_7d = data["wins_7d"] + data["losses_7d"]
total_30d = data["wins_30d"] + data["losses_30d"]
if total_7d > 0:
data["win_rate_7d"] = round((data["wins_7d"] / total_7d) * 100, 1)
if total_30d > 0:
data["win_rate_30d"] = round((data["wins_30d"] / total_30d) * 100, 1)
return stats
def print_statistics(stats: Dict[str, Any]):
"""Print formatted statistics report."""
logger.info("\n" + "=" * 60)
logger.info("RECOMMENDATION PERFORMANCE STATISTICS")
logger.info("=" * 60)
logger.info(f"\nTotal Recommendations Tracked: {stats['total_recommendations']}")
# Overall stats
logger.info("\n📊 OVERALL PERFORMANCE")
logger.info("-" * 60)
if stats["overall_7d"]["count"] > 0:
logger.info("7-Day Performance:")
logger.info(f" • Tracked: {stats['overall_7d']['count']} recommendations")
logger.info(f" • Win Rate: {stats['overall_7d']['win_rate']}%")
logger.info(f" • Avg Return: {stats['overall_7d']['avg_return']:+.2f}%")
if stats["overall_30d"]["count"] > 0:
logger.info("\n30-Day Performance:")
logger.info(f" • Tracked: {stats['overall_30d']['count']} recommendations")
logger.info(f" • Win Rate: {stats['overall_30d']['win_rate']}%")
logger.info(f" • Avg Return: {stats['overall_30d']['avg_return']:+.2f}%")
# By strategy
if stats["by_strategy"]:
logger.info("\n📈 PERFORMANCE BY STRATEGY")
logger.info("-" * 60)
# Sort by win rate (if available)
sorted_strategies = sorted(
stats["by_strategy"].items(), key=lambda x: x[1].get("win_rate_7d", 0), reverse=True
)
for strategy, data in sorted_strategies:
logger.info(f"\n{strategy}:")
logger.info(f" • Total: {data['count']} recommendations")
if data.get("win_rate_7d"):
logger.info(
f" • 7-Day Win Rate: {data['win_rate_7d']}% ({data['wins_7d']}W/{data['losses_7d']}L)"
)
if data.get("win_rate_30d"):
logger.info(
f" • 30-Day Win Rate: {data['win_rate_30d']}% ({data['wins_30d']}W/{data['losses_30d']}L)"
)
def main():
"""Main execution function."""
logger.info("🔍 Loading historical recommendations...")
recommendations = load_recommendations()
if not recommendations:
logger.warning("No recommendations found to track.")
return
logger.info(f"Found {len(recommendations)} total recommendations")
# Filter to only track open positions (not closed after 30 days)
open_recs = [r for r in recommendations if r.get("status") != "closed"]
logger.info(f"Tracking {len(open_recs)} open positions...")
logger.info("\n📊 Updating performance metrics...\n")
updated_recs = update_performance(recommendations)
logger.info("\n📈 Calculating statistics...")
stats = calculate_statistics(updated_recs)
print_statistics(stats)
save_performance_database(updated_recs)
# Also save stats separately
stats_path = "data/recommendations/statistics.json"
with open(stats_path, "w") as f:
json.dump(stats, f, indent=2)
logger.info(f"💾 Saved statistics to {stats_path}")
logger.info("\n✅ Performance tracking complete!")
if __name__ == "__main__":
main()

370
scripts/train_ml_model.py Normal file
View File

@ -0,0 +1,370 @@
#!/usr/bin/env python3
"""Train ML model on the generated dataset.
Supports TabPFN (recommended, requires GPU or API) and LightGBM (fallback).
Uses time-based train/validation split to prevent data leakage.
Usage:
python scripts/train_ml_model.py
python scripts/train_ml_model.py --model lightgbm
python scripts/train_ml_model.py --model tabpfn --dataset data/ml/training_dataset.parquet
python scripts/train_ml_model.py --max-train-samples 5000
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
)
# Add project root to path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
from tradingagents.ml.feature_engineering import FEATURE_COLUMNS
from tradingagents.ml.predictor import LGBMWrapper, MLPredictor
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
DATA_DIR = Path("data/ml")
LABEL_NAMES = {-1: "LOSS", 0: "TIMEOUT", 1: "WIN"}
def load_dataset(path: str) -> pd.DataFrame:
"""Load and validate the training dataset."""
df = pd.read_parquet(path)
logger.info(f"Loaded {len(df)} samples from {path}")
# Validate columns
missing = [c for c in FEATURE_COLUMNS if c not in df.columns]
if missing:
raise ValueError(f"Missing feature columns: {missing}")
if "label" not in df.columns:
raise ValueError("Missing 'label' column")
if "date" not in df.columns:
raise ValueError("Missing 'date' column")
# Show label distribution
for label, name in LABEL_NAMES.items():
count = (df["label"] == label).sum()
pct = count / len(df) * 100
logger.info(f" {name:>7} ({label:+d}): {count:>7} ({pct:.1f}%)")
return df
def time_split(
df: pd.DataFrame,
val_start: str = "2024-07-01",
max_train_samples: int | None = None,
) -> tuple:
"""Split dataset by time — train on older data, validate on newer."""
df["date"] = pd.to_datetime(df["date"])
val_start_dt = pd.Timestamp(val_start)
train = df[df["date"] < val_start_dt].copy()
val = df[df["date"] >= val_start_dt].copy()
if max_train_samples is not None and len(train) > max_train_samples:
train = train.sort_values("date").tail(max_train_samples)
logger.info(
f"Limiting training samples to most recent {max_train_samples} "
f"before {val_start}"
)
logger.info(f"Time-based split at {val_start}:")
logger.info(f" Train: {len(train)} samples ({train['date'].min().date()} to {train['date'].max().date()})")
logger.info(f" Val: {len(val)} samples ({val['date'].min().date()} to {val['date'].max().date()})")
X_train = train[FEATURE_COLUMNS].values
y_train = train["label"].values.astype(int)
X_val = val[FEATURE_COLUMNS].values
y_val = val["label"].values.astype(int)
return X_train, y_train, X_val, y_val
def train_tabpfn(X_train, y_train, X_val, y_val):
"""Train using TabPFN foundation model."""
try:
from tabpfn import TabPFNClassifier
except ImportError:
logger.error("TabPFN not installed. Install with: pip install tabpfn")
logger.error("Falling back to LightGBM...")
return train_lightgbm(X_train, y_train, X_val, y_val)
logger.info("Training TabPFN classifier...")
# TabPFN handles NaN values natively
# For large datasets, subsample training data (TabPFN works best with <10K samples)
max_train = 10_000
if len(X_train) > max_train:
logger.info(f"Subsampling training data: {len(X_train)}{max_train}")
idx = np.random.RandomState(42).choice(len(X_train), max_train, replace=False)
X_train_sub = X_train[idx]
y_train_sub = y_train[idx]
else:
X_train_sub = X_train
y_train_sub = y_train
try:
clf = TabPFNClassifier()
clf.fit(X_train_sub, y_train_sub)
return clf, "tabpfn"
except Exception as e:
logger.error(f"TabPFN training failed: {e}")
logger.error("Falling back to LightGBM...")
return train_lightgbm(X_train, y_train, X_val, y_val)
def train_lightgbm(X_train, y_train, X_val, y_val):
"""Train using LightGBM (fallback when TabPFN unavailable)."""
try:
import lightgbm as lgb
except ImportError:
logger.error("LightGBM not installed. Install with: pip install lightgbm")
sys.exit(1)
logger.info("Training LightGBM classifier...")
# Remap labels: {-1, 0, 1} → {0, 1, 2} for LightGBM
y_train_mapped = y_train + 1 # -1→0, 0→1, 1→2
y_val_mapped = y_val + 1
# Compute class weights to handle imbalanced labels
from collections import Counter
class_counts = Counter(y_train_mapped)
total = len(y_train_mapped)
n_classes = len(class_counts)
class_weight = {c: total / (n_classes * count) for c, count in class_counts.items()}
sample_weights = np.array([class_weight[y] for y in y_train_mapped])
train_data = lgb.Dataset(X_train, label=y_train_mapped, weight=sample_weights, feature_name=FEATURE_COLUMNS)
val_data = lgb.Dataset(X_val, label=y_val_mapped, feature_name=FEATURE_COLUMNS, reference=train_data)
params = {
"objective": "multiclass",
"num_class": 3,
"metric": "multi_logloss",
# Lower LR + more rounds = smoother learning on noisy data
"learning_rate": 0.01,
# More capacity to find feature interactions
"num_leaves": 63,
"max_depth": 8,
"min_child_samples": 100,
# Aggressive subsampling to reduce overfitting on noise
"subsample": 0.7,
"subsample_freq": 1,
"colsample_bytree": 0.7,
# Stronger regularization for financial data
"reg_alpha": 1.0,
"reg_lambda": 1.0,
"min_gain_to_split": 0.01,
"path_smooth": 1.0,
"verbose": -1,
"seed": 42,
}
callbacks = [
lgb.log_evaluation(period=100),
lgb.early_stopping(stopping_rounds=100),
]
booster = lgb.train(
params,
train_data,
num_boost_round=2000,
valid_sets=[val_data],
callbacks=callbacks,
)
# Wrap in sklearn-compatible interface
clf = LGBMWrapper(booster, y_train)
return clf, "lightgbm"
def evaluate(model, X_val, y_val, model_type: str) -> dict:
"""Evaluate model and return metrics dict."""
if isinstance(X_val, np.ndarray):
X_df = pd.DataFrame(X_val, columns=FEATURE_COLUMNS)
else:
X_df = X_val
y_pred = model.predict(X_df)
probas = model.predict_proba(X_df)
accuracy = accuracy_score(y_val, y_pred)
report = classification_report(
y_val, y_pred,
target_names=["LOSS (-1)", "TIMEOUT (0)", "WIN (+1)"],
output_dict=True,
)
cm = confusion_matrix(y_val, y_pred)
# Win-class specific metrics
win_mask = y_val == 1
if win_mask.sum() > 0:
win_probs = probas[win_mask]
win_col_idx = list(model.classes_).index(1)
avg_win_prob_for_actual_wins = float(win_probs[:, win_col_idx].mean())
else:
avg_win_prob_for_actual_wins = 0.0
# High-confidence win precision
win_col_idx = list(model.classes_).index(1)
high_conf_mask = probas[:, win_col_idx] >= 0.6
if high_conf_mask.sum() > 0:
high_conf_precision = float((y_val[high_conf_mask] == 1).mean())
high_conf_count = int(high_conf_mask.sum())
else:
high_conf_precision = 0.0
high_conf_count = 0
# Calibration analysis: do higher P(WIN) quintiles actually win more?
win_probs_all = probas[:, win_col_idx]
quintile_labels = pd.qcut(win_probs_all, q=5, labels=False, duplicates="drop")
calibration = {}
for q in sorted(set(quintile_labels)):
mask = quintile_labels == q
q_probs = win_probs_all[mask]
q_actual_win_rate = float((y_val[mask] == 1).mean())
q_actual_loss_rate = float((y_val[mask] == -1).mean())
calibration[f"Q{q+1}"] = {
"mean_predicted_win_prob": round(float(q_probs.mean()), 4),
"actual_win_rate": round(q_actual_win_rate, 4),
"actual_loss_rate": round(q_actual_loss_rate, 4),
"count": int(mask.sum()),
}
# Top decile (top 10% by P(WIN)) — most actionable metric
top_decile_threshold = np.percentile(win_probs_all, 90)
top_decile_mask = win_probs_all >= top_decile_threshold
top_decile_win_rate = float((y_val[top_decile_mask] == 1).mean()) if top_decile_mask.sum() > 0 else 0.0
top_decile_loss_rate = float((y_val[top_decile_mask] == -1).mean()) if top_decile_mask.sum() > 0 else 0.0
metrics = {
"model_type": model_type,
"accuracy": round(accuracy, 4),
"per_class": {k: {kk: round(vv, 4) for kk, vv in v.items()} for k, v in report.items() if isinstance(v, dict)},
"confusion_matrix": cm.tolist(),
"avg_win_prob_for_actual_wins": round(avg_win_prob_for_actual_wins, 4),
"high_confidence_win_precision": round(high_conf_precision, 4),
"high_confidence_win_count": high_conf_count,
"calibration_quintiles": calibration,
"top_decile_win_rate": round(top_decile_win_rate, 4),
"top_decile_loss_rate": round(top_decile_loss_rate, 4),
"top_decile_threshold": round(float(top_decile_threshold), 4),
"top_decile_count": int(top_decile_mask.sum()),
"val_samples": len(y_val),
}
# Print summary
logger.info(f"\n{'='*60}")
logger.info(f"Model: {model_type}")
logger.info(f"Overall Accuracy: {accuracy:.1%}")
logger.info(f"\nPer-class metrics:")
logger.info(f"{'':>15} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Support':>10}")
for label, name in [(-1, "LOSS"), (0, "TIMEOUT"), (1, "WIN")]:
key = f"{name} ({label:+d})"
if key in report:
r = report[key]
logger.info(f"{name:>15} {r['precision']:>10.3f} {r['recall']:>10.3f} {r['f1-score']:>10.3f} {r['support']:>10.0f}")
logger.info(f"\nConfusion Matrix (rows=actual, cols=predicted):")
logger.info(f"{'':>10} {'LOSS':>8} {'TIMEOUT':>8} {'WIN':>8}")
for i, name in enumerate(["LOSS", "TIMEOUT", "WIN"]):
logger.info(f"{name:>10} {cm[i][0]:>8} {cm[i][1]:>8} {cm[i][2]:>8}")
logger.info(f"\nWin-class insights:")
logger.info(f" Avg P(WIN) for actual winners: {avg_win_prob_for_actual_wins:.1%}")
logger.info(f" High-confidence (>60%) precision: {high_conf_precision:.1%} ({high_conf_count} samples)")
logger.info("\nCalibration (does higher P(WIN) = more actual wins?):")
logger.info(f"{'Quintile':>10} {'Avg P(WIN)':>12} {'Actual WIN%':>12} {'Actual LOSS%':>13} {'Count':>8}")
for q_name, q_data in calibration.items():
logger.info(
f"{q_name:>10} {q_data['mean_predicted_win_prob']:>12.1%} "
f"{q_data['actual_win_rate']:>12.1%} {q_data['actual_loss_rate']:>13.1%} "
f"{q_data['count']:>8}"
)
logger.info("\nTop decile (top 10% by P(WIN)):")
logger.info(f" Threshold: P(WIN) >= {top_decile_threshold:.1%}")
logger.info(f" Actual win rate: {top_decile_win_rate:.1%} ({int(top_decile_mask.sum())} samples)")
logger.info(f" Actual loss rate: {top_decile_loss_rate:.1%}")
baseline_win = float((y_val == 1).mean())
logger.info(f" Baseline win rate: {baseline_win:.1%}")
if baseline_win > 0:
logger.info(f" Lift over baseline: {top_decile_win_rate / baseline_win:.2f}x")
logger.info(f"{'='*60}")
return metrics
def main():
parser = argparse.ArgumentParser(description="Train ML model for win probability")
parser.add_argument("--dataset", type=str, default="data/ml/training_dataset.parquet")
parser.add_argument("--model", type=str, choices=["tabpfn", "lightgbm", "auto"], default="auto",
help="Model type (auto tries TabPFN first, falls back to LightGBM)")
parser.add_argument("--val-start", type=str, default="2024-07-01",
help="Validation split date (default: 2024-07-01)")
parser.add_argument("--max-train-samples", type=int, default=None,
help="Limit training samples to the most recent N before val-start")
parser.add_argument("--output-dir", type=str, default="data/ml")
args = parser.parse_args()
if args.max_train_samples is not None and args.max_train_samples <= 0:
logger.error("--max-train-samples must be a positive integer")
sys.exit(1)
# Load dataset
df = load_dataset(args.dataset)
# Split
X_train, y_train, X_val, y_val = time_split(
df,
val_start=args.val_start,
max_train_samples=args.max_train_samples,
)
if len(X_val) == 0:
logger.error(f"No validation data after {args.val_start} — adjust --val-start")
sys.exit(1)
# Train
if args.model == "tabpfn" or args.model == "auto":
model, model_type = train_tabpfn(X_train, y_train, X_val, y_val)
else:
model, model_type = train_lightgbm(X_train, y_train, X_val, y_val)
# Evaluate
metrics = evaluate(model, X_val, y_val, model_type)
# Save model
predictor = MLPredictor(model=model, feature_columns=FEATURE_COLUMNS, model_type=model_type)
model_path = predictor.save(args.output_dir)
logger.info(f"Model saved to {model_path}")
# Save metrics
metrics_path = os.path.join(args.output_dir, "metrics.json")
with open(metrics_path, "w") as f:
json.dump(metrics, f, indent=2)
logger.info(f"Metrics saved to {metrics_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,39 @@
#!/bin/bash
# Script to extract consistently failing tickers from the delisted cache
# These are candidates for adding to PERMANENTLY_DELISTED after manual verification
CACHE_FILE="data/delisted_cache.json"
REVIEW_FILE="data/delisted_review.txt"
echo "Analyzing delisted cache for consistently failing tickers..."
if [ ! -f "$CACHE_FILE" ]; then
echo "No delisted cache found at $CACHE_FILE"
echo "Run discovery flow at least once to populate the cache."
exit 0
fi
# Check if jq is installed
if ! command -v jq &> /dev/null; then
echo "Error: jq is required but not installed."
echo "Install it with: brew install jq (macOS) or apt-get install jq (Linux)"
exit 1
fi
# Extract tickers with high fail counts (3+ failures across multiple days)
echo ""
echo "Tickers that have failed 3+ times:"
echo "=================================="
jq -r 'to_entries[] | select(.value.fail_count >= 3) | "\(.key): \(.value.fail_count) failures across \(.value.fail_dates | length) days - \(.value.reason)"' "$CACHE_FILE"
echo ""
echo "---"
echo "Review the tickers above and verify their status using:"
echo " 1. Yahoo Finance: https://finance.yahoo.com/quote/TICKER"
echo " 2. SEC EDGAR: https://www.sec.gov/cgi-bin/browse-edgar"
echo " 3. Google search: 'TICKER stock delisted'"
echo ""
echo "For CONFIRMED permanent delistings, add them to PERMANENTLY_DELISTED in:"
echo " tradingagents/graph/discovery_graph.py"
echo ""
echo "Detailed review list has been exported to: $REVIEW_FILE"

203
scripts/update_positions.py Executable file
View File

@ -0,0 +1,203 @@
#!/usr/bin/env python3
"""
Position Updater Script
This script:
1. Fetches current prices for all open positions
2. Updates positions with latest price data
3. Calculates return % for each position
4. Can be run manually or via cron for continuous monitoring
Usage:
python scripts/update_positions.py
"""
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from datetime import datetime
import yfinance as yf
from tradingagents.dataflows.discovery.performance.position_tracker import PositionTracker
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def fetch_current_prices(tickers):
"""
Fetch current prices for given tickers using yfinance.
Handles both single and multiple tickers with appropriate error handling.
Args:
tickers: List of ticker symbols
Returns:
Dictionary mapping ticker to current price (or None if fetch failed)
"""
prices = {}
if not tickers:
return prices
# Try to download all tickers at once for efficiency
try:
if len(tickers) == 1:
# Single ticker - yfinance returns Series instead of DataFrame
ticker = tickers[0]
data = yf.download(
ticker,
period="1d",
progress=False,
auto_adjust=True,
)
if not data.empty:
# For single ticker with period='1d', get the latest close
prices[ticker] = float(data["Close"].iloc[-1])
else:
logger.warning(f"Could not fetch data for {ticker}")
prices[ticker] = None
else:
# Multiple tickers - yfinance returns DataFrame with MultiIndex
data = yf.download(
tickers,
period="1d",
progress=False,
auto_adjust=True,
)
if not data.empty:
# Get the latest close for each ticker
if len(tickers) > 1:
for ticker in tickers:
if ticker in data.columns:
close_price = data[ticker]["Close"]
if not close_price.empty:
prices[ticker] = float(close_price.iloc[-1])
else:
prices[ticker] = None
else:
prices[ticker] = None
else:
# Edge case: single ticker in batch download
if "Close" in data.columns:
prices[tickers[0]] = float(data["Close"].iloc[-1])
else:
prices[tickers[0]] = None
else:
for ticker in tickers:
prices[ticker] = None
except Exception as e:
logger.warning(f"Batch download failed: {e}")
# Fall back to per-ticker download
for ticker in tickers:
try:
data = yf.download(
ticker,
period="1d",
progress=False,
auto_adjust=True,
)
if not data.empty:
prices[ticker] = float(data["Close"].iloc[-1])
else:
prices[ticker] = None
except Exception as e:
logger.error(f"Failed to fetch price for {ticker}: {e}")
prices[ticker] = None
return prices
def main():
"""
Main function to update all open positions with current prices.
Process:
1. Initialize PositionTracker
2. Load all open positions
3. Get unique tickers
4. Fetch current prices via yfinance
5. Update each position with new price
6. Save updated positions
7. Print progress messages
"""
logger.info("""
TradingAgents - Position Updater
""".strip())
# Initialize position tracker
tracker = PositionTracker(data_dir="data")
# Load all open positions
logger.info("📂 Loading open positions...")
positions = tracker.load_all_open_positions()
if not positions:
logger.info("✅ No open positions to update.")
return
logger.info(f"✅ Found {len(positions)} open position(s)")
# Get unique tickers
tickers = list({pos["ticker"] for pos in positions})
logger.info(f"📊 Fetching current prices for {len(tickers)} unique ticker(s)...")
logger.info(f"Tickers: {', '.join(sorted(tickers))}")
# Fetch current prices
prices = fetch_current_prices(tickers)
# Update positions and track results
updated_count = 0
failed_count = 0
for position in positions:
ticker = position["ticker"]
current_price = prices.get(ticker)
if current_price is None:
logger.error(f"{ticker}: Failed to fetch price - position not updated")
failed_count += 1
continue
# Update position with new price
entry_price = position["entry_price"]
return_pct = ((current_price - entry_price) / entry_price) * 100
# Update the position
position = tracker.update_position_price(position, current_price)
# Save the updated position
tracker.save_position(position)
# Log progress
return_symbol = "📈" if return_pct >= 0 else "📉"
logger.info(
f"{return_symbol} {ticker:6} | Price: ${current_price:8.2f} | Return: {return_pct:+7.2f}%"
)
updated_count += 1
# Summary
logger.info("=" * 60)
logger.info("✅ Update Summary:")
logger.info(f"Updated: {updated_count}/{len(positions)} positions")
logger.info(f"Failed: {failed_count}/{len(positions)} positions")
logger.info(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')}")
logger.info("=" * 60)
if updated_count > 0:
logger.info("🎉 Position update complete!")
else:
logger.warning("No positions were successfully updated.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,305 @@
#!/usr/bin/env python3
"""
Ticker Database Updater
Maintains and augments the ticker list in data/tickers.txt
Usage:
python scripts/update_ticker_database.py [OPTIONS]
Examples:
# Validate and clean existing list
python scripts/update_ticker_database.py --validate
# Add specific tickers
python scripts/update_ticker_database.py --add NVDA,PLTR,HOOD
# Fetch latest from Alpha Vantage
python scripts/update_ticker_database.py --fetch-alphavantage
"""
import argparse
import os
import sys
from pathlib import Path
from typing import Set
import requests
from dotenv import load_dotenv
from tradingagents.utils.logger import get_logger
load_dotenv()
sys.path.insert(0, str(Path(__file__).parent.parent))
logger = get_logger(__name__)
class TickerDatabaseUpdater:
def __init__(self, ticker_file: str = "data/tickers.txt"):
self.ticker_file = ticker_file
self.tickers: Set[str] = set()
self.added_count = 0
self.removed_count = 0
def load_tickers(self) -> Set[str]:
"""Load existing tickers from file."""
logger.info(f"📖 Loading tickers from {self.ticker_file}...")
try:
with open(self.ticker_file, "r") as f:
for line in f:
symbol = line.strip()
if symbol and symbol.isalpha():
self.tickers.add(symbol.upper())
logger.info(f" ✓ Loaded {len(self.tickers)} tickers")
return self.tickers
except FileNotFoundError:
logger.info(" File not found, starting fresh")
return set()
except Exception as e:
logger.warning(f" ⚠️ Error loading: {str(e)}")
return set()
def add_tickers(self, new_tickers: list):
"""Add new tickers to the database."""
logger.info(f"\n Adding tickers: {', '.join(new_tickers)}")
for ticker in new_tickers:
ticker = ticker.strip().upper()
if ticker and ticker.isalpha():
if ticker not in self.tickers:
self.tickers.add(ticker)
self.added_count += 1
logger.info(f" ✓ Added {ticker}")
else:
logger.info(f" {ticker} already exists")
def validate_and_clean(self, remove_warrants=False, remove_preferred=False):
"""Validate tickers and remove invalid ones."""
logger.info(f"\n🔍 Validating {len(self.tickers)} tickers...")
invalid = set()
for ticker in self.tickers:
# Remove if not alphabetic or too long
if not ticker.isalpha() or len(ticker) > 5 or len(ticker) < 1:
invalid.add(ticker)
continue
# Optionally remove warrants (ending in W)
if remove_warrants and ticker.endswith("W") and len(ticker) > 1:
invalid.add(ticker)
continue
# Optionally remove preferred shares (ending in P after checking it's not a regular stock)
if remove_preferred and ticker.endswith("P") and len(ticker) > 1:
invalid.add(ticker)
if invalid:
logger.warning(f" ⚠️ Found {len(invalid)} problematic tickers")
# Categorize for reporting
warrants = [t for t in invalid if t.endswith("W")]
preferred = [t for t in invalid if t.endswith("P")]
other_invalid = [t for t in invalid if not (t.endswith("W") or t.endswith("P"))]
if warrants and remove_warrants:
logger.info(f" Warrants (ending in W): {len(warrants)}")
if preferred and remove_preferred:
logger.info(f" Preferred shares (ending in P): {len(preferred)}")
if other_invalid:
logger.info(f" Other invalid: {len(other_invalid)}")
for ticker in list(other_invalid)[:10]:
logger.debug(f" - {ticker}")
if len(other_invalid) > 10:
logger.debug(f" ... and {len(other_invalid) - 10} more")
for ticker in invalid:
self.tickers.remove(ticker)
self.removed_count += 1
else:
logger.info(" ✓ All tickers valid")
def fetch_from_alphavantage(self):
"""Fetch tickers from Alpha Vantage LISTING_STATUS endpoint."""
logger.info("\n📥 Fetching from Alpha Vantage...")
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
if not api_key or "placeholder" in api_key:
logger.warning(" ⚠️ ALPHA_VANTAGE_API_KEY not configured")
logger.info(" 💡 Set in .env file to use this feature")
return
try:
url = f"https://www.alphavantage.co/query?function=LISTING_STATUS&apikey={api_key}"
logger.info(" Downloading listing data...")
response = requests.get(url, timeout=60)
if response.status_code != 200:
logger.error(f" ❌ Failed: HTTP {response.status_code}")
return
# Parse CSV response
lines = response.text.strip().split("\n")
if len(lines) < 2:
logger.error(" ❌ Invalid response format")
return
header = lines[0].split(",")
logger.debug(f" Columns: {', '.join(header)}")
# Find symbol and status columns
try:
symbol_idx = header.index("symbol")
status_idx = header.index("status")
except ValueError:
# Try without quotes
symbol_idx = 0 # Usually first column
status_idx = None
initial_count = len(self.tickers)
for line in lines[1:]:
parts = line.split(",")
if len(parts) > symbol_idx:
symbol = parts[symbol_idx].strip().strip('"')
# Check if active (if status column exists)
if status_idx and len(parts) > status_idx:
status = parts[status_idx].strip().strip('"')
if status != "Active":
continue
# Only add alphabetic symbols
if symbol and symbol.isalpha() and len(symbol) <= 5:
self.tickers.add(symbol.upper())
new_count = len(self.tickers) - initial_count
self.added_count += new_count
logger.info(f" ✓ Added {new_count} new tickers from Alpha Vantage")
except Exception as e:
logger.error(f" ❌ Error: {str(e)}")
def save_tickers(self):
"""Save tickers back to file (sorted)."""
output_path = Path(self.ticker_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
sorted_tickers = sorted(self.tickers)
with open(output_path, "w") as f:
for symbol in sorted_tickers:
f.write(f"{symbol}\n")
logger.info(f"\n✅ Saved {len(sorted_tickers)} tickers to: {self.ticker_file}")
def print_summary(self):
"""Print summary."""
logger.info("\n" + "=" * 70)
logger.info("📊 SUMMARY")
logger.info("=" * 70)
logger.info(f"Total Tickers: {len(self.tickers):,}")
if self.added_count > 0:
logger.info(f"Added: {self.added_count}")
if self.removed_count > 0:
logger.info(f"Removed: {self.removed_count}")
logger.info("=" * 70 + "\n")
def main():
parser = argparse.ArgumentParser(description="Update and maintain ticker database")
parser.add_argument(
"--file",
type=str,
default="data/tickers.txt",
help="Ticker file path (default: data/tickers.txt)",
)
parser.add_argument(
"--add", type=str, help="Comma-separated list of tickers to add (e.g., NVDA,PLTR,HOOD)"
)
parser.add_argument(
"--validate", action="store_true", help="Validate and clean existing tickers"
)
parser.add_argument(
"--remove-warrants",
action="store_true",
help="Remove warrants (tickers ending in W) during validation",
)
parser.add_argument(
"--remove-preferred",
action="store_true",
help="Remove preferred shares (tickers ending in P) during validation",
)
parser.add_argument(
"--fetch-alphavantage", action="store_true", help="Fetch latest tickers from Alpha Vantage"
)
args = parser.parse_args()
logger.info("=" * 70)
logger.info("🔄 TICKER DATABASE UPDATER")
logger.info("=" * 70)
logger.info(f"File: {args.file}")
logger.info("=" * 70 + "\n")
updater = TickerDatabaseUpdater(args.file)
# Load existing tickers
updater.load_tickers()
# Perform requested operations
if args.add:
new_tickers = [t.strip() for t in args.add.split(",")]
updater.add_tickers(new_tickers)
if args.validate or args.remove_warrants or args.remove_preferred:
updater.validate_and_clean(
remove_warrants=args.remove_warrants, remove_preferred=args.remove_preferred
)
if args.fetch_alphavantage:
updater.fetch_from_alphavantage()
# If no operations specified, just validate
if not (
args.add
or args.validate
or args.remove_warrants
or args.remove_preferred
or args.fetch_alphavantage
):
logger.info("No operations specified. Use --help for options.")
logger.info("\nRunning basic validation...")
updater.validate_and_clean(remove_warrants=False, remove_preferred=False)
# Save if any changes were made
if updater.added_count > 0 or updater.removed_count > 0:
updater.save_tickers()
else:
logger.info("\n No changes made")
# Print summary
updater.print_summary()
logger.info("💡 Usage examples:")
logger.info(" python scripts/update_ticker_database.py --add NVDA,PLTR")
logger.info(" python scripts/update_ticker_database.py --validate")
logger.info(" python scripts/update_ticker_database.py --remove-warrants")
logger.info(" python scripts/update_ticker_database.py --fetch-alphavantage\n")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
logger.warning("\n\n⚠️ Interrupted by user")
sys.exit(1)
except Exception as e:
logger.error(f"\n❌ Error: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@ -2,7 +2,7 @@
Setup script for the TradingAgents package.
"""
from setuptools import setup, find_packages
from setuptools import find_packages, setup
setup(
name="tradingagents",

42
tests/conftest.py Normal file
View File

@ -0,0 +1,42 @@
import os
from unittest.mock import patch
import pytest
from tradingagents.config import Config
@pytest.fixture
def mock_env_vars():
"""Mock environment variables for testing."""
with patch.dict(os.environ, {
"OPENAI_API_KEY": "test-openai-key",
"ALPHA_VANTAGE_API_KEY": "test-alpha-key",
"FINNHUB_API_KEY": "test-finnhub-key",
"TRADIER_API_KEY": "test-tradier-key",
"GOOGLE_API_KEY": "test-google-key",
"REDDIT_CLIENT_ID": "test-reddit-id",
"REDDIT_CLIENT_SECRET": "test-reddit-secret",
"TWITTER_BEARER_TOKEN": "test-twitter-token"
}, clear=True):
yield
@pytest.fixture
def mock_config(mock_env_vars):
"""Return a Config instance with mocked env vars."""
# Reset singleton
Config._instance = None
return Config()
@pytest.fixture
def sample_stock_data():
"""Return a sample DataFrame for technical analysis."""
import pandas as pd
data = {
"close": [100, 102, 101, 103, 105, 108, 110, 109, 112, 115],
"high": [105, 106, 105, 107, 108, 112, 115, 113, 116, 118],
"low": [95, 98, 99, 100, 102, 105, 108, 106, 108, 111],
"volume": [1000] * 10
}
return pd.DataFrame(data)

View File

@ -0,0 +1,45 @@
from unittest.mock import patch
import pytest
from tradingagents.dataflows.news_semantic_scanner import NewsSemanticScanner
class TestNewsSemanticScanner:
@pytest.fixture
def scanner(self, mock_config):
# Allow instantiation by mocking __init__ dependencies if needed?
# The class uses OpenAI in init.
with patch('tradingagents.dataflows.news_semantic_scanner.OpenAI') as MockOpenAI:
scanner = NewsSemanticScanner(config=mock_config)
return scanner
def test_filter_by_time(self, scanner):
from datetime import datetime
# Test data
news = [
{"published_at": "2025-01-01T12:00:00Z", "title": "Old News"},
{"published_at": datetime.now().isoformat(), "title": "New News"}
]
# We need to set scanner.cutoff_time manually or check its logic
# current logic sets it to now - lookback
# This is a bit tricky without mocking datetime or adjusting cutoff,
# so let's trust the logic for now or do a simple structural test.
assert hasattr(scanner, "scan_news")
@patch('tradingagents.dataflows.news_semantic_scanner.NewsSemanticScanner._fetch_openai_news')
def test_scan_news_aggregates(self, mock_fetch_openai, scanner):
mock_fetch_openai.return_value = [{"title": "OpenAI News", "importance": 8}]
# Configure to only use openai
scanner.news_sources = ["openai"]
result = scanner.scan_news()
assert len(result) == 1
assert result[0]["title"] == "OpenAI News"

View File

@ -0,0 +1,31 @@
import pandas as pd
from stockstats import wrap
from tradingagents.dataflows.technical_analyst import TechnicalAnalyst
def test_technical_analyst_report_generation(sample_stock_data):
df = wrap(sample_stock_data)
current_price = 115.0
analyst = TechnicalAnalyst(df, current_price)
report = analyst.generate_report("TEST", "2025-01-01")
assert "# Technical Analysis for TEST" in report
assert "**Current Price:** $115.00" in report
assert "## Price Action" in report
assert "Daily Change" in report
assert "## RSI" in report
assert "## MACD" in report
def test_technical_analyst_empty_data():
empty_df = pd.DataFrame()
# It might raise an error or handle it, usually logic handles standard DF but let's check
# The class expects columns, so let's pass empty with columns
df = pd.DataFrame(columns=["close", "high", "low", "volume"])
# Wrapping empty might fail or produce empty wrapped
# Our TechnicalAnalyst assumes valid data somewhat, but we should make sure it doesn't just crash blindly
# Actually, y_finance.py checks for empty before calling, so the class itself assumes data.
pass

View File

@ -0,0 +1,25 @@
"""
Quick ticker matcher validation
"""
from tradingagents.dataflows.discovery.ticker_matcher import match_company_to_ticker, load_ticker_universe
# Load universe
print("Loading ticker universe...")
universe = load_ticker_universe()
print(f"Loaded {len(universe)} tickers\n")
# Test cases
tests = [
("Apple Inc", "AAPL"),
("MICROSOFT CORP", "MSFT"),
("Amazon.com, Inc.", "AMZN"),
("TESLA INC", "TSLA"),
("META PLATFORMS INC", "META"),
("NVIDIA CORPORATION", "NVDA"),
]
print("Testing ticker matching:")
for company, expected in tests:
result = match_company_to_ticker(company)
status = "" if result and result.startswith(expected[:3]) else ""
print(f"{status} '{company}' -> {result} (expected {expected})")

42
tests/test_config.py Normal file
View File

@ -0,0 +1,42 @@
import os
from unittest.mock import patch
import pytest
from tradingagents.config import Config
class TestConfig:
def test_singleton(self):
Config._instance = None
c1 = Config()
c2 = Config()
assert c1 is c2
def test_validate_key_success(self, mock_env_vars):
Config._instance = None
config = Config()
key = config.validate_key("openai_api_key", "OpenAI")
assert key == "test-openai-key"
def test_validate_key_failure(self):
Config._instance = None
with patch.dict(os.environ, {}, clear=True):
config = Config()
with pytest.raises(ValueError) as excinfo:
config.validate_key("openai_api_key", "OpenAI")
assert "OpenAI API Key not found" in str(excinfo.value)
def test_get_method(self):
Config._instance = None
config = Config()
# Test getting real property
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
assert config.get("openai_api_key") == "test-key"
# Test getting default value
assert config.get("results_dir") == "./results"
# Test fallback to provided default
assert config.get("non_existent_key", "default") == "default"

View File

@ -0,0 +1,249 @@
#!/usr/bin/env python3
"""
Test script to verify DiscoveryGraph refactoring.
Tests: LLM Factory, TraditionalScanner, CandidateFilter, CandidateRanker
"""
import os
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
def test_llm_factory():
"""Test LLM factory initialization."""
print("\n=== Testing LLM Factory ===")
try:
from tradingagents.utils.llm_factory import create_llms
# Mock API key
os.environ.setdefault("OPENAI_API_KEY", "sk-test-key")
config = {
"llm_provider": "openai",
"deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-3.5-turbo"
}
deep_llm, quick_llm = create_llms(config)
assert deep_llm is not None, "Deep LLM should be initialized"
assert quick_llm is not None, "Quick LLM should be initialized"
print("✅ LLM Factory: Successfully creates LLMs")
return True
except Exception as e:
print(f"❌ LLM Factory: Failed - {e}")
return False
def test_traditional_scanner():
"""Test TraditionalScanner class."""
print("\n=== Testing TraditionalScanner ===")
try:
from unittest.mock import MagicMock
from tradingagents.dataflows.discovery.scanners import TraditionalScanner
config = {"discovery": {}}
mock_llm = MagicMock()
mock_executor = MagicMock()
scanner = TraditionalScanner(config, mock_llm, mock_executor)
assert hasattr(scanner, 'scan'), "Scanner should have scan method"
assert scanner.execute_tool == mock_executor, "Should store executor"
print("✅ TraditionalScanner: Successfully initialized")
return True
except Exception as e:
print(f"❌ TraditionalScanner: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_candidate_filter():
"""Test CandidateFilter class."""
print("\n=== Testing CandidateFilter ===")
try:
from unittest.mock import MagicMock
from tradingagents.dataflows.discovery.filter import CandidateFilter
config = {"discovery": {}}
mock_executor = MagicMock()
filter_obj = CandidateFilter(config, mock_executor)
assert hasattr(filter_obj, 'filter'), "Filter should have filter method"
assert filter_obj.execute_tool == mock_executor, "Should store executor"
print("✅ CandidateFilter: Successfully initialized")
return True
except Exception as e:
print(f"❌ CandidateFilter: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_candidate_ranker():
"""Test CandidateRanker class."""
print("\n=== Testing CandidateRanker ===")
try:
from unittest.mock import MagicMock
from tradingagents.dataflows.discovery.ranker import CandidateRanker
config = {"discovery": {}}
mock_llm = MagicMock()
mock_analytics = MagicMock()
ranker = CandidateRanker(config, mock_llm, mock_analytics)
assert hasattr(ranker, 'rank'), "Ranker should have rank method"
assert ranker.llm == mock_llm, "Should store LLM"
print("✅ CandidateRanker: Successfully initialized")
return True
except Exception as e:
print(f"❌ CandidateRanker: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_discovery_graph_import():
"""Test that DiscoveryGraph still imports correctly."""
print("\n=== Testing DiscoveryGraph Import ===")
try:
from tradingagents.graph.discovery_graph import DiscoveryGraph
# Mock API key
os.environ.setdefault("OPENAI_API_KEY", "sk-test-key")
config = {
"llm_provider": "openai",
"deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-3.5-turbo",
"backend_url": "https://api.openai.com/v1",
"discovery": {}
}
graph = DiscoveryGraph(config=config)
assert hasattr(graph, 'deep_thinking_llm'), "Should have deep LLM"
assert hasattr(graph, 'quick_thinking_llm'), "Should have quick LLM"
assert hasattr(graph, 'analytics'), "Should have analytics"
assert hasattr(graph, 'graph'), "Should have graph"
print("✅ DiscoveryGraph: Successfully initialized with refactored components")
return True
except Exception as e:
print(f"❌ DiscoveryGraph: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_trading_graph_import():
"""Test that TradingAgentsGraph still imports correctly."""
print("\n=== Testing TradingAgentsGraph Import ===")
try:
from tradingagents.graph.trading_graph import TradingAgentsGraph
# Mock API key
os.environ.setdefault("OPENAI_API_KEY", "sk-test-key")
config = {
"llm_provider": "openai",
"deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-3.5-turbo",
"project_dir": str(project_root),
"enable_memory": False
}
graph = TradingAgentsGraph(config=config)
assert hasattr(graph, 'deep_thinking_llm'), "Should have deep LLM"
assert hasattr(graph, 'quick_thinking_llm'), "Should have quick LLM"
print("✅ TradingAgentsGraph: Successfully initialized with LLM factory")
return True
except Exception as e:
print(f"❌ TradingAgentsGraph: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_utils():
"""Test utility functions."""
print("\n=== Testing Utilities ===")
try:
from tradingagents.dataflows.discovery.utils import (
extract_technical_summary,
is_valid_ticker,
)
# Test ticker validation
assert is_valid_ticker("AAPL") == True, "AAPL should be valid"
assert is_valid_ticker("AAPL.WS") == False, "Warrant should be invalid"
assert is_valid_ticker("AAPL-RT") == False, "Rights should be invalid"
# Test technical summary extraction
tech_report = "RSI Value: 45.5"
summary = extract_technical_summary(tech_report)
assert "RSI:45" in summary or "RSI:46" in summary, "Should extract RSI"
print("✅ Utils: All utility functions work correctly")
return True
except Exception as e:
print(f"❌ Utils: Failed - {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Run all tests."""
print("=" * 60)
print("DISCOVERY GRAPH REFACTORING VERIFICATION")
print("=" * 60)
results = []
# Run all tests
results.append(("LLM Factory", test_llm_factory()))
results.append(("Traditional Scanner", test_traditional_scanner()))
results.append(("Candidate Filter", test_candidate_filter()))
results.append(("Candidate Ranker", test_candidate_ranker()))
results.append(("Utils", test_utils()))
results.append(("DiscoveryGraph", test_discovery_graph_import()))
results.append(("TradingAgentsGraph", test_trading_graph_import()))
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
passed = sum(1 for _, result in results if result)
total = len(results)
for name, result in results:
status = "✅ PASS" if result else "❌ FAIL"
print(f"{status}: {name}")
print(f"\n{passed}/{total} tests passed")
if passed == total:
print("\n🎉 All refactoring tests passed!")
return 0
else:
print(f"\n⚠️ {total - passed} test(s) failed")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,159 @@
#!/usr/bin/env python3
"""
Test SEC 13F Parser with Ticker Matching
This script tests the refactored SEC 13F parser to verify:
1. Ticker matcher module loads successfully
2. Fuzzy matching works correctly
3. SEC 13F parsing integrates with ticker matcher
"""
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
print("=" * 60)
print("Testing SEC 13F Parser Refactor")
print("=" * 60)
# Test 1: Ticker Matcher Module
print("\n[1/3] Testing Ticker Matcher Module...")
try:
from tradingagents.dataflows.discovery.ticker_matcher import (
match_company_to_ticker,
load_ticker_universe,
get_match_confidence,
)
# Load universe
universe = load_ticker_universe()
print(f"✓ Loaded {len(universe)} tickers")
# Test exact matches
test_cases = [
("Apple Inc", "AAPL"),
("MICROSOFT CORP", "MSFT"),
("Amazon.com, Inc.", "AMZN"),
("Alphabet Inc", "GOOGL"), # or GOOG
("TESLA INC", "TSLA"),
("META PLATFORMS INC", "META"),
("NVIDIA CORPORATION", "NVDA"),
("Berkshire Hathaway Inc", "BRK.B"), # or BRK.A
]
passed = 0
for company, expected_prefix in test_cases:
result = match_company_to_ticker(company)
if result and result.startswith(expected_prefix[:3]):
passed += 1
print(f"'{company}' -> {result}")
else:
print(f"'{company}' -> {result} (expected {expected_prefix})")
print(f"\nPassed {passed}/{len(test_cases)} exact match tests")
# Test fuzzy matching
print("\nTesting fuzzy matching...")
fuzzy_cases = [
"APPLE COMPUTER INC",
"Microsoft Corporation",
"Amazon Com Inc",
"Tesla Motors",
]
for company in fuzzy_cases:
result = match_company_to_ticker(company, min_confidence=70.0)
confidence = get_match_confidence(company, result) if result else 0
print(f" '{company}' -> {result} (confidence: {confidence:.1f})")
print("✓ Ticker matcher working correctly")
except Exception as e:
print(f"✗ Error testing ticker matcher: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Test 2: SEC 13F Integration
print("\n[2/3] Testing SEC 13F Integration...")
try:
from tradingagents.dataflows.sec_13f import get_recent_13f_changes
print("Fetching recent 13F filings (this may take 30-60 seconds)...")
results = get_recent_13f_changes(
days_lookback=14, # Last 2 weeks
min_position_value=50, # $50M+
notable_only=False,
top_n=10,
return_structured=True,
)
if results:
print(f"\n✓ Found {len(results)} institutional holdings")
print("\nTop 5 holdings:")
print(f"{'Issuer':<40} {'Ticker':<8} {'Institutions':<12} {'Match Method'}")
print("-" * 80)
for i, r in enumerate(results[:5]):
issuer = r['issuer'][:38]
ticker = r.get('ticker', 'N/A')
inst_count = r.get('institution_count', 0)
match_method = r.get('match_method', 'unknown')
print(f"{issuer:<40} {ticker:<8} {inst_count:<12} {match_method}")
# Calculate match statistics
fuzzy_matches = sum(1 for r in results if r.get('match_method') == 'fuzzy')
regex_matches = sum(1 for r in results if r.get('match_method') == 'regex')
unmatched = sum(1 for r in results if r.get('match_method') == 'unmatched')
print(f"\nMatch Statistics:")
print(f" Fuzzy matches: {fuzzy_matches}/{len(results)} ({100*fuzzy_matches/len(results):.1f}%)")
print(f" Regex fallback: {regex_matches}/{len(results)} ({100*regex_matches/len(results):.1f}%)")
print(f" Unmatched: {unmatched}/{len(results)} ({100*unmatched/len(results):.1f}%)")
if fuzzy_matches > 0:
print("\n✓ SEC 13F parser successfully using ticker matcher!")
else:
print("\n⚠ Warning: No fuzzy matches found, matcher may not be integrated")
else:
print("⚠ No results found (may be weekend/no recent filings)")
except Exception as e:
print(f"✗ Error testing SEC 13F integration: {e}")
import traceback
traceback.print_exc()
# Don't exit, this might fail due to network issues
# Test 3: Scanner Interface
print("\n[3/3] Testing Scanner Interface...")
try:
from tradingagents.dataflows.sec_13f import scan_13f_changes
config = {
"discovery": {
"13f_lookback_days": 7,
"13f_min_position_value": 25,
}
}
candidates = scan_13f_changes(config)
if candidates:
print(f"✓ Scanner returned {len(candidates)} candidates")
print(f"\nSample candidates:")
for c in candidates[:3]:
print(f" {c['ticker']}: {c['context']} [{c['priority']}]")
else:
print("⚠ Scanner returned no candidates (may be normal)")
except Exception as e:
print(f"✗ Error testing scanner interface: {e}")
import traceback
traceback.print_exc()
print("\n" + "=" * 60)
print("Testing Complete!")
print("=" * 60)

View File

@ -0,0 +1,27 @@
import logging
from io import StringIO
from tradingagents.utils.logger import get_logger
def test_logger_formatting():
# Capture stdout
capture = StringIO()
handler = logging.StreamHandler(capture)
handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
logger = get_logger("test_logger_unit")
logger.setLevel(logging.INFO)
# Remove existing handlers to avoid cluttering output or double logging
for h in logger.handlers[:]:
logger.removeHandler(h)
logger.addHandler(handler)
logger.info("Test Info")
logger.error("Test Error")
output = capture.getvalue()
print(f"Captured: {output}") # For debugging
assert "INFO: Test Info" in output
assert "ERROR: Test Error" in output

73
tests/verify_refactor.py Normal file
View File

@ -0,0 +1,73 @@
import os
import shutil
import sys
from unittest.mock import MagicMock
# Add project root to path
sys.path.append(os.getcwd())
from tradingagents.dataflows.discovery.scanners import TraditionalScanner
from tradingagents.graph.discovery_graph import DiscoveryGraph
def test_graph_init_with_factory():
print("Testing DiscoveryGraph initialization with LLM Factory...")
config = {
"llm_provider": "openai",
"deep_think_llm": "gpt-4-turbo",
"quick_think_llm": "gpt-3.5-turbo",
"backend_url": "https://api.openai.com/v1",
"discovery": {},
"results_dir": "tests/temp_results"
}
# Mock API key so factory works
if not os.getenv("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = "sk-mock-key"
try:
graph = DiscoveryGraph(config=config)
assert hasattr(graph, 'deep_thinking_llm')
assert hasattr(graph, 'quick_thinking_llm')
assert graph.deep_thinking_llm is not None
print("✅ DiscoveryGraph initialized LLMs via Factory")
except Exception as e:
print(f"❌ DiscoveryGraph initialization failed: {e}")
def test_traditional_scanner_init():
print("Testing TraditionalScanner initialization...")
config = {"discovery": {}}
mock_llm = MagicMock()
mock_executor = MagicMock()
try:
scanner = TraditionalScanner(config, mock_llm, mock_executor)
assert scanner.execute_tool == mock_executor
print("✅ TraditionalScanner initialized")
# Test scan (mocking tools)
mock_executor.return_value = {"valid": ["AAPL"], "invalid": []}
state = {"trade_date": "2023-10-27"}
# We expect some errors printed because we didn't mock everything perfect,
# but it shouldn't crash.
print(" Running scan (expecting some print errors due to missing tools)...")
candidates = scanner.scan(state)
print(f" Scan returned {len(candidates)} candidates")
print("✅ TraditionalScanner scan() ran without crash")
except Exception as e:
print(f"❌ TraditionalScanner failed: {e}")
def cleanup():
if os.path.exists("tests/temp_results"):
shutil.rmtree("tests/temp_results")
if __name__ == "__main__":
try:
test_graph_init_with_factory()
test_traditional_scanner_init()
print("\nAll checks passed!")
finally:
cleanup()

View File

@ -1,23 +1,18 @@
from .utils.agent_utils import create_msg_delete
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.memory import FinancialSituationMemory
from .analysts.fundamentals_analyst import create_fundamentals_analyst
from .analysts.market_analyst import create_market_analyst
from .analysts.news_analyst import create_news_analyst
from .analysts.social_media_analyst import create_social_media_analyst
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .researchers.bear_researcher import create_bear_researcher
from .researchers.bull_researcher import create_bull_researcher
from .risk_mgmt.aggresive_debator import create_risky_debator
from .risk_mgmt.conservative_debator import create_safe_debator
from .risk_mgmt.neutral_debator import create_neutral_debator
from .managers.research_manager import create_research_manager
from .managers.risk_manager import create_risk_manager
from .trader.trader import create_trader
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
from .utils.agent_utils import create_msg_delete
from .utils.memory import FinancialSituationMemory
__all__ = [
"FinancialSituationMemory",

View File

@ -1,23 +1,10 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.tools.generator import get_agent_tools
from tradingagents.dataflows.config import get_config
from tradingagents.agents.utils.prompt_templates import (
BASE_COLLABORATIVE_BOILERPLATE,
get_date_awareness_section,
)
from tradingagents.agents.utils.agent_utils import create_analyst_node
from tradingagents.agents.utils.prompt_templates import get_date_awareness_section
def create_fundamentals_analyst(llm):
def fundamentals_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
tools = get_agent_tools("fundamentals")
system_message = f"""You are a Fundamental Analyst assessing {ticker}'s financial health with SHORT-TERM trading relevance.
def _build_prompt(ticker, current_date):
return f"""You are a Fundamental Analyst assessing {ticker}'s financial health with SHORT-TERM trading relevance.
{get_date_awareness_section(current_date)}
@ -91,31 +78,4 @@ For each fundamental metric, ask:
Date: {current_date} | Ticker: {ticker}"""
tool_names_str = ", ".join([tool.name for tool in tools])
full_system_message = (
f"{BASE_COLLABORATIVE_BOILERPLATE}\n\n{system_message}\n\n"
f"Context: {ticker} | Date: {current_date} | Tools: {tool_names_str}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", full_system_message),
MessagesPlaceholder(variable_name="messages"),
]
)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"fundamentals_report": report,
}
return fundamentals_analyst_node
return create_analyst_node(llm, "fundamentals", "fundamentals_report", _build_prompt)

View File

@ -1,24 +1,10 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.tools.generator import get_agent_tools
from tradingagents.dataflows.config import get_config
from tradingagents.agents.utils.prompt_templates import (
BASE_COLLABORATIVE_BOILERPLATE,
get_date_awareness_section,
)
from tradingagents.agents.utils.agent_utils import create_analyst_node
from tradingagents.agents.utils.prompt_templates import get_date_awareness_section
def create_market_analyst(llm):
def market_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
tools = get_agent_tools("market")
system_message = f"""You are a Market Technical Analyst specializing in identifying actionable short-term trading signals through technical indicators.
def _build_prompt(ticker, current_date):
return f"""You are a Market Technical Analyst specializing in identifying actionable short-term trading signals through technical indicators.
## YOUR MISSION
Analyze {ticker}'s technical setup and identify the 3-5 most relevant trading signals for short-term opportunities (days to weeks, not months).
@ -103,32 +89,4 @@ Available Indicators:
Current date: {current_date} | Ticker: {ticker}"""
tool_names_str = ", ".join([tool.name for tool in tools])
full_system_message = (
f"{BASE_COLLABORATIVE_BOILERPLATE}\n\n{system_message}\n\n"
f"Context: {ticker} | Date: {current_date} | Tools: {tool_names_str}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", full_system_message),
MessagesPlaceholder(variable_name="messages"),
]
)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"market_report": report,
}
return market_analyst_node
return create_analyst_node(llm, "market", "market_report", _build_prompt)

View File

@ -1,23 +1,10 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.tools.generator import get_agent_tools
from tradingagents.dataflows.config import get_config
from tradingagents.agents.utils.prompt_templates import (
BASE_COLLABORATIVE_BOILERPLATE,
get_date_awareness_section,
)
from tradingagents.agents.utils.agent_utils import create_analyst_node
from tradingagents.agents.utils.prompt_templates import get_date_awareness_section
def create_news_analyst(llm):
def news_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
from tradingagents.tools.generator import get_agent_tools
tools = get_agent_tools("news")
system_message = f"""You are a News Intelligence Analyst finding SHORT-TERM catalysts for {ticker}.
def _build_prompt(ticker, current_date):
return f"""You are a News Intelligence Analyst finding SHORT-TERM catalysts for {ticker}.
{get_date_awareness_section(current_date)}
@ -78,30 +65,4 @@ For each:
Date: {current_date} | Ticker: {ticker}"""
tool_names_str = ", ".join([tool.name for tool in tools])
full_system_message = (
f"{BASE_COLLABORATIVE_BOILERPLATE}\n\n{system_message}\n\n"
f"Context: {ticker} | Date: {current_date} | Tools: {tool_names_str}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", full_system_message),
MessagesPlaceholder(variable_name="messages"),
]
)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"news_report": report,
}
return news_analyst_node
return create_analyst_node(llm, "news", "news_report", _build_prompt)

View File

@ -1,23 +1,10 @@
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import time
import json
from tradingagents.tools.generator import get_agent_tools
from tradingagents.dataflows.config import get_config
from tradingagents.agents.utils.prompt_templates import (
BASE_COLLABORATIVE_BOILERPLATE,
get_date_awareness_section,
)
from tradingagents.agents.utils.agent_utils import create_analyst_node
from tradingagents.agents.utils.prompt_templates import get_date_awareness_section
def create_social_media_analyst(llm):
def social_media_analyst_node(state):
current_date = state["trade_date"]
ticker = state["company_of_interest"]
company_name = state["company_of_interest"]
tools = get_agent_tools("social")
system_message = f"""You are a Social Sentiment Analyst tracking {ticker}'s retail momentum for SHORT-TERM signals.
def _build_prompt(ticker, current_date):
return f"""You are a Social Sentiment Analyst tracking {ticker}'s retail momentum for SHORT-TERM signals.
{get_date_awareness_section(current_date)}
@ -76,31 +63,4 @@ When aggregating sentiment, weight sources by credibility:
Date: {current_date} | Ticker: {ticker}"""
tool_names_str = ", ".join([tool.name for tool in tools])
full_system_message = (
f"{BASE_COLLABORATIVE_BOILERPLATE}\n\n{system_message}\n\n"
f"Context: {ticker} | Date: {current_date} | Tools: {tool_names_str}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", full_system_message),
MessagesPlaceholder(variable_name="messages"),
]
)
chain = prompt | llm.bind_tools(tools)
result = chain.invoke(state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"sentiment_report": report,
}
return social_media_analyst_node
return create_analyst_node(llm, "social", "sentiment_report", _build_prompt)

View File

@ -1,5 +1,5 @@
import time
import json
from tradingagents.agents.utils.agent_utils import format_memory_context
from tradingagents.agents.utils.llm_utils import parse_llm_response
def create_research_manager(llm, memory):
@ -12,25 +12,10 @@ def create_research_manager(llm, memory):
investment_debate_state = state["investment_debate_state"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
if memory:
past_memories = memory.get_memories(curr_situation, n_matches=2)
else:
past_memories = []
past_memory_str = format_memory_context(memory, state)
if past_memories:
past_memory_str = "### Past Lessons Applied\\n**Reflections from Similar Situations:**\\n"
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\\n\\n"
past_memory_str += "\\n\\n**How I'm Using These Lessons:**\\n"
past_memory_str += "- [Specific adjustment based on past mistake/success]\\n"
past_memory_str += "- [Impact on current conviction level]\\n"
else:
past_memory_str = "" # Don't include placeholder when no memories
prompt = f"""You are the Trade Judge for {state["company_of_interest"]}. Decide if there is a SHORT-TERM edge to trade this stock (1-2 weeks).
prompt = (
f"""You are the Trade Judge for {state["company_of_interest"]}. Decide if there is a SHORT-TERM edge to trade this stock (1-2 weeks).
## CORE RULES (CRITICAL)
- Evaluate this ticker IN ISOLATION (no portfolio sizing, no portfolio impact, no correlation talk).
@ -64,13 +49,19 @@ Choose the direction with the higher score. If tied, choose BUY.
### What Could Break It
- [2 bullets max: key risks]
""" + (f"""
"""
+ (
f"""
## PAST LESSONS
Here are reflections on past mistakes - apply these lessons:
{past_memory_str}
**Learning Check:** How are you adjusting based on these past situations?
""" if past_memory_str else "") + f"""
"""
if past_memory_str
else ""
)
+ f"""
---
**DEBATE TO JUDGE:**
@ -81,20 +72,22 @@ Technical: {market_research_report}
Sentiment: {sentiment_report}
News: {news_report}
Fundamentals: {fundamentals_report}"""
)
response = llm.invoke(prompt)
response_text = parse_llm_response(response.content)
new_investment_debate_state = {
"judge_decision": response.content,
"judge_decision": response_text,
"history": investment_debate_state.get("history", ""),
"bear_history": investment_debate_state.get("bear_history", ""),
"bull_history": investment_debate_state.get("bull_history", ""),
"current_response": response.content,
"current_response": response_text,
"count": investment_debate_state["count"],
}
return {
"investment_debate_state": new_investment_debate_state,
"investment_plan": response.content,
"investment_plan": response_text,
}
return research_manager_node

View File

@ -1,5 +1,5 @@
import time
import json
from tradingagents.agents.utils.agent_utils import format_memory_context
from tradingagents.agents.utils.llm_utils import parse_llm_response
def create_risk_manager(llm, memory):
@ -15,25 +15,10 @@ def create_risk_manager(llm, memory):
sentiment_report = state["sentiment_report"]
trader_plan = state.get("trader_investment_plan") or state.get("investment_plan", "")
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
if memory:
past_memories = memory.get_memories(curr_situation, n_matches=2)
else:
past_memories = []
past_memory_str = format_memory_context(memory, state)
if past_memories:
past_memory_str = "### Past Lessons Applied\\n**Reflections from Similar Situations:**\\n"
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\\n\\n"
past_memory_str += "\\n\\n**How I'm Using These Lessons:**\\n"
past_memory_str += "- [Specific adjustment based on past mistake/success]\\n"
past_memory_str += "- [Impact on current conviction level]\\n"
else:
past_memory_str = "" # Don't include placeholder when no memories
prompt = f"""You are the Final Trade Decider for {company_name}. Make the final SHORT-TERM call (5-14 days) based on the risk debate and the provided data.
prompt = (
f"""You are the Final Trade Decider for {company_name}. Make the final SHORT-TERM call (5-14 days) based on the risk debate and the provided data.
## CORE RULES (CRITICAL)
- Evaluate this ticker IN ISOLATION (no portfolio sizing, no portfolio impact, no correlation analysis).
@ -66,13 +51,19 @@ If evidence is contradictory, still choose BUY or SELL and set conviction to Low
### Key Risks
- [2 bullets max: main ways it fails]
""" + (f"""
"""
+ (
f"""
## PAST LESSONS - CRITICAL
Review past mistakes to avoid repeating trade-setup errors:
{past_memory_str}
**Self-Check:** Have similar setups failed before? What was the key mistake (timing, catalyst read, or stop placement)?
""" if past_memory_str else "") + f"""
"""
if past_memory_str
else ""
)
+ f"""
---
**RISK DEBATE TO JUDGE:**
@ -84,11 +75,13 @@ Sentiment: {sentiment_report}
News: {news_report}
Fundamentals: {fundamentals_report}
"""
)
response = llm.invoke(prompt)
response_text = parse_llm_response(response.content)
new_risk_debate_state = {
"judge_decision": response.content,
"judge_decision": response_text,
"history": risk_debate_state["history"],
"risky_history": risk_debate_state["risky_history"],
"safe_history": risk_debate_state["safe_history"],
@ -102,7 +95,7 @@ Fundamentals: {fundamentals_report}
return {
"risk_debate_state": new_risk_debate_state,
"final_trade_decision": response.content,
"final_trade_decision": response_text,
}
return risk_manager_node

View File

@ -1,6 +1,5 @@
from langchain_core.messages import AIMessage
import time
import json
from tradingagents.agents.utils.agent_utils import format_memory_context
from tradingagents.agents.utils.llm_utils import create_and_invoke_chain, parse_llm_response
def create_bear_researcher(llm, memory):
@ -15,23 +14,7 @@ def create_bear_researcher(llm, memory):
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
if memory:
past_memories = memory.get_memories(curr_situation, n_matches=2)
else:
past_memories = []
if past_memories:
past_memory_str = "### Past Lessons Applied\n**Reflections from Similar Situations:**\n"
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\n\n"
past_memory_str += "\n\n**How I'm Using These Lessons:**\n"
past_memory_str += "- [Specific adjustment based on past mistake/success]\n"
past_memory_str += "- [Impact on current conviction level]\n"
else:
past_memory_str = ""
past_memory_str = format_memory_context(memory, state)
prompt = f"""You are the Bear Analyst making the case for SHORT-TERM SELL/AVOID (1-2 weeks).
@ -87,7 +70,8 @@ Fundamentals: {fundamentals_report}
**DEBATE:**
History: {history}
Last Bull: {current_response}
""" + (f"""
""" + (
f"""
## PAST LESSONS APPLICATION (Review BEFORE making arguments)
{past_memory_str}
@ -97,11 +81,16 @@ Last Bull: {current_response}
3. **How I'm Adjusting:** [Specific change to current argument based on lesson]
4. **Impact on Conviction:** [Increases/Decreases/No change to conviction level]
Apply lessons: How are you adjusting?""" if past_memory_str else "")
Apply lessons: How are you adjusting?"""
if past_memory_str
else ""
)
response = llm.invoke(prompt)
response = create_and_invoke_chain(llm, [], prompt, [])
argument = f"Bear Analyst: {response.content}"
response_text = parse_llm_response(response.content)
argument = f"Bear Analyst: {response_text}"
new_investment_debate_state = {
"history": history + "\n" + argument,

View File

@ -1,6 +1,5 @@
from langchain_core.messages import AIMessage
import time
import json
from tradingagents.agents.utils.agent_utils import format_memory_context
from tradingagents.agents.utils.llm_utils import create_and_invoke_chain, parse_llm_response
def create_bull_researcher(llm, memory):
@ -15,23 +14,7 @@ def create_bull_researcher(llm, memory):
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
if memory:
past_memories = memory.get_memories(curr_situation, n_matches=2)
else:
past_memories = []
if past_memories:
past_memory_str = "### Past Lessons Applied\\n**Reflections from Similar Situations:**\\n"
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\\n\\n"
past_memory_str += "\\n\\n**How I'm Using These Lessons:**\\n"
past_memory_str += "- [Specific adjustment based on past mistake/success]\\n"
past_memory_str += "- [Impact on current conviction level]\\n"
else:
past_memory_str = "" # Don't include placeholder when no memories
past_memory_str = format_memory_context(memory, state)
prompt = f"""You are the Bull Analyst making the case for a SHORT-TERM BUY (1-2 weeks).
@ -86,7 +69,8 @@ Fundamentals: {fundamentals_report}
**DEBATE:**
History: {history}
Last Bear: {current_response}
""" + (f"""
""" + (
f"""
## PAST LESSONS APPLICATION (Review BEFORE making arguments)
{past_memory_str}
@ -96,11 +80,16 @@ Last Bear: {current_response}
3. **How I'm Adjusting:** [Specific change to current argument based on lesson]
4. **Impact on Conviction:** [Increases/Decreases/No change to conviction level]
Apply past lessons: How are you adjusting based on similar situations?""" if past_memory_str else "")
Apply past lessons: How are you adjusting based on similar situations?"""
if past_memory_str
else ""
)
response = llm.invoke(prompt)
response = create_and_invoke_chain(llm, [], prompt, [])
argument = f"Bull Analyst: {response.content}"
response_text = parse_llm_response(response.content)
argument = f"Bull Analyst: {response_text}"
new_investment_debate_state = {
"history": history + "\n" + argument,

View File

@ -1,12 +1,11 @@
import time
import json
from tradingagents.agents.utils.agent_utils import update_risk_debate_state
from tradingagents.agents.utils.llm_utils import parse_llm_response
def create_risky_debator(llm):
def risky_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
risky_history = risk_debate_state.get("risky_history", "")
current_safe_response = risk_debate_state.get("current_safe_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
@ -67,23 +66,9 @@ State whether you agree with the Trader's direction (BUY/SELL) or flip it (no HO
**If no other arguments yet:** Present your strongest case for why this trade can work soon, using only the provided data."""
response = llm.invoke(prompt)
response_text = parse_llm_response(response.content)
argument = f"Risky Analyst: {response_text}"
argument = f"Risky Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risky_history + "\n" + argument,
"safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Risky",
"current_risky_response": argument,
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
),
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return {"risk_debate_state": update_risk_debate_state(risk_debate_state, argument, "Risky")}
return risky_node

View File

@ -1,13 +1,11 @@
from langchain_core.messages import AIMessage
import time
import json
from tradingagents.agents.utils.agent_utils import update_risk_debate_state
from tradingagents.agents.utils.llm_utils import parse_llm_response
def create_safe_debator(llm):
def safe_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
safe_history = risk_debate_state.get("safe_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "")
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
@ -69,25 +67,9 @@ Choose BUY or SELL (no HOLD). If the setup looks poor, still pick the less-bad s
**If no other arguments yet:** Identify trade invalidation and the key risks using only the provided data."""
response = llm.invoke(prompt)
response_text = parse_llm_response(response.content)
argument = f"Safe Analyst: {response_text}"
argument = f"Safe Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": safe_history + "\n" + argument,
"neutral_history": risk_debate_state.get("neutral_history", ""),
"latest_speaker": "Safe",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
),
"current_safe_response": argument,
"current_neutral_response": risk_debate_state.get(
"current_neutral_response", ""
),
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return {"risk_debate_state": update_risk_debate_state(risk_debate_state, argument, "Safe")}
return safe_node

View File

@ -1,12 +1,11 @@
import time
import json
from tradingagents.agents.utils.agent_utils import update_risk_debate_state
from tradingagents.agents.utils.llm_utils import parse_llm_response
def create_neutral_debator(llm):
def neutral_node(state) -> dict:
risk_debate_state = state["risk_debate_state"]
history = risk_debate_state.get("history", "")
neutral_history = risk_debate_state.get("neutral_history", "")
current_risky_response = risk_debate_state.get("current_risky_response", "")
current_safe_response = risk_debate_state.get("current_safe_response", "")
@ -66,23 +65,9 @@ Choose BUY or SELL (no HOLD). If the edge is unclear, pick the less-bad side and
**If no other arguments yet:** Provide a simple base-case view using only the provided data."""
response = llm.invoke(prompt)
response_text = parse_llm_response(response.content)
argument = f"Neutral Analyst: {response_text}"
argument = f"Neutral Analyst: {response.content}"
new_risk_debate_state = {
"history": history + "\n" + argument,
"risky_history": risk_debate_state.get("risky_history", ""),
"safe_history": risk_debate_state.get("safe_history", ""),
"neutral_history": neutral_history + "\n" + argument,
"latest_speaker": "Neutral",
"current_risky_response": risk_debate_state.get(
"current_risky_response", ""
),
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
"current_neutral_response": argument,
"count": risk_debate_state["count"] + 1,
}
return {"risk_debate_state": new_risk_debate_state}
return {"risk_debate_state": update_risk_debate_state(risk_debate_state, argument, "Neutral")}
return neutral_node

View File

@ -1,6 +1,7 @@
import functools
import time
import json
from tradingagents.agents.utils.agent_utils import format_memory_context
from tradingagents.agents.utils.llm_utils import parse_llm_response
def create_trader(llm, memory):
@ -12,22 +13,7 @@ def create_trader(llm, memory):
news_report = state["news_report"]
fundamentals_report = state["fundamentals_report"]
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
if memory:
past_memories = memory.get_memories(curr_situation, n_matches=2)
else:
past_memories = []
if past_memories:
past_memory_str = "### Past Lessons Applied\\n**Reflections from Similar Situations:**\\n"
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\\n\\n"
past_memory_str += "\\n\\n**How I'm Using These Lessons:**\\n"
past_memory_str += "- [Specific adjustment based on past mistake/success]\\n"
past_memory_str += "- [Impact on current conviction level]\\n"
else:
past_memory_str = "" # Don't include placeholder when no memories
past_memory_str = format_memory_context(memory, state)
context = {
"role": "user",
@ -80,10 +66,11 @@ def create_trader(llm, memory):
]
result = llm.invoke(messages)
trader_plan = parse_llm_response(result.content)
return {
"messages": [result],
"trader_investment_plan": result.content,
"trader_investment_plan": trader_plan,
"sender": name,
}

View File

@ -1,20 +1,15 @@
from typing import Annotated, Sequence
from datetime import date, timedelta, datetime
from typing_extensions import TypedDict, Optional
from langchain_openai import ChatOpenAI
from typing import Annotated
from langgraph.graph import MessagesState
from typing_extensions import TypedDict
from tradingagents.agents import *
from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
# Researcher team state
class InvestDebateState(TypedDict):
bull_history: Annotated[
str, "Bullish Conversation history"
] # Bullish Conversation history
bear_history: Annotated[
str, "Bearish Conversation history"
] # Bullish Conversation history
bull_history: Annotated[str, "Bullish Conversation history"] # Bullish Conversation history
bear_history: Annotated[str, "Bearish Conversation history"] # Bullish Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
current_response: Annotated[str, "Latest response"] # Last response
judge_decision: Annotated[str, "Final judge decision"] # Last response
@ -23,23 +18,13 @@ class InvestDebateState(TypedDict):
# Risk management team state
class RiskDebateState(TypedDict):
risky_history: Annotated[
str, "Risky Agent's Conversation history"
] # Conversation history
safe_history: Annotated[
str, "Safe Agent's Conversation history"
] # Conversation history
neutral_history: Annotated[
str, "Neutral Agent's Conversation history"
] # Conversation history
risky_history: Annotated[str, "Risky Agent's Conversation history"] # Conversation history
safe_history: Annotated[str, "Safe Agent's Conversation history"] # Conversation history
neutral_history: Annotated[str, "Neutral Agent's Conversation history"] # Conversation history
history: Annotated[str, "Conversation history"] # Conversation history
latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[
str, "Latest response by the risky analyst"
] # Last response
current_safe_response: Annotated[
str, "Latest response by the safe analyst"
] # Last response
current_risky_response: Annotated[str, "Latest response by the risky analyst"] # Last response
current_safe_response: Annotated[str, "Latest response by the safe analyst"] # Last response
current_neutral_response: Annotated[
str, "Latest response by the neutral analyst"
] # Last response
@ -56,9 +41,7 @@ class AgentState(MessagesState):
# research step
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[
str, "Report from the News Researcher of current world affairs"
]
news_report: Annotated[str, "Report from the News Researcher of current world affairs"]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
# researcher team discussion step
@ -70,9 +53,7 @@ class AgentState(MessagesState):
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
# risk management team discussion step
risk_debate_state: Annotated[
RiskDebateState, "Current state of the debate on evaluating risk"
]
risk_debate_state: Annotated[RiskDebateState, "Current state of the debate on evaluating risk"]
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
@ -84,5 +65,6 @@ class DiscoveryState(TypedDict):
opportunities: Annotated[list[dict], "List of final opportunities with rationale"]
final_ranking: Annotated[str, "Final ranking from LLM"]
status: Annotated[str, "Current status of discovery"]
tool_logs: Annotated[list[dict], "Detailed logs of all tool calls across all nodes (scanner, filter, deep_dive)"]
tool_logs: Annotated[
list[dict], "Detailed logs of all tool calls across all nodes (scanner, filter, deep_dive)"
]

View File

@ -1,7 +1,13 @@
from typing import Any, Callable, Dict, List
from langchain_core.messages import HumanMessage, RemoveMessage
# Import all tools from the new registry-based system
from tradingagents.tools.generator import ALL_TOOLS
from tradingagents.agents.utils.llm_utils import (
create_and_invoke_chain,
parse_llm_response,
)
from tradingagents.agents.utils.prompt_templates import format_analyst_prompt
from tradingagents.tools.generator import ALL_TOOLS, get_agent_tools
# Re-export tools for backward compatibility
get_stock_data = ALL_TOOLS["get_stock_data"]
@ -20,20 +26,112 @@ get_insider_transactions = ALL_TOOLS["get_insider_transactions"]
# Legacy alias for backward compatibility
validate_ticker_tool = validate_ticker
def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
def format_memory_context(memory: Any, state: Dict[str, Any], n_matches: int = 2) -> str:
"""Fetch and format past memories into a prompt section.
Returns the formatted memory string, or "" if no memories available.
Identical logic previously duplicated across 5 agent files.
"""
reports = (
state["market_report"],
state["sentiment_report"],
state["news_report"],
state["fundamentals_report"],
)
curr_situation = "\n\n".join(reports)
if not memory:
return ""
past_memories = memory.get_memories(curr_situation, n_matches=n_matches)
if not past_memories:
return ""
past_memory_str = "### Past Lessons Applied\\n**Reflections from Similar Situations:**\\n"
for i, rec in enumerate(past_memories, 1):
past_memory_str += rec["recommendation"] + "\\n\\n"
past_memory_str += "\\n\\n**How I'm Using These Lessons:**\\n"
past_memory_str += "- [Specific adjustment based on past mistake/success]\\n"
past_memory_str += "- [Impact on current conviction level]\\n"
return past_memory_str
def update_risk_debate_state(
debate_state: Dict[str, Any], argument: str, role: str
) -> Dict[str, Any]:
"""Build updated risk debate state after a debator speaks.
Args:
debate_state: Current risk_debate_state dict.
argument: The formatted argument string (e.g. "Safe Analyst: ...").
role: One of "Safe", "Risky", "Neutral".
"""
role_key = role.lower() # "safe", "risky", "neutral"
new_state = {
"history": debate_state.get("history", "") + "\n" + argument,
"risky_history": debate_state.get("risky_history", ""),
"safe_history": debate_state.get("safe_history", ""),
"neutral_history": debate_state.get("neutral_history", ""),
"latest_speaker": role,
"current_risky_response": debate_state.get("current_risky_response", ""),
"current_safe_response": debate_state.get("current_safe_response", ""),
"current_neutral_response": debate_state.get("current_neutral_response", ""),
"count": debate_state["count"] + 1,
}
# Append to the speaker's own history and set their current response
new_state[f"{role_key}_history"] = (
debate_state.get(f"{role_key}_history", "") + "\n" + argument
)
new_state[f"current_{role_key}_response"] = argument
return new_state
def create_analyst_node(
llm: Any,
tool_group: str,
output_key: str,
prompt_builder: Callable[[str, str], str],
) -> Callable:
"""Factory for analyst graph nodes.
Args:
llm: The LLM to use.
tool_group: Tool group name for ``get_agent_tools`` (e.g. "fundamentals").
output_key: State key for the report (e.g. "fundamentals_report").
prompt_builder: ``(ticker, current_date) -> system_message`` callable.
"""
def analyst_node(state: Dict[str, Any]) -> Dict[str, Any]:
ticker = state["company_of_interest"]
current_date = state["trade_date"]
tools = get_agent_tools(tool_group)
system_message = prompt_builder(ticker, current_date)
tool_names_str = ", ".join(tool.name for tool in tools)
full_message = format_analyst_prompt(system_message, current_date, ticker, tool_names_str)
result = create_and_invoke_chain(llm, tools, full_message, state["messages"])
report = ""
if len(result.tool_calls) == 0:
report = parse_llm_response(result.content)
return {"messages": [result], output_key: report}
return analyst_node

View File

@ -9,15 +9,16 @@ This module creates agent memories from historical stock data by:
5. Storing memories in ChromaDB for future retrieval
"""
import os
import re
import json
import yfinance as yf
import pandas as pd
from datetime import datetime, timedelta
from typing import List, Dict, Tuple, Optional, Any
from tradingagents.tools.executor import execute_tool
from typing import Any, Dict, List, Optional, Tuple
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.dataflows.y_finance import get_ticker_history
from tradingagents.tools.executor import execute_tool
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class HistoricalMemoryBuilder:
@ -35,7 +36,7 @@ class HistoricalMemoryBuilder:
"bear": 0,
"trader": 0,
"invest_judge": 0,
"risk_manager": 0
"risk_manager": 0,
}
def get_tickers_from_alpha_vantage(self, limit: int = 20) -> List[str]:
@ -48,7 +49,7 @@ class HistoricalMemoryBuilder:
Returns:
List of ticker symbols from top gainers and losers
"""
print(f"\n🔍 Fetching top movers from Alpha Vantage...")
logger.info("🔍 Fetching top movers from Alpha Vantage...")
try:
# Use execute_tool to call the alpha vantage function
@ -57,13 +58,13 @@ class HistoricalMemoryBuilder:
# Parse the markdown table response to extract tickers
tickers = set()
lines = response.split('\n')
lines = response.split("\n")
for line in lines:
# Look for table rows with ticker data
if '|' in line and not line.strip().startswith('|---'):
parts = [p.strip() for p in line.split('|')]
if "|" in line and not line.strip().startswith("|---"):
parts = [p.strip() for p in line.split("|")]
# Table format: | Ticker | Price | Change % | Volume |
if len(parts) >= 2 and parts[1] and parts[1] not in ['Ticker', '']:
if len(parts) >= 2 and parts[1] and parts[1] not in ["Ticker", ""]:
ticker = parts[1].strip()
# Filter out warrants, units, and problematic tickers
@ -71,14 +72,16 @@ class HistoricalMemoryBuilder:
tickers.add(ticker)
ticker_list = sorted(list(tickers))
print(f" ✅ Found {len(ticker_list)} unique tickers from Alpha Vantage")
print(f" Tickers: {', '.join(ticker_list[:10])}{'...' if len(ticker_list) > 10 else ''}")
logger.info(f"✅ Found {len(ticker_list)} unique tickers from Alpha Vantage")
logger.debug(
f"Tickers: {', '.join(ticker_list[:10])}{'...' if len(ticker_list) > 10 else ''}"
)
return ticker_list
except Exception as e:
print(f" ⚠️ Error fetching from Alpha Vantage: {e}")
print(f" Falling back to empty list")
logger.warning(f"⚠️ Error fetching from Alpha Vantage: {e}")
logger.warning("Falling back to empty list")
return []
def _is_valid_ticker(self, ticker: str) -> bool:
@ -102,23 +105,23 @@ class HistoricalMemoryBuilder:
return False
# Must be uppercase letters and numbers only
if not re.match(r'^[A-Z]{1,5}$', ticker):
if not re.match(r"^[A-Z]{1,5}$", ticker):
return False
# Filter out warrants (W, WW, WS suffix)
if ticker.endswith('W') or ticker.endswith('WW') or ticker.endswith('WS'):
if ticker.endswith("W") or ticker.endswith("WW") or ticker.endswith("WS"):
return False
# Filter out units
if ticker.endswith('U'):
if ticker.endswith("U"):
return False
# Filter out rights
if ticker.endswith('R') and len(ticker) > 1:
if ticker.endswith("R") and len(ticker) > 1:
return False
# Filter out other suffixes that indicate derivatives
if ticker.endswith('Z'): # Often used for special situations
if ticker.endswith("Z"): # Often used for special situations
return False
return True
@ -129,7 +132,7 @@ class HistoricalMemoryBuilder:
start_date: str,
end_date: str,
min_move_pct: float = 15.0,
window_days: int = 5
window_days: int = 5,
) -> List[Dict[str, Any]]:
"""
Find stocks that had significant moves (>15% in 5 days).
@ -153,67 +156,66 @@ class HistoricalMemoryBuilder:
"""
high_movers = []
print(f"\n🔍 Scanning for high movers ({min_move_pct}%+ in {window_days} days)")
print(f" Period: {start_date} to {end_date}")
print(f" Tickers: {len(tickers)}\n")
logger.info(f"🔍 Scanning for high movers ({min_move_pct}%+ in {window_days} days)")
logger.info(f"Period: {start_date} to {end_date}")
logger.info(f"Tickers: {len(tickers)}")
for ticker in tickers:
try:
print(f" Scanning {ticker}...", end=" ")
logger.info(f"Scanning {ticker}...")
# Download historical data using yfinance
stock = yf.Ticker(ticker)
df = stock.history(start=start_date, end=end_date)
df = get_ticker_history(ticker, start=start_date, end=end_date)
if df.empty:
print("No data")
logger.debug(f"{ticker}: No data")
continue
# Calculate rolling returns over window_days
df['rolling_return'] = df['Close'].pct_change(periods=window_days) * 100
df["rolling_return"] = df["Close"].pct_change(periods=window_days) * 100
# Find periods with moves >= min_move_pct
significant_moves = df[abs(df['rolling_return']) >= min_move_pct]
significant_moves = df[abs(df["rolling_return"]) >= min_move_pct]
if not significant_moves.empty:
for idx, row in significant_moves.iterrows():
# Get the start date (window_days before this date)
move_end_date = idx.strftime('%Y-%m-%d')
move_start_date = (idx - timedelta(days=window_days)).strftime('%Y-%m-%d')
move_end_date = idx.strftime("%Y-%m-%d")
move_start_date = (idx - timedelta(days=window_days)).strftime("%Y-%m-%d")
# Get prices
try:
start_price = df.loc[df.index >= move_start_date, 'Close'].iloc[0]
end_price = row['Close']
move_pct = row['rolling_return']
start_price = df.loc[df.index >= move_start_date, "Close"].iloc[0]
end_price = row["Close"]
move_pct = row["rolling_return"]
high_movers.append({
'ticker': ticker,
'move_start_date': move_start_date,
'move_end_date': move_end_date,
'move_pct': move_pct,
'direction': 'up' if move_pct > 0 else 'down',
'start_price': start_price,
'end_price': end_price
})
high_movers.append(
{
"ticker": ticker,
"move_start_date": move_start_date,
"move_end_date": move_end_date,
"move_pct": move_pct,
"direction": "up" if move_pct > 0 else "down",
"start_price": start_price,
"end_price": end_price,
}
)
except (IndexError, KeyError):
continue
print(f"Found {len([m for m in high_movers if m['ticker'] == ticker])} moves")
logger.info(f"Found {len([m for m in high_movers if m['ticker'] == ticker])} moves for {ticker}")
else:
print("No significant moves")
logger.debug(f"{ticker}: No significant moves")
except Exception as e:
print(f"Error: {e}")
logger.error(f"Error scanning {ticker}: {e}")
continue
print(f"\n✅ Total high movers found: {len(high_movers)}\n")
logger.info(f"✅ Total high movers found: {len(high_movers)}")
return high_movers
def run_retrospective_analysis(
self,
ticker: str,
analysis_date: str
self, ticker: str, analysis_date: str
) -> Optional[Dict[str, Any]]:
"""
Run the trading graph analysis for a ticker at a specific historical date.
@ -238,47 +240,48 @@ class HistoricalMemoryBuilder:
# Import here to avoid circular imports
from tradingagents.graph.trading_graph import TradingAgentsGraph
print(f" Running analysis for {ticker} on {analysis_date}...")
logger.info(f"Running analysis for {ticker} on {analysis_date}...")
# Create trading graph instance
# Use fewer analysts to reduce token usage
graph = TradingAgentsGraph(
selected_analysts=["market", "fundamentals"], # Skip social/news to reduce tokens
config=self.config,
debug=False
debug=False,
)
# Run the analysis (returns tuple: final_state, processed_signal)
final_state, _ = graph.propagate(ticker, analysis_date)
# Extract reports and decisions (with type safety)
def safe_get_str(d, key, default=''):
def safe_get_str(d, key, default=""):
"""Safely extract string from state, handling lists or other types."""
value = d.get(key, default)
if isinstance(value, list):
# If it's a list, try to extract text from messages
return ' '.join(str(item) for item in value)
return " ".join(str(item) for item in value)
return str(value) if value else default
# Extract reports and decisions
analysis_data = {
'market_report': safe_get_str(final_state, 'market_report'),
'sentiment_report': safe_get_str(final_state, 'sentiment_report'),
'news_report': safe_get_str(final_state, 'news_report'),
'fundamentals_report': safe_get_str(final_state, 'fundamentals_report'),
'investment_plan': safe_get_str(final_state, 'investment_plan'),
'final_decision': safe_get_str(final_state, 'final_trade_decision'),
"market_report": safe_get_str(final_state, "market_report"),
"sentiment_report": safe_get_str(final_state, "sentiment_report"),
"news_report": safe_get_str(final_state, "news_report"),
"fundamentals_report": safe_get_str(final_state, "fundamentals_report"),
"investment_plan": safe_get_str(final_state, "investment_plan"),
"final_decision": safe_get_str(final_state, "final_trade_decision"),
}
# Extract structured signals from reports
analysis_data['structured_signals'] = self.extract_structured_signals(analysis_data)
analysis_data["structured_signals"] = self.extract_structured_signals(analysis_data)
return analysis_data
except Exception as e:
print(f" Error running analysis: {e}")
logger.error(f"Error running analysis: {e}")
import traceback
print(f" Traceback: {traceback.format_exc()}")
logger.debug(f"Traceback: {traceback.format_exc()}")
return None
def extract_structured_signals(self, reports: Dict[str, str]) -> Dict[str, Any]:
@ -300,63 +303,101 @@ class HistoricalMemoryBuilder:
"""
signals = {}
market_report = reports.get('market_report', '')
sentiment_report = reports.get('sentiment_report', '')
news_report = reports.get('news_report', '')
fundamentals_report = reports.get('fundamentals_report', '')
market_report = reports.get("market_report", "")
sentiment_report = reports.get("sentiment_report", "")
news_report = reports.get("news_report", "")
fundamentals_report = reports.get("fundamentals_report", "")
# Extract volume signals
signals['unusual_volume'] = bool(
re.search(r'(unusual volume|volume spike|high volume|increased volume)', market_report, re.IGNORECASE)
signals["unusual_volume"] = bool(
re.search(
r"(unusual volume|volume spike|high volume|increased volume)",
market_report,
re.IGNORECASE,
)
)
# Extract sentiment
if re.search(r'(bullish|positive outlook|strong buy|buy)', sentiment_report + news_report, re.IGNORECASE):
signals['analyst_sentiment'] = 'bullish'
elif re.search(r'(bearish|negative outlook|strong sell|sell)', sentiment_report + news_report, re.IGNORECASE):
signals['analyst_sentiment'] = 'bearish'
if re.search(
r"(bullish|positive outlook|strong buy|buy)",
sentiment_report + news_report,
re.IGNORECASE,
):
signals["analyst_sentiment"] = "bullish"
elif re.search(
r"(bearish|negative outlook|strong sell|sell)",
sentiment_report + news_report,
re.IGNORECASE,
):
signals["analyst_sentiment"] = "bearish"
else:
signals['analyst_sentiment'] = 'neutral'
signals["analyst_sentiment"] = "neutral"
# Extract news sentiment
if re.search(r'(positive|good news|beat expectations|upgrade|growth)', news_report, re.IGNORECASE):
signals['news_sentiment'] = 'positive'
elif re.search(r'(negative|bad news|miss expectations|downgrade|decline)', news_report, re.IGNORECASE):
signals['news_sentiment'] = 'negative'
if re.search(
r"(positive|good news|beat expectations|upgrade|growth)", news_report, re.IGNORECASE
):
signals["news_sentiment"] = "positive"
elif re.search(
r"(negative|bad news|miss expectations|downgrade|decline)", news_report, re.IGNORECASE
):
signals["news_sentiment"] = "negative"
else:
signals['news_sentiment'] = 'neutral'
signals["news_sentiment"] = "neutral"
# Extract short interest
if re.search(r'(high short interest|heavily shorted|short squeeze)', market_report + news_report, re.IGNORECASE):
signals['short_interest'] = 'high'
elif re.search(r'(low short interest|minimal short)', market_report, re.IGNORECASE):
signals['short_interest'] = 'low'
if re.search(
r"(high short interest|heavily shorted|short squeeze)",
market_report + news_report,
re.IGNORECASE,
):
signals["short_interest"] = "high"
elif re.search(r"(low short interest|minimal short)", market_report, re.IGNORECASE):
signals["short_interest"] = "low"
else:
signals['short_interest'] = 'medium'
signals["short_interest"] = "medium"
# Extract insider activity
if re.search(r'(insider buying|executive purchased|insider purchases)', news_report + fundamentals_report, re.IGNORECASE):
signals['insider_activity'] = 'buying'
elif re.search(r'(insider selling|executive sold|insider sales)', news_report + fundamentals_report, re.IGNORECASE):
signals['insider_activity'] = 'selling'
if re.search(
r"(insider buying|executive purchased|insider purchases)",
news_report + fundamentals_report,
re.IGNORECASE,
):
signals["insider_activity"] = "buying"
elif re.search(
r"(insider selling|executive sold|insider sales)",
news_report + fundamentals_report,
re.IGNORECASE,
):
signals["insider_activity"] = "selling"
else:
signals['insider_activity'] = 'none'
signals["insider_activity"] = "none"
# Extract price trend
if re.search(r'(uptrend|bullish trend|rising|moving higher|higher highs)', market_report, re.IGNORECASE):
signals['price_trend'] = 'uptrend'
elif re.search(r'(downtrend|bearish trend|falling|moving lower|lower lows)', market_report, re.IGNORECASE):
signals['price_trend'] = 'downtrend'
if re.search(
r"(uptrend|bullish trend|rising|moving higher|higher highs)",
market_report,
re.IGNORECASE,
):
signals["price_trend"] = "uptrend"
elif re.search(
r"(downtrend|bearish trend|falling|moving lower|lower lows)",
market_report,
re.IGNORECASE,
):
signals["price_trend"] = "downtrend"
else:
signals['price_trend'] = 'sideways'
signals["price_trend"] = "sideways"
# Extract volatility
if re.search(r'(high volatility|volatile|wild swings|sharp movements)', market_report, re.IGNORECASE):
signals['volatility'] = 'high'
elif re.search(r'(low volatility|stable|steady)', market_report, re.IGNORECASE):
signals['volatility'] = 'low'
if re.search(
r"(high volatility|volatile|wild swings|sharp movements)", market_report, re.IGNORECASE
):
signals["volatility"] = "high"
elif re.search(r"(low volatility|stable|steady)", market_report, re.IGNORECASE):
signals["volatility"] = "low"
else:
signals['volatility'] = 'medium'
signals["volatility"] = "medium"
return signals
@ -368,7 +409,7 @@ class HistoricalMemoryBuilder:
min_move_pct: float = 15.0,
analysis_windows: List[int] = [7, 30],
max_samples: int = 50,
sample_strategy: str = "diverse"
sample_strategy: str = "diverse",
) -> Dict[str, FinancialSituationMemory]:
"""
Build memories by finding high movers and running retrospective analyses.
@ -391,25 +432,24 @@ class HistoricalMemoryBuilder:
Returns:
Dictionary of populated memory instances for each agent type
"""
print("=" * 70)
print("🏗️ BUILDING MEMORIES FROM HIGH MOVERS")
print("=" * 70)
logger.info("=" * 70)
logger.info("🏗️ BUILDING MEMORIES FROM HIGH MOVERS")
logger.info("=" * 70)
# Step 1: Find high movers
high_movers = self.find_high_movers(tickers, start_date, end_date, min_move_pct)
if not high_movers:
print("⚠️ No high movers found. Try a different date range or lower threshold.")
logger.warning("⚠️ No high movers found. Try a different date range or lower threshold.")
return {}
# Step 1.5: Sample/filter high movers based on strategy
sampled_movers = self._sample_high_movers(high_movers, max_samples, sample_strategy)
print(f"\n📊 Sampling Strategy: {sample_strategy}")
print(f" Total high movers found: {len(high_movers)}")
print(f" Samples to analyze: {len(sampled_movers)}")
print(f" Estimated runtime: ~{len(sampled_movers) * len(analysis_windows) * 2} minutes")
print()
logger.info(f"📊 Sampling Strategy: {sample_strategy}")
logger.info(f"Total high movers found: {len(high_movers)}")
logger.info(f"Samples to analyze: {len(sampled_movers)}")
logger.info(f"Estimated runtime: ~{len(sampled_movers) * len(analysis_windows) * 2} minutes")
# Initialize memory stores
agent_memories = {
@ -417,35 +457,35 @@ class HistoricalMemoryBuilder:
"bear": FinancialSituationMemory("bear_memory", self.config),
"trader": FinancialSituationMemory("trader_memory", self.config),
"invest_judge": FinancialSituationMemory("invest_judge_memory", self.config),
"risk_manager": FinancialSituationMemory("risk_manager_memory", self.config)
"risk_manager": FinancialSituationMemory("risk_manager_memory", self.config),
}
# Step 2: For each high mover, run retrospective analyses
print("\n📊 Running retrospective analyses...\n")
logger.info("📊 Running retrospective analyses...")
for idx, mover in enumerate(sampled_movers, 1):
ticker = mover['ticker']
move_pct = mover['move_pct']
direction = mover['direction']
move_start_date = mover['move_start_date']
ticker = mover["ticker"]
move_pct = mover["move_pct"]
direction = mover["direction"]
move_start_date = mover["move_start_date"]
print(f" [{idx}/{len(sampled_movers)}] {ticker}: {move_pct:+.1f}% {direction}")
logger.info(f"[{idx}/{len(sampled_movers)}] {ticker}: {move_pct:+.1f}% {direction}")
# Run analyses at different time windows before the move
for days_before in analysis_windows:
# Calculate analysis date
try:
analysis_date = (
datetime.strptime(move_start_date, '%Y-%m-%d') - timedelta(days=days_before)
).strftime('%Y-%m-%d')
datetime.strptime(move_start_date, "%Y-%m-%d") - timedelta(days=days_before)
).strftime("%Y-%m-%d")
print(f" Analyzing T-{days_before} days ({analysis_date})...")
logger.info(f"Analyzing T-{days_before} days ({analysis_date})...")
# Run trading graph analysis
analysis = self.run_retrospective_analysis(ticker, analysis_date)
if not analysis:
print(f" ⚠️ Analysis failed, skipping...")
logger.warning("⚠️ Analysis failed, skipping...")
continue
# Create combined situation text
@ -469,8 +509,7 @@ class HistoricalMemoryBuilder:
# Extract agent recommendation from investment plan and final decision
agent_recommendation = self._extract_recommendation(
analysis.get('investment_plan', ''),
analysis.get('final_decision', '')
analysis.get("investment_plan", ""), analysis.get("final_decision", "")
)
# Determine if agent was correct
@ -478,18 +517,22 @@ class HistoricalMemoryBuilder:
# Create metadata
metadata = {
'ticker': ticker,
'analysis_date': analysis_date,
'days_before_move': days_before,
'move_pct': abs(move_pct),
'move_direction': direction,
'agent_recommendation': agent_recommendation,
'was_correct': was_correct,
'structured_signals': analysis['structured_signals']
"ticker": ticker,
"analysis_date": analysis_date,
"days_before_move": days_before,
"move_pct": abs(move_pct),
"move_direction": direction,
"agent_recommendation": agent_recommendation,
"was_correct": was_correct,
"structured_signals": analysis["structured_signals"],
}
# Create recommendation text
lesson_text = f"This signal combination is reliable for predicting {direction} moves." if was_correct else "This signal combination can be misleading. Need to consider other factors."
lesson_text = (
f"This signal combination is reliable for predicting {direction} moves."
if was_correct
else "This signal combination can be misleading. Need to consider other factors."
)
recommendation_text = f"""
Agent Decision: {agent_recommendation}
@ -507,38 +550,40 @@ Lesson: {lesson_text}
# Store in all agent memories
for agent_type, memory in agent_memories.items():
memory.add_situations_with_metadata([
(situation_text, recommendation_text, metadata)
])
memory.add_situations_with_metadata(
[(situation_text, recommendation_text, metadata)]
)
self.memories_created[agent_type] = self.memories_created.get(agent_type, 0) + 1
print(f" ✅ Memory created: {agent_recommendation} -> {direction} ({was_correct})")
logger.info(
f"✅ Memory created: {agent_recommendation} -> {direction} ({was_correct})"
)
except Exception as e:
print(f" ⚠️ Error: {e}")
logger.warning(f"⚠️ Error: {e}")
continue
# Print summary
print("\n" + "=" * 70)
print("📊 MEMORY CREATION SUMMARY")
print("=" * 70)
print(f" High movers analyzed: {len(sampled_movers)}")
print(f" Analysis windows: {analysis_windows} days before move")
# Log summary
logger.info("=" * 70)
logger.info("📊 MEMORY CREATION SUMMARY")
logger.info("=" * 70)
logger.info(f" High movers analyzed: {len(sampled_movers)}")
logger.info(f" Analysis windows: {analysis_windows} days before move")
for agent_type, count in self.memories_created.items():
print(f" {agent_type.ljust(15)}: {count} memories")
logger.info(f" {agent_type.ljust(15)}: {count} memories")
# Print statistics
print("\n📈 MEMORY BANK STATISTICS")
print("=" * 70)
# Log statistics
logger.info("\n📈 MEMORY BANK STATISTICS")
logger.info("=" * 70)
for agent_type, memory in agent_memories.items():
stats = memory.get_statistics()
print(f"\n {agent_type.upper()}:")
print(f" Total memories: {stats['total_memories']}")
print(f" Accuracy rate: {stats['accuracy_rate']:.1f}%")
print(f" Avg move: {stats['avg_move_pct']:.1f}%")
logger.info(f"\n {agent_type.upper()}:")
logger.info(f" Total memories: {stats['total_memories']}")
logger.info(f" Accuracy rate: {stats['accuracy_rate']:.1f}%")
logger.info(f" Avg move: {stats['avg_move_pct']:.1f}%")
print("=" * 70 + "\n")
logger.info("=" * 70)
return agent_memories
@ -551,11 +596,13 @@ Lesson: {lesson_text}
combined_text = (investment_plan + " " + final_decision).lower()
# Check for clear buy/sell/hold signals
if re.search(r'\b(strong buy|buy|long position|bullish|recommend buying)\b', combined_text):
if re.search(r"\b(strong buy|buy|long position|bullish|recommend buying)\b", combined_text):
return "buy"
elif re.search(r'\b(strong sell|sell|short position|bearish|recommend selling)\b', combined_text):
elif re.search(
r"\b(strong sell|sell|short position|bearish|recommend selling)\b", combined_text
):
return "sell"
elif re.search(r'\b(hold|neutral|wait|avoid)\b', combined_text):
elif re.search(r"\b(hold|neutral|wait|avoid)\b", combined_text):
return "hold"
else:
return "unclear"
@ -589,10 +636,7 @@ Lesson: {lesson_text}
return "\n".join(lines)
def _sample_high_movers(
self,
high_movers: List[Dict[str, Any]],
max_samples: int,
strategy: str
self, high_movers: List[Dict[str, Any]], max_samples: int, strategy: str
) -> List[Dict[str, Any]]:
"""
Sample high movers based on strategy to reduce analysis time.
@ -612,12 +656,12 @@ Lesson: {lesson_text}
if strategy == "diverse":
# Get balanced mix of up/down moves across different magnitudes
up_moves = [m for m in high_movers if m['direction'] == 'up']
down_moves = [m for m in high_movers if m['direction'] == 'down']
up_moves = [m for m in high_movers if m["direction"] == "up"]
down_moves = [m for m in high_movers if m["direction"] == "down"]
# Sort each by magnitude
up_moves.sort(key=lambda x: abs(x['move_pct']), reverse=True)
down_moves.sort(key=lambda x: abs(x['move_pct']), reverse=True)
up_moves.sort(key=lambda x: abs(x["move_pct"]), reverse=True)
down_moves.sort(key=lambda x: abs(x["move_pct"]), reverse=True)
# Take half from each direction (or proportional if imbalanced)
up_count = min(len(up_moves), max_samples // 2)
@ -637,14 +681,14 @@ Lesson: {lesson_text}
# Divide into 3 buckets by magnitude
bucket_size = len(moves) // 3
large = moves[:bucket_size]
medium = moves[bucket_size:bucket_size*2]
small = moves[bucket_size*2:]
medium = moves[bucket_size : bucket_size * 2]
small = moves[bucket_size * 2 :]
# Sample proportionally from each bucket
samples = []
samples.extend(large[:count // 3])
samples.extend(medium[:count // 3])
samples.extend(small[:count - (2 * (count // 3))])
samples.extend(large[: count // 3])
samples.extend(medium[: count // 3])
samples.extend(small[: count - (2 * (count // 3))])
return samples
sampled = []
@ -655,12 +699,12 @@ Lesson: {lesson_text}
elif strategy == "largest":
# Take the largest absolute moves
sorted_movers = sorted(high_movers, key=lambda x: abs(x['move_pct']), reverse=True)
sorted_movers = sorted(high_movers, key=lambda x: abs(x["move_pct"]), reverse=True)
return sorted_movers[:max_samples]
elif strategy == "recent":
# Take the most recent moves
sorted_movers = sorted(high_movers, key=lambda x: x['move_end_date'], reverse=True)
sorted_movers = sorted(high_movers, key=lambda x: x["move_end_date"], reverse=True)
return sorted_movers[:max_samples]
elif strategy == "random":
@ -687,7 +731,9 @@ Lesson: {lesson_text}
# Get technical/price data (what Market Analyst sees)
stock_data = execute_tool("get_stock_data", symbol=ticker, start_date=date)
indicators = execute_tool("get_indicators", symbol=ticker, curr_date=date)
data["market_report"] = f"Stock Data:\n{stock_data}\n\nTechnical Indicators:\n{indicators}"
data["market_report"] = (
f"Stock Data:\n{stock_data}\n\nTechnical Indicators:\n{indicators}"
)
except Exception as e:
data["market_report"] = f"Error fetching market data: {e}"
@ -700,7 +746,9 @@ Lesson: {lesson_text}
try:
# Get sentiment (what Social Analyst sees)
sentiment = execute_tool("get_reddit_discussions", symbol=ticker, from_date=date, to_date=date)
sentiment = execute_tool(
"get_reddit_discussions", symbol=ticker, from_date=date, to_date=date
)
data["sentiment_report"] = sentiment
except Exception as e:
data["sentiment_report"] = f"Error fetching sentiment: {e}"
@ -727,14 +775,19 @@ Lesson: {lesson_text}
"""
try:
# Get stock prices for both dates
start_data = execute_tool("get_stock_data", symbol=ticker, start_date=start_date, end_date=start_date)
end_data = execute_tool("get_stock_data", symbol=ticker, start_date=end_date, end_date=end_date)
start_data = execute_tool(
"get_stock_data", symbol=ticker, start_date=start_date, end_date=start_date
)
end_data = execute_tool(
"get_stock_data", symbol=ticker, start_date=end_date, end_date=end_date
)
# Parse prices (this is simplified - you'd need to parse the actual response)
# Assuming response has close price - adjust based on actual API response
import re
start_match = re.search(r'Close[:\s]+\$?([\d.]+)', str(start_data))
end_match = re.search(r'Close[:\s]+\$?([\d.]+)', str(end_data))
start_match = re.search(r"Close[:\s]+\$?([\d.]+)", str(start_data))
end_match = re.search(r"Close[:\s]+\$?([\d.]+)", str(end_data))
if start_match and end_match:
start_price = float(start_match.group(1))
@ -743,10 +796,12 @@ Lesson: {lesson_text}
return None
except Exception as e:
print(f"Error calculating returns: {e}")
logger.error(f"Error calculating returns: {e}")
return None
def _create_bull_researcher_memory(self, situation: str, returns: float, ticker: str, date: str) -> str:
def _create_bull_researcher_memory(
self, situation: str, returns: float, ticker: str, date: str
) -> str:
"""Create memory for bull researcher based on outcome.
Returns lesson learned from bullish perspective.
@ -780,7 +835,9 @@ Stock moved {returns:.2f}%, indicating mixed signals.
Lesson: This pattern of indicators doesn't provide strong directional conviction. Look for clearer signals before making strong bullish arguments.
"""
def _create_bear_researcher_memory(self, situation: str, returns: float, ticker: str, date: str) -> str:
def _create_bear_researcher_memory(
self, situation: str, returns: float, ticker: str, date: str
) -> str:
"""Create memory for bear researcher based on outcome."""
if returns < -5:
return f"""SUCCESSFUL BEARISH ANALYSIS for {ticker} on {date}:
@ -842,7 +899,9 @@ Trading lesson:
Recommendation: Pattern recognition suggests {action} in similar future scenarios.
"""
def _create_invest_judge_memory(self, situation: str, returns: float, ticker: str, date: str) -> str:
def _create_invest_judge_memory(
self, situation: str, returns: float, ticker: str, date: str
) -> str:
"""Create memory for investment judge/research manager."""
if returns > 5:
verdict = "Strong BUY recommendation was warranted"
@ -868,7 +927,9 @@ When synthesizing bull/bear arguments in similar conditions:
Recommendation for similar situations: {verdict}
"""
def _create_risk_manager_memory(self, situation: str, returns: float, ticker: str, date: str) -> str:
def _create_risk_manager_memory(
self, situation: str, returns: float, ticker: str, date: str
) -> str:
"""Create memory for risk manager."""
volatility = "HIGH" if abs(returns) > 10 else "MEDIUM" if abs(returns) > 5 else "LOW"
@ -901,7 +962,7 @@ Recommendation: {risk_assessment}
start_date: str,
end_date: str,
lookforward_days: int = 7,
interval_days: int = 30
interval_days: int = 30,
) -> Dict[str, List[Tuple[str, str]]]:
"""Build historical memories for a stock across a date range.
@ -915,28 +976,22 @@ Recommendation: {risk_assessment}
Returns:
Dictionary mapping agent type to list of (situation, lesson) tuples
"""
memories = {
"bull": [],
"bear": [],
"trader": [],
"invest_judge": [],
"risk_manager": []
}
memories = {"bull": [], "bear": [], "trader": [], "invest_judge": [], "risk_manager": []}
current_date = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
print(f"\n🧠 Building historical memories for {ticker}")
print(f" Period: {start_date} to {end_date}")
print(f" Lookforward: {lookforward_days} days")
print(f" Sampling interval: {interval_days} days\n")
logger.info(f"🧠 Building historical memories for {ticker}")
logger.info(f"Period: {start_date} to {end_date}")
logger.info(f"Lookforward: {lookforward_days} days")
logger.info(f"Sampling interval: {interval_days} days")
sample_count = 0
while current_date <= end_dt:
date_str = current_date.strftime("%Y-%m-%d")
future_date_str = (current_date + timedelta(days=lookforward_days)).strftime("%Y-%m-%d")
print(f" 📊 Sampling {date_str}...", end=" ")
logger.info(f"📊 Sampling {date_str}...")
# Get historical data for this period
data = self._get_stock_data_for_period(ticker, date_str)
@ -946,42 +1001,49 @@ Recommendation: {risk_assessment}
returns = self._calculate_returns(ticker, date_str, future_date_str)
if returns is not None:
print(f"Return: {returns:+.2f}%")
logger.info(f"Return: {returns:+.2f}%")
# Create agent-specific memories
memories["bull"].append((
situation,
self._create_bull_researcher_memory(situation, returns, ticker, date_str)
))
memories["bull"].append(
(
situation,
self._create_bull_researcher_memory(situation, returns, ticker, date_str),
)
)
memories["bear"].append((
situation,
self._create_bear_researcher_memory(situation, returns, ticker, date_str)
))
memories["bear"].append(
(
situation,
self._create_bear_researcher_memory(situation, returns, ticker, date_str),
)
)
memories["trader"].append((
situation,
self._create_trader_memory(situation, returns, ticker, date_str)
))
memories["trader"].append(
(situation, self._create_trader_memory(situation, returns, ticker, date_str))
)
memories["invest_judge"].append((
situation,
self._create_invest_judge_memory(situation, returns, ticker, date_str)
))
memories["invest_judge"].append(
(
situation,
self._create_invest_judge_memory(situation, returns, ticker, date_str),
)
)
memories["risk_manager"].append((
situation,
self._create_risk_manager_memory(situation, returns, ticker, date_str)
))
memories["risk_manager"].append(
(
situation,
self._create_risk_manager_memory(situation, returns, ticker, date_str),
)
)
sample_count += 1
else:
print("⚠️ No data")
logger.warning("⚠️ No data")
# Move to next interval
current_date += timedelta(days=interval_days)
print(f"\n✅ Created {sample_count} memory samples for {ticker}")
logger.info(f"✅ Created {sample_count} memory samples for {ticker}")
for agent_type in memories:
self.memories_created[agent_type] += len(memories[agent_type])
@ -993,7 +1055,7 @@ Recommendation: {risk_assessment}
start_date: str,
end_date: str,
lookforward_days: int = 7,
interval_days: int = 30
interval_days: int = 30,
) -> Dict[str, FinancialSituationMemory]:
"""Build and populate memories for all agent types across multiple stocks.
@ -1013,12 +1075,12 @@ Recommendation: {risk_assessment}
"bear": FinancialSituationMemory("bear_memory", self.config),
"trader": FinancialSituationMemory("trader_memory", self.config),
"invest_judge": FinancialSituationMemory("invest_judge_memory", self.config),
"risk_manager": FinancialSituationMemory("risk_manager_memory", self.config)
"risk_manager": FinancialSituationMemory("risk_manager_memory", self.config),
}
print("=" * 70)
print("🏗️ HISTORICAL MEMORY BUILDER")
print("=" * 70)
logger.info("=" * 70)
logger.info("🏗️ HISTORICAL MEMORY BUILDER")
logger.info("=" * 70)
# Build memories for each ticker
for ticker in tickers:
@ -1027,7 +1089,7 @@ Recommendation: {risk_assessment}
start_date=start_date,
end_date=end_date,
lookforward_days=lookforward_days,
interval_days=interval_days
interval_days=interval_days,
)
# Add memories to each agent's memory store
@ -1036,12 +1098,12 @@ Recommendation: {risk_assessment}
agent_memories[agent_type].add_situations(memory_list)
# Print summary
print("\n" + "=" * 70)
print("📊 MEMORY CREATION SUMMARY")
print("=" * 70)
logger.info("=" * 70)
logger.info("📊 MEMORY CREATION SUMMARY")
logger.info("=" * 70)
for agent_type, count in self.memories_created.items():
print(f" {agent_type.ljust(15)}: {count} memories")
print("=" * 70 + "\n")
logger.info(f"{agent_type.ljust(15)}: {count} memories")
logger.info("=" * 70)
return agent_memories
@ -1060,19 +1122,19 @@ if __name__ == "__main__":
tickers=tickers,
start_date="2024-01-01",
end_date="2024-12-01",
lookforward_days=7, # 1-week returns
interval_days=30 # Sample monthly
lookforward_days=7, # 1-week returns
interval_days=30, # Sample monthly
)
# Test retrieval
test_situation = "Strong earnings beat with positive sentiment and bullish technical indicators in tech sector"
print("\n🔍 Testing memory retrieval...")
print(f"Query: {test_situation}\n")
logger.info("🔍 Testing memory retrieval...")
logger.info(f"Query: {test_situation}")
for agent_type, memory in memories.items():
print(f"\n{agent_type.upper()} MEMORIES:")
logger.info(f"\n{agent_type.upper()} MEMORIES:")
results = memory.get_memories(test_situation, n_matches=2)
for i, result in enumerate(results, 1):
print(f"\n Match {i} (similarity: {result['similarity_score']:.2f}):")
print(f" {result['recommendation'][:200]}...")
logger.info(f"\n Match {i} (similarity: {result['similarity_score']:.2f}):")
logger.info(f" {result['recommendation'][:200]}...")

View File

@ -0,0 +1,59 @@
from typing import Any, Dict, List, Union
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
def parse_llm_response(response_content: Union[str, List[Union[str, Dict[str, Any]]]]) -> str:
"""
Parse content from an LLM response, handling both string and list formats.
This function standardizes extraction of text from various LLM provider response formats
(e.g., standard strings vs Anthropic's block format).
Args:
response_content: The raw content field from an LLM response object.
Returns:
The extracted text content as a string.
"""
if isinstance(response_content, list):
return "\n".join(
block.get("text", str(block)) if isinstance(block, dict) else str(block)
for block in response_content
)
return str(response_content) if response_content is not None else ""
def create_and_invoke_chain(
llm: Any, tools: List[Any], system_message: str, messages: List[BaseMessage]
) -> Any:
"""
Create and invoke a standard agent chain with tools.
Args:
llm: The Language Model to use
tools: List of tools to bind to the LLM
system_message: The system prompt content
messages: The chat history messages
Returns:
The LLM response (AIMessage)
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", system_message),
MessagesPlaceholder(variable_name="messages"),
]
)
# Ensure at least one non-system message for Gemini compatibility
# Gemini API requires at least one HumanMessage in addition to SystemMessage
if not messages:
messages = [
HumanMessage(content="Please provide your analysis based on the context above.")
]
chain = prompt | llm.bind_tools(tools)
return chain.invoke({"messages": messages})

View File

@ -1,8 +1,12 @@
import os
from typing import Any, Dict, List, Optional, Tuple
import chromadb
from chromadb.config import Settings
from openai import OpenAI
from typing import List, Dict, Any, Optional, Tuple
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class FinancialSituationMemory:
@ -17,7 +21,7 @@ class FinancialSituationMemory:
self.embedding_backend = "https://api.openai.com/v1"
self.embedding = "text-embedding-3-small"
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.client = OpenAI(api_key=config.validate_key("openai_api_key", "OpenAI"))
# Use persistent storage in project directory
persist_directory = os.path.join(config.get("project_dir", "."), "memory_db")
@ -28,43 +32,52 @@ class FinancialSituationMemory:
# Get or create collection
try:
self.situation_collection = self.chroma_client.get_collection(name=name)
except:
except Exception:
self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text):
"""Get OpenAI embedding for a text"""
response = self.client.embeddings.create(
model=self.embedding, input=text
)
response = self.client.embeddings.create(model=self.embedding, input=text)
return response.data[0].embedding
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
def _batch_add(
self,
documents: List[str],
metadatas: List[Dict[str, Any]],
embeddings: List[List[float]],
ids: List[str] = None,
):
"""Internal helper to batch add documents to ChromaDB."""
if not documents:
return
situations = []
advice = []
ids = []
embeddings = []
offset = self.situation_collection.count()
for i, (situation, recommendation) in enumerate(situations_and_advice):
situations.append(situation)
advice.append(recommendation)
ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation))
if ids is None:
offset = self.situation_collection.count()
ids = [str(offset + i) for i in range(len(documents))]
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
documents=documents,
metadatas=metadatas,
embeddings=embeddings,
ids=ids,
)
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
situations = []
metadatas = []
embeddings = []
for situation, recommendation in situations_and_advice:
situations.append(situation)
metadatas.append({"recommendation": recommendation})
embeddings.append(self.get_embedding(situation))
self._batch_add(situations, metadatas, embeddings)
def add_situations_with_metadata(
self,
situations_and_outcomes: List[Tuple[str, str, Dict[str, Any]]]
self, situations_and_outcomes: List[Tuple[str, str, Dict[str, Any]]]
):
"""
Add financial situations with enhanced metadata for learning system.
@ -88,15 +101,11 @@ class FinancialSituationMemory:
- etc.
"""
situations = []
ids = []
embeddings = []
metadatas = []
embeddings = []
offset = self.situation_collection.count()
for i, (situation, recommendation, metadata) in enumerate(situations_and_outcomes):
for situation, recommendation, metadata in situations_and_outcomes:
situations.append(situation)
ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation))
# Merge recommendation with metadata
@ -107,12 +116,7 @@ class FinancialSituationMemory:
full_metadata = self._sanitize_metadata(full_metadata)
metadatas.append(full_metadata)
self.situation_collection.add(
documents=situations,
metadatas=metadatas,
embeddings=embeddings,
ids=ids,
)
self._batch_add(situations, metadatas, embeddings)
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
"""
@ -164,7 +168,7 @@ class FinancialSituationMemory:
current_situation: str,
signal_filters: Optional[Dict[str, Any]] = None,
n_matches: int = 3,
min_similarity: float = 0.5
min_similarity: float = 0.5,
) -> List[Dict[str, Any]]:
"""
Hybrid search: Filter by structured signals, then rank by embedding similarity.
@ -216,18 +220,20 @@ class FinancialSituationMemory:
metadata = results["metadatas"][0][i]
matched_results.append({
"matched_situation": results["documents"][0][i],
"recommendation": metadata.get("recommendation", ""),
"similarity_score": similarity_score,
"metadata": metadata,
# Extract key fields for convenience
"ticker": metadata.get("ticker", ""),
"move_pct": metadata.get("move_pct", 0),
"move_direction": metadata.get("move_direction", ""),
"was_correct": metadata.get("was_correct", False),
"days_before_move": metadata.get("days_before_move", 0),
})
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": metadata.get("recommendation", ""),
"similarity_score": similarity_score,
"metadata": metadata,
# Extract key fields for convenience
"ticker": metadata.get("ticker", ""),
"move_pct": metadata.get("move_pct", 0),
"move_direction": metadata.get("move_direction", ""),
"was_correct": metadata.get("was_correct", False),
"days_before_move": metadata.get("days_before_move", 0),
}
)
# Return top n_matches
return matched_results[:n_matches]
@ -250,13 +256,11 @@ class FinancialSituationMemory:
"total_memories": 0,
"accuracy_rate": 0.0,
"avg_move_pct": 0.0,
"signal_distribution": {}
"signal_distribution": {},
}
# Get all memories
all_results = self.situation_collection.get(
include=["metadatas"]
)
all_results = self.situation_collection.get(include=["metadatas"])
metadatas = all_results["metadatas"]
@ -283,7 +287,7 @@ class FinancialSituationMemory:
"total_memories": total_count,
"accuracy_rate": accuracy_rate,
"avg_move_pct": avg_move_pct,
"signal_distribution": signal_distribution
"signal_distribution": signal_distribution,
}
@ -324,10 +328,10 @@ if __name__ == "__main__":
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
logger.info(f"Match {i}:")
logger.info(f"Similarity Score: {rec['similarity_score']:.2f}")
logger.info(f"Matched Situation: {rec['matched_situation']}")
logger.info(f"Recommendation: {rec['recommendation']}")
except Exception as e:
print(f"Error during recommendation: {str(e)}")
logger.error(f"Error during recommendation: {str(e)}")

View File

@ -38,11 +38,11 @@ def get_date_awareness_section(current_date: str) -> str:
def validate_analyst_output(report: str, required_sections: list) -> dict:
"""
Validate that report contains all required sections.
Args:
report: The analyst report text to validate
required_sections: List of section names to check for
Returns:
Dictionary mapping section names to boolean (True if found)
"""
@ -50,28 +50,23 @@ def validate_analyst_output(report: str, required_sections: list) -> dict:
for section in required_sections:
# Check if section header exists (with ### or ##)
validation[section] = (
f"### {section}" in report
or f"## {section}" in report
or f"**{section}**" in report
f"### {section}" in report or f"## {section}" in report or f"**{section}**" in report
)
return validation
def format_analyst_prompt(
system_message: str,
current_date: str,
ticker: str,
tool_names: str
system_message: str, current_date: str, ticker: str, tool_names: str
) -> str:
"""
Format a complete analyst prompt with boilerplate and context.
Args:
system_message: The agent-specific system message
current_date: Current analysis date
ticker: Stock ticker symbol
tool_names: Comma-separated list of tool names
Returns:
Formatted prompt string
"""
@ -79,4 +74,3 @@ def format_analyst_prompt(
f"{BASE_COLLABORATIVE_BOILERPLATE}\n\n{system_message}\n\n"
f"Context: {ticker} | Date: {current_date} | Tools: {tool_names}"
)

View File

@ -1,7 +1,10 @@
from langchain_core.tools import tool
from typing import Annotated
from langchain_core.tools import tool
from tradingagents.tools.executor import execute_tool
@tool
def get_tweets(
query: Annotated[str, "Search query for tweets (e.g. ticker symbol or topic)"],
@ -18,6 +21,7 @@ def get_tweets(
"""
return execute_tool("get_tweets", query=query, count=count)
@tool
def get_tweets_from_user(
username: Annotated[str, "Twitter username (without @) to fetch tweets from"],

121
tradingagents/config.py Normal file
View File

@ -0,0 +1,121 @@
import os
from typing import Any, Optional
from dotenv import load_dotenv
from tradingagents.default_config import DEFAULT_CONFIG
# Load environment variables from .env file
load_dotenv()
class Config:
"""
Centralized configuration management.
Merges environment variables with default configuration.
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(Config, cls).__new__(cls)
cls._instance._initialize()
return cls._instance
def _initialize(self):
self._defaults = DEFAULT_CONFIG
self._env_cache = {}
def _get_env(self, key: str, default: Any = None) -> Any:
"""Helper to get env var with optional default from config dictionary."""
val = os.getenv(key)
if val is not None:
return val
return default
# --- API Keys ---
@property
def openai_api_key(self) -> Optional[str]:
return self._get_env("OPENAI_API_KEY")
@property
def alpha_vantage_api_key(self) -> Optional[str]:
return self._get_env("ALPHA_VANTAGE_API_KEY")
@property
def finnhub_api_key(self) -> Optional[str]:
return self._get_env("FINNHUB_API_KEY")
@property
def tradier_api_key(self) -> Optional[str]:
return self._get_env("TRADIER_API_KEY")
@property
def fmp_api_key(self) -> Optional[str]:
return self._get_env("FMP_API_KEY")
@property
def reddit_client_id(self) -> Optional[str]:
return self._get_env("REDDIT_CLIENT_ID")
@property
def reddit_client_secret(self) -> Optional[str]:
return self._get_env("REDDIT_CLIENT_SECRET")
@property
def reddit_user_agent(self) -> str:
return self._get_env("REDDIT_USER_AGENT", "TradingAgents/1.0")
@property
def twitter_bearer_token(self) -> Optional[str]:
return self._get_env("TWITTER_BEARER_TOKEN")
@property
def serper_api_key(self) -> Optional[str]:
return self._get_env("SERPER_API_KEY")
@property
def gemini_api_key(self) -> Optional[str]:
return self._get_env("GEMINI_API_KEY")
# --- Paths and Settings ---
@property
def results_dir(self) -> str:
return self._defaults.get("results_dir", "./results")
@property
def user_workspace(self) -> str:
return self._get_env("USER_WORKSPACE", self._defaults.get("project_dir"))
# --- Methods ---
def validate_key(self, key_property: str, service_name: str) -> str:
"""
Validate that a specific API key property is set.
Returns the key if valid, raises ValueError otherwise.
"""
key = getattr(self, key_property)
if not key:
raise ValueError(
f"{service_name} API Key not found. Please set correct environment variable."
)
return key
def get(self, key: str, default: Any = None) -> Any:
"""
Get configuration value.
Checks properties first, then defaults.
"""
if hasattr(self, key):
val = getattr(self, key)
if val is not None:
return val
return self._defaults.get(key, default)
# Global config instance
config = Config()

View File

@ -1,5 +1,28 @@
# Import functions from specialized modules
from .alpha_vantage_fundamentals import (
get_balance_sheet,
get_cashflow,
get_fundamentals,
get_income_statement,
)
from .alpha_vantage_news import (
get_global_news,
get_insider_sentiment,
get_insider_transactions,
get_news,
)
from .alpha_vantage_stock import get_stock, get_top_gainers_losers
from .alpha_vantage_indicator import get_indicator
from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement
from .alpha_vantage_news import get_news, get_insider_transactions, get_insider_sentiment, get_global_news
__all__ = [
"get_stock",
"get_top_gainers_losers",
"get_fundamentals",
"get_balance_sheet",
"get_cashflow",
"get_income_statement",
"get_news",
"get_global_news",
"get_insider_transactions",
"get_insider_sentiment",
]

View File

@ -3,17 +3,19 @@ Alpha Vantage Analyst Rating Changes Detection
Tracks recent analyst upgrades/downgrades and price target changes
"""
import os
import requests
import json
from datetime import datetime, timedelta
from typing import Annotated, List
from typing import Annotated, Dict, List, Union
from .alpha_vantage_common import _make_api_request
def get_analyst_rating_changes(
lookback_days: Annotated[int, "Number of days to look back for rating changes"] = 7,
change_types: Annotated[List[str], "Types of changes to track"] = None,
top_n: Annotated[int, "Number of top results to return"] = 20,
) -> str:
return_structured: Annotated[bool, "Return list of dicts instead of markdown"] = False,
) -> Union[List[Dict], str]:
"""
Track recent analyst upgrades/downgrades and rating changes.
@ -23,14 +25,12 @@ def get_analyst_rating_changes(
lookback_days: Number of days to look back (default 7)
change_types: Types of changes ["upgrade", "downgrade", "initiated", "reiterated"]
top_n: Maximum number of results to return
return_structured: If True, returns list of dicts instead of markdown
Returns:
Formatted markdown report of recent analyst rating changes
If return_structured=True: list of analyst change dicts
If return_structured=False: Formatted markdown report
"""
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
if not api_key:
return "Error: ALPHA_VANTAGE_API_KEY not set in environment variables"
if change_types is None:
change_types = ["upgrade", "downgrade", "initiated"]
@ -38,26 +38,31 @@ def get_analyst_rating_changes(
# We'll use news sentiment API which includes analyst actions
# For production, consider using Financial Modeling Prep or Benzinga API
url = "https://www.alphavantage.co/query"
try:
# Get market news which includes analyst actions
params = {
"function": "NEWS_SENTIMENT",
"topics": "earnings,technology,finance",
"sort": "LATEST",
"limit": 200, # Get more news to find analyst actions
"apikey": api_key,
"limit": "200", # Get more news to find analyst actions
}
response = requests.get(url, params=params, timeout=30)
response.raise_for_status()
data = response.json()
response_text = _make_api_request("NEWS_SENTIMENT", params)
try:
data = json.loads(response_text)
except json.JSONDecodeError:
if return_structured:
return []
return f"API Error: Failed to parse JSON response: {response_text[:100]}"
if "Note" in data:
if return_structured:
return []
return f"API Rate Limit: {data['Note']}"
if "Error Message" in data:
if return_structured:
return []
return f"API Error: {data['Error Message']}"
# Parse news for analyst actions
@ -79,10 +84,21 @@ def get_analyst_rating_changes(
text = f"{title} {summary}"
# Look for analyst action keywords
is_upgrade = any(word in text for word in ["upgrade", "upgrades", "raised", "raises rating"])
is_downgrade = any(word in text for word in ["downgrade", "downgrades", "lowered", "lowers rating"])
is_initiated = any(word in text for word in ["initiates", "initiated", "coverage", "starts coverage"])
is_reiterated = any(word in text for word in ["reiterates", "reiterated", "maintains", "confirms"])
is_upgrade = any(
word in text for word in ["upgrade", "upgrades", "raised", "raises rating"]
)
is_downgrade = any(
word in text
for word in ["downgrade", "downgrades", "lowered", "lowers rating"]
)
is_initiated = any(
word in text
for word in ["initiates", "initiated", "coverage", "starts coverage"]
)
is_reiterated = any(
word in text
for word in ["reiterates", "reiterated", "maintains", "confirms"]
)
# Extract tickers from article
tickers = []
@ -108,36 +124,44 @@ def get_analyst_rating_changes(
hours_old = (datetime.now() - article_date).total_seconds() / 3600
for ticker in tickers[:3]: # Max 3 tickers per article
analyst_changes.append({
"ticker": ticker,
"action": action_type,
"date": time_published[:8],
"hours_old": int(hours_old),
"headline": article.get("title", "")[:100],
"source": article.get("source", "Unknown"),
"url": article.get("url", ""),
})
analyst_changes.append(
{
"ticker": ticker,
"action": action_type,
"date": time_published[:8],
"hours_old": int(hours_old),
"headline": article.get("title", "")[:100],
"source": article.get("source", "Unknown"),
"url": article.get("url", ""),
}
)
except (ValueError, KeyError) as e:
except (ValueError, KeyError):
continue
# Remove duplicates (keep most recent per ticker)
seen_tickers = {}
for change in analyst_changes:
ticker = change["ticker"]
if ticker not in seen_tickers or change["hours_old"] < seen_tickers[ticker]["hours_old"]:
if (
ticker not in seen_tickers
or change["hours_old"] < seen_tickers[ticker]["hours_old"]
):
seen_tickers[ticker] = change
# Sort by freshness (most recent first)
sorted_changes = sorted(
seen_tickers.values(),
key=lambda x: x["hours_old"]
)[:top_n]
sorted_changes = sorted(seen_tickers.values(), key=lambda x: x["hours_old"])[:top_n]
# Format output
if not sorted_changes:
if return_structured:
return []
return f"No analyst rating changes found in the last {lookback_days} days"
# Return structured data if requested
if return_structured:
return sorted_changes
report = f"# Analyst Rating Changes - Last {lookback_days} Days\n\n"
report += f"**Tracking**: {', '.join(change_types)}\n\n"
report += f"**Found**: {len(sorted_changes)} recent analyst actions\n\n"
@ -146,7 +170,11 @@ def get_analyst_rating_changes(
report += "|--------|--------|--------|-----------|----------|\n"
for change in sorted_changes:
freshness = "🔥 FRESH" if change["hours_old"] < 24 else "🟢 Recent" if change["hours_old"] < 72 else "Older"
freshness = (
"🔥 FRESH"
if change["hours_old"] < 24
else "🟢 Recent" if change["hours_old"] < 72 else "Older"
)
report += f"| {change['ticker']} | "
report += f"{change['action'].upper()} | "
@ -161,9 +189,9 @@ def get_analyst_rating_changes(
return report
except requests.exceptions.RequestException as e:
return f"Error fetching analyst rating changes: {str(e)}"
except Exception as e:
if return_structured:
return []
return f"Unexpected error in analyst rating detection: {str(e)}"

View File

@ -1,25 +1,29 @@
import os
import requests
import pandas as pd
import json
from datetime import datetime
from io import StringIO
from typing import Union
import pandas as pd
import requests
from tradingagents.config import config
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
API_BASE_URL = "https://www.alphavantage.co/query"
def get_api_key() -> str:
"""Retrieve the API key for Alpha Vantage from environment variables."""
api_key = os.getenv("ALPHA_VANTAGE_API_KEY")
if not api_key:
raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.")
return api_key
return config.validate_key("alpha_vantage_api_key", "Alpha Vantage")
def format_datetime_for_api(date_input) -> str:
"""Convert various date formats to YYYYMMDDTHHMM format required by Alpha Vantage API."""
if isinstance(date_input, str):
# If already in correct format, return as-is
if len(date_input) == 13 and 'T' in date_input:
if len(date_input) == 13 and "T" in date_input:
return date_input
# Try to parse common date formats
try:
@ -36,39 +40,44 @@ def format_datetime_for_api(date_input) -> str:
else:
raise ValueError(f"Date must be string or datetime object, got {type(date_input)}")
class AlphaVantageRateLimitError(Exception):
"""Exception raised when Alpha Vantage API rate limit is exceeded."""
pass
def _make_api_request(function_name: str, params: dict) -> Union[dict, str]:
"""Helper function to make API requests and handle responses.
Raises:
AlphaVantageRateLimitError: When API rate limit is exceeded
"""
# Create a copy of params to avoid modifying the original
api_params = params.copy()
api_params.update({
"function": function_name,
"apikey": get_api_key(),
"source": "trading_agents",
})
api_params.update(
{
"function": function_name,
"apikey": get_api_key(),
"source": "trading_agents",
}
)
# Handle entitlement parameter if present in params or global variable
current_entitlement = globals().get('_current_entitlement')
current_entitlement = globals().get("_current_entitlement")
entitlement = api_params.get("entitlement") or current_entitlement
if entitlement:
api_params["entitlement"] = entitlement
elif "entitlement" in api_params:
# Remove entitlement if it's None or empty
api_params.pop("entitlement", None)
response = requests.get(API_BASE_URL, params=api_params)
response.raise_for_status()
response_text = response.text
# Check if response is JSON (error responses are typically JSON)
try:
response_json = json.loads(response_text)
@ -76,7 +85,9 @@ def _make_api_request(function_name: str, params: dict) -> Union[dict, str]:
if "Information" in response_json:
info_message = response_json["Information"]
if "rate limit" in info_message.lower() or "api key" in info_message.lower():
raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}")
raise AlphaVantageRateLimitError(
f"Alpha Vantage rate limit exceeded: {info_message}"
)
except json.JSONDecodeError:
# Response is not JSON (likely CSV data), which is normal
pass
@ -84,7 +95,6 @@ def _make_api_request(function_name: str, params: dict) -> Union[dict, str]:
return response_text
def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> str:
"""
Filter CSV data to include only rows within the specified date range.
@ -119,5 +129,5 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
except Exception as e:
# If filtering fails, return original data with a warning
print(f"Warning: Failed to filter CSV data by date range: {e}")
logger.warning(f"Failed to filter CSV data by date range: {e}")
return csv_data

View File

@ -74,4 +74,3 @@ def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str =
}
return _make_api_request("INCOME_STATEMENT", params)

View File

@ -1,5 +1,10 @@
from tradingagents.utils.logger import get_logger
from .alpha_vantage_common import _make_api_request
logger = get_logger(__name__)
def get_indicator(
symbol: str,
indicator: str,
@ -7,7 +12,7 @@ def get_indicator(
look_back_days: int,
interval: str = "daily",
time_period: int = 14,
series_type: str = "close"
series_type: str = "close",
) -> str:
"""
Returns Alpha Vantage technical indicator values over a time window.
@ -25,6 +30,7 @@ def get_indicator(
String containing indicator values and description
"""
from datetime import datetime
from dateutil.relativedelta import relativedelta
supported_indicators = {
@ -39,7 +45,7 @@ def get_indicator(
"boll_ub": ("Bollinger Upper Band", "close"),
"boll_lb": ("Bollinger Lower Band", "close"),
"atr": ("ATR", None),
"vwma": ("VWMA", "close")
"vwma": ("VWMA", "close"),
}
indicator_descriptions = {
@ -54,7 +60,7 @@ def get_indicator(
"boll_ub": "Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.",
"boll_lb": "Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.",
"atr": "ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.",
"vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses."
"vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.",
}
if indicator not in supported_indicators:
@ -75,73 +81,100 @@ def get_indicator(
try:
# Get indicator data for the period
if indicator == "close_50_sma":
data = _make_api_request("SMA", {
"symbol": symbol,
"interval": interval,
"time_period": "50",
"series_type": series_type,
"datatype": "csv"
})
data = _make_api_request(
"SMA",
{
"symbol": symbol,
"interval": interval,
"time_period": "50",
"series_type": series_type,
"datatype": "csv",
},
)
elif indicator == "close_200_sma":
data = _make_api_request("SMA", {
"symbol": symbol,
"interval": interval,
"time_period": "200",
"series_type": series_type,
"datatype": "csv"
})
data = _make_api_request(
"SMA",
{
"symbol": symbol,
"interval": interval,
"time_period": "200",
"series_type": series_type,
"datatype": "csv",
},
)
elif indicator == "close_10_ema":
data = _make_api_request("EMA", {
"symbol": symbol,
"interval": interval,
"time_period": "10",
"series_type": series_type,
"datatype": "csv"
})
data = _make_api_request(
"EMA",
{
"symbol": symbol,
"interval": interval,
"time_period": "10",
"series_type": series_type,
"datatype": "csv",
},
)
elif indicator == "macd":
data = _make_api_request("MACD", {
"symbol": symbol,
"interval": interval,
"series_type": series_type,
"datatype": "csv"
})
data = _make_api_request(
"MACD",
{
"symbol": symbol,
"interval": interval,
"series_type": series_type,
"datatype": "csv",
},
)
elif indicator == "macds":
data = _make_api_request("MACD", {
"symbol": symbol,
"interval": interval,
"series_type": series_type,
"datatype": "csv"
})
data = _make_api_request(
"MACD",
{
"symbol": symbol,
"interval": interval,
"series_type": series_type,
"datatype": "csv",
},
)
elif indicator == "macdh":
data = _make_api_request("MACD", {
"symbol": symbol,
"interval": interval,
"series_type": series_type,
"datatype": "csv"
})
data = _make_api_request(
"MACD",
{
"symbol": symbol,
"interval": interval,
"series_type": series_type,
"datatype": "csv",
},
)
elif indicator == "rsi":
data = _make_api_request("RSI", {
"symbol": symbol,
"interval": interval,
"time_period": str(time_period),
"series_type": series_type,
"datatype": "csv"
})
data = _make_api_request(
"RSI",
{
"symbol": symbol,
"interval": interval,
"time_period": str(time_period),
"series_type": series_type,
"datatype": "csv",
},
)
elif indicator in ["boll", "boll_ub", "boll_lb"]:
data = _make_api_request("BBANDS", {
"symbol": symbol,
"interval": interval,
"time_period": "20",
"series_type": series_type,
"datatype": "csv"
})
data = _make_api_request(
"BBANDS",
{
"symbol": symbol,
"interval": interval,
"time_period": "20",
"series_type": series_type,
"datatype": "csv",
},
)
elif indicator == "atr":
data = _make_api_request("ATR", {
"symbol": symbol,
"interval": interval,
"time_period": str(time_period),
"datatype": "csv"
})
data = _make_api_request(
"ATR",
{
"symbol": symbol,
"interval": interval,
"time_period": str(time_period),
"datatype": "csv",
},
)
elif indicator == "vwma":
# Alpha Vantage doesn't have direct VWMA, so we'll return an informative message
# In a real implementation, this would need to be calculated from OHLCV data
@ -150,23 +183,30 @@ def get_indicator(
return f"Error: Indicator {indicator} not implemented yet."
# Parse CSV data and extract values for the date range
lines = data.strip().split('\n')
lines = data.strip().split("\n")
if len(lines) < 2:
return f"Error: No data returned for {indicator}"
# Parse header and data
header = [col.strip() for col in lines[0].split(',')]
header = [col.strip() for col in lines[0].split(",")]
try:
date_col_idx = header.index('time')
date_col_idx = header.index("time")
except ValueError:
return f"Error: 'time' column not found in data for {indicator}. Available columns: {header}"
# Map internal indicator names to expected CSV column names from Alpha Vantage
col_name_map = {
"macd": "MACD", "macds": "MACD_Signal", "macdh": "MACD_Hist",
"boll": "Real Middle Band", "boll_ub": "Real Upper Band", "boll_lb": "Real Lower Band",
"rsi": "RSI", "atr": "ATR", "close_10_ema": "EMA",
"close_50_sma": "SMA", "close_200_sma": "SMA"
"macd": "MACD",
"macds": "MACD_Signal",
"macdh": "MACD_Hist",
"boll": "Real Middle Band",
"boll_ub": "Real Upper Band",
"boll_lb": "Real Lower Band",
"rsi": "RSI",
"atr": "ATR",
"close_10_ema": "EMA",
"close_50_sma": "SMA",
"close_200_sma": "SMA",
}
target_col_name = col_name_map.get(indicator)
@ -184,7 +224,7 @@ def get_indicator(
for line in lines[1:]:
if not line.strip():
continue
values = line.split(',')
values = line.split(",")
if len(values) > value_col_idx:
try:
date_str = values[date_col_idx].strip()
@ -218,5 +258,5 @@ def get_indicator(
return result_str
except Exception as e:
print(f"Error getting Alpha Vantage indicator data for {indicator}: {e}")
logger.error(f"Error getting Alpha Vantage indicator data for {indicator}: {e}")
return f"Error retrieving {indicator} data: {str(e)}"

View File

@ -1,7 +1,11 @@
from typing import Union, Dict, Optional
from typing import Dict, Union
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
def get_news(ticker: str = None, start_date: str = None, end_date: str = None, query: str = None) -> Union[Dict[str, str], str]:
def get_news(
ticker: str = None, start_date: str = None, end_date: str = None, query: str = None
) -> Union[Dict[str, str], str]:
"""Returns live and historical market news & sentiment data.
Args:
@ -25,11 +29,13 @@ def get_news(ticker: str = None, start_date: str = None, end_date: str = None, q
"sort": "LATEST",
"limit": "50",
}
return _make_api_request("NEWS_SENTIMENT", params)
def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> Union[Dict[str, str], str]:
def get_global_news(
date: str, look_back_days: int = 7, limit: int = 5
) -> Union[Dict[str, str], str]:
"""Returns global market news & sentiment data.
Args:
@ -49,7 +55,41 @@ def get_global_news(date: str, look_back_days: int = 7, limit: int = 5) -> Union
return _make_api_request("NEWS_SENTIMENT", params)
def get_insider_transactions(symbol: str = None, ticker: str = None, curr_date: str = None) -> Union[Dict[str, str], str]:
def get_alpha_vantage_news_feed(
topics: str = None, time_from: str = None, limit: int = 50
) -> Union[Dict[str, str], str]:
"""Returns news feed from Alpha Vantage with optional topic filtering.
Args:
topics: Comma-separated topics (e.g., "technology,finance,earnings").
Valid topics: blockchain, earnings, ipo, mergers_and_acquisitions,
financial_markets, economy_fiscal, economy_monetary, economy_macro,
energy_transportation, finance, life_sciences, manufacturing,
real_estate, retail_wholesale, technology
time_from: Start time in format YYYYMMDDTHHMM (e.g., "20240101T0000").
limit: Maximum number of articles to return.
Returns:
Dictionary containing news sentiment data or JSON string.
"""
params = {
"sort": "LATEST",
"limit": str(limit),
}
if topics:
params["topics"] = topics
if time_from:
params["time_from"] = time_from
return _make_api_request("NEWS_SENTIMENT", params)
def get_insider_transactions(
symbol: str = None, ticker: str = None, curr_date: str = None
) -> Union[Dict[str, str], str]:
"""Returns latest and historical insider transactions.
Args:
@ -70,14 +110,15 @@ def get_insider_transactions(symbol: str = None, ticker: str = None, curr_date:
return _make_api_request("INSIDER_TRANSACTIONS", params)
def get_insider_sentiment(symbol: str = None, ticker: str = None, curr_date: str = None) -> str:
"""Returns insider sentiment data derived from Alpha Vantage transactions.
Args:
symbol: Ticker symbol.
ticker: Alias for symbol.
curr_date: Current date.
Returns:
Formatted string containing insider sentiment analysis.
"""
@ -87,24 +128,24 @@ def get_insider_sentiment(symbol: str = None, ticker: str = None, curr_date: str
import json
from datetime import datetime, timedelta
# Fetch transactions
params = {
"symbol": target_symbol,
}
response_text = _make_api_request("INSIDER_TRANSACTIONS", params)
try:
data = json.loads(response_text)
if "Information" in data:
return f"Error: {data['Information']}"
# Alpha Vantage INSIDER_TRANSACTIONS returns a dictionary with "symbol" and "data" (list)
# or sometimes just the list depending on the endpoint version, but usually it's under a key.
# Let's handle the standard response structure.
# Based on docs, it returns CSV by default? No, _make_api_request handles JSON.
# Actually, Alpha Vantage INSIDER_TRANSACTIONS returns JSON by default.
# Structure check
transactions = []
if "data" in data:
@ -114,16 +155,16 @@ def get_insider_sentiment(symbol: str = None, ticker: str = None, curr_date: str
else:
# If we can't find the list, return the raw text
return f"Raw Data: {str(data)[:500]}"
# Filter and Aggregate
# We want recent transactions (e.g. last 3 months)
if curr_date:
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
else:
curr_dt = datetime.now()
start_dt = curr_dt - timedelta(days=90)
relevant_txs = []
for tx in transactions:
# Date format in AV is usually YYYY-MM-DD
@ -132,44 +173,44 @@ def get_insider_sentiment(symbol: str = None, ticker: str = None, curr_date: str
if not tx_date_str:
continue
tx_date = datetime.strptime(tx_date_str, "%Y-%m-%d")
if start_dt <= tx_date <= curr_dt:
relevant_txs.append(tx)
except ValueError:
continue
if not relevant_txs:
return f"No insider transactions found for {symbol} in the 90 days before {curr_date}."
# Calculate metrics
total_bought = 0
total_sold = 0
net_shares = 0
for tx in relevant_txs:
shares = int(float(tx.get("shares", 0)))
# acquisition_or_disposal: "A" (Acquisition) or "D" (Disposal)
# transaction_code: "P" (Purchase), "S" (Sale)
# We can use acquisition_or_disposal if available, or transaction_code
code = tx.get("acquisition_or_disposal")
if not code:
# Fallback to transaction code logic if needed, but A/D is standard for AV
pass
if code == "A":
total_bought += shares
net_shares += shares
elif code == "D":
total_sold += shares
net_shares -= shares
sentiment = "NEUTRAL"
if net_shares > 0:
sentiment = "POSITIVE"
elif net_shares < 0:
sentiment = "NEGATIVE"
report = f"## Insider Sentiment for {symbol} (Last 90 Days)\n"
report += f"**Overall Sentiment:** {sentiment}\n"
report += f"**Net Shares:** {net_shares:,}\n"
@ -177,13 +218,13 @@ def get_insider_sentiment(symbol: str = None, ticker: str = None, curr_date: str
report += f"**Total Sold:** {total_sold:,}\n"
report += f"**Transaction Count:** {len(relevant_txs)}\n\n"
report += "### Recent Transactions:\n"
# List top 5 recent
relevant_txs.sort(key=lambda x: x.get("transaction_date", ""), reverse=True)
for tx in relevant_txs[:5]:
report += f"- {tx.get('transaction_date')}: {tx.get('executive')} - {tx.get('acquisition_or_disposal')} {tx.get('shares')} shares at ${tx.get('transaction_price')}\n"
return report
except Exception as e:
return f"Error processing insider sentiment: {str(e)}\nRaw response: {response_text[:200]}"
return f"Error processing insider sentiment: {str(e)}\nRaw response: {response_text[:200]}"

View File

@ -1,11 +1,9 @@
from datetime import datetime
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
def get_stock(
symbol: str,
start_date: str,
end_date: str
) -> str:
from .alpha_vantage_common import _filter_csv_by_date_range, _make_api_request
def get_stock(symbol: str, start_date: str, end_date: str) -> str:
"""
Returns raw daily OHLCV values, adjusted close values, and historical split/dividend events
filtered to the specified date range.
@ -38,48 +36,77 @@ def get_stock(
return _filter_csv_by_date_range(response, start_date, end_date)
def get_top_gainers_losers(limit: int = 10) -> str:
def get_top_gainers_losers(limit: int = 10, return_structured: bool = False):
"""
Returns the top gainers, losers, and most active stocks from Alpha Vantage.
Args:
limit: Maximum number of items per category
return_structured: If True, returns dict with raw data instead of markdown
Returns:
If return_structured=True: dict with 'gainers', 'losers', 'most_active' lists
If return_structured=False: Formatted markdown string
"""
params = {}
# This returns a JSON string
response_text = _make_api_request("TOP_GAINERS_LOSERS", params)
try:
import json
data = json.loads(response_text)
if "top_gainers" not in data:
if return_structured:
return {"error": f"Unexpected response format: {response_text[:200]}..."}
return f"Error: Unexpected response format: {response_text[:200]}..."
# Apply limit to data
gainers = data.get("top_gainers", [])[:limit]
losers = data.get("top_losers", [])[:limit]
most_active = data.get("most_actively_traded", [])[:limit]
# Return structured data if requested
if return_structured:
return {
"gainers": gainers,
"losers": losers,
"most_active": most_active,
}
# Format as markdown report
report = "## Top Market Movers (Alpha Vantage)\n\n"
# Top Gainers
report += "### Top Gainers\n"
report += "| Ticker | Price | Change % | Volume |\n"
report += "|--------|-------|----------|--------|\n"
for item in data.get("top_gainers", [])[:limit]:
for item in gainers:
report += f"| {item['ticker']} | {item['price']} | {item['change_percentage']} | {item['volume']} |\n"
# Top Losers
report += "\n### Top Losers\n"
report += "| Ticker | Price | Change % | Volume |\n"
report += "|--------|-------|----------|--------|\n"
for item in data.get("top_losers", [])[:limit]:
for item in losers:
report += f"| {item['ticker']} | {item['price']} | {item['change_percentage']} | {item['volume']} |\n"
# Most Active
report += "\n### Most Active\n"
report += "| Ticker | Price | Change % | Volume |\n"
report += "|--------|-------|----------|--------|\n"
for item in data.get("most_actively_traded", [])[:limit]:
for item in most_active:
report += f"| {item['ticker']} | {item['price']} | {item['change_percentage']} | {item['volume']} |\n"
return report
except json.JSONDecodeError:
if return_structured:
return {"error": f"Failed to parse JSON response: {response_text[:200]}..."}
return f"Error: Failed to parse JSON response: {response_text[:200]}..."
except Exception as e:
return f"Error processing market movers: {str(e)}"
if return_structured:
return {"error": str(e)}
return f"Error processing market movers: {str(e)}"

View File

@ -3,26 +3,27 @@ Unusual Volume Detection using yfinance
Identifies stocks with unusual volume but minimal price movement (accumulation signal)
"""
from datetime import datetime
from typing import Annotated, List, Dict, Optional, Union
import hashlib
import pandas as pd
import yfinance as yf
import json
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tradingagents.dataflows.y_finance import _get_ticker_universe
from datetime import datetime
from pathlib import Path
from typing import Annotated, Dict, List, Optional, Union
import pandas as pd
from tradingagents.dataflows.y_finance import _get_ticker_universe, get_ticker_history
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def _get_cache_path(
ticker_universe: Union[str, List[str]]
) -> Path:
def _get_cache_path(ticker_universe: Union[str, List[str]]) -> Path:
"""
Get the cache file path for unusual volume raw data.
Args:
ticker_universe: Universe identifier
Returns:
Path to cache file
"""
@ -30,7 +31,7 @@ def _get_cache_path(
current_file = Path(__file__)
cache_dir = current_file.parent / "data_cache"
cache_dir.mkdir(exist_ok=True)
# Create cache key from universe only (thresholds are applied later)
if isinstance(ticker_universe, str):
universe_key = ticker_universe
@ -40,38 +41,38 @@ def _get_cache_path(
hash_suffix = hashlib.md5(",".join(sorted(clean_tickers)).encode()).hexdigest()[:8]
universe_key = f"custom_{hash_suffix}"
cache_key = f"unusual_volume_raw_{universe_key}".replace(".", "_")
return cache_dir / f"{cache_key}.json"
def _load_cache(cache_path: Path) -> Optional[Dict]:
"""
Load cached unusual volume raw data if it exists and is from today.
Args:
cache_path: Path to cache file
Returns:
Cached results dict if valid, None otherwise
"""
if not cache_path.exists():
return None
try:
with open(cache_path, 'r') as f:
with open(cache_path, "r") as f:
cache_data = json.load(f)
# Check if cache is from today
cache_date = cache_data.get('date')
today = datetime.now().strftime('%Y-%m-%d')
has_raw_data = bool(cache_data.get('raw_data'))
cache_date = cache_data.get("date")
today = datetime.now().strftime("%Y-%m-%d")
has_raw_data = bool(cache_data.get("raw_data"))
if cache_date == today and has_raw_data:
return cache_data
else:
# Cache is stale, return None to trigger recompute
return None
except Exception:
# If cache is corrupted, return None to trigger recompute
return None
@ -80,35 +81,38 @@ def _load_cache(cache_path: Path) -> Optional[Dict]:
def _save_cache(cache_path: Path, raw_data: Dict[str, List[Dict]], date: str):
"""
Save unusual volume raw data to cache.
Args:
cache_path: Path to cache file
raw_data: Raw ticker data to cache
date: Date string (YYYY-MM-DD)
"""
try:
cache_data = {
'date': date,
'raw_data': raw_data,
'timestamp': datetime.now().isoformat()
}
with open(cache_path, 'w') as f:
cache_data = {"date": date, "raw_data": raw_data, "timestamp": datetime.now().isoformat()}
with open(cache_path, "w") as f:
json.dump(cache_data, f, indent=2)
except Exception as e:
# If caching fails, just continue without cache
print(f"Warning: Could not save cache: {e}")
logger.warning(f"Could not save cache: {e}")
def _history_to_records(hist: pd.DataFrame) -> List[Dict[str, Union[str, float, int]]]:
"""Convert a yfinance history DataFrame to a cache-friendly list of dicts."""
hist_for_cache = hist[["Close", "Volume"]].copy()
# Include Open price for intraday direction analysis (accumulation vs distribution)
cols_to_use = ["Close", "Volume"]
if "Open" in hist.columns:
cols_to_use = ["Open", "Close", "Volume"]
hist_for_cache = hist[cols_to_use].copy()
hist_for_cache = hist_for_cache.reset_index()
date_col = "Date" if "Date" in hist_for_cache.columns else hist_for_cache.columns[0]
hist_for_cache.rename(columns={date_col: "Date"}, inplace=True)
hist_for_cache["Date"] = pd.to_datetime(hist_for_cache["Date"]).dt.strftime('%Y-%m-%d')
hist_for_cache = hist_for_cache[["Date", "Close", "Volume"]]
hist_for_cache["Date"] = pd.to_datetime(hist_for_cache["Date"]).dt.strftime("%Y-%m-%d")
final_cols = ["Date"] + cols_to_use
hist_for_cache = hist_for_cache[final_cols]
return hist_for_cache.to_dict(orient="records")
@ -122,23 +126,194 @@ def _records_to_dataframe(history_records: List[Dict[str, Union[str, float, int]
return hist_df
def get_cached_average_volume(
symbol: str,
lookback_days: int = 20,
curr_date: Optional[str] = None,
cache_key: str = "default",
fallback_download: bool = True,
) -> Dict[str, Union[str, float, int, None]]:
"""Get average volume using cached unusual-volume data, with optional fallback download."""
symbol = symbol.upper()
cache_path = _get_cache_path(cache_key)
cache_date = None
history_records = None
if cache_path.exists():
try:
with open(cache_path, "r") as f:
cache_data = json.load(f)
cache_date = cache_data.get("date")
raw_data = cache_data.get("raw_data") or {}
history_records = raw_data.get(symbol)
except Exception:
history_records = None
source = "cache"
if not history_records and fallback_download:
history_records = _download_ticker_history(
symbol, history_period_days=max(90, lookback_days * 2)
)
source = "download"
if not history_records:
return {
"symbol": symbol,
"average_volume": None,
"latest_volume": None,
"lookback_days": lookback_days,
"source": source,
"cache_date": cache_date,
"error": "No volume data found",
}
hist_df = _records_to_dataframe(history_records)
if hist_df.empty or "Volume" not in hist_df.columns:
return {
"symbol": symbol,
"average_volume": None,
"latest_volume": None,
"lookback_days": lookback_days,
"source": source,
"cache_date": cache_date,
"error": "No volume data found",
}
if curr_date:
curr_dt = pd.to_datetime(curr_date)
hist_df = hist_df[hist_df["Date"] <= curr_dt]
recent = hist_df.tail(lookback_days)
if recent.empty:
return {
"symbol": symbol,
"average_volume": None,
"latest_volume": None,
"lookback_days": lookback_days,
"source": source,
"cache_date": cache_date,
"error": "No recent volume data found",
}
average_volume = float(recent["Volume"].mean())
latest_volume = float(recent["Volume"].iloc[-1])
return {
"symbol": symbol,
"average_volume": average_volume,
"latest_volume": latest_volume,
"lookback_days": lookback_days,
"source": source,
"cache_date": cache_date,
}
def get_cached_average_volume_batch(
symbols: List[str],
lookback_days: int = 20,
curr_date: Optional[str] = None,
cache_key: str = "default",
fallback_download: bool = True,
) -> Dict[str, Dict[str, Union[str, float, int, None]]]:
"""Get average volumes for multiple tickers using the cache once."""
cache_path = _get_cache_path(cache_key)
cache_date = None
raw_data = {}
if cache_path.exists():
try:
with open(cache_path, "r") as f:
cache_data = json.load(f)
cache_date = cache_data.get("date")
raw_data = cache_data.get("raw_data") or {}
except Exception:
raw_data = {}
results: Dict[str, Dict[str, Union[str, float, int, None]]] = {}
symbols_upper = [s.upper() for s in symbols if isinstance(s, str)]
def compute_from_records(symbol: str, history_records: List[Dict[str, Union[str, float, int]]]):
hist_df = _records_to_dataframe(history_records)
if hist_df.empty or "Volume" not in hist_df.columns:
return None, None, "No volume data found"
if curr_date:
curr_dt = pd.to_datetime(curr_date)
hist_df = hist_df[hist_df["Date"] <= curr_dt]
recent = hist_df.tail(lookback_days)
if recent.empty:
return None, None, "No recent volume data found"
avg_volume = float(recent["Volume"].mean())
latest_volume = float(recent["Volume"].iloc[-1])
return avg_volume, latest_volume, None
missing = []
for symbol in symbols_upper:
history_records = raw_data.get(symbol)
if history_records:
avg_volume, latest_volume, error = compute_from_records(symbol, history_records)
results[symbol] = {
"symbol": symbol,
"average_volume": avg_volume,
"latest_volume": latest_volume,
"lookback_days": lookback_days,
"source": "cache",
"cache_date": cache_date,
"error": error,
}
else:
missing.append(symbol)
if fallback_download and missing:
for symbol in missing:
history_records = _download_ticker_history(
symbol, history_period_days=max(90, lookback_days * 2)
)
if history_records:
avg_volume, latest_volume, error = compute_from_records(symbol, history_records)
results[symbol] = {
"symbol": symbol,
"average_volume": avg_volume,
"latest_volume": latest_volume,
"lookback_days": lookback_days,
"source": "download",
"cache_date": cache_date,
"error": error,
}
else:
results[symbol] = {
"symbol": symbol,
"average_volume": None,
"latest_volume": None,
"lookback_days": lookback_days,
"source": "download",
"cache_date": cache_date,
"error": "No volume data found",
}
return results
def _evaluate_unusual_volume_from_history(
ticker: str,
history_records: List[Dict[str, Union[str, float, int]]],
min_volume_multiple: float,
max_price_change: float,
lookback_days: int = 30
lookback_days: int = 30,
) -> Optional[Dict]:
"""
Evaluate a ticker's cached history for unusual volume patterns.
Now includes DIRECTION ANALYSIS to distinguish:
- Accumulation (high volume + price holds/rises) = BULLISH - keep
- Distribution (high volume + price drops) = BEARISH - skip
Args:
ticker: Stock ticker symbol
history_records: Cached price/volume history records
min_volume_multiple: Minimum volume multiple vs average
max_price_change: Maximum absolute price change percentage
lookback_days: Days to look back for average volume calculation
Returns:
Dict with ticker data if unusual volume detected, None otherwise
"""
@ -148,48 +323,76 @@ def _evaluate_unusual_volume_from_history(
return None
current_data = hist.iloc[-1]
current_volume = current_data['Volume']
current_price = current_data['Close']
current_volume = current_data["Volume"]
current_price = current_data["Close"]
avg_volume = hist['Volume'].iloc[-(lookback_days+1):-1].mean()
avg_volume = hist["Volume"].iloc[-(lookback_days + 1) : -1].mean()
if pd.isna(avg_volume) or avg_volume <= 0:
return None
volume_ratio = current_volume / avg_volume
price_start = hist['Close'].iloc[-(lookback_days+1)]
price_start = hist["Close"].iloc[-(lookback_days + 1)]
price_end = current_price
price_change_pct = ((price_end - price_start) / price_start) * 100
# === DIRECTION ANALYSIS (NEW) ===
# Check intraday direction to distinguish accumulation from distribution
intraday_change_pct = 0.0
direction = "neutral"
if "Open" in current_data and pd.notna(current_data["Open"]):
open_price = current_data["Open"]
if open_price > 0:
intraday_change_pct = ((current_price - open_price) / open_price) * 100
# Classify direction based on intraday movement
if intraday_change_pct > 0.5:
direction = "bullish" # Closed higher than open
elif intraday_change_pct < -1.5:
direction = "bearish" # Closed significantly lower than open
else:
direction = "neutral" # Flat intraday
# === DISTRIBUTION FILTER (NEW) ===
# Skip if high volume + bearish direction = likely distribution (selling)
if volume_ratio >= min_volume_multiple and direction == "bearish":
# This is likely DISTRIBUTION - smart money selling, not accumulation
# Return None to filter it out
return None
# Filter: High volume multiple AND low price change (accumulation signal)
if volume_ratio >= min_volume_multiple and abs(price_change_pct) < max_price_change:
# Determine signal type
if abs(price_change_pct) < 2.0:
# Determine signal type with direction context
if direction == "bullish" and abs(price_change_pct) < 3.0:
signal = "strong_accumulation" # Best signal: high volume, rising intraday
elif abs(price_change_pct) < 2.0:
signal = "accumulation"
elif abs(price_change_pct) < 5.0:
signal = "moderate_activity"
else:
signal = "building_momentum"
return {
"ticker": ticker.upper(),
"volume": int(current_volume),
"price": round(float(current_price), 2),
"price_change_pct": round(price_change_pct, 2),
"intraday_change_pct": round(intraday_change_pct, 2),
"direction": direction,
"volume_ratio": round(volume_ratio, 2),
"avg_volume": int(avg_volume),
"signal": signal
"signal": signal,
}
return None
except Exception:
return None
def _download_ticker_history(
ticker: str,
history_period_days: int = 90
ticker: str, history_period_days: int = 90
) -> Optional[List[Dict[str, Union[str, float, int]]]]:
"""
Download raw history for a ticker and return cache-friendly records.
@ -202,8 +405,7 @@ def _download_ticker_history(
List of history records or None if insufficient data
"""
try:
stock = yf.Ticker(ticker.upper())
hist = stock.history(period=f"{history_period_days}d")
hist = get_ticker_history(ticker, period=f"{history_period_days}d")
if hist.empty:
return None
@ -239,7 +441,7 @@ def download_volume_data(
Returns:
Dict mapping ticker symbols to their history records
"""
today = datetime.now().strftime('%Y-%m-%d')
today = datetime.now().strftime("%Y-%m-%d")
# Get cache path (we always need it for saving)
cache_path = _get_cache_path(cache_key)
@ -249,16 +451,16 @@ def download_volume_data(
cached_data = _load_cache(cache_path)
# Check if cache is fresh (from today)
if cached_data and cached_data.get('date') == today:
print(f" Using cached volume data from {cached_data['date']}")
return cached_data['raw_data']
if cached_data and cached_data.get("date") == today:
logger.info(f"Using cached volume data from {cached_data['date']}")
return cached_data["raw_data"]
elif cached_data:
print(f" Cache is stale (from {cached_data.get('date')}), re-downloading...")
logger.info(f"Cache is stale (from {cached_data.get('date')}), re-downloading...")
else:
print(f" Skipping cache (use_cache=False), forcing fresh download...")
logger.info("Skipping cache (use_cache=False), forcing fresh download...")
# Download fresh data
print(f" Downloading {history_period_days} days of volume data for {len(tickers)} tickers...")
logger.info(f"Downloading {history_period_days} days of volume data for {len(tickers)} tickers...")
raw_data = {}
with ThreadPoolExecutor(max_workers=15) as executor:
@ -271,7 +473,7 @@ def download_volume_data(
for future in as_completed(futures):
completed += 1
if completed % 50 == 0:
print(f" Progress: {completed}/{len(tickers)} tickers downloaded...")
logger.info(f"Progress: {completed}/{len(tickers)} tickers downloaded...")
ticker_symbol = futures[future].upper()
history_records = future.result()
@ -280,7 +482,7 @@ def download_volume_data(
# Always save fresh data to cache (so it's available next time)
if cache_path and raw_data:
print(f" Saving {len(raw_data)} tickers to cache...")
logger.info(f"Saving {len(raw_data)} tickers to cache...")
_save_cache(cache_path, raw_data, today)
return raw_data
@ -294,7 +496,8 @@ def get_unusual_volume(
tickers: Annotated[Optional[List[str]], "Custom ticker list or None to use config file"] = None,
max_tickers_to_scan: Annotated[int, "Maximum number of tickers to scan"] = 3000,
use_cache: Annotated[bool, "Use cached raw data when available"] = True,
) -> str:
return_structured: Annotated[bool, "Return list of dicts instead of markdown"] = False,
):
"""
Find stocks with unusual volume but minimal price movement.
@ -309,13 +512,15 @@ def get_unusual_volume(
tickers: Custom list of ticker symbols, or None to load from config file
max_tickers_to_scan: Maximum number of tickers to scan (default: 3000, scans all)
use_cache: Whether to reuse/save cached raw data
return_structured: If True, returns list of candidate dicts instead of markdown
Returns:
Formatted markdown report of stocks with unusual volume
If return_structured=True: list of candidate dicts with ticker, volume_ratio, signal, etc.
If return_structured=False: Formatted markdown report
"""
try:
lookback_days = 30
today = datetime.now().strftime('%Y-%m-%d')
today = datetime.now().strftime("%Y-%m-%d")
analysis_date = date or today
ticker_list = _get_ticker_universe(tickers=tickers, max_tickers=max_tickers_to_scan)
@ -327,15 +532,13 @@ def get_unusual_volume(
# Create cache key from ticker list or "default"
if isinstance(tickers, list):
import hashlib
cache_key = "custom_" + hashlib.md5(",".join(sorted(tickers)).encode()).hexdigest()[:8]
else:
cache_key = "default"
raw_data = download_volume_data(
tickers=ticker_list,
history_period_days=90,
use_cache=use_cache,
cache_key=cache_key
tickers=ticker_list, history_period_days=90, use_cache=use_cache, cache_key=cache_key
)
if not raw_data:
@ -352,38 +555,52 @@ def get_unusual_volume(
history_records,
min_volume_multiple,
max_price_change,
lookback_days=lookback_days
lookback_days=lookback_days,
)
if candidate:
unusual_candidates.append(candidate)
if not unusual_candidates:
if return_structured:
return []
return f"No stocks found with unusual volume patterns matching criteria\n\nScanned {len(ticker_list)} tickers."
# Sort by volume ratio (highest first)
sorted_candidates = sorted(
unusual_candidates,
key=lambda x: (x.get("volume_ratio", 0), x["volume"]),
reverse=True
unusual_candidates, key=lambda x: (x.get("volume_ratio", 0), x["volume"]), reverse=True
)
# Take top N for display
sorted_candidates = sorted_candidates[:top_n]
# Return structured data if requested
if return_structured:
return sorted_candidates
# Format output
report = f"# Unusual Volume Detected - {analysis_date}\n\n"
report += f"**Criteria**: \n"
report += "**Criteria**: \n"
report += f"- Price Change: <{max_price_change}% (accumulation pattern)\n"
report += f"- Volume Multiple: Current volume ≥ {min_volume_multiple}x 30-day average\n"
report += f"- Tickers Scanned: {ticker_count}\n\n"
report += f"**Found**: {len(sorted_candidates)} stocks with unusual activity\n\n"
report += "## Top Unusual Volume Candidates\n\n"
report += "| Ticker | Price | Volume | Avg Volume | Volume Ratio | Price Change % | Signal |\n"
report += "|--------|-------|--------|------------|--------------|----------------|--------|\n"
report += (
"| Ticker | Price | Volume | Avg Volume | Volume Ratio | Price Change % | Signal |\n"
)
report += (
"|--------|-------|--------|------------|--------------|----------------|--------|\n"
)
for candidate in sorted_candidates:
volume_ratio_str = f"{candidate.get('volume_ratio', 'N/A')}x" if candidate.get('volume_ratio') else "N/A"
avg_vol_str = f"{candidate.get('avg_volume', 0):,}" if candidate.get('avg_volume') else "N/A"
volume_ratio_str = (
f"{candidate.get('volume_ratio', 'N/A')}x"
if candidate.get("volume_ratio")
else "N/A"
)
avg_vol_str = (
f"{candidate.get('avg_volume', 0):,}" if candidate.get("avg_volume") else "N/A"
)
report += f"| {candidate['ticker']} | "
report += f"${candidate['price']:.2f} | "
report += f"{candidate['volume']:,} | "
@ -393,13 +610,19 @@ def get_unusual_volume(
report += f"{candidate['signal']} |\n"
report += "\n\n## Signal Definitions\n\n"
report += "- **strong_accumulation**: High volume + bullish intraday direction - Strongest buy signal\n"
report += "- **accumulation**: High volume, minimal price change (<2%) - Smart money building position\n"
report += "- **moderate_activity**: Elevated volume with 2-5% price change - Early momentum\n"
report += (
"- **moderate_activity**: Elevated volume with 2-5% price change - Early momentum\n"
)
report += "- **building_momentum**: High volume with moderate price change - Conviction building\n"
report += "\n**Note**: Distribution patterns (high volume + bearish direction) are automatically filtered out.\n"
return report
except Exception as e:
if return_structured:
return []
return f"Unexpected error in unusual volume detection: {str(e)}"
@ -414,11 +637,5 @@ def get_alpha_vantage_unusual_volume(
) -> str:
"""Alias for get_unusual_volume to match registry naming convention"""
return get_unusual_volume(
date,
min_volume_multiple,
max_price_change,
top_n,
tickers,
max_tickers_to_scan,
use_cache
date, min_volume_multiple, max_price_change, top_n, tickers, max_tickers_to_scan, use_cache
)

View File

@ -1,6 +1,7 @@
import tradingagents.default_config as default_config
from typing import Dict, Optional
import tradingagents.default_config as default_config
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
DATA_DIR: Optional[str] = None

View File

@ -0,0 +1,147 @@
"""
Delisted Cache System
---------------------
Track tickers that consistently fail data fetches (likely delisted).
SAFETY: Only cache tickers that:
- Passed initial format validation (not units/warrants/common words)
- Failed multiple times over multiple days
- Have consistent failure patterns (not temporary API issues)
"""
import json
from datetime import datetime
from pathlib import Path
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class DelistedCache:
"""
Track tickers that consistently fail data fetches (likely delisted).
SAFETY: Only cache tickers that:
- Passed initial format validation (not units/warrants/common words)
- Failed multiple times over multiple days
- Have consistent failure patterns (not temporary API issues)
"""
def __init__(self, cache_file="data/delisted_cache.json"):
self.cache_file = Path(cache_file)
self.cache = self._load_cache()
def _load_cache(self):
if self.cache_file.exists():
with open(self.cache_file, "r") as f:
return json.load(f)
return {}
def mark_failed(self, ticker, reason="no_data", error_code=None):
"""
Record a failed data fetch for a ticker.
Args:
ticker: Stock symbol
reason: Human-readable failure reason
error_code: Specific error (e.g., "404", "no_price_data", "empty_history")
"""
ticker = ticker.upper()
if ticker not in self.cache:
self.cache[ticker] = {
"first_failed": datetime.now().isoformat(),
"last_failed": datetime.now().isoformat(),
"fail_count": 1,
"reason": reason,
"error_code": error_code,
"fail_dates": [datetime.now().date().isoformat()],
}
else:
self.cache[ticker]["fail_count"] += 1
self.cache[ticker]["last_failed"] = datetime.now().isoformat()
self.cache[ticker]["reason"] = reason # Update to latest reason
# Track unique failure dates
today = datetime.now().date().isoformat()
if today not in self.cache[ticker].get("fail_dates", []):
self.cache[ticker].setdefault("fail_dates", []).append(today)
self._save_cache()
def is_likely_delisted(self, ticker, fail_threshold=5, days_threshold=14, min_unique_days=3):
"""
Conservative check: ticker must fail multiple times across multiple days.
Args:
fail_threshold: Minimum number of total failures (default: 5)
days_threshold: Must have failed within this many days (default: 14)
min_unique_days: Must have failed on at least this many different days (default: 3)
Returns:
bool: True if ticker is likely delisted
"""
ticker = ticker.upper()
if ticker not in self.cache:
return False
data = self.cache[ticker]
last_failed = datetime.fromisoformat(data["last_failed"])
days_since = (datetime.now() - last_failed).days
# Count unique failure days
unique_fail_days = len(set(data.get("fail_dates", [])))
# Conservative criteria:
# - Must have failed at least 5 times
# - Must have failed on at least 3 different days (not just repeated same-day attempts)
# - Last failure within 14 days (don't cache stale data)
return (
data["fail_count"] >= fail_threshold
and unique_fail_days >= min_unique_days
and days_since <= days_threshold
)
def get_failure_summary(self, ticker):
"""Get detailed failure info for manual review."""
ticker = ticker.upper()
if ticker not in self.cache:
return None
data = self.cache[ticker]
return {
"ticker": ticker,
"fail_count": data["fail_count"],
"unique_days": len(set(data.get("fail_dates", []))),
"first_failed": data["first_failed"],
"last_failed": data["last_failed"],
"reason": data["reason"],
"is_likely_delisted": self.is_likely_delisted(ticker),
}
def _save_cache(self):
self.cache_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.cache_file, "w") as f:
json.dump(self.cache, f, indent=2)
def export_review_list(self, output_file="data/delisted_review.txt"):
"""Export tickers that need manual review to add to DELISTED_TICKERS."""
likely_delisted = [
ticker for ticker in self.cache.keys() if self.is_likely_delisted(ticker)
]
if not likely_delisted:
return
with open(output_file, "w") as f:
f.write(
"# Tickers that have failed consistently (review before adding to DELISTED_TICKERS)\n\n"
)
for ticker in sorted(likely_delisted):
summary = self.get_failure_summary(ticker)
f.write(
f"{ticker:8s} - Failed {summary['fail_count']:2d} times across {summary['unique_days']} days - {summary['reason']}\n"
)
logger.info(f"📝 Review list exported to: {output_file}")

View File

@ -5,6 +5,10 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class DiscoveryAnalytics:
"""
@ -18,10 +22,10 @@ class DiscoveryAnalytics:
def update_performance_tracking(self):
"""Update performance metrics for all open recommendations."""
print("📊 Updating recommendation performance tracking...")
logger.info("📊 Updating recommendation performance tracking...")
if not self.recommendations_dir.exists():
print(" No historical recommendations to track yet.")
logger.info("No historical recommendations to track yet.")
return
# Load all recommendations
@ -44,15 +48,15 @@ class DiscoveryAnalytics:
)
all_recs.append(rec)
except Exception as e:
print(f" Warning: Error loading {filepath}: {e}")
logger.warning(f"Error loading {filepath}: {e}")
if not all_recs:
print(" No recommendations found to track.")
logger.info("No recommendations found to track.")
return
# Filter to only track open positions
open_recs = [r for r in all_recs if r.get("status") != "closed"]
print(f" Tracking {len(open_recs)} open positions (out of {len(all_recs)} total)...")
logger.info(f"Tracking {len(open_recs)} open positions (out of {len(all_recs)} total)...")
# Update performance
today = datetime.now().strftime("%Y-%m-%d")
@ -109,10 +113,10 @@ class DiscoveryAnalytics:
pass
if updated_count > 0:
print(f" Updated {updated_count} positions")
logger.info(f"Updated {updated_count} positions")
self._save_performance_db(all_recs)
else:
print(" No updates needed")
logger.info("No updates needed")
def _save_performance_db(self, all_recs: List[Dict]):
"""Save the aggregated performance database and recalculate stats."""
@ -142,7 +146,7 @@ class DiscoveryAnalytics:
with open(stats_path, "w") as f:
json.dump(stats, f, indent=2)
print(" 💾 Updated performance database and statistics")
logger.info("💾 Updated performance database and statistics")
def calculate_statistics(self, recommendations: list) -> dict:
"""Calculate aggregate statistics from historical performance."""
@ -259,7 +263,7 @@ class DiscoveryAnalytics:
return insights
except Exception as e:
print(f" Warning: Could not load historical stats: {e}")
logger.warning(f"Could not load historical stats: {e}")
return {"available": False, "message": "Error loading historical data"}
def format_stats_summary(self, stats: dict) -> str:
@ -315,7 +319,7 @@ class DiscoveryAnalytics:
try:
entry_price = get_stock_price(ticker, curr_date=trade_date)
except Exception as e:
print(f" Warning: Could not get entry price for {ticker}: {e}")
logger.warning(f"Could not get entry price for {ticker}: {e}")
entry_price = None
enriched_rankings.append(
@ -345,7 +349,7 @@ class DiscoveryAnalytics:
indent=2,
)
print(f" 📊 Saved {len(enriched_rankings)} recommendations for tracking: {output_file}")
logger.info(f" 📊 Saved {len(enriched_rankings)} recommendations for tracking: {output_file}")
def save_discovery_results(self, state: dict, trade_date: str, config: Dict[str, Any]):
"""Save full discovery results and tool logs."""
@ -390,7 +394,7 @@ class DiscoveryAnalytics:
f.write(f"- **{ticker}** ({strategy})\n")
except Exception as e:
print(f" Error saving results: {e}")
logger.error(f"Error saving results: {e}")
# Save as JSON
try:
@ -404,19 +408,17 @@ class DiscoveryAnalytics:
}
json.dump(json_state, f, indent=2)
except Exception as e:
print(f" Error saving JSON: {e}")
logger.error(f"Error saving JSON: {e}")
# Save tool logs
tool_logs = state.get("tool_logs", [])
if tool_logs:
tool_log_max_chars = (
config.get("discovery", {}).get("tool_log_max_chars", 10_000)
if config
else 10_000
config.get("discovery", {}).get("tool_log_max_chars", 10_000) if config else 10_000
)
self._save_tool_logs(results_dir, tool_logs, trade_date, tool_log_max_chars)
print(f" Results saved to: {results_dir}")
logger.info(f" Results saved to: {results_dir}")
def _write_ranking_md(self, f, final_ranking):
try:
@ -513,4 +515,4 @@ class DiscoveryAnalytics:
f.write(f"### Output\n```\n{output}\n```\n\n")
f.write("---\n\n")
except Exception as e:
print(f" Error saving tool logs: {e}")
logger.error(f"Error saving tool logs: {e}")

View File

@ -1,9 +1,11 @@
"""Common utilities for discovery scanners."""
import re
import logging
from typing import List, Set, Optional
logger = logging.getLogger(__name__)
import re
from typing import List, Optional, Set
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def get_common_stopwords() -> Set[str]:
@ -14,23 +16,84 @@ def get_common_stopwords() -> Set[str]:
"""
return {
# Common words
'THE', 'AND', 'FOR', 'ARE', 'BUT', 'NOT', 'YOU', 'ALL', 'CAN',
'HER', 'WAS', 'ONE', 'OUR', 'OUT', 'DAY', 'WHO', 'HAS', 'HAD',
'NEW', 'NOW', 'GET', 'GOT', 'PUT', 'SET', 'RUN', 'TOP', 'BIG',
"THE",
"AND",
"FOR",
"ARE",
"BUT",
"NOT",
"YOU",
"ALL",
"CAN",
"HER",
"WAS",
"ONE",
"OUR",
"OUT",
"DAY",
"WHO",
"HAS",
"HAD",
"NEW",
"NOW",
"GET",
"GOT",
"PUT",
"SET",
"RUN",
"TOP",
"BIG",
# Financial terms
'CEO', 'CFO', 'CTO', 'COO', 'USD', 'USA', 'SEC', 'IPO', 'ETF',
'NYSE', 'NASDAQ', 'WSB', 'DD', 'YOLO', 'FD', 'ATH', 'ATL', 'GDP',
'STOCK', 'STOCKS', 'MARKET', 'NEWS', 'PRICE', 'TRADE', 'SALES',
"CEO",
"CFO",
"CTO",
"COO",
"USD",
"USA",
"SEC",
"IPO",
"ETF",
"NYSE",
"NASDAQ",
"WSB",
"DD",
"YOLO",
"FD",
"ATH",
"ATL",
"GDP",
"STOCK",
"STOCKS",
"MARKET",
"NEWS",
"PRICE",
"TRADE",
"SALES",
# Time
'JAN', 'FEB', 'MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP',
'OCT', 'NOV', 'DEC', 'MON', 'TUE', 'WED', 'THU', 'FRI', 'SAT', 'SUN',
"JAN",
"FEB",
"MAR",
"APR",
"MAY",
"JUN",
"JUL",
"AUG",
"SEP",
"OCT",
"NOV",
"DEC",
"MON",
"TUE",
"WED",
"THU",
"FRI",
"SAT",
"SUN",
}
def extract_tickers_from_text(
text: str,
stop_words: Optional[Set[str]] = None,
max_text_length: int = 100_000
text: str, stop_words: Optional[Set[str]] = None, max_text_length: int = 100_000
) -> List[str]:
"""Extract valid ticker symbols from text.
@ -51,13 +114,11 @@ def extract_tickers_from_text(
"""
# Truncate oversized text to prevent ReDoS
if len(text) > max_text_length:
logger.warning(
f"Truncating oversized text from {len(text)} to {max_text_length} chars"
)
logger.warning(f"Truncating oversized text from {len(text)} to {max_text_length} chars")
text = text[:max_text_length]
# Match: $TICKER or standalone TICKER (2-5 uppercase letters)
ticker_pattern = r'\b([A-Z]{2,5})\b|\$([A-Z]{2,5})'
ticker_pattern = r"\b([A-Z]{2,5})\b|\$([A-Z]{2,5})"
matches = re.findall(ticker_pattern, text)
# Flatten tuples and deduplicate
@ -82,7 +143,7 @@ def validate_ticker_format(ticker: str) -> bool:
if not ticker or not isinstance(ticker, str):
return False
return bool(re.match(r'^[A-Z]{2,5}$', ticker.strip().upper()))
return bool(re.match(r"^[A-Z]{2,5}$", ticker.strip().upper()))
def validate_candidate_structure(candidate: dict) -> bool:
@ -94,7 +155,7 @@ def validate_candidate_structure(candidate: dict) -> bool:
Returns:
True if candidate has all required keys with valid types
"""
required_keys = {'ticker', 'source', 'context', 'priority'}
required_keys = {"ticker", "source", "context", "priority"}
if not isinstance(candidate, dict):
return False
@ -105,12 +166,12 @@ def validate_candidate_structure(candidate: dict) -> bool:
return False
# Validate ticker format
if not validate_ticker_format(candidate.get('ticker', '')):
if not validate_ticker_format(candidate.get("ticker", "")):
logger.warning(f"Invalid ticker format: {candidate.get('ticker')}")
return False
# Validate priority is string
if not isinstance(candidate.get('priority'), str):
if not isinstance(candidate.get("priority"), str):
logger.warning(f"Invalid priority type: {type(candidate.get('priority'))}")
return False

View File

@ -0,0 +1,210 @@
"""Typed discovery configuration — single source of truth for all discovery consumers."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List
@dataclass
class FilterConfig:
"""Filter-stage settings (from discovery.filters.*)."""
min_average_volume: int = 500_000
volume_lookback_days: int = 10
filter_same_day_movers: bool = True
intraday_movement_threshold: float = 10.0
filter_recent_movers: bool = True
recent_movement_lookback_days: int = 7
recent_movement_threshold: float = 10.0
recent_mover_action: str = "filter"
# Volume / compression detection
volume_cache_key: str = "default"
min_market_cap: int = 0
compression_atr_pct_max: float = 2.0
compression_bb_width_max: float = 6.0
compression_min_volume_ratio: float = 1.3
@dataclass
class EnrichmentConfig:
"""Enrichment-stage settings (from discovery.enrichment.*)."""
batch_news_vendor: str = "google"
batch_news_batch_size: int = 150
news_lookback_days: float = 0.5
context_max_snippets: int = 2
context_snippet_max_chars: int = 140
earnings_lookforward_days: int = 30
@dataclass
class RankerConfig:
"""Ranker settings (from discovery root level)."""
max_candidates_to_analyze: int = 200
analyze_all_candidates: bool = False
final_recommendations: int = 15
truncate_ranking_context: bool = False
max_news_chars: int = 500
max_insider_chars: int = 300
max_recommendations_chars: int = 300
@dataclass
class ChartConfig:
"""Console price chart settings (from discovery root level)."""
enabled: bool = True
library: str = "plotille"
windows: List[str] = field(default_factory=lambda: ["1d", "7d", "1m", "6m", "1y"])
lookback_days: int = 30
width: int = 60
height: int = 12
max_tickers: int = 10
show_movement_stats: bool = True
@dataclass
class LoggingConfig:
"""Tool execution logging settings (from discovery root level)."""
log_tool_calls: bool = True
log_tool_calls_console: bool = False
log_prompts_console: bool = False # Show LLM prompts in console (always saved to log file)
tool_log_max_chars: int = 10_000
tool_log_exclude: List[str] = field(default_factory=lambda: ["validate_ticker"])
@dataclass
class DiscoveryConfig:
"""
Consolidated discovery configuration.
All defaults match ``default_config.py``. Consumers should create an
instance via ``DiscoveryConfig.from_config(raw_config)`` rather than
reaching into the raw dict themselves.
"""
# Nested configs
filters: FilterConfig = field(default_factory=FilterConfig)
enrichment: EnrichmentConfig = field(default_factory=EnrichmentConfig)
ranker: RankerConfig = field(default_factory=RankerConfig)
charts: ChartConfig = field(default_factory=ChartConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
# Flat settings at discovery root level
deep_dive_max_workers: int = 1
discovery_mode: str = "hybrid"
@classmethod
def from_config(cls, raw_config: Dict[str, Any]) -> DiscoveryConfig:
"""Build a ``DiscoveryConfig`` from the raw application config dict."""
disc = raw_config.get("discovery", {})
# Default instances — used to read fallback values for fields that
# use default_factory (which aren't available as class-level attrs).
_fd = FilterConfig()
_ed = EnrichmentConfig()
_rd = RankerConfig()
_cd = ChartConfig()
_ld = LoggingConfig()
# Filters — nested under "filters" key, fallback to root for old configs
f = disc.get("filters", disc)
filters = FilterConfig(
min_average_volume=f.get("min_average_volume", _fd.min_average_volume),
volume_lookback_days=f.get("volume_lookback_days", _fd.volume_lookback_days),
filter_same_day_movers=f.get("filter_same_day_movers", _fd.filter_same_day_movers),
intraday_movement_threshold=f.get(
"intraday_movement_threshold", _fd.intraday_movement_threshold
),
filter_recent_movers=f.get("filter_recent_movers", _fd.filter_recent_movers),
recent_movement_lookback_days=f.get(
"recent_movement_lookback_days", _fd.recent_movement_lookback_days
),
recent_movement_threshold=f.get(
"recent_movement_threshold", _fd.recent_movement_threshold
),
recent_mover_action=f.get("recent_mover_action", _fd.recent_mover_action),
volume_cache_key=f.get("volume_cache_key", _fd.volume_cache_key),
min_market_cap=f.get("min_market_cap", _fd.min_market_cap),
compression_atr_pct_max=f.get("compression_atr_pct_max", _fd.compression_atr_pct_max),
compression_bb_width_max=f.get(
"compression_bb_width_max", _fd.compression_bb_width_max
),
compression_min_volume_ratio=f.get(
"compression_min_volume_ratio", _fd.compression_min_volume_ratio
),
)
# Enrichment — nested under "enrichment" key, fallback to root
e = disc.get("enrichment", disc)
enrichment = EnrichmentConfig(
batch_news_vendor=e.get("batch_news_vendor", _ed.batch_news_vendor),
batch_news_batch_size=e.get("batch_news_batch_size", _ed.batch_news_batch_size),
news_lookback_days=e.get("news_lookback_days", _ed.news_lookback_days),
context_max_snippets=e.get("context_max_snippets", _ed.context_max_snippets),
context_snippet_max_chars=e.get(
"context_snippet_max_chars", _ed.context_snippet_max_chars
),
earnings_lookforward_days=e.get(
"earnings_lookforward_days", _ed.earnings_lookforward_days
),
)
# Ranker
ranker = RankerConfig(
max_candidates_to_analyze=disc.get(
"max_candidates_to_analyze", _rd.max_candidates_to_analyze
),
analyze_all_candidates=disc.get(
"analyze_all_candidates", _rd.analyze_all_candidates
),
final_recommendations=disc.get("final_recommendations", _rd.final_recommendations),
truncate_ranking_context=disc.get(
"truncate_ranking_context", _rd.truncate_ranking_context
),
max_news_chars=disc.get("max_news_chars", _rd.max_news_chars),
max_insider_chars=disc.get("max_insider_chars", _rd.max_insider_chars),
max_recommendations_chars=disc.get(
"max_recommendations_chars", _rd.max_recommendations_chars
),
)
# Charts — keys prefixed with "price_chart_" at discovery root level
charts = ChartConfig(
enabled=disc.get("console_price_charts", _cd.enabled),
library=disc.get("price_chart_library", _cd.library),
windows=disc.get("price_chart_windows", _cd.windows),
lookback_days=disc.get("price_chart_lookback_days", _cd.lookback_days),
width=disc.get("price_chart_width", _cd.width),
height=disc.get("price_chart_height", _cd.height),
max_tickers=disc.get("price_chart_max_tickers", _cd.max_tickers),
show_movement_stats=disc.get(
"price_chart_show_movement_stats", _cd.show_movement_stats
),
)
# Logging
logging_cfg = LoggingConfig(
log_tool_calls=disc.get("log_tool_calls", _ld.log_tool_calls),
log_tool_calls_console=disc.get(
"log_tool_calls_console", _ld.log_tool_calls_console
),
log_prompts_console=disc.get(
"log_prompts_console", _ld.log_prompts_console
),
tool_log_max_chars=disc.get("tool_log_max_chars", _ld.tool_log_max_chars),
tool_log_exclude=disc.get("tool_log_exclude", _ld.tool_log_exclude),
)
return cls(
filters=filters,
enrichment=enrichment,
ranker=ranker,
charts=charts,
logging=logging_cfg,
deep_dive_max_workers=disc.get("deep_dive_max_workers", 1),
discovery_mode=disc.get("discovery_mode", "hybrid"),
)

View File

@ -1,15 +1,21 @@
import json
import re
from datetime import timedelta
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List
import pandas as pd
from tradingagents.dataflows.discovery.candidate import Candidate
from tradingagents.dataflows.discovery.discovery_config import DiscoveryConfig
from tradingagents.dataflows.discovery.utils import (
PRIORITY_ORDER,
Strategy,
is_valid_ticker,
resolve_trade_date,
)
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def _parse_market_cap_to_billions(value: Any) -> Any:
@ -107,34 +113,35 @@ class CandidateFilter:
self.config = config
self.execute_tool = tool_executor
# Discovery Settings
discovery_config = config.get("discovery", {})
dc = DiscoveryConfig.from_config(config)
# Filter settings (nested under "filters" section, with backward compatibility)
filter_config = discovery_config.get("filters", discovery_config) # Fallback to root for old configs
self.filter_same_day_movers = filter_config.get("filter_same_day_movers", True)
self.intraday_movement_threshold = filter_config.get("intraday_movement_threshold", 10.0)
self.filter_recent_movers = filter_config.get("filter_recent_movers", True)
self.recent_movement_lookback_days = filter_config.get("recent_movement_lookback_days", 7)
self.recent_movement_threshold = filter_config.get("recent_movement_threshold", 10.0)
self.recent_mover_action = filter_config.get("recent_mover_action", "filter")
self.min_average_volume = filter_config.get("min_average_volume", 500_000)
self.volume_lookback_days = filter_config.get("volume_lookback_days", 10)
# Filter settings
self.filter_same_day_movers = dc.filters.filter_same_day_movers
self.intraday_movement_threshold = dc.filters.intraday_movement_threshold
self.filter_recent_movers = dc.filters.filter_recent_movers
self.recent_movement_lookback_days = dc.filters.recent_movement_lookback_days
self.recent_movement_threshold = dc.filters.recent_movement_threshold
self.recent_mover_action = dc.filters.recent_mover_action
self.min_average_volume = dc.filters.min_average_volume
self.volume_lookback_days = dc.filters.volume_lookback_days
# Enrichment settings (nested under "enrichment" section, with backward compatibility)
enrichment_config = discovery_config.get("enrichment", discovery_config) # Fallback to root
self.batch_news_vendor = enrichment_config.get("batch_news_vendor", "openai")
self.batch_news_batch_size = enrichment_config.get("batch_news_batch_size", 50)
# Filter extras (volume/compression detection)
self.volume_cache_key = dc.filters.volume_cache_key
self.min_market_cap = dc.filters.min_market_cap
self.compression_atr_pct_max = dc.filters.compression_atr_pct_max
self.compression_bb_width_max = dc.filters.compression_bb_width_max
self.compression_min_volume_ratio = dc.filters.compression_min_volume_ratio
# Other settings (remain at discovery level)
self.news_lookback_days = discovery_config.get("news_lookback_days", 3)
self.volume_cache_key = discovery_config.get("volume_cache_key", "avg_volume_cache")
self.min_market_cap = discovery_config.get("min_market_cap", 0)
self.compression_atr_pct_max = discovery_config.get("compression_atr_pct_max", 2.0)
self.compression_bb_width_max = discovery_config.get("compression_bb_width_max", 6.0)
self.compression_min_volume_ratio = discovery_config.get("compression_min_volume_ratio", 1.3)
self.context_max_snippets = discovery_config.get("context_max_snippets", 2)
self.context_snippet_max_chars = discovery_config.get("context_snippet_max_chars", 140)
# Enrichment settings
self.batch_news_vendor = dc.enrichment.batch_news_vendor
self.batch_news_batch_size = dc.enrichment.batch_news_batch_size
self.news_lookback_days = dc.enrichment.news_lookback_days
self.context_max_snippets = dc.enrichment.context_max_snippets
self.context_snippet_max_chars = dc.enrichment.context_snippet_max_chars
# ML predictor (loaded lazily — None if no model file exists)
self._ml_predictor = None
self._ml_predictor_loaded = False
def filter(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""Filter candidates based on strategy and enrich with additional data."""
@ -150,7 +157,7 @@ class CandidateFilter:
start_date = start_date_obj.strftime("%Y-%m-%d")
end_date = end_date_obj.strftime("%Y-%m-%d")
print(f"🔍 Filtering and enriching {len(candidates)} candidates...")
logger.info(f"🔍 Filtering and enriching {len(candidates)} candidates...")
priority_order = self._priority_order()
candidates = self._dedupe_candidates(candidates, priority_order)
@ -178,12 +185,12 @@ class CandidateFilter:
# Print consolidated list of failed tickers
if failed_tickers:
print(f"\n ⚠️ {len(failed_tickers)} tickers failed data fetch (possibly delisted)")
logger.warning(f"⚠️ {len(failed_tickers)} tickers failed data fetch (possibly delisted)")
if len(failed_tickers) <= 10:
print(f" {', '.join(failed_tickers)}")
logger.warning(f"{', '.join(failed_tickers)}")
else:
print(
f" {', '.join(failed_tickers[:10])} ... and {len(failed_tickers)-10} more"
logger.warning(
f"{', '.join(failed_tickers[:10])} ... and {len(failed_tickers)-10} more"
)
# Export review list
delisted_cache.export_review_list()
@ -255,6 +262,16 @@ class CandidateFilter:
unique_candidates[ticker] = primary
# Compute confluence scores and boost priority for multi-source candidates
for candidate in unique_candidates.values():
source_count = len(candidate.all_sources)
candidate.extras["confluence_score"] = source_count
if source_count >= 3 and candidate.priority != "critical":
candidate.priority = "critical"
elif source_count >= 2 and candidate.priority in ("medium", "low", "unknown"):
candidate.priority = "high"
return [candidate.to_dict() for candidate in unique_candidates.values()]
def _sort_by_priority(
@ -268,8 +285,8 @@ class CandidateFilter:
high_priority = sum(1 for c in candidates if c.get("priority") == "high")
medium_priority = sum(1 for c in candidates if c.get("priority") == "medium")
low_priority = sum(1 for c in candidates if c.get("priority") == "low")
print(
f" Priority breakdown: {critical_priority} critical, {high_priority} high, {medium_priority} medium, {low_priority} low"
logger.info(
f"Priority breakdown: {critical_priority} critical, {high_priority} high, {medium_priority} medium, {low_priority} low"
)
def _fetch_batch_volume(
@ -299,7 +316,7 @@ class CandidateFilter:
if self.batch_news_vendor == "google":
from tradingagents.dataflows.openai import get_batch_stock_news_google
print(f" 📰 Batch fetching news (Google) for {len(all_tickers)} tickers...")
logger.info(f"📰 Batch fetching news (Google) for {len(all_tickers)} tickers...")
news_by_ticker = self._run_call(
"batch fetching news (Google)",
get_batch_stock_news_google,
@ -312,7 +329,7 @@ class CandidateFilter:
else: # Default to OpenAI
from tradingagents.dataflows.openai import get_batch_stock_news_openai
print(f" 📰 Batch fetching news (OpenAI) for {len(all_tickers)} tickers...")
logger.info(f"📰 Batch fetching news (OpenAI) for {len(all_tickers)} tickers...")
news_by_ticker = self._run_call(
"batch fetching news (OpenAI)",
get_batch_stock_news_openai,
@ -322,10 +339,10 @@ class CandidateFilter:
end_date=end_date,
batch_size=self.batch_news_batch_size,
)
print(f" ✓ Batch news fetched for {len(news_by_ticker)} tickers")
logger.info(f"✓ Batch news fetched for {len(news_by_ticker)} tickers")
return news_by_ticker
except Exception as e:
print(f" Warning: Batch news fetch failed, will skip news enrichment: {e}")
logger.warning(f"Batch news fetch failed, will skip news enrichment: {e}")
return {}
def _filter_and_enrich_candidates(
@ -368,8 +385,8 @@ class CandidateFilter:
if intraday_check.get("already_moved"):
filtered_reasons["intraday_moved"] += 1
intraday_pct = intraday_check.get("intraday_change_pct", 0)
print(
f" Filtered {ticker}: Already moved {intraday_pct:+.1f}% today (stale)"
logger.info(
f"Filtered {ticker}: Already moved {intraday_pct:+.1f}% today (stale)"
)
continue
@ -378,7 +395,7 @@ class CandidateFilter:
except Exception as e:
# Don't filter out if check fails, just log
print(f" Warning: Could not check intraday movement for {ticker}: {e}")
logger.warning(f"Could not check intraday movement for {ticker}: {e}")
# Recent multi-day mover filter (avoid stocks that already ran)
if self.filter_recent_movers:
@ -397,8 +414,8 @@ class CandidateFilter:
if self.recent_mover_action == "filter":
filtered_reasons["recent_moved"] += 1
change_pct = reaction.get("price_change_pct", 0)
print(
f" Filtered {ticker}: Already moved {change_pct:+.1f}% in last "
logger.info(
f"Filtered {ticker}: Already moved {change_pct:+.1f}% in last "
f"{self.recent_movement_lookback_days} days"
)
continue
@ -411,7 +428,7 @@ class CandidateFilter:
f"over {self.recent_movement_lookback_days}d"
)
except Exception as e:
print(f" Warning: Could not check recent movement for {ticker}: {e}")
logger.warning(f"Could not check recent movement for {ticker}: {e}")
# Liquidity filter based on average volume
if self.min_average_volume:
@ -482,13 +499,37 @@ class CandidateFilter:
cand["business_description"] = (
f"{company_name} - Business description not available."
)
# Extract short interest from fundamentals (no extra API call)
short_pct_raw = fund.get("ShortPercentOfFloat", fund.get("ShortPercentFloat"))
short_interest_pct = None
if short_pct_raw and short_pct_raw != "N/A":
try:
short_interest_pct = round(float(short_pct_raw) * 100, 2)
except (ValueError, TypeError):
pass
cand["short_interest_pct"] = short_interest_pct
cand["high_short_interest"] = (
short_interest_pct is not None and short_interest_pct > 15.0
)
short_ratio_raw = fund.get("ShortRatio")
if short_ratio_raw and short_ratio_raw != "N/A":
try:
cand["short_ratio"] = float(short_ratio_raw)
except (ValueError, TypeError):
cand["short_ratio"] = None
else:
cand["short_ratio"] = None
else:
cand["fundamentals"] = {}
cand["business_description"] = (
f"{ticker} - Business description not available."
)
cand["short_interest_pct"] = None
cand["high_short_interest"] = False
cand["short_ratio"] = None
except Exception as e:
print(f" Warning: Could not fetch fundamentals for {ticker}: {e}")
logger.warning(f"Could not fetch fundamentals for {ticker}: {e}")
delisted_cache.mark_failed(ticker, str(e))
failed_tickers.append(ticker)
cand["current_price"] = None
@ -630,10 +671,59 @@ class CandidateFilter:
else:
cand["has_bullish_options_flow"] = False
# Normalize options signal for quantitative scoring
cand["options_signal"] = cand.get("options_flow", {}).get("signal", "neutral")
# 5. Earnings Estimate Enrichment
from tradingagents.dataflows.finnhub_api import get_ticker_earnings_estimate
earnings_to = (
datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=30)
).strftime("%Y-%m-%d")
earnings_data = self._run_call(
"fetching earnings estimate",
get_ticker_earnings_estimate,
default={},
ticker=ticker,
from_date=end_date,
to_date=earnings_to,
)
if earnings_data.get("has_upcoming_earnings"):
cand["has_upcoming_earnings"] = True
cand["days_to_earnings"] = earnings_data.get("days_to_earnings")
cand["eps_estimate"] = earnings_data.get("eps_estimate")
cand["revenue_estimate"] = earnings_data.get("revenue_estimate")
cand["earnings_date"] = earnings_data.get("earnings_date")
else:
cand["has_upcoming_earnings"] = False
# Extract derived signals for quant scoring
tech_report = cand.get("technical_indicators", "")
rsi_match = re.search(
r"RSI.*?Value[:\s]*(\d+\.?\d*)", tech_report, re.IGNORECASE | re.DOTALL
)
if rsi_match:
cand["rsi_value"] = float(rsi_match.group(1))
insider_text = cand.get("insider_transactions", "")
cand["has_insider_buying"] = (
isinstance(insider_text, str) and "Purchase" in insider_text
)
# Compute quantitative pre-score
cand["quant_score"] = self._compute_quant_score(cand)
# ML win probability prediction (if model available)
ml_result = self._predict_ml(cand, ticker, end_date)
if ml_result:
cand["ml_win_probability"] = ml_result["win_prob"]
cand["ml_prediction"] = ml_result["prediction"]
cand["ml_loss_probability"] = ml_result["loss_prob"]
filtered_candidates.append(cand)
except Exception as e:
print(f" Error checking {ticker}: {e}")
logger.error(f"Error checking {ticker}: {e}")
return filtered_candidates, filtered_reasons, failed_tickers, delisted_cache
@ -643,19 +733,116 @@ class CandidateFilter:
filtered_candidates: List[Dict[str, Any]],
filtered_reasons: Dict[str, int],
) -> None:
print("\n 📊 Filtering Summary:")
print(f" Starting candidates: {len(candidates)}")
logger.info("\n 📊 Filtering Summary:")
logger.info(f" Starting candidates: {len(candidates)}")
if filtered_reasons.get("intraday_moved", 0) > 0:
print(f" ❌ Same-day movers: {filtered_reasons['intraday_moved']}")
logger.info(f" ❌ Same-day movers: {filtered_reasons['intraday_moved']}")
if filtered_reasons.get("recent_moved", 0) > 0:
print(f" ❌ Recent movers: {filtered_reasons['recent_moved']}")
logger.info(f" ❌ Recent movers: {filtered_reasons['recent_moved']}")
if filtered_reasons.get("volume", 0) > 0:
print(f" ❌ Low volume: {filtered_reasons['volume']}")
logger.info(f" ❌ Low volume: {filtered_reasons['volume']}")
if filtered_reasons.get("market_cap", 0) > 0:
print(f" ❌ Below market cap: {filtered_reasons['market_cap']}")
logger.info(f" ❌ Below market cap: {filtered_reasons['market_cap']}")
if filtered_reasons.get("no_data", 0) > 0:
print(f" ❌ No data available: {filtered_reasons['no_data']}")
print(f" ✅ Passed filters: {len(filtered_candidates)}")
logger.info(f" ❌ No data available: {filtered_reasons['no_data']}")
logger.info(f" ✅ Passed filters: {len(filtered_candidates)}")
def _predict_ml(
self, cand: Dict[str, Any], ticker: str, end_date: str
) -> Any:
"""Run ML win probability prediction for a candidate."""
# Lazy-load predictor on first call
if not self._ml_predictor_loaded:
self._ml_predictor_loaded = True
try:
from tradingagents.ml.predictor import MLPredictor
self._ml_predictor = MLPredictor.load()
if self._ml_predictor:
logger.info("ML predictor loaded — will add win probabilities")
except Exception as e:
logger.debug(f"ML predictor not available: {e}")
if self._ml_predictor is None:
return None
try:
from tradingagents.ml.feature_engineering import (
compute_features_single,
)
from tradingagents.dataflows.y_finance import download_history
# Fetch OHLCV for feature computation (needs ~210 rows of history)
ohlcv = download_history(
ticker,
start=pd.Timestamp(end_date) - pd.DateOffset(years=2),
end=end_date,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
if ohlcv.empty:
return None
ohlcv = ohlcv.reset_index()
market_cap = cand.get("market_cap_bil", 0)
market_cap_usd = market_cap * 1e9 if market_cap else None
features = compute_features_single(ohlcv, end_date, market_cap=market_cap_usd)
if features is None:
return None
return self._ml_predictor.predict(features)
except Exception as e:
logger.debug(f"ML prediction failed for {ticker}: {e}")
return None
def _compute_quant_score(self, cand: Dict[str, Any]) -> int:
"""Compute a 0-100 quantitative pre-score from hard data."""
score = 0
# Volume ratio (max +15)
vol_ratio = cand.get("volume_ratio")
if vol_ratio is not None:
if vol_ratio >= 2.0:
score += 15
elif vol_ratio >= 1.5:
score += 10
elif vol_ratio >= 1.3:
score += 5
# Confluence — per independent source, max 3 (max +30)
confluence = cand.get("confluence_score", 1)
score += min(confluence, 3) * 10
# Options flow signal (max +20)
options_signal = cand.get("options_signal", "neutral")
if options_signal == "very_bullish":
score += 20
elif options_signal == "bullish":
score += 15
# Insider buying detected (max +10)
if cand.get("has_insider_buying"):
score += 10
# Volatility compression with volume uptick (max +10)
if cand.get("has_volatility_compression"):
score += 10
# Healthy RSI momentum: 40-65 range (max +5)
rsi = cand.get("rsi_value")
if rsi is not None and 40 <= rsi <= 65:
score += 5
# Short squeeze potential: 5-20% short interest (max +5)
short_pct = cand.get("short_interest_pct")
if short_pct is not None and 5.0 <= short_pct <= 20.0:
score += 5
return min(score, 100)
def _run_tool(
self,
@ -674,7 +861,7 @@ class CandidateFilter:
**params,
)
except Exception as e:
print(f" Error during {step}: {e}")
logger.error(f"Error during {step}: {e}")
return default
def _run_call(
@ -687,7 +874,7 @@ class CandidateFilter:
try:
return func(**kwargs)
except Exception as e:
print(f" Error {label}: {e}")
logger.error(f"Error {label}: {e}")
return default
def _assign_strategy(self, cand: Dict[str, Any]):

View File

@ -6,9 +6,12 @@ Maintains complete price time-series and calculates real-time metrics.
"""
import json
import os
from datetime import datetime
from pathlib import Path
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
from typing import Any, Dict, List, Optional
@ -189,6 +192,6 @@ class PositionTracker:
open_positions.append(position)
except (json.JSONDecodeError, IOError) as e:
# Log error but continue loading other positions
print(f"Error loading position from {filepath}: {e}")
logger.error(f"Error loading position from {filepath}: {e}")
return open_positions

View File

@ -7,6 +7,7 @@ from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from tradingagents.dataflows.discovery.discovery_config import DiscoveryConfig
from tradingagents.dataflows.discovery.utils import append_llm_log, resolve_llm_name
from tradingagents.utils.logger import get_logger
@ -51,7 +52,7 @@ class StockRanking(BaseModel):
strategy_match: str = Field(description="Strategy that matched")
final_score: int = Field(description="Score 0-100")
confidence: int = Field(description="Confidence 1-10")
reason: str = Field(description="Investment thesis")
reason: str = Field(description="Detailed investment thesis (4-6 sentences) defending the trade with specific catalysts, risk/reward, and timing")
description: str = Field(description="Company description")
@ -71,15 +72,18 @@ class CandidateRanker:
self.llm = llm
self.analytics = analytics
discovery_config = config.get("discovery", {})
self.max_candidates_to_analyze = discovery_config.get("max_candidates_to_analyze", 30)
self.final_recommendations = discovery_config.get("final_recommendations", 3)
dc = DiscoveryConfig.from_config(config)
self.max_candidates_to_analyze = dc.ranker.max_candidates_to_analyze
self.final_recommendations = dc.ranker.final_recommendations
# Truncation settings
self.truncate_context = discovery_config.get("truncate_ranking_context", False)
self.max_news_chars = discovery_config.get("max_news_chars", 500)
self.max_insider_chars = discovery_config.get("max_insider_chars", 300)
self.max_recommendations_chars = discovery_config.get("max_recommendations_chars", 300)
self.truncate_context = dc.ranker.truncate_ranking_context
self.max_news_chars = dc.ranker.max_news_chars
self.max_insider_chars = dc.ranker.max_insider_chars
self.max_recommendations_chars = dc.ranker.max_recommendations_chars
# Prompt logging
self.log_prompts_console = dc.logging.log_prompts_console
def rank(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""Rank all filtered candidates and select the top opportunities."""
@ -87,7 +91,7 @@ class CandidateRanker:
trade_date = state.get("trade_date", datetime.now().strftime("%Y-%m-%d"))
if len(candidates) == 0:
print("⚠️ No candidates to rank.")
logger.warning("⚠️ No candidates to rank.")
return {
"opportunities": [],
"final_ranking": "[]",
@ -98,20 +102,20 @@ class CandidateRanker:
# Limit candidates to prevent token overflow
max_candidates = min(self.max_candidates_to_analyze, 200)
if len(candidates) > max_candidates:
print(
f" ⚠️ Too many candidates ({len(candidates)}), limiting to top {max_candidates} by priority"
logger.warning(
f"⚠️ Too many candidates ({len(candidates)}), limiting to top {max_candidates} by priority"
)
candidates = candidates[:max_candidates]
print(
logger.info(
f"🏆 Ranking {len(candidates)} candidates to select top {self.final_recommendations}..."
)
# Load historical performance statistics
historical_stats = self.analytics.load_historical_stats()
if historical_stats.get("available"):
print(
f" 📊 Loaded historical stats: {historical_stats.get('total_tracked', 0)} tracked recommendations"
logger.info(
f"📊 Loaded historical stats: {historical_stats.get('total_tracked', 0)} tracked recommendations"
)
# Build RICH context for each candidate
@ -213,10 +217,41 @@ class CandidateRanker:
recommendations_text[: self.max_recommendations_chars] + "..."
)
# New enrichment fields
confluence_score = cand.get("confluence_score", 1)
quant_score = cand.get("quant_score", "N/A")
# ML prediction
ml_win_prob = cand.get("ml_win_probability")
ml_prediction = cand.get("ml_prediction")
if ml_win_prob is not None:
ml_str = f"{ml_win_prob:.1%} (Predicted: {ml_prediction})"
else:
ml_str = "N/A"
short_interest_pct = cand.get("short_interest_pct")
high_short = cand.get("high_short_interest", False)
short_str = f"{short_interest_pct:.1f}%" if short_interest_pct else "N/A"
if high_short:
short_str += " (HIGH)"
# Earnings estimate
if cand.get("has_upcoming_earnings"):
days = cand.get("days_to_earnings", "?")
eps_est = cand.get("eps_estimate")
rev_est = cand.get("revenue_estimate")
earnings_date = cand.get("earnings_date", "N/A")
eps_str = f"${eps_est:.2f}" if isinstance(eps_est, (int, float)) else "N/A"
rev_str = f"${rev_est:,.0f}" if isinstance(rev_est, (int, float)) else "N/A"
earnings_section = f"Earnings in {days} days ({earnings_date}): EPS Est {eps_str}, Rev Est {rev_str}"
else:
earnings_section = "No upcoming earnings within 30 days"
summary = f"""### {ticker} (Priority: {priority.upper()})
- **Strategy Match**: {strategy}
- **Sources**: {source_str}
- **Sources**: {source_str} | **Confluence**: {confluence_score} source(s)
- **Quant Pre-Score**: {quant_score}/100 | **ML Win Probability**: {ml_str}
- **Price**: {price_str} | **Current Price (numeric)**: {current_price if isinstance(current_price, (int, float)) else "N/A"} | **Intraday**: {intraday_str} | **Avg Volume**: {volume_str}
- **Short Interest**: {short_str}
- **Discovery Context**: {context}
- **Business**: {business_description}
- **News**: {news_summary}
@ -234,6 +269,8 @@ class CandidateRanker:
**Options Activity**:
{options_activity if options_activity else "N/A"}
**Upcoming Earnings**: {earnings_section}
"""
candidate_summaries.append(summary)
@ -256,12 +293,14 @@ CANDIDATES FOR REVIEW:
INSTRUCTIONS:
1. Analyze each candidate's "Discovery Context" (why it was found) and "Strategy Match".
2. Cross-reference with Technicals (RSI, etc.) and Fundamentals.
3. Prioritize "LEADING" indicators (Undiscovered DD, Earnings Accumulation, Insider Buying) over lagging ones.
4. Select exactly {self.final_recommendations} winners.
5. Use ONLY the information provided in the candidates section; do NOT invent catalysts, prices, or metrics.
6. If a required field is missing, set it to null (do not guess).
7. Rank only tickers from the candidates list.
8. Reasons must reference at least two concrete facts from the candidate context.
3. Use the Quantitative Pre-Score as an objective baseline. Scores above 50 indicate strong multi-factor alignment.
4. The ML Win Probability is a trained model's estimate that this stock hits +5% within 7 days. Treat scores above 60% as strong ML confirmation.
5. Prioritize "LEADING" indicators (Undiscovered DD, Earnings Accumulation, Insider Buying) over lagging ones.
6. Select exactly {self.final_recommendations} winners.
7. Use ONLY the information provided in the candidates section; do NOT invent catalysts, prices, or metrics.
8. If a required field is missing, set it to null (do not guess).
9. Rank only tickers from the candidates list.
10. Reasons must reference at least two concrete facts from the candidate context.
Output a JSON object with a 'rankings' list. Each item should have:
- rank: 1 to {self.final_recommendations}
@ -271,17 +310,20 @@ Output a JSON object with a 'rankings' list. Each item should have:
- strategy_match: main strategy
- final_score: 0-100 score
- confidence: 1-10 confidence level
- reason: Detailed investment thesis (2-3 sentences) explaining WHY this will move NOW.
- reason: Detailed investment thesis (4-6 sentences). Defend the trade: (1) what is the catalyst/edge, (2) why NOW and not later, (3) what does the risk/reward look like, (4) what could go wrong. Reference specific data points from the candidate context.
- description: Brief company description.
JSON FORMAT ONLY. No markdown, no extra text. All numeric fields must be numbers (not strings)."""
# Invoke LLM with structured output
print(" 🧠 Deep Thinking Ranker analyzing opportunities...")
logger.info("🧠 Deep Thinking Ranker analyzing opportunities...")
logger.info(
f"Invoking ranking LLM with {len(candidates)} candidates, prompt length: {len(prompt)} chars"
)
logger.debug(f"Full ranking prompt:\n{prompt}")
if self.log_prompts_console:
logger.info(f"Full ranking prompt:\n{prompt}")
else:
logger.debug(f"Full ranking prompt:\n{prompt}")
try:
# Use structured output with include_raw for debugging
@ -364,7 +406,7 @@ JSON FORMAT ONLY. No markdown, no extra text. All numeric fields must be numbers
final_ranking_list = [ranking.model_dump() for ranking in result.rankings]
print(f" ✅ Selected {len(final_ranking_list)} top recommendations")
logger.info(f"✅ Selected {len(final_ranking_list)} top recommendations")
logger.info(
f"Successfully ranked {len(final_ranking_list)} opportunities: "
f"{[r['ticker'] for r in final_ranking_list]}"
@ -407,7 +449,7 @@ JSON FORMAT ONLY. No markdown, no extra text. All numeric fields must be numbers
)
state["tool_logs"] = tool_logs
# Structured output validation failed
print(f" ❌ Error: {e}")
logger.error(f"❌ Error: {e}")
logger.error(f"Structured output validation error: {e}")
return {"final_ranking": [], "opportunities": [], "status": "ranking_failed"}
@ -423,7 +465,7 @@ JSON FORMAT ONLY. No markdown, no extra text. All numeric fields must be numbers
error=str(e),
)
state["tool_logs"] = tool_logs
print(f" ❌ Error during ranking: {e}")
logger.error(f"❌ Error during ranking: {e}")
logger.exception(f"Unexpected error during ranking: {e}")
return {"final_ranking": [], "opportunities": [], "status": "error"}

View File

@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
import logging
logger = logging.getLogger(__name__)
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class BaseScanner(ABC):
@ -43,9 +44,7 @@ class BaseScanner(ABC):
candidates = self.scan(state)
if not isinstance(candidates, list):
logger.error(
f"{self.name}: scan() returned {type(candidates)}, expected list"
)
logger.error(f"{self.name}: scan() returned {type(candidates)}, expected list")
return []
# Validate each candidate
@ -58,7 +57,7 @@ class BaseScanner(ABC):
else:
logger.warning(
f"{self.name}: Invalid candidate #{i}: {candidate}",
extra={"scanner": self.name, "pipeline": self.pipeline}
extra={"scanner": self.name, "pipeline": self.pipeline},
)
if len(valid_candidates) < len(candidates):
@ -76,8 +75,8 @@ class BaseScanner(ABC):
extra={
"scanner": self.name,
"pipeline": self.pipeline,
"error_type": type(e).__name__
}
"error_type": type(e).__name__,
},
)
return []
@ -101,12 +100,12 @@ class ScannerRegistry:
# Check for duplicate registration
if scanner_class.name in self.scanners:
logger.warning(
f"Scanner '{scanner_class.name}' already registered, overwriting"
)
logger.warning(f"Scanner '{scanner_class.name}' already registered, overwriting")
self.scanners[scanner_class.name] = scanner_class
logger.info(f"Registered scanner: {scanner_class.name} (pipeline: {scanner_class.pipeline})")
logger.info(
f"Registered scanner: {scanner_class.name} (pipeline: {scanner_class.pipeline})"
)
def get_scanners_by_pipeline(self, pipeline: str) -> List[Type[BaseScanner]]:
return [sc for sc in self.scanners.values() if sc.pipeline == pipeline]

View File

@ -12,6 +12,9 @@ from tradingagents.dataflows.discovery.utils import (
resolve_trade_date_str,
)
from tradingagents.schemas import RedditTickerList
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
@dataclass
@ -129,7 +132,7 @@ class TraditionalScanner:
try:
return spec.handler(state)
except Exception as e:
print(f" Error running scanner '{spec.name}': {e}")
logger.error(f"Error running scanner '{spec.name}': {e}")
return []
def _run_tool(
@ -149,7 +152,7 @@ class TraditionalScanner:
**params,
)
except Exception as e:
print(f" Error during {step}: {e}")
logger.error(f"Error during {step}: {e}")
return default
def _run_call(
@ -162,7 +165,7 @@ class TraditionalScanner:
try:
return func(**kwargs)
except Exception as e:
print(f" Error {label}: {e}")
logger.error(f"Error {label}: {e}")
return default
def _scan_reddit(self, state: Dict[str, Any]) -> List[Dict[str, Any]]:
@ -183,7 +186,7 @@ class TraditionalScanner:
try:
from tradingagents.dataflows.reddit_api import get_reddit_undiscovered_dd
print(" 🔍 Scanning Reddit for undiscovered DD...")
logger.info("🔍 Scanning Reddit for undiscovered DD...")
# Note: get_reddit_undiscovered_dd is not a tool in strict sense but a direct function call
# that uses an LLM. We call it directly here as in original code.
reddit_dd_report = self._run_call(
@ -195,7 +198,7 @@ class TraditionalScanner:
llm_evaluator=self.llm, # Use fast LLM for evaluation
)
except Exception as e:
print(f" Error fetching undiscovered DD: {e}")
logger.error(f"Error fetching undiscovered DD: {e}")
# BATCHED LLM CALL: Extract tickers from both Reddit sources in ONE call
# Uses proper Pydantic structured output for clean, validated results
@ -220,7 +223,9 @@ IMPORTANT RULES:
{reddit_dd_report}
"""
combined_prompt += """Extract ALL mentioned stock tickers with their source and context."""
combined_prompt += (
"""Extract ALL mentioned stock tickers with their source and context."""
)
# Use proper Pydantic structured output (not raw JSON schema)
structured_llm = self.llm.with_structured_output(RedditTickerList)
@ -276,8 +281,8 @@ IMPORTANT RULES:
)
trending_count += 1
print(
f" Found {trending_count} trending + {dd_count} DD tickers from Reddit "
logger.info(
f"Found {trending_count} trending + {dd_count} DD tickers from Reddit "
f"(skipped {skipped_low_confidence} low-confidence)"
)
except Exception as e:
@ -292,7 +297,7 @@ IMPORTANT RULES:
error=str(e),
)
state["tool_logs"] = tool_logs
print(f" Error extracting Reddit tickers: {e}")
logger.error(f"Error extracting Reddit tickers: {e}")
return candidates
@ -301,7 +306,7 @@ IMPORTANT RULES:
candidates: List[Dict[str, Any]] = []
from tradingagents.dataflows.alpha_vantage_stock import get_top_gainers_losers
print(" 📊 Fetching market movers (direct parsing)...")
logger.info("📊 Fetching market movers (direct parsing)...")
movers_data = self._run_call(
"fetching market movers",
get_top_gainers_losers,
@ -343,9 +348,9 @@ IMPORTANT RULES:
)
movers_count += 1
print(f" Found {movers_count} market movers (direct)")
logger.info(f"Found {movers_count} market movers (direct)")
else:
print(" Market movers returned error or empty")
logger.warning("Market movers returned error or empty")
return candidates
@ -361,7 +366,7 @@ IMPORTANT RULES:
from_date = today.strftime("%Y-%m-%d")
to_date = (today + timedelta(days=self.max_days_until_earnings)).strftime("%Y-%m-%d")
print(f" 📅 Fetching earnings calendar (next {self.max_days_until_earnings} days)...")
logger.info(f"📅 Fetching earnings calendar (next {self.max_days_until_earnings} days)...")
earnings_data = self._run_call(
"fetching earnings calendar",
get_earnings_calendar,
@ -465,8 +470,8 @@ IMPORTANT RULES:
}
)
print(
f" Found {len(earnings_candidates)} earnings candidates (filtered from {len(earnings_data)} total, cap: {self.max_earnings_candidates})"
logger.info(
f"Found {len(earnings_candidates)} earnings candidates (filtered from {len(earnings_data)} total, cap: {self.max_earnings_candidates})"
)
return candidates
@ -474,7 +479,7 @@ IMPORTANT RULES:
def _scan_ipo(self, state: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Fetch IPO calendar."""
candidates: List[Dict[str, Any]] = []
from datetime import datetime, timedelta
from datetime import timedelta
from tradingagents.dataflows.finnhub_api import get_ipo_calendar
@ -482,7 +487,7 @@ IMPORTANT RULES:
from_date = (today - timedelta(days=7)).strftime("%Y-%m-%d")
to_date = (today + timedelta(days=14)).strftime("%Y-%m-%d")
print(" 🆕 Fetching IPO calendar (direct parsing)...")
logger.info("🆕 Fetching IPO calendar (direct parsing)...")
ipo_data = self._run_call(
"fetching IPO calendar",
get_ipo_calendar,
@ -515,7 +520,7 @@ IMPORTANT RULES:
)
ipo_count += 1
print(f" Found {ipo_count} IPO candidates (direct)")
logger.info(f"Found {ipo_count} IPO candidates (direct)")
return candidates
@ -524,7 +529,7 @@ IMPORTANT RULES:
candidates: List[Dict[str, Any]] = []
from tradingagents.dataflows.finviz_scraper import get_short_interest
print(" 🩳 Fetching short interest (direct parsing)...")
logger.info("🩳 Fetching short interest (direct parsing)...")
short_data = self._run_call(
"fetching short interest",
get_short_interest,
@ -554,7 +559,7 @@ IMPORTANT RULES:
)
short_count += 1
print(f" Found {short_count} short squeeze candidates (direct)")
logger.info(f"Found {short_count} short squeeze candidates (direct)")
return candidates
@ -565,7 +570,7 @@ IMPORTANT RULES:
today = resolve_trade_date_str(state)
print(" 📈 Fetching unusual volume (direct parsing)...")
logger.info("📈 Fetching unusual volume (direct parsing)...")
volume_data = self._run_call(
"fetching unusual volume",
get_unusual_volume,
@ -593,7 +598,9 @@ IMPORTANT RULES:
# Build context with direction info
direction_emoji = "🟢" if direction == "bullish" else ""
context = f"Volume: {vol_ratio}x avg, Price: {price_change:+.1f}%, "
context += f"Intraday: {intraday_change:+.1f}% {direction_emoji}, Signal: {signal}"
context += (
f"Intraday: {intraday_change:+.1f}% {direction_emoji}, Signal: {signal}"
)
# Strong accumulation gets highest priority
priority = "critical" if signal == "strong_accumulation" else "high"
@ -608,7 +615,9 @@ IMPORTANT RULES:
)
volume_count += 1
print(f" Found {volume_count} unusual volume candidates (direct, distribution filtered)")
logger.info(
f"Found {volume_count} unusual volume candidates (direct, distribution filtered)"
)
return candidates
@ -618,7 +627,7 @@ IMPORTANT RULES:
from tradingagents.dataflows.alpha_vantage_analysts import get_analyst_rating_changes
from tradingagents.dataflows.y_finance import check_if_price_reacted
print(" 📊 Fetching analyst rating changes (direct parsing)...")
logger.info("📊 Fetching analyst rating changes (direct parsing)...")
analyst_data = self._run_call(
"fetching analyst rating changes",
get_analyst_rating_changes,
@ -639,9 +648,7 @@ IMPORTANT RULES:
hours_old = entry.get("hours_old") or 0
freshness = (
"🔥 FRESH"
if hours_old < 24
else "🟢 Recent" if hours_old < 72 else "Older"
"🔥 FRESH" if hours_old < 24 else "🟢 Recent" if hours_old < 72 else "Older"
)
context = f"{action.upper()} from {source} ({freshness}, {hours_old}h ago)"
@ -651,12 +658,12 @@ IMPORTANT RULES:
ticker, lookback_days=3, reaction_threshold=10.0
)
if reaction["status"] == "leading":
context += (
f" | 💎 EARLY: Price {reaction['price_change_pct']:+.1f}%"
)
context += f" | 💎 EARLY: Price {reaction['price_change_pct']:+.1f}%"
priority = "high"
elif reaction["status"] == "lagging":
context += f" | ⚠️ LATE: Already moved {reaction['price_change_pct']:+.1f}%"
context += (
f" | ⚠️ LATE: Already moved {reaction['price_change_pct']:+.1f}%"
)
priority = "low"
else:
priority = "medium"
@ -673,7 +680,7 @@ IMPORTANT RULES:
)
analyst_count += 1
print(f" Found {analyst_count} analyst upgrade candidates (direct)")
logger.info(f"Found {analyst_count} analyst upgrade candidates (direct)")
return candidates
@ -682,7 +689,7 @@ IMPORTANT RULES:
candidates: List[Dict[str, Any]] = []
from tradingagents.dataflows.finviz_scraper import get_insider_buying_screener
print(" 💰 Fetching insider buying (direct parsing)...")
logger.info("💰 Fetching insider buying (direct parsing)...")
insider_data = self._run_call(
"fetching insider buying",
get_insider_buying_screener,
@ -718,7 +725,7 @@ IMPORTANT RULES:
)
insider_count += 1
print(f" Found {insider_count} insider buying candidates (direct)")
logger.info(f"Found {insider_count} insider buying candidates (direct)")
return candidates
@ -749,10 +756,10 @@ IMPORTANT RULES:
]
removed = before_count - len(candidates)
if removed:
print(f" Removed {removed} invalid tickers after batch validation.")
logger.info(f"Removed {removed} invalid tickers after batch validation.")
else:
print(" Batch validation returned no valid tickers; skipping filter.")
logger.warning("Batch validation returned no valid tickers; skipping filter.")
except Exception as e:
print(f" Error during batch validation: {e}")
logger.error(f"Error during batch validation: {e}")
return candidates

View File

@ -1,11 +1,14 @@
"""Discovery scanners for modular pipeline architecture."""
# Import all scanners to trigger registration
from . import insider_buying # noqa: F401
from . import options_flow # noqa: F401
from . import reddit_trending # noqa: F401
from . import market_movers # noqa: F401
from . import volume_accumulation # noqa: F401
from . import semantic_news # noqa: F401
from . import reddit_dd # noqa: F401
from . import earnings_calendar # noqa: F401
from . import (
earnings_calendar, # noqa: F401
insider_buying, # noqa: F401
market_movers, # noqa: F401
options_flow, # noqa: F401
reddit_dd, # noqa: F401
reddit_trending, # noqa: F401
semantic_news, # noqa: F401
volume_accumulation, # noqa: F401
ml_signal, # noqa: F401
)

View File

@ -1,10 +1,14 @@
"""Earnings calendar scanner for upcoming earnings events."""
from typing import Any, Dict, List
from datetime import datetime, timedelta
from tradingagents.dataflows.discovery.scanner_registry import BaseScanner, SCANNER_REGISTRY
from datetime import datetime, timedelta
from typing import Any, Dict, List
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.discovery.utils import Priority
from tradingagents.tools.executor import execute_tool
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class EarningsCalendarScanner(BaseScanner):
@ -23,17 +27,19 @@ class EarningsCalendarScanner(BaseScanner):
if not self.is_enabled():
return []
print(f" 📅 Scanning earnings calendar (next {self.max_days_until_earnings} days)...")
logger.info(f"📅 Scanning earnings calendar (next {self.max_days_until_earnings} days)...")
try:
# Get earnings calendar from Finnhub or Alpha Vantage
from_date = datetime.now().strftime("%Y-%m-%d")
to_date = (datetime.now() + timedelta(days=self.max_days_until_earnings)).strftime("%Y-%m-%d")
to_date = (datetime.now() + timedelta(days=self.max_days_until_earnings)).strftime(
"%Y-%m-%d"
)
result = execute_tool("get_earnings_calendar", from_date=from_date, to_date=to_date)
if not result:
print(f" Found 0 earnings events")
logger.info("Found 0 earnings events")
return []
candidates = []
@ -55,21 +61,23 @@ class EarningsCalendarScanner(BaseScanner):
candidates.sort(key=lambda x: x.get("days_until", 999))
# Apply limit
candidates = candidates[:self.limit]
candidates = candidates[: self.limit]
print(f" Found {len(candidates)} upcoming earnings")
logger.info(f"Found {len(candidates)} upcoming earnings")
return candidates
except Exception as e:
print(f" ⚠️ Earnings calendar failed: {e}")
logger.warning(f"⚠️ Earnings calendar failed: {e}")
return []
def _parse_structured_earnings(self, earnings_list: List[Dict], seen_tickers: set) -> List[Dict[str, Any]]:
def _parse_structured_earnings(
self, earnings_list: List[Dict], seen_tickers: set
) -> List[Dict[str, Any]]:
"""Parse structured earnings data."""
candidates = []
today = datetime.now().date()
for event in earnings_list[:self.max_candidates * 2]:
for event in earnings_list[: self.max_candidates * 2]:
ticker = event.get("ticker", event.get("symbol", "")).upper()
if not ticker or ticker in seen_tickers:
continue
@ -82,7 +90,9 @@ class EarningsCalendarScanner(BaseScanner):
try:
# Parse date (handle different formats)
if isinstance(earnings_date_str, str):
earnings_date = datetime.strptime(earnings_date_str.split()[0], "%Y-%m-%d").date()
earnings_date = datetime.strptime(
earnings_date_str.split()[0], "%Y-%m-%d"
).date()
else:
earnings_date = earnings_date_str
@ -107,15 +117,19 @@ class EarningsCalendarScanner(BaseScanner):
else:
priority = Priority.LOW.value
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Earnings in {days_until} day(s) on {earnings_date_str}",
"priority": priority,
"strategy": "pre_earnings_accumulation" if days_until > 1 else "earnings_play",
"days_until": days_until,
"earnings_date": earnings_date_str,
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": f"Earnings in {days_until} day(s) on {earnings_date_str}",
"priority": priority,
"strategy": (
"pre_earnings_accumulation" if days_until > 1 else "earnings_play"
),
"days_until": days_until,
"earnings_date": earnings_date_str,
}
)
if len(candidates) >= self.max_candidates:
break
@ -133,12 +147,12 @@ class EarningsCalendarScanner(BaseScanner):
today = datetime.now().date()
# Split by date sections (### 2026-02-05)
date_sections = re.split(r'###\s+(\d{4}-\d{2}-\d{2})', text)
date_sections = re.split(r"###\s+(\d{4}-\d{2}-\d{2})", text)
current_date = None
for i, section in enumerate(date_sections):
# Check if this is a date line
if re.match(r'\d{4}-\d{2}-\d{2}', section):
if re.match(r"\d{4}-\d{2}-\d{2}", section):
current_date = section
continue
@ -146,7 +160,7 @@ class EarningsCalendarScanner(BaseScanner):
continue
# Find tickers in this section (format: **TICKER** (timing))
ticker_pattern = r'\*\*([A-Z]{2,5})\*\*\s*\(([^\)]+)\)'
ticker_pattern = r"\*\*([A-Z]{2,5})\*\*\s*\(([^\)]+)\)"
ticker_matches = re.findall(ticker_pattern, section)
for ticker, timing in ticker_matches:
@ -174,20 +188,24 @@ class EarningsCalendarScanner(BaseScanner):
if timing == "bmo": # Before market open
strategy = "earnings_play"
elif timing == "amc": # After market close
strategy = "pre_earnings_accumulation" if days_until > 0 else "earnings_play"
strategy = (
"pre_earnings_accumulation" if days_until > 0 else "earnings_play"
)
else:
strategy = "pre_earnings_accumulation"
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Earnings {timing} in {days_until} day(s) on {current_date}",
"priority": priority,
"strategy": strategy,
"days_until": days_until,
"earnings_date": current_date,
"timing": timing,
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": f"Earnings {timing} in {days_until} day(s) on {current_date}",
"priority": priority,
"strategy": strategy,
"days_until": days_until,
"earnings_date": current_date,
"timing": timing,
}
)
if len(candidates) >= self.max_candidates:
return candidates

View File

@ -1,10 +1,12 @@
"""SEC Form 4 insider buying scanner."""
import re
from datetime import datetime, timedelta
from typing import Any, Dict, List
from tradingagents.dataflows.discovery.scanner_registry import BaseScanner, SCANNER_REGISTRY
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.discovery.utils import Priority
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class InsiderBuyingScanner(BaseScanner):
@ -22,7 +24,7 @@ class InsiderBuyingScanner(BaseScanner):
if not self.is_enabled():
return []
print(f" 💼 Scanning insider buying (last {self.lookback_days} days)...")
logger.info(f"💼 Scanning insider buying (last {self.lookback_days} days)...")
try:
# Use Finviz insider buying screener
@ -32,11 +34,11 @@ class InsiderBuyingScanner(BaseScanner):
transaction_type="buy",
lookback_days=self.lookback_days,
min_value=self.min_transaction_value,
top_n=self.limit
top_n=self.limit,
)
if not result or not isinstance(result, str):
print(f" Found 0 insider purchases")
logger.info("Found 0 insider purchases")
return []
# Parse the markdown result
@ -45,12 +47,13 @@ class InsiderBuyingScanner(BaseScanner):
# Extract tickers from markdown table
import re
lines = result.split('\n')
lines = result.split("\n")
for line in lines:
if '|' not in line or 'Ticker' in line or '---' in line:
if "|" not in line or "Ticker" in line or "---" in line:
continue
parts = [p.strip() for p in line.split('|')]
parts = [p.strip() for p in line.split("|")]
if len(parts) < 3:
continue
@ -61,29 +64,30 @@ class InsiderBuyingScanner(BaseScanner):
continue
# Validate ticker format
if not re.match(r'^[A-Z]{1,5}$', ticker):
if not re.match(r"^[A-Z]{1,5}$", ticker):
continue
seen_tickers.add(ticker)
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Insider purchase detected (Finviz)",
"priority": Priority.HIGH.value,
"strategy": "insider_buying",
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": "Insider purchase detected (Finviz)",
"priority": Priority.HIGH.value,
"strategy": "insider_buying",
}
)
if len(candidates) >= self.limit:
break
print(f" Found {len(candidates)} insider purchases")
logger.info(f"Found {len(candidates)} insider purchases")
return candidates
except Exception as e:
print(f" ⚠️ Insider buying failed: {e}")
logger.warning(f"⚠️ Insider buying failed: {e}")
return []
SCANNER_REGISTRY.register(InsiderBuyingScanner)

View File

@ -1,8 +1,12 @@
"""Market movers scanner - migrated from legacy TraditionalScanner."""
from typing import Any, Dict, List
from tradingagents.dataflows.discovery.scanner_registry import BaseScanner, SCANNER_REGISTRY
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.discovery.utils import Priority
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class MarketMoversScanner(BaseScanner):
@ -18,58 +22,59 @@ class MarketMoversScanner(BaseScanner):
if not self.is_enabled():
return []
print(f" 📈 Scanning market movers...")
logger.info("📈 Scanning market movers...")
from tradingagents.tools.executor import execute_tool
try:
result = execute_tool(
"get_market_movers",
return_structured=True
)
result = execute_tool("get_market_movers", return_structured=True)
if not result or not isinstance(result, dict):
return []
if "error" in result:
print(f" ⚠️ API error: {result['error']}")
logger.warning(f"⚠️ API error: {result['error']}")
return []
candidates = []
# Process gainers
for gainer in result.get("gainers", [])[:self.limit // 2]:
for gainer in result.get("gainers", [])[: self.limit // 2]:
ticker = gainer.get("ticker", "").upper()
if not ticker:
continue
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Top gainer: {gainer.get('change_percentage', 0)} change",
"priority": Priority.MEDIUM.value,
"strategy": "momentum",
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": f"Top gainer: {gainer.get('change_percentage', 0)} change",
"priority": Priority.MEDIUM.value,
"strategy": "momentum",
}
)
# Process losers (potential reversal plays)
for loser in result.get("losers", [])[:self.limit // 2]:
for loser in result.get("losers", [])[: self.limit // 2]:
ticker = loser.get("ticker", "").upper()
if not ticker:
continue
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Top loser: {loser.get('change_percentage', 0)} change (reversal play)",
"priority": Priority.LOW.value,
"strategy": "oversold_reversal",
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": f"Top loser: {loser.get('change_percentage', 0)} change (reversal play)",
"priority": Priority.LOW.value,
"strategy": "oversold_reversal",
}
)
print(f" Found {len(candidates)} market movers")
logger.info(f"Found {len(candidates)} market movers")
return candidates
except Exception as e:
print(f" ⚠️ Market movers failed: {e}")
logger.warning(f"⚠️ Market movers failed: {e}")
return []

View File

@ -0,0 +1,295 @@
"""ML signal scanner — surfaces high P(WIN) setups from a ticker universe.
Universe is loaded from a text file (one ticker per line, # comments allowed).
Default: data/tickers.txt. Override via config: discovery.scanners.ml_signal.ticker_file
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.discovery.utils import Priority
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Default ticker file path (relative to project root)
DEFAULT_TICKER_FILE = "data/tickers.txt"
def _load_tickers_from_file(path: str) -> List[str]:
"""Load ticker symbols from a text file (one per line, # comments allowed)."""
try:
with open(path) as f:
tickers = [
line.strip().upper()
for line in f
if line.strip() and not line.strip().startswith("#")
]
if tickers:
logger.info(f"ML scanner: loaded {len(tickers)} tickers from {path}")
return tickers
except FileNotFoundError:
logger.warning(f"Ticker file not found: {path}")
except Exception as e:
logger.warning(f"Failed to load ticker file {path}: {e}")
return []
class MLSignalScanner(BaseScanner):
"""Scan a ticker universe for high ML win-probability setups.
Loads the trained LightGBM/TabPFN model, fetches recent OHLCV data
for a universe of tickers, computes technical features, and returns
candidates whose predicted P(WIN) exceeds a configurable threshold.
Optimized for large universes (500+ tickers):
- Single batch yfinance download (1 HTTP request)
- Parallel feature computation via ThreadPoolExecutor
- Market cap skipped by default (1 NaN feature out of 30)
"""
name = "ml_signal"
pipeline = "momentum"
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.min_win_prob = self.scanner_config.get("min_win_prob", 0.35)
self.lookback_period = self.scanner_config.get("lookback_period", "1y")
self.max_workers = self.scanner_config.get("max_workers", 8)
self.fetch_market_cap = self.scanner_config.get("fetch_market_cap", False)
# Load universe: config list > config file > default tickers file
if "ticker_universe" in self.scanner_config:
self.universe = self.scanner_config["ticker_universe"]
else:
ticker_file = self.scanner_config.get(
"ticker_file",
config.get("tickers_file", DEFAULT_TICKER_FILE),
)
self.universe = _load_tickers_from_file(ticker_file)
if not self.universe:
logger.warning(f"No tickers loaded from {ticker_file} — scanner will be empty")
def scan(self, state: Dict[str, Any]) -> List[Dict[str, Any]]:
if not self.is_enabled():
return []
logger.info(
f"Running ML signal scanner on {len(self.universe)} tickers "
f"(min P(WIN) = {self.min_win_prob:.0%})..."
)
# 1. Load ML model
predictor = self._load_predictor()
if predictor is None:
logger.warning("No ML model available — skipping ml_signal scanner")
return []
# 2. Batch-fetch OHLCV data (single HTTP request)
ohlcv_by_ticker = self._fetch_universe_ohlcv()
if not ohlcv_by_ticker:
logger.warning("No OHLCV data fetched — skipping ml_signal scanner")
return []
# 3. Compute features and predict in parallel
candidates = self._predict_universe(predictor, ohlcv_by_ticker)
# 4. Sort by P(WIN) descending and apply limit
candidates.sort(key=lambda c: c.get("ml_win_prob", 0), reverse=True)
candidates = candidates[: self.limit]
logger.info(
f"ML signal scanner: {len(candidates)} candidates above "
f"{self.min_win_prob:.0%} threshold (from {len(ohlcv_by_ticker)} tickers)"
)
# Log individual candidate results
if candidates:
header = f"{'Ticker':<8} {'P(WIN)':>8} {'P(LOSS)':>9} {'Prediction':>12} {'Priority':>10}"
separator = "-" * len(header)
lines = ["\n ML Signal Scanner Results:", f" {header}", f" {separator}"]
for c in candidates:
lines.append(
f" {c['ticker']:<8} {c.get('ml_win_prob', 0):>7.1%} "
f"{c.get('ml_loss_prob', 0):>9.1%} "
f"{c.get('ml_prediction', 'N/A'):>12} "
f"{c.get('priority', 'N/A'):>10}"
)
lines.append(f" {separator}")
logger.info("\n".join(lines))
return candidates
def _load_predictor(self):
"""Load the trained ML model."""
try:
from tradingagents.ml.predictor import MLPredictor
return MLPredictor.load()
except Exception as e:
logger.warning(f"Failed to load ML predictor: {e}")
return None
def _fetch_universe_ohlcv(self) -> Dict[str, pd.DataFrame]:
"""Batch-fetch OHLCV data for the entire ticker universe.
Uses yfinance batch download a single HTTP request regardless of
universe size. This is the key optimization for large universes.
"""
try:
from tradingagents.dataflows.y_finance import download_history
logger.info(f"Batch-downloading {len(self.universe)} tickers ({self.lookback_period})...")
# yfinance batch download — single HTTP request for all tickers
raw = download_history(
" ".join(self.universe),
period=self.lookback_period,
auto_adjust=True,
progress=False,
)
if raw.empty:
return {}
# Handle multi-level columns from batch download
result = {}
if isinstance(raw.columns, pd.MultiIndex):
# Multi-ticker: columns are (Price, Ticker)
tickers_in_data = raw.columns.get_level_values(1).unique()
for ticker in tickers_in_data:
try:
ticker_df = raw.xs(ticker, level=1, axis=1).copy()
ticker_df = ticker_df.reset_index()
if len(ticker_df) > 0:
result[ticker] = ticker_df
except (KeyError, ValueError):
continue
else:
# Single ticker fallback
raw = raw.reset_index()
if len(self.universe) == 1:
result[self.universe[0]] = raw
logger.info(f"Fetched OHLCV for {len(result)} tickers")
return result
except Exception as e:
logger.warning(f"OHLCV batch fetch failed: {e}")
return {}
def _predict_universe(
self, predictor, ohlcv_by_ticker: Dict[str, pd.DataFrame]
) -> List[Dict[str, Any]]:
"""Predict P(WIN) for all tickers using parallel feature computation."""
candidates = []
if self.max_workers <= 1 or len(ohlcv_by_ticker) <= 10:
# Serial execution for small universes
for ticker, ohlcv in ohlcv_by_ticker.items():
result = self._predict_ticker(predictor, ticker, ohlcv)
if result is not None:
candidates.append(result)
else:
# Parallel feature computation for large universes
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._predict_ticker, predictor, ticker, ohlcv): ticker
for ticker, ohlcv in ohlcv_by_ticker.items()
}
for future in as_completed(futures):
try:
result = future.result(timeout=10)
if result is not None:
candidates.append(result)
except Exception as e:
ticker = futures[future]
logger.debug(f"{ticker}: prediction timed out or failed — {e}")
return candidates
def _predict_ticker(
self, predictor, ticker: str, ohlcv: pd.DataFrame
) -> Optional[Dict[str, Any]]:
"""Compute features and predict P(WIN) for a single ticker."""
try:
from tradingagents.ml.feature_engineering import (
MIN_HISTORY_ROWS,
compute_features_single,
)
if len(ohlcv) < MIN_HISTORY_ROWS:
return None
# Market cap: skip by default for speed (1 NaN out of 30 features)
market_cap = self._get_market_cap(ticker) if self.fetch_market_cap else None
# Compute features for the most recent date
latest_date = pd.to_datetime(ohlcv["Date"]).max().strftime("%Y-%m-%d")
features = compute_features_single(ohlcv, latest_date, market_cap=market_cap)
if features is None:
return None
# Run ML prediction
prediction = predictor.predict(features)
if prediction is None:
return None
win_prob = prediction.get("win_prob", 0)
loss_prob = prediction.get("loss_prob", 0)
if win_prob < self.min_win_prob:
return None
# Determine priority from P(WIN)
if win_prob >= 0.50:
priority = Priority.CRITICAL.value
elif win_prob >= 0.40:
priority = Priority.HIGH.value
else:
priority = Priority.MEDIUM.value
return {
"ticker": ticker,
"source": self.name,
"context": (
f"ML model: {win_prob:.0%} win probability, "
f"{loss_prob:.0%} loss probability "
f"({prediction.get('prediction', 'N/A')})"
),
"priority": priority,
"strategy": "ml_signal",
"ml_win_prob": win_prob,
"ml_loss_prob": loss_prob,
"ml_prediction": prediction.get("prediction", "N/A"),
}
except Exception as e:
logger.debug(f"{ticker}: ML prediction failed — {e}")
return None
def _get_market_cap(self, ticker: str) -> Optional[float]:
"""Get market cap (best-effort, cached in memory for the scan)."""
if not hasattr(self, "_market_cap_cache"):
self._market_cap_cache: Dict[str, Optional[float]] = {}
if ticker in self._market_cap_cache:
return self._market_cap_cache[ticker]
try:
from tradingagents.dataflows.y_finance import get_ticker_info
info = get_ticker_info(ticker)
cap = info.get("marketCap")
self._market_cap_cache[ticker] = cap
return cap
except Exception:
self._market_cap_cache[ticker] = None
return None
SCANNER_REGISTRY.register(MLSignalScanner)

View File

@ -1,8 +1,12 @@
"""Unusual options activity scanner."""
from typing import Any, Dict, List
import yfinance as yf
from tradingagents.dataflows.discovery.scanner_registry import BaseScanner, SCANNER_REGISTRY
from typing import Any, Dict, List
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.y_finance import get_option_chain, get_ticker_options
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class OptionsFlowScanner(BaseScanner):
@ -16,15 +20,15 @@ class OptionsFlowScanner(BaseScanner):
self.min_volume_oi_ratio = self.scanner_config.get("unusual_volume_multiple", 2.0)
self.min_volume = self.scanner_config.get("min_volume", 1000)
self.min_premium = self.scanner_config.get("min_premium", 25000)
self.ticker_universe = self.scanner_config.get("ticker_universe", [
"AAPL", "MSFT", "GOOGL", "AMZN", "META", "NVDA", "AMD", "TSLA"
])
self.ticker_universe = self.scanner_config.get(
"ticker_universe", ["AAPL", "MSFT", "GOOGL", "AMZN", "META", "NVDA", "AMD", "TSLA"]
)
def scan(self, state: Dict[str, Any]) -> List[Dict[str, Any]]:
if not self.is_enabled():
return []
print(f" Scanning unusual options activity...")
logger.info("Scanning unusual options activity...")
candidates = []
@ -38,17 +42,16 @@ class OptionsFlowScanner(BaseScanner):
except Exception:
continue
print(f" Found {len(candidates)} unusual options flows")
logger.info(f"Found {len(candidates)} unusual options flows")
return candidates
def _analyze_ticker_options(self, ticker: str) -> Dict[str, Any]:
try:
stock = yf.Ticker(ticker)
expirations = stock.options
expirations = get_ticker_options(ticker)
if not expirations:
return None
options = stock.option_chain(expirations[0])
options = get_option_chain(ticker, expirations[0])
calls = options.calls
puts = options.puts
@ -58,12 +61,9 @@ class OptionsFlowScanner(BaseScanner):
vol = opt.get("volume", 0)
oi = opt.get("openInterest", 0)
if oi > 0 and vol > self.min_volume and (vol / oi) >= self.min_volume_oi_ratio:
unusual_strikes.append({
"type": "call",
"strike": opt["strike"],
"volume": vol,
"oi": oi
})
unusual_strikes.append(
{"type": "call", "strike": opt["strike"], "volume": vol, "oi": oi}
)
if not unusual_strikes:
return None
@ -81,7 +81,7 @@ class OptionsFlowScanner(BaseScanner):
"context": f"Unusual options: {len(unusual_strikes)} strikes, P/C={pc_ratio:.2f} ({sentiment})",
"priority": "high" if sentiment == "bullish" else "medium",
"strategy": "options_flow",
"put_call_ratio": round(pc_ratio, 2)
"put_call_ratio": round(pc_ratio, 2),
}
except Exception:

View File

@ -1,9 +1,13 @@
"""Reddit DD (Due Diligence) scanner."""
from typing import Any, Dict, List
from tradingagents.dataflows.discovery.scanner_registry import BaseScanner, SCANNER_REGISTRY
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.discovery.utils import Priority
from tradingagents.tools.executor import execute_tool
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class RedditDDScanner(BaseScanner):
@ -19,17 +23,14 @@ class RedditDDScanner(BaseScanner):
if not self.is_enabled():
return []
print(f" 📝 Scanning Reddit DD posts...")
logger.info("📝 Scanning Reddit DD posts...")
try:
# Use Reddit DD scanner tool
result = execute_tool(
"scan_reddit_dd",
limit=self.limit
)
result = execute_tool("scan_reddit_dd", limit=self.limit)
if not result:
print(f" Found 0 DD posts")
logger.info("Found 0 DD posts")
return []
candidates = []
@ -37,7 +38,7 @@ class RedditDDScanner(BaseScanner):
# Handle different result formats
if isinstance(result, list):
# Structured result with DD posts
for post in result[:self.limit]:
for post in result[: self.limit]:
ticker = post.get("ticker", "").upper()
if not ticker:
continue
@ -48,39 +49,43 @@ class RedditDDScanner(BaseScanner):
# Higher score = higher priority
priority = Priority.HIGH.value if score > 1000 else Priority.MEDIUM.value
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Reddit DD: {title[:80]}... (score: {score})",
"priority": priority,
"strategy": "undiscovered_dd",
"dd_score": score,
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": f"Reddit DD: {title[:80]}... (score: {score})",
"priority": priority,
"strategy": "undiscovered_dd",
"dd_score": score,
}
)
elif isinstance(result, dict):
# Dict format
for ticker_data in result.get("posts", [])[:self.limit]:
for ticker_data in result.get("posts", [])[: self.limit]:
ticker = ticker_data.get("ticker", "").upper()
if not ticker:
continue
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Reddit DD post",
"priority": Priority.MEDIUM.value,
"strategy": "undiscovered_dd",
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": "Reddit DD post",
"priority": Priority.MEDIUM.value,
"strategy": "undiscovered_dd",
}
)
elif isinstance(result, str):
# Text result - extract tickers
candidates = self._parse_text_result(result)
print(f" Found {len(candidates)} DD posts")
logger.info(f"Found {len(candidates)} DD posts")
return candidates
except Exception as e:
print(f" ⚠️ Reddit DD scan failed, using fallback: {e}")
logger.warning(f"⚠️ Reddit DD scan failed, using fallback: {e}")
return self._fallback_dd_scan()
def _fallback_dd_scan(self) -> List[Dict[str, Any]]:
@ -99,7 +104,8 @@ class RedditDDScanner(BaseScanner):
for submission in subreddit.search("flair:DD", limit=self.limit * 2):
# Extract ticker from title
import re
ticker_pattern = r'\$([A-Z]{2,5})\b|^([A-Z]{2,5})\s'
ticker_pattern = r"\$([A-Z]{2,5})\b|^([A-Z]{2,5})\s"
matches = re.findall(ticker_pattern, submission.title)
if not matches:
@ -111,19 +117,21 @@ class RedditDDScanner(BaseScanner):
seen_tickers.add(ticker)
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Reddit DD: {submission.title[:80]}...",
"priority": Priority.MEDIUM.value,
"strategy": "undiscovered_dd",
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": f"Reddit DD: {submission.title[:80]}...",
"priority": Priority.MEDIUM.value,
"strategy": "undiscovered_dd",
}
)
if len(candidates) >= self.limit:
break
return candidates
except:
except Exception:
return []
def _parse_text_result(self, text: str) -> List[Dict[str, Any]]:
@ -131,19 +139,21 @@ class RedditDDScanner(BaseScanner):
import re
candidates = []
ticker_pattern = r'\$([A-Z]{2,5})\b|^([A-Z]{2,5})\s'
ticker_pattern = r"\$([A-Z]{2,5})\b|^([A-Z]{2,5})\s"
matches = re.findall(ticker_pattern, text)
tickers = list(set([t[0] or t[1] for t in matches if t[0] or t[1]]))
for ticker in tickers[:self.limit]:
candidates.append({
"ticker": ticker,
"source": self.name,
"context": "Reddit DD post",
"priority": Priority.MEDIUM.value,
"strategy": "undiscovered_dd",
})
for ticker in tickers[: self.limit]:
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": "Reddit DD post",
"priority": Priority.MEDIUM.value,
"strategy": "undiscovered_dd",
}
)
return candidates

View File

@ -1,8 +1,12 @@
"""Reddit trending scanner - migrated from legacy TraditionalScanner."""
from typing import Any, Dict, List
from tradingagents.dataflows.discovery.scanner_registry import BaseScanner, SCANNER_REGISTRY
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.discovery.utils import Priority
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class RedditTrendingScanner(BaseScanner):
@ -18,21 +22,18 @@ class RedditTrendingScanner(BaseScanner):
if not self.is_enabled():
return []
print(f" 📱 Scanning Reddit trending...")
logger.info("📱 Scanning Reddit trending...")
from tradingagents.tools.executor import execute_tool
try:
result = execute_tool(
"get_trending_tickers",
limit=self.limit
)
result = execute_tool("get_trending_tickers", limit=self.limit)
if not result or not isinstance(result, str):
return []
if "Error" in result or "No trending" in result:
print(f" ⚠️ {result}")
logger.warning(f"⚠️ {result}")
return []
# Extract tickers using common utility
@ -41,20 +42,22 @@ class RedditTrendingScanner(BaseScanner):
tickers_found = extract_tickers_from_text(result)
candidates = []
for ticker in tickers_found[:self.limit]:
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Reddit trending discussion",
"priority": Priority.MEDIUM.value,
"strategy": "social_hype",
})
for ticker in tickers_found[: self.limit]:
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": "Reddit trending discussion",
"priority": Priority.MEDIUM.value,
"strategy": "social_hype",
}
)
print(f" Found {len(candidates)} Reddit trending tickers")
logger.info(f"Found {len(candidates)} Reddit trending tickers")
return candidates
except Exception as e:
print(f" ⚠️ Reddit trending failed: {e}")
logger.warning(f"⚠️ Reddit trending failed: {e}")
return []

View File

@ -1,8 +1,12 @@
"""Semantic news scanner for early catalyst detection."""
from typing import Any, Dict, List
from tradingagents.dataflows.discovery.scanner_registry import BaseScanner, SCANNER_REGISTRY
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.discovery.utils import Priority
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class SemanticNewsScanner(BaseScanner):
@ -22,12 +26,13 @@ class SemanticNewsScanner(BaseScanner):
if not self.is_enabled():
return []
print(f" 📰 Scanning news catalysts...")
logger.info("📰 Scanning news catalysts...")
try:
from tradingagents.tools.executor import execute_tool
from datetime import datetime
from tradingagents.tools.executor import execute_tool
# Get recent global news
date_str = datetime.now().strftime("%Y-%m-%d")
result = execute_tool("get_global_news", date=date_str)
@ -37,30 +42,44 @@ class SemanticNewsScanner(BaseScanner):
# Extract tickers mentioned in news
import re
ticker_pattern = r'\b([A-Z]{2,5})\b|\$([A-Z]{2,5})'
ticker_pattern = r"\b([A-Z]{2,5})\b|\$([A-Z]{2,5})"
matches = re.findall(ticker_pattern, result)
tickers = list(set([t[0] or t[1] for t in matches if t[0] or t[1]]))
stop_words = {'NYSE', 'NASDAQ', 'CEO', 'CFO', 'IPO', 'ETF', 'USA', 'SEC', 'NEWS', 'STOCK', 'MARKET'}
stop_words = {
"NYSE",
"NASDAQ",
"CEO",
"CFO",
"IPO",
"ETF",
"USA",
"SEC",
"NEWS",
"STOCK",
"MARKET",
}
tickers = [t for t in tickers if t not in stop_words]
candidates = []
for ticker in tickers[:self.limit]:
candidates.append({
"ticker": ticker,
"source": self.name,
"context": "Mentioned in recent market news",
"priority": Priority.MEDIUM.value,
"strategy": "news_catalyst",
})
for ticker in tickers[: self.limit]:
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": "Mentioned in recent market news",
"priority": Priority.MEDIUM.value,
"strategy": "news_catalyst",
}
)
print(f" Found {len(candidates)} news mentions")
logger.info(f"Found {len(candidates)} news mentions")
return candidates
except Exception as e:
print(f" ⚠️ News scan failed: {e}")
logger.warning(f"⚠️ News scan failed: {e}")
return []
SCANNER_REGISTRY.register(SemanticNewsScanner)

View File

@ -1,9 +1,13 @@
"""Volume accumulation and compression scanner."""
from typing import Any, Dict, List
from tradingagents.dataflows.discovery.scanner_registry import BaseScanner, SCANNER_REGISTRY
from tradingagents.dataflows.discovery.scanner_registry import SCANNER_REGISTRY, BaseScanner
from tradingagents.dataflows.discovery.utils import Priority
from tradingagents.tools.executor import execute_tool
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class VolumeAccumulationScanner(BaseScanner):
@ -21,18 +25,18 @@ class VolumeAccumulationScanner(BaseScanner):
if not self.is_enabled():
return []
print(f" 📊 Scanning volume accumulation...")
logger.info("📊 Scanning volume accumulation...")
try:
# Use volume scanner tool
result = execute_tool(
"get_unusual_volume",
min_volume_multiple=self.unusual_volume_multiple,
top_n=self.limit
top_n=self.limit,
)
if not result:
print(f" Found 0 volume accumulation candidates")
logger.info("Found 0 volume accumulation candidates")
return []
candidates = []
@ -43,7 +47,7 @@ class VolumeAccumulationScanner(BaseScanner):
candidates = self._parse_text_result(result)
elif isinstance(result, list):
# Structured result
for item in result[:self.limit]:
for item in result[: self.limit]:
ticker = item.get("ticker", "").upper()
if not ticker:
continue
@ -51,29 +55,35 @@ class VolumeAccumulationScanner(BaseScanner):
volume_ratio = item.get("volume_ratio", 0)
avg_volume = item.get("avg_volume", 0)
candidates.append({
"ticker": ticker,
"source": self.name,
"context": f"Unusual volume: {volume_ratio:.1f}x average ({avg_volume:,})",
"priority": Priority.MEDIUM.value if volume_ratio < 3.0 else Priority.HIGH.value,
"strategy": "volume_accumulation",
})
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": f"Unusual volume: {volume_ratio:.1f}x average ({avg_volume:,})",
"priority": (
Priority.MEDIUM.value if volume_ratio < 3.0 else Priority.HIGH.value
),
"strategy": "volume_accumulation",
}
)
elif isinstance(result, dict):
# Dict with tickers list
for ticker in result.get("tickers", [])[:self.limit]:
candidates.append({
"ticker": ticker.upper(),
"source": self.name,
"context": f"Unusual volume accumulation",
"priority": Priority.MEDIUM.value,
"strategy": "volume_accumulation",
})
for ticker in result.get("tickers", [])[: self.limit]:
candidates.append(
{
"ticker": ticker.upper(),
"source": self.name,
"context": "Unusual volume accumulation",
"priority": Priority.MEDIUM.value,
"strategy": "volume_accumulation",
}
)
print(f" Found {len(candidates)} volume accumulation candidates")
logger.info(f"Found {len(candidates)} volume accumulation candidates")
return candidates
except Exception as e:
print(f" ⚠️ Volume accumulation failed: {e}")
logger.warning(f"⚠️ Volume accumulation failed: {e}")
return []
def _parse_text_result(self, text: str) -> List[Dict[str, Any]]:
@ -83,14 +93,16 @@ class VolumeAccumulationScanner(BaseScanner):
candidates = []
tickers = extract_tickers_from_text(text)
for ticker in tickers[:self.limit]:
candidates.append({
"ticker": ticker,
"source": self.name,
"context": "Unusual volume detected",
"priority": Priority.MEDIUM.value,
"strategy": "volume_accumulation",
})
for ticker in tickers[: self.limit]:
candidates.append(
{
"ticker": ticker,
"source": self.name,
"context": "Unusual volume detected",
"priority": Priority.MEDIUM.value,
"strategy": "volume_accumulation",
}
)
return candidates

View File

@ -6,7 +6,7 @@ with the ticker universe CSV.
Usage:
from tradingagents.dataflows.discovery.ticker_matcher import match_company_to_ticker
ticker = match_company_to_ticker("Apple Inc")
# Returns: "AAPL"
"""
@ -14,108 +14,116 @@ Usage:
import csv
import re
from pathlib import Path
from typing import Dict, Optional, Tuple
from typing import Dict, Optional
from rapidfuzz import fuzz, process
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Global cache
_TICKER_UNIVERSE: Optional[Dict[str, str]] = None # ticker -> name
_NAME_TO_TICKER: Optional[Dict[str, str]] = None # normalized_name -> ticker
_MATCH_CACHE: Dict[str, Optional[str]] = {} # company_name -> ticker
_NAME_TO_TICKER: Optional[Dict[str, str]] = None # normalized_name -> ticker
_MATCH_CACHE: Dict[str, Optional[str]] = {} # company_name -> ticker
def _normalize_company_name(name: str) -> str:
"""
Normalize company name for matching.
Removes common suffixes, punctuation, and standardizes format.
"""
if not name:
return ""
# Convert to uppercase
name = name.upper()
# Remove common suffixes
suffixes = [
r'\s+INC\.?',
r'\s+INCORPORATED',
r'\s+CORP\.?',
r'\s+CORPORATION',
r'\s+LTD\.?',
r'\s+LIMITED',
r'\s+LLC',
r'\s+L\.?L\.?C\.?',
r'\s+PLC',
r'\s+CO\.?',
r'\s+COMPANY',
r'\s+CLASS [A-Z]',
r'\s+COMMON STOCK',
r'\s+ORDINARY SHARES?',
r'\s+-\s+.*$', # Remove everything after dash
r'\s+\(.*?\)', # Remove parenthetical
r"\s+INC\.?",
r"\s+INCORPORATED",
r"\s+CORP\.?",
r"\s+CORPORATION",
r"\s+LTD\.?",
r"\s+LIMITED",
r"\s+LLC",
r"\s+L\.?L\.?C\.?",
r"\s+PLC",
r"\s+CO\.?",
r"\s+COMPANY",
r"\s+CLASS [A-Z]",
r"\s+COMMON STOCK",
r"\s+ORDINARY SHARES?",
r"\s+-\s+.*$", # Remove everything after dash
r"\s+\(.*?\)", # Remove parenthetical
]
for suffix in suffixes:
name = re.sub(suffix, '', name, flags=re.IGNORECASE)
name = re.sub(suffix, "", name, flags=re.IGNORECASE)
# Remove punctuation except spaces
name = re.sub(r'[^\w\s]', '', name)
name = re.sub(r"[^\w\s]", "", name)
# Normalize whitespace
name = ' '.join(name.split())
name = " ".join(name.split())
return name.strip()
def load_ticker_universe(force_reload: bool = False) -> Dict[str, str]:
"""
Load ticker universe from CSV.
Args:
force_reload: Force reload even if already loaded
Returns:
Dict mapping ticker -> company name
"""
global _TICKER_UNIVERSE, _NAME_TO_TICKER
if _TICKER_UNIVERSE is not None and not force_reload:
return _TICKER_UNIVERSE
# Find CSV file
project_root = Path(__file__).parent.parent.parent.parent
csv_path = project_root / "data" / "ticker_universe.csv"
if not csv_path.exists():
raise FileNotFoundError(f"Ticker universe not found: {csv_path}")
ticker_universe = {}
name_to_ticker = {}
with open(csv_path, 'r', encoding='utf-8') as f:
with open(csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
ticker = row['ticker']
name = row['name']
ticker = row["ticker"]
name = row["name"]
# Store ticker -> name mapping
ticker_universe[ticker] = name
# Build reverse index (normalized name -> ticker)
normalized = _normalize_company_name(name)
if normalized:
# If multiple tickers have same normalized name, prefer common stocks
if normalized not in name_to_ticker:
name_to_ticker[normalized] = ticker
elif "COMMON" in name.upper() and "COMMON" not in ticker_universe.get(name_to_ticker[normalized], "").upper():
elif (
"COMMON" in name.upper()
and "COMMON" not in ticker_universe.get(name_to_ticker[normalized], "").upper()
):
# Prefer common stock over other securities
name_to_ticker[normalized] = ticker
_TICKER_UNIVERSE = ticker_universe
_NAME_TO_TICKER = name_to_ticker
print(f" Loaded {len(ticker_universe)} tickers from universe")
logger.info(f"Loaded {len(ticker_universe)} tickers from universe")
return ticker_universe
@ -126,15 +134,15 @@ def match_company_to_ticker(
) -> Optional[str]:
"""
Match a company name to a ticker symbol using fuzzy matching.
Args:
company_name: Company name from 13F filing
min_confidence: Minimum fuzzy match score (0-100)
use_cache: Use cached results
Returns:
Ticker symbol or None if no good match found
Examples:
>>> match_company_to_ticker("Apple Inc")
'AAPL'
@ -145,51 +153,48 @@ def match_company_to_ticker(
"""
if not company_name:
return None
# Check cache
if use_cache and company_name in _MATCH_CACHE:
return _MATCH_CACHE[company_name]
# Ensure universe is loaded
if _TICKER_UNIVERSE is None or _NAME_TO_TICKER is None:
load_ticker_universe()
# Normalize input
normalized_input = _normalize_company_name(company_name)
if not normalized_input:
return None
# Try exact match first
if normalized_input in _NAME_TO_TICKER:
result = _NAME_TO_TICKER[normalized_input]
_MATCH_CACHE[company_name] = result
return result
# Fuzzy match against all normalized names
choices = list(_NAME_TO_TICKER.keys())
# Use token_sort_ratio for best results with company names
match_result = process.extractOne(
normalized_input,
choices,
scorer=fuzz.token_sort_ratio,
score_cutoff=min_confidence
normalized_input, choices, scorer=fuzz.token_sort_ratio, score_cutoff=min_confidence
)
if match_result:
matched_name, score, _ = match_result
ticker = _NAME_TO_TICKER[matched_name]
# Log match for debugging
if score < 95:
print(f" Fuzzy match: '{company_name}' -> {ticker} (score: {score:.1f})")
logger.info(f"Fuzzy match: '{company_name}' -> {ticker} (score: {score:.1f})")
_MATCH_CACHE[company_name] = ticker
return ticker
# No match found
print(f" No ticker match for: '{company_name}'")
logger.info(f"No ticker match for: '{company_name}'")
_MATCH_CACHE[company_name] = None
return None
@ -197,26 +202,26 @@ def match_company_to_ticker(
def get_match_confidence(company_name: str, ticker: str) -> float:
"""
Get confidence score for a company name -> ticker match.
Args:
company_name: Company name
ticker: Ticker symbol
Returns:
Confidence score (0-100)
"""
if _TICKER_UNIVERSE is None:
load_ticker_universe()
if ticker not in _TICKER_UNIVERSE:
return 0.0
ticker_name = _TICKER_UNIVERSE[ticker]
# Normalize both names
norm_input = _normalize_company_name(company_name)
norm_ticker = _normalize_company_name(ticker_name)
# Calculate similarity
return fuzz.token_sort_ratio(norm_input, norm_ticker)

View File

@ -22,6 +22,7 @@ PERMANENTLY_DELISTED = {
"SVIVU",
}
# Priority and strategy enums for consistent labeling.
class Priority(str, Enum):
CRITICAL = "critical"
@ -123,6 +124,7 @@ def append_llm_log(
tool_logs.append(entry)
return entry
def get_delisted_tickers() -> Set[str]:
"""Get combined list of delisted tickers from permanent list + dynamic cache."""
# Local import to avoid circular dependencies if any

View File

@ -1,50 +1,55 @@
import os
from datetime import datetime
from typing import Annotated, Any, Dict
import finnhub
from typing import Annotated
from dotenv import load_dotenv
from tradingagents.utils.logger import get_logger
from tradingagents.config import config
load_dotenv()
logger = get_logger(__name__)
def get_finnhub_client():
"""Get authenticated Finnhub client."""
api_key = os.getenv("FINNHUB_API_KEY")
if not api_key:
raise ValueError("FINNHUB_API_KEY not found in environment variables.")
api_key = config.validate_key("finnhub_api_key", "Finnhub")
return finnhub.Client(api_key=api_key)
def get_recommendation_trends(
ticker: Annotated[str, "Ticker symbol of the company"]
) -> str:
def get_recommendation_trends(ticker: Annotated[str, "Ticker symbol of the company"]) -> str:
"""
Get analyst recommendation trends for a stock.
Shows the distribution of buy/hold/sell recommendations over time.
Args:
ticker: Stock ticker symbol (e.g., "AAPL", "TSLA")
Returns:
str: Formatted report of recommendation trends
"""
try:
client = get_finnhub_client()
data = client.recommendation_trends(ticker.upper())
if not data:
return f"No recommendation trends data found for {ticker}"
# Format the response
result = f"## Analyst Recommendation Trends for {ticker.upper()}\n\n"
for entry in data:
period = entry.get('period', 'N/A')
strong_buy = entry.get('strongBuy', 0)
buy = entry.get('buy', 0)
hold = entry.get('hold', 0)
sell = entry.get('sell', 0)
strong_sell = entry.get('strongSell', 0)
period = entry.get("period", "N/A")
strong_buy = entry.get("strongBuy", 0)
buy = entry.get("buy", 0)
hold = entry.get("hold", 0)
sell = entry.get("sell", 0)
strong_sell = entry.get("strongSell", 0)
total = strong_buy + buy + hold + sell + strong_sell
result += f"### {period}\n"
result += f"- **Strong Buy**: {strong_buy}\n"
result += f"- **Buy**: {buy}\n"
@ -52,32 +57,37 @@ def get_recommendation_trends(
result += f"- **Sell**: {sell}\n"
result += f"- **Strong Sell**: {strong_sell}\n"
result += f"- **Total Analysts**: {total}\n\n"
# Calculate sentiment
if total > 0:
bullish_pct = ((strong_buy + buy) / total) * 100
bearish_pct = ((sell + strong_sell) / total) * 100
result += f"**Sentiment**: {bullish_pct:.1f}% Bullish, {bearish_pct:.1f}% Bearish\n\n"
result += (
f"**Sentiment**: {bullish_pct:.1f}% Bullish, {bearish_pct:.1f}% Bearish\n\n"
)
return result
except Exception as e:
return f"Error fetching recommendation trends for {ticker}: {str(e)}"
def get_earnings_calendar(
from_date: Annotated[str, "Start date in yyyy-mm-dd format"],
to_date: Annotated[str, "End date in yyyy-mm-dd format"]
) -> str:
to_date: Annotated[str, "End date in yyyy-mm-dd format"],
return_structured: Annotated[bool, "Return list of dicts instead of markdown"] = False,
):
"""
Get earnings calendar for stocks with upcoming earnings announcements.
Args:
from_date: Start date in yyyy-mm-dd format
to_date: End date in yyyy-mm-dd format
return_structured: If True, returns list of earnings dicts instead of markdown
Returns:
str: Formatted report of upcoming earnings
If return_structured=True: list of earnings dicts with symbol, date, epsEstimate, etc.
If return_structured=False: Formatted markdown report
"""
try:
client = get_finnhub_client()
@ -85,17 +95,25 @@ def get_earnings_calendar(
_from=from_date,
to=to_date,
symbol="", # Empty string returns all stocks
international=False
international=False,
)
if not data or 'earningsCalendar' not in data:
if not data or "earningsCalendar" not in data:
if return_structured:
return []
return f"No earnings data found for period {from_date} to {to_date}"
earnings = data['earningsCalendar']
earnings = data["earningsCalendar"]
if not earnings:
if return_structured:
return []
return f"No earnings scheduled between {from_date} and {to_date}"
# Return structured data if requested
if return_structured:
return earnings
# Format the response
result = f"## Earnings Calendar ({from_date} to {to_date})\n\n"
result += f"**Total Companies**: {len(earnings)}\n\n"
@ -103,7 +121,7 @@ def get_earnings_calendar(
# Group by date
by_date = {}
for entry in earnings:
date = entry.get('date', 'Unknown')
date = entry.get("date", "Unknown")
if date not in by_date:
by_date[date] = []
by_date[date].append(entry)
@ -113,28 +131,44 @@ def get_earnings_calendar(
result += f"### {date}\n\n"
for entry in by_date[date]:
symbol = entry.get('symbol', 'N/A')
eps_estimate = entry.get('epsEstimate', 'N/A')
eps_actual = entry.get('epsActual', 'N/A')
revenue_estimate = entry.get('revenueEstimate', 'N/A')
revenue_actual = entry.get('revenueActual', 'N/A')
hour = entry.get('hour', 'N/A')
symbol = entry.get("symbol", "N/A")
eps_estimate = entry.get("epsEstimate", "N/A")
eps_actual = entry.get("epsActual", "N/A")
revenue_estimate = entry.get("revenueEstimate", "N/A")
revenue_actual = entry.get("revenueActual", "N/A")
hour = entry.get("hour", "N/A")
result += f"**{symbol}**"
if hour != 'N/A':
if hour != "N/A":
result += f" ({hour})"
result += "\n"
if eps_estimate != 'N/A':
result += f" - EPS Estimate: ${eps_estimate:.2f}" if isinstance(eps_estimate, (int, float)) else f" - EPS Estimate: {eps_estimate}"
if eps_actual != 'N/A':
result += f" | Actual: ${eps_actual:.2f}" if isinstance(eps_actual, (int, float)) else f" | Actual: {eps_actual}"
if eps_estimate != "N/A":
result += (
f" - EPS Estimate: ${eps_estimate:.2f}"
if isinstance(eps_estimate, (int, float))
else f" - EPS Estimate: {eps_estimate}"
)
if eps_actual != "N/A":
result += (
f" | Actual: ${eps_actual:.2f}"
if isinstance(eps_actual, (int, float))
else f" | Actual: {eps_actual}"
)
result += "\n"
if revenue_estimate != 'N/A':
result += f" - Revenue Estimate: ${revenue_estimate:,.0f}M" if isinstance(revenue_estimate, (int, float)) else f" - Revenue Estimate: {revenue_estimate}"
if revenue_actual != 'N/A':
result += f" | Actual: ${revenue_actual:,.0f}M" if isinstance(revenue_actual, (int, float)) else f" | Actual: {revenue_actual}"
if revenue_estimate != "N/A":
result += (
f" - Revenue Estimate: ${revenue_estimate:,.0f}M"
if isinstance(revenue_estimate, (int, float))
else f" - Revenue Estimate: {revenue_estimate}"
)
if revenue_actual != "N/A":
result += (
f" | Actual: ${revenue_actual:,.0f}M"
if isinstance(revenue_actual, (int, float))
else f" | Actual: {revenue_actual}"
)
result += "\n"
result += "\n"
@ -142,38 +176,105 @@ def get_earnings_calendar(
return result
except Exception as e:
if return_structured:
return []
return f"Error fetching earnings calendar: {str(e)}"
def get_ticker_earnings_estimate(
ticker: str,
from_date: str,
to_date: str,
) -> Dict[str, Any]:
"""
Get upcoming earnings estimate for a single ticker.
Returns dict with: has_upcoming_earnings, days_to_earnings,
eps_estimate, revenue_estimate, earnings_date, hour.
"""
result: Dict[str, Any] = {
"has_upcoming_earnings": False,
"days_to_earnings": None,
"eps_estimate": None,
"revenue_estimate": None,
"earnings_date": None,
"hour": None,
}
try:
client = get_finnhub_client()
data = client.earnings_calendar(
_from=from_date,
to=to_date,
symbol=ticker.upper(),
international=False,
)
if not data or "earningsCalendar" not in data:
return result
earnings = data["earningsCalendar"]
if not earnings:
return result
# Take the nearest upcoming entry
entry = earnings[0]
earnings_date = entry.get("date")
if earnings_date:
try:
ed = datetime.strptime(earnings_date, "%Y-%m-%d")
fd = datetime.strptime(from_date, "%Y-%m-%d")
result["days_to_earnings"] = (ed - fd).days
except ValueError:
pass
result["has_upcoming_earnings"] = True
result["earnings_date"] = earnings_date
result["eps_estimate"] = entry.get("epsEstimate")
result["revenue_estimate"] = entry.get("revenueEstimate")
result["hour"] = entry.get("hour")
return result
except Exception as e:
logger.warning(f"Could not fetch earnings estimate for {ticker}: {e}")
return result
def get_ipo_calendar(
from_date: Annotated[str, "Start date in yyyy-mm-dd format"],
to_date: Annotated[str, "End date in yyyy-mm-dd format"]
) -> str:
to_date: Annotated[str, "End date in yyyy-mm-dd format"],
return_structured: Annotated[bool, "Return list of dicts instead of markdown"] = False,
):
"""
Get IPO calendar for upcoming and recent initial public offerings.
Args:
from_date: Start date in yyyy-mm-dd format
to_date: End date in yyyy-mm-dd format
return_structured: If True, returns list of IPO dicts instead of markdown
Returns:
str: Formatted report of IPOs
If return_structured=True: list of IPO dicts with symbol, name, date, etc.
If return_structured=False: Formatted markdown report
"""
try:
client = get_finnhub_client()
data = client.ipo_calendar(
_from=from_date,
to=to_date
)
data = client.ipo_calendar(_from=from_date, to=to_date)
if not data or 'ipoCalendar' not in data:
if not data or "ipoCalendar" not in data:
if return_structured:
return []
return f"No IPO data found for period {from_date} to {to_date}"
ipos = data['ipoCalendar']
ipos = data["ipoCalendar"]
if not ipos:
if return_structured:
return []
return f"No IPOs scheduled between {from_date} and {to_date}"
# Return structured data if requested
if return_structured:
return ipos
# Format the response
result = f"## IPO Calendar ({from_date} to {to_date})\n\n"
result += f"**Total IPOs**: {len(ipos)}\n\n"
@ -181,7 +282,7 @@ def get_ipo_calendar(
# Group by date
by_date = {}
for entry in ipos:
date = entry.get('date', 'Unknown')
date = entry.get("date", "Unknown")
if date not in by_date:
by_date[date] = []
by_date[date].append(entry)
@ -191,29 +292,39 @@ def get_ipo_calendar(
result += f"### {date}\n\n"
for entry in by_date[date]:
symbol = entry.get('symbol', 'N/A')
name = entry.get('name', 'N/A')
exchange = entry.get('exchange', 'N/A')
price = entry.get('price', 'N/A')
shares = entry.get('numberOfShares', 'N/A')
total_shares = entry.get('totalSharesValue', 'N/A')
status = entry.get('status', 'N/A')
symbol = entry.get("symbol", "N/A")
name = entry.get("name", "N/A")
exchange = entry.get("exchange", "N/A")
price = entry.get("price", "N/A")
shares = entry.get("numberOfShares", "N/A")
total_shares = entry.get("totalSharesValue", "N/A")
status = entry.get("status", "N/A")
result += f"**{symbol}** - {name}\n"
result += f" - Exchange: {exchange}\n"
if price != 'N/A':
if price != "N/A":
result += f" - Price: ${price}\n"
if shares != 'N/A':
result += f" - Shares Offered: {shares:,}\n" if isinstance(shares, (int, float)) else f" - Shares Offered: {shares}\n"
if shares != "N/A":
result += (
f" - Shares Offered: {shares:,}\n"
if isinstance(shares, (int, float))
else f" - Shares Offered: {shares}\n"
)
if total_shares != 'N/A':
result += f" - Total Value: ${total_shares:,.0f}M\n" if isinstance(total_shares, (int, float)) else f" - Total Value: {total_shares}\n"
if total_shares != "N/A":
result += (
f" - Total Value: ${total_shares:,.0f}M\n"
if isinstance(total_shares, (int, float))
else f" - Total Value: {total_shares}\n"
)
result += f" - Status: {status}\n\n"
return result
except Exception as e:
if return_structured:
return []
return f"Error fetching IPO calendar: {str(e)}"

View File

@ -3,19 +3,25 @@ Finviz + Yahoo Finance Hybrid - Short Interest Discovery
Uses Finviz to discover tickers with high short interest, then Yahoo Finance for exact data
"""
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Annotated
import requests
from bs4 import BeautifulSoup
from typing import Annotated
import re
import yfinance as yf
from concurrent.futures import ThreadPoolExecutor, as_completed
from tradingagents.dataflows.y_finance import get_ticker_info
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def get_short_interest(
min_short_interest_pct: Annotated[float, "Minimum short interest % of float"] = 10.0,
min_days_to_cover: Annotated[float, "Minimum days to cover ratio"] = 2.0,
top_n: Annotated[int, "Number of top results to return"] = 20,
) -> str:
return_structured: Annotated[bool, "Return dict with raw data instead of markdown"] = False,
):
"""
Discover stocks with high short interest using Finviz + Yahoo Finance.
@ -29,13 +35,17 @@ def get_short_interest(
min_short_interest_pct: Minimum short interest as % of float
min_days_to_cover: Minimum days to cover ratio
top_n: Number of top results to return
return_structured: If True, returns list of dicts instead of markdown
Returns:
Formatted markdown report of discovered high short interest stocks
If return_structured=True: list of candidate dicts with ticker, short_interest_pct, signal, etc.
If return_structured=False: Formatted markdown report
"""
try:
# Step 1: Use Finviz screener to DISCOVER tickers with high short interest
print(f" Discovering tickers with short interest >{min_short_interest_pct}% from Finviz...")
logger.info(
f"Discovering tickers with short interest >{min_short_interest_pct}% from Finviz..."
)
# Determine Finviz filter
if min_short_interest_pct >= 20:
@ -51,8 +61,8 @@ def get_short_interest(
base_url = f"https://finviz.com/screener.ashx?v=152&f={short_filter}"
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
'Accept': 'text/html',
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Accept": "text/html",
}
discovered_tickers = []
@ -68,31 +78,32 @@ def get_short_interest(
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
soup = BeautifulSoup(response.text, "html.parser")
# Find ticker links in the page
ticker_links = soup.find_all('a', href=re.compile(r'quote\.ashx\?t='))
ticker_links = soup.find_all("a", href=re.compile(r"quote\.ashx\?t="))
for link in ticker_links:
ticker = link.get_text(strip=True)
# Validate it's a ticker (1-5 uppercase letters)
if re.match(r'^[A-Z]{1,5}$', ticker) and ticker not in discovered_tickers:
if re.match(r"^[A-Z]{1,5}$", ticker) and ticker not in discovered_tickers:
discovered_tickers.append(ticker)
if not discovered_tickers:
if return_structured:
return []
return f"No stocks discovered with short interest >{min_short_interest_pct}% on Finviz."
print(f" Discovered {len(discovered_tickers)} tickers from Finviz")
print(f" Fetching detailed short interest data from Yahoo Finance...")
logger.info(f"Discovered {len(discovered_tickers)} tickers from Finviz")
logger.info("Fetching detailed short interest data from Yahoo Finance...")
# Step 2: Use Yahoo Finance to get EXACT short interest data for discovered tickers
def fetch_short_data(ticker):
try:
stock = yf.Ticker(ticker)
info = stock.info
info = get_ticker_info(ticker)
# Get short interest data
short_pct = info.get('shortPercentOfFloat', info.get('sharesPercentSharesOut', 0))
short_pct = info.get("shortPercentOfFloat", info.get("sharesPercentSharesOut", 0))
if short_pct and isinstance(short_pct, (int, float)):
short_pct = short_pct * 100 # Convert to percentage
else:
@ -100,9 +111,9 @@ def get_short_interest(
# Verify it meets criteria (Finviz filter might be outdated)
if short_pct >= min_short_interest_pct:
price = info.get('currentPrice', info.get('regularMarketPrice', 0))
market_cap = info.get('marketCap', 0)
volume = info.get('volume', info.get('regularMarketVolume', 0))
price = info.get("currentPrice", info.get("regularMarketPrice", 0))
market_cap = info.get("marketCap", 0)
volume = info.get("volume", info.get("regularMarketVolume", 0))
# Categorize squeeze potential
if short_pct >= 30:
@ -128,7 +139,9 @@ def get_short_interest(
# Fetch data in parallel (faster)
all_candidates = []
with ThreadPoolExecutor(max_workers=10) as executor:
futures = {executor.submit(fetch_short_data, ticker): ticker for ticker in discovered_tickers}
futures = {
executor.submit(fetch_short_data, ticker): ticker for ticker in discovered_tickers
}
for future in as_completed(futures):
result = future.result()
@ -136,26 +149,30 @@ def get_short_interest(
all_candidates.append(result)
if not all_candidates:
if return_structured:
return []
return f"No stocks with verified short interest >{min_short_interest_pct}% (Finviz found {len(discovered_tickers)} tickers but Yahoo Finance data didn't confirm)."
# Sort by short interest percentage (highest first)
sorted_candidates = sorted(
all_candidates,
key=lambda x: x["short_interest_pct"],
reverse=True
all_candidates, key=lambda x: x["short_interest_pct"], reverse=True
)[:top_n]
# Return structured data if requested
if return_structured:
return sorted_candidates
# Format output
report = f"# Discovered High Short Interest Stocks\n\n"
report = "# Discovered High Short Interest Stocks\n\n"
report += f"**Criteria**: Short Interest >{min_short_interest_pct}%\n"
report += f"**Data Source**: Finviz Screener (Web Scraping)\n"
report += "**Data Source**: Finviz Screener (Web Scraping)\n"
report += f"**Total Discovered**: {len(all_candidates)} stocks\n\n"
report += f"**Top {len(sorted_candidates)} Candidates**:\n\n"
report += "| Ticker | Price | Market Cap | Volume | Short % | Signal |\n"
report += "|--------|-------|------------|--------|---------|--------|\n"
for candidate in sorted_candidates:
market_cap_str = format_market_cap(candidate['market_cap'])
market_cap_str = format_market_cap(candidate["market_cap"])
report += f"| {candidate['ticker']} | "
report += f"${candidate['price']:.2f} | "
report += f"{market_cap_str} | "
@ -166,38 +183,44 @@ def get_short_interest(
report += "\n\n## Signal Definitions\n\n"
report += "- **extreme_squeeze_risk**: Short interest >30% - Very high squeeze potential\n"
report += "- **high_squeeze_potential**: Short interest 20-30% - High squeeze risk\n"
report += "- **moderate_squeeze_potential**: Short interest 15-20% - Moderate squeeze risk\n"
report += (
"- **moderate_squeeze_potential**: Short interest 15-20% - Moderate squeeze risk\n"
)
report += "- **low_squeeze_potential**: Short interest 10-15% - Lower squeeze risk\n\n"
report += "**Note**: High short interest alone doesn't guarantee a squeeze. Look for positive catalysts.\n"
return report
except requests.exceptions.RequestException as e:
if return_structured:
return []
return f"Error scraping Finviz: {str(e)}"
except Exception as e:
if return_structured:
return []
return f"Unexpected error discovering short interest stocks: {str(e)}"
def parse_market_cap(market_cap_text: str) -> float:
"""Parse market cap from Finviz format (e.g., '1.23B', '456M')."""
if not market_cap_text or market_cap_text == '-':
if not market_cap_text or market_cap_text == "-":
return 0.0
market_cap_text = market_cap_text.upper().strip()
# Extract number and multiplier
match = re.match(r'([0-9.]+)([BMK])?', market_cap_text)
match = re.match(r"([0-9.]+)([BMK])?", market_cap_text)
if not match:
return 0.0
number = float(match.group(1))
multiplier = match.group(2)
if multiplier == 'B':
if multiplier == "B":
return number * 1_000_000_000
elif multiplier == 'M':
elif multiplier == "M":
return number * 1_000_000
elif multiplier == 'K':
elif multiplier == "K":
return number * 1_000
else:
return number
@ -220,3 +243,210 @@ def get_finviz_short_interest(
) -> str:
"""Alias for get_short_interest to match registry naming convention"""
return get_short_interest(min_short_interest_pct, min_days_to_cover, top_n)
def get_insider_buying_screener(
transaction_type: Annotated[str, "Transaction type: 'buy', 'sell', or 'any'"] = "buy",
lookback_days: Annotated[int, "Days to look back for transactions"] = 7,
min_value: Annotated[int, "Minimum transaction value in dollars"] = 25000,
top_n: Annotated[int, "Number of top results to return"] = 20,
return_structured: Annotated[bool, "Return list of dicts instead of markdown"] = False,
):
"""
Discover stocks with recent insider buying/selling using OpenInsider.
LEADING INDICATOR: Insiders buying their own stock before price moves.
Results are sorted by transaction value (largest first).
Args:
transaction_type: "buy" for purchases, "sell" for sales
lookback_days: Days to look back (default 7)
min_value: Minimum transaction value in dollars
top_n: Number of top results to return
return_structured: If True, returns list of dicts instead of markdown
Returns:
If return_structured=True: list of transaction dicts
If return_structured=False: Formatted markdown report
"""
try:
filter_desc = "insider buying" if transaction_type == "buy" else "insider selling"
logger.info(f"Discovering tickers with {filter_desc} from OpenInsider...")
# OpenInsider screener URL
# xp=1 means exclude private transactions
# fd=7 means last 7 days filing date
# vl=25 means minimum value $25k
if transaction_type == "buy":
url = f"http://openinsider.com/screener?s=&o=&pl=&ph=&ll=&lh=&fd={lookback_days}&fdr=&td=0&tdr=&fdlyl=&fdlyh=&dtefrom=&dteto=&xp=1&vl={min_value // 1000}&vh=&ocl=&och=&session=all&cnt=100&page=1"
else:
url = f"http://openinsider.com/screener?s=&o=&pl=&ph=&ll=&lh=&fd={lookback_days}&fdr=&td=0&tdr=&fdlyl=&fdlyh=&dtefrom=&dteto=&xs=1&vl={min_value // 1000}&vh=&ocl=&och=&sic1=-1&sicl=100&sich=9999&grp=0&nfl=&nfh=&nil=&nih=&nol=&noh=&v2l=&v2h=&oc2l=&oc2h=&sortcol=4&cnt=100&page=1"
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Accept": "text/html",
}
response = requests.get(url, headers=headers, timeout=60)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
# Find the main data table
table = soup.find("table", class_="tinytable")
if not table:
return f"No {filter_desc} data found on OpenInsider."
tbody = table.find("tbody")
if not tbody:
return f"No {filter_desc} data found on OpenInsider."
rows = tbody.find_all("tr")
transactions = []
for row in rows:
cells = row.find_all("td")
if len(cells) < 12:
continue
try:
# OpenInsider columns:
# 0: X (checkbox), 1: Filing Date, 2: Trade Date, 3: Ticker, 4: Company Name
# 5: Insider Name, 6: Title, 7: Trade Type, 8: Price, 9: Qty, 10: Owned, 11: ΔOwn, 12: Value
ticker_cell = cells[3]
ticker_link = ticker_cell.find("a")
ticker = ticker_link.get_text(strip=True) if ticker_link else ""
if not ticker or not re.match(r"^[A-Z]{1,5}$", ticker):
continue
company = cells[4].get_text(strip=True)[:40] if len(cells) > 4 else ""
insider_name = cells[5].get_text(strip=True)[:25] if len(cells) > 5 else ""
title_raw = cells[6].get_text(strip=True) if len(cells) > 6 else ""
# "10%" means 10% beneficial owner - clarify for readability
title = "10% Owner" if title_raw == "10%" else title_raw[:20]
trade_type = cells[7].get_text(strip=True) if len(cells) > 7 else ""
price = cells[8].get_text(strip=True) if len(cells) > 8 else ""
qty = cells[9].get_text(strip=True) if len(cells) > 9 else ""
value_str = cells[12].get_text(strip=True) if len(cells) > 12 else ""
# Filter by transaction type
trade_type_lower = trade_type.lower()
if (
transaction_type == "buy"
and "buy" not in trade_type_lower
and "p -" not in trade_type_lower
):
continue
if (
transaction_type == "sell"
and "sale" not in trade_type_lower
and "s -" not in trade_type_lower
):
continue
# Parse value for sorting
value_num = 0
if value_str:
# Remove $ and + signs, handle K/M suffixes
clean_value = (
value_str.replace("$", "").replace("+", "").replace(",", "").strip()
)
try:
if "M" in clean_value:
value_num = float(clean_value.replace("M", "")) * 1_000_000
elif "K" in clean_value:
value_num = float(clean_value.replace("K", "")) * 1_000
else:
value_num = float(clean_value)
except ValueError:
value_num = 0
transactions.append(
{
"ticker": ticker,
"company": company,
"insider": insider_name,
"title": title,
"trade_type": trade_type,
"price": price,
"qty": qty,
"value_str": value_str,
"value_num": value_num,
}
)
except Exception:
continue
if not transactions:
if return_structured:
return []
return f"No {filter_desc} transactions found in the last {lookback_days} days."
# Sort by value (largest first)
transactions.sort(key=lambda x: x["value_num"], reverse=True)
# Deduplicate by ticker, keeping the largest transaction per ticker
seen_tickers = set()
unique_transactions = []
for t in transactions:
if t["ticker"] not in seen_tickers:
seen_tickers.add(t["ticker"])
unique_transactions.append(t)
if len(unique_transactions) >= top_n:
break
logger.info(
f"Discovered {len(unique_transactions)} tickers with {filter_desc} (sorted by value)"
)
# Return structured data if requested
if return_structured:
return unique_transactions
# Format report
report_lines = [
f"# Insider {'Buying' if transaction_type == 'buy' else 'Selling'} Report",
f"*Top {len(unique_transactions)} stocks by transaction value (last {lookback_days} days)*\n",
"| Ticker | Company | Insider | Title | Value | Price |",
"|--------|---------|---------|-------|-------|-------|",
]
for t in unique_transactions:
report_lines.append(
f"| {t['ticker']} | {t['company']} | {t['insider']} | {t['title']} | {t['value_str']} | {t['price']} |"
)
report_lines.append(
f"\n**Total: {len(unique_transactions)} stocks with significant {filter_desc}**"
)
report_lines.append("*Sorted by transaction value (largest first)*")
return "\n".join(report_lines)
except requests.exceptions.RequestException as e:
if return_structured:
return []
return f"Error fetching insider data from OpenInsider: {e}"
except Exception as e:
if return_structured:
return []
return f"Error processing insider screener: {e}"
def get_finviz_insider_buying(
transaction_type: str = "buy",
lookback_days: int = 7,
min_value: int = 25000,
top_n: int = 20,
) -> str:
"""Alias for get_insider_buying_screener to match registry naming convention"""
return get_insider_buying_screener(
transaction_type=transaction_type,
lookback_days=lookback_days,
min_value=min_value,
top_n=top_n,
)

View File

@ -3,11 +3,14 @@ Yahoo Finance API - Short Interest Data using yfinance
Identifies potential short squeeze candidates with high short interest
"""
import os
import yfinance as yf
from typing import Annotated
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Annotated
from tradingagents.dataflows.market_data_utils import format_markdown_table, format_market_cap
from tradingagents.dataflows.y_finance import get_ticker_info
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def get_short_interest(
@ -37,33 +40,70 @@ def get_short_interest(
# In a production system, this would come from a screener API
watchlist = [
# Meme stocks & high short interest candidates
"GME", "AMC", "BBBY", "BYND", "CLOV", "WISH", "PLTR", "SPCE",
"GME",
"AMC",
"BBBY",
"BYND",
"CLOV",
"WISH",
"PLTR",
"SPCE",
# EV & Tech
"RIVN", "LCID", "NIO", "TSLA", "NKLA", "PLUG", "FCEL",
"RIVN",
"LCID",
"NIO",
"TSLA",
"NKLA",
"PLUG",
"FCEL",
# Biotech (often heavily shorted)
"SAVA", "NVAX", "MRNA", "BNTX", "VXRT", "SESN", "OCGN",
"SAVA",
"NVAX",
"MRNA",
"BNTX",
"VXRT",
"SESN",
"OCGN",
# Retail & Consumer
"PTON", "W", "CVNA", "DASH", "UBER", "LYFT",
"PTON",
"W",
"CVNA",
"DASH",
"UBER",
"LYFT",
# Finance & REITs
"SOFI", "HOOD", "COIN", "SQ", "AFRM",
"SOFI",
"HOOD",
"COIN",
"SQ",
"AFRM",
# Small caps with squeeze potential
"APRN", "ATER", "BBIG", "CEI", "PROG", "SNDL",
"APRN",
"ATER",
"BBIG",
"CEI",
"PROG",
"SNDL",
# Others
"TDOC", "ZM", "PTON", "NFLX", "SNAP", "PINS",
"TDOC",
"ZM",
"PTON",
"NFLX",
"SNAP",
"PINS",
]
print(f" Checking short interest for {len(watchlist)} tickers...")
logger.info(f"Checking short interest for {len(watchlist)} tickers...")
high_si_candidates = []
# Use threading to speed up API calls
def fetch_short_data(ticker):
try:
stock = yf.Ticker(ticker)
info = stock.info
info = get_ticker_info(ticker)
# Get short interest data
short_pct = info.get('shortPercentOfFloat', info.get('sharesPercentSharesOut', 0))
short_pct = info.get("shortPercentOfFloat", info.get("sharesPercentSharesOut", 0))
if short_pct and isinstance(short_pct, (int, float)):
short_pct = short_pct * 100 # Convert to percentage
else:
@ -72,9 +112,9 @@ def get_short_interest(
# Only include if meets criteria
if short_pct >= min_short_interest_pct:
# Get other data
price = info.get('currentPrice', info.get('regularMarketPrice', 0))
market_cap = info.get('marketCap', 0)
volume = info.get('volume', info.get('regularMarketVolume', 0))
price = info.get("currentPrice", info.get("regularMarketPrice", 0))
market_cap = info.get("marketCap", 0)
volume = info.get("volume", info.get("regularMarketVolume", 0))
# Categorize squeeze potential
if short_pct >= 30:
@ -111,34 +151,40 @@ def get_short_interest(
# Sort by short interest percentage (highest first)
sorted_candidates = sorted(
high_si_candidates,
key=lambda x: x["short_interest_pct"],
reverse=True
high_si_candidates, key=lambda x: x["short_interest_pct"], reverse=True
)[:top_n]
# Format output
report = f"# High Short Interest Stocks (Yahoo Finance Data)\n\n"
report = "# High Short Interest Stocks (Yahoo Finance Data)\n\n"
report += f"**Criteria**: Short Interest >{min_short_interest_pct}%\n"
report += f"**Data Source**: Yahoo Finance via yfinance\n"
report += "**Data Source**: Yahoo Finance via yfinance\n"
report += f"**Checked**: {len(watchlist)} tickers from watchlist\n\n"
report += f"**Found**: {len(sorted_candidates)} stocks with high short interest\n\n"
report += f"**Found**: {len(sorted_candidates)} stocks with high short interest\n\n"
report += "## Potential Short Squeeze Candidates\n\n"
report += "| Ticker | Price | Market Cap | Volume | Short % | Signal |\n"
report += "|--------|-------|------------|--------|---------|--------|\n"
headers = ["Ticker", "Price", "Market Cap", "Volume", "Short %", "Signal"]
rows = []
for candidate in sorted_candidates:
market_cap_str = format_market_cap(candidate['market_cap'])
report += f"| {candidate['ticker']} | "
report += f"${candidate['price']:.2f} | "
report += f"{market_cap_str} | "
report += f"{candidate['volume']:,} | "
report += f"{candidate['short_interest_pct']:.1f}% | "
report += f"{candidate['signal']} |\n"
rows.append(
[
candidate["ticker"],
f"${candidate['price']:.2f}",
format_market_cap(candidate["market_cap"]),
f"{candidate['volume']:,}",
f"{candidate['short_interest_pct']:.1f}%",
candidate["signal"],
]
)
report += format_markdown_table(headers, rows)
report += "\n\n## Signal Definitions\n\n"
report += "- **extreme_squeeze_risk**: Short interest >30% - Very high squeeze potential\n"
report += "- **high_squeeze_potential**: Short interest 20-30% - High squeeze risk\n"
report += "- **moderate_squeeze_potential**: Short interest 15-20% - Moderate squeeze risk\n"
report += (
"- **moderate_squeeze_potential**: Short interest 15-20% - Moderate squeeze risk\n"
)
report += "- **low_squeeze_potential**: Short interest 10-15% - Lower squeeze risk\n\n"
report += "**Note**: High short interest alone doesn't guarantee a squeeze. Look for positive catalysts.\n"
report += "**Limitation**: This checks a curated watchlist. For comprehensive scanning, use a stock screener with short interest filters.\n"
@ -149,41 +195,6 @@ def get_short_interest(
return f"Unexpected error in short interest detection: {str(e)}"
def parse_market_cap(market_cap_text: str) -> float:
"""Parse market cap from Finviz format (e.g., '1.23B', '456M')."""
if not market_cap_text or market_cap_text == '-':
return 0.0
market_cap_text = market_cap_text.upper().strip()
# Extract number and multiplier
match = re.match(r'([0-9.]+)([BMK])?', market_cap_text)
if not match:
return 0.0
number = float(match.group(1))
multiplier = match.group(2)
if multiplier == 'B':
return number * 1_000_000_000
elif multiplier == 'M':
return number * 1_000_000
elif multiplier == 'K':
return number * 1_000
else:
return number
def format_market_cap(market_cap: float) -> str:
"""Format market cap for display."""
if market_cap >= 1_000_000_000:
return f"${market_cap / 1_000_000_000:.2f}B"
elif market_cap >= 1_000_000:
return f"${market_cap / 1_000_000:.2f}M"
else:
return f"${market_cap:,.0f}"
def get_fmp_short_interest(
min_short_interest_pct: float = 10.0,
min_days_to_cover: float = 2.0,

View File

@ -1,6 +1,8 @@
from typing import Annotated
from datetime import datetime
from typing import Annotated
from dateutil.relativedelta import relativedelta
from .googlenews_utils import getNewsData
@ -32,7 +34,9 @@ def get_google_news(
start_dt = datetime.strptime(curr_date, "%Y-%m-%d")
before = (start_dt - relativedelta(days=look_back_days)).strftime("%Y-%m-%d")
else:
raise ValueError("Must provide either (start_date, end_date) or (curr_date, look_back_days)")
raise ValueError(
"Must provide either (start_date, end_date) or (curr_date, look_back_days)"
)
news_results = getNewsData(search_query, before, target_date)
@ -40,7 +44,9 @@ def get_google_news(
for news in news_results:
news_str += (
f"### {news['title']} (source: {news['source']}) \n\n{news['snippet']}\n\n"
f"### {news['title']} (source: {news['source']}, date: {news['date']})\n"
f"Link: {news['link']}\n"
f"Snippet: {news['snippet']}\n\n"
)
if len(news_results) == 0:
@ -49,24 +55,18 @@ def get_google_news(
return f"## {search_query} Google News, from {before} to {target_date}:\n\n{news_str}"
def get_global_news_google(
date: str,
look_back_days: int = 3,
limit: int = 5
) -> str:
def get_global_news_google(date: str, look_back_days: int = 3, limit: int = 5) -> str:
"""Retrieve global market news using Google News.
Args:
date: Date for news, yyyy-mm-dd
look_back_days: Days to look back
limit: Max number of articles (not strictly enforced by underlying function but good for interface)
Returns:
Global news report
"""
# Query for general market topics
return get_google_news(
query="financial markets macroeconomics",
curr_date=date,
look_back_days=look_back_days
)
query="financial markets macroeconomics", curr_date=date, look_back_days=look_back_days
)

View File

@ -1,17 +1,20 @@
import json
import random
import time
from datetime import datetime
import requests
from bs4 import BeautifulSoup
from datetime import datetime
import time
import random
from tenacity import (
retry,
retry_if_result,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
retry_if_result,
)
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def is_rate_limited(response):
"""Check if the response indicates rate limiting (status code 429)"""
@ -88,7 +91,7 @@ def getNewsData(query, start_date, end_date):
}
)
except Exception as e:
print(f"Error processing result: {e}")
logger.error(f"Error processing result: {e}")
# If one of the fields is not found, skip this result
continue
@ -102,7 +105,7 @@ def getNewsData(query, start_date, end_date):
page += 1
except Exception as e:
print(f"Failed after multiple retries: {e}")
logger.error(f"Failed after multiple retries: {e}")
break
return news_results

View File

@ -1,26 +1,4 @@
from typing import Annotated
# Import from vendor-specific modules
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_technical_analysis, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions, validate_ticker as validate_ticker_yfinance
from .google import get_google_news, get_global_news_google
from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai
from .alpha_vantage import (
get_stock as get_alpha_vantage_stock,
get_top_gainers_losers as get_alpha_vantage_movers,
get_indicator as get_alpha_vantage_indicator,
get_fundamentals as get_alpha_vantage_fundamentals,
get_balance_sheet as get_alpha_vantage_balance_sheet,
get_cashflow as get_alpha_vantage_cashflow,
get_income_statement as get_alpha_vantage_income_statement,
get_insider_transactions as get_alpha_vantage_insider_transactions,
get_news as get_alpha_vantage_news,
get_global_news as get_alpha_vantage_global_news
)
from .alpha_vantage_common import AlphaVantageRateLimitError
from .reddit_api import get_reddit_news, get_reddit_global_news as get_reddit_api_global_news, get_reddit_trending_tickers, get_reddit_discussions
from .finnhub_api import get_recommendation_trends as get_finnhub_recommendation_trends
from .twitter_data import get_tweets as get_twitter_tweets, get_tweets_from_user as get_twitter_user_tweets
# ============================================================================
# LEGACY COMPATIBILITY LAYER
@ -29,6 +7,7 @@ from .twitter_data import get_tweets as get_twitter_tweets, get_tweets_from_user
# All new code should use tradingagents.tools.executor.execute_tool() directly.
# ============================================================================
def route_to_vendor(method: str, *args, **kwargs):
"""Route method calls to appropriate vendor implementation with fallback support.
@ -40,4 +19,4 @@ def route_to_vendor(method: str, *args, **kwargs):
from tradingagents.tools.executor import execute_tool
# Delegate to new system
return execute_tool(method, *args, **kwargs)
return execute_tool(method, *args, **kwargs)

View File

@ -1,13 +1,20 @@
from typing import Annotated
import pandas as pd
import os
from .config import DATA_DIR
from datetime import datetime
from dateutil.relativedelta import relativedelta
import json
from .reddit_utils import fetch_top_from_category
import os
from datetime import datetime
from typing import Annotated
import pandas as pd
from dateutil.relativedelta import relativedelta
from tqdm import tqdm
from .config import DATA_DIR
from .reddit_utils import fetch_top_from_category
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def get_YFin_data_window(
symbol: Annotated[str, "ticker symbol of the company"],
curr_date: Annotated[str, "Start date in yyyy-mm-dd format"],
@ -30,9 +37,7 @@ def get_YFin_data_window(
data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)
]
filtered_data = data[(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1)
@ -43,10 +48,8 @@ def get_YFin_data_window(
):
df_string = filtered_data.to_string()
return (
f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n"
+ df_string
)
return f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n" + df_string
def get_YFin_data(
symbol: Annotated[str, "ticker symbol of the company"],
@ -70,9 +73,7 @@ def get_YFin_data(
data["DateOnly"] = data["Date"].str[:10]
# Filter data between the start and end dates (inclusive)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)
]
filtered_data = data[(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)]
# Drop the temporary column we created
filtered_data = filtered_data.drop("DateOnly", axis=1)
@ -82,6 +83,7 @@ def get_YFin_data(
return filtered_data
def get_finnhub_news(
query: Annotated[str, "Search query or ticker symbol"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
@ -109,9 +111,7 @@ def get_finnhub_news(
if len(data) == 0:
continue
for entry in data:
current_news = (
"### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"]
)
current_news = "### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"]
combined_result += current_news + "\n\n"
return f"## {query} News, from {start_date} to {end_date}:\n" + str(combined_result)
@ -191,6 +191,7 @@ def get_finnhub_company_insider_transactions(
+ "The change field reflects the variation in share count—here a negative number indicates a reduction in holdings—while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction."
)
def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None):
"""
Gets finnhub data saved and processed on disk.
@ -224,6 +225,7 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=
filtered_data[key] = value
return filtered_data
def get_simfin_balance_sheet(
ticker: Annotated[str, "ticker symbol"],
freq: Annotated[
@ -255,7 +257,7 @@ def get_simfin_balance_sheet(
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No balance sheet available before the given current date.")
logger.warning("No balance sheet available before the given current date.")
return ""
# Get the most recent balance sheet by selecting the row with the latest Publish Date
@ -302,7 +304,7 @@ def get_simfin_cashflow(
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No cash flow statement available before the given current date.")
logger.warning("No cash flow statement available before the given current date.")
return ""
# Get the most recent cash flow statement by selecting the row with the latest Publish Date
@ -349,7 +351,7 @@ def get_simfin_income_statements(
# Check if there are any available reports; if not, return a notification
if filtered_df.empty:
print("No income statement available before the given current date.")
logger.warning("No income statement available before the given current date.")
return ""
# Get the most recent income statement by selecting the row with the latest Publish Date
@ -472,4 +474,4 @@ def get_reddit_company_news(
else:
news_str += f"### {post['title']}\n\n{post['content']}\n\n"
return f"##{query} News Reddit, from {start_date} to {end_date}:\n\n{news_str}"
return f"##{query} News Reddit, from {start_date} to {end_date}:\n\n{news_str}"

View File

@ -0,0 +1,73 @@
import re
from typing import Any, List
def format_markdown_table(headers: List[str], rows: List[List[Any]]) -> str:
"""
Format a list of rows into a Markdown table.
Args:
headers: List of column headers
rows: List of rows, where each row is a list of values
Returns:
Formatted Markdown table string
"""
if not headers:
return ""
# Create header row
header_str = "| " + " | ".join(headers) + " |\n"
# Create separator row
separator_str = "| " + " | ".join(["---"] * len(headers)) + " |\n"
# Create data rows
body_str = ""
for row in rows:
# Convert all values to string and handle None
formatted_row = [str(val) if val is not None else "" for val in row]
body_str += "| " + " | ".join(formatted_row) + " |\n"
return header_str + separator_str + body_str
def parse_market_cap(market_cap_text: str) -> float:
"""Parse market cap from string format (e.g., '1.23B', '456M')."""
if not market_cap_text or market_cap_text == "-":
return 0.0
market_cap_text = str(market_cap_text).upper().strip()
# Extract number and multiplier
match = re.match(r"([0-9.]+)([BMK])?", market_cap_text)
if not match:
try:
return float(market_cap_text)
except ValueError:
return 0.0
number = float(match.group(1))
multiplier = match.group(2)
if multiplier == "B":
return number * 1_000_000_000
elif multiplier == "M":
return number * 1_000_000
elif multiplier == "K":
return number * 1_000
else:
return number
def format_market_cap(market_cap: float) -> str:
"""Format market cap for display (e.g. 1.5B, 200M)."""
if not isinstance(market_cap, (int, float)):
return str(market_cap)
if market_cap >= 1_000_000_000:
return f"${market_cap / 1_000_000_000:.2f}B"
elif market_cap >= 1_000_000:
return f"${market_cap / 1_000_000:.2f}M"
else:
return f"${market_cap:,.0f}"

View File

@ -0,0 +1,960 @@
"""
News Semantic Scanner
--------------------
Scans news from multiple sources, summarizes key themes, and enables semantic
matching against ticker descriptions to find relevant investment opportunities.
Sources:
- OpenAI web search (real-time market news)
- SEC EDGAR filings (regulatory news)
- Google News
- Alpha Vantage news
"""
import json
import os
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
import requests
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from openai import OpenAI
from tradingagents.dataflows.discovery.utils import build_llm_log_entry
from tradingagents.schemas import FilingsList, NewsList
from tradingagents.utils.logger import get_logger
load_dotenv()
logger = get_logger(__name__)
class NewsSemanticScanner:
"""Scans and processes news for semantic ticker matching."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize news scanner.
Args:
config: Configuration dict with:
- openai_api_key: OpenAI API key
- news_sources: List of sources to use
- max_news_items: Maximum news items to process
- news_lookback_hours: How far back to look for news (default: 24 hours)
"""
self.config = config
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY not found in environment")
self.openai_client = OpenAI(api_key=openai_api_key)
self.news_sources = config.get("news_sources", ["openai", "google_news"])
self.max_news_items = config.get("max_news_items", 20)
self.news_lookback_hours = config.get("news_lookback_hours", 24)
self.log_callback = config.get("log_callback")
# Calculate time window
self.cutoff_time = datetime.now() - timedelta(hours=self.news_lookback_hours)
def _emit_log(self, entry: Dict[str, Any]) -> None:
if self.log_callback:
try:
self.log_callback(entry)
except Exception:
pass
def _log_llm(
self,
step: str,
model: str,
prompt: Any,
output: Any,
error: str = "",
) -> None:
entry = build_llm_log_entry(
node="semantic_news",
step=step,
model=model,
prompt=prompt,
output=output,
error=error,
)
self._emit_log(entry)
def _get_time_phrase(self) -> str:
"""Generate human-readable time phrase for queries."""
if self.news_lookback_hours <= 1:
return "from the last hour"
elif self.news_lookback_hours <= 6:
return f"from the last {self.news_lookback_hours} hours"
elif self.news_lookback_hours <= 24:
return "from today"
elif self.news_lookback_hours <= 48:
return "from the last 2 days"
else:
days = int(self.news_lookback_hours / 24)
return f"from the last {days} days"
def _deduplicate_news(
self, news_items: List[Dict[str, Any]], similarity_threshold: float = 0.85
) -> List[Dict[str, Any]]:
"""
Deduplicate news items using semantic similarity (embeddings + cosine similarity).
Two-pass approach:
1. Fast hash-based pass for exact/near-exact duplicates
2. Embedding-based cosine similarity for semantically similar stories
Args:
news_items: List of news items from various sources
similarity_threshold: Cosine similarity threshold (0.85 = very similar)
Returns:
Deduplicated list, keeping highest importance version of each story
"""
import hashlib
import re
import numpy as np
if not news_items:
return []
def normalize_text(text: str) -> str:
"""Normalize text for comparison."""
if not text:
return ""
text = text.lower()
text = re.sub(r"[^\w\s]", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def get_content_hash(item: Dict[str, Any]) -> str:
"""Generate hash from normalized title + summary."""
title = normalize_text(item.get("title", ""))
summary = normalize_text(item.get("summary", ""))[:100]
content = title + " " + summary
return hashlib.md5(content.encode()).hexdigest()
def get_news_text(item: Dict[str, Any]) -> str:
"""Get combined text for embedding."""
title = item.get("title", "")
summary = item.get("summary", "")
return f"{title}. {summary}"[:500] # Limit length for efficiency
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
# === PASS 1: Hash-based deduplication (fast, exact matches) ===
seen_hashes: Dict[str, Dict[str, Any]] = {}
hash_duplicates = 0
for item in news_items:
content_hash = get_content_hash(item)
if content_hash not in seen_hashes:
seen_hashes[content_hash] = item
else:
existing = seen_hashes[content_hash]
if (item.get("importance", 0) or 0) > (existing.get("importance", 0) or 0):
seen_hashes[content_hash] = item
hash_duplicates += 1
after_hash = list(seen_hashes.values())
logger.info(
f"Hash dedup: {len(news_items)}{len(after_hash)} ({hash_duplicates} exact duplicates)"
)
# === PASS 2: Embedding-based semantic similarity ===
# Only run if we have enough items to justify the cost
if len(after_hash) <= 3:
return after_hash
try:
# Generate embeddings for all remaining items
texts = [get_news_text(item) for item in after_hash]
# Use OpenAI embeddings (same as ticker_semantic_db)
response = self.openai_client.embeddings.create(
model="text-embedding-3-small",
input=texts,
)
embeddings = np.array([e.embedding for e in response.data])
# Find semantic duplicates using cosine similarity
unique_indices = []
semantic_duplicates = 0
for i in range(len(after_hash)):
is_duplicate = False
for j in unique_indices:
sim = cosine_similarity(embeddings[i], embeddings[j])
if sim >= similarity_threshold:
# This is a semantic duplicate
is_duplicate = True
semantic_duplicates += 1
# Keep higher importance version
existing_item = after_hash[j]
new_item = after_hash[i]
if (new_item.get("importance", 0) or 0) > (
existing_item.get("importance", 0) or 0
):
# Replace with higher importance
unique_indices.remove(j)
unique_indices.append(i)
logger.debug(
f"Semantic duplicate (sim={sim:.2f}): "
f"'{new_item.get('title', '')[:40]}' vs "
f"'{existing_item.get('title', '')[:40]}'"
)
break
if not is_duplicate:
unique_indices.append(i)
final_items = [after_hash[i] for i in unique_indices]
logger.info(
f"Semantic dedup: {len(after_hash)}{len(final_items)} "
f"({semantic_duplicates} similar stories merged)"
)
return final_items
except Exception as e:
logger.warning(f"Embedding-based dedup failed, using hash-only results: {e}")
return after_hash
def _filter_by_time(self, news_items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Filter news items by timestamp to respect lookback window.
Args:
news_items: List of news items with 'published_at' or 'timestamp' field
Returns:
Filtered list of news items within time window
"""
filtered = []
filtered_out_count = 0
for item in news_items:
timestamp_str = item.get("published_at") or item.get("timestamp")
title_preview = item.get("title", "")[:60]
if not timestamp_str:
# No timestamp, keep it (assume recent)
logger.debug(f"No timestamp for '{title_preview}', keeping")
filtered.append(item)
continue
item_time = self._parse_timestamp(timestamp_str, date_only_end=True)
if not item_time:
# If parsing fails, keep it
logger.debug(f"Parse failed for '{timestamp_str}' on '{title_preview}', keeping")
filtered.append(item)
continue
if item_time >= self.cutoff_time:
filtered.append(item)
else:
filtered_out_count += 1
logger.debug(
f"FILTERED OUT: '{title_preview}' | "
f"published_at='{item.get('published_at')}' | "
f"parsed={item_time.strftime('%Y-%m-%d %H:%M')} | "
f"cutoff={self.cutoff_time.strftime('%Y-%m-%d %H:%M')}"
)
if filtered_out_count > 0:
logger.info(
f"Time filter removed {filtered_out_count} items with timestamps before cutoff"
)
return filtered
def _parse_timestamp(self, timestamp_str: str, date_only_end: bool) -> Optional[datetime]:
"""Parse a timestamp string into a naive datetime, or return None if invalid."""
try:
# Handle date-only strings
if len(timestamp_str) == 10 and timestamp_str[4] == "-" and timestamp_str[7] == "-":
base_time = datetime.fromisoformat(timestamp_str)
if date_only_end:
return base_time.replace(hour=23, minute=59, second=59)
return base_time
# Parse ISO timestamp
parsed_time = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
if parsed_time.tzinfo:
parsed_time = parsed_time.astimezone().replace(tzinfo=None)
return parsed_time
except Exception:
return None
def _publish_date_range(
self, news_items: List[Dict[str, Any]]
) -> Tuple[Optional[datetime], Optional[datetime]]:
"""Get the earliest and latest publish timestamps from a list of news items."""
min_time = None
max_time = None
for item in news_items:
timestamp_str = item.get("published_at") or item.get("timestamp")
if not timestamp_str:
continue
item_time = self._parse_timestamp(timestamp_str, date_only_end=False)
if not item_time:
continue
if min_time is None or item_time < min_time:
min_time = item_time
if max_time is None or item_time > max_time:
max_time = item_time
return min_time, max_time
def _build_web_search_prompt(self, query: str = "breaking stock market news today") -> str:
"""
Build unified web search prompt for both OpenAI and Gemini.
Args:
query: Search query for news
Returns:
Formatted search prompt string
"""
time_phrase = self._get_time_phrase()
time_query = f"{query} {time_phrase}"
current_datetime = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
cutoff_datetime = self.cutoff_time.strftime("%Y-%m-%dT%H:%M:%S")
return f"""Search the web for: {time_query}
CRITICAL TIME CONSTRAINT:
- Current time: {current_datetime}
- Only include news published AFTER: {cutoff_datetime}
- Skip any articles older than {self.news_lookback_hours} hours
Find the top {self.max_news_items} most important market-moving news stories from the last {self.news_lookback_hours} hours.
Prefer company-specific or single-catalyst stories that are likely to impact only one company or a small number of companies. Avoid broad market, index, or macroeconomic headlines unless they have a clear company-specific catalyst.
Focus on:
- Earnings reports and guidance
- FDA approvals / regulatory decisions
- Mergers, acquisitions, partnerships
- Product launches
- Executive changes
- Legal/regulatory actions
- Analyst upgrades/downgrades
For each news item, extract:
- title: Headline
- summary: 2-3 sentence summary of key points
- published_at: ISO-8601 timestamp (REQUIRED - convert relative times like "2 hours ago" to full timestamp using current time {current_datetime})
- companies_mentioned: List of ticker symbols or company names mentioned
- themes: List of key themes (e.g., "earnings beat", "FDA approval", "merger")
- sentiment: one of positive, negative, neutral
- importance: 1-10 score (10 = highly market-moving)
"""
def _build_openai_input(self, system_text: str, user_text: str) -> str:
"""Build Responses API input as a single prompt string."""
if system_text:
return f"{system_text}\n\n{user_text}"
return user_text
def _fetch_openai_news(
self, query: str = "breaking stock market news today"
) -> List[Dict[str, Any]]:
"""
Fetch news using OpenAI's web search capability.
Args:
query: Search query for news
Returns:
List of news items with title, summary, published_at, timestamp
"""
try:
# Build search prompt
search_prompt = self._build_web_search_prompt(query)
# Use OpenAI web search tool for real-time news
response = self.openai_client.responses.parse(
model="gpt-4o",
tools=[{"type": "web_search"}],
input=self._build_openai_input(
"You are a financial news analyst. Search the web for the latest market news "
"and return structured summaries.",
search_prompt,
),
text_format=NewsList,
)
news_list = response.output_parsed
news_items = [item.model_dump() for item in news_list.news]
self._log_llm(
step="OpenAI web search",
model="gpt-4o",
prompt=search_prompt,
output=news_items,
)
# Add metadata
for item in news_items:
item["source"] = "openai_search"
item["timestamp"] = datetime.now().isoformat()
return news_items[: self.max_news_items]
except Exception as e:
self._log_llm(
step="OpenAI web search",
model="gpt-4o",
prompt=search_prompt if "search_prompt" in locals() else "",
output="",
error=str(e),
)
logger.error(f"Error fetching OpenAI news: {e}")
return []
def _fetch_google_news(self, query: str = "stock market") -> List[Dict[str, Any]]:
"""
Fetch news from Google News RSS.
Args:
query: Search query
Returns:
List of news items
"""
try:
# Use Google News helper
from tradingagents.dataflows.google import get_google_news
# Convert hours to days (round up)
lookback_days = max(1, int((self.news_lookback_hours + 23) / 24))
news_report = get_google_news(
query=query,
curr_date=datetime.now().strftime("%Y-%m-%d"),
look_back_days=lookback_days,
)
# Parse the report using LLM to extract structured data
current_datetime = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
cutoff_datetime = self.cutoff_time.strftime("%Y-%m-%dT%H:%M:%S")
parse_prompt = f"""Parse this news report and extract individual news items.
CRITICAL TIME CONSTRAINT:
- Current time: {current_datetime}
- Only include news published AFTER: {cutoff_datetime}
- Skip any articles older than {self.news_lookback_hours} hours
Prefer company-specific or single-catalyst stories that are likely to impact only one company or a small number of companies. Avoid broad market, index, or macroeconomic headlines unless they have a clear company-specific catalyst. If a story is broad or sector-wide without a specific company catalyst, skip it.
{news_report}
For each news item, extract:
- title: Headline
- summary: Brief summary
- published_at: ISO-8601 timestamp (REQUIRED - convert relative times like "2 hours ago" to full timestamp using current time {current_datetime})
- companies_mentioned: Companies or tickers mentioned
- themes: Key themes
- sentiment: one of positive, negative, neutral
- importance: 1-10 score
Return as JSON array with key "news"."""
response = self.openai_client.responses.parse(
model="gpt-4o-mini",
input=self._build_openai_input(
"Extract news items from this report into structured JSON format.",
parse_prompt,
),
text_format=NewsList,
)
news_list = response.output_parsed
news_items = [item.model_dump() for item in news_list.news]
self._log_llm(
step="Parse Google News",
model="gpt-4o-mini",
prompt=parse_prompt,
output=news_items,
)
# Add metadata
for item in news_items:
item["source"] = "google_news"
item["timestamp"] = datetime.now().isoformat()
return news_items[: self.max_news_items]
except Exception as e:
self._log_llm(
step="Parse Google News",
model="gpt-4o-mini",
prompt=parse_prompt if "parse_prompt" in locals() else "",
output="",
error=str(e),
)
logger.error(f"Error fetching Google News: {e}")
return []
def _fetch_sec_filings(self) -> List[Dict[str, Any]]:
"""
Fetch recent SEC filings (8-K, 13D, 13G - market-moving events).
Returns:
List of filing summaries
"""
try:
# SEC EDGAR API endpoint
# Get recent 8-K filings (material events)
url = "https://www.sec.gov/cgi-bin/browse-edgar"
params = {"action": "getcurrent", "type": "8-K", "output": "atom", "count": 20}
headers = {"User-Agent": "TradingAgents/1.0 (contact@example.com)"}
response = requests.get(url, params=params, headers=headers, timeout=10)
if response.status_code != 200:
return []
# Parse SEC filings using LLM
# (SEC returns XML/Atom feed, we'll parse with LLM for simplicity)
current_datetime = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
cutoff_datetime = self.cutoff_time.strftime("%Y-%m-%dT%H:%M:%S")
filings_prompt = f"""Parse these SEC filings and extract the most important ones.
CRITICAL TIME CONSTRAINT:
- Current time: {current_datetime}
- Only include filings submitted AFTER: {cutoff_datetime}
- Skip any filings older than {self.news_lookback_hours} hours
Prefer company-specific filings and material events; skip broad market commentary.
{response.text} # Limit to avoid token limits
For each important filing, extract:
- title: Company name and filing type
- summary: What the material event is about
- published_at: ISO-8601 timestamp (REQUIRED - extract from filing date/time)
- companies_mentioned: [company name and ticker if available]
- themes: Type of event (e.g., "acquisition", "earnings guidance", "executive change")
- sentiment: one of positive, negative, neutral
- importance: 1-10 score
Return as JSON array with key "filings"."""
llm_response = self.openai_client.responses.parse(
model="gpt-4o-mini",
input=self._build_openai_input(
"Extract important SEC 8-K filings from this data and summarize the market-moving events.",
filings_prompt,
),
text_format=FilingsList,
)
filings_list = llm_response.output_parsed
filings = [item.model_dump() for item in filings_list.filings]
self._log_llm(
step="Parse SEC filings",
model="gpt-4o-mini",
prompt=filings_prompt,
output=filings,
)
# Add metadata
for filing in filings:
filing["source"] = "sec_edgar"
filing["timestamp"] = datetime.now().isoformat()
return filings[: self.max_news_items]
except Exception as e:
self._log_llm(
step="Parse SEC filings",
model="gpt-4o-mini",
prompt=filings_prompt if "filings_prompt" in locals() else "",
output="",
error=str(e),
)
logger.error(f"Error fetching SEC filings: {e}")
return []
def _fetch_alpha_vantage_news(
self, topics: str = "earnings,technology"
) -> List[Dict[str, Any]]:
"""
Fetch news from Alpha Vantage.
Args:
topics: News topics to filter
Returns:
List of news items
"""
try:
from tradingagents.dataflows.alpha_vantage_news import get_alpha_vantage_news_feed
# Use cutoff time for Alpha Vantage
time_from = self.cutoff_time.strftime("%Y%m%dT%H%M")
news_report = get_alpha_vantage_news_feed(topics=topics, time_from=time_from, limit=50)
# Parse with LLM
current_datetime = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
cutoff_datetime = self.cutoff_time.strftime("%Y-%m-%dT%H:%M:%S")
parse_prompt = f"""Parse this news feed and extract the most important market-moving stories.
CRITICAL TIME CONSTRAINT:
- Current time: {current_datetime}
- Only include news published AFTER: {cutoff_datetime}
- Skip any articles older than {self.news_lookback_hours} hours
Prefer company-specific or single-catalyst stories that are likely to impact only one company or a small number of companies. Avoid broad market, index, or macroeconomic headlines unless they have a clear company-specific catalyst. If a story is broad or sector-wide without a specific company catalyst, skip it.
{news_report}
For each news item, extract:
- title: Headline
- summary: Key points
- published_at: ISO-8601 timestamp (REQUIRED - extract from the data or convert relative times using current time {current_datetime})
- companies_mentioned: Tickers/companies mentioned
- themes: Key themes
- sentiment: one of positive, negative, neutral
- importance: 1-10 score (10 = highly market-moving)
Return as JSON array with key "news"."""
response = self.openai_client.responses.parse(
model="gpt-4o-mini",
input=self._build_openai_input(
"Extract and summarize important market news.",
parse_prompt,
),
text_format=NewsList,
)
news_list = response.output_parsed
news_items = [item.model_dump() for item in news_list.news]
self._log_llm(
step="Parse Alpha Vantage news",
model="gpt-4o-mini",
prompt=parse_prompt,
output=news_items,
)
# Add metadata
for item in news_items:
item["source"] = "alpha_vantage"
item["timestamp"] = datetime.now().isoformat()
return news_items[: self.max_news_items]
except Exception as e:
self._log_llm(
step="Parse Alpha Vantage news",
model="gpt-4o-mini",
prompt=parse_prompt if "parse_prompt" in locals() else "",
output="",
error=str(e),
)
logger.error(f"Error fetching Alpha Vantage news: {e}")
return []
def _fetch_gemini_search_news(
self, query: str = "breaking stock market news today"
) -> List[Dict[str, Any]]:
"""
Fetch news using Google Gemini's native web search (grounding) capability.
This uses Gemini's built-in web search tool for real-time market news,
which may provide different results than OpenAI's web search.
Args:
query: Search query for news
Returns:
List of news items with title, summary, published_at, timestamp
"""
try:
import os
# Get API key
google_api_key = os.getenv("GOOGLE_API_KEY")
if not google_api_key:
logger.error("GOOGLE_API_KEY not set, skipping Gemini search")
return []
# Build search prompt
search_prompt = self._build_web_search_prompt(query)
# Step 1: Execute web search using Gemini with google_search tool
search_llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-lite", # Fast model for search
api_key=google_api_key,
temperature=1.0, # Higher temperature for diverse results
).bind_tools([{"google_search": {}}])
# Execute search
raw_response = search_llm.invoke(search_prompt)
self._log_llm(
step="Gemini search",
model="gemini-2.5-flash-lite",
prompt=search_prompt,
output=raw_response.content if hasattr(raw_response, "content") else raw_response,
)
# Step 2: Structure the results using Gemini with JSON schema
structured_llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-lite", api_key=google_api_key
).with_structured_output(NewsList, method="json_schema")
current_datetime = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
cutoff_datetime = self.cutoff_time.strftime("%Y-%m-%dT%H:%M:%S")
structure_prompt = f"""Parse the following web search results into structured news items.
CRITICAL TIME CONSTRAINT:
- Current time: {current_datetime}
- Only include news published AFTER: {cutoff_datetime}
- Skip any articles older than {self.news_lookback_hours} hours
For each news item, extract:
- title: Headline
- summary: 2-3 sentence summary of key points
- published_at: ISO-8601 timestamp (REQUIRED - convert "X hours ago" to full timestamp using current time {current_datetime})
- companies_mentioned: List of ticker symbols or company names
- themes: List of key themes (e.g., "earnings beat", "FDA approval", "merger")
- sentiment: one of positive, negative, neutral
- importance: 1-10 score (10 = highly market-moving)
Web search results:
{raw_response.content}
Return as JSON with "news" array."""
structured_response = structured_llm.invoke(structure_prompt)
self._log_llm(
step="Gemini search structuring",
model="gemini-2.5-flash-lite",
prompt=structure_prompt,
output=structured_response,
)
# Extract news items
news_items = [item.model_dump() for item in structured_response.news]
# Add metadata
for item in news_items:
item["source"] = "gemini_search"
item["timestamp"] = datetime.now().isoformat()
return news_items[: self.max_news_items]
except Exception as e:
self._log_llm(
step="Gemini search",
model="gemini-2.5-flash-lite",
prompt=search_prompt if "search_prompt" in locals() else "",
output="",
error=str(e),
)
logger.error(f"Error fetching Gemini search news: {e}")
return []
def scan_news(self) -> List[Dict[str, Any]]:
"""
Scan news from all enabled sources.
Returns:
Aggregated list of news items sorted by importance
"""
all_news = []
logger.info("Scanning news sources...")
logger.info(f"Time window: {self._get_time_phrase()} (last {self.news_lookback_hours}h)")
logger.info(f"Cutoff: {self.cutoff_time.strftime('%Y-%m-%d %H:%M')}")
# Fetch from each enabled source
if "openai" in self.news_sources:
logger.info("Fetching OpenAI web search...")
openai_news = self._fetch_openai_news()
all_news.extend(openai_news)
logger.info(f"Found {len(openai_news)} items from OpenAI")
min_date, max_date = self._publish_date_range(openai_news)
if min_date:
logger.debug(f"Min publish date (OpenAI): {min_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Min publish date (OpenAI): N/A")
if max_date:
logger.debug(f"Max publish date (OpenAI): {max_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Max publish date (OpenAI): N/A")
if "google_news" in self.news_sources:
logger.info("Fetching Google News...")
google_news = self._fetch_google_news()
all_news.extend(google_news)
logger.info(f"Found {len(google_news)} items from Google News")
min_date, max_date = self._publish_date_range(google_news)
if min_date:
logger.debug(f"Min publish date (Google News): {min_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Min publish date (Google News): N/A")
if max_date:
logger.debug(f"Max publish date (Google News): {max_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Max publish date (Google News): N/A")
if "sec_filings" in self.news_sources:
logger.info("Fetching SEC filings...")
sec_filings = self._fetch_sec_filings()
all_news.extend(sec_filings)
logger.info(f"Found {len(sec_filings)} items from SEC")
min_date, max_date = self._publish_date_range(sec_filings)
if min_date:
logger.debug(f"Min publish date (SEC): {min_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Min publish date (SEC): N/A")
if max_date:
logger.debug(f"Max publish date (SEC): {max_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Max publish date (SEC): N/A")
if "alpha_vantage" in self.news_sources:
logger.info("Fetching Alpha Vantage news...")
av_news = self._fetch_alpha_vantage_news()
all_news.extend(av_news)
logger.info(f"Found {len(av_news)} items from Alpha Vantage")
min_date, max_date = self._publish_date_range(av_news)
if min_date:
logger.debug(f"Min publish date (Alpha Vantage): {min_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Min publish date (Alpha Vantage): N/A")
if max_date:
logger.debug(f"Max publish date (Alpha Vantage): {max_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Max publish date (Alpha Vantage): N/A")
if "gemini_search" in self.news_sources:
logger.info("Fetching Google Gemini search...")
gemini_news = self._fetch_gemini_search_news()
all_news.extend(gemini_news)
logger.info(f"Found {len(gemini_news)} items from Gemini search")
min_date, max_date = self._publish_date_range(gemini_news)
if min_date:
logger.debug(f"Min publish date (Gemini): {min_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Min publish date (Gemini): N/A")
if max_date:
logger.debug(f"Max publish date (Gemini): {max_date.strftime('%Y-%m-%d %H:%M')}")
else:
logger.debug("Max publish date (Gemini): N/A")
# Apply time filtering
logger.info(f"Collected {len(all_news)} raw news items")
all_news = self._filter_by_time(all_news)
logger.info(f"After time filtering: {len(all_news)} items")
# Deduplicate news from multiple sources (same story = same hash)
all_news = self._deduplicate_news(all_news)
logger.info(f"After deduplication: {len(all_news)} items")
# Sort by importance
all_news.sort(key=lambda x: x.get("importance", 0), reverse=True)
logger.info(f"Total news items collected: {len(all_news)}")
return all_news[: self.max_news_items]
def generate_news_summary(self, news_item: Dict[str, Any]) -> str:
"""
Generate a semantic search-optimized summary for a news item.
Args:
news_item: News item dict
Returns:
Optimized summary text for embedding/matching
"""
title = news_item.get("title", "")
summary = news_item.get("summary", "")
themes = news_item.get("themes", [])
companies = news_item.get("companies_mentioned", [])
# Create rich text for semantic matching
search_text = f"""
{title}
{summary}
Key themes: {', '.join(themes) if themes else 'General market news'}
Companies mentioned: {', '.join(companies) if companies else 'Broad market'}
""".strip()
return search_text
def main():
"""CLI for testing news scanner."""
import argparse
parser = argparse.ArgumentParser(description="Scan news for semantic ticker matching")
parser.add_argument(
"--sources",
nargs="+",
default=["openai"],
choices=["openai", "google_news", "sec_filings", "alpha_vantage", "gemini_search"],
help="News sources to use",
)
parser.add_argument("--max-items", type=int, default=10, help="Maximum news items to fetch")
parser.add_argument(
"--lookback-hours",
type=int,
default=24,
help="How far back to look for news (in hours). Examples: 1 (last hour), 6 (last 6 hours), 24 (last day), 168 (last week)",
)
parser.add_argument("--output", type=str, help="Output file for news JSON")
args = parser.parse_args()
config = {
"news_sources": args.sources,
"max_news_items": args.max_items,
"news_lookback_hours": args.lookback_hours,
}
scanner = NewsSemanticScanner(config)
news_items = scanner.scan_news()
# Display results
logger.info("\n" + "=" * 60)
logger.info(f"Top {min(5, len(news_items))} Most Important News Items:")
logger.info("=" * 60 + "\n")
for i, item in enumerate(news_items[:5], 1):
logger.info(f"{i}. {item.get('title', 'Untitled')}")
logger.info(f" Source: {item.get('source', 'unknown')}")
logger.info(f" Importance: {item.get('importance', 'N/A')}/10")
logger.info(f" Summary: {item.get('summary', '')[:150]}...")
logger.info(f" Themes: {', '.join(item.get('themes', []))}")
logger.info("")
# Save to file if specified
if args.output:
with open(args.output, "w") as f:
json.dump(news_items, f, indent=2)
logger.info(f"✅ Saved {len(news_items)} news items to {args.output}")
if __name__ == "__main__":
main()

View File

@ -1,6 +1,15 @@
import os
import warnings
from openai import OpenAI
from .config import get_config
from tradingagents.config import config
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Suppress Pydantic serialization warnings from OpenAI web search
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic.main")
_OPENAI_CLIENT = None
@ -8,7 +17,7 @@ _OPENAI_CLIENT = None
def _get_openai_client() -> OpenAI:
global _OPENAI_CLIENT
if _OPENAI_CLIENT is None:
_OPENAI_CLIENT = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
_OPENAI_CLIENT = OpenAI(api_key=config.validate_key("openai_api_key", "OpenAI"))
return _OPENAI_CLIENT
@ -36,7 +45,7 @@ def get_stock_news_openai(query=None, ticker=None, start_date=None, end_date=Non
response = client.responses.create(
model="gpt-4o-mini",
tools=[{"type": "web_search_preview"}],
input=f"Search Social Media and news sources for {search_query} from {start_date} to {end_date}. Make sure you only get the data posted during that period."
input=f"Search Social Media and news sources for {search_query} from {start_date} to {end_date}. Make sure you only get the data posted during that period.",
)
return response.output_text
except Exception as e:
@ -50,7 +59,7 @@ def get_global_news_openai(date, look_back_days=7, limit=5):
response = client.responses.create(
model="gpt-4o-mini",
tools=[{"type": "web_search_preview"}],
input=f"Search global or macroeconomics news from {look_back_days} days before {date} that would be informative for trading purposes. Make sure you only get the data posted during that period. Limit the results to {limit} articles."
input=f"Search global or macroeconomics news from {look_back_days} days before {date} that would be informative for trading purposes. Make sure you only get the data posted during that period. Limit the results to {limit} articles.",
)
return response.output_text
except Exception as e:
@ -64,8 +73,197 @@ def get_fundamentals_openai(ticker, curr_date):
response = client.responses.create(
model="gpt-4o-mini",
tools=[{"type": "web_search_preview"}],
input=f"Search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc"
input=f"Search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
)
return response.output_text
except Exception as e:
return f"Error fetching fundamentals from OpenAI: {str(e)}"
def get_batch_stock_news_openai(
tickers: list[str],
start_date: str,
end_date: str,
batch_size: int = 10,
) -> dict[str, str]:
"""Fetch news for multiple tickers in batched OpenAI calls.
Instead of making one API call per ticker, this batches tickers together
to significantly reduce API costs (~90% savings for 50 tickers).
Args:
tickers: List of ticker symbols
start_date: Start date yyyy-mm-dd
end_date: End date yyyy-mm-dd
batch_size: Max tickers per API call (default 10 to avoid output truncation)
Returns:
dict: {ticker: "news summary text", ...}
"""
from typing import List
from pydantic import BaseModel
# Define structured output schema (matching working snippet)
class TickerNews(BaseModel):
ticker: str
news_summary: str
date: str
class PortfolioUpdate(BaseModel):
items: List[TickerNews]
from tqdm import tqdm
client = _get_openai_client()
results = {}
# Process in batches to avoid output token limits
with tqdm(total=len(tickers), desc="📰 OpenAI batch news", unit="ticker") as pbar:
for i in range(0, len(tickers), batch_size):
batch = tickers[i : i + batch_size]
# Request comprehensive news summaries for better ranker LLM context
prompt = f"""Find the most significant news stories for {batch} from {start_date} to {end_date}.
Focus on business catalysts: earnings, product launches, partnerships, analyst changes, regulatory news.
For each ticker, provide a comprehensive summary (5-8 sentences) covering:
- What happened (the catalyst/event)
- Key numbers/metrics if applicable (revenue, earnings, deal size, etc.)
- Why it matters for investors
- Market reaction or implications
- Any forward-looking statements or guidance"""
try:
completion = client.responses.parse(
model="gpt-5-nano",
tools=[{"type": "web_search"}],
input=prompt,
text_format=PortfolioUpdate,
)
# Extract structured output
if completion.output_parsed:
for item in completion.output_parsed.items:
results[item.ticker.upper()] = item.news_summary
else:
# Fallback if parsing failed
logger.warning(f"Structured parsing returned None for batch: {batch}")
for ticker in batch:
results[ticker.upper()] = ""
except Exception as e:
logger.error(f"Error fetching batch news for {batch}: {e}")
# On error, set empty string for all tickers in batch
for ticker in batch:
results[ticker.upper()] = ""
# Update progress bar
pbar.update(len(batch))
return results
def get_batch_stock_news_google(
tickers: list[str],
start_date: str,
end_date: str,
batch_size: int = 10,
model: str = "gemini-3-flash-preview",
) -> dict[str, str]:
"""Fetch news for multiple tickers using Google Search (Gemini).
Two-step approach:
1. Use Gemini with google_search tool to gather grounded news
2. Use structured output to format into JSON
Args:
tickers: List of ticker symbols
start_date: Start date yyyy-mm-dd
end_date: End date yyyy-mm-dd
batch_size: Max tickers per API call (default 10)
model: Gemini model name (default: gemini-3-flash-preview)
Returns:
dict: {ticker: "news summary text", ...}
"""
# Create LLMs with specified model (don't use cached version)
from typing import List
from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic import BaseModel
google_api_key = os.getenv("GOOGLE_API_KEY")
if not google_api_key:
raise ValueError("GOOGLE_API_KEY not set in environment")
# Define schema for structured output
class TickerNews(BaseModel):
ticker: str
news_summary: str
date: str
class PortfolioUpdate(BaseModel):
items: List[TickerNews]
# Searcher: Enable web search tool
search_llm = ChatGoogleGenerativeAI(
model=model, api_key=google_api_key, temperature=1.0
).bind_tools([{"google_search": {}}])
# Formatter: Native JSON mode
structured_llm = ChatGoogleGenerativeAI(
model=model, api_key=google_api_key
).with_structured_output(PortfolioUpdate, method="json_schema")
results = {}
from tqdm import tqdm
# Process in batches
with tqdm(total=len(tickers), desc="📰 Google batch news", unit="ticker") as pbar:
for i in range(0, len(tickers), batch_size):
batch = tickers[i : i + batch_size]
# Request comprehensive news summaries for better ranker LLM context
prompt = f"""Find the most significant news stories for {batch} from {start_date} to {end_date}.
Focus on business catalysts: earnings, product launches, partnerships, analyst changes, regulatory news.
For each ticker, provide a comprehensive summary (5-8 sentences) covering:
- What happened (the catalyst/event)
- Key numbers/metrics if applicable (revenue, earnings, deal size, etc.)
- Why it matters for investors
- Market reaction or implications
- Any forward-looking statements or guidance"""
try:
# Step 1: Perform Google search (grounded response)
raw_news = search_llm.invoke(prompt)
# Step 2: Structure the grounded results
structured_result = structured_llm.invoke(
f"Using this verified news data: {raw_news.content}\n\n"
f"Format the news for these tickers into the JSON structure: {batch}\n"
f"Include all tickers from the list, even if no news was found."
)
# Extract results
if structured_result and hasattr(structured_result, "items"):
for item in structured_result.items:
results[item.ticker.upper()] = item.news_summary
else:
logger.warning(f"Structured output invalid for batch: {batch}")
for ticker in batch:
results[ticker.upper()] = ""
except Exception as e:
logger.error(f"Error fetching Google batch news for {batch}: {e}")
# On error, set empty string for all tickers in batch
for ticker in batch:
results[ticker.upper()] = ""
# Update progress bar
pbar.update(len(batch))
return results

View File

@ -1,22 +1,22 @@
import os
import praw
from datetime import datetime, timedelta
from typing import Annotated
import praw
from tradingagents.config import config
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
def get_reddit_client():
"""Initialize and return a PRAW Reddit instance."""
client_id = os.getenv("REDDIT_CLIENT_ID")
client_secret = os.getenv("REDDIT_CLIENT_SECRET")
user_agent = os.getenv("REDDIT_USER_AGENT", "trading_agents_bot/1.0")
client_id = config.validate_key("reddit_client_id", "Reddit Client ID")
client_secret = config.validate_key("reddit_client_secret", "Reddit Client Secret")
user_agent = config.reddit_user_agent
if not client_id or not client_secret:
raise ValueError("REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET must be set in environment variables.")
return praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)
return praw.Reddit(
client_id=client_id,
client_secret=client_secret,
user_agent=user_agent
)
def get_reddit_news(
ticker: Annotated[str, "Ticker symbol"] = None,
@ -33,133 +33,163 @@ def get_reddit_news(
try:
reddit = get_reddit_client()
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
# Add one day to end_date to include the full day
end_dt = end_dt + timedelta(days=1)
# Subreddits to search
subreddits = "stocks+investing+wallstreetbets+stockmarket"
# Search queries - try multiple variations
queries = [
target_query,
f"${target_query}", # Common format on WSB
target_query.lower(),
]
posts = []
seen_ids = set() # Avoid duplicates
subreddit = reddit.subreddit(subreddits)
# Try multiple search strategies
for q in queries:
# Strategy 1: Search by relevance
for submission in subreddit.search(q, sort='relevance', time_filter='all', limit=50):
for submission in subreddit.search(q, sort="relevance", time_filter="all", limit=50):
if submission.id in seen_ids:
continue
post_date = datetime.fromtimestamp(submission.created_utc)
if start_dt <= post_date <= end_dt:
seen_ids.add(submission.id)
# Fetch top comments for this post
submission.comment_sort = 'top'
submission.comment_sort = "top"
submission.comments.replace_more(limit=0)
top_comments = []
for comment in submission.comments[:5]: # Top 5 comments
if hasattr(comment, 'body') and hasattr(comment, 'score'):
top_comments.append({
'body': comment.body[:300] + "..." if len(comment.body) > 300 else comment.body,
'score': comment.score,
'author': str(comment.author) if comment.author else '[deleted]'
})
posts.append({
"title": submission.title,
"score": submission.score,
"num_comments": submission.num_comments,
"date": post_date.strftime("%Y-%m-%d"),
"url": submission.url,
"text": submission.selftext[:500] + "..." if len(submission.selftext) > 500 else submission.selftext,
"subreddit": submission.subreddit.display_name,
"top_comments": top_comments
})
if hasattr(comment, "body") and hasattr(comment, "score"):
top_comments.append(
{
"body": (
comment.body[:300] + "..."
if len(comment.body) > 300
else comment.body
),
"score": comment.score,
"author": (
str(comment.author) if comment.author else "[deleted]"
),
}
)
posts.append(
{
"title": submission.title,
"score": submission.score,
"num_comments": submission.num_comments,
"date": post_date.strftime("%Y-%m-%d"),
"url": submission.url,
"text": (
submission.selftext[:500] + "..."
if len(submission.selftext) > 500
else submission.selftext
),
"subreddit": submission.subreddit.display_name,
"top_comments": top_comments,
}
)
# Strategy 2: Search by new (for recent posts)
for submission in subreddit.search(q, sort='new', time_filter='week', limit=50):
for submission in subreddit.search(q, sort="new", time_filter="week", limit=50):
if submission.id in seen_ids:
continue
post_date = datetime.fromtimestamp(submission.created_utc)
if start_dt <= post_date <= end_dt:
seen_ids.add(submission.id)
submission.comment_sort = 'top'
submission.comment_sort = "top"
submission.comments.replace_more(limit=0)
top_comments = []
for comment in submission.comments[:5]:
if hasattr(comment, 'body') and hasattr(comment, 'score'):
top_comments.append({
'body': comment.body[:300] + "..." if len(comment.body) > 300 else comment.body,
'score': comment.score,
'author': str(comment.author) if comment.author else '[deleted]'
})
posts.append({
"title": submission.title,
"score": submission.score,
"num_comments": submission.num_comments,
"date": post_date.strftime("%Y-%m-%d"),
"url": submission.url,
"text": submission.selftext[:500] + "..." if len(submission.selftext) > 500 else submission.selftext,
"subreddit": submission.subreddit.display_name,
"top_comments": top_comments
})
if hasattr(comment, "body") and hasattr(comment, "score"):
top_comments.append(
{
"body": (
comment.body[:300] + "..."
if len(comment.body) > 300
else comment.body
),
"score": comment.score,
"author": (
str(comment.author) if comment.author else "[deleted]"
),
}
)
posts.append(
{
"title": submission.title,
"score": submission.score,
"num_comments": submission.num_comments,
"date": post_date.strftime("%Y-%m-%d"),
"url": submission.url,
"text": (
submission.selftext[:500] + "..."
if len(submission.selftext) > 500
else submission.selftext
),
"subreddit": submission.subreddit.display_name,
"top_comments": top_comments,
}
)
if not posts:
return f"No Reddit posts found for {target_query} between {start_date} and {end_date}."
# Format output
report = f"## Reddit Discussions for {target_query} ({start_date} to {end_date})\n\n"
report += f"**Total Posts Found:** {len(posts)}\n\n"
# Sort by score (popularity)
posts.sort(key=lambda x: x["score"], reverse=True)
# Detailed view of top posts
report += "### Top Posts with Community Reactions\n\n"
for i, post in enumerate(posts[:10], 1): # Top 10 posts
report += f"#### {i}. [{post['subreddit']}] {post['title']}\n"
report += f"**Score:** {post['score']} | **Comments:** {post['num_comments']} | **Date:** {post['date']}\n\n"
if post['text']:
if post["text"]:
report += f"**Post Content:**\n{post['text']}\n\n"
if post['top_comments']:
if post["top_comments"]:
report += f"**Top Community Reactions ({len(post['top_comments'])} comments):**\n"
for j, comment in enumerate(post['top_comments'], 1):
for j, comment in enumerate(post["top_comments"], 1):
report += f"{j}. *[{comment['score']} upvotes]* u/{comment['author']}: {comment['body']}\n"
report += "\n"
report += f"**Link:** {post['url']}\n\n"
report += "---\n\n"
# Summary statistics
total_engagement = sum(p['score'] + p['num_comments'] for p in posts)
avg_score = sum(p['score'] for p in posts) / len(posts) if posts else 0
total_engagement = sum(p["score"] + p["num_comments"] for p in posts)
avg_score = sum(p["score"] for p in posts) / len(posts) if posts else 0
report += "### Summary Statistics\n"
report += f"- **Total Posts:** {len(posts)}\n"
report += f"- **Average Score:** {avg_score:.1f}\n"
report += f"- **Total Engagement:** {total_engagement:,} (upvotes + comments)\n"
report += f"- **Most Active Subreddit:** {max(posts, key=lambda x: x['score'])['subreddit']}\n"
report += (
f"- **Most Active Subreddit:** {max(posts, key=lambda x: x['score'])['subreddit']}\n"
)
return report
except Exception as e:
@ -181,43 +211,45 @@ def get_reddit_global_news(
try:
reddit = get_reddit_client()
curr_dt = datetime.strptime(target_date, "%Y-%m-%d")
start_dt = curr_dt - timedelta(days=look_back_days)
# Subreddits for global news
subreddits = "financenews+finance+economics+stockmarket"
posts = []
subreddit = reddit.subreddit(subreddits)
# For global news, we just want top posts from the period
# We can use 'top' with time_filter, but 'week' is a fixed window.
# Better to iterate top of 'week' and filter by date.
for submission in subreddit.top(time_filter='week', limit=50):
for submission in subreddit.top(time_filter="week", limit=50):
post_date = datetime.fromtimestamp(submission.created_utc)
if start_dt <= post_date <= curr_dt + timedelta(days=1):
posts.append({
"title": submission.title,
"score": submission.score,
"date": post_date.strftime("%Y-%m-%d"),
"subreddit": submission.subreddit.display_name
})
posts.append(
{
"title": submission.title,
"score": submission.score,
"date": post_date.strftime("%Y-%m-%d"),
"subreddit": submission.subreddit.display_name,
}
)
if not posts:
return f"No global news found on Reddit for the past {look_back_days} days."
# Format output
report = f"## Global News from Reddit (Last {look_back_days} days)\n\n"
posts.sort(key=lambda x: x["score"], reverse=True)
for post in posts[:limit]:
report += f"### [{post['subreddit']}] {post['title']} (Score: {post['score']})\n"
report += f"**Date:** {post['date']}\n\n"
return report
except Exception as e:
@ -234,58 +266,65 @@ def get_reddit_trending_tickers(
"""
try:
reddit = get_reddit_client()
# Subreddits to scan
subreddits = "wallstreetbets+stocks+investing+stockmarket"
subreddit = reddit.subreddit(subreddits)
posts = []
# Scan hot posts
for submission in subreddit.hot(limit=limit * 2): # Fetch more to filter by date
for submission in subreddit.hot(limit=limit * 2): # Fetch more to filter by date
# Check date
post_date = datetime.fromtimestamp(submission.created_utc)
if (datetime.now() - post_date).days > look_back_days:
continue
# Fetch top comments
submission.comment_sort = 'top'
submission.comment_sort = "top"
submission.comments.replace_more(limit=0)
top_comments = []
for comment in submission.comments[:3]:
if hasattr(comment, 'body'):
if hasattr(comment, "body"):
top_comments.append(f"- {comment.body[:200]}...")
posts.append({
"title": submission.title,
"score": submission.score,
"subreddit": submission.subreddit.display_name,
"text": submission.selftext[:500] + "..." if len(submission.selftext) > 500 else submission.selftext,
"comments": top_comments
})
posts.append(
{
"title": submission.title,
"score": submission.score,
"subreddit": submission.subreddit.display_name,
"text": (
submission.selftext[:500] + "..."
if len(submission.selftext) > 500
else submission.selftext
),
"comments": top_comments,
}
)
if len(posts) >= limit:
break
if not posts:
return "No trending discussions found."
# Format report for LLM
report = "## Trending Reddit Discussions\n\n"
for i, post in enumerate(posts, 1):
report += f"### {i}. [{post['subreddit']}] {post['title']} (Score: {post['score']})\n"
if post['text']:
if post["text"]:
report += f"**Content:** {post['text']}\n"
if post['comments']:
report += "**Top Comments:**\n" + "\n".join(post['comments']) + "\n"
if post["comments"]:
report += "**Top Comments:**\n" + "\n".join(post["comments"]) + "\n"
report += "\n---\n"
return report
except Exception as e:
return f"Error fetching trending tickers: {str(e)}"
def get_reddit_discussions(
symbol: Annotated[str, "Ticker symbol"],
from_date: Annotated[str, "Start date in yyyy-mm-dd format"],
@ -302,7 +341,7 @@ def get_reddit_undiscovered_dd(
scan_limit: Annotated[int, "Number of new posts to scan"] = 100,
top_n: Annotated[int, "Number of top DD posts to return"] = 10,
num_comments: Annotated[int, "Number of top comments to include"] = 10,
llm_evaluator = None, # Will be passed from discovery graph
llm_evaluator=None, # Will be passed from discovery graph
) -> str:
"""
Find high-quality undiscovered DD using LLM evaluation.
@ -345,47 +384,77 @@ def get_reddit_undiscovered_dd(
continue
# Get top comments for community validation
submission.comment_sort = 'top'
submission.comment_sort = "top"
submission.comments.replace_more(limit=0)
top_comments = []
for comment in submission.comments[:num_comments]:
if hasattr(comment, 'body') and hasattr(comment, 'score'):
top_comments.append({
'body': comment.body[:500], # Include more of each comment
'score': comment.score,
})
if hasattr(comment, "body") and hasattr(comment, "score"):
top_comments.append(
{
"body": comment.body[:1000], # Include more of each comment
"score": comment.score,
}
)
candidate_posts.append({
"title": submission.title,
"author": str(submission.author) if submission.author else '[deleted]',
"score": submission.score,
"num_comments": submission.num_comments,
"subreddit": submission.subreddit.display_name,
"flair": submission.link_flair_text or "None",
"date": post_date.strftime("%Y-%m-%d %H:%M"),
"url": f"https://reddit.com{submission.permalink}",
"text": submission.selftext[:1500], # First 1500 chars for LLM
"full_length": len(submission.selftext),
"hours_ago": int((datetime.now() - post_date).total_seconds() / 3600),
"top_comments": top_comments,
})
candidate_posts.append(
{
"title": submission.title,
"author": str(submission.author) if submission.author else "[deleted]",
"score": submission.score,
"num_comments": submission.num_comments,
"subreddit": submission.subreddit.display_name,
"flair": submission.link_flair_text or "None",
"date": post_date.strftime("%Y-%m-%d %H:%M"),
"url": f"https://reddit.com{submission.permalink}",
"text": submission.selftext[:1500], # First 1500 chars for LLM
"full_length": len(submission.selftext),
"hours_ago": int((datetime.now() - post_date).total_seconds() / 3600),
"top_comments": top_comments,
}
)
if not candidate_posts:
return f"# Undiscovered DD\n\nNo posts found in last {lookback_hours}h."
print(f" Scanning {len(candidate_posts)} Reddit posts with LLM...")
logger.info(f"Scanning {len(candidate_posts)} Reddit posts with LLM...")
# LLM evaluation (parallel)
if llm_evaluator:
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List
from pydantic import BaseModel, Field
from typing import List, Optional
# Define structured output schema
class DDEvaluation(BaseModel):
score: int = Field(description="Quality score 0-100")
reason: str = Field(description="Brief reasoning for the score")
tickers: List[str] = Field(default_factory=list, description="List of stock ticker symbols mentioned (empty list if none)")
tickers: List[str] = Field(
default_factory=list,
description="List of stock ticker symbols mentioned (empty list if none)",
)
# Configure LLM for Reddit content (adjust safety settings if using Gemini)
try:
# Check if using Google Gemini and configure safety settings
if (
hasattr(llm_evaluator, "model_name")
and "gemini" in llm_evaluator.model_name.lower()
):
from langchain_google_genai import HarmBlockThreshold, HarmCategory
# More permissive safety settings for financial content analysis
llm_evaluator.safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}
logger.info(
"⚙️ Configured Gemini with permissive safety settings for financial content"
)
except Exception as e:
logger.warning(f"Could not configure safety settings: {e}")
# Create structured LLM
structured_llm = llm_evaluator.with_structured_output(DDEvaluation)
@ -394,10 +463,12 @@ def get_reddit_undiscovered_dd(
try:
# Build prompt with comments if available
comments_section = ""
if post.get('top_comments') and len(post['top_comments']) > 0:
if post.get("top_comments") and len(post["top_comments"]) > 0:
comments_section = "\n\nTop Community Comments (for validation):\n"
for i, comment in enumerate(post['top_comments'], 1):
comments_section += f"{i}. [{comment['score']} upvotes] {comment['body']}\n"
for i, comment in enumerate(post["top_comments"], 1):
comments_section += (
f"{i}. [{comment['score']} upvotes] {comment['body']}\n"
)
prompt = f"""Evaluate this Reddit post for investment Due Diligence quality.
@ -420,22 +491,34 @@ Extract all stock ticker symbols mentioned in the post or comments."""
result = structured_llm.invoke(prompt)
# Handle None result (Gemini blocked content despite safety settings)
if result is None:
logger.warning(f"⚠️ Content blocked for '{post['title'][:50]}...' - Skipping")
post["quality_score"] = 0
post["quality_reason"] = (
"Content blocked by LLM safety filter. "
"Consider using OpenAI/Anthropic for Reddit content."
)
post["tickers"] = []
return post
# Extract values from structured response
post['quality_score'] = result.score
post['quality_reason'] = result.reason
post['tickers'] = result.tickers # Now a list
post["quality_score"] = result.score
post["quality_reason"] = result.reason
post["tickers"] = result.tickers # Now a list
except Exception as e:
print(f" Error evaluating '{post['title'][:50]}': {str(e)}")
post['quality_score'] = 0
post['quality_reason'] = f'Error: {str(e)}'
post['tickers'] = []
logger.error(f"Error evaluating '{post['title'][:50]}': {str(e)}")
post["quality_score"] = 0
post["quality_reason"] = f"Error: {str(e)}"
post["tickers"] = []
return post
# Parallel evaluation with progress tracking
try:
from tqdm import tqdm
use_tqdm = True
except ImportError:
use_tqdm = False
@ -446,48 +529,49 @@ Extract all stock ticker symbols mentioned in the post or comments."""
if use_tqdm:
# With progress bar
evaluated = []
for future in tqdm(as_completed(futures), total=len(futures), desc=" Evaluating posts"):
for future in tqdm(
as_completed(futures), total=len(futures), desc=" Evaluating posts"
):
evaluated.append(future.result())
else:
# Without progress bar (fallback)
evaluated = [f.result() for f in as_completed(futures)]
# Filter quality threshold (55+ = decent DD)
quality_dd = [p for p in evaluated if p['quality_score'] >= 55]
quality_dd.sort(key=lambda x: x['quality_score'], reverse=True)
quality_dd = [p for p in evaluated if p["quality_score"] >= 55]
quality_dd.sort(key=lambda x: x["quality_score"], reverse=True)
# Debug: show score distribution
all_scores = [p['quality_score'] for p in evaluated if p['quality_score'] > 0]
all_scores = [p["quality_score"] for p in evaluated if p["quality_score"] > 0]
if all_scores:
avg_score = sum(all_scores) / len(all_scores)
max_score = max(all_scores)
print(f" Score distribution: avg={avg_score:.1f}, max={max_score}, quality_posts={len(quality_dd)}")
logger.info(
f"Score distribution: avg={avg_score:.1f}, max={max_score}, quality_posts={len(quality_dd)}"
)
top_dd = quality_dd[:top_n]
else:
# No LLM - sort by length + engagement
candidate_posts.sort(
key=lambda x: x['full_length'] + (x['score'] * 10),
reverse=True
)
candidate_posts.sort(key=lambda x: x["full_length"] + (x["score"] * 10), reverse=True)
top_dd = candidate_posts[:top_n]
if not top_dd:
return f"# Undiscovered DD\n\nNo high-quality DD found (scanned {len(candidate_posts)} posts)."
# Build report
report = f"# 💎 Undiscovered DD (LLM-Filtered Quality)\n\n"
report = "# 💎 Undiscovered DD (LLM-Filtered Quality)\n\n"
report += f"**Scanned:** {len(candidate_posts)} posts\n"
report += f"**High Quality:** {len(top_dd)} DD posts (score ≥60)\n\n"
for i, post in enumerate(top_dd, 1):
report += f"## {i}. {post['title']}\n\n"
if 'quality_score' in post:
if "quality_score" in post:
report += f"**Quality:** {post['quality_score']}/100 - {post['quality_reason']}\n"
if post.get('tickers') and len(post['tickers']) > 0:
tickers_str = ', '.join([f'${t}' for t in post['tickers']])
if post.get("tickers") and len(post["tickers"]) > 0:
tickers_str = ", ".join([f"${t}" for t in post["tickers"]])
report += f"**Tickers:** {tickers_str}\n"
report += f"**r/{post['subreddit']}** | {post['hours_ago']}h ago | "
@ -500,4 +584,5 @@ Extract all stock ticker symbols mentioned in the post or comments."""
except Exception as e:
import traceback
return f"# Undiscovered DD\n\nError: {str(e)}\n{traceback.format_exc()}"

View File

@ -1,11 +1,8 @@
import requests
import time
import json
from datetime import datetime, timedelta
from contextlib import contextmanager
from typing import Annotated
import os
import re
from datetime import datetime
from typing import Annotated
ticker_to_company = {
"AAPL": "Apple",
@ -50,9 +47,7 @@ ticker_to_company = {
def fetch_top_from_category(
category: Annotated[
str, "Category to fetch top post from. Collection of subreddits."
],
category: Annotated[str, "Category to fetch top post from. Collection of subreddits."],
date: Annotated[str, "Date to fetch top posts from."],
max_limit: Annotated[int, "Maximum number of posts to fetch."],
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
@ -70,9 +65,7 @@ def fetch_top_from_category(
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
)
limit_per_subreddit = max_limit // len(
os.listdir(os.path.join(base_path, category))
)
limit_per_subreddit = max_limit // len(os.listdir(os.path.join(base_path, category)))
for data_file in os.listdir(os.path.join(base_path, category)):
# check if data_file is a .jsonl file
@ -90,9 +83,9 @@ def fetch_top_from_category(
parsed_line = json.loads(line)
# select only lines that are from the date
post_date = datetime.utcfromtimestamp(
parsed_line["created_utc"]
).strftime("%Y-%m-%d")
post_date = datetime.utcfromtimestamp(parsed_line["created_utc"]).strftime(
"%Y-%m-%d"
)
if post_date != date:
continue
@ -108,9 +101,9 @@ def fetch_top_from_category(
found = False
for term in search_terms:
if re.search(
term, parsed_line["title"], re.IGNORECASE
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
if re.search(term, parsed_line["title"], re.IGNORECASE) or re.search(
term, parsed_line["selftext"], re.IGNORECASE
):
found = True
break

View File

@ -0,0 +1,575 @@
"""
Semantic Discovery System
------------------------
Combines news scanning with ticker semantic matching to discover
investment opportunities based on breaking news before they show up
in social media or price action.
Flow:
1. Scan news from multiple sources
2. Generate embeddings for each news item
3. Match news against ticker descriptions semantically
4. Filter and rank opportunities
5. Return actionable ticker candidates
"""
import re
from datetime import datetime
from typing import Any, Dict, List
from dotenv import load_dotenv
from tradingagents.dataflows.news_semantic_scanner import NewsSemanticScanner
from tradingagents.dataflows.ticker_semantic_db import TickerSemanticDB
from tradingagents.utils.logger import get_logger
load_dotenv()
logger = get_logger(__name__)
class SemanticDiscovery:
"""Discovers investment opportunities through news-ticker semantic matching."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize semantic discovery system.
Args:
config: Configuration dict with settings for both
ticker DB and news scanner
"""
self.config = config
# Initialize ticker database
self.ticker_db = TickerSemanticDB(config)
# Initialize news scanner
self.news_scanner = NewsSemanticScanner(config)
# Discovery settings
self.min_similarity_threshold = config.get("min_similarity_threshold", 0.3)
self.min_news_importance = config.get("min_news_importance", 5)
self.max_tickers_per_news = config.get("max_tickers_per_news", 5)
self.max_total_candidates = config.get("max_total_candidates", 20)
self.news_sentiment_filter = config.get("news_sentiment_filter", "positive")
self.group_by_news = config.get("group_by_news", False)
def _extract_tickers(self, mentions: List[str]) -> List[str]:
from tradingagents.dataflows.discovery.utils import is_valid_ticker
tickers = set()
for mention in mentions or []:
for match in re.findall(r"\b[A-Z]{1,5}\b", str(mention)):
# APPLY VALIDATION IMMEDIATELY
if is_valid_ticker(match):
tickers.add(match)
return sorted(tickers)
def get_directly_mentioned_tickers(self) -> List[Dict[str, Any]]:
"""
Get tickers that are directly mentioned in news (highest signal).
This extracts tickers from the 'companies_mentioned' field of news items,
which represents explicit company references rather than semantic matches.
Returns:
List of ticker info dicts with news context
"""
# Scan news if not already done
news_items = self.news_scanner.scan_news()
# Filter by importance
important_news = [
item for item in news_items if item.get("importance", 0) >= self.min_news_importance
]
# Extract directly mentioned tickers
mentioned_tickers = {} # ticker -> list of news items
# Common words to exclude (not tickers)
exclude_words = {
"A",
"I",
"AN",
"AI",
"CEO",
"CFO",
"CTO",
"FDA",
"SEC",
"IPO",
"ETF",
"GDP",
"CPI",
"FED",
"NYSE",
"Q1",
"Q2",
"Q3",
"Q4",
"US",
"UK",
"EU",
"AT",
"BE",
"BY",
"DO",
"GO",
"IF",
"IN",
"IS",
"IT",
"ME",
"MY",
"NO",
"OF",
"ON",
"OR",
"SO",
"TO",
"UP",
"WE",
"ALL",
"ARE",
"FOR",
"HAS",
"NEW",
"NOW",
"OLD",
"OUR",
"OUT",
"THE",
"TOP",
"TWO",
"WAS",
"WHO",
"WHY",
"WIN",
"BUY",
"COO",
"EPS",
"P/E",
"ROE",
"ROI",
# Common business abbreviations that aren't tickers
"INC",
"CO",
"LLC",
"LTD",
"CORP",
"PLC",
"AG",
"SA",
"SE",
"NV",
"GAS",
"OIL",
"MGE",
"LG", # Common words/abbreviations from logs
# Single/two-letter words often false positives
"AM",
"AS",
}
for news_item in important_news:
companies = news_item.get("companies_mentioned", [])
extracted = self._extract_tickers(companies)
for ticker in extracted:
if ticker in exclude_words:
continue
if len(ticker) < 2:
continue
if ticker not in mentioned_tickers:
mentioned_tickers[ticker] = []
mentioned_tickers[ticker].append(
{
"news_title": news_item.get("title", ""),
"news_summary": news_item.get("summary", ""),
"sentiment": news_item.get("sentiment", "neutral"),
"importance": news_item.get("importance", 5),
"themes": news_item.get("themes", []),
"source": news_item.get("source", "unknown"),
}
)
# Convert to list format, prioritizing by news importance
result = []
for ticker, news_list in mentioned_tickers.items():
# Use the most important news item as primary
best_news = max(news_list, key=lambda x: x["importance"])
result.append(
{
"ticker": ticker,
"news_title": best_news["news_title"],
"news_summary": best_news["news_summary"],
"sentiment": best_news["sentiment"],
"importance": best_news["importance"],
"themes": best_news["themes"],
"source": best_news["source"],
"mention_count": len(news_list),
}
)
# Sort by importance and mention count
result.sort(key=lambda x: (x["importance"], x["mention_count"]), reverse=True)
logger.info(f"📌 Found {len(result)} directly mentioned tickers in news")
return result[: self.max_total_candidates]
def discover(self) -> List[Dict[str, Any]]:
"""
Run semantic discovery to find ticker opportunities.
Returns:
List of ticker candidates with news context and relevance scores
"""
logger.info("=" * 60)
logger.info("🚀 SEMANTIC DISCOVERY")
logger.info("=" * 60)
# Step 1: Scan news
news_items = self.news_scanner.scan_news()
if not news_items:
logger.info("No news items found.")
return []
# Filter news by importance threshold
important_news = [
item for item in news_items if item.get("importance", 0) >= self.min_news_importance
]
logger.info(f"📰 Processing {len(important_news)} high-importance news items...")
logger.info(f"(Filtered from {len(news_items)} total items)")
if self.news_sentiment_filter:
before_count = len(important_news)
important_news = [
item
for item in important_news
if item.get("sentiment", "").lower() == self.news_sentiment_filter
]
logger.info(
f"Sentiment filter: {self.news_sentiment_filter} "
f"({len(important_news)}/{before_count} kept)"
)
# Step 2: For each news item, find matching tickers
all_candidates = []
news_ticker_map = {} # Track which news items match which tickers
news_groups = {} # Track which tickers match each news item
for i, news_item in enumerate(important_news, 1):
title = news_item.get("title", "Untitled")
logger.info(f"{i}. {title}")
logger.debug(f"Importance: {news_item.get('importance', 0)}/10")
mentioned_tickers = self._extract_tickers(news_item.get("companies_mentioned", []))
# Generate search query from news
search_text = self.news_scanner.generate_news_summary(news_item)
# Search ticker database
matches = self.ticker_db.search_by_text(
query_text=search_text, top_k=self.max_tickers_per_news
)
# Filter by similarity threshold
relevant_matches = [
match
for match in matches
if match["similarity_score"] >= self.min_similarity_threshold
]
if relevant_matches:
logger.info(f"Found {len(relevant_matches)} relevant tickers:")
news_key = (
f"{title}|{news_item.get('source', '')}|"
f"{news_item.get('published_at') or news_item.get('timestamp', '')}"
)
if news_key not in news_groups:
news_groups[news_key] = {
"news_title": title,
"news_summary": news_item.get("summary", ""),
"news_importance": news_item.get("importance", 0),
"news_themes": news_item.get("themes", []),
"news_sentiment": news_item.get("sentiment"),
"news_source": news_item.get("source"),
"published_at": news_item.get("published_at"),
"timestamp": news_item.get("timestamp"),
"mentioned_tickers": mentioned_tickers,
"tickers": [],
}
for match in relevant_matches:
symbol = match["symbol"]
score = match["similarity_score"]
logger.debug(f"{symbol} (similarity: {score:.3f})")
# Track news-ticker mapping
if symbol not in news_ticker_map:
news_ticker_map[symbol] = []
news_ticker_map[symbol].append(
{
"news_title": title,
"news_summary": news_item.get("summary", ""),
"news_importance": news_item.get("importance", 0),
"news_themes": news_item.get("themes", []),
"news_sentiment": news_item.get("sentiment"),
"news_tickers_mentioned": mentioned_tickers,
"similarity_score": score,
"timestamp": news_item.get("timestamp"),
"source": news_item.get("source"),
}
)
if symbol not in {t["ticker"] for t in news_groups[news_key]["tickers"]}:
news_groups[news_key]["tickers"].append(
{
"ticker": symbol,
"similarity_score": score,
"ticker_name": match["metadata"]["name"],
"ticker_sector": match["metadata"]["sector"],
"ticker_industry": match["metadata"]["industry"],
}
)
# Add to candidates
all_candidates.append(
{
"ticker": symbol,
"ticker_name": match["metadata"]["name"],
"ticker_sector": match["metadata"]["sector"],
"ticker_industry": match["metadata"]["industry"],
"news_title": title,
"news_summary": news_item.get("summary", ""),
"news_importance": news_item.get("importance", 0),
"news_themes": news_item.get("themes", []),
"news_sentiment": news_item.get("sentiment"),
"news_tickers_mentioned": mentioned_tickers,
"similarity_score": score,
"news_source": news_item.get("source"),
"discovery_timestamp": datetime.now().isoformat(),
}
)
else:
logger.debug("No relevant tickers found (below threshold)")
if self.group_by_news:
grouped_candidates = []
for news_entry in news_groups.values():
tickers = news_entry["tickers"]
if not tickers:
continue
avg_similarity = sum(t["similarity_score"] for t in tickers) / len(tickers)
aggregate_score = (
(news_entry["news_importance"] * 1.5)
+ (avg_similarity * 3.0)
+ (len(tickers) * 0.5)
)
grouped_candidates.append(
{
**news_entry,
"num_tickers": len(tickers),
"avg_similarity": round(avg_similarity, 3),
"aggregate_score": round(aggregate_score, 2),
}
)
grouped_candidates.sort(key=lambda x: x["aggregate_score"], reverse=True)
grouped_candidates = grouped_candidates[: self.max_total_candidates]
logger.info("📊 Aggregating and ranking news items...")
logger.info(f"Identified {len(grouped_candidates)} news items with tickers")
return grouped_candidates
# Step 3: Aggregate and rank candidates
logger.info("📊 Aggregating and ranking candidates...")
# Group by ticker and calculate aggregate scores
ticker_aggregates = {}
for ticker, news_matches in news_ticker_map.items():
# Calculate aggregate score
# Factors: number of news matches, importance, similarity
num_matches = len(news_matches)
avg_importance = sum(n["news_importance"] for n in news_matches) / num_matches
avg_similarity = sum(n["similarity_score"] for n in news_matches) / num_matches
max_importance = max(n["news_importance"] for n in news_matches)
# Weighted score
aggregate_score = (
(num_matches * 2.0) # More news = higher score
+ (avg_importance * 1.5) # Average importance
+ (avg_similarity * 3.0) # Similarity strength
+ (max_importance * 1.0) # Bonus for having one very important match
)
ticker_aggregates[ticker] = {
"ticker": ticker,
"num_news_matches": num_matches,
"avg_importance": round(avg_importance, 2),
"avg_similarity": round(avg_similarity, 3),
"max_importance": max_importance,
"aggregate_score": round(aggregate_score, 2),
"news_matches": news_matches,
}
# Sort by aggregate score
ranked_candidates = sorted(
ticker_aggregates.values(), key=lambda x: x["aggregate_score"], reverse=True
)
# Limit to max candidates
ranked_candidates = ranked_candidates[: self.max_total_candidates]
logger.info(f"Identified {len(ranked_candidates)} unique ticker candidates")
return ranked_candidates
def format_discovery_report(self, candidates: List[Dict[str, Any]]) -> str:
"""
Format discovery results as a readable report.
Args:
candidates: List of ranked candidates
Returns:
Formatted text report
"""
if not candidates:
return "No opportunities discovered."
if "tickers" in candidates[0]:
report = "\n" + "=" * 60
report += "\n📰 NEWS-DRIVEN RESULTS"
report += "\n" + "=" * 60 + "\n"
for i, news in enumerate(candidates, 1):
title = news["news_title"]
score = news["aggregate_score"]
num_tickers = news["num_tickers"]
importance = news["news_importance"]
report += f"\n{i}. {title}"
report += f"\n Score: {score:.2f} | Tickers: {num_tickers} | Importance: {importance}/10"
report += f"\n Source: {news.get('news_source', 'unknown')}"
if news.get("news_themes"):
report += f"\n Themes: {', '.join(news['news_themes'])}"
if news.get("news_summary"):
report += f"\n Summary: {news['news_summary']}"
if news.get("mentioned_tickers"):
report += f"\n Mentioned Tickers: {', '.join(news['mentioned_tickers'])}"
tickers = sorted(news["tickers"], key=lambda x: x["similarity_score"], reverse=True)
report += "\n Related Tickers:"
for j, ticker_info in enumerate(tickers[:5], 1):
report += (
f"\n {j}. {ticker_info['ticker']} "
f"(similarity: {ticker_info['similarity_score']:.3f})"
)
if len(tickers) > 5:
report += f"\n ... and {len(tickers) - 5} more"
report += "\n"
return report
report = "\n" + "=" * 60
report += "\n🎯 SEMANTIC DISCOVERY RESULTS"
report += "\n" + "=" * 60 + "\n"
for i, candidate in enumerate(candidates, 1):
ticker = candidate["ticker"]
score = candidate["aggregate_score"]
num_matches = candidate["num_news_matches"]
avg_importance = candidate["avg_importance"]
report += f"\n{i}. {ticker}"
report += f"\n Score: {score:.2f} | Matches: {num_matches} | Avg Importance: {avg_importance}/10"
report += "\n Related News:"
for j, news in enumerate(candidate["news_matches"][:3], 1): # Show top 3 news
report += f"\n {j}. {news['news_title']}"
report += f"\n Similarity: {news['similarity_score']:.3f} | Importance: {news['news_importance']}/10"
if news.get("news_themes"):
report += f"\n Themes: {', '.join(news['news_themes'])}"
if len(candidate["news_matches"]) > 3:
report += f"\n ... and {len(candidate['news_matches']) - 3} more"
report += "\n"
return report
def main():
"""CLI for running semantic discovery."""
import argparse
import json
parser = argparse.ArgumentParser(description="Run semantic discovery")
parser.add_argument(
"--news-sources",
nargs="+",
default=["openai"],
choices=["openai", "google_news", "sec_filings", "alpha_vantage", "gemini_search"],
help="News sources to use",
)
parser.add_argument(
"--min-importance", type=int, default=5, help="Minimum news importance (1-10)"
)
parser.add_argument(
"--min-similarity", type=float, default=0.2, help="Minimum similarity threshold (0-1)"
)
parser.add_argument(
"--max-candidates", type=int, default=15, help="Maximum ticker candidates to return"
)
parser.add_argument(
"--lookback-hours",
type=int,
default=24,
help="How far back to look for news (in hours). Examples: 1, 6, 24, 168",
)
parser.add_argument("--output", type=str, help="Output file for results JSON")
parser.add_argument(
"--group-by-news", action="store_true", help="Group results by news item instead of ticker"
)
args = parser.parse_args()
# Load project config
from tradingagents.default_config import DEFAULT_CONFIG
config = {
"project_dir": DEFAULT_CONFIG["project_dir"],
"use_openai_embeddings": True,
"news_sources": args.news_sources,
"news_lookback_hours": args.lookback_hours,
"min_news_importance": args.min_importance,
"min_similarity_threshold": args.min_similarity,
"max_tickers_per_news": 5,
"max_total_candidates": args.max_candidates,
"news_sentiment_filter": "positive",
"group_by_news": args.group_by_news,
}
# Run discovery
discovery = SemanticDiscovery(config)
candidates = discovery.discover()
# Display report
report = discovery.format_discovery_report(candidates)
logger.info(report)
# Save to file if specified
if args.output:
with open(args.output, "w") as f:
json.dump(candidates, f, indent=2)
logger.info(f"✅ Saved {len(candidates)} candidates to {args.output}")
if __name__ == "__main__":
main()

View File

@ -1,9 +1,10 @@
import pandas as pd
import yfinance as yf
from stockstats import wrap
from typing import Annotated
import os
from .config import get_config, DATA_DIR
from typing import Annotated
import pandas as pd
from stockstats import wrap
from .config import DATA_DIR, get_config
class StockstatsUtils:
@ -13,9 +14,7 @@ class StockstatsUtils:
indicator: Annotated[
str, "quantitative indicators based off of the stock data for the company"
],
curr_date: Annotated[
str, "curr date for retrieving stock price data, YYYY-mm-dd"
],
curr_date: Annotated[str, "curr date for retrieving stock price data, YYYY-mm-dd"],
):
# Get config and set up data directory path
config = get_config()
@ -57,7 +56,9 @@ class StockstatsUtils:
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
else:
data = yf.download(
from .y_finance import download_history
data = download_history(
symbol,
start=start_date,
end=end_date,

View File

@ -0,0 +1,476 @@
from typing import List
import pandas as pd
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
class TechnicalAnalyst:
"""
Performs comprehensive technical analysis on stock data.
"""
def __init__(self, df: pd.DataFrame, current_price: float):
"""
Initialize with stock dataframe and current price.
Args:
df: DataFrame with stock data (must contain 'close', 'high', 'low', 'volume')
current_price: The latest price of the stock
"""
self.df = df
self.current_price = current_price
self.analysis_report = []
def add_section(self, title: str, content: List[str]):
"""Add a formatted section to the report."""
self.analysis_report.append(f"## {title}")
self.analysis_report.extend(content)
self.analysis_report.append("")
def analyze_price_action(self):
"""Analyze recent price movements."""
latest = self.df.iloc[-1]
prev = self.df.iloc[-2] if len(self.df) > 1 else latest
prev_5 = self.df.iloc[-5] if len(self.df) > 5 else latest
daily_change = ((self.current_price - float(prev["close"])) / float(prev["close"])) * 100
weekly_change = (
(self.current_price - float(prev_5["close"])) / float(prev_5["close"])
) * 100
self.add_section(
"Price Action",
[
f"- **Daily Change:** {daily_change:+.2f}%",
f"- **5-Day Change:** {weekly_change:+.2f}%",
],
)
def analyze_rsi(self):
"""Analyze Relative Strength Index."""
try:
self.df["rsi"] # Trigger calculation
rsi = float(self.df.iloc[-1]["rsi"])
rsi_prev = float(self.df.iloc[-5]["rsi"]) if len(self.df) > 5 else rsi
if rsi > 70:
rsi_signal = "OVERBOUGHT ⚠️"
elif rsi < 30:
rsi_signal = "OVERSOLD ⚡"
elif rsi > 50:
rsi_signal = "Bullish"
else:
rsi_signal = "Bearish"
rsi_trend = "" if rsi > rsi_prev else ""
self.add_section(
"RSI (14)", [f"- **Value:** {rsi:.1f} {rsi_trend}", f"- **Signal:** {rsi_signal}"]
)
except Exception as e:
logger.warning(f"RSI analysis failed: {e}")
def analyze_macd(self):
"""Analyze MACD."""
try:
self.df["macd"]
self.df["macds"]
self.df["macdh"]
macd = float(self.df.iloc[-1]["macd"])
signal = float(self.df.iloc[-1]["macds"])
histogram = float(self.df.iloc[-1]["macdh"])
hist_prev = float(self.df.iloc[-2]["macdh"]) if len(self.df) > 1 else histogram
if macd > signal and histogram > 0:
macd_signal = "BULLISH CROSSOVER ⚡" if histogram > hist_prev else "Bullish"
elif macd < signal and histogram < 0:
macd_signal = "BEARISH CROSSOVER ⚠️" if histogram < hist_prev else "Bearish"
else:
macd_signal = "Neutral"
momentum = "Strengthening ↑" if abs(histogram) > abs(hist_prev) else "Weakening ↓"
self.add_section(
"MACD",
[
f"- **MACD Line:** {macd:.3f}",
f"- **Signal Line:** {signal:.3f}",
f"- **Histogram:** {histogram:.3f} ({momentum})",
f"- **Signal:** {macd_signal}",
],
)
except Exception as e:
logger.warning(f"MACD analysis failed: {e}")
def analyze_moving_averages(self):
"""Analyze Moving Averages."""
try:
self.df["close_50_sma"]
self.df["close_200_sma"]
sma_50 = float(self.df.iloc[-1]["close_50_sma"])
sma_200 = float(self.df.iloc[-1]["close_200_sma"])
# Trend determination
if self.current_price > sma_50 > sma_200:
trend = "STRONG UPTREND ⚡"
elif self.current_price > sma_50:
trend = "Uptrend"
elif self.current_price < sma_50 < sma_200:
trend = "STRONG DOWNTREND ⚠️"
elif self.current_price < sma_50:
trend = "Downtrend"
else:
trend = "Sideways"
# Golden/Death cross detection
sma_50_prev = float(self.df.iloc[-5]["close_50_sma"]) if len(self.df) > 5 else sma_50
sma_200_prev = float(self.df.iloc[-5]["close_200_sma"]) if len(self.df) > 5 else sma_200
cross = ""
if sma_50 > sma_200 and sma_50_prev < sma_200_prev:
cross = " (GOLDEN CROSS ⚡)"
elif sma_50 < sma_200 and sma_50_prev > sma_200_prev:
cross = " (DEATH CROSS ⚠️)"
self.add_section(
"Moving Averages",
[
f"- **50 SMA:** ${sma_50:.2f} ({'+' if self.current_price > sma_50 else ''}{((self.current_price - sma_50) / sma_50 * 100):.1f}% from price)",
f"- **200 SMA:** ${sma_200:.2f} ({'+' if self.current_price > sma_200 else ''}{((self.current_price - sma_200) / sma_200 * 100):.1f}% from price)",
f"- **Trend:** {trend}{cross}",
],
)
except Exception as e:
logger.warning(f"Moving averages analysis failed: {e}")
def analyze_bollinger_bands(self):
"""Analyze Bollinger Bands."""
try:
self.df["boll"]
self.df["boll_ub"]
self.df["boll_lb"]
middle = float(self.df.iloc[-1]["boll"])
upper = float(self.df.iloc[-1]["boll_ub"])
lower = float(self.df.iloc[-1]["boll_lb"])
band_position = (
(self.current_price - lower) / (upper - lower) if upper != lower else 0.5
)
if band_position > 0.95:
bb_signal = "AT UPPER BAND - Potential reversal ⚠️"
elif band_position < 0.05:
bb_signal = "AT LOWER BAND - Potential bounce ⚡"
elif band_position > 0.8:
bb_signal = "Near upper band"
elif band_position < 0.2:
bb_signal = "Near lower band"
else:
bb_signal = "Within bands"
bandwidth = ((upper - lower) / middle) * 100
self.add_section(
"Bollinger Bands (20,2)",
[
f"- **Upper:** ${upper:.2f}",
f"- **Middle:** ${middle:.2f}",
f"- **Lower:** ${lower:.2f}",
f"- **Band Position:** {band_position:.0%}",
f"- **Bandwidth:** {bandwidth:.1f}% (volatility indicator)",
f"- **Signal:** {bb_signal}",
],
)
except Exception as e:
logger.warning(f"Bollinger bands analysis failed: {e}")
def analyze_atr(self):
"""Analyze ATR (Volatility)."""
try:
self.df["atr"]
atr = float(self.df.iloc[-1]["atr"])
atr_pct = (atr / self.current_price) * 100
if atr_pct > 5:
vol_level = "HIGH VOLATILITY ⚠️"
elif atr_pct > 2:
vol_level = "Moderate volatility"
else:
vol_level = "Low volatility"
self.add_section(
"ATR (Volatility)",
[
f"- **ATR:** ${atr:.2f} ({atr_pct:.1f}% of price)",
f"- **Level:** {vol_level}",
f"- **Suggested Stop-Loss:** ${self.current_price - (1.5 * atr):.2f} (1.5x ATR)",
],
)
except Exception as e:
logger.warning(f"ATR analysis failed: {e}")
def analyze_stochastic(self):
"""Analyze Stochastic Oscillator."""
try:
self.df["kdjk"]
self.df["kdjd"]
stoch_k = float(self.df.iloc[-1]["kdjk"])
stoch_d = float(self.df.iloc[-1]["kdjd"])
stoch_k_prev = float(self.df.iloc[-2]["kdjk"]) if len(self.df) > 1 else stoch_k
if stoch_k > 80 and stoch_d > 80:
stoch_signal = "OVERBOUGHT ⚠️"
elif stoch_k < 20 and stoch_d < 20:
stoch_signal = "OVERSOLD ⚡"
elif stoch_k > stoch_d and stoch_k_prev < stoch_d:
stoch_signal = "Bullish crossover ⚡"
elif stoch_k < stoch_d and stoch_k_prev > stoch_d:
stoch_signal = "Bearish crossover ⚠️"
elif stoch_k > 50:
stoch_signal = "Bullish"
else:
stoch_signal = "Bearish"
self.add_section(
"Stochastic (14,3,3)",
[
f"- **%K:** {stoch_k:.1f}",
f"- **%D:** {stoch_d:.1f}",
f"- **Signal:** {stoch_signal}",
],
)
except Exception as e:
logger.warning(f"Stochastic analysis failed: {e}")
def analyze_adx(self):
"""Analyze ADX (Trend Strength)."""
try:
self.df["adx"]
adx = float(self.df.iloc[-1]["adx"])
adx_prev = float(self.df.iloc[-5]["adx"]) if len(self.df) > 5 else adx
if adx > 50:
trend_strength = "VERY STRONG TREND ⚡"
elif adx > 25:
trend_strength = "Strong trend"
elif adx > 20:
trend_strength = "Trending"
else:
trend_strength = "WEAK/NO TREND (range-bound) ⚠️"
adx_direction = "Strengthening ↑" if adx > adx_prev else "Weakening ↓"
self.add_section(
"ADX (Trend Strength)",
[
f"- **ADX:** {adx:.1f} ({adx_direction})",
f"- **Interpretation:** {trend_strength}",
],
)
except Exception as e:
logger.warning(f"ADX analysis failed: {e}")
def analyze_ema(self):
"""Analyze 20 EMA."""
try:
self.df["close_20_ema"]
ema_20 = float(self.df.iloc[-1]["close_20_ema"])
pct_from_ema = ((self.current_price - ema_20) / ema_20) * 100
if self.current_price > ema_20:
ema_signal = "Price ABOVE 20 EMA (short-term bullish)"
else:
ema_signal = "Price BELOW 20 EMA (short-term bearish)"
self.add_section(
"20 EMA",
[
f"- **Value:** ${ema_20:.2f} ({pct_from_ema:+.1f}% from price)",
f"- **Signal:** {ema_signal}",
],
)
except Exception as e:
logger.warning(f"EMA analysis failed: {e}")
def analyze_obv(self):
"""Analyze On-Balance Volume."""
try:
# Check if we have enough data
if len(self.df) < 2:
logger.warning("Insufficient data for OBV analysis (need at least 2 days)")
return
obv = 0
obv_values = [0]
for i in range(1, len(self.df)):
if float(self.df.iloc[i]["close"]) > float(self.df.iloc[i - 1]["close"]):
obv += float(self.df.iloc[i]["volume"])
elif float(self.df.iloc[i]["close"]) < float(self.df.iloc[i - 1]["close"]):
obv -= float(self.df.iloc[i]["volume"])
obv_values.append(obv)
current_obv = obv_values[-1]
obv_5_ago = obv_values[-5] if len(obv_values) > 5 else obv_values[0]
# Check if we have enough data for price comparison
if len(self.df) >= 5:
price_5_ago = float(self.df.iloc[-5]["close"])
else:
price_5_ago = float(self.df.iloc[0]["close"])
if current_obv > obv_5_ago and self.current_price > price_5_ago:
obv_signal = "Confirmed uptrend (price & volume rising)"
elif current_obv < obv_5_ago and self.current_price < price_5_ago:
obv_signal = "Confirmed downtrend (price & volume falling)"
elif current_obv > obv_5_ago and self.current_price < price_5_ago:
obv_signal = "BULLISH DIVERGENCE ⚡ (accumulation)"
elif current_obv < obv_5_ago and self.current_price > price_5_ago:
obv_signal = "BEARISH DIVERGENCE ⚠️ (distribution)"
else:
obv_signal = "Neutral"
obv_formatted = (
f"{current_obv/1e6:.1f}M" if abs(current_obv) > 1e6 else f"{current_obv/1e3:.1f}K"
)
self.add_section(
"OBV (On-Balance Volume)",
[
f"- **Value:** {obv_formatted}",
f"- **5-Day Trend:** {'Rising ↑' if current_obv > obv_5_ago else 'Falling ↓'}",
f"- **Signal:** {obv_signal}",
],
)
except Exception as e:
logger.warning(f"OBV analysis failed: {e}")
def analyze_vwap(self):
"""Analyze VWAP."""
try:
# Calculate VWAP for today (simplified - using recent data)
# Calculate cumulative VWAP (last 20 periods approximation)
recent_df = self.df.tail(20)
tp_vol = ((recent_df["high"] + recent_df["low"] + recent_df["close"]) / 3) * recent_df[
"volume"
]
vwap = float(tp_vol.sum() / recent_df["volume"].sum())
pct_from_vwap = ((self.current_price - vwap) / vwap) * 100
if self.current_price > vwap:
vwap_signal = "Price ABOVE VWAP (institutional buying)"
else:
vwap_signal = "Price BELOW VWAP (institutional selling)"
self.add_section(
"VWAP (20-period)",
[
f"- **VWAP:** ${vwap:.2f}",
f"- **Current vs VWAP:** {pct_from_vwap:+.1f}%",
f"- **Signal:** {vwap_signal}",
],
)
except Exception as e:
logger.warning(f"VWAP analysis failed: {e}")
def analyze_fibonacci(self):
"""Analyze Fibonacci Retracement."""
try:
# Get high and low from last 50 periods
recent_high = float(self.df.tail(50)["high"].max())
recent_low = float(self.df.tail(50)["low"].min())
diff = recent_high - recent_low
fib_levels = {
"0.0% (High)": recent_high,
"23.6%": recent_high - (diff * 0.236),
"38.2%": recent_high - (diff * 0.382),
"50.0%": recent_high - (diff * 0.5),
"61.8%": recent_high - (diff * 0.618),
"78.6%": recent_high - (diff * 0.786),
"100% (Low)": recent_low,
}
# Find nearest support and resistance
support = None
resistance = None
for level_name, level_price in fib_levels.items():
if level_price < self.current_price and (
support is None or level_price > support[1]
):
support = (level_name, level_price)
if level_price > self.current_price and (
resistance is None or level_price < resistance[1]
):
resistance = (level_name, level_price)
content = [
f"- **Recent High:** ${recent_high:.2f}",
f"- **Recent Low:** ${recent_low:.2f}",
]
if resistance:
content.append(f"- **Next Resistance:** ${resistance[1]:.2f} ({resistance[0]})")
if support:
content.append(f"- **Next Support:** ${support[1]:.2f} ({support[0]})")
self.add_section("Fibonacci Levels (50-period)", content)
except Exception as e:
logger.warning(f"Fibonacci analysis failed: {e}")
def generate_summary(self):
"""Generate final summary section."""
signals = []
try:
rsi = float(self.df.iloc[-1]["rsi"])
if rsi > 70:
signals.append("RSI overbought")
elif rsi < 30:
signals.append("RSI oversold")
except Exception:
pass
try:
if self.current_price > float(self.df.iloc[-1]["close_50_sma"]):
signals.append("Above 50 SMA")
else:
signals.append("Below 50 SMA")
except Exception:
pass
content = []
if signals:
content.append(f"- **Key Signals:** {', '.join(signals)}")
self.add_section("Summary", content)
def generate_report(self, symbol: str, date: str) -> str:
"""Run all analyses and generate the markdown report."""
self.df = self.df.copy() # Avoid modifying original
# Header
self.analysis_report = [
f"# Technical Analysis for {symbol.upper()}",
f"**Date:** {date}",
f"**Current Price:** ${self.current_price:.2f}",
"",
]
# Run analyses
self.analyze_price_action()
self.analyze_rsi()
self.analyze_macd()
self.analyze_moving_averages()
self.analyze_bollinger_bands()
self.analyze_atr()
self.analyze_stochastic()
self.analyze_adx()
self.analyze_ema()
self.analyze_obv()
self.analyze_vwap()
self.analyze_fibonacci()
self.generate_summary()
return "\n".join(self.analysis_report)

View File

@ -0,0 +1,395 @@
"""
Ticker Semantic Database
------------------------
Creates and maintains a database of ticker descriptions with embeddings
for semantic matching against news events.
This enables news-driven discovery by finding tickers semantically related
to breaking news, rather than waiting for social media buzz or price action.
"""
import json
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
import chromadb
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm
from tradingagents.dataflows.y_finance import get_ticker_info
from tradingagents.utils.logger import get_logger
# Load environment variables
load_dotenv()
logger = get_logger(__name__)
class TickerSemanticDB:
"""Manages ticker descriptions and embeddings for semantic search."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize the ticker semantic database.
Args:
config: Configuration dict with:
- project_dir: Base directory for storage
- use_openai_embeddings: If True, use OpenAI; else use local HF model
- embedding_model: Model name (default: text-embedding-3-small)
"""
self.config = config
self.use_openai = config.get("use_openai_embeddings", True)
# Setup embedding backend
if self.use_openai:
self.embedding_model = config.get("embedding_model", "text-embedding-3-small")
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("OPENAI_API_KEY not found in environment")
self.openai_client = OpenAI(api_key=openai_api_key)
self.embedding_dim = 1536 # OpenAI text-embedding-3-small dimension
else:
# TODO: Add local HuggingFace model support
# Use sentence-transformers with a good MTEB-ranked model
from sentence_transformers import SentenceTransformer
self.embedding_model = config.get("embedding_model", "BAAI/bge-small-en-v1.5")
self.local_model = SentenceTransformer(self.embedding_model)
self.embedding_dim = self.local_model.get_sentence_embedding_dimension()
# Setup ChromaDB for persistent storage
project_dir = config.get("project_dir", ".")
embedding_model_safe = self.embedding_model.replace("/", "_").replace(" ", "_")
db_dir = os.path.join(project_dir, "ticker_semantic_db", embedding_model_safe)
os.makedirs(db_dir, exist_ok=True)
self.chroma_client = chromadb.PersistentClient(path=db_dir)
# Get or create collection
collection_name = "ticker_descriptions"
try:
self.collection = self.chroma_client.get_collection(name=collection_name)
logger.info(f"Loaded existing ticker database: {self.collection.count()} tickers")
except Exception:
self.collection = self.chroma_client.create_collection(
name=collection_name,
metadata={"description": "Ticker descriptions with metadata for semantic search"},
)
logger.info("Created new ticker database collection")
def get_embedding(self, text: str) -> List[float]:
"""Generate embedding for text using configured backend."""
if self.use_openai:
response = self.openai_client.embeddings.create(model=self.embedding_model, input=text)
return response.data[0].embedding
else:
# Local HuggingFace model
embedding = self.local_model.encode(text, convert_to_numpy=True)
return embedding.tolist()
def fetch_ticker_info(self, symbol: str) -> Optional[Dict[str, Any]]:
"""
Fetch ticker information from Yahoo Finance.
Args:
symbol: Stock ticker symbol
Returns:
Dict with ticker metadata or None if fetch fails
"""
try:
info = get_ticker_info(symbol)
# Extract relevant fields
description = info.get("longBusinessSummary", "")
if not description:
# Fallback to shorter description if available
description = info.get("description", f"{symbol} - No description available")
# Build metadata dict
ticker_data = {
"symbol": symbol.upper(),
"name": info.get("longName", info.get("shortName", symbol)),
"description": description,
"industry": info.get("industry", "Unknown"),
"sector": info.get("sector", "Unknown"),
"market_cap": info.get("marketCap", 0),
"revenue": info.get("totalRevenue", 0),
"country": info.get("country", "US"),
"website": info.get("website", ""),
"employees": info.get("fullTimeEmployees", 0),
"last_updated": datetime.now().isoformat(),
}
return ticker_data
except Exception as e:
logger.warning(f"Error fetching {symbol}: {e}")
return None
def add_ticker(self, symbol: str, force_refresh: bool = False) -> bool:
"""
Add a single ticker to the database.
Args:
symbol: Stock ticker symbol
force_refresh: If True, refresh even if ticker exists
Returns:
True if added successfully, False otherwise
"""
# Check if already exists
if not force_refresh:
try:
existing = self.collection.get(ids=[symbol.upper()])
if existing and existing["ids"]:
return True # Already exists
except Exception:
pass
# Fetch ticker info
ticker_data = self.fetch_ticker_info(symbol)
if not ticker_data:
return False
# Generate embedding from description
try:
embedding = self.get_embedding(ticker_data["description"])
except Exception as e:
logger.error(f"Error generating embedding for {symbol}: {e}")
return False
# Store in ChromaDB
try:
# Store description as document, metadata as metadata, embedding as embedding
self.collection.upsert(
ids=[symbol.upper()],
documents=[ticker_data["description"]],
embeddings=[embedding],
metadatas=[
{
"symbol": ticker_data["symbol"],
"name": ticker_data["name"],
"industry": ticker_data["industry"],
"sector": ticker_data["sector"],
"market_cap": ticker_data["market_cap"],
"revenue": ticker_data["revenue"],
"country": ticker_data["country"],
"website": ticker_data["website"],
"employees": ticker_data["employees"],
"last_updated": ticker_data["last_updated"],
}
],
)
return True
except Exception as e:
logger.error(f"Error storing {symbol}: {e}")
return False
def build_database(
self,
ticker_file: str,
max_tickers: Optional[int] = None,
skip_existing: bool = True,
batch_size: int = 100,
):
"""
Build the ticker database from a file.
Args:
ticker_file: Path to file with ticker symbols (one per line)
max_tickers: Maximum number of tickers to process (None = all)
skip_existing: If True, skip tickers already in DB
batch_size: Number of tickers to process before showing progress
"""
# Read ticker file
with open(ticker_file, "r") as f:
tickers = [line.strip().upper() for line in f if line.strip()]
if max_tickers:
tickers = tickers[:max_tickers]
logger.info("Building ticker semantic database...")
logger.info(f"Source: {ticker_file}")
logger.info(f"Total tickers: {len(tickers)}")
logger.info(f"Embedding model: {self.embedding_model}")
# Get existing tickers if skipping
existing_tickers = set()
if skip_existing:
try:
existing = self.collection.get(include=[])
existing_tickers = set(existing["ids"])
logger.info(f"Existing tickers in DB: {len(existing_tickers)}")
except Exception:
pass
# Process tickers
success_count = 0
skip_count = 0
fail_count = 0
for i, symbol in enumerate(tqdm(tickers, desc="Processing tickers")):
# Skip if exists
if skip_existing and symbol in existing_tickers:
skip_count += 1
continue
# Add ticker
if self.add_ticker(symbol, force_refresh=not skip_existing):
success_count += 1
else:
fail_count += 1
logger.info("Database build complete!")
logger.info(f"Success: {success_count}")
logger.info(f"Skipped: {skip_count}")
logger.info(f"Failed: {fail_count}")
logger.info(f"Total in DB: {self.collection.count()}")
def search_by_text(
self, query_text: str, top_k: int = 10, filters: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Search for tickers semantically related to query text.
Args:
query_text: Text to search for (e.g., news summary)
top_k: Number of top matches to return
filters: Optional metadata filters (e.g., {"sector": "Technology"})
Returns:
List of ticker matches with metadata and similarity scores
"""
# Generate embedding for query
query_embedding = self.get_embedding(query_text)
# Search ChromaDB
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=filters, # Apply metadata filters if provided
include=["documents", "metadatas", "distances"],
)
# Format results
matches = []
for i in range(len(results["ids"][0])):
distance = results["distances"][0][i]
similarity = 1 / (1 + distance)
match = {
"symbol": results["ids"][0][i],
"description": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"similarity_score": similarity, # Normalize distance to (0, 1]
}
matches.append(match)
return matches
def get_ticker_info(self, symbol: str) -> Optional[Dict[str, Any]]:
"""Get stored information for a specific ticker."""
try:
result = self.collection.get(ids=[symbol.upper()], include=["documents", "metadatas"])
if not result["ids"]:
return None
return {
"symbol": result["ids"][0],
"description": result["documents"][0],
"metadata": result["metadatas"][0],
}
except Exception:
return None
def get_stats(self) -> Dict[str, Any]:
"""Get database statistics."""
try:
count = self.collection.count()
# Get sector breakdown
all_data = self.collection.get(include=["metadatas"])
sectors = {}
industries = {}
for metadata in all_data["metadatas"]:
sector = metadata.get("sector", "Unknown")
industry = metadata.get("industry", "Unknown")
sectors[sector] = sectors.get(sector, 0) + 1
industries[industry] = industries.get(industry, 0) + 1
return {
"total_tickers": count,
"sectors": sectors,
"industries": industries,
"embedding_model": self.embedding_model,
"embedding_dimension": self.embedding_dim,
}
except Exception as e:
return {"error": str(e)}
def main():
"""CLI for building/managing the ticker database."""
import argparse
parser = argparse.ArgumentParser(description="Build ticker semantic database")
parser.add_argument("--ticker-file", default="data/tickers.txt", help="Path to ticker file")
parser.add_argument(
"--max-tickers", type=int, default=None, help="Maximum tickers to process (default: all)"
)
parser.add_argument(
"--use-local",
action="store_true",
help="Use local HuggingFace embeddings instead of OpenAI",
)
parser.add_argument(
"--force-refresh", action="store_true", help="Refresh all tickers even if they exist"
)
parser.add_argument("--stats", action="store_true", help="Show database statistics")
parser.add_argument("--search", type=str, help="Search for tickers by text query")
args = parser.parse_args()
# Load config
from tradingagents.default_config import DEFAULT_CONFIG
config = {
"project_dir": DEFAULT_CONFIG["project_dir"],
"use_openai_embeddings": not args.use_local,
}
# Initialize database
db = TickerSemanticDB(config)
# Execute command
if args.stats:
stats = db.get_stats()
logger.info("📊 Database Statistics:")
logger.info(json.dumps(stats, indent=2))
elif args.search:
logger.info(f"🔍 Searching for: {args.search}")
matches = db.search_by_text(args.search, top_k=10)
logger.info("Top matches:")
for i, match in enumerate(matches, 1):
logger.info(f"{i}. {match['symbol']} - {match['metadata']['name']}")
logger.debug(f" Sector: {match['metadata']['sector']}")
logger.debug(f" Similarity: {match['similarity_score']:.3f}")
logger.debug(f" Description: {match['description'][:150]}...")
else:
# Build database
db.build_database(
ticker_file=args.ticker_file,
max_tickers=args.max_tickers,
skip_existing=not args.force_refresh,
)
if __name__ == "__main__":
main()

View File

@ -4,10 +4,13 @@ Detects unusual options activity indicating smart money positioning
"""
import os
import requests
from datetime import datetime
from typing import Annotated, List
import requests
from tradingagents.config import config
from tradingagents.dataflows.market_data_utils import format_markdown_table
def get_unusual_options_activity(
tickers: Annotated[List[str], "List of ticker symbols to analyze"] = None,
@ -33,9 +36,10 @@ def get_unusual_options_activity(
Returns:
Formatted markdown report of unusual options activity
"""
api_key = os.getenv("TRADIER_API_KEY")
if not api_key:
return "Error: TRADIER_API_KEY not set in environment variables. Get a free key at https://tradier.com"
try:
api_key = config.validate_key("tradier_api_key", "Tradier")
except ValueError as e:
return f"Error: {str(e)}"
if not tickers or len(tickers) == 0:
return "Error: No tickers provided. This function analyzes options activity for specific tickers found by other discovery methods."
@ -45,10 +49,7 @@ def get_unusual_options_activity(
# Use production: https://api.tradier.com
base_url = os.getenv("TRADIER_BASE_URL", "https://sandbox.tradier.com")
headers = {
"Authorization": f"Bearer {api_key}",
"Accept": "application/json"
}
headers = {"Authorization": f"Bearer {api_key}", "Accept": "application/json"}
try:
# Strategy: Analyze options activity for provided tickers
@ -63,7 +64,7 @@ def get_unusual_options_activity(
params = {
"symbol": ticker,
"expiration": "", # Will get nearest expiration
"greeks": "true"
"greeks": "true",
}
response = requests.get(options_url, headers=headers, params=params, timeout=10)
@ -96,7 +97,9 @@ def get_unusual_options_activity(
total_volume = total_call_volume + total_put_volume
if total_volume > 10000: # Significant volume threshold
put_call_ratio = total_put_volume / total_call_volume if total_call_volume > 0 else 0
put_call_ratio = (
total_put_volume / total_call_volume if total_call_volume > 0 else 0
)
# Unusual signals:
# - Very low P/C ratio (<0.7) = Bullish (heavy call buying)
@ -111,46 +114,52 @@ def get_unusual_options_activity(
elif total_volume > 50000:
signal = "high_volume"
unusual_activity.append({
"ticker": ticker,
"total_volume": total_volume,
"call_volume": total_call_volume,
"put_volume": total_put_volume,
"put_call_ratio": put_call_ratio,
"signal": signal,
"call_oi": total_call_oi,
"put_oi": total_put_oi,
})
unusual_activity.append(
{
"ticker": ticker,
"total_volume": total_volume,
"call_volume": total_call_volume,
"put_volume": total_put_volume,
"put_call_ratio": put_call_ratio,
"signal": signal,
"call_oi": total_call_oi,
"put_oi": total_put_oi,
}
)
except Exception as e:
except Exception:
# Skip this ticker if there's an error
continue
# Sort by total volume (highest first)
sorted_activity = sorted(
unusual_activity,
key=lambda x: x["total_volume"],
reverse=True
)[:top_n]
sorted_activity = sorted(unusual_activity, key=lambda x: x["total_volume"], reverse=True)[
:top_n
]
# Format output
if not sorted_activity:
return "No unusual options activity detected"
report = f"# Unusual Options Activity - {date or 'Latest'}\n\n"
report += f"**Criteria**: P/C Ratio extremes (<0.7 bullish, >1.5 bearish), High volume (>50k)\n\n"
report += (
"**Criteria**: P/C Ratio extremes (<0.7 bullish, >1.5 bearish), High volume (>50k)\n\n"
)
report += f"**Found**: {len(sorted_activity)} stocks with notable options activity\n\n"
report += "## Top Options Activity\n\n"
report += "| Ticker | Total Volume | Call Vol | Put Vol | P/C Ratio | Signal |\n"
report += "|--------|--------------|----------|---------|-----------|--------|\n"
for activity in sorted_activity:
report += f"| {activity['ticker']} | "
report += f"{activity['total_volume']:,} | "
report += f"{activity['call_volume']:,} | "
report += f"{activity['put_volume']:,} | "
report += f"{activity['put_call_ratio']:.2f} | "
report += f"{activity['signal']} |\n"
report += format_markdown_table(
["Ticker", "Total Volume", "Call Vol", "Put Vol", "P/C Ratio", "Signal"],
[
[
a["ticker"],
f"{a['total_volume']:,}",
f"{a['call_volume']:,}",
f"{a['put_volume']:,}",
f"{a['put_call_ratio']:.2f}",
a["signal"],
]
for a in sorted_activity
],
)
report += "\n\n## Signal Definitions\n\n"
report += "- **bullish_calls**: P/C ratio <0.7 - Heavy call buying, bullish positioning\n"

View File

@ -1,11 +1,16 @@
import os
import tweepy
import json
import time
from datetime import datetime
from pathlib import Path
import tweepy
from dotenv import load_dotenv
from tradingagents.config import config
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Load environment variables
load_dotenv()
@ -16,10 +21,12 @@ USAGE_FILE = DATA_DIR / ".twitter_usage.json"
MONTHLY_LIMIT = 200
CACHE_DURATION_HOURS = 4
def _ensure_data_dir():
"""Ensure the data directory exists."""
DATA_DIR.mkdir(exist_ok=True)
def _load_json(file_path: Path) -> dict:
"""Load JSON data from a file, returning empty dict if not found."""
if not file_path.exists():
@ -30,6 +37,7 @@ def _load_json(file_path: Path) -> dict:
except (json.JSONDecodeError, IOError):
return {}
def _save_json(file_path: Path, data: dict):
"""Save dictionary to a JSON file."""
_ensure_data_dir()
@ -37,57 +45,62 @@ def _save_json(file_path: Path, data: dict):
with open(file_path, "w") as f:
json.dump(data, f, indent=2)
except IOError as e:
print(f"Warning: Could not save to {file_path}: {e}")
logger.warning(f"Could not save to {file_path}: {e}")
def _get_cache_key(prefix: str, identifier: str) -> str:
"""Generate a cache key."""
return f"{prefix}:{identifier}"
def _is_cache_valid(timestamp: float) -> bool:
"""Check if the cached entry is still valid."""
age_hours = (time.time() - timestamp) / 3600
return age_hours < CACHE_DURATION_HOURS
def _check_usage_limit() -> bool:
"""Check if the monthly usage limit has been reached."""
usage_data = _load_json(USAGE_FILE)
current_month = datetime.now().strftime("%Y-%m")
# Reset usage if it's a new month
if usage_data.get("month") != current_month:
usage_data = {"month": current_month, "count": 0}
_save_json(USAGE_FILE, usage_data)
return True
return usage_data.get("count", 0) < MONTHLY_LIMIT
def _increment_usage():
"""Increment the usage counter."""
usage_data = _load_json(USAGE_FILE)
current_month = datetime.now().strftime("%Y-%m")
if usage_data.get("month") != current_month:
usage_data = {"month": current_month, "count": 0}
usage_data["count"] = usage_data.get("count", 0) + 1
_save_json(USAGE_FILE, usage_data)
def get_tweets(query: str, count: int = 10) -> str:
"""
Fetches recent tweets matching the query using Twitter API v2.
Includes caching and rate limiting.
Args:
query (str): The search query (e.g., "AAPL", "Bitcoin").
count (int): Number of tweets to retrieve (default 10).
Returns:
str: A formatted string containing the tweets or an error message.
"""
# 1. Check Cache
cache_key = _get_cache_key("search", query)
cache = _load_json(CACHE_FILE)
if cache_key in cache:
entry = cache[cache_key]
if _is_cache_valid(entry["timestamp"]):
@ -97,26 +110,23 @@ def get_tweets(query: str, count: int = 10) -> str:
if not _check_usage_limit():
return "Error: Monthly Twitter API usage limit (200 calls) reached."
bearer_token = os.getenv("TWITTER_BEARER_TOKEN")
if not bearer_token:
return "Error: TWITTER_BEARER_TOKEN not found in environment variables."
bearer_token = config.validate_key("twitter_bearer_token", "Twitter")
try:
client = tweepy.Client(bearer_token=bearer_token)
# Search for recent tweets
safe_count = max(10, min(count, 100))
response = client.search_recent_tweets(
query=query,
query=query,
max_results=safe_count,
tweet_fields=['created_at', 'author_id', 'public_metrics']
tweet_fields=["created_at", "author_id", "public_metrics"],
)
# 3. Increment Usage
_increment_usage()
if not response.data:
result = f"No tweets found for query: {query}"
else:
@ -130,33 +140,31 @@ def get_tweets(query: str, count: int = 10) -> str:
result = formatted_tweets
# 4. Save to Cache
cache[cache_key] = {
"timestamp": time.time(),
"data": result
}
cache[cache_key] = {"timestamp": time.time(), "data": result}
_save_json(CACHE_FILE, cache)
return result
except Exception as e:
return f"Error fetching tweets: {str(e)}"
def get_tweets_from_user(username: str, count: int = 10) -> str:
"""
Fetches recent tweets from a specific user using Twitter API v2.
Includes caching and rate limiting.
Args:
username (str): The Twitter username (without @).
count (int): Number of tweets to retrieve (default 10).
Returns:
str: A formatted string containing the tweets or an error message.
"""
# 1. Check Cache
cache_key = _get_cache_key("user", username)
cache = _load_json(CACHE_FILE)
if cache_key in cache:
entry = cache[cache_key]
if _is_cache_valid(entry["timestamp"]):
@ -166,33 +174,28 @@ def get_tweets_from_user(username: str, count: int = 10) -> str:
if not _check_usage_limit():
return "Error: Monthly Twitter API usage limit (200 calls) reached."
bearer_token = os.getenv("TWITTER_BEARER_TOKEN")
if not bearer_token:
return "Error: TWITTER_BEARER_TOKEN not found in environment variables."
bearer_token = config.validate_key("twitter_bearer_token", "Twitter")
try:
client = tweepy.Client(bearer_token=bearer_token)
# First, get the user ID
user = client.get_user(username=username)
if not user.data:
return f"Error: User '@{username}' not found."
user_id = user.data.id
# max_results must be between 5 and 100 for get_users_tweets
safe_count = max(5, min(count, 100))
response = client.get_users_tweets(
id=user_id,
max_results=safe_count,
tweet_fields=['created_at', 'public_metrics']
id=user_id, max_results=safe_count, tweet_fields=["created_at", "public_metrics"]
)
# 3. Increment Usage
_increment_usage()
if not response.data:
result = f"No recent tweets found for user: @{username}"
else:
@ -204,16 +207,12 @@ def get_tweets_from_user(username: str, count: int = 10) -> str:
formatted_tweets += f" (Likes: {metrics.get('like_count', 0)}, Retweets: {metrics.get('retweet_count', 0)})\n"
formatted_tweets += "\n"
result = formatted_tweets
# 4. Save to Cache
cache[cache_key] = {
"timestamp": time.time(),
"data": result
}
cache[cache_key] = {"timestamp": time.time(), "data": result}
_save_json(CACHE_FILE, cache)
return result
except Exception as e:
return f"Error fetching tweets from user @{username}: {str(e)}"

View File

@ -1,15 +1,19 @@
import os
import json
import pandas as pd
from datetime import date, timedelta, datetime
from datetime import date, datetime, timedelta
from typing import Annotated
import pandas as pd
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
if save_path:
data.to_csv(save_path)
print(f"{tag} saved to {save_path}")
logger.info(f"{tag} saved to {save_path}")
def get_current_date():

Some files were not shown because too many files have changed in this diff Show More