From 3f6b1e9f39d63204cd24f9872289866c124f8c52 Mon Sep 17 00:00:00 2001 From: Joseph O'Brien <98370624+89jobrien@users.noreply.github.com> Date: Tue, 2 Dec 2025 20:19:34 -0500 Subject: [PATCH] Add trending stock discovery feature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement a multi-stage pipeline to discover trending stocks from news: - Entity extraction from news articles using LLM - Stock ticker resolution via Yahoo Finance - Sector classification and event categorization - Scoring algorithm based on mentions, sentiment, and recency - CLI integration with interactive stock selection and analysis flow - Persistence layer for saving discovery results - Comprehensive test suite for all discovery components Update README with uv-based installation instructions and remove emojis. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .gitignore | 10 + README.md | 25 +- cli/main.py | 1791 +++++++++++------ main.py | 2 +- tests/discovery/__init__.py | 0 tests/discovery/test_api.py | 200 ++ tests/discovery/test_bulk_news.py | 160 ++ tests/discovery/test_cli.py | 127 ++ tests/discovery/test_entity_extractor.py | 278 +++ tests/discovery/test_integration.py | 489 +++++ tests/discovery/test_models.py | 196 ++ tests/discovery/test_persistence.py | 228 +++ tests/discovery/test_scorer.py | 469 +++++ tests/discovery/test_sector_classifier.py | 94 + tests/discovery/test_stock_resolver.py | 135 ++ tradingagents/agents/discovery/__init__.py | 53 + .../agents/discovery/entity_extractor.py | 159 ++ tradingagents/agents/discovery/exceptions.py | 14 + tradingagents/agents/discovery/models.py | 180 ++ tradingagents/agents/discovery/persistence.py | 120 ++ tradingagents/agents/discovery/scorer.py | 153 ++ tradingagents/agents/utils/agent_states.py | 38 +- tradingagents/agents/utils/agent_utils.py | 10 - tradingagents/agents/utils/memory.py | 76 +- tradingagents/dataflows/alpha_vantage_news.py | 82 +- tradingagents/dataflows/google.py | 55 +- tradingagents/dataflows/interface.py | 170 +- tradingagents/dataflows/openai.py | 93 +- tradingagents/dataflows/trending/__init__.py | 21 + .../dataflows/trending/sector_classifier.py | 267 +++ .../dataflows/trending/stock_resolver.py | 538 +++++ tradingagents/default_config.py | 28 +- tradingagents/graph/trading_graph.py | 163 +- 33 files changed, 5609 insertions(+), 815 deletions(-) create mode 100644 tests/discovery/__init__.py create mode 100644 tests/discovery/test_api.py create mode 100644 tests/discovery/test_bulk_news.py create mode 100644 tests/discovery/test_cli.py create mode 100644 tests/discovery/test_entity_extractor.py create mode 100644 tests/discovery/test_integration.py create mode 100644 tests/discovery/test_models.py create mode 100644 tests/discovery/test_persistence.py create mode 100644 tests/discovery/test_scorer.py create mode 100644 tests/discovery/test_sector_classifier.py create mode 100644 tests/discovery/test_stock_resolver.py create mode 100644 tradingagents/agents/discovery/__init__.py create mode 100644 tradingagents/agents/discovery/entity_extractor.py create mode 100644 tradingagents/agents/discovery/exceptions.py create mode 100644 tradingagents/agents/discovery/models.py create mode 100644 tradingagents/agents/discovery/persistence.py create mode 100644 tradingagents/agents/discovery/scorer.py create mode 100644 tradingagents/dataflows/trending/__init__.py create mode 100644 tradingagents/dataflows/trending/sector_classifier.py create mode 100644 tradingagents/dataflows/trending/stock_resolver.py diff --git a/.gitignore b/.gitignore index 3369bad9..17ceaeb9 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,13 @@ eval_results/ eval_data/ *.egg-info/ .env +.claude/ +.pytest_cache/ +.specify/ +specs/ +agent-os/ +*.local.md +build/ +.mcp.json +*.zip +todos.md diff --git a/README.md b/README.md index 7e90c60f..b8cc9304 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ # TradingAgents: Multi-Agents LLM Financial Trading Framework -> 🎉 **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community. +> **TradingAgents** officially released! We have received numerous inquiries about the work, and we would like to express our thanks for the enthusiasm in our community. > > So we decided to fully open-source the framework. Looking forward to building impactful projects with you! @@ -43,7 +43,7 @@
-🚀 [TradingAgents](#tradingagents-framework) | ⚡ [Installation & CLI](#installation-and-cli) | 🎬 [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | 📦 [Package Usage](#tradingagents-package) | 🤝 [Contributing](#contributing) | 📄 [Citation](#citation) +[TradingAgents](#tradingagents-framework) | [Installation & CLI](#installation-and-cli) | [Demo](https://www.youtube.com/watch?v=90gr5lwjIho) | [Package Usage](#tradingagents-package) | [Contributing](#contributing) | [Citation](#citation)
@@ -101,15 +101,10 @@ git clone https://github.com/TauricResearch/TradingAgents.git cd TradingAgents ``` -Create a virtual environment in any of your favorite environment managers: +Sync virtual environment: ```bash -conda create -n tradingagents python=3.13 -conda activate tradingagents -``` - -Install dependencies: -```bash -pip install -r requirements.txt +uv sync +uv source .venv/bin/activate ``` ### Required APIs @@ -133,7 +128,7 @@ cp .env.example .env You can also try out the CLI directly by running: ```bash -python -m cli.main +uv run cli/main.py ``` You will see a screen where you can select your desired tickers, date, LLMs, research depth, etc. @@ -204,13 +199,9 @@ print(decision) You can view the full list of configurations in `tradingagents/default_config.py`. -## Contributing +## Source -We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/). - -## Citation - -Please reference our work if you find *TradingAgents* provides you with some help :) +Thanks to Yijia Xiao and Edward Sun and Di Luo and Wei Wang. Core agent implementation based on [TradingAgents: Multi-Agents LLM Financial Trading Framework](https://arxiv.org/abs/2412.20138) ``` @misc{xiao2025tradingagentsmultiagentsllmfinancial, diff --git a/cli/main.py b/cli/main.py index 2e06d50c..d5efcf3c 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List import datetime import typer from pathlib import Path @@ -6,7 +6,6 @@ from functools import wraps from rich.console import Console from dotenv import load_dotenv -# Load environment variables from .env file load_dotenv() from rich.panel import Panel from rich.spinner import Spinner @@ -15,7 +14,6 @@ from rich.columns import Columns from rich.markdown import Markdown from rich.layout import Layout from rich.text import Text -from rich.live import Live from rich.table import Table from collections import deque import time @@ -23,9 +21,19 @@ from rich.tree import Tree from rich import box from rich.align import Align from rich.rule import Rule +import questionary from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.agents.discovery.models import ( + DiscoveryRequest, + DiscoveryResult, + DiscoveryStatus, + TrendingStock, + Sector, + EventCategory, +) +from tradingagents.agents.discovery.persistence import save_discovery_result from cli.models import AnalystType from cli.utils import * @@ -34,34 +42,28 @@ console = Console() app = typer.Typer( name="TradingAgents", help="TradingAgents CLI: Multi-Agents LLM Financial Trading Framework", - add_completion=True, # Enable shell completion + add_completion=True, ) -# Create a deque to store recent messages with a maximum length class MessageBuffer: def __init__(self, max_length=100): self.messages = deque(maxlen=max_length) self.tool_calls = deque(maxlen=max_length) self.current_report = None - self.final_report = None # Store the complete final report + self.final_report = None self.agent_status = { - # Analyst Team "Market Analyst": "pending", "Social Analyst": "pending", "News Analyst": "pending", "Fundamentals Analyst": "pending", - # Research Team "Bull Researcher": "pending", "Bear Researcher": "pending", "Research Manager": "pending", - # Trading Team "Trader": "pending", - # Risk Management Team "Risky Analyst": "pending", "Neutral Analyst": "pending", "Safe Analyst": "pending", - # Portfolio Management Team "Portfolio Manager": "pending", } self.current_agent = None @@ -94,18 +96,15 @@ class MessageBuffer: self._update_current_report() def _update_current_report(self): - # For the panel display, only show the most recently updated section latest_section = None latest_content = None - # Find the most recently updated section for section, content in self.report_sections.items(): if content is not None: latest_section = section latest_content = content - + if latest_section and latest_content: - # Format the current section for display section_titles = { "market_report": "Market Analysis", "sentiment_report": "Social Sentiment", @@ -119,13 +118,11 @@ class MessageBuffer: f"### {section_titles[latest_section]}\n{latest_content}" ) - # Update the final complete report self._update_final_report() def _update_final_report(self): report_parts = [] - # Analyst Team Reports if any( self.report_sections[section] for section in [ @@ -153,17 +150,14 @@ class MessageBuffer: f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" ) - # Research Team Reports if self.report_sections["investment_plan"]: report_parts.append("## Research Team Decision") report_parts.append(f"{self.report_sections['investment_plan']}") - # Trading Team Reports if self.report_sections["trader_investment_plan"]: report_parts.append("## Trading Team Plan") report_parts.append(f"{self.report_sections['trader_investment_plan']}") - # Portfolio Management Decision if self.report_sections["final_trade_decision"]: report_parts.append("## Portfolio Management Decision") report_parts.append(f"{self.report_sections['final_trade_decision']}") @@ -174,284 +168,396 @@ class MessageBuffer: message_buffer = MessageBuffer() -def create_layout(): - layout = Layout() - layout.split_column( - Layout(name="header", size=3), - Layout(name="main"), - Layout(name="footer", size=3), - ) - layout["main"].split_column( - Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5) - ) - layout["upper"].split_row( - Layout(name="progress", ratio=2), Layout(name="messages", ratio=3) - ) - return layout +LOOKBACK_OPTIONS = [ + ("Last hour (1h)", "1h"), + ("Last 6 hours (6h)", "6h"), + ("Last 24 hours (24h)", "24h"), + ("Last 7 days (7d)", "7d"), +] + +SECTOR_OPTIONS = [ + ("Technology", Sector.TECHNOLOGY), + ("Healthcare", Sector.HEALTHCARE), + ("Finance", Sector.FINANCE), + ("Energy", Sector.ENERGY), + ("Consumer Goods", Sector.CONSUMER_GOODS), + ("Industrials", Sector.INDUSTRIALS), + ("Other", Sector.OTHER), +] + +EVENT_OPTIONS = [ + ("Earnings", EventCategory.EARNINGS), + ("Merger/Acquisition", EventCategory.MERGER_ACQUISITION), + ("Regulatory", EventCategory.REGULATORY), + ("Product Launch", EventCategory.PRODUCT_LAUNCH), + ("Executive Change", EventCategory.EXECUTIVE_CHANGE), + ("Other", EventCategory.OTHER), +] -def update_display(layout, spinner_text=None): - # Header with welcome message - layout["header"].update( - Panel( - "[bold green]Welcome to TradingAgents CLI[/bold green]\n" - "[dim]© [Tauric Research](https://github.com/TauricResearch)[/dim]", - title="Welcome to TradingAgents", - border_style="green", - padding=(1, 2), - expand=True, - ) - ) +def create_question_box(title: str, prompt: str, default: str = None) -> Panel: + box_content = f"[bold]{title}[/bold]\n" + box_content += f"[dim]{prompt}[/dim]" + if default: + box_content += f"\n[dim]Default: {default}[/dim]" + return Panel(box_content, border_style="blue", padding=(1, 2)) - # Progress panel showing agent status - progress_table = Table( - show_header=True, - header_style="bold magenta", - show_footer=False, - box=box.SIMPLE_HEAD, # Use simple header with horizontal lines - title=None, # Remove the redundant Progress title - padding=(0, 2), # Add horizontal padding - expand=True, # Make table expand to fill available space - ) - progress_table.add_column("Team", style="cyan", justify="center", width=20) - progress_table.add_column("Agent", style="green", justify="center", width=20) - progress_table.add_column("Status", style="yellow", justify="center", width=20) - # Group agents by team - teams = { - "Analyst Team": [ - "Market Analyst", - "Social Analyst", - "News Analyst", - "Fundamentals Analyst", +def select_lookback_period() -> str: + choice = questionary.select( + "Select lookback period:", + choices=[ + questionary.Choice(display, value=value) for display, value in LOOKBACK_OPTIONS ], - "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], - "Trading Team": ["Trader"], - "Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"], - "Portfolio Management": ["Portfolio Manager"], - } + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ("pointer", "fg:cyan noinherit"), + ] + ), + ).ask() - for team, agents in teams.items(): - # Add first agent with team name - first_agent = agents[0] - status = message_buffer.agent_status[first_agent] - if status == "in_progress": - spinner = Spinner( - "dots", text="[blue]in_progress[/blue]", style="bold cyan" - ) - status_cell = spinner - else: - status_color = { - "pending": "yellow", - "completed": "green", - "error": "red", - }.get(status, "white") - status_cell = f"[{status_color}]{status}[/{status_color}]" - progress_table.add_row(team, first_agent, status_cell) + if choice is None: + console.print("\n[red]No lookback period selected. Exiting...[/red]") + exit(1) - # Add remaining agents in team - for agent in agents[1:]: - status = message_buffer.agent_status[agent] - if status == "in_progress": - spinner = Spinner( - "dots", text="[blue]in_progress[/blue]", style="bold cyan" - ) - status_cell = spinner - else: - status_color = { - "pending": "yellow", - "completed": "green", - "error": "red", - }.get(status, "white") - status_cell = f"[{status_color}]{status}[/{status_color}]" - progress_table.add_row("", agent, status_cell) + return choice - # Add horizontal line after each team - progress_table.add_row("─" * 20, "─" * 20, "─" * 20, style="dim") - layout["progress"].update( - Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)) - ) +def select_sector_filter() -> Optional[List[Sector]]: + use_filter = questionary.confirm( + "Filter by sector?", + default=False, + style=questionary.Style( + [ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ] + ), + ).ask() - # Messages panel showing recent messages and tool calls - messages_table = Table( + if not use_filter: + return None + + choices = questionary.checkbox( + "Select sectors to include:", + choices=[ + questionary.Choice(display, value=value) for display, value in SECTOR_OPTIONS + ], + instruction="\n- Press Space to select/unselect\n- Press 'a' to select all\n- Press Enter when done", + style=questionary.Style( + [ + ("checkbox-selected", "fg:cyan"), + ("selected", "fg:cyan noinherit"), + ("highlighted", "noinherit"), + ("pointer", "noinherit"), + ] + ), + ).ask() + + if not choices: + return None + + return choices + + +def select_event_filter() -> Optional[List[EventCategory]]: + use_filter = questionary.confirm( + "Filter by event type?", + default=False, + style=questionary.Style( + [ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ] + ), + ).ask() + + if not use_filter: + return None + + choices = questionary.checkbox( + "Select event types to include:", + choices=[ + questionary.Choice(display, value=value) for display, value in EVENT_OPTIONS + ], + instruction="\n- Press Space to select/unselect\n- Press 'a' to select all\n- Press Enter when done", + style=questionary.Style( + [ + ("checkbox-selected", "fg:cyan"), + ("selected", "fg:cyan noinherit"), + ("highlighted", "noinherit"), + ("pointer", "noinherit"), + ] + ), + ).ask() + + if not choices: + return None + + return choices + + +def create_discovery_results_table(trending_stocks: List[TrendingStock]) -> Table: + table = Table( show_header=True, header_style="bold magenta", - show_footer=False, - expand=True, # Make table expand to fill available space - box=box.MINIMAL, # Use minimal box style for a lighter look - show_lines=True, # Keep horizontal lines - padding=(0, 1), # Add some padding between columns + box=box.ROUNDED, + title="Trending Stocks", + title_style="bold green", + expand=True, ) - messages_table.add_column("Time", style="cyan", width=8, justify="center") - messages_table.add_column("Type", style="green", width=10, justify="center") - messages_table.add_column( - "Content", style="white", no_wrap=False, ratio=1 - ) # Make content column expand - # Combine tool calls and messages - all_messages = [] + table.add_column("Rank", style="cyan", justify="center", width=6) + table.add_column("Ticker", style="bold yellow", justify="center", width=10) + table.add_column("Company", style="white", justify="left", width=25) + table.add_column("Score", style="green", justify="right", width=10) + table.add_column("Mentions", style="blue", justify="center", width=10) + table.add_column("Event Type", style="magenta", justify="center", width=18) - # Add tool calls - for timestamp, tool_name, args in message_buffer.tool_calls: - # Truncate tool call args if too long - if isinstance(args, str) and len(args) > 100: - args = args[:97] + "..." - all_messages.append((timestamp, "Tool", f"{tool_name}: {args}")) + for rank, stock in enumerate(trending_stocks, 1): + if rank <= 3: + rank_display = f"[bold green]{rank}[/bold green]" + ticker_display = f"[bold yellow]{stock.ticker}[/bold yellow]" + else: + rank_display = str(rank) + ticker_display = stock.ticker - # Add regular messages - for timestamp, msg_type, content in message_buffer.messages: - # Convert content to string if it's not already - content_str = content - if isinstance(content, list): - # Handle list of content blocks (Anthropic format) - text_parts = [] - for item in content: - if isinstance(item, dict): - if item.get('type') == 'text': - text_parts.append(item.get('text', '')) - elif item.get('type') == 'tool_use': - text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") - else: - text_parts.append(str(item)) - content_str = ' '.join(text_parts) - elif not isinstance(content_str, str): - content_str = str(content) - - # Truncate message content if too long - if len(content_str) > 200: - content_str = content_str[:197] + "..." - all_messages.append((timestamp, msg_type, content_str)) - - # Sort by timestamp - all_messages.sort(key=lambda x: x[0]) - - # Calculate how many messages we can show based on available space - # Start with a reasonable number and adjust based on content length - max_messages = 12 # Increased from 8 to better fill the space - - # Get the last N messages that will fit in the panel - recent_messages = all_messages[-max_messages:] - - # Add messages to table - for timestamp, msg_type, content in recent_messages: - # Format content with word wrapping - wrapped_content = Text(content, overflow="fold") - messages_table.add_row(timestamp, msg_type, wrapped_content) - - if spinner_text: - messages_table.add_row("", "Spinner", spinner_text) - - # Add a footer to indicate if messages were truncated - if len(all_messages) > max_messages: - messages_table.footer = ( - f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]" + table.add_row( + rank_display, + ticker_display, + stock.company_name[:25] if len(stock.company_name) > 25 else stock.company_name, + f"{stock.score:.2f}", + str(stock.mention_count), + stock.event_type.value.replace("_", " ").title(), ) - layout["messages"].update( - Panel( - messages_table, - title="Messages & Tools", - border_style="blue", - padding=(1, 2), - ) - ) - - # Analysis panel showing current report - if message_buffer.current_report: - layout["analysis"].update( - Panel( - Markdown(message_buffer.current_report), - title="Current Report", - border_style="green", - padding=(1, 2), - ) - ) - else: - layout["analysis"].update( - Panel( - "[italic]Waiting for analysis report...[/italic]", - title="Current Report", - border_style="green", - padding=(1, 2), - ) - ) - - # Footer with statistics - tool_calls_count = len(message_buffer.tool_calls) - llm_calls_count = sum( - 1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning" - ) - reports_count = sum( - 1 for content in message_buffer.report_sections.values() if content is not None - ) - - stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True) - stats_table.add_column("Stats", justify="center") - stats_table.add_row( - f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}" - ) - - layout["footer"].update(Panel(stats_table, border_style="grey50")) + return table -def get_user_selections(): - """Get all user selections before starting the analysis display.""" - # Display ASCII art welcome message - with open("./cli/static/welcome.txt", "r") as f: - welcome_ascii = f.read() +def create_stock_detail_panel(stock: TrendingStock, rank: int) -> Panel: + sentiment_label = "positive" if stock.sentiment > 0.3 else "negative" if stock.sentiment < -0.3 else "neutral" + sentiment_color = "green" if stock.sentiment > 0.3 else "red" if stock.sentiment < -0.3 else "yellow" - # Create welcome box content - welcome_content = f"{welcome_ascii}\n" - welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" - welcome_content += "[bold]Workflow Steps:[/bold]\n" - welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n" - welcome_content += ( - "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]" - ) + content = f"""[bold]Rank #{rank}: {stock.ticker} - {stock.company_name}[/bold] - # Create and center the welcome box - welcome_box = Panel( - welcome_content, - border_style="green", +[cyan]Score:[/cyan] {stock.score:.2f} +[cyan]Sentiment:[/cyan] [{sentiment_color}]{stock.sentiment:.2f} ({sentiment_label})[/{sentiment_color}] +[cyan]Sector:[/cyan] {stock.sector.value.replace("_", " ").title()} +[cyan]Event Type:[/cyan] {stock.event_type.value.replace("_", " ").title()} +[cyan]Mentions:[/cyan] {stock.mention_count} + +[bold]News Summary:[/bold] +{stock.news_summary} + +[bold]Top Source Articles:[/bold]""" + + for i, article in enumerate(stock.source_articles[:3], 1): + content += f"\n {i}. [{article.title[:50]}...] - {article.source}" + + return Panel( + content, + title=f"Stock Details: {stock.ticker}", + border_style="cyan", padding=(1, 2), - title="Welcome to TradingAgents", - subtitle="Multi-Agents LLM Financial Trading Framework", ) - console.print(Align.center(welcome_box)) - console.print() # Add a blank line after the welcome box - # Create a boxed questionnaire for each step - def create_question_box(title, prompt, default=None): - box_content = f"[bold]{title}[/bold]\n" - box_content += f"[dim]{prompt}[/dim]" - if default: - box_content += f"\n[dim]Default: {default}[/dim]" - return Panel(box_content, border_style="blue", padding=(1, 2)) - # Step 1: Ticker symbol +def select_stock_for_detail(trending_stocks: List[TrendingStock]) -> Optional[TrendingStock]: + if not trending_stocks: + return None + + choices = [ + questionary.Choice( + f"{i+1}. {stock.ticker} - {stock.company_name} (Score: {stock.score:.2f})", + value=stock + ) + for i, stock in enumerate(trending_stocks) + ] + choices.append(questionary.Choice("Back to menu", value=None)) + + selected = questionary.select( + "Select a stock to view details:", + choices=choices, + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:cyan noinherit"), + ("highlighted", "fg:cyan noinherit"), + ("pointer", "fg:cyan noinherit"), + ] + ), + ).ask() + + return selected + + +def discover_trending_flow(): + console.print(Rule("[bold green]Discover Trending Stocks[/bold green]")) + console.print() + console.print( create_question_box( - "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY" + "Step 1: Lookback Period", + "Select how far back to search for trending stocks" ) ) - selected_ticker = get_ticker() + lookback_period = select_lookback_period() + console.print(f"[green]Selected lookback period:[/green] {lookback_period}") + console.print() - # Step 2: Analysis date - default_date = datetime.datetime.now().strftime("%Y-%m-%d") console.print( create_question_box( - "Step 2: Analysis Date", - "Enter the analysis date (YYYY-MM-DD)", - default_date, + "Step 2: Sector Filter (Optional)", + "Optionally filter results by sector" ) ) - analysis_date = get_analysis_date() + sector_filter = select_sector_filter() + if sector_filter: + console.print(f"[green]Selected sectors:[/green] {', '.join(s.value for s in sector_filter)}") + else: + console.print("[dim]No sector filter applied[/dim]") + console.print() - # Step 3: Select analysts console.print( create_question_box( - "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis" + "Step 3: Event Filter (Optional)", + "Optionally filter results by event type" + ) + ) + event_filter = select_event_filter() + if event_filter: + console.print(f"[green]Selected events:[/green] {', '.join(e.value for e in event_filter)}") + else: + console.print("[dim]No event filter applied[/dim]") + console.print() + + console.print( + create_question_box( + "Step 4: LLM Provider", + "Select your LLM provider for entity extraction" + ) + ) + selected_llm_provider, backend_url = select_llm_provider() + console.print() + + console.print( + create_question_box( + "Step 5: Quick-Thinking Model", + "Select the model for entity extraction" + ) + ) + selected_model = select_shallow_thinking_agent(selected_llm_provider) + console.print() + + config = DEFAULT_CONFIG.copy() + config["llm_provider"] = selected_llm_provider.lower() + config["backend_url"] = backend_url + config["quick_think_llm"] = selected_model + config["deep_think_llm"] = selected_model + + request = DiscoveryRequest( + lookback_period=lookback_period, + sector_filter=sector_filter, + event_filter=event_filter, + max_results=config.get("discovery_max_results", 20), + ) + + discovery_stages = [ + "Fetching news...", + "Extracting entities...", + "Resolving tickers...", + "Calculating scores...", + ] + + result = None + with Live(console=console, refresh_per_second=4) as live: + for i, stage in enumerate(discovery_stages): + progress_panel = Panel( + f"[bold cyan]{stage}[/bold cyan]\n\n" + f"[dim]Stage {i+1} of {len(discovery_stages)}[/dim]", + title="Discovery Progress", + border_style="cyan", + padding=(2, 4), + ) + live.update(Align.center(progress_panel)) + + if i == 0: + try: + graph = TradingAgentsGraph(config=config, debug=False) + result = graph.discover_trending(request) + except Exception as e: + console.print(f"\n[red]Error during discovery: {e}[/red]") + return + + time.sleep(0.5) + + if result is None: + console.print("\n[red]Discovery failed. Please try again.[/red]") + return + + if result.status == DiscoveryStatus.FAILED: + console.print(f"\n[red]Discovery failed: {result.error_message}[/red]") + return + + if result.status == DiscoveryStatus.COMPLETED: + try: + save_path = save_discovery_result(result) + console.print(f"\n[dim]Results saved to: {save_path}[/dim]") + except Exception as e: + console.print(f"\n[yellow]Warning: Could not save results: {e}[/yellow]") + + console.print() + + if not result.trending_stocks: + console.print("[yellow]No trending stocks found matching your criteria.[/yellow]") + return + + console.print(f"[green]Found {len(result.trending_stocks)} trending stocks[/green]") + console.print() + + results_table = create_discovery_results_table(result.trending_stocks) + console.print(results_table) + console.print() + + while True: + selected_stock = select_stock_for_detail(result.trending_stocks) + + if selected_stock is None: + break + + rank = result.trending_stocks.index(selected_stock) + 1 + detail_panel = create_stock_detail_panel(selected_stock, rank) + console.print() + console.print(detail_panel) + console.print() + + analyze_choice = questionary.confirm( + f"Analyze {selected_stock.ticker}?", + default=False, + style=questionary.Style( + [ + ("selected", "fg:green noinherit"), + ("highlighted", "fg:green noinherit"), + ] + ), + ).ask() + + if analyze_choice: + console.print(f"\n[green]Starting analysis for {selected_stock.ticker}...[/green]\n") + run_analysis_for_ticker(selected_stock.ticker, config) + break + + +def run_analysis_for_ticker(ticker: str, config: dict): + analysis_date = datetime.datetime.now().strftime("%Y-%m-%d") + + console.print( + create_question_box( + "Analysts Team", + "Select your LLM analyst agents for the analysis" ) ) selected_analysts = select_analysts() @@ -459,302 +565,32 @@ def get_user_selections(): f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}" ) - # Step 4: Research depth console.print( create_question_box( - "Step 4: Research Depth", "Select your research depth level" + "Research Depth", + "Select your research depth level" ) ) selected_research_depth = select_research_depth() - # Step 5: OpenAI backend console.print( create_question_box( - "Step 5: OpenAI backend", "Select which service to talk to" + "Deep-Thinking Model", + "Select the model for deep analysis" ) ) - selected_llm_provider, backend_url = select_llm_provider() - - # Step 6: Thinking agents - console.print( - create_question_box( - "Step 6: Thinking Agents", "Select your thinking agents for analysis" - ) - ) - selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) - selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) + llm_provider = config.get("llm_provider", "openai") + selected_deep_thinker = select_deep_thinking_agent(llm_provider.capitalize()) - return { - "ticker": selected_ticker, - "analysis_date": analysis_date, - "analysts": selected_analysts, - "research_depth": selected_research_depth, - "llm_provider": selected_llm_provider.lower(), - "backend_url": backend_url, - "shallow_thinker": selected_shallow_thinker, - "deep_thinker": selected_deep_thinker, - } + config["max_debate_rounds"] = selected_research_depth + config["max_risk_discuss_rounds"] = selected_research_depth + config["deep_think_llm"] = selected_deep_thinker - -def get_ticker(): - """Get ticker symbol from user input.""" - return typer.prompt("", default="SPY") - - -def get_analysis_date(): - """Get the analysis date from user input.""" - while True: - date_str = typer.prompt( - "", default=datetime.datetime.now().strftime("%Y-%m-%d") - ) - try: - # Validate date format and ensure it's not in the future - analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d") - if analysis_date.date() > datetime.datetime.now().date(): - console.print("[red]Error: Analysis date cannot be in the future[/red]") - continue - return date_str - except ValueError: - console.print( - "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]" - ) - - -def display_complete_report(final_state): - """Display the complete analysis report with team-based panels.""" - console.print("\n[bold green]Complete Analysis Report[/bold green]\n") - - # I. Analyst Team Reports - analyst_reports = [] - - # Market Analyst Report - if final_state.get("market_report"): - analyst_reports.append( - Panel( - Markdown(final_state["market_report"]), - title="Market Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # Social Analyst Report - if final_state.get("sentiment_report"): - analyst_reports.append( - Panel( - Markdown(final_state["sentiment_report"]), - title="Social Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # News Analyst Report - if final_state.get("news_report"): - analyst_reports.append( - Panel( - Markdown(final_state["news_report"]), - title="News Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # Fundamentals Analyst Report - if final_state.get("fundamentals_report"): - analyst_reports.append( - Panel( - Markdown(final_state["fundamentals_report"]), - title="Fundamentals Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - if analyst_reports: - console.print( - Panel( - Columns(analyst_reports, equal=True, expand=True), - title="I. Analyst Team Reports", - border_style="cyan", - padding=(1, 2), - ) - ) - - # II. Research Team Reports - if final_state.get("investment_debate_state"): - research_reports = [] - debate_state = final_state["investment_debate_state"] - - # Bull Researcher Analysis - if debate_state.get("bull_history"): - research_reports.append( - Panel( - Markdown(debate_state["bull_history"]), - title="Bull Researcher", - border_style="blue", - padding=(1, 2), - ) - ) - - # Bear Researcher Analysis - if debate_state.get("bear_history"): - research_reports.append( - Panel( - Markdown(debate_state["bear_history"]), - title="Bear Researcher", - border_style="blue", - padding=(1, 2), - ) - ) - - # Research Manager Decision - if debate_state.get("judge_decision"): - research_reports.append( - Panel( - Markdown(debate_state["judge_decision"]), - title="Research Manager", - border_style="blue", - padding=(1, 2), - ) - ) - - if research_reports: - console.print( - Panel( - Columns(research_reports, equal=True, expand=True), - title="II. Research Team Decision", - border_style="magenta", - padding=(1, 2), - ) - ) - - # III. Trading Team Reports - if final_state.get("trader_investment_plan"): - console.print( - Panel( - Panel( - Markdown(final_state["trader_investment_plan"]), - title="Trader", - border_style="blue", - padding=(1, 2), - ), - title="III. Trading Team Plan", - border_style="yellow", - padding=(1, 2), - ) - ) - - # IV. Risk Management Team Reports - if final_state.get("risk_debate_state"): - risk_reports = [] - risk_state = final_state["risk_debate_state"] - - # Aggressive (Risky) Analyst Analysis - if risk_state.get("risky_history"): - risk_reports.append( - Panel( - Markdown(risk_state["risky_history"]), - title="Aggressive Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # Conservative (Safe) Analyst Analysis - if risk_state.get("safe_history"): - risk_reports.append( - Panel( - Markdown(risk_state["safe_history"]), - title="Conservative Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # Neutral Analyst Analysis - if risk_state.get("neutral_history"): - risk_reports.append( - Panel( - Markdown(risk_state["neutral_history"]), - title="Neutral Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - if risk_reports: - console.print( - Panel( - Columns(risk_reports, equal=True, expand=True), - title="IV. Risk Management Team Decision", - border_style="red", - padding=(1, 2), - ) - ) - - # V. Portfolio Manager Decision - if risk_state.get("judge_decision"): - console.print( - Panel( - Panel( - Markdown(risk_state["judge_decision"]), - title="Portfolio Manager", - border_style="blue", - padding=(1, 2), - ), - title="V. Portfolio Manager Decision", - border_style="green", - padding=(1, 2), - ) - ) - - -def update_research_team_status(status): - """Update status for all research team members and trader.""" - research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"] - for agent in research_team: - message_buffer.update_agent_status(agent, status) - -def extract_content_string(content): - """Extract string content from various message formats.""" - if isinstance(content, str): - return content - elif isinstance(content, list): - # Handle Anthropic's list format - text_parts = [] - for item in content: - if isinstance(item, dict): - if item.get('type') == 'text': - text_parts.append(item.get('text', '')) - elif item.get('type') == 'tool_use': - text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") - else: - text_parts.append(str(item)) - return ' '.join(text_parts) - else: - return str(content) - -def run_analysis(): - # First get all user selections - selections = get_user_selections() - - # Create config with selected research depth - config = DEFAULT_CONFIG.copy() - config["max_debate_rounds"] = selections["research_depth"] - config["max_risk_discuss_rounds"] = selections["research_depth"] - config["quick_think_llm"] = selections["shallow_thinker"] - config["deep_think_llm"] = selections["deep_thinker"] - config["backend_url"] = selections["backend_url"] - config["llm_provider"] = selections["llm_provider"].lower() - - # Initialize the graph graph = TradingAgentsGraph( - [analyst.value for analyst in selections["analysts"]], config=config, debug=True + [analyst.value for analyst in selected_analysts], config=config, debug=True ) - # Create result directory - results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] + results_dir = Path(config["results_dir"]) / ticker / analysis_date results_dir.mkdir(parents=True, exist_ok=True) report_dir = results_dir / "reports" report_dir.mkdir(parents=True, exist_ok=True) @@ -767,11 +603,757 @@ def run_analysis(): def wrapper(*args, **kwargs): func(*args, **kwargs) timestamp, message_type, content = obj.messages[-1] - content = content.replace("\n", " ") # Replace newlines with spaces + content = content.replace("\n", " ") with open(log_file, "a") as f: f.write(f"{timestamp} [{message_type}] {content}\n") return wrapper - + + def save_tool_call_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(*args, **kwargs): + func(*args, **kwargs) + timestamp, tool_name, args = obj.tool_calls[-1] + args_str = ", ".join(f"{k}={v}" for k, v in args.items()) + with open(log_file, "a") as f: + f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") + return wrapper + + def save_report_section_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(section_name, content): + func(section_name, content) + if section_name in obj.report_sections and obj.report_sections[section_name] is not None: + content = obj.report_sections[section_name] + if content: + file_name = f"{section_name}.md" + with open(report_dir / file_name, "w") as f: + f.write(content) + return wrapper + + message_buffer.add_message = save_message_decorator(message_buffer, "add_message") + message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call") + message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section") + + layout = create_layout() + + with Live(layout, refresh_per_second=4) as live: + update_display(layout) + + message_buffer.add_message("System", f"Selected ticker: {ticker}") + message_buffer.add_message("System", f"Analysis date: {analysis_date}") + message_buffer.add_message( + "System", + f"Selected analysts: {', '.join(analyst.value for analyst in selected_analysts)}", + ) + update_display(layout) + + for agent in message_buffer.agent_status: + message_buffer.update_agent_status(agent, "pending") + + for section in message_buffer.report_sections: + message_buffer.report_sections[section] = None + message_buffer.current_report = None + message_buffer.final_report = None + + first_analyst = f"{selected_analysts[0].value.capitalize()} Analyst" + message_buffer.update_agent_status(first_analyst, "in_progress") + update_display(layout) + + spinner_text = f"Analyzing {ticker} on {analysis_date}..." + update_display(layout, spinner_text) + + init_agent_state = graph.propagator.create_initial_state(ticker, analysis_date) + args = graph.propagator.get_graph_args() + + trace = [] + for chunk in graph.graph.stream(init_agent_state, **args): + if len(chunk["messages"]) > 0: + last_message = chunk["messages"][-1] + + if hasattr(last_message, "content"): + content = extract_content_string(last_message.content) + msg_type = "Reasoning" + else: + content = str(last_message) + msg_type = "System" + + message_buffer.add_message(msg_type, content) + + if hasattr(last_message, "tool_calls"): + for tool_call in last_message.tool_calls: + if isinstance(tool_call, dict): + message_buffer.add_tool_call(tool_call["name"], tool_call["args"]) + else: + message_buffer.add_tool_call(tool_call.name, tool_call.args) + + process_chunk_for_display(chunk, selected_analysts) + update_display(layout) + + trace.append(chunk) + + final_state = trace[-1] + decision = graph.process_signal(final_state["final_trade_decision"]) + + for agent in message_buffer.agent_status: + message_buffer.update_agent_status(agent, "completed") + + message_buffer.add_message("Analysis", f"Completed analysis for {analysis_date}") + + for section in message_buffer.report_sections.keys(): + if section in final_state: + message_buffer.update_report_section(section, final_state[section]) + + display_complete_report(final_state) + update_display(layout) + + +def process_chunk_for_display(chunk, selected_analysts): + if "market_report" in chunk and chunk["market_report"]: + message_buffer.update_report_section("market_report", chunk["market_report"]) + message_buffer.update_agent_status("Market Analyst", "completed") + if AnalystType.SOCIAL in selected_analysts: + message_buffer.update_agent_status("Social Analyst", "in_progress") + + if "sentiment_report" in chunk and chunk["sentiment_report"]: + message_buffer.update_report_section("sentiment_report", chunk["sentiment_report"]) + message_buffer.update_agent_status("Social Analyst", "completed") + if AnalystType.NEWS in selected_analysts: + message_buffer.update_agent_status("News Analyst", "in_progress") + + if "news_report" in chunk and chunk["news_report"]: + message_buffer.update_report_section("news_report", chunk["news_report"]) + message_buffer.update_agent_status("News Analyst", "completed") + if AnalystType.FUNDAMENTALS in selected_analysts: + message_buffer.update_agent_status("Fundamentals Analyst", "in_progress") + + if "fundamentals_report" in chunk and chunk["fundamentals_report"]: + message_buffer.update_report_section("fundamentals_report", chunk["fundamentals_report"]) + message_buffer.update_agent_status("Fundamentals Analyst", "completed") + update_research_team_status("in_progress") + + if "investment_debate_state" in chunk and chunk["investment_debate_state"]: + debate_state = chunk["investment_debate_state"] + + if "bull_history" in debate_state and debate_state["bull_history"]: + update_research_team_status("in_progress") + bull_responses = debate_state["bull_history"].split("\n") + latest_bull = bull_responses[-1] if bull_responses else "" + if latest_bull: + message_buffer.add_message("Reasoning", latest_bull) + message_buffer.update_report_section( + "investment_plan", + f"### Bull Researcher Analysis\n{latest_bull}", + ) + + if "bear_history" in debate_state and debate_state["bear_history"]: + update_research_team_status("in_progress") + bear_responses = debate_state["bear_history"].split("\n") + latest_bear = bear_responses[-1] if bear_responses else "" + if latest_bear: + message_buffer.add_message("Reasoning", latest_bear) + message_buffer.update_report_section( + "investment_plan", + f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}", + ) + + if "judge_decision" in debate_state and debate_state["judge_decision"]: + update_research_team_status("in_progress") + message_buffer.add_message( + "Reasoning", + f"Research Manager: {debate_state['judge_decision']}", + ) + message_buffer.update_report_section( + "investment_plan", + f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}", + ) + update_research_team_status("completed") + message_buffer.update_agent_status("Risky Analyst", "in_progress") + + if "trader_investment_plan" in chunk and chunk["trader_investment_plan"]: + message_buffer.update_report_section("trader_investment_plan", chunk["trader_investment_plan"]) + message_buffer.update_agent_status("Risky Analyst", "in_progress") + + if "risk_debate_state" in chunk and chunk["risk_debate_state"]: + risk_state = chunk["risk_debate_state"] + + if "current_risky_response" in risk_state and risk_state["current_risky_response"]: + message_buffer.update_agent_status("Risky Analyst", "in_progress") + message_buffer.add_message( + "Reasoning", + f"Risky Analyst: {risk_state['current_risky_response']}", + ) + message_buffer.update_report_section( + "final_trade_decision", + f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}", + ) + + if "current_safe_response" in risk_state and risk_state["current_safe_response"]: + message_buffer.update_agent_status("Safe Analyst", "in_progress") + message_buffer.add_message( + "Reasoning", + f"Safe Analyst: {risk_state['current_safe_response']}", + ) + message_buffer.update_report_section( + "final_trade_decision", + f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}", + ) + + if "current_neutral_response" in risk_state and risk_state["current_neutral_response"]: + message_buffer.update_agent_status("Neutral Analyst", "in_progress") + message_buffer.add_message( + "Reasoning", + f"Neutral Analyst: {risk_state['current_neutral_response']}", + ) + message_buffer.update_report_section( + "final_trade_decision", + f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}", + ) + + if "judge_decision" in risk_state and risk_state["judge_decision"]: + message_buffer.update_agent_status("Portfolio Manager", "in_progress") + message_buffer.add_message( + "Reasoning", + f"Portfolio Manager: {risk_state['judge_decision']}", + ) + message_buffer.update_report_section( + "final_trade_decision", + f"### Portfolio Manager Decision\n{risk_state['judge_decision']}", + ) + message_buffer.update_agent_status("Risky Analyst", "completed") + message_buffer.update_agent_status("Safe Analyst", "completed") + message_buffer.update_agent_status("Neutral Analyst", "completed") + message_buffer.update_agent_status("Portfolio Manager", "completed") + + +def create_layout(): + layout = Layout() + layout.split_column( + Layout(name="header", size=3), + Layout(name="main"), + Layout(name="footer", size=3), + ) + layout["main"].split_column( + Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5) + ) + layout["upper"].split_row( + Layout(name="progress", ratio=2), Layout(name="messages", ratio=3) + ) + return layout + + +def update_display(layout, spinner_text=None): + layout["header"].update( + Panel( + "[bold green]Welcome to TradingAgents CLI[/bold green]\n" + "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]", + title="Welcome to TradingAgents", + border_style="green", + padding=(1, 2), + expand=True, + ) + ) + + progress_table = Table( + show_header=True, + header_style="bold magenta", + show_footer=False, + box=box.SIMPLE_HEAD, + title=None, + padding=(0, 2), + expand=True, + ) + progress_table.add_column("Team", style="cyan", justify="center", width=20) + progress_table.add_column("Agent", style="green", justify="center", width=20) + progress_table.add_column("Status", style="yellow", justify="center", width=20) + + teams = { + "Analyst Team": [ + "Market Analyst", + "Social Analyst", + "News Analyst", + "Fundamentals Analyst", + ], + "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], + "Trading Team": ["Trader"], + "Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"], + "Portfolio Management": ["Portfolio Manager"], + } + + for team, agents in teams.items(): + first_agent = agents[0] + status = message_buffer.agent_status[first_agent] + if status == "in_progress": + spinner = Spinner( + "dots", text="[blue]in_progress[/blue]", style="bold cyan" + ) + status_cell = spinner + else: + status_color = { + "pending": "yellow", + "completed": "green", + "error": "red", + }.get(status, "white") + status_cell = f"[{status_color}]{status}[/{status_color}]" + progress_table.add_row(team, first_agent, status_cell) + + for agent in agents[1:]: + status = message_buffer.agent_status[agent] + if status == "in_progress": + spinner = Spinner( + "dots", text="[blue]in_progress[/blue]", style="bold cyan" + ) + status_cell = spinner + else: + status_color = { + "pending": "yellow", + "completed": "green", + "error": "red", + }.get(status, "white") + status_cell = f"[{status_color}]{status}[/{status_color}]" + progress_table.add_row("", agent, status_cell) + + progress_table.add_row("-" * 20, "-" * 20, "-" * 20, style="dim") + + layout["progress"].update( + Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)) + ) + + messages_table = Table( + show_header=True, + header_style="bold magenta", + show_footer=False, + expand=True, + box=box.MINIMAL, + show_lines=True, + padding=(0, 1), + ) + messages_table.add_column("Time", style="cyan", width=8, justify="center") + messages_table.add_column("Type", style="green", width=10, justify="center") + messages_table.add_column("Content", style="white", no_wrap=False, ratio=1) + + all_messages = [] + + for timestamp, tool_name, args in message_buffer.tool_calls: + if isinstance(args, str) and len(args) > 100: + args = args[:97] + "..." + all_messages.append((timestamp, "Tool", f"{tool_name}: {args}")) + + for timestamp, msg_type, content in message_buffer.messages: + content_str = content + if isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, dict): + if item.get('type') == 'text': + text_parts.append(item.get('text', '')) + elif item.get('type') == 'tool_use': + text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") + else: + text_parts.append(str(item)) + content_str = ' '.join(text_parts) + elif not isinstance(content_str, str): + content_str = str(content) + + if len(content_str) > 200: + content_str = content_str[:197] + "..." + all_messages.append((timestamp, msg_type, content_str)) + + all_messages.sort(key=lambda x: x[0]) + max_messages = 12 + recent_messages = all_messages[-max_messages:] + + for timestamp, msg_type, content in recent_messages: + wrapped_content = Text(content, overflow="fold") + messages_table.add_row(timestamp, msg_type, wrapped_content) + + if spinner_text: + messages_table.add_row("", "Spinner", spinner_text) + + if len(all_messages) > max_messages: + messages_table.footer = ( + f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]" + ) + + layout["messages"].update( + Panel( + messages_table, + title="Messages & Tools", + border_style="blue", + padding=(1, 2), + ) + ) + + if message_buffer.current_report: + layout["analysis"].update( + Panel( + Markdown(message_buffer.current_report), + title="Current Report", + border_style="green", + padding=(1, 2), + ) + ) + else: + layout["analysis"].update( + Panel( + "[italic]Waiting for analysis report...[/italic]", + title="Current Report", + border_style="green", + padding=(1, 2), + ) + ) + + tool_calls_count = len(message_buffer.tool_calls) + llm_calls_count = sum( + 1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning" + ) + reports_count = sum( + 1 for content in message_buffer.report_sections.values() if content is not None + ) + + stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True) + stats_table.add_column("Stats", justify="center") + stats_table.add_row( + f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}" + ) + + layout["footer"].update(Panel(stats_table, border_style="grey50")) + + +def get_user_selections(): + with open("./cli/static/welcome.txt", "r") as f: + welcome_ascii = f.read() + + welcome_content = f"{welcome_ascii}\n" + welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" + welcome_content += "[bold]Workflow Steps:[/bold]\n" + welcome_content += "I. Analyst Team -> II. Research Team -> III. Trader -> IV. Risk Management -> V. Portfolio Management\n\n" + welcome_content += "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]" + + welcome_box = Panel( + welcome_content, + border_style="green", + padding=(1, 2), + title="Welcome to TradingAgents", + subtitle="Multi-Agents LLM Financial Trading Framework", + ) + console.print(Align.center(welcome_box)) + console.print() + + console.print( + create_question_box( + "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY" + ) + ) + selected_ticker = get_ticker() + + default_date = datetime.datetime.now().strftime("%Y-%m-%d") + console.print( + create_question_box( + "Step 2: Analysis Date", + "Enter the analysis date (YYYY-MM-DD)", + default_date, + ) + ) + analysis_date = get_analysis_date() + + console.print( + create_question_box( + "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis" + ) + ) + selected_analysts = select_analysts() + console.print( + f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}" + ) + + console.print( + create_question_box( + "Step 4: Research Depth", "Select your research depth level" + ) + ) + selected_research_depth = select_research_depth() + + console.print( + create_question_box( + "Step 5: OpenAI backend", "Select which service to talk to" + ) + ) + selected_llm_provider, backend_url = select_llm_provider() + + console.print( + create_question_box( + "Step 6: Thinking Agents", "Select your thinking agents for analysis" + ) + ) + selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) + selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) + + return { + "ticker": selected_ticker, + "analysis_date": analysis_date, + "analysts": selected_analysts, + "research_depth": selected_research_depth, + "llm_provider": selected_llm_provider.lower(), + "backend_url": backend_url, + "shallow_thinker": selected_shallow_thinker, + "deep_thinker": selected_deep_thinker, + } + + +def get_ticker(): + return typer.prompt("", default="SPY") + + +def get_analysis_date(): + while True: + date_str = typer.prompt( + "", default=datetime.datetime.now().strftime("%Y-%m-%d") + ) + try: + analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d") + if analysis_date.date() > datetime.datetime.now().date(): + console.print("[red]Error: Analysis date cannot be in the future[/red]") + continue + return date_str + except ValueError: + console.print( + "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]" + ) + + +def display_complete_report(final_state): + console.print("\n[bold green]Complete Analysis Report[/bold green]\n") + + analyst_reports = [] + + if final_state.get("market_report"): + analyst_reports.append( + Panel( + Markdown(final_state["market_report"]), + title="Market Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if final_state.get("sentiment_report"): + analyst_reports.append( + Panel( + Markdown(final_state["sentiment_report"]), + title="Social Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if final_state.get("news_report"): + analyst_reports.append( + Panel( + Markdown(final_state["news_report"]), + title="News Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if final_state.get("fundamentals_report"): + analyst_reports.append( + Panel( + Markdown(final_state["fundamentals_report"]), + title="Fundamentals Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if analyst_reports: + console.print( + Panel( + Columns(analyst_reports, equal=True, expand=True), + title="I. Analyst Team Reports", + border_style="cyan", + padding=(1, 2), + ) + ) + + if final_state.get("investment_debate_state"): + research_reports = [] + debate_state = final_state["investment_debate_state"] + + if debate_state.get("bull_history"): + research_reports.append( + Panel( + Markdown(debate_state["bull_history"]), + title="Bull Researcher", + border_style="blue", + padding=(1, 2), + ) + ) + + if debate_state.get("bear_history"): + research_reports.append( + Panel( + Markdown(debate_state["bear_history"]), + title="Bear Researcher", + border_style="blue", + padding=(1, 2), + ) + ) + + if debate_state.get("judge_decision"): + research_reports.append( + Panel( + Markdown(debate_state["judge_decision"]), + title="Research Manager", + border_style="blue", + padding=(1, 2), + ) + ) + + if research_reports: + console.print( + Panel( + Columns(research_reports, equal=True, expand=True), + title="II. Research Team Decision", + border_style="magenta", + padding=(1, 2), + ) + ) + + if final_state.get("trader_investment_plan"): + console.print( + Panel( + Panel( + Markdown(final_state["trader_investment_plan"]), + title="Trader", + border_style="blue", + padding=(1, 2), + ), + title="III. Trading Team Plan", + border_style="yellow", + padding=(1, 2), + ) + ) + + if final_state.get("risk_debate_state"): + risk_reports = [] + risk_state = final_state["risk_debate_state"] + + if risk_state.get("risky_history"): + risk_reports.append( + Panel( + Markdown(risk_state["risky_history"]), + title="Aggressive Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if risk_state.get("safe_history"): + risk_reports.append( + Panel( + Markdown(risk_state["safe_history"]), + title="Conservative Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if risk_state.get("neutral_history"): + risk_reports.append( + Panel( + Markdown(risk_state["neutral_history"]), + title="Neutral Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if risk_reports: + console.print( + Panel( + Columns(risk_reports, equal=True, expand=True), + title="IV. Risk Management Team Decision", + border_style="red", + padding=(1, 2), + ) + ) + + if risk_state.get("judge_decision"): + console.print( + Panel( + Panel( + Markdown(risk_state["judge_decision"]), + title="Portfolio Manager", + border_style="blue", + padding=(1, 2), + ), + title="V. Portfolio Manager Decision", + border_style="green", + padding=(1, 2), + ) + ) + + +def update_research_team_status(status): + research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"] + for agent in research_team: + message_buffer.update_agent_status(agent, status) + + +def extract_content_string(content): + if isinstance(content, str): + return content + elif isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, dict): + if item.get('type') == 'text': + text_parts.append(item.get('text', '')) + elif item.get('type') == 'tool_use': + text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") + else: + text_parts.append(str(item)) + return ' '.join(text_parts) + else: + return str(content) + + +def run_analysis(): + selections = get_user_selections() + + config = DEFAULT_CONFIG.copy() + config["max_debate_rounds"] = selections["research_depth"] + config["max_risk_discuss_rounds"] = selections["research_depth"] + config["quick_think_llm"] = selections["shallow_thinker"] + config["deep_think_llm"] = selections["deep_thinker"] + config["backend_url"] = selections["backend_url"] + config["llm_provider"] = selections["llm_provider"].lower() + + graph = TradingAgentsGraph( + [analyst.value for analyst in selections["analysts"]], config=config, debug=True + ) + + results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] + results_dir.mkdir(parents=True, exist_ok=True) + report_dir = results_dir / "reports" + report_dir.mkdir(parents=True, exist_ok=True) + log_file = results_dir / "message_tool.log" + log_file.touch(exist_ok=True) + + def save_message_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(*args, **kwargs): + func(*args, **kwargs) + timestamp, message_type, content = obj.messages[-1] + content = content.replace("\n", " ") + with open(log_file, "a") as f: + f.write(f"{timestamp} [{message_type}] {content}\n") + return wrapper + def save_tool_call_decorator(obj, func_name): func = getattr(obj, func_name) @wraps(func) @@ -800,14 +1382,11 @@ def run_analysis(): message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call") message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section") - # Now start the display layout layout = create_layout() with Live(layout, refresh_per_second=4) as live: - # Initial display update_display(layout) - # Add initial messages message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") message_buffer.add_message( "System", f"Analysis date: {selections['analysis_date']}" @@ -818,55 +1397,44 @@ def run_analysis(): ) update_display(layout) - # Reset agent statuses for agent in message_buffer.agent_status: message_buffer.update_agent_status(agent, "pending") - # Reset report sections for section in message_buffer.report_sections: message_buffer.report_sections[section] = None message_buffer.current_report = None message_buffer.final_report = None - # Update agent status to in_progress for the first analyst first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst" message_buffer.update_agent_status(first_analyst, "in_progress") update_display(layout) - # Create spinner text spinner_text = ( f"Analyzing {selections['ticker']} on {selections['analysis_date']}..." ) update_display(layout, spinner_text) - # Initialize state and get graph args init_agent_state = graph.propagator.create_initial_state( selections["ticker"], selections["analysis_date"] ) args = graph.propagator.get_graph_args() - # Stream the analysis trace = [] for chunk in graph.graph.stream(init_agent_state, **args): if len(chunk["messages"]) > 0: - # Get the last message from the chunk last_message = chunk["messages"][-1] - # Extract message content and type if hasattr(last_message, "content"): - content = extract_content_string(last_message.content) # Use the helper function + content = extract_content_string(last_message.content) msg_type = "Reasoning" else: content = str(last_message) msg_type = "System" - # Add message to buffer - message_buffer.add_message(msg_type, content) + message_buffer.add_message(msg_type, content) - # If it's a tool call, add it to tool calls if hasattr(last_message, "tool_calls"): for tool_call in last_message.tool_calls: - # Handle both dictionary and object tool calls if isinstance(tool_call, dict): message_buffer.add_tool_call( tool_call["name"], tool_call["args"] @@ -874,14 +1442,11 @@ def run_analysis(): else: message_buffer.add_tool_call(tool_call.name, tool_call.args) - # Update reports and agent status based on chunk content - # Analyst Team Reports if "market_report" in chunk and chunk["market_report"]: message_buffer.update_report_section( "market_report", chunk["market_report"] ) message_buffer.update_agent_status("Market Analyst", "completed") - # Set next analyst to in_progress if "social" in selections["analysts"]: message_buffer.update_agent_status( "Social Analyst", "in_progress" @@ -892,7 +1457,6 @@ def run_analysis(): "sentiment_report", chunk["sentiment_report"] ) message_buffer.update_agent_status("Social Analyst", "completed") - # Set next analyst to in_progress if "news" in selections["analysts"]: message_buffer.update_agent_status( "News Analyst", "in_progress" @@ -903,7 +1467,6 @@ def run_analysis(): "news_report", chunk["news_report"] ) message_buffer.update_agent_status("News Analyst", "completed") - # Set next analyst to in_progress if "fundamentals" in selections["analysts"]: message_buffer.update_agent_status( "Fundamentals Analyst", "in_progress" @@ -916,70 +1479,54 @@ def run_analysis(): message_buffer.update_agent_status( "Fundamentals Analyst", "completed" ) - # Set all research team members to in_progress update_research_team_status("in_progress") - # Research Team - Handle Investment Debate State if ( "investment_debate_state" in chunk and chunk["investment_debate_state"] ): debate_state = chunk["investment_debate_state"] - # Update Bull Researcher status and report if "bull_history" in debate_state and debate_state["bull_history"]: - # Keep all research team members in progress update_research_team_status("in_progress") - # Extract latest bull response bull_responses = debate_state["bull_history"].split("\n") latest_bull = bull_responses[-1] if bull_responses else "" if latest_bull: message_buffer.add_message("Reasoning", latest_bull) - # Update research report with bull's latest analysis message_buffer.update_report_section( "investment_plan", f"### Bull Researcher Analysis\n{latest_bull}", ) - # Update Bear Researcher status and report if "bear_history" in debate_state and debate_state["bear_history"]: - # Keep all research team members in progress update_research_team_status("in_progress") - # Extract latest bear response bear_responses = debate_state["bear_history"].split("\n") latest_bear = bear_responses[-1] if bear_responses else "" if latest_bear: message_buffer.add_message("Reasoning", latest_bear) - # Update research report with bear's latest analysis message_buffer.update_report_section( "investment_plan", f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}", ) - # Update Research Manager status and final decision if ( "judge_decision" in debate_state and debate_state["judge_decision"] ): - # Keep all research team members in progress until final decision update_research_team_status("in_progress") message_buffer.add_message( "Reasoning", f"Research Manager: {debate_state['judge_decision']}", ) - # Update research report with final decision message_buffer.update_report_section( "investment_plan", f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}", ) - # Mark all research team members as completed update_research_team_status("completed") - # Set first risk analyst to in_progress message_buffer.update_agent_status( "Risky Analyst", "in_progress" ) - # Trading Team if ( "trader_investment_plan" in chunk and chunk["trader_investment_plan"] @@ -987,14 +1534,11 @@ def run_analysis(): message_buffer.update_report_section( "trader_investment_plan", chunk["trader_investment_plan"] ) - # Set first risk analyst to in_progress message_buffer.update_agent_status("Risky Analyst", "in_progress") - # Risk Management Team - Handle Risk Debate State if "risk_debate_state" in chunk and chunk["risk_debate_state"]: risk_state = chunk["risk_debate_state"] - # Update Risky Analyst status and report if ( "current_risky_response" in risk_state and risk_state["current_risky_response"] @@ -1006,13 +1550,11 @@ def run_analysis(): "Reasoning", f"Risky Analyst: {risk_state['current_risky_response']}", ) - # Update risk report with risky analyst's latest analysis only message_buffer.update_report_section( "final_trade_decision", f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}", ) - # Update Safe Analyst status and report if ( "current_safe_response" in risk_state and risk_state["current_safe_response"] @@ -1024,13 +1566,11 @@ def run_analysis(): "Reasoning", f"Safe Analyst: {risk_state['current_safe_response']}", ) - # Update risk report with safe analyst's latest analysis only message_buffer.update_report_section( "final_trade_decision", f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}", ) - # Update Neutral Analyst status and report if ( "current_neutral_response" in risk_state and risk_state["current_neutral_response"] @@ -1042,13 +1582,11 @@ def run_analysis(): "Reasoning", f"Neutral Analyst: {risk_state['current_neutral_response']}", ) - # Update risk report with neutral analyst's latest analysis only message_buffer.update_report_section( "final_trade_decision", f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}", ) - # Update Portfolio Manager status and final decision if "judge_decision" in risk_state and risk_state["judge_decision"]: message_buffer.update_agent_status( "Portfolio Manager", "in_progress" @@ -1057,12 +1595,10 @@ def run_analysis(): "Reasoning", f"Portfolio Manager: {risk_state['judge_decision']}", ) - # Update risk report with final decision only message_buffer.update_report_section( "final_trade_decision", f"### Portfolio Manager Decision\n{risk_state['judge_decision']}", ) - # Mark risk analysts as completed message_buffer.update_agent_status("Risky Analyst", "completed") message_buffer.update_agent_status("Safe Analyst", "completed") message_buffer.update_agent_status( @@ -1072,16 +1608,13 @@ def run_analysis(): "Portfolio Manager", "completed" ) - # Update the display update_display(layout) trace.append(chunk) - # Get final state and decision final_state = trace[-1] decision = graph.process_signal(final_state["final_trade_decision"]) - # Update all agent statuses to completed for agent in message_buffer.agent_status: message_buffer.update_agent_status(agent, "completed") @@ -1089,21 +1622,85 @@ def run_analysis(): "Analysis", f"Completed analysis for {selections['analysis_date']}" ) - # Update final report sections for section in message_buffer.report_sections.keys(): if section in final_state: message_buffer.update_report_section(section, final_state[section]) - # Display the complete final report display_complete_report(final_state) update_display(layout) +def show_main_menu(): + with open("./cli/static/welcome.txt", "r") as f: + welcome_ascii = f.read() + + welcome_content = f"{welcome_ascii}\n" + welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" + welcome_content += "[bold]Available Options:[/bold]\n" + welcome_content += "1. Analyze a specific stock\n" + welcome_content += "2. Discover trending stocks\n\n" + welcome_content += "[dim]Built by Tauric Research (https://github.com/TauricResearch)[/dim]" + + welcome_box = Panel( + welcome_content, + border_style="green", + padding=(1, 2), + title="Welcome to TradingAgents", + subtitle="Multi-Agents LLM Financial Trading Framework", + ) + console.print(Align.center(welcome_box)) + console.print() + + MENU_OPTIONS = [ + ("1. Analyze a specific stock", "analyze"), + ("2. Discover trending stocks", "discover"), + ] + + choice = questionary.select( + "Select an option:", + choices=[ + questionary.Choice(display, value=value) for display, value in MENU_OPTIONS + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:green noinherit"), + ("highlighted", "fg:green noinherit"), + ("pointer", "fg:green noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print("\n[red]No option selected. Exiting...[/red]") + exit(0) + + return choice + + @app.command() def analyze(): run_analysis() +@app.command() +def discover(): + discover_trending_flow() + + +@app.command() +def menu(): + choice = show_main_menu() + if choice == "analyze": + run_analysis() + elif choice == "discover": + discover_trending_flow() + + if __name__ == "__main__": - app() + choice = show_main_menu() + if choice == "analyze": + run_analysis() + elif choice == "discover": + discover_trending_flow() diff --git a/main.py b/main.py index a85ee6ec..42a45a0d 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ load_dotenv() # Create a custom config config = DEFAULT_CONFIG.copy() -config["deep_think_llm"] = "gpt-4o-mini" # Use a different model +config["deep_think_llm"] = "gpt-5" # Use a different model config["quick_think_llm"] = "gpt-4o-mini" # Use a different model config["max_debate_rounds"] = 1 # Increase debate rounds diff --git a/tests/discovery/__init__.py b/tests/discovery/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/discovery/test_api.py b/tests/discovery/test_api.py new file mode 100644 index 00000000..700f351f --- /dev/null +++ b/tests/discovery/test_api.py @@ -0,0 +1,200 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime, timedelta +import signal + +from tradingagents.agents.discovery import ( + DiscoveryRequest, + DiscoveryResult, + DiscoveryStatus, + TrendingStock, + NewsArticle, + Sector, + EventCategory, + DiscoveryTimeoutError, +) + + +def create_mock_trending_stock( + ticker: str = "AAPL", + company_name: str = "Apple Inc.", + score: float = 10.0, + sector: Sector = Sector.TECHNOLOGY, + event_type: EventCategory = EventCategory.EARNINGS, +) -> TrendingStock: + return TrendingStock( + ticker=ticker, + company_name=company_name, + score=score, + mention_count=5, + sentiment=0.5, + sector=sector, + event_type=event_type, + news_summary="Test news summary", + source_articles=[], + ) + + +def create_mock_news_article() -> NewsArticle: + return NewsArticle( + title="Test Article", + source="Test Source", + url="https://example.com/article", + published_at=datetime.now(), + content_snippet="Test content about Apple stock", + ticker_mentions=["AAPL"], + ) + + +class TestDiscoverTrendingReturnsDiscoveryResult: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + def test_discover_trending_returns_discovery_result( + self, mock_scores, mock_extract, mock_bulk_news + ): + mock_bulk_news.return_value = [create_mock_news_article()] + mock_extract.return_value = [] + mock_scores.return_value = [create_mock_trending_stock()] + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + result = graph.discover_trending() + + assert isinstance(result, DiscoveryResult) + assert result.status == DiscoveryStatus.COMPLETED + assert len(result.trending_stocks) > 0 + + +class TestAnalyzeTrendingCallsPropagate: + def test_analyze_trending_calls_propagate(self): + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.propagate = Mock(return_value=({"final_state": "test"}, "BUY")) + + trending_stock = create_mock_trending_stock() + + result = graph.analyze_trending(trending_stock) + + graph.propagate.assert_called_once() + call_args = graph.propagate.call_args + assert call_args[0][0] == "AAPL" + + +class TestSectorFilterParameter: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + def test_sector_filter_filters_results( + self, mock_scores, mock_extract, mock_bulk_news + ): + mock_bulk_news.return_value = [create_mock_news_article()] + mock_extract.return_value = [] + mock_scores.return_value = [ + create_mock_trending_stock(ticker="AAPL", sector=Sector.TECHNOLOGY), + create_mock_trending_stock(ticker="JPM", sector=Sector.FINANCE), + create_mock_trending_stock(ticker="XOM", sector=Sector.ENERGY), + ] + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + request = DiscoveryRequest( + lookback_period="24h", + sector_filter=[Sector.TECHNOLOGY], + ) + result = graph.discover_trending(request) + + assert all( + stock.sector == Sector.TECHNOLOGY for stock in result.trending_stocks + ) + + +class TestEventFilterParameter: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + def test_event_filter_filters_results( + self, mock_scores, mock_extract, mock_bulk_news + ): + mock_bulk_news.return_value = [create_mock_news_article()] + mock_extract.return_value = [] + mock_scores.return_value = [ + create_mock_trending_stock(ticker="AAPL", event_type=EventCategory.EARNINGS), + create_mock_trending_stock( + ticker="MSFT", event_type=EventCategory.PRODUCT_LAUNCH + ), + create_mock_trending_stock( + ticker="GOOGL", event_type=EventCategory.MERGER_ACQUISITION + ), + ] + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + request = DiscoveryRequest( + lookback_period="24h", + event_filter=[EventCategory.EARNINGS], + ) + result = graph.discover_trending(request) + + assert all( + stock.event_type == EventCategory.EARNINGS + for stock in result.trending_stocks + ) + + +class TestTimeoutHandling: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + def test_timeout_raises_discovery_timeout_error(self, mock_bulk_news): + def slow_fetch(*args, **kwargs): + import time + time.sleep(0.5) + return [] + + mock_bulk_news.side_effect = slow_fetch + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 0.1, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + with pytest.raises(DiscoveryTimeoutError): + graph.discover_trending() diff --git a/tests/discovery/test_bulk_news.py b/tests/discovery/test_bulk_news.py new file mode 100644 index 00000000..b6fb3e0e --- /dev/null +++ b/tests/discovery/test_bulk_news.py @@ -0,0 +1,160 @@ +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock +from tradingagents.agents.discovery import NewsArticle +from tradingagents.dataflows.alpha_vantage_common import AlphaVantageRateLimitError + + +class TestGetBulkNewsReturnsNewsArticles: + def test_get_bulk_news_returns_list_of_news_article_objects(self): + mock_raw_news = [ + { + "title": "Market Update: Tech stocks rally", + "source": "Reuters", + "url": "https://reuters.com/market-update", + "published_at": datetime.now().isoformat(), + "content_snippet": "Technology stocks led gains in early trading...", + }, + { + "title": "Fed signals rate decision", + "source": "Bloomberg", + "url": "https://bloomberg.com/fed-rates", + "published_at": datetime.now().isoformat(), + "content_snippet": "Federal Reserve officials indicated...", + }, + ] + + from tradingagents.dataflows.interface import ( + _bulk_news_cache, + get_bulk_news, + ) + + _bulk_news_cache.clear() + + with patch( + "tradingagents.dataflows.interface._fetch_bulk_news_from_vendor" + ) as mock_fetch: + mock_fetch.return_value = mock_raw_news + + result = get_bulk_news(lookback_period="24h") + + assert isinstance(result, list) + assert len(result) == 2 + for article in result: + assert isinstance(article, NewsArticle) + assert article.title is not None + assert article.source is not None + assert article.url is not None + + +class TestLookbackPeriodParsing: + @pytest.mark.parametrize( + "lookback,expected_hours", + [ + ("1h", 1), + ("6h", 6), + ("24h", 24), + ("7d", 168), + ], + ) + def test_lookback_period_parsing(self, lookback, expected_hours): + from tradingagents.dataflows.interface import parse_lookback_period + + hours = parse_lookback_period(lookback) + assert hours == expected_hours + + def test_invalid_lookback_period_raises_error(self): + from tradingagents.dataflows.interface import parse_lookback_period + + with pytest.raises(ValueError): + parse_lookback_period("invalid") + + +class TestVendorFallback: + def test_vendor_fallback_when_primary_rate_limited(self): + mock_openai_news = [ + { + "title": "Fallback news from OpenAI", + "source": "Web Search", + "url": "https://example.com/fallback", + "published_at": datetime.now().isoformat(), + "content_snippet": "This is fallback content...", + }, + ] + + from tradingagents.dataflows.interface import ( + _bulk_news_cache, + ) + + _bulk_news_cache.clear() + + with patch( + "tradingagents.dataflows.interface.VENDOR_METHODS", + { + "get_bulk_news": { + "alpha_vantage": MagicMock(side_effect=AlphaVantageRateLimitError("Rate limit")), + "openai": MagicMock(return_value=mock_openai_news), + "google": MagicMock(return_value=[]), + } + } + ): + from tradingagents.dataflows.interface import _fetch_bulk_news_from_vendor + + result = _fetch_bulk_news_from_vendor("24h") + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["title"] == "Fallback news from OpenAI" + + +class TestBulkNewsCache: + def test_cache_returns_same_results_within_ttl(self): + from tradingagents.dataflows.interface import ( + _bulk_news_cache, + _get_cached_bulk_news, + _set_cached_bulk_news, + ) + + _bulk_news_cache.clear() + + test_articles = [ + NewsArticle( + title="Cached article", + source="Test Source", + url="https://test.com/cached", + published_at=datetime.now(), + content_snippet="Cached content...", + ticker_mentions=[], + ) + ] + + _set_cached_bulk_news("24h", test_articles) + + cached_result = _get_cached_bulk_news("24h") + assert cached_result is not None + assert len(cached_result) == 1 + assert cached_result[0].title == "Cached article" + + cached_result_again = _get_cached_bulk_news("24h") + assert cached_result_again is not None + assert cached_result_again[0].title == cached_result[0].title + + +class TestEmptyResultHandling: + def test_empty_result_handling(self): + from tradingagents.dataflows.interface import ( + _bulk_news_cache, + get_bulk_news, + ) + + _bulk_news_cache.clear() + + with patch( + "tradingagents.dataflows.interface._fetch_bulk_news_from_vendor" + ) as mock_fetch: + mock_fetch.return_value = [] + + result = get_bulk_news(lookback_period="1h") + + assert isinstance(result, list) + assert len(result) == 0 diff --git a/tests/discovery/test_cli.py b/tests/discovery/test_cli.py new file mode 100644 index 00000000..5b0c56d7 --- /dev/null +++ b/tests/discovery/test_cli.py @@ -0,0 +1,127 @@ +import pytest +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime +from io import StringIO + +from tradingagents.agents.discovery.models import ( + DiscoveryResult, + DiscoveryRequest, + DiscoveryStatus, + TrendingStock, + NewsArticle, + Sector, + EventCategory, +) + + +@pytest.fixture +def sample_trending_stocks(): + article = NewsArticle( + title="Apple announces new iPhone", + source="Reuters", + url="https://reuters.com/article", + published_at=datetime.now(), + content_snippet="Apple Inc. unveiled its latest iPhone model today...", + ticker_mentions=["AAPL"], + ) + return [ + TrendingStock( + ticker="AAPL", + company_name="Apple Inc.", + score=8.5, + mention_count=10, + sentiment=0.7, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.PRODUCT_LAUNCH, + news_summary="Apple announced new iPhone model with enhanced AI capabilities.", + source_articles=[article], + ), + TrendingStock( + ticker="MSFT", + company_name="Microsoft Corporation", + score=7.2, + mention_count=8, + sentiment=0.5, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Microsoft reported strong quarterly earnings.", + source_articles=[article], + ), + TrendingStock( + ticker="NVDA", + company_name="NVIDIA Corporation", + score=6.8, + mention_count=6, + sentiment=0.4, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.PRODUCT_LAUNCH, + news_summary="NVIDIA unveiled new AI chips.", + source_articles=[article], + ), + ] + + +@pytest.fixture +def sample_discovery_result(sample_trending_stocks): + request = DiscoveryRequest( + lookback_period="24h", + max_results=20, + ) + return DiscoveryResult( + request=request, + trending_stocks=sample_trending_stocks, + status=DiscoveryStatus.COMPLETED, + started_at=datetime.now(), + completed_at=datetime.now(), + ) + + +class TestDiscoveryMenuOption: + def test_discover_trending_flow_exists(self): + from cli.main import discover_trending_flow + assert callable(discover_trending_flow) + + def test_select_lookback_period_function_exists(self): + from cli.main import select_lookback_period + assert callable(select_lookback_period) + + +class TestLookbackSelection: + @patch("cli.main.questionary.select") + def test_lookback_selection_returns_valid_period(self, mock_select): + mock_select.return_value.ask.return_value = "24h" + from cli.main import select_lookback_period + result = select_lookback_period() + assert result in ["1h", "6h", "24h", "7d"] + + @patch("cli.main.questionary.select") + def test_lookback_selection_handles_all_options(self, mock_select): + from cli.main import select_lookback_period + for period in ["1h", "6h", "24h", "7d"]: + mock_select.return_value.ask.return_value = period + result = select_lookback_period() + assert result == period + + +class TestResultsTableDisplay: + def test_create_discovery_results_table(self, sample_trending_stocks): + from cli.main import create_discovery_results_table + table = create_discovery_results_table(sample_trending_stocks) + assert table is not None + assert table.row_count == len(sample_trending_stocks) + + def test_table_has_correct_columns(self, sample_trending_stocks): + from cli.main import create_discovery_results_table + table = create_discovery_results_table(sample_trending_stocks) + column_names = [col.header for col in table.columns] + expected_columns = ["Rank", "Ticker", "Company", "Score", "Mentions", "Event Type"] + for expected in expected_columns: + assert expected in column_names + + +class TestDetailView: + def test_create_stock_detail_panel(self, sample_trending_stocks): + from cli.main import create_stock_detail_panel + stock = sample_trending_stocks[0] + panel = create_stock_detail_panel(stock, rank=1) + assert panel is not None diff --git a/tests/discovery/test_entity_extractor.py b/tests/discovery/test_entity_extractor.py new file mode 100644 index 00000000..57f9f82b --- /dev/null +++ b/tests/discovery/test_entity_extractor.py @@ -0,0 +1,278 @@ +import pytest +from datetime import datetime +from unittest.mock import patch, MagicMock +from tradingagents.agents.discovery import NewsArticle, EventCategory + + +class TestExtractEntitiesReturnsCompanyMentions: + def test_extract_entities_returns_list_of_company_mentions(self): + from tradingagents.agents.discovery.entity_extractor import ( + extract_entities, + EntityMention, + ) + + articles = [ + NewsArticle( + title="Apple announces new iPhone", + source="Reuters", + url="https://reuters.com/apple", + published_at=datetime.now(), + content_snippet="Apple Inc unveiled its latest iPhone model today with advanced AI features.", + ticker_mentions=[], + ), + ] + + mock_response = MagicMock() + mock_response.entities = [ + MagicMock( + company_name="Apple Inc", + confidence=0.95, + context_snippet="Apple Inc unveiled its latest iPhone", + event_type="product_launch", + sentiment=0.7, + ) + ] + + with patch( + "tradingagents.agents.discovery.entity_extractor._get_llm" + ) as mock_get_llm: + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value.invoke.return_value = ( + mock_response + ) + mock_get_llm.return_value = mock_llm + + result = extract_entities(articles) + + assert isinstance(result, list) + assert len(result) > 0 + assert all(isinstance(m, EntityMention) for m in result) + assert result[0].company_name == "Apple Inc" + + +class TestConfidenceScoreRange: + def test_confidence_score_in_valid_range(self): + from tradingagents.agents.discovery.entity_extractor import ( + extract_entities, + EntityMention, + ) + + articles = [ + NewsArticle( + title="Tesla reports earnings", + source="Bloomberg", + url="https://bloomberg.com/tsla", + published_at=datetime.now(), + content_snippet="Tesla Inc reported strong quarterly earnings beating analyst expectations.", + ticker_mentions=[], + ), + ] + + mock_response = MagicMock() + mock_response.entities = [ + MagicMock( + company_name="Tesla Inc", + confidence=0.88, + context_snippet="Tesla Inc reported strong quarterly earnings", + event_type="earnings", + sentiment=0.5, + ) + ] + + with patch( + "tradingagents.agents.discovery.entity_extractor._get_llm" + ) as mock_get_llm: + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value.invoke.return_value = ( + mock_response + ) + mock_get_llm.return_value = mock_llm + + result = extract_entities(articles) + + for mention in result: + assert 0.0 <= mention.confidence <= 1.0 + + +class TestContextSnippetExtraction: + def test_context_snippet_extraction(self): + from tradingagents.agents.discovery.entity_extractor import ( + extract_entities, + EntityMention, + ) + + articles = [ + NewsArticle( + title="Microsoft acquires gaming company", + source="WSJ", + url="https://wsj.com/msft", + published_at=datetime.now(), + content_snippet="Microsoft Corporation announced today it will acquire a major gaming studio for $10 billion.", + ticker_mentions=[], + ), + ] + + mock_response = MagicMock() + mock_response.entities = [ + MagicMock( + company_name="Microsoft Corporation", + confidence=0.92, + context_snippet="Microsoft Corporation announced today it will acquire", + event_type="merger_acquisition", + sentiment=0.6, + ) + ] + + with patch( + "tradingagents.agents.discovery.entity_extractor._get_llm" + ) as mock_get_llm: + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value.invoke.return_value = ( + mock_response + ) + mock_get_llm.return_value = mock_llm + + result = extract_entities(articles) + + assert len(result) > 0 + for mention in result: + assert mention.context_snippet is not None + assert len(mention.context_snippet) > 0 + assert len(mention.context_snippet) <= 150 + + +class TestBatchProcessing: + def test_batch_processing_of_multiple_articles(self): + from tradingagents.agents.discovery.entity_extractor import ( + extract_entities, + EntityMention, + BATCH_SIZE, + ) + + articles = [ + NewsArticle( + title=f"News article {i}", + source="Reuters", + url=f"https://reuters.com/article{i}", + published_at=datetime.now(), + content_snippet=f"Company {i} announced major developments today.", + ticker_mentions=[], + ) + for i in range(15) + ] + + mock_response = MagicMock() + mock_response.entities = [ + MagicMock( + company_name="Test Company", + confidence=0.85, + context_snippet="Company announced major developments", + event_type="other", + sentiment=0.0, + ) + ] + + with patch( + "tradingagents.agents.discovery.entity_extractor._get_llm" + ) as mock_get_llm: + mock_llm = MagicMock() + structured_llm = MagicMock() + structured_llm.invoke.return_value = mock_response + mock_llm.with_structured_output.return_value = structured_llm + mock_get_llm.return_value = mock_llm + + result = extract_entities(articles) + + expected_batches = (len(articles) + BATCH_SIZE - 1) // BATCH_SIZE + assert structured_llm.invoke.call_count == expected_batches + + +class TestNoCompanyMentions: + def test_handling_of_articles_with_no_company_mentions(self): + from tradingagents.agents.discovery.entity_extractor import ( + extract_entities, + EntityMention, + ) + + articles = [ + NewsArticle( + title="Weather forecast for tomorrow", + source="Weather Channel", + url="https://weather.com/forecast", + published_at=datetime.now(), + content_snippet="Tomorrow will be sunny with temperatures reaching 75 degrees.", + ticker_mentions=[], + ), + ] + + mock_response = MagicMock() + mock_response.entities = [] + + with patch( + "tradingagents.agents.discovery.entity_extractor._get_llm" + ) as mock_get_llm: + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value.invoke.return_value = ( + mock_response + ) + mock_get_llm.return_value = mock_llm + + result = extract_entities(articles) + + assert isinstance(result, list) + assert len(result) == 0 + + +class TestEventTypeClassification: + @pytest.mark.parametrize( + "event_type", + [ + "earnings", + "merger_acquisition", + "regulatory", + "product_launch", + "executive_change", + "other", + ], + ) + def test_event_type_classification(self, event_type): + from tradingagents.agents.discovery.entity_extractor import ( + extract_entities, + EntityMention, + ) + + articles = [ + NewsArticle( + title="Company news", + source="Reuters", + url="https://reuters.com/news", + published_at=datetime.now(), + content_snippet="A company made an announcement today.", + ticker_mentions=[], + ), + ] + + mock_response = MagicMock() + mock_response.entities = [ + MagicMock( + company_name="Test Company", + confidence=0.90, + context_snippet="A company made an announcement", + event_type=event_type, + sentiment=0.0, + ) + ] + + with patch( + "tradingagents.agents.discovery.entity_extractor._get_llm" + ) as mock_get_llm: + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value.invoke.return_value = ( + mock_response + ) + mock_get_llm.return_value = mock_llm + + result = extract_entities(articles) + + assert len(result) > 0 + assert result[0].event_type == EventCategory(event_type) diff --git a/tests/discovery/test_integration.py b/tests/discovery/test_integration.py new file mode 100644 index 00000000..6adba188 --- /dev/null +++ b/tests/discovery/test_integration.py @@ -0,0 +1,489 @@ +import pytest +import math +from datetime import datetime, timedelta +from unittest.mock import patch, MagicMock +from tradingagents.agents.discovery import ( + TrendingStock, + NewsArticle, + DiscoveryRequest, + DiscoveryResult, + DiscoveryStatus, + Sector, + EventCategory, + DiscoveryTimeoutError, + NewsUnavailableError, +) +from tradingagents.agents.discovery.entity_extractor import EntityMention + + +class TestEndToEndDiscoveryFlow: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + def test_full_discovery_flow_from_news_to_results( + self, mock_scores, mock_extract, mock_bulk_news + ): + now = datetime.now() + mock_articles = [ + NewsArticle( + title="Apple announces record earnings", + source="Reuters", + url="https://reuters.com/apple-earnings", + published_at=now - timedelta(hours=2), + content_snippet="Apple Inc reported record quarterly earnings...", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="Apple stock surges on AI news", + source="Bloomberg", + url="https://bloomberg.com/apple-ai", + published_at=now - timedelta(hours=1), + content_snippet="Shares of Apple jumped after AI announcement...", + ticker_mentions=["AAPL"], + ), + ] + mock_bulk_news.return_value = mock_articles + + mock_mentions = [ + EntityMention( + company_name="Apple Inc", + confidence=0.95, + context_snippet="Apple Inc reported record quarterly earnings", + article_id="article_0", + event_type=EventCategory.EARNINGS, + sentiment=0.8, + ), + EntityMention( + company_name="Apple", + confidence=0.92, + context_snippet="Shares of Apple jumped", + article_id="article_1", + event_type=EventCategory.PRODUCT_LAUNCH, + sentiment=0.7, + ), + ] + mock_extract.return_value = mock_mentions + + mock_trending = [ + TrendingStock( + ticker="AAPL", + company_name="Apple Inc.", + score=8.5, + mention_count=2, + sentiment=0.75, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Apple reported record earnings and AI progress.", + source_articles=mock_articles, + ), + ] + mock_scores.return_value = mock_trending + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + request = DiscoveryRequest(lookback_period="24h") + result = graph.discover_trending(request) + + assert isinstance(result, DiscoveryResult) + assert result.status == DiscoveryStatus.COMPLETED + assert len(result.trending_stocks) == 1 + assert result.trending_stocks[0].ticker == "AAPL" + assert result.trending_stocks[0].mention_count >= 2 + + mock_bulk_news.assert_called_once_with("24h") + mock_extract.assert_called_once() + mock_scores.assert_called_once() + + +class TestEntityExtractionToScoringPipeline: + def test_pipeline_from_extraction_to_scoring(self): + from tradingagents.agents.discovery.scorer import calculate_trending_scores + + now = datetime.now() + articles = [ + NewsArticle( + title="Microsoft cloud revenue grows", + source="WSJ", + url="https://wsj.com/article1", + published_at=now - timedelta(hours=2), + content_snippet="Microsoft Corporation reported strong cloud growth.", + ticker_mentions=["MSFT"], + ), + NewsArticle( + title="Microsoft earnings beat estimates", + source="CNBC", + url="https://cnbc.com/article2", + published_at=now - timedelta(hours=3), + content_snippet="Microsoft earnings exceeded analyst expectations.", + ticker_mentions=["MSFT"], + ), + NewsArticle( + title="Tech stocks rally", + source="Bloomberg", + url="https://bloomberg.com/article3", + published_at=now - timedelta(hours=1), + content_snippet="Technology companies led market gains.", + ticker_mentions=[], + ), + ] + + mentions = [ + EntityMention( + company_name="Microsoft Corporation", + confidence=0.95, + context_snippet="Microsoft Corporation reported strong cloud growth", + article_id="article_0", + event_type=EventCategory.EARNINGS, + sentiment=0.7, + ), + EntityMention( + company_name="Microsoft", + confidence=0.92, + context_snippet="Microsoft earnings exceeded analyst expectations", + article_id="article_1", + event_type=EventCategory.EARNINGS, + sentiment=0.8, + ), + ] + + with patch("tradingagents.agents.discovery.scorer.resolve_ticker") as mock_resolve: + mock_resolve.return_value = "MSFT" + + with patch("tradingagents.agents.discovery.scorer.classify_sector") as mock_sector: + mock_sector.return_value = "technology" + + result = calculate_trending_scores(mentions, articles, min_mentions=2) + + assert len(result) == 1 + assert result[0].ticker == "MSFT" + assert result[0].mention_count == 2 + assert result[0].sentiment > 0 + + +class TestNewsVendorFailureGracefulDegradation: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + def test_news_vendor_failure_with_graceful_degradation(self, mock_bulk_news): + mock_bulk_news.side_effect = NewsUnavailableError("All news vendors failed") + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + result = graph.discover_trending() + + assert result.status == DiscoveryStatus.FAILED + assert result.error_message is not None + assert "news" in result.error_message.lower() or "vendor" in result.error_message.lower() + + +class TestTimeoutHandlingWithPartialResults: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + def test_timeout_handling_returns_error(self, mock_bulk_news): + def slow_fetch(*args, **kwargs): + import time + time.sleep(0.3) + return [] + + mock_bulk_news.side_effect = slow_fetch + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 0.1, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + with pytest.raises(DiscoveryTimeoutError): + graph.discover_trending() + + +class TestNoTrendingStocksFound: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + def test_no_trending_stocks_found_returns_empty_list( + self, mock_scores, mock_extract, mock_bulk_news + ): + mock_bulk_news.return_value = [ + NewsArticle( + title="General market update", + source="Reuters", + url="https://reuters.com/general", + published_at=datetime.now(), + content_snippet="Markets were quiet today with no major news.", + ticker_mentions=[], + ), + ] + mock_extract.return_value = [] + mock_scores.return_value = [] + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + result = graph.discover_trending() + + assert result.status == DiscoveryStatus.COMPLETED + assert len(result.trending_stocks) == 0 + + +class TestAllStocksFilteredOutBySectorFilter: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + def test_all_stocks_filtered_out_by_sector_filter( + self, mock_scores, mock_extract, mock_bulk_news + ): + mock_bulk_news.return_value = [] + mock_extract.return_value = [] + mock_scores.return_value = [ + TrendingStock( + ticker="AAPL", + company_name="Apple Inc.", + score=10.0, + mention_count=5, + sentiment=0.5, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Apple earnings", + source_articles=[], + ), + TrendingStock( + ticker="MSFT", + company_name="Microsoft", + score=9.0, + mention_count=4, + sentiment=0.4, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.PRODUCT_LAUNCH, + news_summary="Microsoft product", + source_articles=[], + ), + ] + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + request = DiscoveryRequest( + lookback_period="24h", + sector_filter=[Sector.HEALTHCARE], + ) + result = graph.discover_trending(request) + + assert result.status == DiscoveryStatus.COMPLETED + assert len(result.trending_stocks) == 0 + + +class TestAllStocksFilteredOutByEventFilter: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + def test_all_stocks_filtered_out_by_event_filter( + self, mock_scores, mock_extract, mock_bulk_news + ): + mock_bulk_news.return_value = [] + mock_extract.return_value = [] + mock_scores.return_value = [ + TrendingStock( + ticker="AAPL", + company_name="Apple Inc.", + score=10.0, + mention_count=5, + sentiment=0.5, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Apple earnings", + source_articles=[], + ), + ] + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + request = DiscoveryRequest( + lookback_period="24h", + event_filter=[EventCategory.MERGER_ACQUISITION], + ) + result = graph.discover_trending(request) + + assert result.status == DiscoveryStatus.COMPLETED + assert len(result.trending_stocks) == 0 + + +class TestMultipleSectorsAndEventsFiltering: + @patch("tradingagents.graph.trading_graph.get_bulk_news") + @patch("tradingagents.graph.trading_graph.extract_entities") + @patch("tradingagents.graph.trading_graph.calculate_trending_scores") + def test_combined_sector_and_event_filtering( + self, mock_scores, mock_extract, mock_bulk_news + ): + mock_bulk_news.return_value = [] + mock_extract.return_value = [] + mock_scores.return_value = [ + TrendingStock( + ticker="AAPL", + company_name="Apple Inc.", + score=10.0, + mention_count=5, + sentiment=0.5, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Apple earnings", + source_articles=[], + ), + TrendingStock( + ticker="JPM", + company_name="JPMorgan Chase", + score=9.0, + mention_count=4, + sentiment=0.4, + sector=Sector.FINANCE, + event_type=EventCategory.EARNINGS, + news_summary="JPM earnings", + source_articles=[], + ), + TrendingStock( + ticker="XOM", + company_name="Exxon Mobil", + score=8.0, + mention_count=3, + sentiment=0.3, + sector=Sector.ENERGY, + event_type=EventCategory.REGULATORY, + news_summary="XOM regulatory news", + source_articles=[], + ), + ] + + from tradingagents.graph.trading_graph import TradingAgentsGraph + + with patch.object(TradingAgentsGraph, "__init__", lambda self, **kwargs: None): + graph = TradingAgentsGraph() + graph.config = { + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, + } + + request = DiscoveryRequest( + lookback_period="24h", + sector_filter=[Sector.TECHNOLOGY, Sector.FINANCE], + event_filter=[EventCategory.EARNINGS], + ) + result = graph.discover_trending(request) + + assert result.status == DiscoveryStatus.COMPLETED + assert len(result.trending_stocks) == 2 + tickers = [s.ticker for s in result.trending_stocks] + assert "AAPL" in tickers + assert "JPM" in tickers + assert "XOM" not in tickers + + +class TestDiscoveryResultPersistenceIntegration: + def test_discovery_result_can_be_serialized_and_saved(self): + from tradingagents.agents.discovery.persistence import ( + save_discovery_result, + generate_markdown_summary, + ) + import tempfile + import shutil + from pathlib import Path + + article = NewsArticle( + title="Test article", + source="Test", + url="https://test.com", + published_at=datetime.now(), + content_snippet="Test content", + ticker_mentions=["TEST"], + ) + + stock = TrendingStock( + ticker="TEST", + company_name="Test Company", + score=5.0, + mention_count=2, + sentiment=0.5, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.OTHER, + news_summary="Test news summary", + source_articles=[article], + ) + + request = DiscoveryRequest( + lookback_period="24h", + created_at=datetime.now(), + ) + + result = DiscoveryResult( + request=request, + trending_stocks=[stock], + status=DiscoveryStatus.COMPLETED, + started_at=datetime.now(), + completed_at=datetime.now(), + ) + + temp_dir = tempfile.mkdtemp() + try: + path = save_discovery_result(result, base_path=Path(temp_dir)) + assert path.exists() + assert (path / "discovery_result.json").exists() + assert (path / "discovery_summary.md").exists() + + markdown = generate_markdown_summary(result) + assert "TEST" in markdown + assert "Test Company" in markdown + finally: + shutil.rmtree(temp_dir) diff --git a/tests/discovery/test_models.py b/tests/discovery/test_models.py new file mode 100644 index 00000000..7717d022 --- /dev/null +++ b/tests/discovery/test_models.py @@ -0,0 +1,196 @@ +import pytest +from datetime import datetime +from tradingagents.agents.discovery import ( + TrendingStock, + NewsArticle, + DiscoveryRequest, + DiscoveryResult, + Sector, + EventCategory, +) +from tradingagents.agents.discovery.models import DiscoveryStatus + + +class TestTrendingStock: + def test_trending_stock_creation_and_validation(self): + article = NewsArticle( + title="Apple announces new iPhone", + source="Reuters", + url="https://reuters.com/article1", + published_at=datetime(2024, 1, 15, 10, 30, 0), + content_snippet="Apple Inc announced its latest iPhone model today...", + ticker_mentions=["AAPL"], + ) + + stock = TrendingStock( + ticker="AAPL", + company_name="Apple Inc.", + score=85.5, + mention_count=10, + sentiment=0.75, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.PRODUCT_LAUNCH, + news_summary="Apple announced new iPhone with advanced AI features.", + source_articles=[article], + ) + + assert stock.ticker == "AAPL" + assert stock.company_name == "Apple Inc." + assert stock.score == 85.5 + assert stock.mention_count == 10 + assert stock.sentiment == 0.75 + assert stock.sector == Sector.TECHNOLOGY + assert stock.event_type == EventCategory.PRODUCT_LAUNCH + assert len(stock.source_articles) == 1 + + +class TestNewsArticle: + def test_news_article_with_required_fields(self): + published = datetime(2024, 1, 15, 14, 0, 0) + + article = NewsArticle( + title="Tesla Q4 Earnings Beat Expectations", + source="Bloomberg", + url="https://bloomberg.com/news/tsla-earnings", + published_at=published, + content_snippet="Tesla Inc. reported fourth quarter earnings that exceeded analyst expectations...", + ticker_mentions=["TSLA", "F"], + ) + + assert article.title == "Tesla Q4 Earnings Beat Expectations" + assert article.source == "Bloomberg" + assert article.url == "https://bloomberg.com/news/tsla-earnings" + assert article.published_at == published + assert article.content_snippet.startswith("Tesla Inc.") + assert "TSLA" in article.ticker_mentions + assert "F" in article.ticker_mentions + + +class TestDiscoveryRequest: + def test_discovery_request_with_lookback_period_validation(self): + created = datetime(2024, 1, 15, 12, 0, 0) + + request = DiscoveryRequest( + lookback_period="24h", + sector_filter=[Sector.TECHNOLOGY, Sector.HEALTHCARE], + event_filter=[EventCategory.EARNINGS], + max_results=20, + created_at=created, + ) + + assert request.lookback_period == "24h" + assert Sector.TECHNOLOGY in request.sector_filter + assert Sector.HEALTHCARE in request.sector_filter + assert EventCategory.EARNINGS in request.event_filter + assert request.max_results == 20 + assert request.created_at == created + + def test_discovery_request_with_defaults(self): + request = DiscoveryRequest( + lookback_period="1h", + ) + + assert request.lookback_period == "1h" + assert request.sector_filter is None + assert request.event_filter is None + assert request.max_results == 20 + assert request.created_at is not None + + +class TestDiscoveryResult: + def test_discovery_result_state_transitions(self): + request = DiscoveryRequest(lookback_period="6h") + started = datetime(2024, 1, 15, 12, 0, 0) + + result = DiscoveryResult( + request=request, + trending_stocks=[], + status=DiscoveryStatus.CREATED, + started_at=started, + ) + + assert result.status == DiscoveryStatus.CREATED + + result.status = DiscoveryStatus.PROCESSING + assert result.status == DiscoveryStatus.PROCESSING + + result.status = DiscoveryStatus.COMPLETED + result.completed_at = datetime(2024, 1, 15, 12, 1, 0) + assert result.status == DiscoveryStatus.COMPLETED + assert result.completed_at is not None + + def test_discovery_result_failed_state(self): + request = DiscoveryRequest(lookback_period="7d") + + result = DiscoveryResult( + request=request, + trending_stocks=[], + status=DiscoveryStatus.FAILED, + started_at=datetime(2024, 1, 15, 12, 0, 0), + error_message="News API rate limit exceeded", + ) + + assert result.status == DiscoveryStatus.FAILED + assert result.error_message == "News API rate limit exceeded" + + +class TestSerializationRoundtrip: + def test_to_dict_and_from_dict_serialization_roundtrip(self): + article = NewsArticle( + title="Microsoft acquires AI startup", + source="WSJ", + url="https://wsj.com/msft-acquisition", + published_at=datetime(2024, 1, 15, 9, 0, 0), + content_snippet="Microsoft Corp announced the acquisition of an AI startup...", + ticker_mentions=["MSFT"], + ) + + stock = TrendingStock( + ticker="MSFT", + company_name="Microsoft Corporation", + score=92.3, + mention_count=15, + sentiment=0.65, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.MERGER_ACQUISITION, + news_summary="Microsoft announces major AI acquisition.", + source_articles=[article], + ) + + request = DiscoveryRequest( + lookback_period="24h", + sector_filter=[Sector.TECHNOLOGY], + event_filter=[EventCategory.MERGER_ACQUISITION], + max_results=10, + created_at=datetime(2024, 1, 15, 8, 0, 0), + ) + + result = DiscoveryResult( + request=request, + trending_stocks=[stock], + status=DiscoveryStatus.COMPLETED, + started_at=datetime(2024, 1, 15, 8, 0, 0), + completed_at=datetime(2024, 1, 15, 8, 1, 30), + ) + + result_dict = result.to_dict() + restored_result = DiscoveryResult.from_dict(result_dict) + + assert restored_result.status == result.status + assert restored_result.request.lookback_period == request.lookback_period + assert len(restored_result.trending_stocks) == 1 + + restored_stock = restored_result.trending_stocks[0] + assert restored_stock.ticker == stock.ticker + assert restored_stock.company_name == stock.company_name + assert restored_stock.score == stock.score + assert restored_stock.mention_count == stock.mention_count + assert restored_stock.sentiment == stock.sentiment + assert restored_stock.sector == stock.sector + assert restored_stock.event_type == stock.event_type + + assert len(restored_stock.source_articles) == 1 + restored_article = restored_stock.source_articles[0] + assert restored_article.title == article.title + assert restored_article.source == article.source + assert restored_article.url == article.url diff --git a/tests/discovery/test_persistence.py b/tests/discovery/test_persistence.py new file mode 100644 index 00000000..e649b02a --- /dev/null +++ b/tests/discovery/test_persistence.py @@ -0,0 +1,228 @@ +import pytest +import json +from datetime import datetime +from pathlib import Path +import tempfile +import shutil + +from tradingagents.agents.discovery import ( + TrendingStock, + NewsArticle, + DiscoveryRequest, + DiscoveryResult, + DiscoveryStatus, + Sector, + EventCategory, +) +from tradingagents.agents.discovery.persistence import ( + save_discovery_result, + generate_markdown_summary, +) + + +@pytest.fixture +def sample_discovery_result(): + articles = [ + NewsArticle( + title="Apple announces new iPhone with AI features", + source="Reuters", + url="https://reuters.com/apple-iphone-ai", + published_at=datetime(2024, 1, 15, 10, 30, 0), + content_snippet="Apple Inc announced its latest iPhone model with advanced AI...", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="Apple stock surges on earnings beat", + source="Bloomberg", + url="https://bloomberg.com/apple-earnings", + published_at=datetime(2024, 1, 15, 11, 0, 0), + content_snippet="Shares of Apple Inc surged after the company reported...", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="Microsoft cloud revenue grows 25%", + source="WSJ", + url="https://wsj.com/msft-cloud", + published_at=datetime(2024, 1, 15, 9, 0, 0), + content_snippet="Microsoft Corp reported strong cloud revenue growth...", + ticker_mentions=["MSFT"], + ), + ] + + stocks = [ + TrendingStock( + ticker="AAPL", + company_name="Apple Inc.", + score=8.54, + mention_count=12, + sentiment=0.72, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.EARNINGS, + news_summary="Apple reported strong earnings and announced new AI features.", + source_articles=[articles[0], articles[1]], + ), + TrendingStock( + ticker="MSFT", + company_name="Microsoft Corporation", + score=7.23, + mention_count=9, + sentiment=0.65, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.PRODUCT_LAUNCH, + news_summary="Microsoft cloud business continues strong growth.", + source_articles=[articles[2]], + ), + TrendingStock( + ticker="GOOGL", + company_name="Alphabet Inc.", + score=6.15, + mention_count=7, + sentiment=0.58, + sector=Sector.TECHNOLOGY, + event_type=EventCategory.REGULATORY, + news_summary="Google faces regulatory scrutiny in multiple markets.", + source_articles=[], + ), + ] + + request = DiscoveryRequest( + lookback_period="24h", + sector_filter=[Sector.TECHNOLOGY], + event_filter=[EventCategory.EARNINGS], + max_results=20, + created_at=datetime(2024, 1, 15, 14, 30, 45), + ) + + return DiscoveryResult( + request=request, + trending_stocks=stocks, + status=DiscoveryStatus.COMPLETED, + started_at=datetime(2024, 1, 15, 14, 30, 45), + completed_at=datetime(2024, 1, 15, 14, 31, 30), + ) + + +@pytest.fixture +def temp_results_dir(): + temp_dir = tempfile.mkdtemp() + yield Path(temp_dir) + shutil.rmtree(temp_dir) + + +class TestDirectoryStructureCreation: + def test_creates_correct_directory_structure(self, sample_discovery_result, temp_results_dir): + result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir) + + assert result_path.exists() + assert result_path.is_dir() + + path_parts = result_path.parts + assert "discovery" in path_parts + + date_part = path_parts[-2] + time_part = path_parts[-1] + + assert len(date_part.split("-")) == 3 + assert len(time_part.split("-")) == 3 + + +class TestDiscoveryResultJson: + def test_discovery_result_json_contains_all_fields(self, sample_discovery_result, temp_results_dir): + result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir) + + json_path = result_path / "discovery_result.json" + assert json_path.exists() + + with open(json_path, "r") as f: + saved_data = json.load(f) + + assert "request" in saved_data + assert "trending_stocks" in saved_data + assert "status" in saved_data + assert "started_at" in saved_data + assert "completed_at" in saved_data + + assert saved_data["request"]["lookback_period"] == "24h" + assert saved_data["status"] == "completed" + assert len(saved_data["trending_stocks"]) == 3 + + first_stock = saved_data["trending_stocks"][0] + assert first_stock["ticker"] == "AAPL" + assert first_stock["company_name"] == "Apple Inc." + assert first_stock["score"] == 8.54 + assert first_stock["mention_count"] == 12 + assert first_stock["sentiment"] == 0.72 + assert first_stock["sector"] == "technology" + assert first_stock["event_type"] == "earnings" + assert "news_summary" in first_stock + assert "source_articles" in first_stock + + +class TestDiscoverySummaryMarkdown: + def test_discovery_summary_md_is_human_readable(self, sample_discovery_result, temp_results_dir): + result_path = save_discovery_result(sample_discovery_result, base_path=temp_results_dir) + + md_path = result_path / "discovery_summary.md" + assert md_path.exists() + + with open(md_path, "r") as f: + markdown_content = f.read() + + assert "# Discovery Results" in markdown_content + assert "Timestamp:" in markdown_content + assert "Lookback Period:" in markdown_content + assert "24h" in markdown_content + assert "Total Stocks Found:" in markdown_content + + assert "## Trending Stocks" in markdown_content + assert "| Rank |" in markdown_content + assert "| Ticker |" in markdown_content + assert "| Company |" in markdown_content + assert "| Score |" in markdown_content + assert "| Mentions |" in markdown_content + assert "| Event |" in markdown_content + + assert "AAPL" in markdown_content + assert "Apple Inc." in markdown_content + assert "8.54" in markdown_content + assert "12" in markdown_content + assert "earnings" in markdown_content + + assert "MSFT" in markdown_content + assert "Microsoft Corporation" in markdown_content + + assert "## Top 3 Detailed Analysis" in markdown_content + assert "### 1. AAPL - Apple Inc." in markdown_content + assert "**Score:**" in markdown_content + assert "**Sentiment:**" in markdown_content + assert "**Sector:**" in markdown_content + assert "**Event Type:**" in markdown_content + assert "**Mentions:**" in markdown_content + assert "**News Summary:**" in markdown_content + + +class TestMarkdownGeneration: + def test_generate_markdown_with_filters(self, sample_discovery_result): + markdown = generate_markdown_summary(sample_discovery_result) + + assert "sector=technology" in markdown.lower() + assert "event=earnings" in markdown.lower() + + def test_generate_markdown_without_filters(self): + request = DiscoveryRequest( + lookback_period="6h", + created_at=datetime(2024, 1, 15, 10, 0, 0), + ) + + result = DiscoveryResult( + request=request, + trending_stocks=[], + status=DiscoveryStatus.COMPLETED, + started_at=datetime(2024, 1, 15, 10, 0, 0), + completed_at=datetime(2024, 1, 15, 10, 1, 0), + ) + + markdown = generate_markdown_summary(result) + + assert "Filters:" in markdown + assert "None" in markdown diff --git a/tests/discovery/test_scorer.py b/tests/discovery/test_scorer.py new file mode 100644 index 00000000..e40b778e --- /dev/null +++ b/tests/discovery/test_scorer.py @@ -0,0 +1,469 @@ +import pytest +import math +from datetime import datetime, timedelta +from unittest.mock import patch +from tradingagents.agents.discovery import NewsArticle, EventCategory, Sector +from tradingagents.agents.discovery.entity_extractor import EntityMention + + +class TestFrequencyCalculation: + def test_frequency_calculation_unique_article_count(self): + from tradingagents.agents.discovery.scorer import calculate_trending_scores + + now = datetime.now() + articles = [ + NewsArticle( + title="Apple Q4 Earnings", + source="Reuters", + url="https://reuters.com/article1", + published_at=now - timedelta(hours=1), + content_snippet="Apple Inc reported strong earnings.", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="Apple iPhone Sales", + source="Bloomberg", + url="https://bloomberg.com/article2", + published_at=now - timedelta(hours=2), + content_snippet="Apple saw record iPhone sales.", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="Apple AI Features", + source="WSJ", + url="https://wsj.com/article3", + published_at=now - timedelta(hours=3), + content_snippet="Apple announced AI features.", + ticker_mentions=["AAPL"], + ), + ] + + mentions = [ + EntityMention( + company_name="Apple Inc", + confidence=0.95, + context_snippet="Apple Inc reported strong earnings", + article_id="article_0", + event_type=EventCategory.EARNINGS, + ), + EntityMention( + company_name="Apple", + confidence=0.90, + context_snippet="Apple saw record iPhone sales", + article_id="article_1", + event_type=EventCategory.EARNINGS, + ), + EntityMention( + company_name="Apple Inc.", + confidence=0.92, + context_snippet="Apple announced AI features", + article_id="article_2", + event_type=EventCategory.PRODUCT_LAUNCH, + ), + ] + + with patch( + "tradingagents.agents.discovery.scorer.resolve_ticker" + ) as mock_resolve: + mock_resolve.return_value = "AAPL" + + with patch( + "tradingagents.agents.discovery.scorer.classify_sector" + ) as mock_sector: + mock_sector.return_value = "technology" + + result = calculate_trending_scores(mentions, articles) + + assert len(result) == 1 + assert result[0].ticker == "AAPL" + assert result[0].mention_count == 3 + + +class TestSentimentIntensityFactor: + def test_sentiment_intensity_uses_absolute_value(self): + from tradingagents.agents.discovery.scorer import calculate_trending_scores + + now = datetime.now() + articles = [ + NewsArticle( + title="Stock drops sharply", + source="Reuters", + url="https://reuters.com/article1", + published_at=now - timedelta(hours=1), + content_snippet="Company faced major issues.", + ticker_mentions=["TSLA"], + ), + NewsArticle( + title="More bad news", + source="Bloomberg", + url="https://bloomberg.com/article2", + published_at=now - timedelta(hours=2), + content_snippet="Further decline expected.", + ticker_mentions=["TSLA"], + ), + ] + + mentions = [ + EntityMention( + company_name="Tesla", + confidence=0.95, + context_snippet="Company faced major issues", + article_id="article_0", + event_type=EventCategory.OTHER, + sentiment=-0.8, + ), + EntityMention( + company_name="Tesla Inc", + confidence=0.90, + context_snippet="Further decline expected", + article_id="article_1", + event_type=EventCategory.OTHER, + sentiment=-0.6, + ), + ] + + with patch( + "tradingagents.agents.discovery.scorer.resolve_ticker" + ) as mock_resolve: + mock_resolve.return_value = "TSLA" + + with patch( + "tradingagents.agents.discovery.scorer.classify_sector" + ) as mock_sector: + mock_sector.return_value = "technology" + + result = calculate_trending_scores(mentions, articles) + + assert len(result) == 1 + assert result[0].sentiment < 0 + expected_sentiment = (-0.8 * 0.95 + -0.6 * 0.90) / (0.95 + 0.90) + assert abs(result[0].sentiment - expected_sentiment) < 0.01 + + +class TestRecencyWeightExponentialDecay: + def test_recency_weight_exponential_decay(self): + from tradingagents.agents.discovery.scorer import calculate_trending_scores + + now = datetime.now() + articles = [ + NewsArticle( + title="Recent news", + source="Reuters", + url="https://reuters.com/article1", + published_at=now - timedelta(hours=1), + content_snippet="Recent company news.", + ticker_mentions=["NVDA"], + ), + NewsArticle( + title="Older news", + source="Bloomberg", + url="https://bloomberg.com/article2", + published_at=now - timedelta(hours=10), + content_snippet="Older company news.", + ticker_mentions=["NVDA"], + ), + ] + + mentions = [ + EntityMention( + company_name="Nvidia", + confidence=0.90, + context_snippet="Recent company news", + article_id="article_0", + event_type=EventCategory.OTHER, + sentiment=0.5, + ), + EntityMention( + company_name="Nvidia", + confidence=0.90, + context_snippet="Older company news", + article_id="article_1", + event_type=EventCategory.OTHER, + sentiment=0.5, + ), + ] + + with patch( + "tradingagents.agents.discovery.scorer.resolve_ticker" + ) as mock_resolve: + mock_resolve.return_value = "NVDA" + + with patch( + "tradingagents.agents.discovery.scorer.classify_sector" + ) as mock_sector: + mock_sector.return_value = "technology" + + result = calculate_trending_scores(mentions, articles, decay_rate=0.1) + + assert len(result) == 1 + recent_weight = math.exp(-0.1 * 1) + older_weight = math.exp(-0.1 * 10) + avg_recency = (recent_weight + older_weight) / 2 + assert result[0].score > 0 + + +class TestMinimumThresholdFiltering: + def test_minimum_threshold_filtering_requires_two_articles(self): + from tradingagents.agents.discovery.scorer import calculate_trending_scores + + now = datetime.now() + articles = [ + NewsArticle( + title="Single mention stock", + source="Reuters", + url="https://reuters.com/article1", + published_at=now - timedelta(hours=1), + content_snippet="Some company news.", + ticker_mentions=["AMD"], + ), + NewsArticle( + title="Multiple mention stock 1", + source="Bloomberg", + url="https://bloomberg.com/article2", + published_at=now - timedelta(hours=2), + content_snippet="Popular company news.", + ticker_mentions=["MSFT"], + ), + NewsArticle( + title="Multiple mention stock 2", + source="WSJ", + url="https://wsj.com/article3", + published_at=now - timedelta(hours=3), + content_snippet="More popular company news.", + ticker_mentions=["MSFT"], + ), + ] + + mentions = [ + EntityMention( + company_name="AMD", + confidence=0.90, + context_snippet="Some company news", + article_id="article_0", + event_type=EventCategory.OTHER, + ), + EntityMention( + company_name="Microsoft", + confidence=0.95, + context_snippet="Popular company news", + article_id="article_1", + event_type=EventCategory.OTHER, + ), + EntityMention( + company_name="Microsoft Corp", + confidence=0.92, + context_snippet="More popular company news", + article_id="article_2", + event_type=EventCategory.OTHER, + ), + ] + + with patch( + "tradingagents.agents.discovery.scorer.resolve_ticker" + ) as mock_resolve: + + def resolve_side_effect(name): + if "AMD" in name or name == "AMD": + return "AMD" + return "MSFT" + + mock_resolve.side_effect = resolve_side_effect + + with patch( + "tradingagents.agents.discovery.scorer.classify_sector" + ) as mock_sector: + mock_sector.return_value = "technology" + + result = calculate_trending_scores(mentions, articles, min_mentions=2) + + assert len(result) == 1 + assert result[0].ticker == "MSFT" + assert all(stock.mention_count >= 2 for stock in result) + + +class TestFinalScoreFormulaCorrectness: + def test_final_score_formula_correctness(self): + from tradingagents.agents.discovery.scorer import calculate_trending_scores + + now = datetime.now() + hours_old = 2.0 + articles = [ + NewsArticle( + title="Test article 1", + source="Reuters", + url="https://reuters.com/article1", + published_at=now - timedelta(hours=hours_old), + content_snippet="Google announced results.", + ticker_mentions=["GOOGL"], + ), + NewsArticle( + title="Test article 2", + source="Bloomberg", + url="https://bloomberg.com/article2", + published_at=now - timedelta(hours=hours_old), + content_snippet="Alphabet earnings beat.", + ticker_mentions=["GOOGL"], + ), + ] + + sentiment_val = 0.6 + confidence = 0.9 + mentions = [ + EntityMention( + company_name="Google", + confidence=confidence, + context_snippet="Google announced results", + article_id="article_0", + event_type=EventCategory.EARNINGS, + sentiment=sentiment_val, + ), + EntityMention( + company_name="Alphabet", + confidence=confidence, + context_snippet="Alphabet earnings beat", + article_id="article_1", + event_type=EventCategory.EARNINGS, + sentiment=sentiment_val, + ), + ] + + decay_rate = 0.1 + with patch( + "tradingagents.agents.discovery.scorer.resolve_ticker" + ) as mock_resolve: + mock_resolve.return_value = "GOOGL" + + with patch( + "tradingagents.agents.discovery.scorer.classify_sector" + ) as mock_sector: + mock_sector.return_value = "technology" + + result = calculate_trending_scores( + mentions, articles, decay_rate=decay_rate + ) + + assert len(result) == 1 + stock = result[0] + + frequency = 2 + sentiment_factor = 1 + abs(sentiment_val) + recency_weight = math.exp(-decay_rate * hours_old) + expected_score = frequency * sentiment_factor * recency_weight + + assert abs(stock.score - expected_score) < 0.01 + + +class TestSortingByScoreDescending: + def test_results_sorted_by_score_descending(self): + from tradingagents.agents.discovery.scorer import calculate_trending_scores + + now = datetime.now() + articles = [ + NewsArticle( + title="High score stock 1", + source="Reuters", + url="https://reuters.com/article1", + published_at=now - timedelta(hours=1), + content_snippet="Apple news.", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="High score stock 2", + source="Bloomberg", + url="https://bloomberg.com/article2", + published_at=now - timedelta(hours=1), + content_snippet="More Apple news.", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="High score stock 3", + source="WSJ", + url="https://wsj.com/article3", + published_at=now - timedelta(hours=1), + content_snippet="Even more Apple news.", + ticker_mentions=["AAPL"], + ), + NewsArticle( + title="Low score stock 1", + source="CNBC", + url="https://cnbc.com/article4", + published_at=now - timedelta(hours=10), + content_snippet="Tesla news.", + ticker_mentions=["TSLA"], + ), + NewsArticle( + title="Low score stock 2", + source="FT", + url="https://ft.com/article5", + published_at=now - timedelta(hours=10), + content_snippet="More Tesla news.", + ticker_mentions=["TSLA"], + ), + ] + + mentions = [ + EntityMention( + company_name="Apple", + confidence=0.95, + context_snippet="Apple news", + article_id="article_0", + event_type=EventCategory.OTHER, + sentiment=0.8, + ), + EntityMention( + company_name="Apple Inc", + confidence=0.93, + context_snippet="More Apple news", + article_id="article_1", + event_type=EventCategory.OTHER, + sentiment=0.8, + ), + EntityMention( + company_name="Apple", + confidence=0.90, + context_snippet="Even more Apple news", + article_id="article_2", + event_type=EventCategory.OTHER, + sentiment=0.8, + ), + EntityMention( + company_name="Tesla", + confidence=0.85, + context_snippet="Tesla news", + article_id="article_3", + event_type=EventCategory.OTHER, + sentiment=0.2, + ), + EntityMention( + company_name="Tesla Inc", + confidence=0.85, + context_snippet="More Tesla news", + article_id="article_4", + event_type=EventCategory.OTHER, + sentiment=0.2, + ), + ] + + with patch( + "tradingagents.agents.discovery.scorer.resolve_ticker" + ) as mock_resolve: + + def resolve_side_effect(name): + if "Apple" in name: + return "AAPL" + if "Tesla" in name: + return "TSLA" + return None + + mock_resolve.side_effect = resolve_side_effect + + with patch( + "tradingagents.agents.discovery.scorer.classify_sector" + ) as mock_sector: + mock_sector.return_value = "technology" + + result = calculate_trending_scores(mentions, articles, min_mentions=2) + + assert len(result) == 2 + for i in range(len(result) - 1): + assert result[i].score >= result[i + 1].score diff --git a/tests/discovery/test_sector_classifier.py b/tests/discovery/test_sector_classifier.py new file mode 100644 index 00000000..15458e58 --- /dev/null +++ b/tests/discovery/test_sector_classifier.py @@ -0,0 +1,94 @@ +import pytest +from unittest.mock import patch, MagicMock +from tradingagents.dataflows.trending.sector_classifier import ( + classify_sector, + TICKER_TO_SECTOR, + VALID_SECTORS, + _llm_classify_sector, + _sector_cache, +) + + +class TestStaticSectorMapping: + def test_static_sector_mapping_for_known_technology_tickers(self): + assert classify_sector("AAPL") == "technology" + assert classify_sector("MSFT") == "technology" + assert classify_sector("GOOGL") == "technology" + assert classify_sector("NVDA") == "technology" + + def test_static_sector_mapping_for_known_healthcare_tickers(self): + assert classify_sector("JNJ") == "healthcare" + assert classify_sector("PFE") == "healthcare" + assert classify_sector("UNH") == "healthcare" + + def test_static_sector_mapping_for_known_finance_tickers(self): + assert classify_sector("JPM") == "finance" + assert classify_sector("BAC") == "finance" + assert classify_sector("GS") == "finance" + + def test_static_sector_mapping_for_known_energy_tickers(self): + assert classify_sector("XOM") == "energy" + assert classify_sector("CVX") == "energy" + assert classify_sector("COP") == "energy" + + def test_static_sector_mapping_case_insensitive(self): + assert classify_sector("aapl") == "technology" + assert classify_sector("AAPL") == "technology" + assert classify_sector("Aapl") == "technology" + + +class TestLLMFallback: + @patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector") + def test_llm_fallback_for_unknown_tickers(self, mock_llm_classify): + mock_llm_classify.return_value = "technology" + _sector_cache.clear() + + result = classify_sector("UNKNOWNTICKER123") + + mock_llm_classify.assert_called_once_with("UNKNOWNTICKER123") + assert result == "technology" + + @patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector") + def test_llm_fallback_caches_results(self, mock_llm_classify): + mock_llm_classify.return_value = "healthcare" + _sector_cache.clear() + + result1 = classify_sector("NEWCO123") + result2 = classify_sector("NEWCO123") + + assert mock_llm_classify.call_count == 1 + assert result1 == "healthcare" + assert result2 == "healthcare" + + @patch("tradingagents.dataflows.trending.sector_classifier._llm_classify_sector") + def test_llm_fallback_returns_other_on_error(self, mock_llm_classify): + mock_llm_classify.side_effect = Exception("LLM error") + _sector_cache.clear() + + result = classify_sector("ERRORCO") + + assert result == "other" + + +class TestAllSectorCategories: + def test_all_sector_categories_in_valid_sectors(self): + expected_sectors = { + "technology", + "healthcare", + "finance", + "energy", + "consumer_goods", + "industrials", + "other", + } + assert VALID_SECTORS == expected_sectors + + def test_static_mapping_covers_all_sector_categories(self): + sectors_in_mapping = set(TICKER_TO_SECTOR.values()) + assert sectors_in_mapping.issubset(VALID_SECTORS) + + def test_classify_sector_always_returns_valid_sector(self): + test_tickers = ["AAPL", "JPM", "XOM", "JNJ", "WMT", "CAT"] + for ticker in test_tickers: + result = classify_sector(ticker) + assert result in VALID_SECTORS diff --git a/tests/discovery/test_stock_resolver.py b/tests/discovery/test_stock_resolver.py new file mode 100644 index 00000000..96f5b455 --- /dev/null +++ b/tests/discovery/test_stock_resolver.py @@ -0,0 +1,135 @@ +import pytest +import logging +from unittest.mock import patch, MagicMock +from tradingagents.dataflows.trending.stock_resolver import ( + resolve_ticker, + validate_us_ticker, + _normalize_company_name, + _search_yfinance_ticker, +) + + +class TestStaticLookup: + def test_static_lookup_for_known_companies(self): + assert resolve_ticker("Apple") == "AAPL" + assert resolve_ticker("Microsoft") == "MSFT" + assert resolve_ticker("Google") == "GOOGL" + assert resolve_ticker("Amazon") == "AMZN" + assert resolve_ticker("Tesla") == "TSLA" + assert resolve_ticker("Nvidia") == "NVDA" + + def test_static_lookup_case_insensitive(self): + assert resolve_ticker("APPLE") == "AAPL" + assert resolve_ticker("apple") == "AAPL" + assert resolve_ticker("ApPlE") == "AAPL" + assert resolve_ticker("microsoft") == "MSFT" + assert resolve_ticker("MICROSOFT") == "MSFT" + + +class TestNameVariationHandling: + def test_name_variation_handling_with_suffixes(self): + assert resolve_ticker("Apple Inc.") == "AAPL" + assert resolve_ticker("Apple Inc") == "AAPL" + assert resolve_ticker("Apple Corporation") == "AAPL" + assert resolve_ticker("Microsoft Corp.") == "MSFT" + assert resolve_ticker("Microsoft Corp") == "MSFT" + assert resolve_ticker("Tesla Inc") == "TSLA" + + def test_name_variation_handling_informal_names(self): + assert resolve_ticker("the iPhone maker") == "AAPL" + assert resolve_ticker("iPhone maker") == "AAPL" + assert resolve_ticker("the search giant") == "GOOGL" + assert resolve_ticker("the e-commerce giant") == "AMZN" + assert resolve_ticker("EV maker Tesla") == "TSLA" + + def test_name_variation_handling_alternate_names(self): + assert resolve_ticker("Alphabet") == "GOOGL" + assert resolve_ticker("Meta") == "META" + assert resolve_ticker("Facebook") == "META" + assert resolve_ticker("Meta Platforms") == "META" + + +class TestYfinanceFallback: + @patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker") + @patch("tradingagents.dataflows.trending.stock_resolver.validate_us_ticker") + def test_yfinance_fallback_for_unknown_company(self, mock_validate, mock_search): + mock_search.return_value = "PLTR" + mock_validate.return_value = True + + result = resolve_ticker("UnknownTechStartupXYZ") + + mock_search.assert_called_once() + assert result == "PLTR" + + @patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker") + def test_yfinance_fallback_returns_none_when_not_found(self, mock_search): + mock_search.return_value = None + + result = resolve_ticker("NonexistentCompanyXYZ123") + + assert result is None + + +class TestUSExchangeValidation: + @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker") + def test_validate_us_ticker_accepts_nyse(self, mock_ticker): + mock_info = {"exchange": "NYQ"} + mock_ticker.return_value.info = mock_info + + assert validate_us_ticker("IBM") is True + + @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker") + def test_validate_us_ticker_accepts_nasdaq(self, mock_ticker): + mock_info = {"exchange": "NMS"} + mock_ticker.return_value.info = mock_info + + assert validate_us_ticker("AAPL") is True + + @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker") + def test_validate_us_ticker_accepts_amex(self, mock_ticker): + mock_info = {"exchange": "ASE"} + mock_ticker.return_value.info = mock_info + + assert validate_us_ticker("SPY") is True + + @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker") + def test_validate_us_ticker_rejects_international(self, mock_ticker): + mock_info = {"exchange": "LSE"} + mock_ticker.return_value.info = mock_info + + assert validate_us_ticker("VOD.L") is False + + @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker") + def test_validate_us_ticker_rejects_otc(self, mock_ticker): + mock_info = {"exchange": "PNK"} + mock_ticker.return_value.info = mock_info + + assert validate_us_ticker("OTCPK") is False + + +class TestAmbiguousResolutionLogging: + def test_ambiguous_resolution_logs_multiple_matches(self, caplog): + with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"): + pass + + @patch("tradingagents.dataflows.trending.stock_resolver._search_yfinance_ticker") + @patch("tradingagents.dataflows.trending.stock_resolver.validate_us_ticker") + def test_yfinance_fallback_is_logged(self, mock_validate, mock_search, caplog): + mock_search.return_value = "RBLX" + mock_validate.return_value = True + + with caplog.at_level(logging.INFO, logger="tradingagents.dataflows.trending.stock_resolver"): + result = resolve_ticker("SomeRandomCompanyNotInMapping") + + assert any("fallback" in record.message.lower() or "yfinance" in record.message.lower() + for record in caplog.records) + + @patch("tradingagents.dataflows.trending.stock_resolver.yf.Ticker") + def test_validation_failure_is_logged(self, mock_ticker, caplog): + mock_info = {"exchange": "LSE"} + mock_ticker.return_value.info = mock_info + + with caplog.at_level(logging.WARNING, logger="tradingagents.dataflows.trending.stock_resolver"): + result = validate_us_ticker("VOD.L") + + assert result is False diff --git a/tradingagents/agents/discovery/__init__.py b/tradingagents/agents/discovery/__init__.py new file mode 100644 index 00000000..0effd7ce --- /dev/null +++ b/tradingagents/agents/discovery/__init__.py @@ -0,0 +1,53 @@ +from .models import ( + NewsArticle, + TrendingStock, + DiscoveryRequest, + DiscoveryResult, + DiscoveryStatus, + Sector, + EventCategory, +) +from .exceptions import ( + DiscoveryError, + NewsUnavailableError, + DiscoveryTimeoutError, + TickerResolutionError, +) +from .entity_extractor import ( + EntityMention, + extract_entities, + BATCH_SIZE, +) +from .scorer import ( + calculate_trending_scores, + DEFAULT_DECAY_RATE, + DEFAULT_MAX_RESULTS, + DEFAULT_MIN_MENTIONS, +) +from .persistence import ( + save_discovery_result, + generate_markdown_summary, +) + +__all__ = [ + "NewsArticle", + "TrendingStock", + "DiscoveryRequest", + "DiscoveryResult", + "DiscoveryStatus", + "Sector", + "EventCategory", + "DiscoveryError", + "NewsUnavailableError", + "DiscoveryTimeoutError", + "TickerResolutionError", + "EntityMention", + "extract_entities", + "BATCH_SIZE", + "calculate_trending_scores", + "DEFAULT_DECAY_RATE", + "DEFAULT_MAX_RESULTS", + "DEFAULT_MIN_MENTIONS", + "save_discovery_result", + "generate_markdown_summary", +] diff --git a/tradingagents/agents/discovery/entity_extractor.py b/tradingagents/agents/discovery/entity_extractor.py new file mode 100644 index 00000000..5bad3242 --- /dev/null +++ b/tradingagents/agents/discovery/entity_extractor.py @@ -0,0 +1,159 @@ +from dataclasses import dataclass, field +from typing import List, Optional +from pydantic import BaseModel, Field as PydanticField + +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_google_genai import ChatGoogleGenerativeAI + +from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.agents.discovery.models import NewsArticle, EventCategory + + +BATCH_SIZE = 10 + + +@dataclass +class EntityMention: + company_name: str + confidence: float + context_snippet: str + article_id: str + event_type: EventCategory + sentiment: float = field(default=0.0) + + +class ExtractedEntity(BaseModel): + company_name: str = PydanticField(description="The name of the publicly traded company mentioned") + confidence: float = PydanticField(description="Confidence score from 0.0 to 1.0 based on mention clarity") + context_snippet: str = PydanticField(description="Surrounding context of 50-100 characters around the company mention") + event_type: str = PydanticField(description="Event category: earnings, merger_acquisition, regulatory, product_launch, executive_change, or other") + sentiment: float = PydanticField(default=0.0, description="Sentiment score from -1.0 (negative) to 1.0 (positive)") + + +class ExtractionResponse(BaseModel): + entities: List[ExtractedEntity] = PydanticField(default_factory=list, description="List of extracted company entities") + + +def _get_llm(config: Optional[dict] = None): + cfg = config or DEFAULT_CONFIG + provider = cfg.get("llm_provider", "openai").lower() + model = cfg.get("quick_think_llm", "gpt-4o-mini") + backend_url = cfg.get("backend_url", "https://api.openai.com/v1") + + if provider in ("openai", "ollama", "openrouter"): + return ChatOpenAI(model=model, base_url=backend_url) + elif provider == "anthropic": + return ChatAnthropic(model=model, base_url=backend_url) + elif provider == "google": + return ChatGoogleGenerativeAI(model=model) + else: + raise ValueError(f"Unsupported LLM provider: {provider}") + + +EXTRACTION_PROMPT = """You are an expert at identifying publicly traded companies mentioned in news articles. + +For each article provided, extract all mentions of publicly traded companies. For each company mention: + +1. Extract the company name as it appears (e.g., "Apple Inc.", "Apple", "AAPL", "the iPhone maker") +2. Assign a confidence score from 0.0 to 1.0 based on how clearly the company is mentioned: + - 0.9-1.0: Direct company name or ticker symbol + - 0.7-0.9: Clear reference with context (e.g., "the Cupertino tech giant") + - 0.5-0.7: Indirect reference requiring inference + - Below 0.5: Uncertain or ambiguous reference +3. Extract 50-100 characters of surrounding context +4. Classify the event type: + - earnings: Quarterly/annual earnings reports, revenue announcements + - merger_acquisition: Mergers, acquisitions, buyouts, takeovers + - regulatory: SEC filings, government investigations, compliance issues + - product_launch: New products, services, or features + - executive_change: CEO/CFO changes, board appointments, departures + - other: Any other business news +5. Assign a sentiment score from -1.0 to 1.0: + - -1.0: Very negative news (lawsuits, crashes, major failures) + - -0.5: Moderately negative news + - 0.0: Neutral news + - 0.5: Moderately positive news + - 1.0: Very positive news (breakthroughs, record earnings) + +Only extract companies that are publicly traded on major stock exchanges. +Handle name variations by providing the most complete company name found. + +Articles to analyze: +{articles_text} + +Extract all company mentions from the articles above.""" + + +def _format_articles_for_prompt(articles: List[NewsArticle], start_idx: int) -> str: + formatted = [] + for i, article in enumerate(articles): + article_id = f"article_{start_idx + i}" + formatted.append( + f"[{article_id}]\n" + f"Title: {article.title}\n" + f"Source: {article.source}\n" + f"Content: {article.content_snippet}\n" + ) + return "\n---\n".join(formatted) + + +def _extract_batch( + articles: List[NewsArticle], + start_idx: int, + llm, +) -> List[EntityMention]: + if not articles: + return [] + + articles_text = _format_articles_for_prompt(articles, start_idx) + prompt = EXTRACTION_PROMPT.format(articles_text=articles_text) + + structured_llm = llm.with_structured_output(ExtractionResponse) + response = structured_llm.invoke(prompt) + + mentions = [] + for entity in response.entities: + event_type_str = entity.event_type.lower().strip() + valid_event_types = {e.value for e in EventCategory} + if event_type_str not in valid_event_types: + event_type_str = "other" + + confidence = max(0.0, min(1.0, entity.confidence)) + sentiment = max(-1.0, min(1.0, entity.sentiment)) + + context = entity.context_snippet + if len(context) > 150: + context = context[:147] + "..." + + mention = EntityMention( + company_name=entity.company_name, + confidence=confidence, + context_snippet=context, + article_id=f"article_{start_idx}", + event_type=EventCategory(event_type_str), + sentiment=sentiment, + ) + mentions.append(mention) + + return mentions + + +def extract_entities( + articles: List[NewsArticle], + config: Optional[dict] = None, +) -> List[EntityMention]: + if not articles: + return [] + + llm = _get_llm(config) + all_mentions: List[EntityMention] = [] + + for batch_start in range(0, len(articles), BATCH_SIZE): + batch_end = min(batch_start + BATCH_SIZE, len(articles)) + batch = articles[batch_start:batch_end] + + batch_mentions = _extract_batch(batch, batch_start, llm) + all_mentions.extend(batch_mentions) + + return all_mentions diff --git a/tradingagents/agents/discovery/exceptions.py b/tradingagents/agents/discovery/exceptions.py new file mode 100644 index 00000000..94953c62 --- /dev/null +++ b/tradingagents/agents/discovery/exceptions.py @@ -0,0 +1,14 @@ +class DiscoveryError(Exception): + pass + + +class NewsUnavailableError(DiscoveryError): + pass + + +class DiscoveryTimeoutError(DiscoveryError): + pass + + +class TickerResolutionError(DiscoveryError): + pass diff --git a/tradingagents/agents/discovery/models.py b/tradingagents/agents/discovery/models.py new file mode 100644 index 00000000..9595f89d --- /dev/null +++ b/tradingagents/agents/discovery/models.py @@ -0,0 +1,180 @@ +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import List, Optional, Dict, Any + + +class DiscoveryStatus(Enum): + CREATED = "created" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + + +class Sector(Enum): + TECHNOLOGY = "technology" + HEALTHCARE = "healthcare" + FINANCE = "finance" + ENERGY = "energy" + CONSUMER_GOODS = "consumer_goods" + INDUSTRIALS = "industrials" + OTHER = "other" + + +class EventCategory(Enum): + EARNINGS = "earnings" + MERGER_ACQUISITION = "merger_acquisition" + REGULATORY = "regulatory" + PRODUCT_LAUNCH = "product_launch" + EXECUTIVE_CHANGE = "executive_change" + OTHER = "other" + + +@dataclass +class NewsArticle: + title: str + source: str + url: str + published_at: datetime + content_snippet: str + ticker_mentions: List[str] + + def to_dict(self) -> Dict[str, Any]: + return { + "title": self.title, + "source": self.source, + "url": self.url, + "published_at": self.published_at.isoformat(), + "content_snippet": self.content_snippet, + "ticker_mentions": self.ticker_mentions, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "NewsArticle": + return cls( + title=data["title"], + source=data["source"], + url=data["url"], + published_at=datetime.fromisoformat(data["published_at"]), + content_snippet=data["content_snippet"], + ticker_mentions=data["ticker_mentions"], + ) + + +@dataclass +class TrendingStock: + ticker: str + company_name: str + score: float + mention_count: int + sentiment: float + sector: Sector + event_type: EventCategory + news_summary: str + source_articles: List[NewsArticle] + + def to_dict(self) -> Dict[str, Any]: + return { + "ticker": self.ticker, + "company_name": self.company_name, + "score": self.score, + "mention_count": self.mention_count, + "sentiment": self.sentiment, + "sector": self.sector.value, + "event_type": self.event_type.value, + "news_summary": self.news_summary, + "source_articles": [article.to_dict() for article in self.source_articles], + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TrendingStock": + return cls( + ticker=data["ticker"], + company_name=data["company_name"], + score=data["score"], + mention_count=data["mention_count"], + sentiment=data["sentiment"], + sector=Sector(data["sector"]), + event_type=EventCategory(data["event_type"]), + news_summary=data["news_summary"], + source_articles=[ + NewsArticle.from_dict(article) for article in data["source_articles"] + ], + ) + + +@dataclass +class DiscoveryRequest: + lookback_period: str + sector_filter: Optional[List[Sector]] = None + event_filter: Optional[List[EventCategory]] = None + max_results: int = 20 + created_at: datetime = field(default_factory=datetime.now) + + def to_dict(self) -> Dict[str, Any]: + return { + "lookback_period": self.lookback_period, + "sector_filter": ( + [s.value for s in self.sector_filter] if self.sector_filter else None + ), + "event_filter": ( + [e.value for e in self.event_filter] if self.event_filter else None + ), + "max_results": self.max_results, + "created_at": self.created_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryRequest": + return cls( + lookback_period=data["lookback_period"], + sector_filter=( + [Sector(s) for s in data["sector_filter"]] + if data.get("sector_filter") + else None + ), + event_filter=( + [EventCategory(e) for e in data["event_filter"]] + if data.get("event_filter") + else None + ), + max_results=data.get("max_results", 20), + created_at=datetime.fromisoformat(data["created_at"]), + ) + + +@dataclass +class DiscoveryResult: + request: DiscoveryRequest + trending_stocks: List[TrendingStock] + status: DiscoveryStatus + started_at: datetime + completed_at: Optional[datetime] = None + error_message: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "request": self.request.to_dict(), + "trending_stocks": [stock.to_dict() for stock in self.trending_stocks], + "status": self.status.value, + "started_at": self.started_at.isoformat(), + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "error_message": self.error_message, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DiscoveryResult": + return cls( + request=DiscoveryRequest.from_dict(data["request"]), + trending_stocks=[ + TrendingStock.from_dict(stock) for stock in data["trending_stocks"] + ], + status=DiscoveryStatus(data["status"]), + started_at=datetime.fromisoformat(data["started_at"]), + completed_at=( + datetime.fromisoformat(data["completed_at"]) + if data.get("completed_at") + else None + ), + error_message=data.get("error_message"), + ) diff --git a/tradingagents/agents/discovery/persistence.py b/tradingagents/agents/discovery/persistence.py new file mode 100644 index 00000000..68f2bd03 --- /dev/null +++ b/tradingagents/agents/discovery/persistence.py @@ -0,0 +1,120 @@ +import json +from datetime import datetime +from pathlib import Path +from typing import Optional + +from .models import DiscoveryResult, TrendingStock + + +def save_discovery_result( + result: DiscoveryResult, + base_path: Optional[Path] = None, +) -> Path: + if base_path is None: + base_path = Path("results") + + timestamp = result.completed_at or result.started_at + date_str = timestamp.strftime("%Y-%m-%d") + time_str = timestamp.strftime("%H-%M-%S") + + result_dir = base_path / "discovery" / date_str / time_str + result_dir.mkdir(parents=True, exist_ok=True) + + json_path = result_dir / "discovery_result.json" + with open(json_path, "w") as f: + json.dump(result.to_dict(), f, indent=2) + + md_path = result_dir / "discovery_summary.md" + markdown_content = generate_markdown_summary(result) + with open(md_path, "w") as f: + f.write(markdown_content) + + return result_dir + + +def generate_markdown_summary(result: DiscoveryResult) -> str: + lines = [] + + lines.append("# Discovery Results") + lines.append("") + + timestamp = result.completed_at or result.started_at + lines.append(f"**Timestamp:** {timestamp.strftime('%Y-%m-%d %H:%M:%S')}") + lines.append(f"**Lookback Period:** {result.request.lookback_period}") + + filters = _format_filters(result) + lines.append(f"**Filters:** {filters}") + lines.append(f"**Total Stocks Found:** {len(result.trending_stocks)}") + lines.append("") + + lines.append("## Trending Stocks") + lines.append("") + lines.append("| Rank | Ticker | Company | Score | Mentions | Event |") + lines.append("|------|--------|---------|-------|----------|-------|") + + for rank, stock in enumerate(result.trending_stocks, 1): + lines.append( + f"| {rank} | {stock.ticker} | {stock.company_name} | " + f"{stock.score:.2f} | {stock.mention_count} | {stock.event_type.value} |" + ) + + lines.append("") + + lines.append("## Top 3 Detailed Analysis") + lines.append("") + + top_stocks = result.trending_stocks[:3] + for rank, stock in enumerate(top_stocks, 1): + lines.extend(_format_stock_detail(rank, stock)) + + return "\n".join(lines) + + +def _format_filters(result: DiscoveryResult) -> str: + filter_parts = [] + + if result.request.sector_filter: + sector_values = [s.value for s in result.request.sector_filter] + filter_parts.append(f"sector={','.join(sector_values)}") + + if result.request.event_filter: + event_values = [e.value for e in result.request.event_filter] + filter_parts.append(f"event={','.join(event_values)}") + + if filter_parts: + return " ".join(filter_parts) + return "None" + + +def _format_stock_detail(rank: int, stock: TrendingStock) -> list: + lines = [] + + lines.append(f"### {rank}. {stock.ticker} - {stock.company_name}") + lines.append(f"- **Score:** {stock.score:.2f}") + + sentiment_label = _get_sentiment_label(stock.sentiment) + lines.append(f"- **Sentiment:** {stock.sentiment:.2f} ({sentiment_label})") + lines.append(f"- **Sector:** {stock.sector.value}") + lines.append(f"- **Event Type:** {stock.event_type.value}") + lines.append(f"- **Mentions:** {stock.mention_count}") + lines.append("") + + lines.append("**News Summary:**") + lines.append(stock.news_summary) + lines.append("") + + if stock.source_articles: + lines.append("**Top Sources:**") + for article in stock.source_articles[:3]: + lines.append(f"- [{article.title}] - {article.source}") + lines.append("") + + return lines + + +def _get_sentiment_label(sentiment: float) -> str: + if sentiment > 0.3: + return "positive" + elif sentiment < -0.3: + return "negative" + return "neutral" diff --git a/tradingagents/agents/discovery/scorer.py b/tradingagents/agents/discovery/scorer.py new file mode 100644 index 00000000..564bc717 --- /dev/null +++ b/tradingagents/agents/discovery/scorer.py @@ -0,0 +1,153 @@ +import math +from collections import defaultdict +from datetime import datetime +from typing import List, Dict, Optional + +from tradingagents.agents.discovery.models import ( + TrendingStock, + NewsArticle, + Sector, + EventCategory, +) +from tradingagents.agents.discovery.entity_extractor import EntityMention +from tradingagents.dataflows.trending.stock_resolver import resolve_ticker +from tradingagents.dataflows.trending.sector_classifier import classify_sector + + +DEFAULT_DECAY_RATE = 0.1 +DEFAULT_MAX_RESULTS = 20 +DEFAULT_MIN_MENTIONS = 2 + + +def _aggregate_sentiment(mentions: List[EntityMention]) -> float: + if not mentions: + return 0.0 + + total_weighted_sentiment = 0.0 + total_confidence = 0.0 + + for mention in mentions: + total_weighted_sentiment += mention.sentiment * mention.confidence + total_confidence += mention.confidence + + if total_confidence == 0: + return 0.0 + + return total_weighted_sentiment / total_confidence + + +def _calculate_recency_weight( + articles: List[NewsArticle], + article_ids: set, + decay_rate: float, +) -> float: + if not articles: + return 1.0 + + now = datetime.now() + weights = [] + + for i, article in enumerate(articles): + article_id = f"article_{i}" + if article_id in article_ids: + hours_old = (now - article.published_at).total_seconds() / 3600.0 + weight = math.exp(-decay_rate * hours_old) + weights.append(weight) + + if not weights: + return 1.0 + + return sum(weights) / len(weights) + + +def _get_most_common_event_type(mentions: List[EntityMention]) -> EventCategory: + if not mentions: + return EventCategory.OTHER + + event_counts: Dict[EventCategory, int] = defaultdict(int) + for mention in mentions: + event_counts[mention.event_type] += 1 + + return max(event_counts.keys(), key=lambda e: event_counts[e]) + + +def _build_news_summary(mentions: List[EntityMention]) -> str: + if not mentions: + return "" + + snippets = [m.context_snippet for m in mentions[:3]] + return " ".join(snippets) + + +def calculate_trending_scores( + mentions: List[EntityMention], + articles: List[NewsArticle], + decay_rate: float = DEFAULT_DECAY_RATE, + max_results: int = DEFAULT_MAX_RESULTS, + min_mentions: int = DEFAULT_MIN_MENTIONS, +) -> List[TrendingStock]: + if not mentions: + return [] + + ticker_mentions: Dict[str, List[EntityMention]] = defaultdict(list) + ticker_company_names: Dict[str, str] = {} + + for mention in mentions: + ticker = resolve_ticker(mention.company_name) + if ticker: + ticker_mentions[ticker].append(mention) + if ticker not in ticker_company_names: + ticker_company_names[ticker] = mention.company_name + + article_index: Dict[str, int] = {} + for i, article in enumerate(articles): + article_index[f"article_{i}"] = i + + trending_stocks: List[TrendingStock] = [] + + for ticker, ticker_mention_list in ticker_mentions.items(): + article_ids = {m.article_id for m in ticker_mention_list} + frequency = len(article_ids) + + if frequency < min_mentions: + continue + + sentiment = _aggregate_sentiment(ticker_mention_list) + sentiment_factor = 1 + abs(sentiment) + + recency_weight = _calculate_recency_weight(articles, article_ids, decay_rate) + + score = frequency * sentiment_factor * recency_weight + + sector_str = classify_sector(ticker) + try: + sector = Sector(sector_str) + except ValueError: + sector = Sector.OTHER + + event_type = _get_most_common_event_type(ticker_mention_list) + + source_article_list: List[NewsArticle] = [] + for article_id in article_ids: + idx = article_index.get(article_id) + if idx is not None and idx < len(articles): + source_article_list.append(articles[idx]) + + news_summary = _build_news_summary(ticker_mention_list) + + trending_stock = TrendingStock( + ticker=ticker, + company_name=ticker_company_names.get(ticker, ticker), + score=score, + mention_count=frequency, + sentiment=sentiment, + sector=sector, + event_type=event_type, + news_summary=news_summary, + source_articles=source_article_list, + ) + trending_stocks.append(trending_stock) + + trending_stocks.sort(key=lambda s: s.score, reverse=True) + + return trending_stocks[:max_results] diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 3a859ea1..4a1ce0ce 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -7,44 +7,44 @@ from langgraph.prebuilt import ToolNode from langgraph.graph import END, StateGraph, START, MessagesState -# Researcher team state class InvestDebateState(TypedDict): + """Researcher team state""" bull_history: Annotated[ str, "Bullish Conversation history" - ] # Bullish Conversation history + ] bear_history: Annotated[ str, "Bearish Conversation history" - ] # Bullish Conversation history - history: Annotated[str, "Conversation history"] # Conversation history - current_response: Annotated[str, "Latest response"] # Last response - judge_decision: Annotated[str, "Final judge decision"] # Last response - count: Annotated[int, "Length of the current conversation"] # Conversation length + ] + history: Annotated[str, "Conversation history"] + current_response: Annotated[str, "Latest response"] + judge_decision: Annotated[str, "Final judge decision"] + count: Annotated[int, "Length of the current conversation"] -# Risk management team state class RiskDebateState(TypedDict): + """Risk management team state""" risky_history: Annotated[ str, "Risky Agent's Conversation history" - ] # Conversation history + ] safe_history: Annotated[ str, "Safe Agent's Conversation history" - ] # Conversation history + ] neutral_history: Annotated[ str, "Neutral Agent's Conversation history" - ] # Conversation history - history: Annotated[str, "Conversation history"] # Conversation history + ] + history: Annotated[str, "Conversation history"] latest_speaker: Annotated[str, "Analyst that spoke last"] current_risky_response: Annotated[ str, "Latest response by the risky analyst" - ] # Last response + ] current_safe_response: Annotated[ str, "Latest response by the safe analyst" - ] # Last response + ] current_neutral_response: Annotated[ str, "Latest response by the neutral analyst" - ] # Last response + ] judge_decision: Annotated[str, "Judge's decision"] - count: Annotated[int, "Length of the current conversation"] # Conversation length + count: Annotated[int, "Length of the current conversation"] class AgentState(MessagesState): @@ -53,7 +53,7 @@ class AgentState(MessagesState): sender: Annotated[str, "Agent that sent this message"] - # research step + # research market_report: Annotated[str, "Report from the Market Analyst"] sentiment_report: Annotated[str, "Report from the Social Media Analyst"] news_report: Annotated[ @@ -61,7 +61,7 @@ class AgentState(MessagesState): ] fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] - # researcher team discussion step + # research investment_debate_state: Annotated[ InvestDebateState, "Current state of the debate on if to invest or not" ] @@ -69,7 +69,7 @@ class AgentState(MessagesState): trader_investment_plan: Annotated[str, "Plan generated by the Trader"] - # risk management team discussion step + # risk mgmt risk_debate_state: Annotated[ RiskDebateState, "Current state of the debate on evaluating risk" ] diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 6cf294a1..6f01dc32 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -1,6 +1,5 @@ from langchain_core.messages import HumanMessage, RemoveMessage -# Import tools from separate utility files from tradingagents.agents.utils.core_stock_tools import ( get_stock_data ) @@ -24,16 +23,7 @@ def create_msg_delete(): def delete_messages(state): """Clear messages and add placeholder for Anthropic compatibility""" messages = state["messages"] - - # Remove all messages removal_operations = [RemoveMessage(id=m.id) for m in messages] - - # Add a minimal placeholder message placeholder = HumanMessage(content="Continue") - return {"messages": removal_operations + [placeholder]} - return delete_messages - - - \ No newline at end of file diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..9146313e 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -67,47 +67,45 @@ class FinancialSituationMemory: return matched_results -if __name__ == "__main__": - # Example usage - matcher = FinancialSituationMemory() +# if __name__ == "__main__": +# # Example usage +# matcher = FinancialSituationMemory() +# example_data = [ +# ( +# "High inflation rate with rising interest rates and declining consumer spending", +# "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", +# ), +# ( +# "Tech sector showing high volatility with increasing institutional selling pressure", +# "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", +# ), +# ( +# "Strong dollar affecting emerging markets with increasing forex volatility", +# "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", +# ), +# ( +# "Market showing signs of sector rotation with rising yields", +# "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", +# ), +# ] - # Example data - example_data = [ - ( - "High inflation rate with rising interest rates and declining consumer spending", - "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", - ), - ( - "Tech sector showing high volatility with increasing institutional selling pressure", - "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", - ), - ( - "Strong dollar affecting emerging markets with increasing forex volatility", - "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", - ), - ( - "Market showing signs of sector rotation with rising yields", - "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", - ), - ] +# # Add the example situations and recommendations +# matcher.add_situations(example_data) - # Add the example situations and recommendations - matcher.add_situations(example_data) +# # Example query +# current_situation = """ +# Market showing increased volatility in tech sector, with institutional investors +# reducing positions and rising interest rates affecting growth stock valuations +# """ - # Example query - current_situation = """ - Market showing increased volatility in tech sector, with institutional investors - reducing positions and rising interest rates affecting growth stock valuations - """ +# try: +# recommendations = matcher.get_memories(current_situation, n_matches=2) - try: - recommendations = matcher.get_memories(current_situation, n_matches=2) +# for i, rec in enumerate(recommendations, 1): +# print(f"\nMatch {i}:") +# print(f"Similarity Score: {rec['similarity_score']:.2f}") +# print(f"Matched Situation: {rec['matched_situation']}") +# print(f"Recommendation: {rec['recommendation']}") - for i, rec in enumerate(recommendations, 1): - print(f"\nMatch {i}:") - print(f"Similarity Score: {rec['similarity_score']:.2f}") - print(f"Matched Situation: {rec['matched_situation']}") - print(f"Recommendation: {rec['recommendation']}") - - except Exception as e: - print(f"Error during recommendation: {str(e)}") +# except Exception as e: +# print(f"Error during recommendation: {str(e)}") diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py index 8124fb45..984bd4e0 100644 --- a/tradingagents/dataflows/alpha_vantage_news.py +++ b/tradingagents/dataflows/alpha_vantage_news.py @@ -1,19 +1,9 @@ +import json +from datetime import datetime, timedelta +from typing import List, Dict, Any from .alpha_vantage_common import _make_api_request, format_datetime_for_api def get_news(ticker, start_date, end_date) -> dict[str, str] | str: - """Returns live and historical market news & sentiment data from premier news outlets worldwide. - - Covers stocks, cryptocurrencies, forex, and topics like fiscal policy, mergers & acquisitions, IPOs. - - Args: - ticker: Stock symbol for news articles. - start_date: Start date for news search. - end_date: End date for news search. - - Returns: - Dictionary containing news sentiment data or JSON string. - """ - params = { "tickers": ticker, "time_from": format_datetime_for_api(start_date), @@ -21,23 +11,63 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str: "sort": "LATEST", "limit": "50", } - + return _make_api_request("NEWS_SENTIMENT", params) def get_insider_transactions(symbol: str) -> dict[str, str] | str: - """Returns latest and historical insider transactions by key stakeholders. - - Covers transactions by founders, executives, board members, etc. - - Args: - symbol: Ticker symbol. Example: "IBM". - - Returns: - Dictionary containing insider transaction data or JSON string. - """ - params = { "symbol": symbol, } - return _make_api_request("INSIDER_TRANSACTIONS", params) \ No newline at end of file + return _make_api_request("INSIDER_TRANSACTIONS", params) + + +def get_bulk_news_alpha_vantage(lookback_hours: int) -> List[Dict[str, Any]]: + end_date = datetime.now() + start_date = end_date - timedelta(hours=lookback_hours) + + params = { + "time_from": format_datetime_for_api(start_date), + "time_to": format_datetime_for_api(end_date), + "sort": "LATEST", + "limit": "200", + "topics": "financial_markets,earnings,economy_fiscal,economy_monetary,mergers_and_acquisitions", + } + + response = _make_api_request("NEWS_SENTIMENT", params) + + if isinstance(response, str): + try: + response = json.loads(response) + except json.JSONDecodeError: + return [] + + if not isinstance(response, dict): + return [] + + feed = response.get("feed", []) + + articles = [] + for item in feed: + try: + time_published = item.get("time_published", "") + if time_published: + try: + published_at = datetime.strptime(time_published, "%Y%m%dT%H%M%S") + except ValueError: + published_at = datetime.now() + else: + published_at = datetime.now() + + article = { + "title": item.get("title", ""), + "source": item.get("source", ""), + "url": item.get("url", ""), + "published_at": published_at.isoformat(), + "content_snippet": item.get("summary", "")[:500], + } + articles.append(article) + except Exception: + continue + + return articles diff --git a/tradingagents/dataflows/google.py b/tradingagents/dataflows/google.py index 3fe20f3c..80037ed4 100644 --- a/tradingagents/dataflows/google.py +++ b/tradingagents/dataflows/google.py @@ -1,5 +1,5 @@ -from typing import Annotated -from datetime import datetime +from typing import Annotated, List, Dict, Any +from datetime import datetime, timedelta from dateutil.relativedelta import relativedelta from .googlenews_utils import getNewsData @@ -27,4 +27,53 @@ def get_google_news( if len(news_results) == 0: return "" - return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}" \ No newline at end of file + return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}" + + +def get_bulk_news_google(lookback_hours: int) -> List[Dict[str, Any]]: + end_date = datetime.now() + start_date = end_date - timedelta(hours=lookback_hours) + + start_str = start_date.strftime("%Y-%m-%d") + end_str = end_date.strftime("%Y-%m-%d") + + queries = [ + "stock market", + "trading news", + "earnings report", + ] + + all_articles = [] + seen_titles = set() + + for query in queries: + try: + news_results = getNewsData(query.replace(" ", "+"), start_str, end_str) + + for news in news_results: + title = news.get("title", "") + if title and title not in seen_titles: + seen_titles.add(title) + + date_str = news.get("date", "") + try: + if date_str: + published_at = datetime.now() + else: + published_at = datetime.now() + except ValueError: + published_at = datetime.now() + + article = { + "title": title, + "source": news.get("source", "Google News"), + "url": news.get("link", ""), + "published_at": published_at.isoformat(), + "content_snippet": news.get("snippet", "")[:500], + } + all_articles.append(article) + + except Exception: + continue + + return all_articles diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 4cd5ddef..c49909fa 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -1,10 +1,10 @@ -from typing import Annotated +from typing import Annotated, List, Dict, Any, Optional +from datetime import datetime, timedelta -# Import from vendor-specific modules from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions -from .google import get_google_news -from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai +from .google import get_google_news, get_bulk_news_google +from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai, get_bulk_news_openai from .alpha_vantage import ( get_stock as get_alpha_vantage_stock, get_indicator as get_alpha_vantage_indicator, @@ -15,12 +15,13 @@ from .alpha_vantage import ( get_insider_transactions as get_alpha_vantage_insider_transactions, get_news as get_alpha_vantage_news ) +from .alpha_vantage_news import get_bulk_news_alpha_vantage from .alpha_vantage_common import AlphaVantageRateLimitError -# Configuration and routing logic from .config import get_config -# Tools organized by category +from tradingagents.agents.discovery import NewsArticle + TOOLS_CATEGORIES = { "core_stock_apis": { "description": "OHLCV stock price data", @@ -50,6 +51,7 @@ TOOLS_CATEGORIES = { "get_global_news", "get_insider_sentiment", "get_insider_transactions", + "get_bulk_news", ] } } @@ -61,21 +63,17 @@ VENDOR_LIST = [ "google" ] -# Mapping of methods to their vendor-specific implementations VENDOR_METHODS = { - # core_stock_apis "get_stock_data": { "alpha_vantage": get_alpha_vantage_stock, "yfinance": get_YFin_data_online, "local": get_YFin_data, }, - # technical_indicators "get_indicators": { "alpha_vantage": get_alpha_vantage_indicator, "yfinance": get_stock_stats_indicators_window, "local": get_stock_stats_indicators_window }, - # fundamental_data "get_fundamentals": { "alpha_vantage": get_alpha_vantage_fundamentals, "openai": get_fundamentals_openai, @@ -95,7 +93,6 @@ VENDOR_METHODS = { "yfinance": get_yfinance_income_statement, "local": get_simfin_income_statements, }, - # news_data "get_news": { "alpha_vantage": get_alpha_vantage_news, "openai": get_stock_news_openai, @@ -114,56 +111,159 @@ VENDOR_METHODS = { "yfinance": get_yfinance_insider_transactions, "local": get_finnhub_company_insider_transactions, }, + "get_bulk_news": { + "alpha_vantage": get_bulk_news_alpha_vantage, + "openai": get_bulk_news_openai, + "google": get_bulk_news_google, + }, } +CACHE_TTL_SECONDS = 300 + +_bulk_news_cache: Dict[str, Dict[str, Any]] = {} + + +def parse_lookback_period(lookback: str) -> int: + lookback = lookback.lower().strip() + + if lookback == "1h": + return 1 + elif lookback == "6h": + return 6 + elif lookback == "24h": + return 24 + elif lookback == "7d": + return 168 + else: + raise ValueError(f"Invalid lookback period: {lookback}. Valid values: 1h, 6h, 24h, 7d") + + +def _get_cached_bulk_news(lookback_period: str) -> Optional[List[NewsArticle]]: + cache_key = lookback_period + if cache_key in _bulk_news_cache: + cached = _bulk_news_cache[cache_key] + cached_time = cached.get("timestamp") + if cached_time and (datetime.now() - cached_time).total_seconds() < CACHE_TTL_SECONDS: + return cached.get("articles") + return None + + +def _set_cached_bulk_news(lookback_period: str, articles: List[NewsArticle]) -> None: + cache_key = lookback_period + _bulk_news_cache[cache_key] = { + "timestamp": datetime.now(), + "articles": articles, + } + + +def _convert_to_news_articles(raw_articles: List[Dict[str, Any]]) -> List[NewsArticle]: + articles = [] + for item in raw_articles: + try: + published_at_str = item.get("published_at", "") + if isinstance(published_at_str, str): + try: + published_at = datetime.fromisoformat(published_at_str.replace("Z", "+00:00")) + except ValueError: + published_at = datetime.now() + elif isinstance(published_at_str, datetime): + published_at = published_at_str + else: + published_at = datetime.now() + + article = NewsArticle( + title=item.get("title", ""), + source=item.get("source", ""), + url=item.get("url", ""), + published_at=published_at, + content_snippet=item.get("content_snippet", ""), + ticker_mentions=[], + ) + articles.append(article) + except Exception: + continue + return articles + + +def _fetch_bulk_news_from_vendor(lookback_period: str) -> List[Dict[str, Any]]: + lookback_hours = parse_lookback_period(lookback_period) + + vendor_order = ["alpha_vantage", "openai", "google"] + + for vendor in vendor_order: + if vendor not in VENDOR_METHODS["get_bulk_news"]: + continue + + vendor_func = VENDOR_METHODS["get_bulk_news"][vendor] + + try: + print(f"DEBUG: Attempting bulk news from vendor '{vendor}'...") + result = vendor_func(lookback_hours) + if result: + print(f"SUCCESS: Got {len(result)} articles from vendor '{vendor}'") + return result + print(f"DEBUG: Vendor '{vendor}' returned empty results, trying next...") + except AlphaVantageRateLimitError as e: + print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded: {e}") + continue + except Exception as e: + print(f"FAILED: Vendor '{vendor}' failed: {e}") + continue + + return [] + + +def get_bulk_news(lookback_period: str = "24h") -> List[NewsArticle]: + cached = _get_cached_bulk_news(lookback_period) + if cached is not None: + print(f"DEBUG: Returning cached bulk news for period '{lookback_period}'") + return cached + + raw_articles = _fetch_bulk_news_from_vendor(lookback_period) + + articles = _convert_to_news_articles(raw_articles) + + _set_cached_bulk_news(lookback_period, articles) + + return articles + + def get_category_for_method(method: str) -> str: - """Get the category that contains the specified method.""" for category, info in TOOLS_CATEGORIES.items(): if method in info["tools"]: return category raise ValueError(f"Method '{method}' not found in any category") def get_vendor(category: str, method: str = None) -> str: - """Get the configured vendor for a data category or specific tool method. - Tool-level configuration takes precedence over category-level. - """ config = get_config() - # Check tool-level configuration first (if method provided) if method: tool_vendors = config.get("tool_vendors", {}) if method in tool_vendors: return tool_vendors[method] - # Fall back to category-level configuration return config.get("data_vendors", {}).get(category, "default") def route_to_vendor(method: str, *args, **kwargs): - """Route method calls to appropriate vendor implementation with fallback support.""" category = get_category_for_method(method) vendor_config = get_vendor(category, method) - # Handle comma-separated vendors primary_vendors = [v.strip() for v in vendor_config.split(',')] if method not in VENDOR_METHODS: raise ValueError(f"Method '{method}' not supported") - # Get all available vendors for this method for fallback all_available_vendors = list(VENDOR_METHODS[method].keys()) - - # Create fallback vendor list: primary vendors first, then remaining vendors as fallbacks + fallback_vendors = primary_vendors.copy() for vendor in all_available_vendors: if vendor not in fallback_vendors: fallback_vendors.append(vendor) - # Debug: Print fallback ordering - primary_str = " → ".join(primary_vendors) - fallback_str = " → ".join(fallback_vendors) + primary_str = " -> ".join(primary_vendors) + fallback_str = " -> ".join(fallback_vendors) print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]") - # Track results and execution state results = [] vendor_attempt_count = 0 any_primary_vendor_attempted = False @@ -179,22 +279,18 @@ def route_to_vendor(method: str, *args, **kwargs): is_primary_vendor = vendor in primary_vendors vendor_attempt_count += 1 - # Track if we attempted any primary vendor if is_primary_vendor: any_primary_vendor_attempted = True - # Debug: Print current attempt vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK" print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})") - # Handle list of methods for a vendor if isinstance(vendor_impl, list): vendor_methods = [(impl, vendor) for impl in vendor_impl] print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions") else: vendor_methods = [(vendor_impl, vendor)] - # Run methods for this vendor vendor_results = [] for impl_func, vendor_name in vendor_methods: try: @@ -202,43 +298,35 @@ def route_to_vendor(method: str, *args, **kwargs): result = impl_func(*args, **kwargs) vendor_results.append(result) print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully") - + except AlphaVantageRateLimitError as e: if vendor == "alpha_vantage": print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor") print(f"DEBUG: Rate limit details: {e}") - # Continue to next vendor for fallback continue except Exception as e: - # Log error but continue with other implementations print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}") continue - # Add this vendor's results if vendor_results: results.extend(vendor_results) successful_vendor = vendor result_summary = f"Got {len(vendor_results)} result(s)" print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}") - - # Stopping logic: Stop after first successful vendor for single-vendor configs - # Multiple vendor configs (comma-separated) may want to collect from multiple sources + if len(primary_vendors) == 1: print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)") break else: print(f"FAILED: Vendor '{vendor}' produced no results") - # Final result summary if not results: print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'") raise RuntimeError(f"All vendor implementations failed for method '{method}'") else: print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)") - # Return single result if only one, otherwise concatenate as string if len(results) == 1: return results[0] else: - # Convert all results to strings and concatenate - return '\n'.join(str(result) for result in results) \ No newline at end of file + return '\n'.join(str(result) for result in results) diff --git a/tradingagents/dataflows/openai.py b/tradingagents/dataflows/openai.py index 91a2258b..d9439e14 100644 --- a/tradingagents/dataflows/openai.py +++ b/tradingagents/dataflows/openai.py @@ -1,3 +1,7 @@ +import json +import re +from datetime import datetime, timedelta +from typing import List, Dict, Any from openai import OpenAI from .config import get_config @@ -104,4 +108,91 @@ def get_fundamentals_openai(ticker, curr_date): store=True, ) - return response.output[1].content[0].text \ No newline at end of file + return response.output[1].content[0].text + + +def get_bulk_news_openai(lookback_hours: int) -> List[Dict[str, Any]]: + config = get_config() + client = OpenAI(base_url=config["backend_url"]) + + end_date = datetime.now() + start_date = end_date - timedelta(hours=lookback_hours) + + start_str = start_date.strftime("%Y-%m-%d %H:%M") + end_str = end_date.strftime("%Y-%m-%d %H:%M") + + prompt = f"""Search for recent stock market news, trading news, and earnings announcements from {start_str} to {end_str}. + +Return the results as a JSON array with the following structure: +[ + {{ + "title": "Article title", + "source": "Source name", + "url": "https://...", + "published_at": "YYYY-MM-DDTHH:MM:SS", + "content_snippet": "Brief summary of the article..." + }} +] + +Focus on: +- Stock market movements and trends +- Company earnings reports +- Mergers and acquisitions +- Significant trading activity +- Economic news affecting markets + +Return ONLY the JSON array, no additional text.""" + + response = client.responses.create( + model=config["quick_think_llm"], + input=[ + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": prompt, + } + ], + } + ], + text={"format": {"type": "text"}}, + reasoning={}, + tools=[ + { + "type": "web_search_preview", + "user_location": {"type": "approximate"}, + "search_context_size": "medium", + } + ], + temperature=0.5, + max_output_tokens=8192, + top_p=1, + store=True, + ) + + try: + response_text = response.output[1].content[0].text + + json_match = re.search(r'\[[\s\S]*\]', response_text) + if json_match: + articles = json.loads(json_match.group()) + else: + articles = json.loads(response_text) + + result = [] + for item in articles: + if isinstance(item, dict): + article = { + "title": item.get("title", ""), + "source": item.get("source", "Web Search"), + "url": item.get("url", ""), + "published_at": item.get("published_at", datetime.now().isoformat()), + "content_snippet": item.get("content_snippet", "")[:500], + } + result.append(article) + + return result + + except (json.JSONDecodeError, IndexError, AttributeError): + return [] diff --git a/tradingagents/dataflows/trending/__init__.py b/tradingagents/dataflows/trending/__init__.py new file mode 100644 index 00000000..190a1a32 --- /dev/null +++ b/tradingagents/dataflows/trending/__init__.py @@ -0,0 +1,21 @@ +from .stock_resolver import ( + resolve_ticker, + validate_tradeable, + validate_us_ticker, + COMPANY_TO_TICKER, +) +from .sector_classifier import ( + classify_sector, + TICKER_TO_SECTOR, + VALID_SECTORS, +) + +__all__ = [ + "resolve_ticker", + "validate_tradeable", + "validate_us_ticker", + "COMPANY_TO_TICKER", + "classify_sector", + "TICKER_TO_SECTOR", + "VALID_SECTORS", +] diff --git a/tradingagents/dataflows/trending/sector_classifier.py b/tradingagents/dataflows/trending/sector_classifier.py new file mode 100644 index 00000000..999ec3aa --- /dev/null +++ b/tradingagents/dataflows/trending/sector_classifier.py @@ -0,0 +1,267 @@ +import logging +from typing import Dict + +logger = logging.getLogger(__name__) + +VALID_SECTORS = { + "technology", + "healthcare", + "finance", + "energy", + "consumer_goods", + "industrials", + "other", +} + +TICKER_TO_SECTOR: Dict[str, str] = { + "AAPL": "technology", + "MSFT": "technology", + "GOOGL": "technology", + "GOOG": "technology", + "AMZN": "technology", + "META": "technology", + "NVDA": "technology", + "TSLA": "technology", + "AMD": "technology", + "INTC": "technology", + "QCOM": "technology", + "AVGO": "technology", + "TXN": "technology", + "ADBE": "technology", + "CRM": "technology", + "CSCO": "technology", + "NFLX": "technology", + "ORCL": "technology", + "IBM": "technology", + "NOW": "technology", + "INTU": "technology", + "ADSK": "technology", + "SNPS": "technology", + "CDNS": "technology", + "PLTR": "technology", + "SNOW": "technology", + "DDOG": "technology", + "CRWD": "technology", + "OKTA": "technology", + "NET": "technology", + "MDB": "technology", + "TWLO": "technology", + "WDAY": "technology", + "SPLK": "technology", + "VMW": "technology", + "HPQ": "technology", + "DELL": "technology", + "FTNT": "technology", + "PANW": "technology", + "ZS": "technology", + "S": "technology", + "VEEV": "technology", + "ZM": "technology", + "DOCU": "technology", + "ASAN": "technology", + "MNDY": "technology", + "TEAM": "technology", + "ANSS": "technology", + "ROP": "technology", + "JPM": "finance", + "BAC": "finance", + "WFC": "finance", + "GS": "finance", + "MS": "finance", + "C": "finance", + "BLK": "finance", + "SCHW": "finance", + "AXP": "finance", + "V": "finance", + "MA": "finance", + "PYPL": "finance", + "SQ": "finance", + "COIN": "finance", + "HOOD": "finance", + "SOFI": "finance", + "AFRM": "finance", + "MQ": "finance", + "BRK-B": "finance", + "BRK-A": "finance", + "JNJ": "healthcare", + "UNH": "healthcare", + "PFE": "healthcare", + "ABBV": "healthcare", + "MRK": "healthcare", + "LLY": "healthcare", + "MRNA": "healthcare", + "BNTX": "healthcare", + "CVS": "healthcare", + "WBA": "healthcare", + "MCK": "healthcare", + "CAH": "healthcare", + "HUM": "healthcare", + "CI": "healthcare", + "ELV": "healthcare", + "XOM": "energy", + "CVX": "energy", + "COP": "energy", + "SLB": "energy", + "HAL": "energy", + "BKR": "energy", + "MPC": "energy", + "VLO": "energy", + "PSX": "energy", + "OXY": "energy", + "PXD": "energy", + "DVN": "energy", + "CEG": "energy", + "NEE": "energy", + "DUK": "energy", + "SO": "energy", + "D": "energy", + "SRE": "energy", + "WMT": "consumer_goods", + "COST": "consumer_goods", + "TGT": "consumer_goods", + "HD": "consumer_goods", + "LOW": "consumer_goods", + "PG": "consumer_goods", + "KO": "consumer_goods", + "PEP": "consumer_goods", + "NKE": "consumer_goods", + "SBUX": "consumer_goods", + "MCD": "consumer_goods", + "CMG": "consumer_goods", + "YUM": "consumer_goods", + "DPZ": "consumer_goods", + "DIS": "consumer_goods", + "CMCSA": "consumer_goods", + "VZ": "consumer_goods", + "T": "consumer_goods", + "TMUS": "consumer_goods", + "EL": "consumer_goods", + "CL": "consumer_goods", + "KMB": "consumer_goods", + "CLX": "consumer_goods", + "KHC": "consumer_goods", + "GIS": "consumer_goods", + "K": "consumer_goods", + "MDLZ": "consumer_goods", + "HSY": "consumer_goods", + "TSN": "consumer_goods", + "BYND": "consumer_goods", + "CAG": "consumer_goods", + "STZ": "consumer_goods", + "BUD": "consumer_goods", + "DEO": "consumer_goods", + "PM": "consumer_goods", + "MO": "consumer_goods", + "LULU": "consumer_goods", + "DG": "consumer_goods", + "DLTR": "consumer_goods", + "ROST": "consumer_goods", + "TJX": "consumer_goods", + "AZO": "consumer_goods", + "ORLY": "consumer_goods", + "KMX": "consumer_goods", + "ADDYY": "consumer_goods", + "UBER": "consumer_goods", + "LYFT": "consumer_goods", + "ABNB": "consumer_goods", + "DASH": "consumer_goods", + "SNAP": "consumer_goods", + "PINS": "consumer_goods", + "TWTR": "consumer_goods", + "SHOP": "consumer_goods", + "TOST": "consumer_goods", + "BA": "industrials", + "LMT": "industrials", + "RTX": "industrials", + "GD": "industrials", + "NOC": "industrials", + "GE": "industrials", + "HON": "industrials", + "MMM": "industrials", + "CAT": "industrials", + "DE": "industrials", + "UNP": "industrials", + "UPS": "industrials", + "FDX": "industrials", + "DAL": "industrials", + "UAL": "industrials", + "AAL": "industrials", + "LUV": "industrials", + "F": "industrials", + "GM": "industrials", + "TM": "industrials", + "HMC": "industrials", + "VWAGY": "industrials", + "RACE": "industrials", + "RIVN": "industrials", + "LCID": "industrials", + "NIO": "industrials", + "LNVGY": "industrials", +} + +_sector_cache: Dict[str, str] = {} + + +def _llm_classify_sector(ticker: str) -> str: + from langchain_openai import ChatOpenAI + from langchain_core.messages import HumanMessage, SystemMessage + from tradingagents.default_config import DEFAULT_CONFIG + + llm_name = DEFAULT_CONFIG.get("quick_think_llm", "gpt-4o-mini") + llm_provider = DEFAULT_CONFIG.get("llm_provider", "openai") + backend_url = DEFAULT_CONFIG.get("backend_url", "https://api.openai.com/v1") + + llm = ChatOpenAI( + model=llm_name, + base_url=backend_url, + temperature=0, + ) + + system_prompt = ( + "You are a financial sector classifier. Given a stock ticker symbol, " + "classify it into exactly one of the following sectors: " + "technology, healthcare, finance, energy, consumer_goods, industrials, other. " + "Respond with only the sector name in lowercase, nothing else." + ) + + user_prompt = f"Classify the stock ticker: {ticker}" + + messages = [ + SystemMessage(content=system_prompt), + HumanMessage(content=user_prompt), + ] + + response = llm.invoke(messages) + sector = response.content.strip().lower() + + if sector not in VALID_SECTORS: + logger.warning( + "LLM returned invalid sector '%s' for ticker %s, defaulting to 'other'", + sector, + ticker, + ) + return "other" + + return sector + + +def classify_sector(ticker: str) -> str: + ticker_upper = ticker.upper() + + if ticker_upper in TICKER_TO_SECTOR: + return TICKER_TO_SECTOR[ticker_upper] + + if ticker_upper in _sector_cache: + return _sector_cache[ticker_upper] + + logger.info("Using LLM fallback for sector classification of ticker: %s", ticker) + + try: + sector = _llm_classify_sector(ticker_upper) + _sector_cache[ticker_upper] = sector + logger.info("Classified %s as %s via LLM", ticker, sector) + return sector + except Exception as e: + logger.error("LLM sector classification failed for %s: %s", ticker, str(e)) + _sector_cache[ticker_upper] = "other" + return "other" diff --git a/tradingagents/dataflows/trending/stock_resolver.py b/tradingagents/dataflows/trending/stock_resolver.py new file mode 100644 index 00000000..a2c2aefb --- /dev/null +++ b/tradingagents/dataflows/trending/stock_resolver.py @@ -0,0 +1,538 @@ +import logging +import re +from typing import Optional + +import yfinance as yf + +logger = logging.getLogger(__name__) + +COMPANY_TO_TICKER = { + "apple": "AAPL", + "apple inc": "AAPL", + "apple inc.": "AAPL", + "apple corporation": "AAPL", + "the iphone maker": "AAPL", + "iphone maker": "AAPL", + "microsoft": "MSFT", + "microsoft inc": "MSFT", + "microsoft inc.": "MSFT", + "microsoft corp": "MSFT", + "microsoft corp.": "MSFT", + "microsoft corporation": "MSFT", + "google": "GOOGL", + "alphabet": "GOOGL", + "alphabet inc": "GOOGL", + "alphabet inc.": "GOOGL", + "the search giant": "GOOGL", + "amazon": "AMZN", + "amazon inc": "AMZN", + "amazon inc.": "AMZN", + "amazon.com": "AMZN", + "amazon.com inc": "AMZN", + "the e-commerce giant": "AMZN", + "e-commerce giant": "AMZN", + "meta": "META", + "meta platforms": "META", + "meta platforms inc": "META", + "meta platforms inc.": "META", + "facebook": "META", + "facebook inc": "META", + "facebook inc.": "META", + "tesla": "TSLA", + "tesla inc": "TSLA", + "tesla inc.": "TSLA", + "tesla motors": "TSLA", + "ev maker tesla": "TSLA", + "nvidia": "NVDA", + "nvidia corp": "NVDA", + "nvidia corp.": "NVDA", + "nvidia corporation": "NVDA", + "berkshire hathaway": "BRK-B", + "berkshire": "BRK-B", + "jpmorgan": "JPM", + "jpmorgan chase": "JPM", + "jp morgan": "JPM", + "jp morgan chase": "JPM", + "johnson & johnson": "JNJ", + "johnson and johnson": "JNJ", + "j&j": "JNJ", + "unitedhealth": "UNH", + "unitedhealth group": "UNH", + "visa": "V", + "visa inc": "V", + "visa inc.": "V", + "procter & gamble": "PG", + "procter and gamble": "PG", + "p&g": "PG", + "mastercard": "MA", + "mastercard inc": "MA", + "mastercard inc.": "MA", + "home depot": "HD", + "the home depot": "HD", + "chevron": "CVX", + "chevron corp": "CVX", + "chevron corporation": "CVX", + "exxon": "XOM", + "exxon mobil": "XOM", + "exxonmobil": "XOM", + "pfizer": "PFE", + "pfizer inc": "PFE", + "pfizer inc.": "PFE", + "abbvie": "ABBV", + "abbvie inc": "ABBV", + "abbvie inc.": "ABBV", + "coca-cola": "KO", + "coca cola": "KO", + "coke": "KO", + "the coca-cola company": "KO", + "pepsico": "PEP", + "pepsi": "PEP", + "pepsi co": "PEP", + "costco": "COST", + "costco wholesale": "COST", + "walmart": "WMT", + "wal-mart": "WMT", + "walmart inc": "WMT", + "bank of america": "BAC", + "bofa": "BAC", + "merck": "MRK", + "merck & co": "MRK", + "merck and co": "MRK", + "eli lilly": "LLY", + "lilly": "LLY", + "eli lilly and company": "LLY", + "adobe": "ADBE", + "adobe inc": "ADBE", + "adobe inc.": "ADBE", + "adobe systems": "ADBE", + "salesforce": "CRM", + "salesforce inc": "CRM", + "salesforce.com": "CRM", + "cisco": "CSCO", + "cisco systems": "CSCO", + "cisco systems inc": "CSCO", + "netflix": "NFLX", + "netflix inc": "NFLX", + "netflix inc.": "NFLX", + "oracle": "ORCL", + "oracle corp": "ORCL", + "oracle corporation": "ORCL", + "intel": "INTC", + "intel corp": "INTC", + "intel corporation": "INTC", + "amd": "AMD", + "advanced micro devices": "AMD", + "qualcomm": "QCOM", + "qualcomm inc": "QCOM", + "qualcomm inc.": "QCOM", + "broadcom": "AVGO", + "broadcom inc": "AVGO", + "broadcom inc.": "AVGO", + "texas instruments": "TXN", + "ti": "TXN", + "disney": "DIS", + "walt disney": "DIS", + "the walt disney company": "DIS", + "walt disney company": "DIS", + "comcast": "CMCSA", + "comcast corp": "CMCSA", + "comcast corporation": "CMCSA", + "verizon": "VZ", + "verizon communications": "VZ", + "at&t": "T", + "att": "T", + "t-mobile": "TMUS", + "tmobile": "TMUS", + "t-mobile us": "TMUS", + "american express": "AXP", + "amex": "AXP", + "goldman sachs": "GS", + "goldman": "GS", + "morgan stanley": "MS", + "wells fargo": "WFC", + "wells": "WFC", + "citigroup": "C", + "citi": "C", + "citibank": "C", + "charles schwab": "SCHW", + "schwab": "SCHW", + "blackrock": "BLK", + "blackrock inc": "BLK", + "paypal": "PYPL", + "paypal holdings": "PYPL", + "paypal inc": "PYPL", + "square": "SQ", + "block": "SQ", + "block inc": "SQ", + "shopify": "SHOP", + "shopify inc": "SHOP", + "uber": "UBER", + "uber technologies": "UBER", + "lyft": "LYFT", + "lyft inc": "LYFT", + "airbnb": "ABNB", + "airbnb inc": "ABNB", + "doordash": "DASH", + "doordash inc": "DASH", + "snap": "SNAP", + "snap inc": "SNAP", + "snapchat": "SNAP", + "pinterest": "PINS", + "pinterest inc": "PINS", + "twitter": "TWTR", + "twitter inc": "TWTR", + "linkedin": "MSFT", + "zoom": "ZM", + "zoom video": "ZM", + "zoom video communications": "ZM", + "slack": "CRM", + "slack technologies": "CRM", + "palantir": "PLTR", + "palantir technologies": "PLTR", + "snowflake": "SNOW", + "snowflake inc": "SNOW", + "datadog": "DDOG", + "datadog inc": "DDOG", + "crowdstrike": "CRWD", + "crowdstrike holdings": "CRWD", + "okta": "OKTA", + "okta inc": "OKTA", + "cloudflare": "NET", + "cloudflare inc": "NET", + "mongodb": "MDB", + "mongodb inc": "MDB", + "twilio": "TWLO", + "twilio inc": "TWLO", + "servicenow": "NOW", + "servicenow inc": "NOW", + "workday": "WDAY", + "workday inc": "WDAY", + "splunk": "SPLK", + "splunk inc": "SPLK", + "vmware": "VMW", + "vmware inc": "VMW", + "ibm": "IBM", + "international business machines": "IBM", + "hp": "HPQ", + "hewlett-packard": "HPQ", + "hewlett packard": "HPQ", + "dell": "DELL", + "dell technologies": "DELL", + "lenovo": "LNVGY", + "boeing": "BA", + "boeing company": "BA", + "the boeing company": "BA", + "lockheed martin": "LMT", + "lockheed": "LMT", + "raytheon": "RTX", + "rtx": "RTX", + "general dynamics": "GD", + "northrop grumman": "NOC", + "northrop": "NOC", + "general electric": "GE", + "ge": "GE", + "honeywell": "HON", + "honeywell international": "HON", + "3m": "MMM", + "3m company": "MMM", + "caterpillar": "CAT", + "caterpillar inc": "CAT", + "deere": "DE", + "john deere": "DE", + "deere & company": "DE", + "union pacific": "UNP", + "ups": "UPS", + "united parcel service": "UPS", + "fedex": "FDX", + "federal express": "FDX", + "delta": "DAL", + "delta air lines": "DAL", + "delta airlines": "DAL", + "united airlines": "UAL", + "united": "UAL", + "american airlines": "AAL", + "southwest": "LUV", + "southwest airlines": "LUV", + "ford": "F", + "ford motor": "F", + "ford motor company": "F", + "general motors": "GM", + "gm": "GM", + "toyota": "TM", + "toyota motor": "TM", + "honda": "HMC", + "honda motor": "HMC", + "volkswagen": "VWAGY", + "vw": "VWAGY", + "ferrari": "RACE", + "rivian": "RIVN", + "rivian automotive": "RIVN", + "lucid": "LCID", + "lucid motors": "LCID", + "lucid group": "LCID", + "nio": "NIO", + "nio inc": "NIO", + "moderna": "MRNA", + "moderna inc": "MRNA", + "biontech": "BNTX", + "cvs": "CVS", + "cvs health": "CVS", + "walgreens": "WBA", + "walgreens boots alliance": "WBA", + "mckesson": "MCK", + "mckesson corp": "MCK", + "cardinal health": "CAH", + "humana": "HUM", + "humana inc": "HUM", + "cigna": "CI", + "cigna group": "CI", + "anthem": "ELV", + "elevance health": "ELV", + "starbucks": "SBUX", + "starbucks corp": "SBUX", + "starbucks corporation": "SBUX", + "mcdonalds": "MCD", + "mcdonald's": "MCD", + "chipotle": "CMG", + "chipotle mexican grill": "CMG", + "yum brands": "YUM", + "yum": "YUM", + "dominos": "DPZ", + "domino's": "DPZ", + "domino's pizza": "DPZ", + "nike": "NKE", + "nike inc": "NKE", + "adidas": "ADDYY", + "lululemon": "LULU", + "lululemon athletica": "LULU", + "target": "TGT", + "target corp": "TGT", + "target corporation": "TGT", + "dollar general": "DG", + "dollar tree": "DLTR", + "ross stores": "ROST", + "ross": "ROST", + "tjx": "TJX", + "tjx companies": "TJX", + "tj maxx": "TJX", + "lowes": "LOW", + "lowe's": "LOW", + "lowe's companies": "LOW", + "autozone": "AZO", + "o'reilly": "ORLY", + "o'reilly automotive": "ORLY", + "carmax": "KMX", + "estee lauder": "EL", + "colgate": "CL", + "colgate-palmolive": "CL", + "colgate palmolive": "CL", + "kimberly-clark": "KMB", + "kimberly clark": "KMB", + "clorox": "CLX", + "clorox company": "CLX", + "kraft heinz": "KHC", + "kraft": "KHC", + "heinz": "KHC", + "general mills": "GIS", + "kellogg": "K", + "kellogg's": "K", + "mondelez": "MDLZ", + "mondelez international": "MDLZ", + "hershey": "HSY", + "the hershey company": "HSY", + "tyson": "TSN", + "tyson foods": "TSN", + "beyond meat": "BYND", + "conagra": "CAG", + "conagra brands": "CAG", + "constellation brands": "STZ", + "anheuser-busch": "BUD", + "anheuser busch": "BUD", + "ab inbev": "BUD", + "diageo": "DEO", + "philip morris": "PM", + "philip morris international": "PM", + "altria": "MO", + "altria group": "MO", + "constellation energy": "CEG", + "nextera": "NEE", + "nextera energy": "NEE", + "duke energy": "DUK", + "southern company": "SO", + "dominion": "D", + "dominion energy": "D", + "sempra": "SRE", + "sempra energy": "SRE", + "conocophillips": "COP", + "conoco": "COP", + "schlumberger": "SLB", + "halliburton": "HAL", + "baker hughes": "BKR", + "marathon": "MPC", + "marathon petroleum": "MPC", + "valero": "VLO", + "valero energy": "VLO", + "phillips 66": "PSX", + "occidental": "OXY", + "occidental petroleum": "OXY", + "pioneer": "PXD", + "pioneer natural resources": "PXD", + "devon energy": "DVN", + "devon": "DVN", + "coinbase": "COIN", + "coinbase global": "COIN", + "robinhood": "HOOD", + "robinhood markets": "HOOD", + "sofi": "SOFI", + "sofi technologies": "SOFI", + "affirm": "AFRM", + "affirm holdings": "AFRM", + "marqeta": "MQ", + "toast": "TOST", + "toast inc": "TOST", + "docusign": "DOCU", + "docusign inc": "DOCU", + "asana": "ASAN", + "monday.com": "MNDY", + "monday": "MNDY", + "atlassian": "TEAM", + "atlassian corp": "TEAM", + "intuit": "INTU", + "intuit inc": "INTU", + "autodesk": "ADSK", + "autodesk inc": "ADSK", + "synopsys": "SNPS", + "cadence": "CDNS", + "cadence design": "CDNS", + "ansys": "ANSS", + "roper": "ROP", + "roper technologies": "ROP", + "fortinet": "FTNT", + "palo alto": "PANW", + "palo alto networks": "PANW", + "zscaler": "ZS", + "sentinelone": "S", + "veeva": "VEEV", + "veeva systems": "VEEV", +} + +US_EXCHANGE_CODES = { + "NYQ", + "NMS", + "NGM", + "NCM", + "ASE", + "PCX", + "BTS", + "NYSE", + "NASDAQ", + "AMEX", + "NYS", + "NAS", + "NIM", + "NAQ", +} + +SUFFIX_PATTERNS = [ + r"\s+inc\.?$", + r"\s+corp\.?$", + r"\s+corporation$", + r"\s+co\.?$", + r"\s+company$", + r"\s+llc$", + r"\s+ltd\.?$", + r"\s+limited$", + r"\s+plc$", + r"\s+holdings?$", + r"\s+group$", + r"\s+technologies$", + r"\s+enterprises?$", +] + + +def _normalize_company_name(name: str) -> str: + normalized = name.lower().strip() + for pattern in SUFFIX_PATTERNS: + normalized = re.sub(pattern, "", normalized, flags=re.IGNORECASE) + normalized = normalized.strip() + return normalized + + +def _search_yfinance_ticker(company_name: str) -> Optional[str]: + try: + search_result = yf.Ticker(company_name) + info = search_result.info + if info and "symbol" in info: + return info["symbol"] + except Exception as e: + logger.debug("yfinance search failed for %s: %s", company_name, str(e)) + + try: + search = yf.Search(company_name, max_results=5) + if hasattr(search, "quotes") and search.quotes: + for quote in search.quotes: + if "symbol" in quote: + return quote["symbol"] + except Exception as e: + logger.debug("yfinance Search failed for %s: %s", company_name, str(e)) + + return None + + +def validate_us_ticker(ticker: str) -> bool: + try: + ticker_obj = yf.Ticker(ticker.upper()) + info = ticker_obj.info + if not info: + logger.warning("Validation failed for %s: no info available", ticker) + return False + + exchange = info.get("exchange", "") + if exchange in US_EXCHANGE_CODES: + return True + + exchange_lower = exchange.lower() + if any(us_ex.lower() in exchange_lower for us_ex in ["nyse", "nasdaq", "amex", "nys", "nms", "ngm"]): + return True + + logger.warning("Validation failed for %s: exchange %s is not a US exchange", ticker, exchange) + return False + except Exception as e: + logger.warning("Validation failed for %s: %s", ticker, str(e)) + return False + + +def resolve_ticker(company_name: str) -> Optional[str]: + if not company_name or not company_name.strip(): + return None + + normalized = company_name.lower().strip() + + if normalized in COMPANY_TO_TICKER: + return COMPANY_TO_TICKER[normalized] + + normalized_stripped = _normalize_company_name(company_name) + if normalized_stripped in COMPANY_TO_TICKER: + return COMPANY_TO_TICKER[normalized_stripped] + + if company_name.upper() in [v for v in COMPANY_TO_TICKER.values()]: + if validate_us_ticker(company_name.upper()): + return company_name.upper() + + logger.info("Using yfinance fallback for company: %s", company_name) + yf_ticker = _search_yfinance_ticker(company_name) + + if yf_ticker: + if validate_us_ticker(yf_ticker): + logger.info("Resolved %s to %s via yfinance", company_name, yf_ticker) + return yf_ticker + else: + logger.warning("Ticker %s for %s failed US exchange validation", yf_ticker, company_name) + return None + + logger.warning("Could not resolve ticker for company: %s", company_name) + return None + + +def validate_tradeable(ticker: str) -> bool: + return validate_us_ticker(ticker) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 1f40a2a2..e88868d6 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -8,26 +8,24 @@ DEFAULT_CONFIG = { os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "dataflows/data_cache", ), - # LLM settings "llm_provider": "openai", - "deep_think_llm": "o4-mini", - "quick_think_llm": "gpt-4o-mini", + "deep_think_llm": "gpt-5", + "quick_think_llm": "gpt-5-mini", "backend_url": "https://api.openai.com/v1", - # Debate and discussion settings - "max_debate_rounds": 1, - "max_risk_discuss_rounds": 1, + "max_debate_rounds": 2, + "max_risk_discuss_rounds": 2, "max_recur_limit": 100, - # Data vendor configuration - # Category-level configuration (default for all tools in category) "data_vendors": { - "core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local - "technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local - "fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local - "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local + "core_stock_apis": "yfinance", + "technical_indicators": "yfinance", + "fundamental_data": "alpha_vantage", + "news_data": "alpha_vantage", }, - # Tool-level configuration (takes precedence over category-level) "tool_vendors": { - # Example: "get_stock_data": "alpha_vantage", # Override category default - # Example: "get_news": "openai", # Override category default }, + "discovery_timeout": 60, + "discovery_hard_timeout": 120, + "discovery_cache_ttl": 300, + "discovery_max_results": 20, + "discovery_min_mentions": 2, } diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 40cdff75..ba4c092c 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,9 +1,9 @@ -# TradingAgents/graph/trading_graph.py - import os +import signal +import threading from pathlib import Path import json -from datetime import date +from datetime import date, datetime from typing import Dict, Any, Tuple, List, Optional from langchain_openai import ChatOpenAI @@ -22,7 +22,6 @@ from tradingagents.agents.utils.agent_states import ( ) from tradingagents.dataflows.config import set_config -# Import the new abstract tool methods from agent_utils from tradingagents.agents.utils.agent_utils import ( get_stock_data, get_indicators, @@ -36,6 +35,19 @@ from tradingagents.agents.utils.agent_utils import ( get_global_news ) +from tradingagents.agents.discovery import ( + DiscoveryRequest, + DiscoveryResult, + DiscoveryStatus, + TrendingStock, + Sector, + EventCategory, + DiscoveryTimeoutError, + extract_entities, + calculate_trending_scores, +) +from tradingagents.dataflows.interface import get_bulk_news + from .conditional_logic import ConditionalLogic from .setup import GraphSetup from .propagation import Propagator @@ -43,8 +55,15 @@ from .reflection import Reflector from .signal_processing import SignalProcessor +class DiscoveryTimeoutException(Exception): + pass + + +def _timeout_handler(signum, frame): + raise DiscoveryTimeoutException("Discovery operation timed out") + + class TradingAgentsGraph: - """Main class that orchestrates the trading agents framework.""" def __init__( self, @@ -52,26 +71,16 @@ class TradingAgentsGraph: debug=False, config: Dict[str, Any] = None, ): - """Initialize the trading agents graph and components. - - Args: - selected_analysts: List of analyst types to include - debug: Whether to run in debug mode - config: Configuration dictionary. If None, uses default config - """ self.debug = debug self.config = config or DEFAULT_CONFIG - # Update the interface's config set_config(self.config) - # Create necessary directories os.makedirs( os.path.join(self.config["project_dir"], "dataflows/data_cache"), exist_ok=True, ) - # Initialize LLMs if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) @@ -83,18 +92,13 @@ class TradingAgentsGraph: self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) else: raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") - - # Initialize memories + self.bull_memory = FinancialSituationMemory("bull_memory", self.config) self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config) self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) - - # Create tool nodes self.tool_nodes = self._create_tool_nodes() - - # Initialize components self.conditional_logic = ConditionalLogic() self.graph_setup = GraphSetup( self.quick_thinking_llm, @@ -111,35 +115,26 @@ class TradingAgentsGraph: self.propagator = Propagator() self.reflector = Reflector(self.quick_thinking_llm) self.signal_processor = SignalProcessor(self.quick_thinking_llm) - - # State tracking self.curr_state = None self.ticker = None - self.log_states_dict = {} # date to full state dict - - # Set up the graph + self.log_states_dict = {} self.graph = self.graph_setup.setup_graph(selected_analysts) def _create_tool_nodes(self) -> Dict[str, ToolNode]: - """Create tool nodes for different data sources using abstract methods.""" return { "market": ToolNode( [ - # Core stock data tools get_stock_data, - # Technical indicators get_indicators, ] ), "social": ToolNode( [ - # News tools for social media analysis get_news, ] ), "news": ToolNode( [ - # News and insider information get_news, get_global_news, get_insider_sentiment, @@ -148,7 +143,6 @@ class TradingAgentsGraph: ), "fundamentals": ToolNode( [ - # Fundamental analysis tools get_fundamentals, get_balance_sheet, get_cashflow, @@ -158,18 +152,13 @@ class TradingAgentsGraph: } def propagate(self, company_name, trade_date): - """Run the trading agents graph for a company on a specific date.""" - self.ticker = company_name - - # Initialize state init_agent_state = self.propagator.create_initial_state( company_name, trade_date ) args = self.propagator.get_graph_args() if self.debug: - # Debug mode with tracing trace = [] for chunk in self.graph.stream(init_agent_state, **args): if len(chunk["messages"]) == 0: @@ -180,20 +169,14 @@ class TradingAgentsGraph: final_state = trace[-1] else: - # Standard mode without tracing final_state = self.graph.invoke(init_agent_state, **args) - # Store current state for reflection self.curr_state = final_state - - # Log state self._log_state(trade_date, final_state) - # Return decision and processed signal return final_state, self.process_signal(final_state["final_trade_decision"]) def _log_state(self, trade_date, final_state): - """Log the final state to a JSON file.""" self.log_states_dict[str(trade_date)] = { "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], @@ -224,7 +207,6 @@ class TradingAgentsGraph: "final_trade_decision": final_state["final_trade_decision"], } - # Save to file directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") directory.mkdir(parents=True, exist_ok=True) @@ -235,7 +217,6 @@ class TradingAgentsGraph: json.dump(self.log_states_dict, f, indent=4) def reflect_and_remember(self, returns_losses): - """Reflect on decisions and update memory based on returns.""" self.reflector.reflect_bull_researcher( self.curr_state, returns_losses, self.bull_memory ) @@ -253,5 +234,95 @@ class TradingAgentsGraph: ) def process_signal(self, full_signal): - """Process a signal to extract the core decision.""" return self.signal_processor.process_signal(full_signal) + + def discover_trending( + self, + request: Optional[DiscoveryRequest] = None, + ) -> DiscoveryResult: + if request is None: + request = DiscoveryRequest( + lookback_period="24h", + max_results=self.config.get("discovery_max_results", 20), + ) + + started_at = datetime.now() + result = DiscoveryResult( + request=request, + trending_stocks=[], + status=DiscoveryStatus.PROCESSING, + started_at=started_at, + ) + + hard_timeout = self.config.get("discovery_hard_timeout", 120) + + discovery_result = {"stocks": [], "error": None} + + def run_discovery(): + try: + articles = get_bulk_news(request.lookback_period) + + mentions = extract_entities(articles, self.config) + + min_mentions = self.config.get("discovery_min_mentions", 2) + max_results = request.max_results or self.config.get("discovery_max_results", 20) + + trending_stocks = calculate_trending_scores( + mentions, + articles, + max_results=max_results, + min_mentions=min_mentions, + ) + + discovery_result["stocks"] = trending_stocks + except Exception as e: + discovery_result["error"] = str(e) + + discovery_thread = threading.Thread(target=run_discovery) + discovery_thread.start() + discovery_thread.join(timeout=hard_timeout) + + if discovery_thread.is_alive(): + raise DiscoveryTimeoutError( + f"Discovery operation exceeded {hard_timeout} second timeout" + ) + + if discovery_result["error"]: + result.status = DiscoveryStatus.FAILED + result.error_message = discovery_result["error"] + result.completed_at = datetime.now() + return result + + trending_stocks = discovery_result["stocks"] + + if request.sector_filter: + sector_values = {s.value if isinstance(s, Sector) else s for s in request.sector_filter} + trending_stocks = [ + stock for stock in trending_stocks + if stock.sector.value in sector_values or stock.sector in request.sector_filter + ] + + if request.event_filter: + event_values = {e.value if isinstance(e, EventCategory) else e for e in request.event_filter} + trending_stocks = [ + stock for stock in trending_stocks + if stock.event_type.value in event_values or stock.event_type in request.event_filter + ] + + result.trending_stocks = trending_stocks + result.status = DiscoveryStatus.COMPLETED + result.completed_at = datetime.now() + + return result + + def analyze_trending( + self, + trending_stock: TrendingStock, + trade_date: Optional[date] = None, + ) -> Tuple[Dict[str, Any], str]: + ticker = trending_stock.ticker + + if trade_date is None: + trade_date = date.today() + + return self.propagate(ticker, trade_date)