diff --git a/.gitignore b/.gitignore
index 3369bad9..17ceaeb9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,3 +9,13 @@ eval_results/
eval_data/
*.egg-info/
.env
+.claude/
+.pytest_cache/
+.specify/
+specs/
+agent-os/
+*.local.md
+build/
+.mcp.json
+*.zip
+todos.md
diff --git a/README.md b/README.md
index 7e90c60f..b8cc9304 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,7 @@
# TradingAgents: Multi-Agents LLM Financial Trading Framework
-> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
+> **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community.
>
> So we decided to fully open-source the framework. Looking forward to building impactful projects with you!
@@ -43,7 +43,7 @@
-🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation)
+[TradingAgents](#tradingagents-framework) | [Installation & CLI](#installation-and-cli) | [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | [Package Usage](#tradingagents-package) | [Contributing](#contributing) | [Citation](#citation)
@@ -101,15 +101,10 @@ git clone https://github.com/TauricResearch/TradingAgents.git
cd TradingAgents
```
-Create a virtual environment in any of your favorite environment managers:
+Sync virtual environment:
```bash
-conda create -n tradingagents python=3.13
-conda activate tradingagents
-```
-
-Install dependencies:
-```bash
-pip install -r requirements.txt
+uv sync
+uv source .venv/bin/activate
```
### Required APIs
@@ -133,7 +128,7 @@ cp .env.example .env
You can also try out the CLI directly by running:
```bash
-python -m cli.main
+uv run cli/main.py
```
You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc.
@@ -204,13 +199,9 @@ print(decision)
You can view the full list of configurations in `tradingagents/default_config.py`.
-## Contributing
+## Source
-We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
-
-## Citation
-
-Please reference our work if you find *TradingAgents* provides you with some help :)
+Thanks to Yijia Xiao and Edward Sun and Di Luo and Wei Wang. Core agent implementation based on [TradingAgents: Multi-Agents LLM Financial Trading Framework](https://arxiv.org/abs/2412.20138)
```
@misc{xiao2025tradingagentsmultiagentsllmfinancial,
diff --git a/cli/main.py b/cli/main.py
index 2e06d50c..d5efcf3c 100644
--- a/cli/main.py
+++ b/cli/main.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, List
import datetime
import typer
from pathlib import Path
@@ -6,7 +6,6 @@ from functools import wraps
from rich.console import Console
from dotenv import load_dotenv
-# Load environment variables from .env file
load_dotenv()
from rich.panel import Panel
from rich.spinner import Spinner
@@ -15,7 +14,6 @@ from rich.columns import Columns
from rich.markdown import Markdown
from rich.layout import Layout
from rich.text import Text
-from rich.live import Live
from rich.table import Table
from collections import deque
import time
@@ -23,9 +21,19 @@ from rich.tree import Tree
from rich import box
from rich.align import Align
from rich.rule import Rule
+import questionary
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
+from tradingagents.agents.discovery.models import (
+ DiscoveryRequest,
+ DiscoveryResult,
+ DiscoveryStatus,
+ TrendingStock,
+ Sector,
+ EventCategory,
+)
+from tradingagents.agents.discovery.persistence import save_discovery_result
from cli.models import AnalystType
from cli.utils import *
@@ -34,34 +42,28 @@ console = Console()
app = typer.Typer(
name="TradingAgents",
help="TradingAgents CLI: Multi-Agents LLM Financial Trading Framework",
- add_completion=True, # Enable shell completion
+ add_completion=True,
)
-# Create a deque to store recent messages with a maximum length
class MessageBuffer:
def __init__(self, max_length=100):
self.messages = deque(maxlen=max_length)
self.tool_calls = deque(maxlen=max_length)
self.current_report = None
- self.final_report = None # Store the complete final report
+ self.final_report = None
self.agent_status = {
- # Analyst Team
"Market Analyst": "pending",
"Social Analyst": "pending",
"News Analyst": "pending",
"Fundamentals Analyst": "pending",
- # Research Team
"Bull Researcher": "pending",
"Bear Researcher": "pending",
"Research Manager": "pending",
- # Trading Team
"Trader": "pending",
- # Risk Management Team
"Risky Analyst": "pending",
"Neutral Analyst": "pending",
"Safe Analyst": "pending",
- # Portfolio Management Team
"Portfolio Manager": "pending",
}
self.current_agent = None
@@ -94,18 +96,15 @@ class MessageBuffer:
self._update_current_report()
def _update_current_report(self):
- # For the panel display, only show the most recently updated section
latest_section = None
latest_content = None
- # Find the most recently updated section
for section, content in self.report_sections.items():
if content is not None:
latest_section = section
latest_content = content
-
+
if latest_section and latest_content:
- # Format the current section for display
section_titles = {
"market_report": "Market Analysis",
"sentiment_report": "Social Sentiment",
@@ -119,13 +118,11 @@ class MessageBuffer:
f"### {section_titles[latest_section]}\n{latest_content}"
)
- # Update the final complete report
self._update_final_report()
def _update_final_report(self):
report_parts = []
- # Analyst Team Reports
if any(
self.report_sections[section]
for section in [
@@ -153,17 +150,14 @@ class MessageBuffer:
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
)
- # Research Team Reports
if self.report_sections["investment_plan"]:
report_parts.append("## Research Team Decision")
report_parts.append(f"{self.report_sections['investment_plan']}")
- # Trading Team Reports
if self.report_sections["trader_investment_plan"]:
report_parts.append("## Trading Team Plan")
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
- # Portfolio Management Decision
if self.report_sections["final_trade_decision"]:
report_parts.append("## Portfolio Management Decision")
report_parts.append(f"{self.report_sections['final_trade_decision']}")
@@ -174,284 +168,396 @@ class MessageBuffer:
message_buffer = MessageBuffer()
-def create_layout():
- layout = Layout()
- layout.split_column(
- Layout(name="header", size=3),
- Layout(name="main"),
- Layout(name="footer", size=3),
- )
- layout["main"].split_column(
- Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
- )
- layout["upper"].split_row(
- Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
- )
- return layout
+LOOKBACK_OPTIONS = [
+ ("Last hour (1h)", "1h"),
+ ("Last 6 hours (6h)", "6h"),
+ ("Last 24 hours (24h)", "24h"),
+ ("Last 7 days (7d)", "7d"),
+]
+
+SECTOR_OPTIONS = [
+ ("Technology", Sector.TECHNOLOGY),
+ ("Healthcare", Sector.HEALTHCARE),
+ ("Finance", Sector.FINANCE),
+ ("Energy", Sector.ENERGY),
+ ("Consumer Goods", Sector.CONSUMER_GOODS),
+ ("Industrials", Sector.INDUSTRIALS),
+ ("Other", Sector.OTHER),
+]
+
+EVENT_OPTIONS = [
+ ("Earnings", EventCategory.EARNINGS),
+ ("Merger/Acquisition", EventCategory.MERGER_ACQUISITION),
+ ("Regulatory", EventCategory.REGULATORY),
+ ("Product Launch", EventCategory.PRODUCT_LAUNCH),
+ ("Executive Change", EventCategory.EXECUTIVE_CHANGE),
+ ("Other", EventCategory.OTHER),
+]
-def update_display(layout, spinner_text=None):
- # Header with welcome message
- layout["header"].update(
- Panel(
- "[bold green]Welcome to TradingAgents CLI[/bold green]\n"
- "[dim]© [Tauric Research](https://github.com/TauricResearch)[/dim]",
- title="Welcome to TradingAgents",
- border_style="green",
- padding=(1, 2),
- expand=True,
- )
- )
+def create_question_box(title: str, prompt: str, default: str = None) -> Panel:
+ box_content = f"[bold]{title}[/bold]\n"
+ box_content += f"[dim]{prompt}[/dim]"
+ if default:
+ box_content += f"\n[dim]Default: {default}[/dim]"
+ return Panel(box_content, border_style="blue", padding=(1, 2))
- # Progress panel showing agent status
- progress_table = Table(
- show_header=True,
- header_style="bold magenta",
- show_footer=False,
- box=box.SIMPLE_HEAD, # Use simple header with horizontal lines
- title=None, # Remove the redundant Progress title
- padding=(0, 2), # Add horizontal padding
- expand=True, # Make table expand to fill available space
- )
- progress_table.add_column("Team", style="cyan", justify="center", width=20)
- progress_table.add_column("Agent", style="green", justify="center", width=20)
- progress_table.add_column("Status", style="yellow", justify="center", width=20)
- # Group agents by team
- teams = {
- "Analyst Team": [
- "Market Analyst",
- "Social Analyst",
- "News Analyst",
- "Fundamentals Analyst",
+def select_lookback_period() -> str:
+ choice = questionary.select(
+ "Select lookback period:",
+ choices=[
+ questionary.Choice(display, value=value) for display, value in LOOKBACK_OPTIONS
],
- "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
- "Trading Team": ["Trader"],
- "Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"],
- "Portfolio Management": ["Portfolio Manager"],
- }
+ instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
+ style=questionary.Style(
+ [
+ ("selected", "fg:cyan noinherit"),
+ ("highlighted", "fg:cyan noinherit"),
+ ("pointer", "fg:cyan noinherit"),
+ ]
+ ),
+ ).ask()
- for team, agents in teams.items():
- # Add first agent with team name
- first_agent = agents[0]
- status = message_buffer.agent_status[first_agent]
- if status == "in_progress":
- spinner = Spinner(
- "dots", text="[blue]in_progress[/blue]", style="bold cyan"
- )
- status_cell = spinner
- else:
- status_color = {
- "pending": "yellow",
- "completed": "green",
- "error": "red",
- }.get(status, "white")
- status_cell = f"[{status_color}]{status}[/{status_color}]"
- progress_table.add_row(team, first_agent, status_cell)
+ if choice is None:
+ console.print("\n[red]No lookback period selected. Exiting...[/red]")
+ exit(1)
- # Add remaining agents in team
- for agent in agents[1:]:
- status = message_buffer.agent_status[agent]
- if status == "in_progress":
- spinner = Spinner(
- "dots", text="[blue]in_progress[/blue]", style="bold cyan"
- )
- status_cell = spinner
- else:
- status_color = {
- "pending": "yellow",
- "completed": "green",
- "error": "red",
- }.get(status, "white")
- status_cell = f"[{status_color}]{status}[/{status_color}]"
- progress_table.add_row("", agent, status_cell)
+ return choice
- # Add horizontal line after each team
- progress_table.add_row("─" * 20, "─" * 20, "─" * 20, style="dim")
- layout["progress"].update(
- Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
- )
+def select_sector_filter() -> Optional[List[Sector]]:
+ use_filter = questionary.confirm(
+ "Filter by sector?",
+ default=False,
+ style=questionary.Style(
+ [
+ ("selected", "fg:cyan noinherit"),
+ ("highlighted", "fg:cyan noinherit"),
+ ]
+ ),
+ ).ask()
- # Messages panel showing recent messages and tool calls
- messages_table = Table(
+ if not use_filter:
+ return None
+
+ choices = questionary.checkbox(
+ "Select sectors to include:",
+ choices=[
+ questionary.Choice(display, value=value) for display, value in SECTOR_OPTIONS
+ ],
+ instruction="\n- Press Space to select/unselect\n- Press 'a' to select all\n- Press Enter when done",
+ style=questionary.Style(
+ [
+ ("checkbox-selected", "fg:cyan"),
+ ("selected", "fg:cyan noinherit"),
+ ("highlighted", "noinherit"),
+ ("pointer", "noinherit"),
+ ]
+ ),
+ ).ask()
+
+ if not choices:
+ return None
+
+ return choices
+
+
+def select_event_filter() -> Optional[List[EventCategory]]:
+ use_filter = questionary.confirm(
+ "Filter by event type?",
+ default=False,
+ style=questionary.Style(
+ [
+ ("selected", "fg:cyan noinherit"),
+ ("highlighted", "fg:cyan noinherit"),
+ ]
+ ),
+ ).ask()
+
+ if not use_filter:
+ return None
+
+ choices = questionary.checkbox(
+ "Select event types to include:",
+ choices=[
+ questionary.Choice(display, value=value) for display, value in EVENT_OPTIONS
+ ],
+ instruction="\n- Press Space to select/unselect\n- Press 'a' to select all\n- Press Enter when done",
+ style=questionary.Style(
+ [
+ ("checkbox-selected", "fg:cyan"),
+ ("selected", "fg:cyan noinherit"),
+ ("highlighted", "noinherit"),
+ ("pointer", "noinherit"),
+ ]
+ ),
+ ).ask()
+
+ if not choices:
+ return None
+
+ return choices
+
+
+def create_discovery_results_table(trending_stocks: List[TrendingStock]) -> Table:
+ table = Table(
show_header=True,
header_style="bold magenta",
- show_footer=False,
- expand=True, # Make table expand to fill available space
- box=box.MINIMAL, # Use minimal box style for a lighter look
- show_lines=True, # Keep horizontal lines
- padding=(0, 1), # Add some padding between columns
+ box=box.ROUNDED,
+ title="Trending Stocks",
+ title_style="bold green",
+ expand=True,
)
- messages_table.add_column("Time", style="cyan", width=8, justify="center")
- messages_table.add_column("Type", style="green", width=10, justify="center")
- messages_table.add_column(
- "Content", style="white", no_wrap=False, ratio=1
- ) # Make content column expand
- # Combine tool calls and messages
- all_messages = []
+ table.add_column("Rank", style="cyan", justify="center", width=6)
+ table.add_column("Ticker", style="bold yellow", justify="center", width=10)
+ table.add_column("Company", style="white", justify="left", width=25)
+ table.add_column("Score", style="green", justify="right", width=10)
+ table.add_column("Mentions", style="blue", justify="center", width=10)
+ table.add_column("Event Type", style="magenta", justify="center", width=18)
- # Add tool calls
- for timestamp, tool_name, args in message_buffer.tool_calls:
- # Truncate tool call args if too long
- if isinstance(args, str) and len(args) > 100:
- args = args[:97] + "..."
- all_messages.append((timestamp, "Tool", f"{tool_name}: {args}"))
+ for rank, stock in enumerate(trending_stocks, 1):
+ if rank <= 3:
+ rank_display = f"[bold green]{rank}[/bold green]"
+ ticker_display = f"[bold yellow]{stock.ticker}[/bold yellow]"
+ else:
+ rank_display = str(rank)
+ ticker_display = stock.ticker
- # Add regular messages
- for timestamp, msg_type, content in message_buffer.messages:
- # Convert content to string if it's not already
- content_str = content
- if isinstance(content, list):
- # Handle list of content blocks (Anthropic format)
- text_parts = []
- for item in content:
- if isinstance(item, dict):
- if item.get('type') == 'text':
- text_parts.append(item.get('text', ''))
- elif item.get('type') == 'tool_use':
- text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
- else:
- text_parts.append(str(item))
- content_str = ' '.join(text_parts)
- elif not isinstance(content_str, str):
- content_str = str(content)
-
- # Truncate message content if too long
- if len(content_str) > 200:
- content_str = content_str[:197] + "..."
- all_messages.append((timestamp, msg_type, content_str))
-
- # Sort by timestamp
- all_messages.sort(key=lambda x: x[0])
-
- # Calculate how many messages we can show based on available space
- # Start with a reasonable number and adjust based on content length
- max_messages = 12 # Increased from 8 to better fill the space
-
- # Get the last N messages that will fit in the panel
- recent_messages = all_messages[-max_messages:]
-
- # Add messages to table
- for timestamp, msg_type, content in recent_messages:
- # Format content with word wrapping
- wrapped_content = Text(content, overflow="fold")
- messages_table.add_row(timestamp, msg_type, wrapped_content)
-
- if spinner_text:
- messages_table.add_row("", "Spinner", spinner_text)
-
- # Add a footer to indicate if messages were truncated
- if len(all_messages) > max_messages:
- messages_table.footer = (
- f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]"
+ table.add_row(
+ rank_display,
+ ticker_display,
+ stock.company_name[:25] if len(stock.company_name) > 25 else stock.company_name,
+ f"{stock.score:.2f}",
+ str(stock.mention_count),
+ stock.event_type.value.replace("_", " ").title(),
)
- layout["messages"].update(
- Panel(
- messages_table,
- title="Messages & Tools",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- # Analysis panel showing current report
- if message_buffer.current_report:
- layout["analysis"].update(
- Panel(
- Markdown(message_buffer.current_report),
- title="Current Report",
- border_style="green",
- padding=(1, 2),
- )
- )
- else:
- layout["analysis"].update(
- Panel(
- "[italic]Waiting for analysis report...[/italic]",
- title="Current Report",
- border_style="green",
- padding=(1, 2),
- )
- )
-
- # Footer with statistics
- tool_calls_count = len(message_buffer.tool_calls)
- llm_calls_count = sum(
- 1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
- )
- reports_count = sum(
- 1 for content in message_buffer.report_sections.values() if content is not None
- )
-
- stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
- stats_table.add_column("Stats", justify="center")
- stats_table.add_row(
- f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
- )
-
- layout["footer"].update(Panel(stats_table, border_style="grey50"))
+ return table
-def get_user_selections():
- """Get all user selections before starting the analysis display."""
- # Display ASCII art welcome message
- with open("./cli/static/welcome.txt", "r") as f:
- welcome_ascii = f.read()
+def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel:
+ sentiment_label = "positive" if stock.sentiment > 0.3 else "negative" if stock.sentiment < -0.3 else "neutral"
+ sentiment_color = "green" if stock.sentiment > 0.3 else "red" if stock.sentiment < -0.3 else "yellow"
- # Create welcome box content
- welcome_content = f"{welcome_ascii}\n"
- welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
- welcome_content += "[bold]Workflow Steps:[/bold]\n"
- welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n"
- welcome_content += (
- "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]"
- )
+ content = f"""[bold]Rank #{rank}: {stock.ticker} - {stock.company_name}[/bold]
- # Create and center the welcome box
- welcome_box = Panel(
- welcome_content,
- border_style="green",
+[cyan]Score:[/cyan] {stock.score:.2f}
+[cyan]Sentiment:[/cyan] [{sentiment_color}]{stock.sentiment:.2f} ({sentiment_label})[/{sentiment_color}]
+[cyan]Sector:[/cyan] {stock.sector.value.replace("_", " ").title()}
+[cyan]Event Type:[/cyan] {stock.event_type.value.replace("_", " ").title()}
+[cyan]Mentions:[/cyan] {stock.mention_count}
+
+[bold]News Summary:[/bold]
+{stock.news_summary}
+
+[bold]Top Source Articles:[/bold]"""
+
+ for i, article in enumerate(stock.source_articles[:3], 1):
+ content += f"\n {i}. [{article.title[:50]}...] - {article.source}"
+
+ return Panel(
+ content,
+ title=f"Stock Details: {stock.ticker}",
+ border_style="cyan",
padding=(1, 2),
- title="Welcome to TradingAgents",
- subtitle="Multi-Agents LLM Financial Trading Framework",
)
- console.print(Align.center(welcome_box))
- console.print() # Add a blank line after the welcome box
- # Create a boxed questionnaire for each step
- def create_question_box(title, prompt, default=None):
- box_content = f"[bold]{title}[/bold]\n"
- box_content += f"[dim]{prompt}[/dim]"
- if default:
- box_content += f"\n[dim]Default: {default}[/dim]"
- return Panel(box_content, border_style="blue", padding=(1, 2))
- # Step 1: Ticker symbol
+def select_stock_for_detail(trending_stocks: List[TrendingStock]) -> Optional[TrendingStock]:
+ if not trending_stocks:
+ return None
+
+ choices = [
+ questionary.Choice(
+ f"{i+1}. {stock.ticker} - {stock.company_name} (Score: {stock.score:.2f})",
+ value=stock
+ )
+ for i, stock in enumerate(trending_stocks)
+ ]
+ choices.append(questionary.Choice("Back to menu", value=None))
+
+ selected = questionary.select(
+ "Select a stock to view details:",
+ choices=choices,
+ instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
+ style=questionary.Style(
+ [
+ ("selected", "fg:cyan noinherit"),
+ ("highlighted", "fg:cyan noinherit"),
+ ("pointer", "fg:cyan noinherit"),
+ ]
+ ),
+ ).ask()
+
+ return selected
+
+
+def discover_trending_flow():
+ console.print(Rule("[bold green]Discover Trending Stocks[/bold green]"))
+ console.print()
+
console.print(
create_question_box(
- "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
+ "Step 1: Lookback Period",
+ "Select how far back to search for trending stocks"
)
)
- selected_ticker = get_ticker()
+ lookback_period = select_lookback_period()
+ console.print(f"[green]Selected lookback period:[/green] {lookback_period}")
+ console.print()
- # Step 2: Analysis date
- default_date = datetime.datetime.now().strftime("%Y-%m-%d")
console.print(
create_question_box(
- "Step 2: Analysis Date",
- "Enter the analysis date (YYYY-MM-DD)",
- default_date,
+ "Step 2: Sector Filter (Optional)",
+ "Optionally filter results by sector"
)
)
- analysis_date = get_analysis_date()
+ sector_filter = select_sector_filter()
+ if sector_filter:
+ console.print(f"[green]Selected sectors:[/green] {', '.join(s.value for s in sector_filter)}")
+ else:
+ console.print("[dim]No sector filter applied[/dim]")
+ console.print()
- # Step 3: Select analysts
console.print(
create_question_box(
- "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
+ "Step 3: Event Filter (Optional)",
+ "Optionally filter results by event type"
+ )
+ )
+ event_filter = select_event_filter()
+ if event_filter:
+ console.print(f"[green]Selected events:[/green] {', '.join(e.value for e in event_filter)}")
+ else:
+ console.print("[dim]No event filter applied[/dim]")
+ console.print()
+
+ console.print(
+ create_question_box(
+ "Step 4: LLM Provider",
+ "Select your LLM provider for entity extraction"
+ )
+ )
+ selected_llm_provider, backend_url = select_llm_provider()
+ console.print()
+
+ console.print(
+ create_question_box(
+ "Step 5: Quick-Thinking Model",
+ "Select the model for entity extraction"
+ )
+ )
+ selected_model = select_shallow_thinking_agent(selected_llm_provider)
+ console.print()
+
+ config = DEFAULT_CONFIG.copy()
+ config["llm_provider"] = selected_llm_provider.lower()
+ config["backend_url"] = backend_url
+ config["quick_think_llm"] = selected_model
+ config["deep_think_llm"] = selected_model
+
+ request = DiscoveryRequest(
+ lookback_period=lookback_period,
+ sector_filter=sector_filter,
+ event_filter=event_filter,
+ max_results=config.get("discovery_max_results", 20),
+ )
+
+ discovery_stages = [
+ "Fetching news...",
+ "Extracting entities...",
+ "Resolving tickers...",
+ "Calculating scores...",
+ ]
+
+ result = None
+ with Live(console=console, refresh_per_second=4) as live:
+ for i, stage in enumerate(discovery_stages):
+ progress_panel = Panel(
+ f"[bold cyan]{stage}[/bold cyan]\n\n"
+ f"[dim]Stage {i+1} of {len(discovery_stages)}[/dim]",
+ title="Discovery Progress",
+ border_style="cyan",
+ padding=(2, 4),
+ )
+ live.update(Align.center(progress_panel))
+
+ if i == 0:
+ try:
+ graph = TradingAgentsGraph(config=config, debug=False)
+ result = graph.discover_trending(request)
+ except Exception as e:
+ console.print(f"\n[red]Error during discovery: {e}[/red]")
+ return
+
+ time.sleep(0.5)
+
+ if result is None:
+ console.print("\n[red]Discovery failed. Please try again.[/red]")
+ return
+
+ if result.status == DiscoveryStatus.FAILED:
+ console.print(f"\n[red]Discovery failed: {result.error_message}[/red]")
+ return
+
+ if result.status == DiscoveryStatus.COMPLETED:
+ try:
+ save_path = save_discovery_result(result)
+ console.print(f"\n[dim]Results saved to: {save_path}[/dim]")
+ except Exception as e:
+ console.print(f"\n[yellow]Warning: Could not save results: {e}[/yellow]")
+
+ console.print()
+
+ if not result.trending_stocks:
+ console.print("[yellow]No trending stocks found matching your criteria.[/yellow]")
+ return
+
+ console.print(f"[green]Found {len(result.trending_stocks)} trending stocks[/green]")
+ console.print()
+
+ results_table = create_discovery_results_table(result.trending_stocks)
+ console.print(results_table)
+ console.print()
+
+ while True:
+ selected_stock = select_stock_for_detail(result.trending_stocks)
+
+ if selected_stock is None:
+ break
+
+ rank = result.trending_stocks.index(selected_stock) + 1
+ detail_panel = create_stock_detail_panel(selected_stock, rank)
+ console.print()
+ console.print(detail_panel)
+ console.print()
+
+ analyze_choice = questionary.confirm(
+ f"Analyze {selected_stock.ticker}?",
+ default=False,
+ style=questionary.Style(
+ [
+ ("selected", "fg:green noinherit"),
+ ("highlighted", "fg:green noinherit"),
+ ]
+ ),
+ ).ask()
+
+ if analyze_choice:
+ console.print(f"\n[green]Starting analysis for {selected_stock.ticker}...[/green]\n")
+ run_analysis_for_ticker(selected_stock.ticker, config)
+ break
+
+
+def run_analysis_for_ticker(ticker: str, config: dict):
+ analysis_date = datetime.datetime.now().strftime("%Y-%m-%d")
+
+ console.print(
+ create_question_box(
+ "Analysts Team",
+ "Select your LLM analyst agents for the analysis"
)
)
selected_analysts = select_analysts()
@@ -459,302 +565,32 @@ def get_user_selections():
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
)
- # Step 4: Research depth
console.print(
create_question_box(
- "Step 4: Research Depth", "Select your research depth level"
+ "Research Depth",
+ "Select your research depth level"
)
)
selected_research_depth = select_research_depth()
- # Step 5: OpenAI backend
console.print(
create_question_box(
- "Step 5: OpenAI backend", "Select which service to talk to"
+ "Deep-Thinking Model",
+ "Select the model for deep analysis"
)
)
- selected_llm_provider, backend_url = select_llm_provider()
-
- # Step 6: Thinking agents
- console.print(
- create_question_box(
- "Step 6: Thinking Agents", "Select your thinking agents for analysis"
- )
- )
- selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
- selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
+ llm_provider = config.get("llm_provider", "openai")
+ selected_deep_thinker = select_deep_thinking_agent(llm_provider.capitalize())
- return {
- "ticker": selected_ticker,
- "analysis_date": analysis_date,
- "analysts": selected_analysts,
- "research_depth": selected_research_depth,
- "llm_provider": selected_llm_provider.lower(),
- "backend_url": backend_url,
- "shallow_thinker": selected_shallow_thinker,
- "deep_thinker": selected_deep_thinker,
- }
+ config["max_debate_rounds"] = selected_research_depth
+ config["max_risk_discuss_rounds"] = selected_research_depth
+ config["deep_think_llm"] = selected_deep_thinker
-
-def get_ticker():
- """Get ticker symbol from user input."""
- return typer.prompt("", default="SPY")
-
-
-def get_analysis_date():
- """Get the analysis date from user input."""
- while True:
- date_str = typer.prompt(
- "", default=datetime.datetime.now().strftime("%Y-%m-%d")
- )
- try:
- # Validate date format and ensure it's not in the future
- analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
- if analysis_date.date() > datetime.datetime.now().date():
- console.print("[red]Error: Analysis date cannot be in the future[/red]")
- continue
- return date_str
- except ValueError:
- console.print(
- "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
- )
-
-
-def display_complete_report(final_state):
- """Display the complete analysis report with team-based panels."""
- console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
-
- # I. Analyst Team Reports
- analyst_reports = []
-
- # Market Analyst Report
- if final_state.get("market_report"):
- analyst_reports.append(
- Panel(
- Markdown(final_state["market_report"]),
- title="Market Analyst",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- # Social Analyst Report
- if final_state.get("sentiment_report"):
- analyst_reports.append(
- Panel(
- Markdown(final_state["sentiment_report"]),
- title="Social Analyst",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- # News Analyst Report
- if final_state.get("news_report"):
- analyst_reports.append(
- Panel(
- Markdown(final_state["news_report"]),
- title="News Analyst",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- # Fundamentals Analyst Report
- if final_state.get("fundamentals_report"):
- analyst_reports.append(
- Panel(
- Markdown(final_state["fundamentals_report"]),
- title="Fundamentals Analyst",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- if analyst_reports:
- console.print(
- Panel(
- Columns(analyst_reports, equal=True, expand=True),
- title="I. Analyst Team Reports",
- border_style="cyan",
- padding=(1, 2),
- )
- )
-
- # II. Research Team Reports
- if final_state.get("investment_debate_state"):
- research_reports = []
- debate_state = final_state["investment_debate_state"]
-
- # Bull Researcher Analysis
- if debate_state.get("bull_history"):
- research_reports.append(
- Panel(
- Markdown(debate_state["bull_history"]),
- title="Bull Researcher",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- # Bear Researcher Analysis
- if debate_state.get("bear_history"):
- research_reports.append(
- Panel(
- Markdown(debate_state["bear_history"]),
- title="Bear Researcher",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- # Research Manager Decision
- if debate_state.get("judge_decision"):
- research_reports.append(
- Panel(
- Markdown(debate_state["judge_decision"]),
- title="Research Manager",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- if research_reports:
- console.print(
- Panel(
- Columns(research_reports, equal=True, expand=True),
- title="II. Research Team Decision",
- border_style="magenta",
- padding=(1, 2),
- )
- )
-
- # III. Trading Team Reports
- if final_state.get("trader_investment_plan"):
- console.print(
- Panel(
- Panel(
- Markdown(final_state["trader_investment_plan"]),
- title="Trader",
- border_style="blue",
- padding=(1, 2),
- ),
- title="III. Trading Team Plan",
- border_style="yellow",
- padding=(1, 2),
- )
- )
-
- # IV. Risk Management Team Reports
- if final_state.get("risk_debate_state"):
- risk_reports = []
- risk_state = final_state["risk_debate_state"]
-
- # Aggressive (Risky) Analyst Analysis
- if risk_state.get("risky_history"):
- risk_reports.append(
- Panel(
- Markdown(risk_state["risky_history"]),
- title="Aggressive Analyst",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- # Conservative (Safe) Analyst Analysis
- if risk_state.get("safe_history"):
- risk_reports.append(
- Panel(
- Markdown(risk_state["safe_history"]),
- title="Conservative Analyst",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- # Neutral Analyst Analysis
- if risk_state.get("neutral_history"):
- risk_reports.append(
- Panel(
- Markdown(risk_state["neutral_history"]),
- title="Neutral Analyst",
- border_style="blue",
- padding=(1, 2),
- )
- )
-
- if risk_reports:
- console.print(
- Panel(
- Columns(risk_reports, equal=True, expand=True),
- title="IV. Risk Management Team Decision",
- border_style="red",
- padding=(1, 2),
- )
- )
-
- # V. Portfolio Manager Decision
- if risk_state.get("judge_decision"):
- console.print(
- Panel(
- Panel(
- Markdown(risk_state["judge_decision"]),
- title="Portfolio Manager",
- border_style="blue",
- padding=(1, 2),
- ),
- title="V. Portfolio Manager Decision",
- border_style="green",
- padding=(1, 2),
- )
- )
-
-
-def update_research_team_status(status):
- """Update status for all research team members and trader."""
- research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
- for agent in research_team:
- message_buffer.update_agent_status(agent, status)
-
-def extract_content_string(content):
- """Extract string content from various message formats."""
- if isinstance(content, str):
- return content
- elif isinstance(content, list):
- # Handle Anthropic's list format
- text_parts = []
- for item in content:
- if isinstance(item, dict):
- if item.get('type') == 'text':
- text_parts.append(item.get('text', ''))
- elif item.get('type') == 'tool_use':
- text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
- else:
- text_parts.append(str(item))
- return ' '.join(text_parts)
- else:
- return str(content)
-
-def run_analysis():
- # First get all user selections
- selections = get_user_selections()
-
- # Create config with selected research depth
- config = DEFAULT_CONFIG.copy()
- config["max_debate_rounds"] = selections["research_depth"]
- config["max_risk_discuss_rounds"] = selections["research_depth"]
- config["quick_think_llm"] = selections["shallow_thinker"]
- config["deep_think_llm"] = selections["deep_thinker"]
- config["backend_url"] = selections["backend_url"]
- config["llm_provider"] = selections["llm_provider"].lower()
-
- # Initialize the graph
graph = TradingAgentsGraph(
- [analyst.value for analyst in selections["analysts"]], config=config, debug=True
+ [analyst.value for analyst in selected_analysts], config=config, debug=True
)
- # Create result directory
- results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
+ results_dir = Path(config["results_dir"]) / ticker / analysis_date
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
@@ -767,11 +603,757 @@ def run_analysis():
def wrapper(*args, **kwargs):
func(*args, **kwargs)
timestamp, message_type, content = obj.messages[-1]
- content = content.replace("\n", " ") # Replace newlines with spaces
+ content = content.replace("\n", " ")
with open(log_file, "a") as f:
f.write(f"{timestamp} [{message_type}] {content}\n")
return wrapper
-
+
+ def save_tool_call_decorator(obj, func_name):
+ func = getattr(obj, func_name)
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ func(*args, **kwargs)
+ timestamp, tool_name, args = obj.tool_calls[-1]
+ args_str = ", ".join(f"{k}={v}" for k, v in args.items())
+ with open(log_file, "a") as f:
+ f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n")
+ return wrapper
+
+ def save_report_section_decorator(obj, func_name):
+ func = getattr(obj, func_name)
+ @wraps(func)
+ def wrapper(section_name, content):
+ func(section_name, content)
+ if section_name in obj.report_sections and obj.report_sections[section_name] is not None:
+ content = obj.report_sections[section_name]
+ if content:
+ file_name = f"{section_name}.md"
+ with open(report_dir / file_name, "w") as f:
+ f.write(content)
+ return wrapper
+
+ message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
+ message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
+ message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
+
+ layout = create_layout()
+
+ with Live(layout, refresh_per_second=4) as live:
+ update_display(layout)
+
+ message_buffer.add_message("System", f"Selected ticker: {ticker}")
+ message_buffer.add_message("System", f"Analysis date: {analysis_date}")
+ message_buffer.add_message(
+ "System",
+ f"Selected analysts: {', '.join(analyst.value for analyst in selected_analysts)}",
+ )
+ update_display(layout)
+
+ for agent in message_buffer.agent_status:
+ message_buffer.update_agent_status(agent, "pending")
+
+ for section in message_buffer.report_sections:
+ message_buffer.report_sections[section] = None
+ message_buffer.current_report = None
+ message_buffer.final_report = None
+
+ first_analyst = f"{selected_analysts[0].value.capitalize()} Analyst"
+ message_buffer.update_agent_status(first_analyst, "in_progress")
+ update_display(layout)
+
+ spinner_text = f"Analyzing {ticker} on {analysis_date}..."
+ update_display(layout, spinner_text)
+
+ init_agent_state = graph.propagator.create_initial_state(ticker, analysis_date)
+ args = graph.propagator.get_graph_args()
+
+ trace = []
+ for chunk in graph.graph.stream(init_agent_state, **args):
+ if len(chunk["messages"]) > 0:
+ last_message = chunk["messages"][-1]
+
+ if hasattr(last_message, "content"):
+ content = extract_content_string(last_message.content)
+ msg_type = "Reasoning"
+ else:
+ content = str(last_message)
+ msg_type = "System"
+
+ message_buffer.add_message(msg_type, content)
+
+ if hasattr(last_message, "tool_calls"):
+ for tool_call in last_message.tool_calls:
+ if isinstance(tool_call, dict):
+ message_buffer.add_tool_call(tool_call["name"], tool_call["args"])
+ else:
+ message_buffer.add_tool_call(tool_call.name, tool_call.args)
+
+ process_chunk_for_display(chunk, selected_analysts)
+ update_display(layout)
+
+ trace.append(chunk)
+
+ final_state = trace[-1]
+ decision = graph.process_signal(final_state["final_trade_decision"])
+
+ for agent in message_buffer.agent_status:
+ message_buffer.update_agent_status(agent, "completed")
+
+ message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}")
+
+ for section in message_buffer.report_sections.keys():
+ if section in final_state:
+ message_buffer.update_report_section(section, final_state[section])
+
+ display_complete_report(final_state)
+ update_display(layout)
+
+
+def process_chunk_for_display(chunk, selected_analysts):
+ if "market_report" in chunk and chunk["market_report"]:
+ message_buffer.update_report_section("market_report", chunk["market_report"])
+ message_buffer.update_agent_status("Market Analyst", "completed")
+ if AnalystType.SOCIAL in selected_analysts:
+ message_buffer.update_agent_status("Social Analyst", "in_progress")
+
+ if "sentiment_report" in chunk and chunk["sentiment_report"]:
+ message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"])
+ message_buffer.update_agent_status("Social Analyst", "completed")
+ if AnalystType.NEWS in selected_analysts:
+ message_buffer.update_agent_status("News Analyst", "in_progress")
+
+ if "news_report" in chunk and chunk["news_report"]:
+ message_buffer.update_report_section("news_report", chunk["news_report"])
+ message_buffer.update_agent_status("News Analyst", "completed")
+ if AnalystType.FUNDAMENTALS in selected_analysts:
+ message_buffer.update_agent_status("Fundamentals Analyst", "in_progress")
+
+ if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
+ message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"])
+ message_buffer.update_agent_status("Fundamentals Analyst", "completed")
+ update_research_team_status("in_progress")
+
+ if "investment_debate_state" in chunk and chunk["investment_debate_state"]:
+ debate_state = chunk["investment_debate_state"]
+
+ if "bull_history" in debate_state and debate_state["bull_history"]:
+ update_research_team_status("in_progress")
+ bull_responses = debate_state["bull_history"].split("\n")
+ latest_bull = bull_responses[-1] if bull_responses else ""
+ if latest_bull:
+ message_buffer.add_message("Reasoning", latest_bull)
+ message_buffer.update_report_section(
+ "investment_plan",
+ f"### Bull Researcher Analysis\n{latest_bull}",
+ )
+
+ if "bear_history" in debate_state and debate_state["bear_history"]:
+ update_research_team_status("in_progress")
+ bear_responses = debate_state["bear_history"].split("\n")
+ latest_bear = bear_responses[-1] if bear_responses else ""
+ if latest_bear:
+ message_buffer.add_message("Reasoning", latest_bear)
+ message_buffer.update_report_section(
+ "investment_plan",
+ f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}",
+ )
+
+ if "judge_decision" in debate_state and debate_state["judge_decision"]:
+ update_research_team_status("in_progress")
+ message_buffer.add_message(
+ "Reasoning",
+ f"Research Manager: {debate_state['judge_decision']}",
+ )
+ message_buffer.update_report_section(
+ "investment_plan",
+ f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
+ )
+ update_research_team_status("completed")
+ message_buffer.update_agent_status("Risky Analyst", "in_progress")
+
+ if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]:
+ message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"])
+ message_buffer.update_agent_status("Risky Analyst", "in_progress")
+
+ if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
+ risk_state = chunk["risk_debate_state"]
+
+ if "current_risky_response" in risk_state and risk_state["current_risky_response"]:
+ message_buffer.update_agent_status("Risky Analyst", "in_progress")
+ message_buffer.add_message(
+ "Reasoning",
+ f"Risky Analyst: {risk_state['current_risky_response']}",
+ )
+ message_buffer.update_report_section(
+ "final_trade_decision",
+ f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}",
+ )
+
+ if "current_safe_response" in risk_state and risk_state["current_safe_response"]:
+ message_buffer.update_agent_status("Safe Analyst", "in_progress")
+ message_buffer.add_message(
+ "Reasoning",
+ f"Safe Analyst: {risk_state['current_safe_response']}",
+ )
+ message_buffer.update_report_section(
+ "final_trade_decision",
+ f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}",
+ )
+
+ if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]:
+ message_buffer.update_agent_status("Neutral Analyst", "in_progress")
+ message_buffer.add_message(
+ "Reasoning",
+ f"Neutral Analyst: {risk_state['current_neutral_response']}",
+ )
+ message_buffer.update_report_section(
+ "final_trade_decision",
+ f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}",
+ )
+
+ if "judge_decision" in risk_state and risk_state["judge_decision"]:
+ message_buffer.update_agent_status("Portfolio Manager", "in_progress")
+ message_buffer.add_message(
+ "Reasoning",
+ f"Portfolio Manager: {risk_state['judge_decision']}",
+ )
+ message_buffer.update_report_section(
+ "final_trade_decision",
+ f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
+ )
+ message_buffer.update_agent_status("Risky Analyst", "completed")
+ message_buffer.update_agent_status("Safe Analyst", "completed")
+ message_buffer.update_agent_status("Neutral Analyst", "completed")
+ message_buffer.update_agent_status("Portfolio Manager", "completed")
+
+
+def create_layout():
+ layout = Layout()
+ layout.split_column(
+ Layout(name="header", size=3),
+ Layout(name="main"),
+ Layout(name="footer", size=3),
+ )
+ layout["main"].split_column(
+ Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5)
+ )
+ layout["upper"].split_row(
+ Layout(name="progress", ratio=2), Layout(name="messages", ratio=3)
+ )
+ return layout
+
+
+def update_display(layout, spinner_text=None):
+ layout["header"].update(
+ Panel(
+ "[bold green]Welcome to TradingAgents CLI[/bold green]\n"
+ "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]",
+ title="Welcome to TradingAgents",
+ border_style="green",
+ padding=(1, 2),
+ expand=True,
+ )
+ )
+
+ progress_table = Table(
+ show_header=True,
+ header_style="bold magenta",
+ show_footer=False,
+ box=box.SIMPLE_HEAD,
+ title=None,
+ padding=(0, 2),
+ expand=True,
+ )
+ progress_table.add_column("Team", style="cyan", justify="center", width=20)
+ progress_table.add_column("Agent", style="green", justify="center", width=20)
+ progress_table.add_column("Status", style="yellow", justify="center", width=20)
+
+ teams = {
+ "Analyst Team": [
+ "Market Analyst",
+ "Social Analyst",
+ "News Analyst",
+ "Fundamentals Analyst",
+ ],
+ "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"],
+ "Trading Team": ["Trader"],
+ "Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"],
+ "Portfolio Management": ["Portfolio Manager"],
+ }
+
+ for team, agents in teams.items():
+ first_agent = agents[0]
+ status = message_buffer.agent_status[first_agent]
+ if status == "in_progress":
+ spinner = Spinner(
+ "dots", text="[blue]in_progress[/blue]", style="bold cyan"
+ )
+ status_cell = spinner
+ else:
+ status_color = {
+ "pending": "yellow",
+ "completed": "green",
+ "error": "red",
+ }.get(status, "white")
+ status_cell = f"[{status_color}]{status}[/{status_color}]"
+ progress_table.add_row(team, first_agent, status_cell)
+
+ for agent in agents[1:]:
+ status = message_buffer.agent_status[agent]
+ if status == "in_progress":
+ spinner = Spinner(
+ "dots", text="[blue]in_progress[/blue]", style="bold cyan"
+ )
+ status_cell = spinner
+ else:
+ status_color = {
+ "pending": "yellow",
+ "completed": "green",
+ "error": "red",
+ }.get(status, "white")
+ status_cell = f"[{status_color}]{status}[/{status_color}]"
+ progress_table.add_row("", agent, status_cell)
+
+ progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim")
+
+ layout["progress"].update(
+ Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2))
+ )
+
+ messages_table = Table(
+ show_header=True,
+ header_style="bold magenta",
+ show_footer=False,
+ expand=True,
+ box=box.MINIMAL,
+ show_lines=True,
+ padding=(0, 1),
+ )
+ messages_table.add_column("Time", style="cyan", width=8, justify="center")
+ messages_table.add_column("Type", style="green", width=10, justify="center")
+ messages_table.add_column("Content", style="white", no_wrap=False, ratio=1)
+
+ all_messages = []
+
+ for timestamp, tool_name, args in message_buffer.tool_calls:
+ if isinstance(args, str) and len(args) > 100:
+ args = args[:97] + "..."
+ all_messages.append((timestamp, "Tool", f"{tool_name}: {args}"))
+
+ for timestamp, msg_type, content in message_buffer.messages:
+ content_str = content
+ if isinstance(content, list):
+ text_parts = []
+ for item in content:
+ if isinstance(item, dict):
+ if item.get('type') == 'text':
+ text_parts.append(item.get('text', ''))
+ elif item.get('type') == 'tool_use':
+ text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
+ else:
+ text_parts.append(str(item))
+ content_str = ' '.join(text_parts)
+ elif not isinstance(content_str, str):
+ content_str = str(content)
+
+ if len(content_str) > 200:
+ content_str = content_str[:197] + "..."
+ all_messages.append((timestamp, msg_type, content_str))
+
+ all_messages.sort(key=lambda x: x[0])
+ max_messages = 12
+ recent_messages = all_messages[-max_messages:]
+
+ for timestamp, msg_type, content in recent_messages:
+ wrapped_content = Text(content, overflow="fold")
+ messages_table.add_row(timestamp, msg_type, wrapped_content)
+
+ if spinner_text:
+ messages_table.add_row("", "Spinner", spinner_text)
+
+ if len(all_messages) > max_messages:
+ messages_table.footer = (
+ f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]"
+ )
+
+ layout["messages"].update(
+ Panel(
+ messages_table,
+ title="Messages & Tools",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if message_buffer.current_report:
+ layout["analysis"].update(
+ Panel(
+ Markdown(message_buffer.current_report),
+ title="Current Report",
+ border_style="green",
+ padding=(1, 2),
+ )
+ )
+ else:
+ layout["analysis"].update(
+ Panel(
+ "[italic]Waiting for analysis report...[/italic]",
+ title="Current Report",
+ border_style="green",
+ padding=(1, 2),
+ )
+ )
+
+ tool_calls_count = len(message_buffer.tool_calls)
+ llm_calls_count = sum(
+ 1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning"
+ )
+ reports_count = sum(
+ 1 for content in message_buffer.report_sections.values() if content is not None
+ )
+
+ stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True)
+ stats_table.add_column("Stats", justify="center")
+ stats_table.add_row(
+ f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}"
+ )
+
+ layout["footer"].update(Panel(stats_table, border_style="grey50"))
+
+
+def get_user_selections():
+ with open("./cli/static/welcome.txt", "r") as f:
+ welcome_ascii = f.read()
+
+ welcome_content = f"{welcome_ascii}\n"
+ welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
+ welcome_content += "[bold]Workflow Steps:[/bold]\n"
+ welcome_content += "I. Analyst Team -> II. Research Team -> III. Trader -> IV. Risk Management -> V. Portfolio Management\n\n"
+ welcome_content += "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]"
+
+ welcome_box = Panel(
+ welcome_content,
+ border_style="green",
+ padding=(1, 2),
+ title="Welcome to TradingAgents",
+ subtitle="Multi-Agents LLM Financial Trading Framework",
+ )
+ console.print(Align.center(welcome_box))
+ console.print()
+
+ console.print(
+ create_question_box(
+ "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY"
+ )
+ )
+ selected_ticker = get_ticker()
+
+ default_date = datetime.datetime.now().strftime("%Y-%m-%d")
+ console.print(
+ create_question_box(
+ "Step 2: Analysis Date",
+ "Enter the analysis date (YYYY-MM-DD)",
+ default_date,
+ )
+ )
+ analysis_date = get_analysis_date()
+
+ console.print(
+ create_question_box(
+ "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis"
+ )
+ )
+ selected_analysts = select_analysts()
+ console.print(
+ f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
+ )
+
+ console.print(
+ create_question_box(
+ "Step 4: Research Depth", "Select your research depth level"
+ )
+ )
+ selected_research_depth = select_research_depth()
+
+ console.print(
+ create_question_box(
+ "Step 5: OpenAI backend", "Select which service to talk to"
+ )
+ )
+ selected_llm_provider, backend_url = select_llm_provider()
+
+ console.print(
+ create_question_box(
+ "Step 6: Thinking Agents", "Select your thinking agents for analysis"
+ )
+ )
+ selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
+ selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
+
+ return {
+ "ticker": selected_ticker,
+ "analysis_date": analysis_date,
+ "analysts": selected_analysts,
+ "research_depth": selected_research_depth,
+ "llm_provider": selected_llm_provider.lower(),
+ "backend_url": backend_url,
+ "shallow_thinker": selected_shallow_thinker,
+ "deep_thinker": selected_deep_thinker,
+ }
+
+
+def get_ticker():
+ return typer.prompt("", default="SPY")
+
+
+def get_analysis_date():
+ while True:
+ date_str = typer.prompt(
+ "", default=datetime.datetime.now().strftime("%Y-%m-%d")
+ )
+ try:
+ analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d")
+ if analysis_date.date() > datetime.datetime.now().date():
+ console.print("[red]Error: Analysis date cannot be in the future[/red]")
+ continue
+ return date_str
+ except ValueError:
+ console.print(
+ "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]"
+ )
+
+
+def display_complete_report(final_state):
+ console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
+
+ analyst_reports = []
+
+ if final_state.get("market_report"):
+ analyst_reports.append(
+ Panel(
+ Markdown(final_state["market_report"]),
+ title="Market Analyst",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if final_state.get("sentiment_report"):
+ analyst_reports.append(
+ Panel(
+ Markdown(final_state["sentiment_report"]),
+ title="Social Analyst",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if final_state.get("news_report"):
+ analyst_reports.append(
+ Panel(
+ Markdown(final_state["news_report"]),
+ title="News Analyst",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if final_state.get("fundamentals_report"):
+ analyst_reports.append(
+ Panel(
+ Markdown(final_state["fundamentals_report"]),
+ title="Fundamentals Analyst",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if analyst_reports:
+ console.print(
+ Panel(
+ Columns(analyst_reports, equal=True, expand=True),
+ title="I. Analyst Team Reports",
+ border_style="cyan",
+ padding=(1, 2),
+ )
+ )
+
+ if final_state.get("investment_debate_state"):
+ research_reports = []
+ debate_state = final_state["investment_debate_state"]
+
+ if debate_state.get("bull_history"):
+ research_reports.append(
+ Panel(
+ Markdown(debate_state["bull_history"]),
+ title="Bull Researcher",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if debate_state.get("bear_history"):
+ research_reports.append(
+ Panel(
+ Markdown(debate_state["bear_history"]),
+ title="Bear Researcher",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if debate_state.get("judge_decision"):
+ research_reports.append(
+ Panel(
+ Markdown(debate_state["judge_decision"]),
+ title="Research Manager",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if research_reports:
+ console.print(
+ Panel(
+ Columns(research_reports, equal=True, expand=True),
+ title="II. Research Team Decision",
+ border_style="magenta",
+ padding=(1, 2),
+ )
+ )
+
+ if final_state.get("trader_investment_plan"):
+ console.print(
+ Panel(
+ Panel(
+ Markdown(final_state["trader_investment_plan"]),
+ title="Trader",
+ border_style="blue",
+ padding=(1, 2),
+ ),
+ title="III. Trading Team Plan",
+ border_style="yellow",
+ padding=(1, 2),
+ )
+ )
+
+ if final_state.get("risk_debate_state"):
+ risk_reports = []
+ risk_state = final_state["risk_debate_state"]
+
+ if risk_state.get("risky_history"):
+ risk_reports.append(
+ Panel(
+ Markdown(risk_state["risky_history"]),
+ title="Aggressive Analyst",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if risk_state.get("safe_history"):
+ risk_reports.append(
+ Panel(
+ Markdown(risk_state["safe_history"]),
+ title="Conservative Analyst",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if risk_state.get("neutral_history"):
+ risk_reports.append(
+ Panel(
+ Markdown(risk_state["neutral_history"]),
+ title="Neutral Analyst",
+ border_style="blue",
+ padding=(1, 2),
+ )
+ )
+
+ if risk_reports:
+ console.print(
+ Panel(
+ Columns(risk_reports, equal=True, expand=True),
+ title="IV. Risk Management Team Decision",
+ border_style="red",
+ padding=(1, 2),
+ )
+ )
+
+ if risk_state.get("judge_decision"):
+ console.print(
+ Panel(
+ Panel(
+ Markdown(risk_state["judge_decision"]),
+ title="Portfolio Manager",
+ border_style="blue",
+ padding=(1, 2),
+ ),
+ title="V. Portfolio Manager Decision",
+ border_style="green",
+ padding=(1, 2),
+ )
+ )
+
+
+def update_research_team_status(status):
+ research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"]
+ for agent in research_team:
+ message_buffer.update_agent_status(agent, status)
+
+
+def extract_content_string(content):
+ if isinstance(content, str):
+ return content
+ elif isinstance(content, list):
+ text_parts = []
+ for item in content:
+ if isinstance(item, dict):
+ if item.get('type') == 'text':
+ text_parts.append(item.get('text', ''))
+ elif item.get('type') == 'tool_use':
+ text_parts.append(f"[Tool: {item.get('name', 'unknown')}]")
+ else:
+ text_parts.append(str(item))
+ return ' '.join(text_parts)
+ else:
+ return str(content)
+
+
+def run_analysis():
+ selections = get_user_selections()
+
+ config = DEFAULT_CONFIG.copy()
+ config["max_debate_rounds"] = selections["research_depth"]
+ config["max_risk_discuss_rounds"] = selections["research_depth"]
+ config["quick_think_llm"] = selections["shallow_thinker"]
+ config["deep_think_llm"] = selections["deep_thinker"]
+ config["backend_url"] = selections["backend_url"]
+ config["llm_provider"] = selections["llm_provider"].lower()
+
+ graph = TradingAgentsGraph(
+ [analyst.value for analyst in selections["analysts"]], config=config, debug=True
+ )
+
+ results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
+ results_dir.mkdir(parents=True, exist_ok=True)
+ report_dir = results_dir / "reports"
+ report_dir.mkdir(parents=True, exist_ok=True)
+ log_file = results_dir / "message_tool.log"
+ log_file.touch(exist_ok=True)
+
+ def save_message_decorator(obj, func_name):
+ func = getattr(obj, func_name)
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ func(*args, **kwargs)
+ timestamp, message_type, content = obj.messages[-1]
+ content = content.replace("\n", " ")
+ with open(log_file, "a") as f:
+ f.write(f"{timestamp} [{message_type}] {content}\n")
+ return wrapper
+
def save_tool_call_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)
@@ -800,14 +1382,11 @@ def run_analysis():
message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call")
message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section")
- # Now start the display layout
layout = create_layout()
with Live(layout, refresh_per_second=4) as live:
- # Initial display
update_display(layout)
- # Add initial messages
message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}")
message_buffer.add_message(
"System", f"Analysis date: {selections['analysis_date']}"
@@ -818,55 +1397,44 @@ def run_analysis():
)
update_display(layout)
- # Reset agent statuses
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "pending")
- # Reset report sections
for section in message_buffer.report_sections:
message_buffer.report_sections[section] = None
message_buffer.current_report = None
message_buffer.final_report = None
- # Update agent status to in_progress for the first analyst
first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst"
message_buffer.update_agent_status(first_analyst, "in_progress")
update_display(layout)
- # Create spinner text
spinner_text = (
f"Analyzing {selections['ticker']} on {selections['analysis_date']}..."
)
update_display(layout, spinner_text)
- # Initialize state and get graph args
init_agent_state = graph.propagator.create_initial_state(
selections["ticker"], selections["analysis_date"]
)
args = graph.propagator.get_graph_args()
- # Stream the analysis
trace = []
for chunk in graph.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) > 0:
- # Get the last message from the chunk
last_message = chunk["messages"][-1]
- # Extract message content and type
if hasattr(last_message, "content"):
- content = extract_content_string(last_message.content) # Use the helper function
+ content = extract_content_string(last_message.content)
msg_type = "Reasoning"
else:
content = str(last_message)
msg_type = "System"
- # Add message to buffer
- message_buffer.add_message(msg_type, content)
+ message_buffer.add_message(msg_type, content)
- # If it's a tool call, add it to tool calls
if hasattr(last_message, "tool_calls"):
for tool_call in last_message.tool_calls:
- # Handle both dictionary and object tool calls
if isinstance(tool_call, dict):
message_buffer.add_tool_call(
tool_call["name"], tool_call["args"]
@@ -874,14 +1442,11 @@ def run_analysis():
else:
message_buffer.add_tool_call(tool_call.name, tool_call.args)
- # Update reports and agent status based on chunk content
- # Analyst Team Reports
if "market_report" in chunk and chunk["market_report"]:
message_buffer.update_report_section(
"market_report", chunk["market_report"]
)
message_buffer.update_agent_status("Market Analyst", "completed")
- # Set next analyst to in_progress
if "social" in selections["analysts"]:
message_buffer.update_agent_status(
"Social Analyst", "in_progress"
@@ -892,7 +1457,6 @@ def run_analysis():
"sentiment_report", chunk["sentiment_report"]
)
message_buffer.update_agent_status("Social Analyst", "completed")
- # Set next analyst to in_progress
if "news" in selections["analysts"]:
message_buffer.update_agent_status(
"News Analyst", "in_progress"
@@ -903,7 +1467,6 @@ def run_analysis():
"news_report", chunk["news_report"]
)
message_buffer.update_agent_status("News Analyst", "completed")
- # Set next analyst to in_progress
if "fundamentals" in selections["analysts"]:
message_buffer.update_agent_status(
"Fundamentals Analyst", "in_progress"
@@ -916,70 +1479,54 @@ def run_analysis():
message_buffer.update_agent_status(
"Fundamentals Analyst", "completed"
)
- # Set all research team members to in_progress
update_research_team_status("in_progress")
- # Research Team - Handle Investment Debate State
if (
"investment_debate_state" in chunk
and chunk["investment_debate_state"]
):
debate_state = chunk["investment_debate_state"]
- # Update Bull Researcher status and report
if "bull_history" in debate_state and debate_state["bull_history"]:
- # Keep all research team members in progress
update_research_team_status("in_progress")
- # Extract latest bull response
bull_responses = debate_state["bull_history"].split("\n")
latest_bull = bull_responses[-1] if bull_responses else ""
if latest_bull:
message_buffer.add_message("Reasoning", latest_bull)
- # Update research report with bull's latest analysis
message_buffer.update_report_section(
"investment_plan",
f"### Bull Researcher Analysis\n{latest_bull}",
)
- # Update Bear Researcher status and report
if "bear_history" in debate_state and debate_state["bear_history"]:
- # Keep all research team members in progress
update_research_team_status("in_progress")
- # Extract latest bear response
bear_responses = debate_state["bear_history"].split("\n")
latest_bear = bear_responses[-1] if bear_responses else ""
if latest_bear:
message_buffer.add_message("Reasoning", latest_bear)
- # Update research report with bear's latest analysis
message_buffer.update_report_section(
"investment_plan",
f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}",
)
- # Update Research Manager status and final decision
if (
"judge_decision" in debate_state
and debate_state["judge_decision"]
):
- # Keep all research team members in progress until final decision
update_research_team_status("in_progress")
message_buffer.add_message(
"Reasoning",
f"Research Manager: {debate_state['judge_decision']}",
)
- # Update research report with final decision
message_buffer.update_report_section(
"investment_plan",
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
)
- # Mark all research team members as completed
update_research_team_status("completed")
- # Set first risk analyst to in_progress
message_buffer.update_agent_status(
"Risky Analyst", "in_progress"
)
- # Trading Team
if (
"trader_investment_plan" in chunk
and chunk["trader_investment_plan"]
@@ -987,14 +1534,11 @@ def run_analysis():
message_buffer.update_report_section(
"trader_investment_plan", chunk["trader_investment_plan"]
)
- # Set first risk analyst to in_progress
message_buffer.update_agent_status("Risky Analyst", "in_progress")
- # Risk Management Team - Handle Risk Debate State
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
risk_state = chunk["risk_debate_state"]
- # Update Risky Analyst status and report
if (
"current_risky_response" in risk_state
and risk_state["current_risky_response"]
@@ -1006,13 +1550,11 @@ def run_analysis():
"Reasoning",
f"Risky Analyst: {risk_state['current_risky_response']}",
)
- # Update risk report with risky analyst's latest analysis only
message_buffer.update_report_section(
"final_trade_decision",
f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}",
)
- # Update Safe Analyst status and report
if (
"current_safe_response" in risk_state
and risk_state["current_safe_response"]
@@ -1024,13 +1566,11 @@ def run_analysis():
"Reasoning",
f"Safe Analyst: {risk_state['current_safe_response']}",
)
- # Update risk report with safe analyst's latest analysis only
message_buffer.update_report_section(
"final_trade_decision",
f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}",
)
- # Update Neutral Analyst status and report
if (
"current_neutral_response" in risk_state
and risk_state["current_neutral_response"]
@@ -1042,13 +1582,11 @@ def run_analysis():
"Reasoning",
f"Neutral Analyst: {risk_state['current_neutral_response']}",
)
- # Update risk report with neutral analyst's latest analysis only
message_buffer.update_report_section(
"final_trade_decision",
f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}",
)
- # Update Portfolio Manager status and final decision
if "judge_decision" in risk_state and risk_state["judge_decision"]:
message_buffer.update_agent_status(
"Portfolio Manager", "in_progress"
@@ -1057,12 +1595,10 @@ def run_analysis():
"Reasoning",
f"Portfolio Manager: {risk_state['judge_decision']}",
)
- # Update risk report with final decision only
message_buffer.update_report_section(
"final_trade_decision",
f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
)
- # Mark risk analysts as completed
message_buffer.update_agent_status("Risky Analyst", "completed")
message_buffer.update_agent_status("Safe Analyst", "completed")
message_buffer.update_agent_status(
@@ -1072,16 +1608,13 @@ def run_analysis():
"Portfolio Manager", "completed"
)
- # Update the display
update_display(layout)
trace.append(chunk)
- # Get final state and decision
final_state = trace[-1]
decision = graph.process_signal(final_state["final_trade_decision"])
- # Update all agent statuses to completed
for agent in message_buffer.agent_status:
message_buffer.update_agent_status(agent, "completed")
@@ -1089,21 +1622,85 @@ def run_analysis():
"Analysis", f"Completed analysis for {selections['analysis_date']}"
)
- # Update final report sections
for section in message_buffer.report_sections.keys():
if section in final_state:
message_buffer.update_report_section(section, final_state[section])
- # Display the complete final report
display_complete_report(final_state)
update_display(layout)
+def show_main_menu():
+ with open("./cli/static/welcome.txt", "r") as f:
+ welcome_ascii = f.read()
+
+ welcome_content = f"{welcome_ascii}\n"
+ welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n"
+ welcome_content += "[bold]Available Options:[/bold]\n"
+ welcome_content += "1. Analyze a specific stock\n"
+ welcome_content += "2. Discover trending stocks\n\n"
+ welcome_content += "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]"
+
+ welcome_box = Panel(
+ welcome_content,
+ border_style="green",
+ padding=(1, 2),
+ title="Welcome to TradingAgents",
+ subtitle="Multi-Agents LLM Financial Trading Framework",
+ )
+ console.print(Align.center(welcome_box))
+ console.print()
+
+ MENU_OPTIONS = [
+ ("1. Analyze a specific stock", "analyze"),
+ ("2. Discover trending stocks", "discover"),
+ ]
+
+ choice = questionary.select(
+ "Select an option:",
+ choices=[
+ questionary.Choice(display, value=value) for display, value in MENU_OPTIONS
+ ],
+ instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
+ style=questionary.Style(
+ [
+ ("selected", "fg:green noinherit"),
+ ("highlighted", "fg:green noinherit"),
+ ("pointer", "fg:green noinherit"),
+ ]
+ ),
+ ).ask()
+
+ if choice is None:
+ console.print("\n[red]No option selected. Exiting...[/red]")
+ exit(0)
+
+ return choice
+
+
@app.command()
def analyze():
run_analysis()
+@app.command()
+def discover():
+ discover_trending_flow()
+
+
+@app.command()
+def menu():
+ choice = show_main_menu()
+ if choice == "analyze":
+ run_analysis()
+ elif choice == "discover":
+ discover_trending_flow()
+
+
if __name__ == "__main__":
- app()
+ choice = show_main_menu()
+ if choice == "analyze":
+ run_analysis()
+ elif choice == "discover":
+ discover_trending_flow()
diff --git a/main.py b/main.py
index a85ee6ec..42a45a0d 100644
--- a/main.py
+++ b/main.py
@@ -8,7 +8,7 @@ load_dotenv()
# Create a custom config
config = DEFAULT_CONFIG.copy()
-config["deep_think_llm"] = "gpt-4o-mini" # Use a different model
+config["deep_think_llm"] = "gpt-5" # Use a different model
config["quick_think_llm"] = "gpt-4o-mini" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
diff --git a/tests/discovery/__init__.py b/tests/discovery/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/discovery/test_api.py b/tests/discovery/test_api.py
new file mode 100644
index 00000000..700f351f
--- /dev/null
+++ b/tests/discovery/test_api.py
@@ -0,0 +1,200 @@
+import pytest
+from unittest.mock import Mock, patch, MagicMock
+from datetime import datetime, timedelta
+import signal
+
+from tradingagents.agents.discovery import (
+ DiscoveryRequest,
+ DiscoveryResult,
+ DiscoveryStatus,
+ TrendingStock,
+ NewsArticle,
+ Sector,
+ EventCategory,
+ DiscoveryTimeoutError,
+)
+
+
+def create_mock_trending_stock(
+ ticker: str = "AAPL",
+ company_name: str = "Apple Inc.",
+ score: float = 10.0,
+ sector: Sector = Sector.TECHNOLOGY,
+ event_type: EventCategory = EventCategory.EARNINGS,
+) -> TrendingStock:
+ return TrendingStock(
+ ticker=ticker,
+ company_name=company_name,
+ score=score,
+ mention_count=5,
+ sentiment=0.5,
+ sector=sector,
+ event_type=event_type,
+ news_summary="Test news summary",
+ source_articles=[],
+ )
+
+
+def create_mock_news_article() -> NewsArticle:
+ return NewsArticle(
+ title="Test Article",
+ source="Test Source",
+ url="https://example.com/article",
+ published_at=datetime.now(),
+ content_snippet="Test content about Apple stock",
+ ticker_mentions=["AAPL"],
+ )
+
+
+class TestDiscoverTrendingReturnsDiscoveryResult:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ @patch("tradingagents.graph.trading_graph.extract_entities")
+ @patch("tradingagents.graph.trading_graph.calculate_trending_scores")
+ def test_discover_trending_returns_discovery_result(
+ self, mock_scores, mock_extract, mock_bulk_news
+ ):
+ mock_bulk_news.return_value = [create_mock_news_article()]
+ mock_extract.return_value = []
+ mock_scores.return_value = [create_mock_trending_stock()]
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ result = graph.discover_trending()
+
+ assert isinstance(result, DiscoveryResult)
+ assert result.status == DiscoveryStatus.COMPLETED
+ assert len(result.trending_stocks) > 0
+
+
+class TestAnalyzeTrendingCallsPropagate:
+ def test_analyze_trending_calls_propagate(self):
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.propagate = Mock(return_value=({"final_state": "test"}, "BUY"))
+
+ trending_stock = create_mock_trending_stock()
+
+ result = graph.analyze_trending(trending_stock)
+
+ graph.propagate.assert_called_once()
+ call_args = graph.propagate.call_args
+ assert call_args[0][0] == "AAPL"
+
+
+class TestSectorFilterParameter:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ @patch("tradingagents.graph.trading_graph.extract_entities")
+ @patch("tradingagents.graph.trading_graph.calculate_trending_scores")
+ def test_sector_filter_filters_results(
+ self, mock_scores, mock_extract, mock_bulk_news
+ ):
+ mock_bulk_news.return_value = [create_mock_news_article()]
+ mock_extract.return_value = []
+ mock_scores.return_value = [
+ create_mock_trending_stock(ticker="AAPL", sector=Sector.TECHNOLOGY),
+ create_mock_trending_stock(ticker="JPM", sector=Sector.FINANCE),
+ create_mock_trending_stock(ticker="XOM", sector=Sector.ENERGY),
+ ]
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ sector_filter=[Sector.TECHNOLOGY],
+ )
+ result = graph.discover_trending(request)
+
+ assert all(
+ stock.sector == Sector.TECHNOLOGY for stock in result.trending_stocks
+ )
+
+
+class TestEventFilterParameter:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ @patch("tradingagents.graph.trading_graph.extract_entities")
+ @patch("tradingagents.graph.trading_graph.calculate_trending_scores")
+ def test_event_filter_filters_results(
+ self, mock_scores, mock_extract, mock_bulk_news
+ ):
+ mock_bulk_news.return_value = [create_mock_news_article()]
+ mock_extract.return_value = []
+ mock_scores.return_value = [
+ create_mock_trending_stock(ticker="AAPL", event_type=EventCategory.EARNINGS),
+ create_mock_trending_stock(
+ ticker="MSFT", event_type=EventCategory.PRODUCT_LAUNCH
+ ),
+ create_mock_trending_stock(
+ ticker="GOOGL", event_type=EventCategory.MERGER_ACQUISITION
+ ),
+ ]
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ event_filter=[EventCategory.EARNINGS],
+ )
+ result = graph.discover_trending(request)
+
+ assert all(
+ stock.event_type == EventCategory.EARNINGS
+ for stock in result.trending_stocks
+ )
+
+
+class TestTimeoutHandling:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ def test_timeout_raises_discovery_timeout_error(self, mock_bulk_news):
+ def slow_fetch(*args, **kwargs):
+ import time
+ time.sleep(0.5)
+ return []
+
+ mock_bulk_news.side_effect = slow_fetch
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 0.1,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ with pytest.raises(DiscoveryTimeoutError):
+ graph.discover_trending()
diff --git a/tests/discovery/test_bulk_news.py b/tests/discovery/test_bulk_news.py
new file mode 100644
index 00000000..b6fb3e0e
--- /dev/null
+++ b/tests/discovery/test_bulk_news.py
@@ -0,0 +1,160 @@
+import pytest
+from datetime import datetime, timedelta
+from unittest.mock import patch, MagicMock
+from tradingagents.agents.discovery import NewsArticle
+from tradingagents.dataflows.alpha_vantage_common import AlphaVantageRateLimitError
+
+
+class TestGetBulkNewsReturnsNewsArticles:
+ def test_get_bulk_news_returns_list_of_news_article_objects(self):
+ mock_raw_news = [
+ {
+ "title": "Market Update: Tech stocks rally",
+ "source": "Reuters",
+ "url": "https://reuters.com/market-update",
+ "published_at": datetime.now().isoformat(),
+ "content_snippet": "Technology stocks led gains in early trading...",
+ },
+ {
+ "title": "Fed signals rate decision",
+ "source": "Bloomberg",
+ "url": "https://bloomberg.com/fed-rates",
+ "published_at": datetime.now().isoformat(),
+ "content_snippet": "Federal Reserve officials indicated...",
+ },
+ ]
+
+ from tradingagents.dataflows.interface import (
+ _bulk_news_cache,
+ get_bulk_news,
+ )
+
+ _bulk_news_cache.clear()
+
+ with patch(
+ "tradingagents.dataflows.interface._fetch_bulk_news_from_vendor"
+ ) as mock_fetch:
+ mock_fetch.return_value = mock_raw_news
+
+ result = get_bulk_news(lookback_period="24h")
+
+ assert isinstance(result, list)
+ assert len(result) == 2
+ for article in result:
+ assert isinstance(article, NewsArticle)
+ assert article.title is not None
+ assert article.source is not None
+ assert article.url is not None
+
+
+class TestLookbackPeriodParsing:
+ @pytest.mark.parametrize(
+ "lookback,expected_hours",
+ [
+ ("1h", 1),
+ ("6h", 6),
+ ("24h", 24),
+ ("7d", 168),
+ ],
+ )
+ def test_lookback_period_parsing(self, lookback, expected_hours):
+ from tradingagents.dataflows.interface import parse_lookback_period
+
+ hours = parse_lookback_period(lookback)
+ assert hours == expected_hours
+
+ def test_invalid_lookback_period_raises_error(self):
+ from tradingagents.dataflows.interface import parse_lookback_period
+
+ with pytest.raises(ValueError):
+ parse_lookback_period("invalid")
+
+
+class TestVendorFallback:
+ def test_vendor_fallback_when_primary_rate_limited(self):
+ mock_openai_news = [
+ {
+ "title": "Fallback news from OpenAI",
+ "source": "Web Search",
+ "url": "https://example.com/fallback",
+ "published_at": datetime.now().isoformat(),
+ "content_snippet": "This is fallback content...",
+ },
+ ]
+
+ from tradingagents.dataflows.interface import (
+ _bulk_news_cache,
+ )
+
+ _bulk_news_cache.clear()
+
+ with patch(
+ "tradingagents.dataflows.interface.VENDOR_METHODS",
+ {
+ "get_bulk_news": {
+ "alpha_vantage": MagicMock(side_effect=AlphaVantageRateLimitError("Rate limit")),
+ "openai": MagicMock(return_value=mock_openai_news),
+ "google": MagicMock(return_value=[]),
+ }
+ }
+ ):
+ from tradingagents.dataflows.interface import _fetch_bulk_news_from_vendor
+
+ result = _fetch_bulk_news_from_vendor("24h")
+
+ assert isinstance(result, list)
+ assert len(result) == 1
+ assert result[0]["title"] == "Fallback news from OpenAI"
+
+
+class TestBulkNewsCache:
+ def test_cache_returns_same_results_within_ttl(self):
+ from tradingagents.dataflows.interface import (
+ _bulk_news_cache,
+ _get_cached_bulk_news,
+ _set_cached_bulk_news,
+ )
+
+ _bulk_news_cache.clear()
+
+ test_articles = [
+ NewsArticle(
+ title="Cached article",
+ source="Test Source",
+ url="https://test.com/cached",
+ published_at=datetime.now(),
+ content_snippet="Cached content...",
+ ticker_mentions=[],
+ )
+ ]
+
+ _set_cached_bulk_news("24h", test_articles)
+
+ cached_result = _get_cached_bulk_news("24h")
+ assert cached_result is not None
+ assert len(cached_result) == 1
+ assert cached_result[0].title == "Cached article"
+
+ cached_result_again = _get_cached_bulk_news("24h")
+ assert cached_result_again is not None
+ assert cached_result_again[0].title == cached_result[0].title
+
+
+class TestEmptyResultHandling:
+ def test_empty_result_handling(self):
+ from tradingagents.dataflows.interface import (
+ _bulk_news_cache,
+ get_bulk_news,
+ )
+
+ _bulk_news_cache.clear()
+
+ with patch(
+ "tradingagents.dataflows.interface._fetch_bulk_news_from_vendor"
+ ) as mock_fetch:
+ mock_fetch.return_value = []
+
+ result = get_bulk_news(lookback_period="1h")
+
+ assert isinstance(result, list)
+ assert len(result) == 0
diff --git a/tests/discovery/test_cli.py b/tests/discovery/test_cli.py
new file mode 100644
index 00000000..5b0c56d7
--- /dev/null
+++ b/tests/discovery/test_cli.py
@@ -0,0 +1,127 @@
+import pytest
+from unittest.mock import Mock, patch, MagicMock
+from datetime import datetime
+from io import StringIO
+
+from tradingagents.agents.discovery.models import (
+ DiscoveryResult,
+ DiscoveryRequest,
+ DiscoveryStatus,
+ TrendingStock,
+ NewsArticle,
+ Sector,
+ EventCategory,
+)
+
+
+@pytest.fixture
+def sample_trending_stocks():
+ article = NewsArticle(
+ title="Apple announces new iPhone",
+ source="Reuters",
+ url="https://reuters.com/article",
+ published_at=datetime.now(),
+ content_snippet="Apple Inc. unveiled its latest iPhone model today...",
+ ticker_mentions=["AAPL"],
+ )
+ return [
+ TrendingStock(
+ ticker="AAPL",
+ company_name="Apple Inc.",
+ score=8.5,
+ mention_count=10,
+ sentiment=0.7,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.PRODUCT_LAUNCH,
+ news_summary="Apple announced new iPhone model with enhanced AI capabilities.",
+ source_articles=[article],
+ ),
+ TrendingStock(
+ ticker="MSFT",
+ company_name="Microsoft Corporation",
+ score=7.2,
+ mention_count=8,
+ sentiment=0.5,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.EARNINGS,
+ news_summary="Microsoft reported strong quarterly earnings.",
+ source_articles=[article],
+ ),
+ TrendingStock(
+ ticker="NVDA",
+ company_name="NVIDIA Corporation",
+ score=6.8,
+ mention_count=6,
+ sentiment=0.4,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.PRODUCT_LAUNCH,
+ news_summary="NVIDIA unveiled new AI chips.",
+ source_articles=[article],
+ ),
+ ]
+
+
+@pytest.fixture
+def sample_discovery_result(sample_trending_stocks):
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ max_results=20,
+ )
+ return DiscoveryResult(
+ request=request,
+ trending_stocks=sample_trending_stocks,
+ status=DiscoveryStatus.COMPLETED,
+ started_at=datetime.now(),
+ completed_at=datetime.now(),
+ )
+
+
+class TestDiscoveryMenuOption:
+ def test_discover_trending_flow_exists(self):
+ from cli.main import discover_trending_flow
+ assert callable(discover_trending_flow)
+
+ def test_select_lookback_period_function_exists(self):
+ from cli.main import select_lookback_period
+ assert callable(select_lookback_period)
+
+
+class TestLookbackSelection:
+ @patch("cli.main.questionary.select")
+ def test_lookback_selection_returns_valid_period(self, mock_select):
+ mock_select.return_value.ask.return_value = "24h"
+ from cli.main import select_lookback_period
+ result = select_lookback_period()
+ assert result in ["1h", "6h", "24h", "7d"]
+
+ @patch("cli.main.questionary.select")
+ def test_lookback_selection_handles_all_options(self, mock_select):
+ from cli.main import select_lookback_period
+ for period in ["1h", "6h", "24h", "7d"]:
+ mock_select.return_value.ask.return_value = period
+ result = select_lookback_period()
+ assert result == period
+
+
+class TestResultsTableDisplay:
+ def test_create_discovery_results_table(self, sample_trending_stocks):
+ from cli.main import create_discovery_results_table
+ table = create_discovery_results_table(sample_trending_stocks)
+ assert table is not None
+ assert table.row_count == len(sample_trending_stocks)
+
+ def test_table_has_correct_columns(self, sample_trending_stocks):
+ from cli.main import create_discovery_results_table
+ table = create_discovery_results_table(sample_trending_stocks)
+ column_names = [col.header for col in table.columns]
+ expected_columns = ["Rank", "Ticker", "Company", "Score", "Mentions", "Event Type"]
+ for expected in expected_columns:
+ assert expected in column_names
+
+
+class TestDetailView:
+ def test_create_stock_detail_panel(self, sample_trending_stocks):
+ from cli.main import create_stock_detail_panel
+ stock = sample_trending_stocks[0]
+ panel = create_stock_detail_panel(stock, rank=1)
+ assert panel is not None
diff --git a/tests/discovery/test_entity_extractor.py b/tests/discovery/test_entity_extractor.py
new file mode 100644
index 00000000..57f9f82b
--- /dev/null
+++ b/tests/discovery/test_entity_extractor.py
@@ -0,0 +1,278 @@
+import pytest
+from datetime import datetime
+from unittest.mock import patch, MagicMock
+from tradingagents.agents.discovery import NewsArticle, EventCategory
+
+
+class TestExtractEntitiesReturnsCompanyMentions:
+ def test_extract_entities_returns_list_of_company_mentions(self):
+ from tradingagents.agents.discovery.entity_extractor import (
+ extract_entities,
+ EntityMention,
+ )
+
+ articles = [
+ NewsArticle(
+ title="Apple announces new iPhone",
+ source="Reuters",
+ url="https://reuters.com/apple",
+ published_at=datetime.now(),
+ content_snippet="Apple Inc unveiled its latest iPhone model today with advanced AI features.",
+ ticker_mentions=[],
+ ),
+ ]
+
+ mock_response = MagicMock()
+ mock_response.entities = [
+ MagicMock(
+ company_name="Apple Inc",
+ confidence=0.95,
+ context_snippet="Apple Inc unveiled its latest iPhone",
+ event_type="product_launch",
+ sentiment=0.7,
+ )
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.entity_extractor._get_llm"
+ ) as mock_get_llm:
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output.return_value.invoke.return_value = (
+ mock_response
+ )
+ mock_get_llm.return_value = mock_llm
+
+ result = extract_entities(articles)
+
+ assert isinstance(result, list)
+ assert len(result) > 0
+ assert all(isinstance(m, EntityMention) for m in result)
+ assert result[0].company_name == "Apple Inc"
+
+
+class TestConfidenceScoreRange:
+ def test_confidence_score_in_valid_range(self):
+ from tradingagents.agents.discovery.entity_extractor import (
+ extract_entities,
+ EntityMention,
+ )
+
+ articles = [
+ NewsArticle(
+ title="Tesla reports earnings",
+ source="Bloomberg",
+ url="https://bloomberg.com/tsla",
+ published_at=datetime.now(),
+ content_snippet="Tesla Inc reported strong quarterly earnings beating analyst expectations.",
+ ticker_mentions=[],
+ ),
+ ]
+
+ mock_response = MagicMock()
+ mock_response.entities = [
+ MagicMock(
+ company_name="Tesla Inc",
+ confidence=0.88,
+ context_snippet="Tesla Inc reported strong quarterly earnings",
+ event_type="earnings",
+ sentiment=0.5,
+ )
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.entity_extractor._get_llm"
+ ) as mock_get_llm:
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output.return_value.invoke.return_value = (
+ mock_response
+ )
+ mock_get_llm.return_value = mock_llm
+
+ result = extract_entities(articles)
+
+ for mention in result:
+ assert 0.0 <= mention.confidence <= 1.0
+
+
+class TestContextSnippetExtraction:
+ def test_context_snippet_extraction(self):
+ from tradingagents.agents.discovery.entity_extractor import (
+ extract_entities,
+ EntityMention,
+ )
+
+ articles = [
+ NewsArticle(
+ title="Microsoft acquires gaming company",
+ source="WSJ",
+ url="https://wsj.com/msft",
+ published_at=datetime.now(),
+ content_snippet="Microsoft Corporation announced today it will acquire a major gaming studio for $10 billion.",
+ ticker_mentions=[],
+ ),
+ ]
+
+ mock_response = MagicMock()
+ mock_response.entities = [
+ MagicMock(
+ company_name="Microsoft Corporation",
+ confidence=0.92,
+ context_snippet="Microsoft Corporation announced today it will acquire",
+ event_type="merger_acquisition",
+ sentiment=0.6,
+ )
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.entity_extractor._get_llm"
+ ) as mock_get_llm:
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output.return_value.invoke.return_value = (
+ mock_response
+ )
+ mock_get_llm.return_value = mock_llm
+
+ result = extract_entities(articles)
+
+ assert len(result) > 0
+ for mention in result:
+ assert mention.context_snippet is not None
+ assert len(mention.context_snippet) > 0
+ assert len(mention.context_snippet) <= 150
+
+
+class TestBatchProcessing:
+ def test_batch_processing_of_multiple_articles(self):
+ from tradingagents.agents.discovery.entity_extractor import (
+ extract_entities,
+ EntityMention,
+ BATCH_SIZE,
+ )
+
+ articles = [
+ NewsArticle(
+ title=f"News article {i}",
+ source="Reuters",
+ url=f"https://reuters.com/article{i}",
+ published_at=datetime.now(),
+ content_snippet=f"Company {i} announced major developments today.",
+ ticker_mentions=[],
+ )
+ for i in range(15)
+ ]
+
+ mock_response = MagicMock()
+ mock_response.entities = [
+ MagicMock(
+ company_name="Test Company",
+ confidence=0.85,
+ context_snippet="Company announced major developments",
+ event_type="other",
+ sentiment=0.0,
+ )
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.entity_extractor._get_llm"
+ ) as mock_get_llm:
+ mock_llm = MagicMock()
+ structured_llm = MagicMock()
+ structured_llm.invoke.return_value = mock_response
+ mock_llm.with_structured_output.return_value = structured_llm
+ mock_get_llm.return_value = mock_llm
+
+ result = extract_entities(articles)
+
+ expected_batches = (len(articles) + BATCH_SIZE - 1) // BATCH_SIZE
+ assert structured_llm.invoke.call_count == expected_batches
+
+
+class TestNoCompanyMentions:
+ def test_handling_of_articles_with_no_company_mentions(self):
+ from tradingagents.agents.discovery.entity_extractor import (
+ extract_entities,
+ EntityMention,
+ )
+
+ articles = [
+ NewsArticle(
+ title="Weather forecast for tomorrow",
+ source="Weather Channel",
+ url="https://weather.com/forecast",
+ published_at=datetime.now(),
+ content_snippet="Tomorrow will be sunny with temperatures reaching 75 degrees.",
+ ticker_mentions=[],
+ ),
+ ]
+
+ mock_response = MagicMock()
+ mock_response.entities = []
+
+ with patch(
+ "tradingagents.agents.discovery.entity_extractor._get_llm"
+ ) as mock_get_llm:
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output.return_value.invoke.return_value = (
+ mock_response
+ )
+ mock_get_llm.return_value = mock_llm
+
+ result = extract_entities(articles)
+
+ assert isinstance(result, list)
+ assert len(result) == 0
+
+
+class TestEventTypeClassification:
+ @pytest.mark.parametrize(
+ "event_type",
+ [
+ "earnings",
+ "merger_acquisition",
+ "regulatory",
+ "product_launch",
+ "executive_change",
+ "other",
+ ],
+ )
+ def test_event_type_classification(self, event_type):
+ from tradingagents.agents.discovery.entity_extractor import (
+ extract_entities,
+ EntityMention,
+ )
+
+ articles = [
+ NewsArticle(
+ title="Company news",
+ source="Reuters",
+ url="https://reuters.com/news",
+ published_at=datetime.now(),
+ content_snippet="A company made an announcement today.",
+ ticker_mentions=[],
+ ),
+ ]
+
+ mock_response = MagicMock()
+ mock_response.entities = [
+ MagicMock(
+ company_name="Test Company",
+ confidence=0.90,
+ context_snippet="A company made an announcement",
+ event_type=event_type,
+ sentiment=0.0,
+ )
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.entity_extractor._get_llm"
+ ) as mock_get_llm:
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output.return_value.invoke.return_value = (
+ mock_response
+ )
+ mock_get_llm.return_value = mock_llm
+
+ result = extract_entities(articles)
+
+ assert len(result) > 0
+ assert result[0].event_type == EventCategory(event_type)
diff --git a/tests/discovery/test_integration.py b/tests/discovery/test_integration.py
new file mode 100644
index 00000000..6adba188
--- /dev/null
+++ b/tests/discovery/test_integration.py
@@ -0,0 +1,489 @@
+import pytest
+import math
+from datetime import datetime, timedelta
+from unittest.mock import patch, MagicMock
+from tradingagents.agents.discovery import (
+ TrendingStock,
+ NewsArticle,
+ DiscoveryRequest,
+ DiscoveryResult,
+ DiscoveryStatus,
+ Sector,
+ EventCategory,
+ DiscoveryTimeoutError,
+ NewsUnavailableError,
+)
+from tradingagents.agents.discovery.entity_extractor import EntityMention
+
+
+class TestEndToEndDiscoveryFlow:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ @patch("tradingagents.graph.trading_graph.extract_entities")
+ @patch("tradingagents.graph.trading_graph.calculate_trending_scores")
+ def test_full_discovery_flow_from_news_to_results(
+ self, mock_scores, mock_extract, mock_bulk_news
+ ):
+ now = datetime.now()
+ mock_articles = [
+ NewsArticle(
+ title="Apple announces record earnings",
+ source="Reuters",
+ url="https://reuters.com/apple-earnings",
+ published_at=now - timedelta(hours=2),
+ content_snippet="Apple Inc reported record quarterly earnings...",
+ ticker_mentions=["AAPL"],
+ ),
+ NewsArticle(
+ title="Apple stock surges on AI news",
+ source="Bloomberg",
+ url="https://bloomberg.com/apple-ai",
+ published_at=now - timedelta(hours=1),
+ content_snippet="Shares of Apple jumped after AI announcement...",
+ ticker_mentions=["AAPL"],
+ ),
+ ]
+ mock_bulk_news.return_value = mock_articles
+
+ mock_mentions = [
+ EntityMention(
+ company_name="Apple Inc",
+ confidence=0.95,
+ context_snippet="Apple Inc reported record quarterly earnings",
+ article_id="article_0",
+ event_type=EventCategory.EARNINGS,
+ sentiment=0.8,
+ ),
+ EntityMention(
+ company_name="Apple",
+ confidence=0.92,
+ context_snippet="Shares of Apple jumped",
+ article_id="article_1",
+ event_type=EventCategory.PRODUCT_LAUNCH,
+ sentiment=0.7,
+ ),
+ ]
+ mock_extract.return_value = mock_mentions
+
+ mock_trending = [
+ TrendingStock(
+ ticker="AAPL",
+ company_name="Apple Inc.",
+ score=8.5,
+ mention_count=2,
+ sentiment=0.75,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.EARNINGS,
+ news_summary="Apple reported record earnings and AI progress.",
+ source_articles=mock_articles,
+ ),
+ ]
+ mock_scores.return_value = mock_trending
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ request = DiscoveryRequest(lookback_period="24h")
+ result = graph.discover_trending(request)
+
+ assert isinstance(result, DiscoveryResult)
+ assert result.status == DiscoveryStatus.COMPLETED
+ assert len(result.trending_stocks) == 1
+ assert result.trending_stocks[0].ticker == "AAPL"
+ assert result.trending_stocks[0].mention_count >= 2
+
+ mock_bulk_news.assert_called_once_with("24h")
+ mock_extract.assert_called_once()
+ mock_scores.assert_called_once()
+
+
+class TestEntityExtractionToScoringPipeline:
+ def test_pipeline_from_extraction_to_scoring(self):
+ from tradingagents.agents.discovery.scorer import calculate_trending_scores
+
+ now = datetime.now()
+ articles = [
+ NewsArticle(
+ title="Microsoft cloud revenue grows",
+ source="WSJ",
+ url="https://wsj.com/article1",
+ published_at=now - timedelta(hours=2),
+ content_snippet="Microsoft Corporation reported strong cloud growth.",
+ ticker_mentions=["MSFT"],
+ ),
+ NewsArticle(
+ title="Microsoft earnings beat estimates",
+ source="CNBC",
+ url="https://cnbc.com/article2",
+ published_at=now - timedelta(hours=3),
+ content_snippet="Microsoft earnings exceeded analyst expectations.",
+ ticker_mentions=["MSFT"],
+ ),
+ NewsArticle(
+ title="Tech stocks rally",
+ source="Bloomberg",
+ url="https://bloomberg.com/article3",
+ published_at=now - timedelta(hours=1),
+ content_snippet="Technology companies led market gains.",
+ ticker_mentions=[],
+ ),
+ ]
+
+ mentions = [
+ EntityMention(
+ company_name="Microsoft Corporation",
+ confidence=0.95,
+ context_snippet="Microsoft Corporation reported strong cloud growth",
+ article_id="article_0",
+ event_type=EventCategory.EARNINGS,
+ sentiment=0.7,
+ ),
+ EntityMention(
+ company_name="Microsoft",
+ confidence=0.92,
+ context_snippet="Microsoft earnings exceeded analyst expectations",
+ article_id="article_1",
+ event_type=EventCategory.EARNINGS,
+ sentiment=0.8,
+ ),
+ ]
+
+ with patch("tradingagents.agents.discovery.scorer.resolve_ticker") as mock_resolve:
+ mock_resolve.return_value = "MSFT"
+
+ with patch("tradingagents.agents.discovery.scorer.classify_sector") as mock_sector:
+ mock_sector.return_value = "technology"
+
+ result = calculate_trending_scores(mentions, articles, min_mentions=2)
+
+ assert len(result) == 1
+ assert result[0].ticker == "MSFT"
+ assert result[0].mention_count == 2
+ assert result[0].sentiment > 0
+
+
+class TestNewsVendorFailureGracefulDegradation:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ def test_news_vendor_failure_with_graceful_degradation(self, mock_bulk_news):
+ mock_bulk_news.side_effect = NewsUnavailableError("All news vendors failed")
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ result = graph.discover_trending()
+
+ assert result.status == DiscoveryStatus.FAILED
+ assert result.error_message is not None
+ assert "news" in result.error_message.lower() or "vendor" in result.error_message.lower()
+
+
+class TestTimeoutHandlingWithPartialResults:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ def test_timeout_handling_returns_error(self, mock_bulk_news):
+ def slow_fetch(*args, **kwargs):
+ import time
+ time.sleep(0.3)
+ return []
+
+ mock_bulk_news.side_effect = slow_fetch
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 0.1,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ with pytest.raises(DiscoveryTimeoutError):
+ graph.discover_trending()
+
+
+class TestNoTrendingStocksFound:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ @patch("tradingagents.graph.trading_graph.extract_entities")
+ @patch("tradingagents.graph.trading_graph.calculate_trending_scores")
+ def test_no_trending_stocks_found_returns_empty_list(
+ self, mock_scores, mock_extract, mock_bulk_news
+ ):
+ mock_bulk_news.return_value = [
+ NewsArticle(
+ title="General market update",
+ source="Reuters",
+ url="https://reuters.com/general",
+ published_at=datetime.now(),
+ content_snippet="Markets were quiet today with no major news.",
+ ticker_mentions=[],
+ ),
+ ]
+ mock_extract.return_value = []
+ mock_scores.return_value = []
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ result = graph.discover_trending()
+
+ assert result.status == DiscoveryStatus.COMPLETED
+ assert len(result.trending_stocks) == 0
+
+
+class TestAllStocksFilteredOutBySectorFilter:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ @patch("tradingagents.graph.trading_graph.extract_entities")
+ @patch("tradingagents.graph.trading_graph.calculate_trending_scores")
+ def test_all_stocks_filtered_out_by_sector_filter(
+ self, mock_scores, mock_extract, mock_bulk_news
+ ):
+ mock_bulk_news.return_value = []
+ mock_extract.return_value = []
+ mock_scores.return_value = [
+ TrendingStock(
+ ticker="AAPL",
+ company_name="Apple Inc.",
+ score=10.0,
+ mention_count=5,
+ sentiment=0.5,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.EARNINGS,
+ news_summary="Apple earnings",
+ source_articles=[],
+ ),
+ TrendingStock(
+ ticker="MSFT",
+ company_name="Microsoft",
+ score=9.0,
+ mention_count=4,
+ sentiment=0.4,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.PRODUCT_LAUNCH,
+ news_summary="Microsoft product",
+ source_articles=[],
+ ),
+ ]
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ sector_filter=[Sector.HEALTHCARE],
+ )
+ result = graph.discover_trending(request)
+
+ assert result.status == DiscoveryStatus.COMPLETED
+ assert len(result.trending_stocks) == 0
+
+
+class TestAllStocksFilteredOutByEventFilter:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ @patch("tradingagents.graph.trading_graph.extract_entities")
+ @patch("tradingagents.graph.trading_graph.calculate_trending_scores")
+ def test_all_stocks_filtered_out_by_event_filter(
+ self, mock_scores, mock_extract, mock_bulk_news
+ ):
+ mock_bulk_news.return_value = []
+ mock_extract.return_value = []
+ mock_scores.return_value = [
+ TrendingStock(
+ ticker="AAPL",
+ company_name="Apple Inc.",
+ score=10.0,
+ mention_count=5,
+ sentiment=0.5,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.EARNINGS,
+ news_summary="Apple earnings",
+ source_articles=[],
+ ),
+ ]
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ event_filter=[EventCategory.MERGER_ACQUISITION],
+ )
+ result = graph.discover_trending(request)
+
+ assert result.status == DiscoveryStatus.COMPLETED
+ assert len(result.trending_stocks) == 0
+
+
+class TestMultipleSectorsAndEventsFiltering:
+ @patch("tradingagents.graph.trading_graph.get_bulk_news")
+ @patch("tradingagents.graph.trading_graph.extract_entities")
+ @patch("tradingagents.graph.trading_graph.calculate_trending_scores")
+ def test_combined_sector_and_event_filtering(
+ self, mock_scores, mock_extract, mock_bulk_news
+ ):
+ mock_bulk_news.return_value = []
+ mock_extract.return_value = []
+ mock_scores.return_value = [
+ TrendingStock(
+ ticker="AAPL",
+ company_name="Apple Inc.",
+ score=10.0,
+ mention_count=5,
+ sentiment=0.5,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.EARNINGS,
+ news_summary="Apple earnings",
+ source_articles=[],
+ ),
+ TrendingStock(
+ ticker="JPM",
+ company_name="JPMorgan Chase",
+ score=9.0,
+ mention_count=4,
+ sentiment=0.4,
+ sector=Sector.FINANCE,
+ event_type=EventCategory.EARNINGS,
+ news_summary="JPM earnings",
+ source_articles=[],
+ ),
+ TrendingStock(
+ ticker="XOM",
+ company_name="Exxon Mobil",
+ score=8.0,
+ mention_count=3,
+ sentiment=0.3,
+ sector=Sector.ENERGY,
+ event_type=EventCategory.REGULATORY,
+ news_summary="XOM regulatory news",
+ source_articles=[],
+ ),
+ ]
+
+ from tradingagents.graph.trading_graph import TradingAgentsGraph
+
+ with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None):
+ graph = TradingAgentsGraph()
+ graph.config = {
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
+ }
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ sector_filter=[Sector.TECHNOLOGY, Sector.FINANCE],
+ event_filter=[EventCategory.EARNINGS],
+ )
+ result = graph.discover_trending(request)
+
+ assert result.status == DiscoveryStatus.COMPLETED
+ assert len(result.trending_stocks) == 2
+ tickers = [s.ticker for s in result.trending_stocks]
+ assert "AAPL" in tickers
+ assert "JPM" in tickers
+ assert "XOM" not in tickers
+
+
+class TestDiscoveryResultPersistenceIntegration:
+ def test_discovery_result_can_be_serialized_and_saved(self):
+ from tradingagents.agents.discovery.persistence import (
+ save_discovery_result,
+ generate_markdown_summary,
+ )
+ import tempfile
+ import shutil
+ from pathlib import Path
+
+ article = NewsArticle(
+ title="Test article",
+ source="Test",
+ url="https://test.com",
+ published_at=datetime.now(),
+ content_snippet="Test content",
+ ticker_mentions=["TEST"],
+ )
+
+ stock = TrendingStock(
+ ticker="TEST",
+ company_name="Test Company",
+ score=5.0,
+ mention_count=2,
+ sentiment=0.5,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.OTHER,
+ news_summary="Test news summary",
+ source_articles=[article],
+ )
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ created_at=datetime.now(),
+ )
+
+ result = DiscoveryResult(
+ request=request,
+ trending_stocks=[stock],
+ status=DiscoveryStatus.COMPLETED,
+ started_at=datetime.now(),
+ completed_at=datetime.now(),
+ )
+
+ temp_dir = tempfile.mkdtemp()
+ try:
+ path = save_discovery_result(result, base_path=Path(temp_dir))
+ assert path.exists()
+ assert (path / "discovery_result.json").exists()
+ assert (path / "discovery_summary.md").exists()
+
+ markdown = generate_markdown_summary(result)
+ assert "TEST" in markdown
+ assert "Test Company" in markdown
+ finally:
+ shutil.rmtree(temp_dir)
diff --git a/tests/discovery/test_models.py b/tests/discovery/test_models.py
new file mode 100644
index 00000000..7717d022
--- /dev/null
+++ b/tests/discovery/test_models.py
@@ -0,0 +1,196 @@
+import pytest
+from datetime import datetime
+from tradingagents.agents.discovery import (
+ TrendingStock,
+ NewsArticle,
+ DiscoveryRequest,
+ DiscoveryResult,
+ Sector,
+ EventCategory,
+)
+from tradingagents.agents.discovery.models import DiscoveryStatus
+
+
+class TestTrendingStock:
+ def test_trending_stock_creation_and_validation(self):
+ article = NewsArticle(
+ title="Apple announces new iPhone",
+ source="Reuters",
+ url="https://reuters.com/article1",
+ published_at=datetime(2024, 1, 15, 10, 30, 0),
+ content_snippet="Apple Inc announced its latest iPhone model today...",
+ ticker_mentions=["AAPL"],
+ )
+
+ stock = TrendingStock(
+ ticker="AAPL",
+ company_name="Apple Inc.",
+ score=85.5,
+ mention_count=10,
+ sentiment=0.75,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.PRODUCT_LAUNCH,
+ news_summary="Apple announced new iPhone with advanced AI features.",
+ source_articles=[article],
+ )
+
+ assert stock.ticker == "AAPL"
+ assert stock.company_name == "Apple Inc."
+ assert stock.score == 85.5
+ assert stock.mention_count == 10
+ assert stock.sentiment == 0.75
+ assert stock.sector == Sector.TECHNOLOGY
+ assert stock.event_type == EventCategory.PRODUCT_LAUNCH
+ assert len(stock.source_articles) == 1
+
+
+class TestNewsArticle:
+ def test_news_article_with_required_fields(self):
+ published = datetime(2024, 1, 15, 14, 0, 0)
+
+ article = NewsArticle(
+ title="Tesla Q4 Earnings Beat Expectations",
+ source="Bloomberg",
+ url="https://bloomberg.com/news/tsla-earnings",
+ published_at=published,
+ content_snippet="Tesla Inc. reported fourth quarter earnings that exceeded analyst expectations...",
+ ticker_mentions=["TSLA", "F"],
+ )
+
+ assert article.title == "Tesla Q4 Earnings Beat Expectations"
+ assert article.source == "Bloomberg"
+ assert article.url == "https://bloomberg.com/news/tsla-earnings"
+ assert article.published_at == published
+ assert article.content_snippet.startswith("Tesla Inc.")
+ assert "TSLA" in article.ticker_mentions
+ assert "F" in article.ticker_mentions
+
+
+class TestDiscoveryRequest:
+ def test_discovery_request_with_lookback_period_validation(self):
+ created = datetime(2024, 1, 15, 12, 0, 0)
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ sector_filter=[Sector.TECHNOLOGY, Sector.HEALTHCARE],
+ event_filter=[EventCategory.EARNINGS],
+ max_results=20,
+ created_at=created,
+ )
+
+ assert request.lookback_period == "24h"
+ assert Sector.TECHNOLOGY in request.sector_filter
+ assert Sector.HEALTHCARE in request.sector_filter
+ assert EventCategory.EARNINGS in request.event_filter
+ assert request.max_results == 20
+ assert request.created_at == created
+
+ def test_discovery_request_with_defaults(self):
+ request = DiscoveryRequest(
+ lookback_period="1h",
+ )
+
+ assert request.lookback_period == "1h"
+ assert request.sector_filter is None
+ assert request.event_filter is None
+ assert request.max_results == 20
+ assert request.created_at is not None
+
+
+class TestDiscoveryResult:
+ def test_discovery_result_state_transitions(self):
+ request = DiscoveryRequest(lookback_period="6h")
+ started = datetime(2024, 1, 15, 12, 0, 0)
+
+ result = DiscoveryResult(
+ request=request,
+ trending_stocks=[],
+ status=DiscoveryStatus.CREATED,
+ started_at=started,
+ )
+
+ assert result.status == DiscoveryStatus.CREATED
+
+ result.status = DiscoveryStatus.PROCESSING
+ assert result.status == DiscoveryStatus.PROCESSING
+
+ result.status = DiscoveryStatus.COMPLETED
+ result.completed_at = datetime(2024, 1, 15, 12, 1, 0)
+ assert result.status == DiscoveryStatus.COMPLETED
+ assert result.completed_at is not None
+
+ def test_discovery_result_failed_state(self):
+ request = DiscoveryRequest(lookback_period="7d")
+
+ result = DiscoveryResult(
+ request=request,
+ trending_stocks=[],
+ status=DiscoveryStatus.FAILED,
+ started_at=datetime(2024, 1, 15, 12, 0, 0),
+ error_message="News API rate limit exceeded",
+ )
+
+ assert result.status == DiscoveryStatus.FAILED
+ assert result.error_message == "News API rate limit exceeded"
+
+
+class TestSerializationRoundtrip:
+ def test_to_dict_and_from_dict_serialization_roundtrip(self):
+ article = NewsArticle(
+ title="Microsoft acquires AI startup",
+ source="WSJ",
+ url="https://wsj.com/msft-acquisition",
+ published_at=datetime(2024, 1, 15, 9, 0, 0),
+ content_snippet="Microsoft Corp announced the acquisition of an AI startup...",
+ ticker_mentions=["MSFT"],
+ )
+
+ stock = TrendingStock(
+ ticker="MSFT",
+ company_name="Microsoft Corporation",
+ score=92.3,
+ mention_count=15,
+ sentiment=0.65,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.MERGER_ACQUISITION,
+ news_summary="Microsoft announces major AI acquisition.",
+ source_articles=[article],
+ )
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ sector_filter=[Sector.TECHNOLOGY],
+ event_filter=[EventCategory.MERGER_ACQUISITION],
+ max_results=10,
+ created_at=datetime(2024, 1, 15, 8, 0, 0),
+ )
+
+ result = DiscoveryResult(
+ request=request,
+ trending_stocks=[stock],
+ status=DiscoveryStatus.COMPLETED,
+ started_at=datetime(2024, 1, 15, 8, 0, 0),
+ completed_at=datetime(2024, 1, 15, 8, 1, 30),
+ )
+
+ result_dict = result.to_dict()
+ restored_result = DiscoveryResult.from_dict(result_dict)
+
+ assert restored_result.status == result.status
+ assert restored_result.request.lookback_period == request.lookback_period
+ assert len(restored_result.trending_stocks) == 1
+
+ restored_stock = restored_result.trending_stocks[0]
+ assert restored_stock.ticker == stock.ticker
+ assert restored_stock.company_name == stock.company_name
+ assert restored_stock.score == stock.score
+ assert restored_stock.mention_count == stock.mention_count
+ assert restored_stock.sentiment == stock.sentiment
+ assert restored_stock.sector == stock.sector
+ assert restored_stock.event_type == stock.event_type
+
+ assert len(restored_stock.source_articles) == 1
+ restored_article = restored_stock.source_articles[0]
+ assert restored_article.title == article.title
+ assert restored_article.source == article.source
+ assert restored_article.url == article.url
diff --git a/tests/discovery/test_persistence.py b/tests/discovery/test_persistence.py
new file mode 100644
index 00000000..e649b02a
--- /dev/null
+++ b/tests/discovery/test_persistence.py
@@ -0,0 +1,228 @@
+import pytest
+import json
+from datetime import datetime
+from pathlib import Path
+import tempfile
+import shutil
+
+from tradingagents.agents.discovery import (
+ TrendingStock,
+ NewsArticle,
+ DiscoveryRequest,
+ DiscoveryResult,
+ DiscoveryStatus,
+ Sector,
+ EventCategory,
+)
+from tradingagents.agents.discovery.persistence import (
+ save_discovery_result,
+ generate_markdown_summary,
+)
+
+
+@pytest.fixture
+def sample_discovery_result():
+ articles = [
+ NewsArticle(
+ title="Apple announces new iPhone with AI features",
+ source="Reuters",
+ url="https://reuters.com/apple-iphone-ai",
+ published_at=datetime(2024, 1, 15, 10, 30, 0),
+ content_snippet="Apple Inc announced its latest iPhone model with advanced AI...",
+ ticker_mentions=["AAPL"],
+ ),
+ NewsArticle(
+ title="Apple stock surges on earnings beat",
+ source="Bloomberg",
+ url="https://bloomberg.com/apple-earnings",
+ published_at=datetime(2024, 1, 15, 11, 0, 0),
+ content_snippet="Shares of Apple Inc surged after the company reported...",
+ ticker_mentions=["AAPL"],
+ ),
+ NewsArticle(
+ title="Microsoft cloud revenue grows 25%",
+ source="WSJ",
+ url="https://wsj.com/msft-cloud",
+ published_at=datetime(2024, 1, 15, 9, 0, 0),
+ content_snippet="Microsoft Corp reported strong cloud revenue growth...",
+ ticker_mentions=["MSFT"],
+ ),
+ ]
+
+ stocks = [
+ TrendingStock(
+ ticker="AAPL",
+ company_name="Apple Inc.",
+ score=8.54,
+ mention_count=12,
+ sentiment=0.72,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.EARNINGS,
+ news_summary="Apple reported strong earnings and announced new AI features.",
+ source_articles=[articles[0], articles[1]],
+ ),
+ TrendingStock(
+ ticker="MSFT",
+ company_name="Microsoft Corporation",
+ score=7.23,
+ mention_count=9,
+ sentiment=0.65,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.PRODUCT_LAUNCH,
+ news_summary="Microsoft cloud business continues strong growth.",
+ source_articles=[articles[2]],
+ ),
+ TrendingStock(
+ ticker="GOOGL",
+ company_name="Alphabet Inc.",
+ score=6.15,
+ mention_count=7,
+ sentiment=0.58,
+ sector=Sector.TECHNOLOGY,
+ event_type=EventCategory.REGULATORY,
+ news_summary="Google faces regulatory scrutiny in multiple markets.",
+ source_articles=[],
+ ),
+ ]
+
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ sector_filter=[Sector.TECHNOLOGY],
+ event_filter=[EventCategory.EARNINGS],
+ max_results=20,
+ created_at=datetime(2024, 1, 15, 14, 30, 45),
+ )
+
+ return DiscoveryResult(
+ request=request,
+ trending_stocks=stocks,
+ status=DiscoveryStatus.COMPLETED,
+ started_at=datetime(2024, 1, 15, 14, 30, 45),
+ completed_at=datetime(2024, 1, 15, 14, 31, 30),
+ )
+
+
+@pytest.fixture
+def temp_results_dir():
+ temp_dir = tempfile.mkdtemp()
+ yield Path(temp_dir)
+ shutil.rmtree(temp_dir)
+
+
+class TestDirectoryStructureCreation:
+ def test_creates_correct_directory_structure(self, sample_discovery_result, temp_results_dir):
+ result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
+
+ assert result_path.exists()
+ assert result_path.is_dir()
+
+ path_parts = result_path.parts
+ assert "discovery" in path_parts
+
+ date_part = path_parts[-2]
+ time_part = path_parts[-1]
+
+ assert len(date_part.split("-")) == 3
+ assert len(time_part.split("-")) == 3
+
+
+class TestDiscoveryResultJson:
+ def test_discovery_result_json_contains_all_fields(self, sample_discovery_result, temp_results_dir):
+ result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
+
+ json_path = result_path / "discovery_result.json"
+ assert json_path.exists()
+
+ with open(json_path, "r") as f:
+ saved_data = json.load(f)
+
+ assert "request" in saved_data
+ assert "trending_stocks" in saved_data
+ assert "status" in saved_data
+ assert "started_at" in saved_data
+ assert "completed_at" in saved_data
+
+ assert saved_data["request"]["lookback_period"] == "24h"
+ assert saved_data["status"] == "completed"
+ assert len(saved_data["trending_stocks"]) == 3
+
+ first_stock = saved_data["trending_stocks"][0]
+ assert first_stock["ticker"] == "AAPL"
+ assert first_stock["company_name"] == "Apple Inc."
+ assert first_stock["score"] == 8.54
+ assert first_stock["mention_count"] == 12
+ assert first_stock["sentiment"] == 0.72
+ assert first_stock["sector"] == "technology"
+ assert first_stock["event_type"] == "earnings"
+ assert "news_summary" in first_stock
+ assert "source_articles" in first_stock
+
+
+class TestDiscoverySummaryMarkdown:
+ def test_discovery_summary_md_is_human_readable(self, sample_discovery_result, temp_results_dir):
+ result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir)
+
+ md_path = result_path / "discovery_summary.md"
+ assert md_path.exists()
+
+ with open(md_path, "r") as f:
+ markdown_content = f.read()
+
+ assert "# Discovery Results" in markdown_content
+ assert "Timestamp:" in markdown_content
+ assert "Lookback Period:" in markdown_content
+ assert "24h" in markdown_content
+ assert "Total Stocks Found:" in markdown_content
+
+ assert "## Trending Stocks" in markdown_content
+ assert "| Rank |" in markdown_content
+ assert "| Ticker |" in markdown_content
+ assert "| Company |" in markdown_content
+ assert "| Score |" in markdown_content
+ assert "| Mentions |" in markdown_content
+ assert "| Event |" in markdown_content
+
+ assert "AAPL" in markdown_content
+ assert "Apple Inc." in markdown_content
+ assert "8.54" in markdown_content
+ assert "12" in markdown_content
+ assert "earnings" in markdown_content
+
+ assert "MSFT" in markdown_content
+ assert "Microsoft Corporation" in markdown_content
+
+ assert "## Top 3 Detailed Analysis" in markdown_content
+ assert "### 1. AAPL - Apple Inc." in markdown_content
+ assert "**Score:**" in markdown_content
+ assert "**Sentiment:**" in markdown_content
+ assert "**Sector:**" in markdown_content
+ assert "**Event Type:**" in markdown_content
+ assert "**Mentions:**" in markdown_content
+ assert "**News Summary:**" in markdown_content
+
+
+class TestMarkdownGeneration:
+ def test_generate_markdown_with_filters(self, sample_discovery_result):
+ markdown = generate_markdown_summary(sample_discovery_result)
+
+ assert "sector=technology" in markdown.lower()
+ assert "event=earnings" in markdown.lower()
+
+ def test_generate_markdown_without_filters(self):
+ request = DiscoveryRequest(
+ lookback_period="6h",
+ created_at=datetime(2024, 1, 15, 10, 0, 0),
+ )
+
+ result = DiscoveryResult(
+ request=request,
+ trending_stocks=[],
+ status=DiscoveryStatus.COMPLETED,
+ started_at=datetime(2024, 1, 15, 10, 0, 0),
+ completed_at=datetime(2024, 1, 15, 10, 1, 0),
+ )
+
+ markdown = generate_markdown_summary(result)
+
+ assert "Filters:" in markdown
+ assert "None" in markdown
diff --git a/tests/discovery/test_scorer.py b/tests/discovery/test_scorer.py
new file mode 100644
index 00000000..e40b778e
--- /dev/null
+++ b/tests/discovery/test_scorer.py
@@ -0,0 +1,469 @@
+import pytest
+import math
+from datetime import datetime, timedelta
+from unittest.mock import patch
+from tradingagents.agents.discovery import NewsArticle, EventCategory, Sector
+from tradingagents.agents.discovery.entity_extractor import EntityMention
+
+
+class TestFrequencyCalculation:
+ def test_frequency_calculation_unique_article_count(self):
+ from tradingagents.agents.discovery.scorer import calculate_trending_scores
+
+ now = datetime.now()
+ articles = [
+ NewsArticle(
+ title="Apple Q4 Earnings",
+ source="Reuters",
+ url="https://reuters.com/article1",
+ published_at=now - timedelta(hours=1),
+ content_snippet="Apple Inc reported strong earnings.",
+ ticker_mentions=["AAPL"],
+ ),
+ NewsArticle(
+ title="Apple iPhone Sales",
+ source="Bloomberg",
+ url="https://bloomberg.com/article2",
+ published_at=now - timedelta(hours=2),
+ content_snippet="Apple saw record iPhone sales.",
+ ticker_mentions=["AAPL"],
+ ),
+ NewsArticle(
+ title="Apple AI Features",
+ source="WSJ",
+ url="https://wsj.com/article3",
+ published_at=now - timedelta(hours=3),
+ content_snippet="Apple announced AI features.",
+ ticker_mentions=["AAPL"],
+ ),
+ ]
+
+ mentions = [
+ EntityMention(
+ company_name="Apple Inc",
+ confidence=0.95,
+ context_snippet="Apple Inc reported strong earnings",
+ article_id="article_0",
+ event_type=EventCategory.EARNINGS,
+ ),
+ EntityMention(
+ company_name="Apple",
+ confidence=0.90,
+ context_snippet="Apple saw record iPhone sales",
+ article_id="article_1",
+ event_type=EventCategory.EARNINGS,
+ ),
+ EntityMention(
+ company_name="Apple Inc.",
+ confidence=0.92,
+ context_snippet="Apple announced AI features",
+ article_id="article_2",
+ event_type=EventCategory.PRODUCT_LAUNCH,
+ ),
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.resolve_ticker"
+ ) as mock_resolve:
+ mock_resolve.return_value = "AAPL"
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.classify_sector"
+ ) as mock_sector:
+ mock_sector.return_value = "technology"
+
+ result = calculate_trending_scores(mentions, articles)
+
+ assert len(result) == 1
+ assert result[0].ticker == "AAPL"
+ assert result[0].mention_count == 3
+
+
+class TestSentimentIntensityFactor:
+ def test_sentiment_intensity_uses_absolute_value(self):
+ from tradingagents.agents.discovery.scorer import calculate_trending_scores
+
+ now = datetime.now()
+ articles = [
+ NewsArticle(
+ title="Stock drops sharply",
+ source="Reuters",
+ url="https://reuters.com/article1",
+ published_at=now - timedelta(hours=1),
+ content_snippet="Company faced major issues.",
+ ticker_mentions=["TSLA"],
+ ),
+ NewsArticle(
+ title="More bad news",
+ source="Bloomberg",
+ url="https://bloomberg.com/article2",
+ published_at=now - timedelta(hours=2),
+ content_snippet="Further decline expected.",
+ ticker_mentions=["TSLA"],
+ ),
+ ]
+
+ mentions = [
+ EntityMention(
+ company_name="Tesla",
+ confidence=0.95,
+ context_snippet="Company faced major issues",
+ article_id="article_0",
+ event_type=EventCategory.OTHER,
+ sentiment=-0.8,
+ ),
+ EntityMention(
+ company_name="Tesla Inc",
+ confidence=0.90,
+ context_snippet="Further decline expected",
+ article_id="article_1",
+ event_type=EventCategory.OTHER,
+ sentiment=-0.6,
+ ),
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.resolve_ticker"
+ ) as mock_resolve:
+ mock_resolve.return_value = "TSLA"
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.classify_sector"
+ ) as mock_sector:
+ mock_sector.return_value = "technology"
+
+ result = calculate_trending_scores(mentions, articles)
+
+ assert len(result) == 1
+ assert result[0].sentiment < 0
+ expected_sentiment = (-0.8 * 0.95 + -0.6 * 0.90) / (0.95 + 0.90)
+ assert abs(result[0].sentiment - expected_sentiment) < 0.01
+
+
+class TestRecencyWeightExponentialDecay:
+ def test_recency_weight_exponential_decay(self):
+ from tradingagents.agents.discovery.scorer import calculate_trending_scores
+
+ now = datetime.now()
+ articles = [
+ NewsArticle(
+ title="Recent news",
+ source="Reuters",
+ url="https://reuters.com/article1",
+ published_at=now - timedelta(hours=1),
+ content_snippet="Recent company news.",
+ ticker_mentions=["NVDA"],
+ ),
+ NewsArticle(
+ title="Older news",
+ source="Bloomberg",
+ url="https://bloomberg.com/article2",
+ published_at=now - timedelta(hours=10),
+ content_snippet="Older company news.",
+ ticker_mentions=["NVDA"],
+ ),
+ ]
+
+ mentions = [
+ EntityMention(
+ company_name="Nvidia",
+ confidence=0.90,
+ context_snippet="Recent company news",
+ article_id="article_0",
+ event_type=EventCategory.OTHER,
+ sentiment=0.5,
+ ),
+ EntityMention(
+ company_name="Nvidia",
+ confidence=0.90,
+ context_snippet="Older company news",
+ article_id="article_1",
+ event_type=EventCategory.OTHER,
+ sentiment=0.5,
+ ),
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.resolve_ticker"
+ ) as mock_resolve:
+ mock_resolve.return_value = "NVDA"
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.classify_sector"
+ ) as mock_sector:
+ mock_sector.return_value = "technology"
+
+ result = calculate_trending_scores(mentions, articles, decay_rate=0.1)
+
+ assert len(result) == 1
+ recent_weight = math.exp(-0.1 * 1)
+ older_weight = math.exp(-0.1 * 10)
+ avg_recency = (recent_weight + older_weight) / 2
+ assert result[0].score > 0
+
+
+class TestMinimumThresholdFiltering:
+ def test_minimum_threshold_filtering_requires_two_articles(self):
+ from tradingagents.agents.discovery.scorer import calculate_trending_scores
+
+ now = datetime.now()
+ articles = [
+ NewsArticle(
+ title="Single mention stock",
+ source="Reuters",
+ url="https://reuters.com/article1",
+ published_at=now - timedelta(hours=1),
+ content_snippet="Some company news.",
+ ticker_mentions=["AMD"],
+ ),
+ NewsArticle(
+ title="Multiple mention stock 1",
+ source="Bloomberg",
+ url="https://bloomberg.com/article2",
+ published_at=now - timedelta(hours=2),
+ content_snippet="Popular company news.",
+ ticker_mentions=["MSFT"],
+ ),
+ NewsArticle(
+ title="Multiple mention stock 2",
+ source="WSJ",
+ url="https://wsj.com/article3",
+ published_at=now - timedelta(hours=3),
+ content_snippet="More popular company news.",
+ ticker_mentions=["MSFT"],
+ ),
+ ]
+
+ mentions = [
+ EntityMention(
+ company_name="AMD",
+ confidence=0.90,
+ context_snippet="Some company news",
+ article_id="article_0",
+ event_type=EventCategory.OTHER,
+ ),
+ EntityMention(
+ company_name="Microsoft",
+ confidence=0.95,
+ context_snippet="Popular company news",
+ article_id="article_1",
+ event_type=EventCategory.OTHER,
+ ),
+ EntityMention(
+ company_name="Microsoft Corp",
+ confidence=0.92,
+ context_snippet="More popular company news",
+ article_id="article_2",
+ event_type=EventCategory.OTHER,
+ ),
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.resolve_ticker"
+ ) as mock_resolve:
+
+ def resolve_side_effect(name):
+ if "AMD" in name or name == "AMD":
+ return "AMD"
+ return "MSFT"
+
+ mock_resolve.side_effect = resolve_side_effect
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.classify_sector"
+ ) as mock_sector:
+ mock_sector.return_value = "technology"
+
+ result = calculate_trending_scores(mentions, articles, min_mentions=2)
+
+ assert len(result) == 1
+ assert result[0].ticker == "MSFT"
+ assert all(stock.mention_count >= 2 for stock in result)
+
+
+class TestFinalScoreFormulaCorrectness:
+ def test_final_score_formula_correctness(self):
+ from tradingagents.agents.discovery.scorer import calculate_trending_scores
+
+ now = datetime.now()
+ hours_old = 2.0
+ articles = [
+ NewsArticle(
+ title="Test article 1",
+ source="Reuters",
+ url="https://reuters.com/article1",
+ published_at=now - timedelta(hours=hours_old),
+ content_snippet="Google announced results.",
+ ticker_mentions=["GOOGL"],
+ ),
+ NewsArticle(
+ title="Test article 2",
+ source="Bloomberg",
+ url="https://bloomberg.com/article2",
+ published_at=now - timedelta(hours=hours_old),
+ content_snippet="Alphabet earnings beat.",
+ ticker_mentions=["GOOGL"],
+ ),
+ ]
+
+ sentiment_val = 0.6
+ confidence = 0.9
+ mentions = [
+ EntityMention(
+ company_name="Google",
+ confidence=confidence,
+ context_snippet="Google announced results",
+ article_id="article_0",
+ event_type=EventCategory.EARNINGS,
+ sentiment=sentiment_val,
+ ),
+ EntityMention(
+ company_name="Alphabet",
+ confidence=confidence,
+ context_snippet="Alphabet earnings beat",
+ article_id="article_1",
+ event_type=EventCategory.EARNINGS,
+ sentiment=sentiment_val,
+ ),
+ ]
+
+ decay_rate = 0.1
+ with patch(
+ "tradingagents.agents.discovery.scorer.resolve_ticker"
+ ) as mock_resolve:
+ mock_resolve.return_value = "GOOGL"
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.classify_sector"
+ ) as mock_sector:
+ mock_sector.return_value = "technology"
+
+ result = calculate_trending_scores(
+ mentions, articles, decay_rate=decay_rate
+ )
+
+ assert len(result) == 1
+ stock = result[0]
+
+ frequency = 2
+ sentiment_factor = 1 + abs(sentiment_val)
+ recency_weight = math.exp(-decay_rate * hours_old)
+ expected_score = frequency * sentiment_factor * recency_weight
+
+ assert abs(stock.score - expected_score) < 0.01
+
+
+class TestSortingByScoreDescending:
+ def test_results_sorted_by_score_descending(self):
+ from tradingagents.agents.discovery.scorer import calculate_trending_scores
+
+ now = datetime.now()
+ articles = [
+ NewsArticle(
+ title="High score stock 1",
+ source="Reuters",
+ url="https://reuters.com/article1",
+ published_at=now - timedelta(hours=1),
+ content_snippet="Apple news.",
+ ticker_mentions=["AAPL"],
+ ),
+ NewsArticle(
+ title="High score stock 2",
+ source="Bloomberg",
+ url="https://bloomberg.com/article2",
+ published_at=now - timedelta(hours=1),
+ content_snippet="More Apple news.",
+ ticker_mentions=["AAPL"],
+ ),
+ NewsArticle(
+ title="High score stock 3",
+ source="WSJ",
+ url="https://wsj.com/article3",
+ published_at=now - timedelta(hours=1),
+ content_snippet="Even more Apple news.",
+ ticker_mentions=["AAPL"],
+ ),
+ NewsArticle(
+ title="Low score stock 1",
+ source="CNBC",
+ url="https://cnbc.com/article4",
+ published_at=now - timedelta(hours=10),
+ content_snippet="Tesla news.",
+ ticker_mentions=["TSLA"],
+ ),
+ NewsArticle(
+ title="Low score stock 2",
+ source="FT",
+ url="https://ft.com/article5",
+ published_at=now - timedelta(hours=10),
+ content_snippet="More Tesla news.",
+ ticker_mentions=["TSLA"],
+ ),
+ ]
+
+ mentions = [
+ EntityMention(
+ company_name="Apple",
+ confidence=0.95,
+ context_snippet="Apple news",
+ article_id="article_0",
+ event_type=EventCategory.OTHER,
+ sentiment=0.8,
+ ),
+ EntityMention(
+ company_name="Apple Inc",
+ confidence=0.93,
+ context_snippet="More Apple news",
+ article_id="article_1",
+ event_type=EventCategory.OTHER,
+ sentiment=0.8,
+ ),
+ EntityMention(
+ company_name="Apple",
+ confidence=0.90,
+ context_snippet="Even more Apple news",
+ article_id="article_2",
+ event_type=EventCategory.OTHER,
+ sentiment=0.8,
+ ),
+ EntityMention(
+ company_name="Tesla",
+ confidence=0.85,
+ context_snippet="Tesla news",
+ article_id="article_3",
+ event_type=EventCategory.OTHER,
+ sentiment=0.2,
+ ),
+ EntityMention(
+ company_name="Tesla Inc",
+ confidence=0.85,
+ context_snippet="More Tesla news",
+ article_id="article_4",
+ event_type=EventCategory.OTHER,
+ sentiment=0.2,
+ ),
+ ]
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.resolve_ticker"
+ ) as mock_resolve:
+
+ def resolve_side_effect(name):
+ if "Apple" in name:
+ return "AAPL"
+ if "Tesla" in name:
+ return "TSLA"
+ return None
+
+ mock_resolve.side_effect = resolve_side_effect
+
+ with patch(
+ "tradingagents.agents.discovery.scorer.classify_sector"
+ ) as mock_sector:
+ mock_sector.return_value = "technology"
+
+ result = calculate_trending_scores(mentions, articles, min_mentions=2)
+
+ assert len(result) == 2
+ for i in range(len(result) - 1):
+ assert result[i].score >= result[i + 1].score
diff --git a/tests/discovery/test_sector_classifier.py b/tests/discovery/test_sector_classifier.py
new file mode 100644
index 00000000..15458e58
--- /dev/null
+++ b/tests/discovery/test_sector_classifier.py
@@ -0,0 +1,94 @@
+import pytest
+from unittest.mock import patch, MagicMock
+from tradingagents.dataflows.trending.sector_classifier import (
+ classify_sector,
+ TICKER_TO_SECTOR,
+ VALID_SECTORS,
+ _llm_classify_sector,
+ _sector_cache,
+)
+
+
+class TestStaticSectorMapping:
+ def test_static_sector_mapping_for_known_technology_tickers(self):
+ assert classify_sector("AAPL") == "technology"
+ assert classify_sector("MSFT") == "technology"
+ assert classify_sector("GOOGL") == "technology"
+ assert classify_sector("NVDA") == "technology"
+
+ def test_static_sector_mapping_for_known_healthcare_tickers(self):
+ assert classify_sector("JNJ") == "healthcare"
+ assert classify_sector("PFE") == "healthcare"
+ assert classify_sector("UNH") == "healthcare"
+
+ def test_static_sector_mapping_for_known_finance_tickers(self):
+ assert classify_sector("JPM") == "finance"
+ assert classify_sector("BAC") == "finance"
+ assert classify_sector("GS") == "finance"
+
+ def test_static_sector_mapping_for_known_energy_tickers(self):
+ assert classify_sector("XOM") == "energy"
+ assert classify_sector("CVX") == "energy"
+ assert classify_sector("COP") == "energy"
+
+ def test_static_sector_mapping_case_insensitive(self):
+ assert classify_sector("aapl") == "technology"
+ assert classify_sector("AAPL") == "technology"
+ assert classify_sector("Aapl") == "technology"
+
+
+class TestLLMFallback:
+ @patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
+ def test_llm_fallback_for_unknown_tickers(self, mock_llm_classify):
+ mock_llm_classify.return_value = "technology"
+ _sector_cache.clear()
+
+ result = classify_sector("UNKNOWNTICKER123")
+
+ mock_llm_classify.assert_called_once_with("UNKNOWNTICKER123")
+ assert result == "technology"
+
+ @patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
+ def test_llm_fallback_caches_results(self, mock_llm_classify):
+ mock_llm_classify.return_value = "healthcare"
+ _sector_cache.clear()
+
+ result1 = classify_sector("NEWCO123")
+ result2 = classify_sector("NEWCO123")
+
+ assert mock_llm_classify.call_count == 1
+ assert result1 == "healthcare"
+ assert result2 == "healthcare"
+
+ @patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector")
+ def test_llm_fallback_returns_other_on_error(self, mock_llm_classify):
+ mock_llm_classify.side_effect = Exception("LLM error")
+ _sector_cache.clear()
+
+ result = classify_sector("ERRORCO")
+
+ assert result == "other"
+
+
+class TestAllSectorCategories:
+ def test_all_sector_categories_in_valid_sectors(self):
+ expected_sectors = {
+ "technology",
+ "healthcare",
+ "finance",
+ "energy",
+ "consumer_goods",
+ "industrials",
+ "other",
+ }
+ assert VALID_SECTORS == expected_sectors
+
+ def test_static_mapping_covers_all_sector_categories(self):
+ sectors_in_mapping = set(TICKER_TO_SECTOR.values())
+ assert sectors_in_mapping.issubset(VALID_SECTORS)
+
+ def test_classify_sector_always_returns_valid_sector(self):
+ test_tickers = ["AAPL", "JPM", "XOM", "JNJ", "WMT", "CAT"]
+ for ticker in test_tickers:
+ result = classify_sector(ticker)
+ assert result in VALID_SECTORS
diff --git a/tests/discovery/test_stock_resolver.py b/tests/discovery/test_stock_resolver.py
new file mode 100644
index 00000000..96f5b455
--- /dev/null
+++ b/tests/discovery/test_stock_resolver.py
@@ -0,0 +1,135 @@
+import pytest
+import logging
+from unittest.mock import patch, MagicMock
+from tradingagents.dataflows.trending.stock_resolver import (
+ resolve_ticker,
+ validate_us_ticker,
+ _normalize_company_name,
+ _search_yfinance_ticker,
+)
+
+
+class TestStaticLookup:
+ def test_static_lookup_for_known_companies(self):
+ assert resolve_ticker("Apple") == "AAPL"
+ assert resolve_ticker("Microsoft") == "MSFT"
+ assert resolve_ticker("Google") == "GOOGL"
+ assert resolve_ticker("Amazon") == "AMZN"
+ assert resolve_ticker("Tesla") == "TSLA"
+ assert resolve_ticker("Nvidia") == "NVDA"
+
+ def test_static_lookup_case_insensitive(self):
+ assert resolve_ticker("APPLE") == "AAPL"
+ assert resolve_ticker("apple") == "AAPL"
+ assert resolve_ticker("ApPlE") == "AAPL"
+ assert resolve_ticker("microsoft") == "MSFT"
+ assert resolve_ticker("MICROSOFT") == "MSFT"
+
+
+class TestNameVariationHandling:
+ def test_name_variation_handling_with_suffixes(self):
+ assert resolve_ticker("Apple Inc.") == "AAPL"
+ assert resolve_ticker("Apple Inc") == "AAPL"
+ assert resolve_ticker("Apple Corporation") == "AAPL"
+ assert resolve_ticker("Microsoft Corp.") == "MSFT"
+ assert resolve_ticker("Microsoft Corp") == "MSFT"
+ assert resolve_ticker("Tesla Inc") == "TSLA"
+
+ def test_name_variation_handling_informal_names(self):
+ assert resolve_ticker("the iPhone maker") == "AAPL"
+ assert resolve_ticker("iPhone maker") == "AAPL"
+ assert resolve_ticker("the search giant") == "GOOGL"
+ assert resolve_ticker("the e-commerce giant") == "AMZN"
+ assert resolve_ticker("EV maker Tesla") == "TSLA"
+
+ def test_name_variation_handling_alternate_names(self):
+ assert resolve_ticker("Alphabet") == "GOOGL"
+ assert resolve_ticker("Meta") == "META"
+ assert resolve_ticker("Facebook") == "META"
+ assert resolve_ticker("Meta Platforms") == "META"
+
+
+class TestYfinanceFallback:
+ @patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
+ @patch("tradingagents.dataflows.trending.stock_resolver.validate_us_ticker")
+ def test_yfinance_fallback_for_unknown_company(self, mock_validate, mock_search):
+ mock_search.return_value = "PLTR"
+ mock_validate.return_value = True
+
+ result = resolve_ticker("UnknownTechStartupXYZ")
+
+ mock_search.assert_called_once()
+ assert result == "PLTR"
+
+ @patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
+ def test_yfinance_fallback_returns_none_when_not_found(self, mock_search):
+ mock_search.return_value = None
+
+ result = resolve_ticker("NonexistentCompanyXYZ123")
+
+ assert result is None
+
+
+class TestUSExchangeValidation:
+ @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
+ def test_validate_us_ticker_accepts_nyse(self, mock_ticker):
+ mock_info = {"exchange": "NYQ"}
+ mock_ticker.return_value.info = mock_info
+
+ assert validate_us_ticker("IBM") is True
+
+ @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
+ def test_validate_us_ticker_accepts_nasdaq(self, mock_ticker):
+ mock_info = {"exchange": "NMS"}
+ mock_ticker.return_value.info = mock_info
+
+ assert validate_us_ticker("AAPL") is True
+
+ @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
+ def test_validate_us_ticker_accepts_amex(self, mock_ticker):
+ mock_info = {"exchange": "ASE"}
+ mock_ticker.return_value.info = mock_info
+
+ assert validate_us_ticker("SPY") is True
+
+ @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
+ def test_validate_us_ticker_rejects_international(self, mock_ticker):
+ mock_info = {"exchange": "LSE"}
+ mock_ticker.return_value.info = mock_info
+
+ assert validate_us_ticker("VOD.L") is False
+
+ @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
+ def test_validate_us_ticker_rejects_otc(self, mock_ticker):
+ mock_info = {"exchange": "PNK"}
+ mock_ticker.return_value.info = mock_info
+
+ assert validate_us_ticker("OTCPK") is False
+
+
+class TestAmbiguousResolutionLogging:
+ def test_ambiguous_resolution_logs_multiple_matches(self, caplog):
+ with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"):
+ pass
+
+ @patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker")
+ @patch("tradingagents.dataflows.trending.stock_resolver.validate_us_ticker")
+ def test_yfinance_fallback_is_logged(self, mock_validate, mock_search, caplog):
+ mock_search.return_value = "RBLX"
+ mock_validate.return_value = True
+
+ with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"):
+ result = resolve_ticker("SomeRandomCompanyNotInMapping")
+
+ assert any("fallback" in record.message.lower() or "yfinance" in record.message.lower()
+ for record in caplog.records)
+
+ @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker")
+ def test_validation_failure_is_logged(self, mock_ticker, caplog):
+ mock_info = {"exchange": "LSE"}
+ mock_ticker.return_value.info = mock_info
+
+ with caplog.at_level(logging.WARNING, logger="tradingagents.dataflows.trending.stock_resolver"):
+ result = validate_us_ticker("VOD.L")
+
+ assert result is False
diff --git a/tradingagents/agents/discovery/__init__.py b/tradingagents/agents/discovery/__init__.py
new file mode 100644
index 00000000..0effd7ce
--- /dev/null
+++ b/tradingagents/agents/discovery/__init__.py
@@ -0,0 +1,53 @@
+from .models import (
+ NewsArticle,
+ TrendingStock,
+ DiscoveryRequest,
+ DiscoveryResult,
+ DiscoveryStatus,
+ Sector,
+ EventCategory,
+)
+from .exceptions import (
+ DiscoveryError,
+ NewsUnavailableError,
+ DiscoveryTimeoutError,
+ TickerResolutionError,
+)
+from .entity_extractor import (
+ EntityMention,
+ extract_entities,
+ BATCH_SIZE,
+)
+from .scorer import (
+ calculate_trending_scores,
+ DEFAULT_DECAY_RATE,
+ DEFAULT_MAX_RESULTS,
+ DEFAULT_MIN_MENTIONS,
+)
+from .persistence import (
+ save_discovery_result,
+ generate_markdown_summary,
+)
+
+__all__ = [
+ "NewsArticle",
+ "TrendingStock",
+ "DiscoveryRequest",
+ "DiscoveryResult",
+ "DiscoveryStatus",
+ "Sector",
+ "EventCategory",
+ "DiscoveryError",
+ "NewsUnavailableError",
+ "DiscoveryTimeoutError",
+ "TickerResolutionError",
+ "EntityMention",
+ "extract_entities",
+ "BATCH_SIZE",
+ "calculate_trending_scores",
+ "DEFAULT_DECAY_RATE",
+ "DEFAULT_MAX_RESULTS",
+ "DEFAULT_MIN_MENTIONS",
+ "save_discovery_result",
+ "generate_markdown_summary",
+]
diff --git a/tradingagents/agents/discovery/entity_extractor.py b/tradingagents/agents/discovery/entity_extractor.py
new file mode 100644
index 00000000..5bad3242
--- /dev/null
+++ b/tradingagents/agents/discovery/entity_extractor.py
@@ -0,0 +1,159 @@
+from dataclasses import dataclass, field
+from typing import List, Optional
+from pydantic import BaseModel, Field as PydanticField
+
+from langchain_openai import ChatOpenAI
+from langchain_anthropic import ChatAnthropic
+from langchain_google_genai import ChatGoogleGenerativeAI
+
+from tradingagents.default_config import DEFAULT_CONFIG
+from tradingagents.agents.discovery.models import NewsArticle, EventCategory
+
+
+BATCH_SIZE = 10
+
+
+@dataclass
+class EntityMention:
+ company_name: str
+ confidence: float
+ context_snippet: str
+ article_id: str
+ event_type: EventCategory
+ sentiment: float = field(default=0.0)
+
+
+class ExtractedEntity(BaseModel):
+ company_name: str = PydanticField(description="The name of the publicly traded company mentioned")
+ confidence: float = PydanticField(description="Confidence score from 0.0 to 1.0 based on mention clarity")
+ context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention")
+ event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other")
+ sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)")
+
+
+class ExtractionResponse(BaseModel):
+ entities: List[ExtractedEntity] = PydanticField(default_factory=list, description="List of extracted company entities")
+
+
+def _get_llm(config: Optional[dict] = None):
+ cfg = config or DEFAULT_CONFIG
+ provider = cfg.get("llm_provider", "openai").lower()
+ model = cfg.get("quick_think_llm", "gpt-4o-mini")
+ backend_url = cfg.get("backend_url", "https://api.openai.com/v1")
+
+ if provider in ("openai", "ollama", "openrouter"):
+ return ChatOpenAI(model=model, base_url=backend_url)
+ elif provider == "anthropic":
+ return ChatAnthropic(model=model, base_url=backend_url)
+ elif provider == "google":
+ return ChatGoogleGenerativeAI(model=model)
+ else:
+ raise ValueError(f"Unsupported LLM provider: {provider}")
+
+
+EXTRACTION_PROMPT = """You are an expert at identifying publicly traded companies mentioned in news articles.
+
+For each article provided, extract all mentions of publicly traded companies. For each company mention:
+
+1. Extract the company name as it appears (e.g., "Apple Inc.", "Apple", "AAPL", "the iPhone maker")
+2. Assign a confidence score from 0.0 to 1.0 based on how clearly the company is mentioned:
+ - 0.9-1.0: Direct company name or ticker symbol
+ - 0.7-0.9: Clear reference with context (e.g., "the Cupertino tech giant")
+ - 0.5-0.7: Indirect reference requiring inference
+ - Below 0.5: Uncertain or ambiguous reference
+3. Extract 50-100 characters of surrounding context
+4. Classify the event type:
+ - earnings: Quarterly/annual earnings reports, revenue announcements
+ - merger_acquisition: Mergers, acquisitions, buyouts, takeovers
+ - regulatory: SEC filings, government investigations, compliance issues
+ - product_launch: New products, services, or features
+ - executive_change: CEO/CFO changes, board appointments, departures
+ - other: Any other business news
+5. Assign a sentiment score from -1.0 to 1.0:
+ - -1.0: Very negative news (lawsuits, crashes, major failures)
+ - -0.5: Moderately negative news
+ - 0.0: Neutral news
+ - 0.5: Moderately positive news
+ - 1.0: Very positive news (breakthroughs, record earnings)
+
+Only extract companies that are publicly traded on major stock exchanges.
+Handle name variations by providing the most complete company name found.
+
+Articles to analyze:
+{articles_text}
+
+Extract all company mentions from the articles above."""
+
+
+def _format_articles_for_prompt(articles: List[NewsArticle], start_idx: int) -> str:
+ formatted = []
+ for i, article in enumerate(articles):
+ article_id = f"article_{start_idx + i}"
+ formatted.append(
+ f"[{article_id}]\n"
+ f"Title: {article.title}\n"
+ f"Source: {article.source}\n"
+ f"Content: {article.content_snippet}\n"
+ )
+ return "\n---\n".join(formatted)
+
+
+def _extract_batch(
+ articles: List[NewsArticle],
+ start_idx: int,
+ llm,
+) -> List[EntityMention]:
+ if not articles:
+ return []
+
+ articles_text = _format_articles_for_prompt(articles, start_idx)
+ prompt = EXTRACTION_PROMPT.format(articles_text=articles_text)
+
+ structured_llm = llm.with_structured_output(ExtractionResponse)
+ response = structured_llm.invoke(prompt)
+
+ mentions = []
+ for entity in response.entities:
+ event_type_str = entity.event_type.lower().strip()
+ valid_event_types = {e.value for e in EventCategory}
+ if event_type_str not in valid_event_types:
+ event_type_str = "other"
+
+ confidence = max(0.0, min(1.0, entity.confidence))
+ sentiment = max(-1.0, min(1.0, entity.sentiment))
+
+ context = entity.context_snippet
+ if len(context) > 150:
+ context = context[:147] + "..."
+
+ mention = EntityMention(
+ company_name=entity.company_name,
+ confidence=confidence,
+ context_snippet=context,
+ article_id=f"article_{start_idx}",
+ event_type=EventCategory(event_type_str),
+ sentiment=sentiment,
+ )
+ mentions.append(mention)
+
+ return mentions
+
+
+def extract_entities(
+ articles: List[NewsArticle],
+ config: Optional[dict] = None,
+) -> List[EntityMention]:
+ if not articles:
+ return []
+
+ llm = _get_llm(config)
+ all_mentions: List[EntityMention] = []
+
+ for batch_start in range(0, len(articles), BATCH_SIZE):
+ batch_end = min(batch_start + BATCH_SIZE, len(articles))
+ batch = articles[batch_start:batch_end]
+
+ batch_mentions = _extract_batch(batch, batch_start, llm)
+ all_mentions.extend(batch_mentions)
+
+ return all_mentions
diff --git a/tradingagents/agents/discovery/exceptions.py b/tradingagents/agents/discovery/exceptions.py
new file mode 100644
index 00000000..94953c62
--- /dev/null
+++ b/tradingagents/agents/discovery/exceptions.py
@@ -0,0 +1,14 @@
+class DiscoveryError(Exception):
+ pass
+
+
+class NewsUnavailableError(DiscoveryError):
+ pass
+
+
+class DiscoveryTimeoutError(DiscoveryError):
+ pass
+
+
+class TickerResolutionError(DiscoveryError):
+ pass
diff --git a/tradingagents/agents/discovery/models.py b/tradingagents/agents/discovery/models.py
new file mode 100644
index 00000000..9595f89d
--- /dev/null
+++ b/tradingagents/agents/discovery/models.py
@@ -0,0 +1,180 @@
+from dataclasses import dataclass, field
+from datetime import datetime
+from enum import Enum
+from typing import List, Optional, Dict, Any
+
+
+class DiscoveryStatus(Enum):
+ CREATED = "created"
+ PROCESSING = "processing"
+ COMPLETED = "completed"
+ FAILED = "failed"
+
+
+class Sector(Enum):
+ TECHNOLOGY = "technology"
+ HEALTHCARE = "healthcare"
+ FINANCE = "finance"
+ ENERGY = "energy"
+ CONSUMER_GOODS = "consumer_goods"
+ INDUSTRIALS = "industrials"
+ OTHER = "other"
+
+
+class EventCategory(Enum):
+ EARNINGS = "earnings"
+ MERGER_ACQUISITION = "merger_acquisition"
+ REGULATORY = "regulatory"
+ PRODUCT_LAUNCH = "product_launch"
+ EXECUTIVE_CHANGE = "executive_change"
+ OTHER = "other"
+
+
+@dataclass
+class NewsArticle:
+ title: str
+ source: str
+ url: str
+ published_at: datetime
+ content_snippet: str
+ ticker_mentions: List[str]
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "title": self.title,
+ "source": self.source,
+ "url": self.url,
+ "published_at": self.published_at.isoformat(),
+ "content_snippet": self.content_snippet,
+ "ticker_mentions": self.ticker_mentions,
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "NewsArticle":
+ return cls(
+ title=data["title"],
+ source=data["source"],
+ url=data["url"],
+ published_at=datetime.fromisoformat(data["published_at"]),
+ content_snippet=data["content_snippet"],
+ ticker_mentions=data["ticker_mentions"],
+ )
+
+
+@dataclass
+class TrendingStock:
+ ticker: str
+ company_name: str
+ score: float
+ mention_count: int
+ sentiment: float
+ sector: Sector
+ event_type: EventCategory
+ news_summary: str
+ source_articles: List[NewsArticle]
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "ticker": self.ticker,
+ "company_name": self.company_name,
+ "score": self.score,
+ "mention_count": self.mention_count,
+ "sentiment": self.sentiment,
+ "sector": self.sector.value,
+ "event_type": self.event_type.value,
+ "news_summary": self.news_summary,
+ "source_articles": [article.to_dict() for article in self.source_articles],
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "TrendingStock":
+ return cls(
+ ticker=data["ticker"],
+ company_name=data["company_name"],
+ score=data["score"],
+ mention_count=data["mention_count"],
+ sentiment=data["sentiment"],
+ sector=Sector(data["sector"]),
+ event_type=EventCategory(data["event_type"]),
+ news_summary=data["news_summary"],
+ source_articles=[
+ NewsArticle.from_dict(article) for article in data["source_articles"]
+ ],
+ )
+
+
+@dataclass
+class DiscoveryRequest:
+ lookback_period: str
+ sector_filter: Optional[List[Sector]] = None
+ event_filter: Optional[List[EventCategory]] = None
+ max_results: int = 20
+ created_at: datetime = field(default_factory=datetime.now)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "lookback_period": self.lookback_period,
+ "sector_filter": (
+ [s.value for s in self.sector_filter] if self.sector_filter else None
+ ),
+ "event_filter": (
+ [e.value for e in self.event_filter] if self.event_filter else None
+ ),
+ "max_results": self.max_results,
+ "created_at": self.created_at.isoformat(),
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryRequest":
+ return cls(
+ lookback_period=data["lookback_period"],
+ sector_filter=(
+ [Sector(s) for s in data["sector_filter"]]
+ if data.get("sector_filter")
+ else None
+ ),
+ event_filter=(
+ [EventCategory(e) for e in data["event_filter"]]
+ if data.get("event_filter")
+ else None
+ ),
+ max_results=data.get("max_results", 20),
+ created_at=datetime.fromisoformat(data["created_at"]),
+ )
+
+
+@dataclass
+class DiscoveryResult:
+ request: DiscoveryRequest
+ trending_stocks: List[TrendingStock]
+ status: DiscoveryStatus
+ started_at: datetime
+ completed_at: Optional[datetime] = None
+ error_message: Optional[str] = None
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "request": self.request.to_dict(),
+ "trending_stocks": [stock.to_dict() for stock in self.trending_stocks],
+ "status": self.status.value,
+ "started_at": self.started_at.isoformat(),
+ "completed_at": self.completed_at.isoformat() if self.completed_at else None,
+ "error_message": self.error_message,
+ }
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryResult":
+ return cls(
+ request=DiscoveryRequest.from_dict(data["request"]),
+ trending_stocks=[
+ TrendingStock.from_dict(stock) for stock in data["trending_stocks"]
+ ],
+ status=DiscoveryStatus(data["status"]),
+ started_at=datetime.fromisoformat(data["started_at"]),
+ completed_at=(
+ datetime.fromisoformat(data["completed_at"])
+ if data.get("completed_at")
+ else None
+ ),
+ error_message=data.get("error_message"),
+ )
diff --git a/tradingagents/agents/discovery/persistence.py b/tradingagents/agents/discovery/persistence.py
new file mode 100644
index 00000000..68f2bd03
--- /dev/null
+++ b/tradingagents/agents/discovery/persistence.py
@@ -0,0 +1,120 @@
+import json
+from datetime import datetime
+from pathlib import Path
+from typing import Optional
+
+from .models import DiscoveryResult, TrendingStock
+
+
+def save_discovery_result(
+ result: DiscoveryResult,
+ base_path: Optional[Path] = None,
+) -> Path:
+ if base_path is None:
+ base_path = Path("results")
+
+ timestamp = result.completed_at or result.started_at
+ date_str = timestamp.strftime("%Y-%m-%d")
+ time_str = timestamp.strftime("%H-%M-%S")
+
+ result_dir = base_path / "discovery" / date_str / time_str
+ result_dir.mkdir(parents=True, exist_ok=True)
+
+ json_path = result_dir / "discovery_result.json"
+ with open(json_path, "w") as f:
+ json.dump(result.to_dict(), f, indent=2)
+
+ md_path = result_dir / "discovery_summary.md"
+ markdown_content = generate_markdown_summary(result)
+ with open(md_path, "w") as f:
+ f.write(markdown_content)
+
+ return result_dir
+
+
+def generate_markdown_summary(result: DiscoveryResult) -> str:
+ lines = []
+
+ lines.append("# Discovery Results")
+ lines.append("")
+
+ timestamp = result.completed_at or result.started_at
+ lines.append(f"**Timestamp:** {timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
+ lines.append(f"**Lookback Period:** {result.request.lookback_period}")
+
+ filters = _format_filters(result)
+ lines.append(f"**Filters:** {filters}")
+ lines.append(f"**Total Stocks Found:** {len(result.trending_stocks)}")
+ lines.append("")
+
+ lines.append("## Trending Stocks")
+ lines.append("")
+ lines.append("| Rank | Ticker | Company | Score | Mentions | Event |")
+ lines.append("|------|--------|---------|-------|----------|-------|")
+
+ for rank, stock in enumerate(result.trending_stocks, 1):
+ lines.append(
+ f"| {rank} | {stock.ticker} | {stock.company_name} | "
+ f"{stock.score:.2f} | {stock.mention_count} | {stock.event_type.value} |"
+ )
+
+ lines.append("")
+
+ lines.append("## Top 3 Detailed Analysis")
+ lines.append("")
+
+ top_stocks = result.trending_stocks[:3]
+ for rank, stock in enumerate(top_stocks, 1):
+ lines.extend(_format_stock_detail(rank, stock))
+
+ return "\n".join(lines)
+
+
+def _format_filters(result: DiscoveryResult) -> str:
+ filter_parts = []
+
+ if result.request.sector_filter:
+ sector_values = [s.value for s in result.request.sector_filter]
+ filter_parts.append(f"sector={','.join(sector_values)}")
+
+ if result.request.event_filter:
+ event_values = [e.value for e in result.request.event_filter]
+ filter_parts.append(f"event={','.join(event_values)}")
+
+ if filter_parts:
+ return " ".join(filter_parts)
+ return "None"
+
+
+def _format_stock_detail(rank: int, stock: TrendingStock) -> list:
+ lines = []
+
+ lines.append(f"### {rank}. {stock.ticker} - {stock.company_name}")
+ lines.append(f"- **Score:** {stock.score:.2f}")
+
+ sentiment_label = _get_sentiment_label(stock.sentiment)
+ lines.append(f"- **Sentiment:** {stock.sentiment:.2f} ({sentiment_label})")
+ lines.append(f"- **Sector:** {stock.sector.value}")
+ lines.append(f"- **Event Type:** {stock.event_type.value}")
+ lines.append(f"- **Mentions:** {stock.mention_count}")
+ lines.append("")
+
+ lines.append("**News Summary:**")
+ lines.append(stock.news_summary)
+ lines.append("")
+
+ if stock.source_articles:
+ lines.append("**Top Sources:**")
+ for article in stock.source_articles[:3]:
+ lines.append(f"- [{article.title}] - {article.source}")
+ lines.append("")
+
+ return lines
+
+
+def _get_sentiment_label(sentiment: float) -> str:
+ if sentiment > 0.3:
+ return "positive"
+ elif sentiment < -0.3:
+ return "negative"
+ return "neutral"
diff --git a/tradingagents/agents/discovery/scorer.py b/tradingagents/agents/discovery/scorer.py
new file mode 100644
index 00000000..564bc717
--- /dev/null
+++ b/tradingagents/agents/discovery/scorer.py
@@ -0,0 +1,153 @@
+import math
+from collections import defaultdict
+from datetime import datetime
+from typing import List, Dict, Optional
+
+from tradingagents.agents.discovery.models import (
+ TrendingStock,
+ NewsArticle,
+ Sector,
+ EventCategory,
+)
+from tradingagents.agents.discovery.entity_extractor import EntityMention
+from tradingagents.dataflows.trending.stock_resolver import resolve_ticker
+from tradingagents.dataflows.trending.sector_classifier import classify_sector
+
+
+DEFAULT_DECAY_RATE = 0.1
+DEFAULT_MAX_RESULTS = 20
+DEFAULT_MIN_MENTIONS = 2
+
+
+def _aggregate_sentiment(mentions: List[EntityMention]) -> float:
+ if not mentions:
+ return 0.0
+
+ total_weighted_sentiment = 0.0
+ total_confidence = 0.0
+
+ for mention in mentions:
+ total_weighted_sentiment += mention.sentiment * mention.confidence
+ total_confidence += mention.confidence
+
+ if total_confidence == 0:
+ return 0.0
+
+ return total_weighted_sentiment / total_confidence
+
+
+def _calculate_recency_weight(
+ articles: List[NewsArticle],
+ article_ids: set,
+ decay_rate: float,
+) -> float:
+ if not articles:
+ return 1.0
+
+ now = datetime.now()
+ weights = []
+
+ for i, article in enumerate(articles):
+ article_id = f"article_{i}"
+ if article_id in article_ids:
+ hours_old = (now - article.published_at).total_seconds() / 3600.0
+ weight = math.exp(-decay_rate * hours_old)
+ weights.append(weight)
+
+ if not weights:
+ return 1.0
+
+ return sum(weights) / len(weights)
+
+
+def _get_most_common_event_type(mentions: List[EntityMention]) -> EventCategory:
+ if not mentions:
+ return EventCategory.OTHER
+
+ event_counts: Dict[EventCategory, int] = defaultdict(int)
+ for mention in mentions:
+ event_counts[mention.event_type] += 1
+
+ return max(event_counts.keys(), key=lambda e: event_counts[e])
+
+
+def _build_news_summary(mentions: List[EntityMention]) -> str:
+ if not mentions:
+ return ""
+
+ snippets = [m.context_snippet for m in mentions[:3]]
+ return " ".join(snippets)
+
+
+def calculate_trending_scores(
+ mentions: List[EntityMention],
+ articles: List[NewsArticle],
+ decay_rate: float = DEFAULT_DECAY_RATE,
+ max_results: int = DEFAULT_MAX_RESULTS,
+ min_mentions: int = DEFAULT_MIN_MENTIONS,
+) -> List[TrendingStock]:
+ if not mentions:
+ return []
+
+ ticker_mentions: Dict[str, List[EntityMention]] = defaultdict(list)
+ ticker_company_names: Dict[str, str] = {}
+
+ for mention in mentions:
+ ticker = resolve_ticker(mention.company_name)
+ if ticker:
+ ticker_mentions[ticker].append(mention)
+ if ticker not in ticker_company_names:
+ ticker_company_names[ticker] = mention.company_name
+
+ article_index: Dict[str, int] = {}
+ for i, article in enumerate(articles):
+ article_index[f"article_{i}"] = i
+
+ trending_stocks: List[TrendingStock] = []
+
+ for ticker, ticker_mention_list in ticker_mentions.items():
+ article_ids = {m.article_id for m in ticker_mention_list}
+ frequency = len(article_ids)
+
+ if frequency < min_mentions:
+ continue
+
+ sentiment = _aggregate_sentiment(ticker_mention_list)
+ sentiment_factor = 1 + abs(sentiment)
+
+ recency_weight = _calculate_recency_weight(articles, article_ids, decay_rate)
+
+ score = frequency * sentiment_factor * recency_weight
+
+ sector_str = classify_sector(ticker)
+ try:
+ sector = Sector(sector_str)
+ except ValueError:
+ sector = Sector.OTHER
+
+ event_type = _get_most_common_event_type(ticker_mention_list)
+
+ source_article_list: List[NewsArticle] = []
+ for article_id in article_ids:
+ idx = article_index.get(article_id)
+ if idx is not None and idx < len(articles):
+ source_article_list.append(articles[idx])
+
+ news_summary = _build_news_summary(ticker_mention_list)
+
+ trending_stock = TrendingStock(
+ ticker=ticker,
+ company_name=ticker_company_names.get(ticker, ticker),
+ score=score,
+ mention_count=frequency,
+ sentiment=sentiment,
+ sector=sector,
+ event_type=event_type,
+ news_summary=news_summary,
+ source_articles=source_article_list,
+ )
+ trending_stocks.append(trending_stock)
+
+ trending_stocks.sort(key=lambda s: s.score, reverse=True)
+
+ return trending_stocks[:max_results]
diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py
index 3a859ea1..4a1ce0ce 100644
--- a/tradingagents/agents/utils/agent_states.py
+++ b/tradingagents/agents/utils/agent_states.py
@@ -7,44 +7,44 @@ from langgraph.prebuilt import ToolNode
from langgraph.graph import END, StateGraph, START, MessagesState
-# Researcher team state
class InvestDebateState(TypedDict):
+ """Researcher team state"""
bull_history: Annotated[
str, "Bullish Conversation history"
- ] # Bullish Conversation history
+ ]
bear_history: Annotated[
str, "Bearish Conversation history"
- ] # Bullish Conversation history
- history: Annotated[str, "Conversation history"] # Conversation history
- current_response: Annotated[str, "Latest response"] # Last response
- judge_decision: Annotated[str, "Final judge decision"] # Last response
- count: Annotated[int, "Length of the current conversation"] # Conversation length
+ ]
+ history: Annotated[str, "Conversation history"]
+ current_response: Annotated[str, "Latest response"]
+ judge_decision: Annotated[str, "Final judge decision"]
+ count: Annotated[int, "Length of the current conversation"]
-# Risk management team state
class RiskDebateState(TypedDict):
+ """Risk management team state"""
risky_history: Annotated[
str, "Risky Agent's Conversation history"
- ] # Conversation history
+ ]
safe_history: Annotated[
str, "Safe Agent's Conversation history"
- ] # Conversation history
+ ]
neutral_history: Annotated[
str, "Neutral Agent's Conversation history"
- ] # Conversation history
- history: Annotated[str, "Conversation history"] # Conversation history
+ ]
+ history: Annotated[str, "Conversation history"]
latest_speaker: Annotated[str, "Analyst that spoke last"]
current_risky_response: Annotated[
str, "Latest response by the risky analyst"
- ] # Last response
+ ]
current_safe_response: Annotated[
str, "Latest response by the safe analyst"
- ] # Last response
+ ]
current_neutral_response: Annotated[
str, "Latest response by the neutral analyst"
- ] # Last response
+ ]
judge_decision: Annotated[str, "Judge's decision"]
- count: Annotated[int, "Length of the current conversation"] # Conversation length
+ count: Annotated[int, "Length of the current conversation"]
class AgentState(MessagesState):
@@ -53,7 +53,7 @@ class AgentState(MessagesState):
sender: Annotated[str, "Agent that sent this message"]
- # research step
+ # research
market_report: Annotated[str, "Report from the Market Analyst"]
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
news_report: Annotated[
@@ -61,7 +61,7 @@ class AgentState(MessagesState):
]
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
- # researcher team discussion step
+ # research
investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not"
]
@@ -69,7 +69,7 @@ class AgentState(MessagesState):
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
- # risk management team discussion step
+ # risk mgmt
risk_debate_state: Annotated[
RiskDebateState, "Current state of the debate on evaluating risk"
]
diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py
index 6cf294a1..6f01dc32 100644
--- a/tradingagents/agents/utils/agent_utils.py
+++ b/tradingagents/agents/utils/agent_utils.py
@@ -1,6 +1,5 @@
from langchain_core.messages import HumanMessage, RemoveMessage
-# Import tools from separate utility files
from tradingagents.agents.utils.core_stock_tools import (
get_stock_data
)
@@ -24,16 +23,7 @@ def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
-
- # Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
-
- # Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
-
return {"messages": removal_operations + [placeholder]}
-
return delete_messages
-
-
-
\ No newline at end of file
diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py
index 69b8ab8c..9146313e 100644
--- a/tradingagents/agents/utils/memory.py
+++ b/tradingagents/agents/utils/memory.py
@@ -67,47 +67,45 @@ class FinancialSituationMemory:
return matched_results
-if __name__ == "__main__":
- # Example usage
- matcher = FinancialSituationMemory()
+# if __name__ == "__main__":
+# # Example usage
+# matcher = FinancialSituationMemory()
+# example_data = [
+# (
+# "High inflation rate with rising interest rates and declining consumer spending",
+# "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
+# ),
+# (
+# "Tech sector showing high volatility with increasing institutional selling pressure",
+# "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
+# ),
+# (
+# "Strong dollar affecting emerging markets with increasing forex volatility",
+# "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
+# ),
+# (
+# "Market showing signs of sector rotation with rising yields",
+# "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
+# ),
+# ]
- # Example data
- example_data = [
- (
- "High inflation rate with rising interest rates and declining consumer spending",
- "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
- ),
- (
- "Tech sector showing high volatility with increasing institutional selling pressure",
- "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
- ),
- (
- "Strong dollar affecting emerging markets with increasing forex volatility",
- "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
- ),
- (
- "Market showing signs of sector rotation with rising yields",
- "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
- ),
- ]
+# # Add the example situations and recommendations
+# matcher.add_situations(example_data)
- # Add the example situations and recommendations
- matcher.add_situations(example_data)
+# # Example query
+# current_situation = """
+# Market showing increased volatility in tech sector, with institutional investors
+# reducing positions and rising interest rates affecting growth stock valuations
+# """
- # Example query
- current_situation = """
- Market showing increased volatility in tech sector, with institutional investors
- reducing positions and rising interest rates affecting growth stock valuations
- """
+# try:
+# recommendations = matcher.get_memories(current_situation, n_matches=2)
- try:
- recommendations = matcher.get_memories(current_situation, n_matches=2)
+# for i, rec in enumerate(recommendations, 1):
+# print(f"\nMatch {i}:")
+# print(f"Similarity Score: {rec['similarity_score']:.2f}")
+# print(f"Matched Situation: {rec['matched_situation']}")
+# print(f"Recommendation: {rec['recommendation']}")
- for i, rec in enumerate(recommendations, 1):
- print(f"\nMatch {i}:")
- print(f"Similarity Score: {rec['similarity_score']:.2f}")
- print(f"Matched Situation: {rec['matched_situation']}")
- print(f"Recommendation: {rec['recommendation']}")
-
- except Exception as e:
- print(f"Error during recommendation: {str(e)}")
+# except Exception as e:
+# print(f"Error during recommendation: {str(e)}")
diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py
index 8124fb45..984bd4e0 100644
--- a/tradingagents/dataflows/alpha_vantage_news.py
+++ b/tradingagents/dataflows/alpha_vantage_news.py
@@ -1,19 +1,9 @@
+import json
+from datetime import datetime, timedelta
+from typing import List, Dict, Any
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
- """Returns live and historical market news & sentiment data from premier news outlets worldwide.
-
- Covers stocks, cryptocurrencies, forex, and topics like fiscal policy, mergers & acquisitions, IPOs.
-
- Args:
- ticker: Stock symbol for news articles.
- start_date: Start date for news search.
- end_date: End date for news search.
-
- Returns:
- Dictionary containing news sentiment data or JSON string.
- """
-
params = {
"tickers": ticker,
"time_from": format_datetime_for_api(start_date),
@@ -21,23 +11,63 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"sort": "LATEST",
"limit": "50",
}
-
+
return _make_api_request("NEWS_SENTIMENT", params)
def get_insider_transactions(symbol: str) -> dict[str, str] | str:
- """Returns latest and historical insider transactions by key stakeholders.
-
- Covers transactions by founders, executives, board members, etc.
-
- Args:
- symbol: Ticker symbol. Example: "IBM".
-
- Returns:
- Dictionary containing insider transaction data or JSON string.
- """
-
params = {
"symbol": symbol,
}
- return _make_api_request("INSIDER_TRANSACTIONS", params)
\ No newline at end of file
+ return _make_api_request("INSIDER_TRANSACTIONS", params)
+
+
+def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]:
+ end_date = datetime.now()
+ start_date = end_date - timedelta(hours=lookback_hours)
+
+ params = {
+ "time_from": format_datetime_for_api(start_date),
+ "time_to": format_datetime_for_api(end_date),
+ "sort": "LATEST",
+ "limit": "200",
+ "topics": "financial_markets,earnings,economy_fiscal,economy_monetary,mergers_and_acquisitions",
+ }
+
+ response = _make_api_request("NEWS_SENTIMENT", params)
+
+ if isinstance(response, str):
+ try:
+ response = json.loads(response)
+ except json.JSONDecodeError:
+ return []
+
+ if not isinstance(response, dict):
+ return []
+
+ feed = response.get("feed", [])
+
+ articles = []
+ for item in feed:
+ try:
+ time_published = item.get("time_published", "")
+ if time_published:
+ try:
+ published_at = datetime.strptime(time_published, "%Y%m%dT%H%M%S")
+ except ValueError:
+ published_at = datetime.now()
+ else:
+ published_at = datetime.now()
+
+ article = {
+ "title": item.get("title", ""),
+ "source": item.get("source", ""),
+ "url": item.get("url", ""),
+ "published_at": published_at.isoformat(),
+ "content_snippet": item.get("summary", "")[:500],
+ }
+ articles.append(article)
+ except Exception:
+ continue
+
+ return articles
diff --git a/tradingagents/dataflows/google.py b/tradingagents/dataflows/google.py
index 3fe20f3c..80037ed4 100644
--- a/tradingagents/dataflows/google.py
+++ b/tradingagents/dataflows/google.py
@@ -1,5 +1,5 @@
-from typing import Annotated
-from datetime import datetime
+from typing import Annotated, List, Dict, Any
+from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from .googlenews_utils import getNewsData
@@ -27,4 +27,53 @@ def get_google_news(
if len(news_results) == 0:
return ""
- return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
\ No newline at end of file
+ return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}"
+
+
+def get_bulk_news_google(lookback_hours: int) -> List[Dict[str, Any]]:
+ end_date = datetime.now()
+ start_date = end_date - timedelta(hours=lookback_hours)
+
+ start_str = start_date.strftime("%Y-%m-%d")
+ end_str = end_date.strftime("%Y-%m-%d")
+
+ queries = [
+ "stock market",
+ "trading news",
+ "earnings report",
+ ]
+
+ all_articles = []
+ seen_titles = set()
+
+ for query in queries:
+ try:
+ news_results = getNewsData(query.replace(" ", "+"), start_str, end_str)
+
+ for news in news_results:
+ title = news.get("title", "")
+ if title and title not in seen_titles:
+ seen_titles.add(title)
+
+ date_str = news.get("date", "")
+ try:
+ if date_str:
+ published_at = datetime.now()
+ else:
+ published_at = datetime.now()
+ except ValueError:
+ published_at = datetime.now()
+
+ article = {
+ "title": title,
+ "source": news.get("source", "Google News"),
+ "url": news.get("link", ""),
+ "published_at": published_at.isoformat(),
+ "content_snippet": news.get("snippet", "")[:500],
+ }
+ all_articles.append(article)
+
+ except Exception:
+ continue
+
+ return all_articles
diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py
index 4cd5ddef..c49909fa 100644
--- a/tradingagents/dataflows/interface.py
+++ b/tradingagents/dataflows/interface.py
@@ -1,10 +1,10 @@
-from typing import Annotated
+from typing import Annotated, List, Dict, Any, Optional
+from datetime import datetime, timedelta
-# Import from vendor-specific modules
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
-from .google import get_google_news
-from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai
+from .google import get_google_news, get_bulk_news_google
+from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai, get_bulk_news_openai
from .alpha_vantage import (
get_stock as get_alpha_vantage_stock,
get_indicator as get_alpha_vantage_indicator,
@@ -15,12 +15,13 @@ from .alpha_vantage import (
get_insider_transactions as get_alpha_vantage_insider_transactions,
get_news as get_alpha_vantage_news
)
+from .alpha_vantage_news import get_bulk_news_alpha_vantage
from .alpha_vantage_common import AlphaVantageRateLimitError
-# Configuration and routing logic
from .config import get_config
-# Tools organized by category
+from tradingagents.agents.discovery import NewsArticle
+
TOOLS_CATEGORIES = {
"core_stock_apis": {
"description": "OHLCV stock price data",
@@ -50,6 +51,7 @@ TOOLS_CATEGORIES = {
"get_global_news",
"get_insider_sentiment",
"get_insider_transactions",
+ "get_bulk_news",
]
}
}
@@ -61,21 +63,17 @@ VENDOR_LIST = [
"google"
]
-# Mapping of methods to their vendor-specific implementations
VENDOR_METHODS = {
- # core_stock_apis
"get_stock_data": {
"alpha_vantage": get_alpha_vantage_stock,
"yfinance": get_YFin_data_online,
"local": get_YFin_data,
},
- # technical_indicators
"get_indicators": {
"alpha_vantage": get_alpha_vantage_indicator,
"yfinance": get_stock_stats_indicators_window,
"local": get_stock_stats_indicators_window
},
- # fundamental_data
"get_fundamentals": {
"alpha_vantage": get_alpha_vantage_fundamentals,
"openai": get_fundamentals_openai,
@@ -95,7 +93,6 @@ VENDOR_METHODS = {
"yfinance": get_yfinance_income_statement,
"local": get_simfin_income_statements,
},
- # news_data
"get_news": {
"alpha_vantage": get_alpha_vantage_news,
"openai": get_stock_news_openai,
@@ -114,56 +111,159 @@ VENDOR_METHODS = {
"yfinance": get_yfinance_insider_transactions,
"local": get_finnhub_company_insider_transactions,
},
+ "get_bulk_news": {
+ "alpha_vantage": get_bulk_news_alpha_vantage,
+ "openai": get_bulk_news_openai,
+ "google": get_bulk_news_google,
+ },
}
+CACHE_TTL_SECONDS = 300
+
+_bulk_news_cache: Dict[str, Dict[str, Any]] = {}
+
+
+def parse_lookback_period(lookback: str) -> int:
+ lookback = lookback.lower().strip()
+
+ if lookback == "1h":
+ return 1
+ elif lookback == "6h":
+ return 6
+ elif lookback == "24h":
+ return 24
+ elif lookback == "7d":
+ return 168
+ else:
+ raise ValueError(f"Invalid lookback period: {lookback}. Valid values: 1h, 6h, 24h, 7d")
+
+
+def _get_cached_bulk_news(lookback_period: str) -> Optional[List[NewsArticle]]:
+ cache_key = lookback_period
+ if cache_key in _bulk_news_cache:
+ cached = _bulk_news_cache[cache_key]
+ cached_time = cached.get("timestamp")
+ if cached_time and (datetime.now() - cached_time).total_seconds() < CACHE_TTL_SECONDS:
+ return cached.get("articles")
+ return None
+
+
+def _set_cached_bulk_news(lookback_period: str, articles: List[NewsArticle]) -> None:
+ cache_key = lookback_period
+ _bulk_news_cache[cache_key] = {
+ "timestamp": datetime.now(),
+ "articles": articles,
+ }
+
+
+def _convert_to_news_articles(raw_articles: List[Dict[str, Any]]) -> List[NewsArticle]:
+ articles = []
+ for item in raw_articles:
+ try:
+ published_at_str = item.get("published_at", "")
+ if isinstance(published_at_str, str):
+ try:
+ published_at = datetime.fromisoformat(published_at_str.replace("Z", "+00:00"))
+ except ValueError:
+ published_at = datetime.now()
+ elif isinstance(published_at_str, datetime):
+ published_at = published_at_str
+ else:
+ published_at = datetime.now()
+
+ article = NewsArticle(
+ title=item.get("title", ""),
+ source=item.get("source", ""),
+ url=item.get("url", ""),
+ published_at=published_at,
+ content_snippet=item.get("content_snippet", ""),
+ ticker_mentions=[],
+ )
+ articles.append(article)
+ except Exception:
+ continue
+ return articles
+
+
+def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]:
+ lookback_hours = parse_lookback_period(lookback_period)
+
+ vendor_order = ["alpha_vantage", "openai", "google"]
+
+ for vendor in vendor_order:
+ if vendor not in VENDOR_METHODS["get_bulk_news"]:
+ continue
+
+ vendor_func = VENDOR_METHODS["get_bulk_news"][vendor]
+
+ try:
+ print(f"DEBUG: Attempting bulk news from vendor '{vendor}'...")
+ result = vendor_func(lookback_hours)
+ if result:
+ print(f"SUCCESS: Got {len(result)} articles from vendor '{vendor}'")
+ return result
+ print(f"DEBUG: Vendor '{vendor}' returned empty results, trying next...")
+ except AlphaVantageRateLimitError as e:
+ print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded: {e}")
+ continue
+ except Exception as e:
+ print(f"FAILED: Vendor '{vendor}' failed: {e}")
+ continue
+
+ return []
+
+
+def get_bulk_news(lookback_period: str = "24h") -> List[NewsArticle]:
+ cached = _get_cached_bulk_news(lookback_period)
+ if cached is not None:
+ print(f"DEBUG: Returning cached bulk news for period '{lookback_period}'")
+ return cached
+
+ raw_articles = _fetch_bulk_news_from_vendor(lookback_period)
+
+ articles = _convert_to_news_articles(raw_articles)
+
+ _set_cached_bulk_news(lookback_period, articles)
+
+ return articles
+
+
def get_category_for_method(method: str) -> str:
- """Get the category that contains the specified method."""
for category, info in TOOLS_CATEGORIES.items():
if method in info["tools"]:
return category
raise ValueError(f"Method '{method}' not found in any category")
def get_vendor(category: str, method: str = None) -> str:
- """Get the configured vendor for a data category or specific tool method.
- Tool-level configuration takes precedence over category-level.
- """
config = get_config()
- # Check tool-level configuration first (if method provided)
if method:
tool_vendors = config.get("tool_vendors", {})
if method in tool_vendors:
return tool_vendors[method]
- # Fall back to category-level configuration
return config.get("data_vendors", {}).get(category, "default")
def route_to_vendor(method: str, *args, **kwargs):
- """Route method calls to appropriate vendor implementation with fallback support."""
category = get_category_for_method(method)
vendor_config = get_vendor(category, method)
- # Handle comma-separated vendors
primary_vendors = [v.strip() for v in vendor_config.split(',')]
if method not in VENDOR_METHODS:
raise ValueError(f"Method '{method}' not supported")
- # Get all available vendors for this method for fallback
all_available_vendors = list(VENDOR_METHODS[method].keys())
-
- # Create fallback vendor list: primary vendors first, then remaining vendors as fallbacks
+
fallback_vendors = primary_vendors.copy()
for vendor in all_available_vendors:
if vendor not in fallback_vendors:
fallback_vendors.append(vendor)
- # Debug: Print fallback ordering
- primary_str = " → ".join(primary_vendors)
- fallback_str = " → ".join(fallback_vendors)
+ primary_str = " -> ".join(primary_vendors)
+ fallback_str = " -> ".join(fallback_vendors)
print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]")
- # Track results and execution state
results = []
vendor_attempt_count = 0
any_primary_vendor_attempted = False
@@ -179,22 +279,18 @@ def route_to_vendor(method: str, *args, **kwargs):
is_primary_vendor = vendor in primary_vendors
vendor_attempt_count += 1
- # Track if we attempted any primary vendor
if is_primary_vendor:
any_primary_vendor_attempted = True
- # Debug: Print current attempt
vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK"
print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})")
- # Handle list of methods for a vendor
if isinstance(vendor_impl, list):
vendor_methods = [(impl, vendor) for impl in vendor_impl]
print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions")
else:
vendor_methods = [(vendor_impl, vendor)]
- # Run methods for this vendor
vendor_results = []
for impl_func, vendor_name in vendor_methods:
try:
@@ -202,43 +298,35 @@ def route_to_vendor(method: str, *args, **kwargs):
result = impl_func(*args, **kwargs)
vendor_results.append(result)
print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully")
-
+
except AlphaVantageRateLimitError as e:
if vendor == "alpha_vantage":
print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor")
print(f"DEBUG: Rate limit details: {e}")
- # Continue to next vendor for fallback
continue
except Exception as e:
- # Log error but continue with other implementations
print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}")
continue
- # Add this vendor's results
if vendor_results:
results.extend(vendor_results)
successful_vendor = vendor
result_summary = f"Got {len(vendor_results)} result(s)"
print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}")
-
- # Stopping logic: Stop after first successful vendor for single-vendor configs
- # Multiple vendor configs (comma-separated) may want to collect from multiple sources
+
if len(primary_vendors) == 1:
print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)")
break
else:
print(f"FAILED: Vendor '{vendor}' produced no results")
- # Final result summary
if not results:
print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'")
raise RuntimeError(f"All vendor implementations failed for method '{method}'")
else:
print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)")
- # Return single result if only one, otherwise concatenate as string
if len(results) == 1:
return results[0]
else:
- # Convert all results to strings and concatenate
- return '\n'.join(str(result) for result in results)
\ No newline at end of file
+ return '\n'.join(str(result) for result in results)
diff --git a/tradingagents/dataflows/openai.py b/tradingagents/dataflows/openai.py
index 91a2258b..d9439e14 100644
--- a/tradingagents/dataflows/openai.py
+++ b/tradingagents/dataflows/openai.py
@@ -1,3 +1,7 @@
+import json
+import re
+from datetime import datetime, timedelta
+from typing import List, Dict, Any
from openai import OpenAI
from .config import get_config
@@ -104,4 +108,91 @@ def get_fundamentals_openai(ticker, curr_date):
store=True,
)
- return response.output[1].content[0].text
\ No newline at end of file
+ return response.output[1].content[0].text
+
+
+def get_bulk_news_openai(lookback_hours: int) -> List[Dict[str, Any]]:
+ config = get_config()
+ client = OpenAI(base_url=config["backend_url"])
+
+ end_date = datetime.now()
+ start_date = end_date - timedelta(hours=lookback_hours)
+
+ start_str = start_date.strftime("%Y-%m-%d %H:%M")
+ end_str = end_date.strftime("%Y-%m-%d %H:%M")
+
+ prompt = f"""Search for recent stock market news, trading news, and earnings announcements from {start_str} to {end_str}.
+
+Return the results as a JSON array with the following structure:
+[
+ {{
+ "title": "Article title",
+ "source": "Source name",
+ "url": "https://...",
+ "published_at": "YYYY-MM-DDTHH:MM:SS",
+ "content_snippet": "Brief summary of the article..."
+ }}
+]
+
+Focus on:
+- Stock market movements and trends
+- Company earnings reports
+- Mergers and acquisitions
+- Significant trading activity
+- Economic news affecting markets
+
+Return ONLY the JSON array, no additional text."""
+
+ response = client.responses.create(
+ model=config["quick_think_llm"],
+ input=[
+ {
+ "role": "system",
+ "content": [
+ {
+ "type": "input_text",
+ "text": prompt,
+ }
+ ],
+ }
+ ],
+ text={"format": {"type": "text"}},
+ reasoning={},
+ tools=[
+ {
+ "type": "web_search_preview",
+ "user_location": {"type": "approximate"},
+ "search_context_size": "medium",
+ }
+ ],
+ temperature=0.5,
+ max_output_tokens=8192,
+ top_p=1,
+ store=True,
+ )
+
+ try:
+ response_text = response.output[1].content[0].text
+
+ json_match = re.search(r'\[[\s\S]*\]', response_text)
+ if json_match:
+ articles = json.loads(json_match.group())
+ else:
+ articles = json.loads(response_text)
+
+ result = []
+ for item in articles:
+ if isinstance(item, dict):
+ article = {
+ "title": item.get("title", ""),
+ "source": item.get("source", "Web Search"),
+ "url": item.get("url", ""),
+ "published_at": item.get("published_at", datetime.now().isoformat()),
+ "content_snippet": item.get("content_snippet", "")[:500],
+ }
+ result.append(article)
+
+ return result
+
+ except (json.JSONDecodeError, IndexError, AttributeError):
+ return []
diff --git a/tradingagents/dataflows/trending/__init__.py b/tradingagents/dataflows/trending/__init__.py
new file mode 100644
index 00000000..190a1a32
--- /dev/null
+++ b/tradingagents/dataflows/trending/__init__.py
@@ -0,0 +1,21 @@
+from .stock_resolver import (
+ resolve_ticker,
+ validate_tradeable,
+ validate_us_ticker,
+ COMPANY_TO_TICKER,
+)
+from .sector_classifier import (
+ classify_sector,
+ TICKER_TO_SECTOR,
+ VALID_SECTORS,
+)
+
+__all__ = [
+ "resolve_ticker",
+ "validate_tradeable",
+ "validate_us_ticker",
+ "COMPANY_TO_TICKER",
+ "classify_sector",
+ "TICKER_TO_SECTOR",
+ "VALID_SECTORS",
+]
diff --git a/tradingagents/dataflows/trending/sector_classifier.py b/tradingagents/dataflows/trending/sector_classifier.py
new file mode 100644
index 00000000..999ec3aa
--- /dev/null
+++ b/tradingagents/dataflows/trending/sector_classifier.py
@@ -0,0 +1,267 @@
+import logging
+from typing import Dict
+
+logger = logging.getLogger(__name__)
+
+VALID_SECTORS = {
+ "technology",
+ "healthcare",
+ "finance",
+ "energy",
+ "consumer_goods",
+ "industrials",
+ "other",
+}
+
+TICKER_TO_SECTOR: Dict[str, str] = {
+ "AAPL": "technology",
+ "MSFT": "technology",
+ "GOOGL": "technology",
+ "GOOG": "technology",
+ "AMZN": "technology",
+ "META": "technology",
+ "NVDA": "technology",
+ "TSLA": "technology",
+ "AMD": "technology",
+ "INTC": "technology",
+ "QCOM": "technology",
+ "AVGO": "technology",
+ "TXN": "technology",
+ "ADBE": "technology",
+ "CRM": "technology",
+ "CSCO": "technology",
+ "NFLX": "technology",
+ "ORCL": "technology",
+ "IBM": "technology",
+ "NOW": "technology",
+ "INTU": "technology",
+ "ADSK": "technology",
+ "SNPS": "technology",
+ "CDNS": "technology",
+ "PLTR": "technology",
+ "SNOW": "technology",
+ "DDOG": "technology",
+ "CRWD": "technology",
+ "OKTA": "technology",
+ "NET": "technology",
+ "MDB": "technology",
+ "TWLO": "technology",
+ "WDAY": "technology",
+ "SPLK": "technology",
+ "VMW": "technology",
+ "HPQ": "technology",
+ "DELL": "technology",
+ "FTNT": "technology",
+ "PANW": "technology",
+ "ZS": "technology",
+ "S": "technology",
+ "VEEV": "technology",
+ "ZM": "technology",
+ "DOCU": "technology",
+ "ASAN": "technology",
+ "MNDY": "technology",
+ "TEAM": "technology",
+ "ANSS": "technology",
+ "ROP": "technology",
+ "JPM": "finance",
+ "BAC": "finance",
+ "WFC": "finance",
+ "GS": "finance",
+ "MS": "finance",
+ "C": "finance",
+ "BLK": "finance",
+ "SCHW": "finance",
+ "AXP": "finance",
+ "V": "finance",
+ "MA": "finance",
+ "PYPL": "finance",
+ "SQ": "finance",
+ "COIN": "finance",
+ "HOOD": "finance",
+ "SOFI": "finance",
+ "AFRM": "finance",
+ "MQ": "finance",
+ "BRK-B": "finance",
+ "BRK-A": "finance",
+ "JNJ": "healthcare",
+ "UNH": "healthcare",
+ "PFE": "healthcare",
+ "ABBV": "healthcare",
+ "MRK": "healthcare",
+ "LLY": "healthcare",
+ "MRNA": "healthcare",
+ "BNTX": "healthcare",
+ "CVS": "healthcare",
+ "WBA": "healthcare",
+ "MCK": "healthcare",
+ "CAH": "healthcare",
+ "HUM": "healthcare",
+ "CI": "healthcare",
+ "ELV": "healthcare",
+ "XOM": "energy",
+ "CVX": "energy",
+ "COP": "energy",
+ "SLB": "energy",
+ "HAL": "energy",
+ "BKR": "energy",
+ "MPC": "energy",
+ "VLO": "energy",
+ "PSX": "energy",
+ "OXY": "energy",
+ "PXD": "energy",
+ "DVN": "energy",
+ "CEG": "energy",
+ "NEE": "energy",
+ "DUK": "energy",
+ "SO": "energy",
+ "D": "energy",
+ "SRE": "energy",
+ "WMT": "consumer_goods",
+ "COST": "consumer_goods",
+ "TGT": "consumer_goods",
+ "HD": "consumer_goods",
+ "LOW": "consumer_goods",
+ "PG": "consumer_goods",
+ "KO": "consumer_goods",
+ "PEP": "consumer_goods",
+ "NKE": "consumer_goods",
+ "SBUX": "consumer_goods",
+ "MCD": "consumer_goods",
+ "CMG": "consumer_goods",
+ "YUM": "consumer_goods",
+ "DPZ": "consumer_goods",
+ "DIS": "consumer_goods",
+ "CMCSA": "consumer_goods",
+ "VZ": "consumer_goods",
+ "T": "consumer_goods",
+ "TMUS": "consumer_goods",
+ "EL": "consumer_goods",
+ "CL": "consumer_goods",
+ "KMB": "consumer_goods",
+ "CLX": "consumer_goods",
+ "KHC": "consumer_goods",
+ "GIS": "consumer_goods",
+ "K": "consumer_goods",
+ "MDLZ": "consumer_goods",
+ "HSY": "consumer_goods",
+ "TSN": "consumer_goods",
+ "BYND": "consumer_goods",
+ "CAG": "consumer_goods",
+ "STZ": "consumer_goods",
+ "BUD": "consumer_goods",
+ "DEO": "consumer_goods",
+ "PM": "consumer_goods",
+ "MO": "consumer_goods",
+ "LULU": "consumer_goods",
+ "DG": "consumer_goods",
+ "DLTR": "consumer_goods",
+ "ROST": "consumer_goods",
+ "TJX": "consumer_goods",
+ "AZO": "consumer_goods",
+ "ORLY": "consumer_goods",
+ "KMX": "consumer_goods",
+ "ADDYY": "consumer_goods",
+ "UBER": "consumer_goods",
+ "LYFT": "consumer_goods",
+ "ABNB": "consumer_goods",
+ "DASH": "consumer_goods",
+ "SNAP": "consumer_goods",
+ "PINS": "consumer_goods",
+ "TWTR": "consumer_goods",
+ "SHOP": "consumer_goods",
+ "TOST": "consumer_goods",
+ "BA": "industrials",
+ "LMT": "industrials",
+ "RTX": "industrials",
+ "GD": "industrials",
+ "NOC": "industrials",
+ "GE": "industrials",
+ "HON": "industrials",
+ "MMM": "industrials",
+ "CAT": "industrials",
+ "DE": "industrials",
+ "UNP": "industrials",
+ "UPS": "industrials",
+ "FDX": "industrials",
+ "DAL": "industrials",
+ "UAL": "industrials",
+ "AAL": "industrials",
+ "LUV": "industrials",
+ "F": "industrials",
+ "GM": "industrials",
+ "TM": "industrials",
+ "HMC": "industrials",
+ "VWAGY": "industrials",
+ "RACE": "industrials",
+ "RIVN": "industrials",
+ "LCID": "industrials",
+ "NIO": "industrials",
+ "LNVGY": "industrials",
+}
+
+_sector_cache: Dict[str, str] = {}
+
+
+def _llm_classify_sector(ticker: str) -> str:
+ from langchain_openai import ChatOpenAI
+ from langchain_core.messages import HumanMessage, SystemMessage
+ from tradingagents.default_config import DEFAULT_CONFIG
+
+ llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini")
+ llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai")
+ backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1")
+
+ llm = ChatOpenAI(
+ model=llm_name,
+ base_url=backend_url,
+ temperature=0,
+ )
+
+ system_prompt = (
+ "You are a financial sector classifier. Given a stock ticker symbol, "
+ "classify it into exactly one of the following sectors: "
+ "technology, healthcare, finance, energy, consumer_goods, industrials, other. "
+ "Respond with only the sector name in lowercase, nothing else."
+ )
+
+ user_prompt = f"Classify the stock ticker: {ticker}"
+
+ messages = [
+ SystemMessage(content=system_prompt),
+ HumanMessage(content=user_prompt),
+ ]
+
+ response = llm.invoke(messages)
+ sector = response.content.strip().lower()
+
+ if sector not in VALID_SECTORS:
+ logger.warning(
+ "LLM returned invalid sector '%s' for ticker %s, defaulting to 'other'",
+ sector,
+ ticker,
+ )
+ return "other"
+
+ return sector
+
+
+def classify_sector(ticker: str) -> str:
+ ticker_upper = ticker.upper()
+
+ if ticker_upper in TICKER_TO_SECTOR:
+ return TICKER_TO_SECTOR[ticker_upper]
+
+ if ticker_upper in _sector_cache:
+ return _sector_cache[ticker_upper]
+
+ logger.info("Using LLM fallback for sector classification of ticker: %s", ticker)
+
+ try:
+ sector = _llm_classify_sector(ticker_upper)
+ _sector_cache[ticker_upper] = sector
+ logger.info("Classified %s as %s via LLM", ticker, sector)
+ return sector
+ except Exception as e:
+ logger.error("LLM sector classification failed for %s: %s", ticker, str(e))
+ _sector_cache[ticker_upper] = "other"
+ return "other"
diff --git a/tradingagents/dataflows/trending/stock_resolver.py b/tradingagents/dataflows/trending/stock_resolver.py
new file mode 100644
index 00000000..a2c2aefb
--- /dev/null
+++ b/tradingagents/dataflows/trending/stock_resolver.py
@@ -0,0 +1,538 @@
+import logging
+import re
+from typing import Optional
+
+import yfinance as yf
+
+logger = logging.getLogger(__name__)
+
+COMPANY_TO_TICKER = {
+ "apple": "AAPL",
+ "apple inc": "AAPL",
+ "apple inc.": "AAPL",
+ "apple corporation": "AAPL",
+ "the iphone maker": "AAPL",
+ "iphone maker": "AAPL",
+ "microsoft": "MSFT",
+ "microsoft inc": "MSFT",
+ "microsoft inc.": "MSFT",
+ "microsoft corp": "MSFT",
+ "microsoft corp.": "MSFT",
+ "microsoft corporation": "MSFT",
+ "google": "GOOGL",
+ "alphabet": "GOOGL",
+ "alphabet inc": "GOOGL",
+ "alphabet inc.": "GOOGL",
+ "the search giant": "GOOGL",
+ "amazon": "AMZN",
+ "amazon inc": "AMZN",
+ "amazon inc.": "AMZN",
+ "amazon.com": "AMZN",
+ "amazon.com inc": "AMZN",
+ "the e-commerce giant": "AMZN",
+ "e-commerce giant": "AMZN",
+ "meta": "META",
+ "meta platforms": "META",
+ "meta platforms inc": "META",
+ "meta platforms inc.": "META",
+ "facebook": "META",
+ "facebook inc": "META",
+ "facebook inc.": "META",
+ "tesla": "TSLA",
+ "tesla inc": "TSLA",
+ "tesla inc.": "TSLA",
+ "tesla motors": "TSLA",
+ "ev maker tesla": "TSLA",
+ "nvidia": "NVDA",
+ "nvidia corp": "NVDA",
+ "nvidia corp.": "NVDA",
+ "nvidia corporation": "NVDA",
+ "berkshire hathaway": "BRK-B",
+ "berkshire": "BRK-B",
+ "jpmorgan": "JPM",
+ "jpmorgan chase": "JPM",
+ "jp morgan": "JPM",
+ "jp morgan chase": "JPM",
+ "johnson & johnson": "JNJ",
+ "johnson and johnson": "JNJ",
+ "j&j": "JNJ",
+ "unitedhealth": "UNH",
+ "unitedhealth group": "UNH",
+ "visa": "V",
+ "visa inc": "V",
+ "visa inc.": "V",
+ "procter & gamble": "PG",
+ "procter and gamble": "PG",
+ "p&g": "PG",
+ "mastercard": "MA",
+ "mastercard inc": "MA",
+ "mastercard inc.": "MA",
+ "home depot": "HD",
+ "the home depot": "HD",
+ "chevron": "CVX",
+ "chevron corp": "CVX",
+ "chevron corporation": "CVX",
+ "exxon": "XOM",
+ "exxon mobil": "XOM",
+ "exxonmobil": "XOM",
+ "pfizer": "PFE",
+ "pfizer inc": "PFE",
+ "pfizer inc.": "PFE",
+ "abbvie": "ABBV",
+ "abbvie inc": "ABBV",
+ "abbvie inc.": "ABBV",
+ "coca-cola": "KO",
+ "coca cola": "KO",
+ "coke": "KO",
+ "the coca-cola company": "KO",
+ "pepsico": "PEP",
+ "pepsi": "PEP",
+ "pepsi co": "PEP",
+ "costco": "COST",
+ "costco wholesale": "COST",
+ "walmart": "WMT",
+ "wal-mart": "WMT",
+ "walmart inc": "WMT",
+ "bank of america": "BAC",
+ "bofa": "BAC",
+ "merck": "MRK",
+ "merck & co": "MRK",
+ "merck and co": "MRK",
+ "eli lilly": "LLY",
+ "lilly": "LLY",
+ "eli lilly and company": "LLY",
+ "adobe": "ADBE",
+ "adobe inc": "ADBE",
+ "adobe inc.": "ADBE",
+ "adobe systems": "ADBE",
+ "salesforce": "CRM",
+ "salesforce inc": "CRM",
+ "salesforce.com": "CRM",
+ "cisco": "CSCO",
+ "cisco systems": "CSCO",
+ "cisco systems inc": "CSCO",
+ "netflix": "NFLX",
+ "netflix inc": "NFLX",
+ "netflix inc.": "NFLX",
+ "oracle": "ORCL",
+ "oracle corp": "ORCL",
+ "oracle corporation": "ORCL",
+ "intel": "INTC",
+ "intel corp": "INTC",
+ "intel corporation": "INTC",
+ "amd": "AMD",
+ "advanced micro devices": "AMD",
+ "qualcomm": "QCOM",
+ "qualcomm inc": "QCOM",
+ "qualcomm inc.": "QCOM",
+ "broadcom": "AVGO",
+ "broadcom inc": "AVGO",
+ "broadcom inc.": "AVGO",
+ "texas instruments": "TXN",
+ "ti": "TXN",
+ "disney": "DIS",
+ "walt disney": "DIS",
+ "the walt disney company": "DIS",
+ "walt disney company": "DIS",
+ "comcast": "CMCSA",
+ "comcast corp": "CMCSA",
+ "comcast corporation": "CMCSA",
+ "verizon": "VZ",
+ "verizon communications": "VZ",
+ "at&t": "T",
+ "att": "T",
+ "t-mobile": "TMUS",
+ "tmobile": "TMUS",
+ "t-mobile us": "TMUS",
+ "american express": "AXP",
+ "amex": "AXP",
+ "goldman sachs": "GS",
+ "goldman": "GS",
+ "morgan stanley": "MS",
+ "wells fargo": "WFC",
+ "wells": "WFC",
+ "citigroup": "C",
+ "citi": "C",
+ "citibank": "C",
+ "charles schwab": "SCHW",
+ "schwab": "SCHW",
+ "blackrock": "BLK",
+ "blackrock inc": "BLK",
+ "paypal": "PYPL",
+ "paypal holdings": "PYPL",
+ "paypal inc": "PYPL",
+ "square": "SQ",
+ "block": "SQ",
+ "block inc": "SQ",
+ "shopify": "SHOP",
+ "shopify inc": "SHOP",
+ "uber": "UBER",
+ "uber technologies": "UBER",
+ "lyft": "LYFT",
+ "lyft inc": "LYFT",
+ "airbnb": "ABNB",
+ "airbnb inc": "ABNB",
+ "doordash": "DASH",
+ "doordash inc": "DASH",
+ "snap": "SNAP",
+ "snap inc": "SNAP",
+ "snapchat": "SNAP",
+ "pinterest": "PINS",
+ "pinterest inc": "PINS",
+ "twitter": "TWTR",
+ "twitter inc": "TWTR",
+ "linkedin": "MSFT",
+ "zoom": "ZM",
+ "zoom video": "ZM",
+ "zoom video communications": "ZM",
+ "slack": "CRM",
+ "slack technologies": "CRM",
+ "palantir": "PLTR",
+ "palantir technologies": "PLTR",
+ "snowflake": "SNOW",
+ "snowflake inc": "SNOW",
+ "datadog": "DDOG",
+ "datadog inc": "DDOG",
+ "crowdstrike": "CRWD",
+ "crowdstrike holdings": "CRWD",
+ "okta": "OKTA",
+ "okta inc": "OKTA",
+ "cloudflare": "NET",
+ "cloudflare inc": "NET",
+ "mongodb": "MDB",
+ "mongodb inc": "MDB",
+ "twilio": "TWLO",
+ "twilio inc": "TWLO",
+ "servicenow": "NOW",
+ "servicenow inc": "NOW",
+ "workday": "WDAY",
+ "workday inc": "WDAY",
+ "splunk": "SPLK",
+ "splunk inc": "SPLK",
+ "vmware": "VMW",
+ "vmware inc": "VMW",
+ "ibm": "IBM",
+ "international business machines": "IBM",
+ "hp": "HPQ",
+ "hewlett-packard": "HPQ",
+ "hewlett packard": "HPQ",
+ "dell": "DELL",
+ "dell technologies": "DELL",
+ "lenovo": "LNVGY",
+ "boeing": "BA",
+ "boeing company": "BA",
+ "the boeing company": "BA",
+ "lockheed martin": "LMT",
+ "lockheed": "LMT",
+ "raytheon": "RTX",
+ "rtx": "RTX",
+ "general dynamics": "GD",
+ "northrop grumman": "NOC",
+ "northrop": "NOC",
+ "general electric": "GE",
+ "ge": "GE",
+ "honeywell": "HON",
+ "honeywell international": "HON",
+ "3m": "MMM",
+ "3m company": "MMM",
+ "caterpillar": "CAT",
+ "caterpillar inc": "CAT",
+ "deere": "DE",
+ "john deere": "DE",
+ "deere & company": "DE",
+ "union pacific": "UNP",
+ "ups": "UPS",
+ "united parcel service": "UPS",
+ "fedex": "FDX",
+ "federal express": "FDX",
+ "delta": "DAL",
+ "delta air lines": "DAL",
+ "delta airlines": "DAL",
+ "united airlines": "UAL",
+ "united": "UAL",
+ "american airlines": "AAL",
+ "southwest": "LUV",
+ "southwest airlines": "LUV",
+ "ford": "F",
+ "ford motor": "F",
+ "ford motor company": "F",
+ "general motors": "GM",
+ "gm": "GM",
+ "toyota": "TM",
+ "toyota motor": "TM",
+ "honda": "HMC",
+ "honda motor": "HMC",
+ "volkswagen": "VWAGY",
+ "vw": "VWAGY",
+ "ferrari": "RACE",
+ "rivian": "RIVN",
+ "rivian automotive": "RIVN",
+ "lucid": "LCID",
+ "lucid motors": "LCID",
+ "lucid group": "LCID",
+ "nio": "NIO",
+ "nio inc": "NIO",
+ "moderna": "MRNA",
+ "moderna inc": "MRNA",
+ "biontech": "BNTX",
+ "cvs": "CVS",
+ "cvs health": "CVS",
+ "walgreens": "WBA",
+ "walgreens boots alliance": "WBA",
+ "mckesson": "MCK",
+ "mckesson corp": "MCK",
+ "cardinal health": "CAH",
+ "humana": "HUM",
+ "humana inc": "HUM",
+ "cigna": "CI",
+ "cigna group": "CI",
+ "anthem": "ELV",
+ "elevance health": "ELV",
+ "starbucks": "SBUX",
+ "starbucks corp": "SBUX",
+ "starbucks corporation": "SBUX",
+ "mcdonalds": "MCD",
+ "mcdonald's": "MCD",
+ "chipotle": "CMG",
+ "chipotle mexican grill": "CMG",
+ "yum brands": "YUM",
+ "yum": "YUM",
+ "dominos": "DPZ",
+ "domino's": "DPZ",
+ "domino's pizza": "DPZ",
+ "nike": "NKE",
+ "nike inc": "NKE",
+ "adidas": "ADDYY",
+ "lululemon": "LULU",
+ "lululemon athletica": "LULU",
+ "target": "TGT",
+ "target corp": "TGT",
+ "target corporation": "TGT",
+ "dollar general": "DG",
+ "dollar tree": "DLTR",
+ "ross stores": "ROST",
+ "ross": "ROST",
+ "tjx": "TJX",
+ "tjx companies": "TJX",
+ "tj maxx": "TJX",
+ "lowes": "LOW",
+ "lowe's": "LOW",
+ "lowe's companies": "LOW",
+ "autozone": "AZO",
+ "o'reilly": "ORLY",
+ "o'reilly automotive": "ORLY",
+ "carmax": "KMX",
+ "estee lauder": "EL",
+ "colgate": "CL",
+ "colgate-palmolive": "CL",
+ "colgate palmolive": "CL",
+ "kimberly-clark": "KMB",
+ "kimberly clark": "KMB",
+ "clorox": "CLX",
+ "clorox company": "CLX",
+ "kraft heinz": "KHC",
+ "kraft": "KHC",
+ "heinz": "KHC",
+ "general mills": "GIS",
+ "kellogg": "K",
+ "kellogg's": "K",
+ "mondelez": "MDLZ",
+ "mondelez international": "MDLZ",
+ "hershey": "HSY",
+ "the hershey company": "HSY",
+ "tyson": "TSN",
+ "tyson foods": "TSN",
+ "beyond meat": "BYND",
+ "conagra": "CAG",
+ "conagra brands": "CAG",
+ "constellation brands": "STZ",
+ "anheuser-busch": "BUD",
+ "anheuser busch": "BUD",
+ "ab inbev": "BUD",
+ "diageo": "DEO",
+ "philip morris": "PM",
+ "philip morris international": "PM",
+ "altria": "MO",
+ "altria group": "MO",
+ "constellation energy": "CEG",
+ "nextera": "NEE",
+ "nextera energy": "NEE",
+ "duke energy": "DUK",
+ "southern company": "SO",
+ "dominion": "D",
+ "dominion energy": "D",
+ "sempra": "SRE",
+ "sempra energy": "SRE",
+ "conocophillips": "COP",
+ "conoco": "COP",
+ "schlumberger": "SLB",
+ "halliburton": "HAL",
+ "baker hughes": "BKR",
+ "marathon": "MPC",
+ "marathon petroleum": "MPC",
+ "valero": "VLO",
+ "valero energy": "VLO",
+ "phillips 66": "PSX",
+ "occidental": "OXY",
+ "occidental petroleum": "OXY",
+ "pioneer": "PXD",
+ "pioneer natural resources": "PXD",
+ "devon energy": "DVN",
+ "devon": "DVN",
+ "coinbase": "COIN",
+ "coinbase global": "COIN",
+ "robinhood": "HOOD",
+ "robinhood markets": "HOOD",
+ "sofi": "SOFI",
+ "sofi technologies": "SOFI",
+ "affirm": "AFRM",
+ "affirm holdings": "AFRM",
+ "marqeta": "MQ",
+ "toast": "TOST",
+ "toast inc": "TOST",
+ "docusign": "DOCU",
+ "docusign inc": "DOCU",
+ "asana": "ASAN",
+ "monday.com": "MNDY",
+ "monday": "MNDY",
+ "atlassian": "TEAM",
+ "atlassian corp": "TEAM",
+ "intuit": "INTU",
+ "intuit inc": "INTU",
+ "autodesk": "ADSK",
+ "autodesk inc": "ADSK",
+ "synopsys": "SNPS",
+ "cadence": "CDNS",
+ "cadence design": "CDNS",
+ "ansys": "ANSS",
+ "roper": "ROP",
+ "roper technologies": "ROP",
+ "fortinet": "FTNT",
+ "palo alto": "PANW",
+ "palo alto networks": "PANW",
+ "zscaler": "ZS",
+ "sentinelone": "S",
+ "veeva": "VEEV",
+ "veeva systems": "VEEV",
+}
+
+US_EXCHANGE_CODES = {
+ "NYQ",
+ "NMS",
+ "NGM",
+ "NCM",
+ "ASE",
+ "PCX",
+ "BTS",
+ "NYSE",
+ "NASDAQ",
+ "AMEX",
+ "NYS",
+ "NAS",
+ "NIM",
+ "NAQ",
+}
+
+SUFFIX_PATTERNS = [
+ r"\s+inc\.?$",
+ r"\s+corp\.?$",
+ r"\s+corporation$",
+ r"\s+co\.?$",
+ r"\s+company$",
+ r"\s+llc$",
+ r"\s+ltd\.?$",
+ r"\s+limited$",
+ r"\s+plc$",
+ r"\s+holdings?$",
+ r"\s+group$",
+ r"\s+technologies$",
+ r"\s+enterprises?$",
+]
+
+
+def _normalize_company_name(name: str) -> str:
+ normalized = name.lower().strip()
+ for pattern in SUFFIX_PATTERNS:
+ normalized = re.sub(pattern, "", normalized, flags=re.IGNORECASE)
+ normalized = normalized.strip()
+ return normalized
+
+
+def _search_yfinance_ticker(company_name: str) -> Optional[str]:
+ try:
+ search_result = yf.Ticker(company_name)
+ info = search_result.info
+ if info and "symbol" in info:
+ return info["symbol"]
+ except Exception as e:
+ logger.debug("yfinance search failed for %s: %s", company_name, str(e))
+
+ try:
+ search = yf.Search(company_name, max_results=5)
+ if hasattr(search, "quotes") and search.quotes:
+ for quote in search.quotes:
+ if "symbol" in quote:
+ return quote["symbol"]
+ except Exception as e:
+ logger.debug("yfinance Search failed for %s: %s", company_name, str(e))
+
+ return None
+
+
+def validate_us_ticker(ticker: str) -> bool:
+ try:
+ ticker_obj = yf.Ticker(ticker.upper())
+ info = ticker_obj.info
+ if not info:
+ logger.warning("Validation failed for %s: no info available", ticker)
+ return False
+
+ exchange = info.get("exchange", "")
+ if exchange in US_EXCHANGE_CODES:
+ return True
+
+ exchange_lower = exchange.lower()
+ if any(us_ex.lower() in exchange_lower for us_ex in ["nyse", "nasdaq", "amex", "nys", "nms", "ngm"]):
+ return True
+
+ logger.warning("Validation failed for %s: exchange %s is not a US exchange", ticker, exchange)
+ return False
+ except Exception as e:
+ logger.warning("Validation failed for %s: %s", ticker, str(e))
+ return False
+
+
+def resolve_ticker(company_name: str) -> Optional[str]:
+ if not company_name or not company_name.strip():
+ return None
+
+ normalized = company_name.lower().strip()
+
+ if normalized in COMPANY_TO_TICKER:
+ return COMPANY_TO_TICKER[normalized]
+
+ normalized_stripped = _normalize_company_name(company_name)
+ if normalized_stripped in COMPANY_TO_TICKER:
+ return COMPANY_TO_TICKER[normalized_stripped]
+
+ if company_name.upper() in [v for v in COMPANY_TO_TICKER.values()]:
+ if validate_us_ticker(company_name.upper()):
+ return company_name.upper()
+
+ logger.info("Using yfinance fallback for company: %s", company_name)
+ yf_ticker = _search_yfinance_ticker(company_name)
+
+ if yf_ticker:
+ if validate_us_ticker(yf_ticker):
+ logger.info("Resolved %s to %s via yfinance", company_name, yf_ticker)
+ return yf_ticker
+ else:
+ logger.warning("Ticker %s for %s failed US exchange validation", yf_ticker, company_name)
+ return None
+
+ logger.warning("Could not resolve ticker for company: %s", company_name)
+ return None
+
+
+def validate_tradeable(ticker: str) -> bool:
+ return validate_us_ticker(ticker)
diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py
index 1f40a2a2..e88868d6 100644
--- a/tradingagents/default_config.py
+++ b/tradingagents/default_config.py
@@ -8,26 +8,24 @@ DEFAULT_CONFIG = {
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
- # LLM settings
"llm_provider": "openai",
- "deep_think_llm": "o4-mini",
- "quick_think_llm": "gpt-4o-mini",
+ "deep_think_llm": "gpt-5",
+ "quick_think_llm": "gpt-5-mini",
"backend_url": "https://api.openai.com/v1",
- # Debate and discussion settings
- "max_debate_rounds": 1,
- "max_risk_discuss_rounds": 1,
+ "max_debate_rounds": 2,
+ "max_risk_discuss_rounds": 2,
"max_recur_limit": 100,
- # Data vendor configuration
- # Category-level configuration (default for all tools in category)
"data_vendors": {
- "core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
- "technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
- "fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
- "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
+ "core_stock_apis": "yfinance",
+ "technical_indicators": "yfinance",
+ "fundamental_data": "alpha_vantage",
+ "news_data": "alpha_vantage",
},
- # Tool-level configuration (takes precedence over category-level)
"tool_vendors": {
- # Example: "get_stock_data": "alpha_vantage", # Override category default
- # Example: "get_news": "openai", # Override category default
},
+ "discovery_timeout": 60,
+ "discovery_hard_timeout": 120,
+ "discovery_cache_ttl": 300,
+ "discovery_max_results": 20,
+ "discovery_min_mentions": 2,
}
diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py
index 40cdff75..ba4c092c 100644
--- a/tradingagents/graph/trading_graph.py
+++ b/tradingagents/graph/trading_graph.py
@@ -1,9 +1,9 @@
-# TradingAgents/graph/trading_graph.py
-
import os
+import signal
+import threading
from pathlib import Path
import json
-from datetime import date
+from datetime import date, datetime
from typing import Dict, Any, Tuple, List, Optional
from langchain_openai import ChatOpenAI
@@ -22,7 +22,6 @@ from tradingagents.agents.utils.agent_states import (
)
from tradingagents.dataflows.config import set_config
-# Import the new abstract tool methods from agent_utils
from tradingagents.agents.utils.agent_utils import (
get_stock_data,
get_indicators,
@@ -36,6 +35,19 @@ from tradingagents.agents.utils.agent_utils import (
get_global_news
)
+from tradingagents.agents.discovery import (
+ DiscoveryRequest,
+ DiscoveryResult,
+ DiscoveryStatus,
+ TrendingStock,
+ Sector,
+ EventCategory,
+ DiscoveryTimeoutError,
+ extract_entities,
+ calculate_trending_scores,
+)
+from tradingagents.dataflows.interface import get_bulk_news
+
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
@@ -43,8 +55,15 @@ from .reflection import Reflector
from .signal_processing import SignalProcessor
+class DiscoveryTimeoutException(Exception):
+ pass
+
+
+def _timeout_handler(signum, frame):
+ raise DiscoveryTimeoutException("Discovery operation timed out")
+
+
class TradingAgentsGraph:
- """Main class that orchestrates the trading agents framework."""
def __init__(
self,
@@ -52,26 +71,16 @@ class TradingAgentsGraph:
debug=False,
config: Dict[str, Any] = None,
):
- """Initialize the trading agents graph and components.
-
- Args:
- selected_analysts: List of analyst types to include
- debug: Whether to run in debug mode
- config: Configuration dictionary. If None, uses default config
- """
self.debug = debug
self.config = config or DEFAULT_CONFIG
- # Update the interface's config
set_config(self.config)
- # Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
- # Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
@@ -83,18 +92,13 @@ class TradingAgentsGraph:
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
else:
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
-
- # Initialize memories
+
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
-
- # Create tool nodes
self.tool_nodes = self._create_tool_nodes()
-
- # Initialize components
self.conditional_logic = ConditionalLogic()
self.graph_setup = GraphSetup(
self.quick_thinking_llm,
@@ -111,35 +115,26 @@ class TradingAgentsGraph:
self.propagator = Propagator()
self.reflector = Reflector(self.quick_thinking_llm)
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
-
- # State tracking
self.curr_state = None
self.ticker = None
- self.log_states_dict = {} # date to full state dict
-
- # Set up the graph
+ self.log_states_dict = {}
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
- """Create tool nodes for different data sources using abstract methods."""
return {
"market": ToolNode(
[
- # Core stock data tools
get_stock_data,
- # Technical indicators
get_indicators,
]
),
"social": ToolNode(
[
- # News tools for social media analysis
get_news,
]
),
"news": ToolNode(
[
- # News and insider information
get_news,
get_global_news,
get_insider_sentiment,
@@ -148,7 +143,6 @@ class TradingAgentsGraph:
),
"fundamentals": ToolNode(
[
- # Fundamental analysis tools
get_fundamentals,
get_balance_sheet,
get_cashflow,
@@ -158,18 +152,13 @@ class TradingAgentsGraph:
}
def propagate(self, company_name, trade_date):
- """Run the trading agents graph for a company on a specific date."""
-
self.ticker = company_name
-
- # Initialize state
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date
)
args = self.propagator.get_graph_args()
if self.debug:
- # Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
@@ -180,20 +169,14 @@ class TradingAgentsGraph:
final_state = trace[-1]
else:
- # Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
- # Store current state for reflection
self.curr_state = final_state
-
- # Log state
self._log_state(trade_date, final_state)
- # Return decision and processed signal
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state):
- """Log the final state to a JSON file."""
self.log_states_dict[str(trade_date)] = {
"company_of_interest": final_state["company_of_interest"],
"trade_date": final_state["trade_date"],
@@ -224,7 +207,6 @@ class TradingAgentsGraph:
"final_trade_decision": final_state["final_trade_decision"],
}
- # Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
directory.mkdir(parents=True, exist_ok=True)
@@ -235,7 +217,6 @@ class TradingAgentsGraph:
json.dump(self.log_states_dict, f, indent=4)
def reflect_and_remember(self, returns_losses):
- """Reflect on decisions and update memory based on returns."""
self.reflector.reflect_bull_researcher(
self.curr_state, returns_losses, self.bull_memory
)
@@ -253,5 +234,95 @@ class TradingAgentsGraph:
)
def process_signal(self, full_signal):
- """Process a signal to extract the core decision."""
return self.signal_processor.process_signal(full_signal)
+
+ def discover_trending(
+ self,
+ request: Optional[DiscoveryRequest] = None,
+ ) -> DiscoveryResult:
+ if request is None:
+ request = DiscoveryRequest(
+ lookback_period="24h",
+ max_results=self.config.get("discovery_max_results", 20),
+ )
+
+ started_at = datetime.now()
+ result = DiscoveryResult(
+ request=request,
+ trending_stocks=[],
+ status=DiscoveryStatus.PROCESSING,
+ started_at=started_at,
+ )
+
+ hard_timeout = self.config.get("discovery_hard_timeout", 120)
+
+ discovery_result = {"stocks": [], "error": None}
+
+ def run_discovery():
+ try:
+ articles = get_bulk_news(request.lookback_period)
+
+ mentions = extract_entities(articles, self.config)
+
+ min_mentions = self.config.get("discovery_min_mentions", 2)
+ max_results = request.max_results or self.config.get("discovery_max_results", 20)
+
+ trending_stocks = calculate_trending_scores(
+ mentions,
+ articles,
+ max_results=max_results,
+ min_mentions=min_mentions,
+ )
+
+ discovery_result["stocks"] = trending_stocks
+ except Exception as e:
+ discovery_result["error"] = str(e)
+
+ discovery_thread = threading.Thread(target=run_discovery)
+ discovery_thread.start()
+ discovery_thread.join(timeout=hard_timeout)
+
+ if discovery_thread.is_alive():
+ raise DiscoveryTimeoutError(
+ f"Discovery operation exceeded {hard_timeout} second timeout"
+ )
+
+ if discovery_result["error"]:
+ result.status = DiscoveryStatus.FAILED
+ result.error_message = discovery_result["error"]
+ result.completed_at = datetime.now()
+ return result
+
+ trending_stocks = discovery_result["stocks"]
+
+ if request.sector_filter:
+ sector_values = {s.value if isinstance(s, Sector) else s for s in request.sector_filter}
+ trending_stocks = [
+ stock for stock in trending_stocks
+ if stock.sector.value in sector_values or stock.sector in request.sector_filter
+ ]
+
+ if request.event_filter:
+ event_values = {e.value if isinstance(e, EventCategory) else e for e in request.event_filter}
+ trending_stocks = [
+ stock for stock in trending_stocks
+ if stock.event_type.value in event_values or stock.event_type in request.event_filter
+ ]
+
+ result.trending_stocks = trending_stocks
+ result.status = DiscoveryStatus.COMPLETED
+ result.completed_at = datetime.now()
+
+ return result
+
+ def analyze_trending(
+ self,
+ trending_stock: TrendingStock,
+ trade_date: Optional[date] = None,
+ ) -> Tuple[Dict[str, Any], str]:
+ ticker = trending_stock.ticker
+
+ if trade_date is None:
+ trade_date = date.today()
+
+ return self.propagate(ticker, trade_date)