From 59860b6a304290e364870ced37cafb70e5731c5c Mon Sep 17 00:00:00 2001 From: gnarayan1 Date: Sat, 29 Nov 2025 20:10:10 -0600 Subject: [PATCH] Added screening logic --- .agent/workflows/install_conda.md | 34 + .gitignore | 26 +- .python-version | 2 +- CHANGELOG.md | 40 + cli/main.py | 2218 ++++++++--------- cli/models.py | 20 +- cli/static/welcome.txt | 14 +- cli/utils.py | 552 ++-- main.py | 62 +- main_screening.py | 114 + requirements.txt | 53 +- setup.py | 87 +- test.py | 22 +- test_social_apis.py | 41 + tradingagents/agents/__init__.py | 80 +- .../agents/analysts/fundamentals_analyst.py | 126 +- .../agents/analysts/market_analyst.py | 170 +- tradingagents/agents/analysts/news_analyst.py | 116 +- .../agents/analysts/social_media_analyst.py | 118 +- .../agents/managers/research_manager.py | 110 +- tradingagents/agents/managers/risk_manager.py | 132 +- .../agents/researchers/bear_researcher.py | 122 +- .../agents/researchers/bull_researcher.py | 118 +- .../agents/risk_mgmt/aggresive_debator.py | 110 +- .../agents/risk_mgmt/conservative_debator.py | 116 +- .../agents/risk_mgmt/neutral_debator.py | 110 +- tradingagents/agents/screening_agent.py | 63 + tradingagents/agents/trader/trader.py | 92 +- tradingagents/agents/utils/agent_states.py | 152 +- tradingagents/agents/utils/agent_utils.py | 78 +- .../agents/utils/core_stock_tools.py | 64 +- .../agents/utils/fundamental_data_tools.py | 152 +- tradingagents/agents/utils/memory.py | 226 +- tradingagents/agents/utils/news_data_tools.py | 142 +- .../utils/technical_indicators_tools.py | 44 +- tradingagents/dataflows/alpha_vantage.py | 11 +- .../dataflows/alpha_vantage_common.py | 244 +- .../dataflows/alpha_vantage_fundamentals.py | 154 +- .../dataflows/alpha_vantage_indicator.py | 444 ++-- .../dataflows/alpha_vantage_market.py | 76 + tradingagents/dataflows/alpha_vantage_news.py | 84 +- .../dataflows/alpha_vantage_stock.py | 74 +- tradingagents/dataflows/config.py | 68 +- tradingagents/dataflows/google.py | 58 +- tradingagents/dataflows/googlenews_utils.py | 216 +- tradingagents/dataflows/interface.py | 503 ++-- tradingagents/dataflows/local.py | 948 +++---- tradingagents/dataflows/openai.py | 212 +- tradingagents/dataflows/reddit_utils.py | 270 +- tradingagents/dataflows/social_sentiment.py | 54 + tradingagents/dataflows/stockstats_utils.py | 164 +- tradingagents/dataflows/utils.py | 78 +- tradingagents/dataflows/y_finance.py | 812 +++--- tradingagents/dataflows/yfin_utils.py | 234 +- tradingagents/default_config.py | 66 +- tradingagents/graph/__init__.py | 34 +- tradingagents/graph/conditional_logic.py | 134 +- tradingagents/graph/propagation.py | 98 +- tradingagents/graph/reflection.py | 242 +- tradingagents/graph/setup.py | 404 +-- tradingagents/graph/signal_processing.py | 62 +- tradingagents/graph/trading_graph.py | 514 ++-- 62 files changed, 6225 insertions(+), 5759 deletions(-) create mode 100644 .agent/workflows/install_conda.md create mode 100644 CHANGELOG.md create mode 100644 main_screening.py create mode 100644 test_social_apis.py create mode 100644 tradingagents/agents/screening_agent.py create mode 100644 tradingagents/dataflows/alpha_vantage_market.py create mode 100644 tradingagents/dataflows/social_sentiment.py diff --git a/.agent/workflows/install_conda.md b/.agent/workflows/install_conda.md new file mode 100644 index 00000000..26d57191 --- /dev/null +++ b/.agent/workflows/install_conda.md @@ -0,0 +1,34 @@ +--- +description: How to install Miniconda in WSL2 +--- + +1. Download the Miniconda installer script: +```bash +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +``` + +2. Run the installer script: +```bash +bash Miniconda3-latest-Linux-x86_64.sh +``` + - Press **Enter** to review the license. + - Type **yes** to accept the license terms. + - Press **Enter** to confirm the installation location. + - Type **yes** when asked to initialize Miniconda3. + +3. Activate the changes (or restart your terminal): +```bash +source ~/.bashrc +``` + +4. Verify installation: +```bash +conda --version +``` + +5. (Optional) Create your environment: +```bash +conda create -n tradingagents python=3.10 +conda activate tradingagents +pip install -r requirements.txt +``` diff --git a/.gitignore b/.gitignore index 3369bad9..04bf35c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,15 @@ -.venv -results -env/ -__pycache__/ -.DS_Store -*.csv -src/ -eval_results/ -eval_data/ -*.egg-info/ -.env +.venv +results +env/ +__pycache__/ +.venv +results +env/ +__pycache__/ +.DS_Store +*.csv +src/ +eval_results/ +eval_data/ +*.egg-info/ +.env \ No newline at end of file diff --git a/.python-version b/.python-version index c8cfe395..2951d9b0 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.10 +3.10 diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..850841c8 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,40 @@ +# Changelog - November 29, 2025 + +## ๐Ÿš€ New Features + +### 1. Stock Screening Agent +- **New Agent**: Created `tradingagents/agents/screening_agent.py`. +- **Purpose**: Identifies potential stock candidates ("Hidden Gems") for further analysis by the main Trading Graph. +- **Strategy**: Implemented an "Early Bird" multi-factor strategy: + - **Social Hype**: Detects stocks trending on StockTwits and Reddit. + - **Insider Activity**: Checks for recent buying by company executives. + - **Technical Analysis**: Identifies Oversold conditions (RSI < 30) or Divergence. + - **Catalysts**: Checks for upcoming earnings reports. + +### 2. New Tools & Data Sources +- **`get_trending_social`**: New tool in `tradingagents/dataflows/social_sentiment.py` to fetch trending tickers from StockTwits and Ape Wisdom (Reddit). +- **`get_market_movers`**: Added to `tradingagents/dataflows/alpha_vantage_market.py` to find Top Gainers/Losers. +- **`get_earnings_calendar`**: Added to `tradingagents/dataflows/alpha_vantage_market.py`. +- **Integrated Tools**: Exposed `get_insider_transactions` and `get_indicators` to the Screening Agent. + +### 3. Execution Workflow +- **`main_screening.py`**: Created a dedicated script to run the Screening Agent. + - **Multi-Step Reasoning**: Implemented a loop allowing the agent to chain tool calls (e.g., Screen -> Check Social -> Recommend) before outputting a final decision. + +## ๐Ÿ› ๏ธ Infrastructure & Fixes + +### Dataflow & Routing +- **`interface.py`**: + - Updated `TOOLS_CATEGORIES` and `VENDOR_METHODS` to support new tools. + - **Fix**: Resolved a critical `SyntaxError` caused by a corrupted edit. +- **`agent_utils.py`**: + - **Fix**: Restored file integrity after it was corrupted during an edit. + - Added imports for all new screening tools. + +### Configuration +- **`.gitignore`**: Corrected to ensure `tradingagents/` source code is tracked by Git (removed accidental exclusion). +- **Dependencies**: Added `python-dotenv` to `requirements.txt` and `setup.py`. + +## ๐Ÿงช Verification +- Verified `main_screening.py` execution with the new loop logic. +- Confirmed fallback behavior for Social Sentiment tools (StockTwits -> Reddit). diff --git a/cli/main.py b/cli/main.py index 2e06d50c..166bfaef 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,1109 +1,1109 @@ -from typing import Optional -import datetime -import typer -from pathlib import Path -from functools import wraps -from rich.console import Console -from dotenv import load_dotenv - -# Load environment variables from .env file -load_dotenv() -from rich.panel import Panel -from rich.spinner import Spinner -from rich.live import Live -from rich.columns import Columns -from rich.markdown import Markdown -from rich.layout import Layout -from rich.text import Text -from rich.live import Live -from rich.table import Table -from collections import deque -import time -from rich.tree import Tree -from rich import box -from rich.align import Align -from rich.rule import Rule - -from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.default_config import DEFAULT_CONFIG -from cli.models import AnalystType -from cli.utils import * - -console = Console() - -app = typer.Typer( - name="TradingAgents", - help="TradingAgents CLI: Multi-Agents LLM Financial Trading Framework", - add_completion=True, # Enable shell completion -) - - -# Create a deque to store recent messages with a maximum length -class MessageBuffer: - def __init__(self, max_length=100): - self.messages = deque(maxlen=max_length) - self.tool_calls = deque(maxlen=max_length) - self.current_report = None - self.final_report = None # Store the complete final report - self.agent_status = { - # Analyst Team - "Market Analyst": "pending", - "Social Analyst": "pending", - "News Analyst": "pending", - "Fundamentals Analyst": "pending", - # Research Team - "Bull Researcher": "pending", - "Bear Researcher": "pending", - "Research Manager": "pending", - # Trading Team - "Trader": "pending", - # Risk Management Team - "Risky Analyst": "pending", - "Neutral Analyst": "pending", - "Safe Analyst": "pending", - # Portfolio Management Team - "Portfolio Manager": "pending", - } - self.current_agent = None - self.report_sections = { - "market_report": None, - "sentiment_report": None, - "news_report": None, - "fundamentals_report": None, - "investment_plan": None, - "trader_investment_plan": None, - "final_trade_decision": None, - } - - def add_message(self, message_type, content): - timestamp = datetime.datetime.now().strftime("%H:%M:%S") - self.messages.append((timestamp, message_type, content)) - - def add_tool_call(self, tool_name, args): - timestamp = datetime.datetime.now().strftime("%H:%M:%S") - self.tool_calls.append((timestamp, tool_name, args)) - - def update_agent_status(self, agent, status): - if agent in self.agent_status: - self.agent_status[agent] = status - self.current_agent = agent - - def update_report_section(self, section_name, content): - if section_name in self.report_sections: - self.report_sections[section_name] = content - self._update_current_report() - - def _update_current_report(self): - # For the panel display, only show the most recently updated section - latest_section = None - latest_content = None - - # Find the most recently updated section - for section, content in self.report_sections.items(): - if content is not None: - latest_section = section - latest_content = content - - if latest_section and latest_content: - # Format the current section for display - section_titles = { - "market_report": "Market Analysis", - "sentiment_report": "Social Sentiment", - "news_report": "News Analysis", - "fundamentals_report": "Fundamentals Analysis", - "investment_plan": "Research Team Decision", - "trader_investment_plan": "Trading Team Plan", - "final_trade_decision": "Portfolio Management Decision", - } - self.current_report = ( - f"### {section_titles[latest_section]}\n{latest_content}" - ) - - # Update the final complete report - self._update_final_report() - - def _update_final_report(self): - report_parts = [] - - # Analyst Team Reports - if any( - self.report_sections[section] - for section in [ - "market_report", - "sentiment_report", - "news_report", - "fundamentals_report", - ] - ): - report_parts.append("## Analyst Team Reports") - if self.report_sections["market_report"]: - report_parts.append( - f"### Market Analysis\n{self.report_sections['market_report']}" - ) - if self.report_sections["sentiment_report"]: - report_parts.append( - f"### Social Sentiment\n{self.report_sections['sentiment_report']}" - ) - if self.report_sections["news_report"]: - report_parts.append( - f"### News Analysis\n{self.report_sections['news_report']}" - ) - if self.report_sections["fundamentals_report"]: - report_parts.append( - f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" - ) - - # Research Team Reports - if self.report_sections["investment_plan"]: - report_parts.append("## Research Team Decision") - report_parts.append(f"{self.report_sections['investment_plan']}") - - # Trading Team Reports - if self.report_sections["trader_investment_plan"]: - report_parts.append("## Trading Team Plan") - report_parts.append(f"{self.report_sections['trader_investment_plan']}") - - # Portfolio Management Decision - if self.report_sections["final_trade_decision"]: - report_parts.append("## Portfolio Management Decision") - report_parts.append(f"{self.report_sections['final_trade_decision']}") - - self.final_report = "\n\n".join(report_parts) if report_parts else None - - -message_buffer = MessageBuffer() - - -def create_layout(): - layout = Layout() - layout.split_column( - Layout(name="header", size=3), - Layout(name="main"), - Layout(name="footer", size=3), - ) - layout["main"].split_column( - Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5) - ) - layout["upper"].split_row( - Layout(name="progress", ratio=2), Layout(name="messages", ratio=3) - ) - return layout - - -def update_display(layout, spinner_text=None): - # Header with welcome message - layout["header"].update( - Panel( - "[bold green]Welcome to TradingAgents CLI[/bold green]\n" - "[dim]ยฉ [Tauric Research](https://github.com/TauricResearch)[/dim]", - title="Welcome to TradingAgents", - border_style="green", - padding=(1, 2), - expand=True, - ) - ) - - # Progress panel showing agent status - progress_table = Table( - show_header=True, - header_style="bold magenta", - show_footer=False, - box=box.SIMPLE_HEAD, # Use simple header with horizontal lines - title=None, # Remove the redundant Progress title - padding=(0, 2), # Add horizontal padding - expand=True, # Make table expand to fill available space - ) - progress_table.add_column("Team", style="cyan", justify="center", width=20) - progress_table.add_column("Agent", style="green", justify="center", width=20) - progress_table.add_column("Status", style="yellow", justify="center", width=20) - - # Group agents by team - teams = { - "Analyst Team": [ - "Market Analyst", - "Social Analyst", - "News Analyst", - "Fundamentals Analyst", - ], - "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], - "Trading Team": ["Trader"], - "Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"], - "Portfolio Management": ["Portfolio Manager"], - } - - for team, agents in teams.items(): - # Add first agent with team name - first_agent = agents[0] - status = message_buffer.agent_status[first_agent] - if status == "in_progress": - spinner = Spinner( - "dots", text="[blue]in_progress[/blue]", style="bold cyan" - ) - status_cell = spinner - else: - status_color = { - "pending": "yellow", - "completed": "green", - "error": "red", - }.get(status, "white") - status_cell = f"[{status_color}]{status}[/{status_color}]" - progress_table.add_row(team, first_agent, status_cell) - - # Add remaining agents in team - for agent in agents[1:]: - status = message_buffer.agent_status[agent] - if status == "in_progress": - spinner = Spinner( - "dots", text="[blue]in_progress[/blue]", style="bold cyan" - ) - status_cell = spinner - else: - status_color = { - "pending": "yellow", - "completed": "green", - "error": "red", - }.get(status, "white") - status_cell = f"[{status_color}]{status}[/{status_color}]" - progress_table.add_row("", agent, status_cell) - - # Add horizontal line after each team - progress_table.add_row("โ”€" * 20, "โ”€" * 20, "โ”€" * 20, style="dim") - - layout["progress"].update( - Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)) - ) - - # Messages panel showing recent messages and tool calls - messages_table = Table( - show_header=True, - header_style="bold magenta", - show_footer=False, - expand=True, # Make table expand to fill available space - box=box.MINIMAL, # Use minimal box style for a lighter look - show_lines=True, # Keep horizontal lines - padding=(0, 1), # Add some padding between columns - ) - messages_table.add_column("Time", style="cyan", width=8, justify="center") - messages_table.add_column("Type", style="green", width=10, justify="center") - messages_table.add_column( - "Content", style="white", no_wrap=False, ratio=1 - ) # Make content column expand - - # Combine tool calls and messages - all_messages = [] - - # Add tool calls - for timestamp, tool_name, args in message_buffer.tool_calls: - # Truncate tool call args if too long - if isinstance(args, str) and len(args) > 100: - args = args[:97] + "..." - all_messages.append((timestamp, "Tool", f"{tool_name}: {args}")) - - # Add regular messages - for timestamp, msg_type, content in message_buffer.messages: - # Convert content to string if it's not already - content_str = content - if isinstance(content, list): - # Handle list of content blocks (Anthropic format) - text_parts = [] - for item in content: - if isinstance(item, dict): - if item.get('type') == 'text': - text_parts.append(item.get('text', '')) - elif item.get('type') == 'tool_use': - text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") - else: - text_parts.append(str(item)) - content_str = ' '.join(text_parts) - elif not isinstance(content_str, str): - content_str = str(content) - - # Truncate message content if too long - if len(content_str) > 200: - content_str = content_str[:197] + "..." - all_messages.append((timestamp, msg_type, content_str)) - - # Sort by timestamp - all_messages.sort(key=lambda x: x[0]) - - # Calculate how many messages we can show based on available space - # Start with a reasonable number and adjust based on content length - max_messages = 12 # Increased from 8 to better fill the space - - # Get the last N messages that will fit in the panel - recent_messages = all_messages[-max_messages:] - - # Add messages to table - for timestamp, msg_type, content in recent_messages: - # Format content with word wrapping - wrapped_content = Text(content, overflow="fold") - messages_table.add_row(timestamp, msg_type, wrapped_content) - - if spinner_text: - messages_table.add_row("", "Spinner", spinner_text) - - # Add a footer to indicate if messages were truncated - if len(all_messages) > max_messages: - messages_table.footer = ( - f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]" - ) - - layout["messages"].update( - Panel( - messages_table, - title="Messages & Tools", - border_style="blue", - padding=(1, 2), - ) - ) - - # Analysis panel showing current report - if message_buffer.current_report: - layout["analysis"].update( - Panel( - Markdown(message_buffer.current_report), - title="Current Report", - border_style="green", - padding=(1, 2), - ) - ) - else: - layout["analysis"].update( - Panel( - "[italic]Waiting for analysis report...[/italic]", - title="Current Report", - border_style="green", - padding=(1, 2), - ) - ) - - # Footer with statistics - tool_calls_count = len(message_buffer.tool_calls) - llm_calls_count = sum( - 1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning" - ) - reports_count = sum( - 1 for content in message_buffer.report_sections.values() if content is not None - ) - - stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True) - stats_table.add_column("Stats", justify="center") - stats_table.add_row( - f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}" - ) - - layout["footer"].update(Panel(stats_table, border_style="grey50")) - - -def get_user_selections(): - """Get all user selections before starting the analysis display.""" - # Display ASCII art welcome message - with open("./cli/static/welcome.txt", "r") as f: - welcome_ascii = f.read() - - # Create welcome box content - welcome_content = f"{welcome_ascii}\n" - welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" - welcome_content += "[bold]Workflow Steps:[/bold]\n" - welcome_content += "I. Analyst Team โ†’ II. Research Team โ†’ III. Trader โ†’ IV. Risk Management โ†’ V. Portfolio Management\n\n" - welcome_content += ( - "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]" - ) - - # Create and center the welcome box - welcome_box = Panel( - welcome_content, - border_style="green", - padding=(1, 2), - title="Welcome to TradingAgents", - subtitle="Multi-Agents LLM Financial Trading Framework", - ) - console.print(Align.center(welcome_box)) - console.print() # Add a blank line after the welcome box - - # Create a boxed questionnaire for each step - def create_question_box(title, prompt, default=None): - box_content = f"[bold]{title}[/bold]\n" - box_content += f"[dim]{prompt}[/dim]" - if default: - box_content += f"\n[dim]Default: {default}[/dim]" - return Panel(box_content, border_style="blue", padding=(1, 2)) - - # Step 1: Ticker symbol - console.print( - create_question_box( - "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY" - ) - ) - selected_ticker = get_ticker() - - # Step 2: Analysis date - default_date = datetime.datetime.now().strftime("%Y-%m-%d") - console.print( - create_question_box( - "Step 2: Analysis Date", - "Enter the analysis date (YYYY-MM-DD)", - default_date, - ) - ) - analysis_date = get_analysis_date() - - # Step 3: Select analysts - console.print( - create_question_box( - "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis" - ) - ) - selected_analysts = select_analysts() - console.print( - f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}" - ) - - # Step 4: Research depth - console.print( - create_question_box( - "Step 4: Research Depth", "Select your research depth level" - ) - ) - selected_research_depth = select_research_depth() - - # Step 5: OpenAI backend - console.print( - create_question_box( - "Step 5: OpenAI backend", "Select which service to talk to" - ) - ) - selected_llm_provider, backend_url = select_llm_provider() - - # Step 6: Thinking agents - console.print( - create_question_box( - "Step 6: Thinking Agents", "Select your thinking agents for analysis" - ) - ) - selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) - selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) - - return { - "ticker": selected_ticker, - "analysis_date": analysis_date, - "analysts": selected_analysts, - "research_depth": selected_research_depth, - "llm_provider": selected_llm_provider.lower(), - "backend_url": backend_url, - "shallow_thinker": selected_shallow_thinker, - "deep_thinker": selected_deep_thinker, - } - - -def get_ticker(): - """Get ticker symbol from user input.""" - return typer.prompt("", default="SPY") - - -def get_analysis_date(): - """Get the analysis date from user input.""" - while True: - date_str = typer.prompt( - "", default=datetime.datetime.now().strftime("%Y-%m-%d") - ) - try: - # Validate date format and ensure it's not in the future - analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d") - if analysis_date.date() > datetime.datetime.now().date(): - console.print("[red]Error: Analysis date cannot be in the future[/red]") - continue - return date_str - except ValueError: - console.print( - "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]" - ) - - -def display_complete_report(final_state): - """Display the complete analysis report with team-based panels.""" - console.print("\n[bold green]Complete Analysis Report[/bold green]\n") - - # I. Analyst Team Reports - analyst_reports = [] - - # Market Analyst Report - if final_state.get("market_report"): - analyst_reports.append( - Panel( - Markdown(final_state["market_report"]), - title="Market Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # Social Analyst Report - if final_state.get("sentiment_report"): - analyst_reports.append( - Panel( - Markdown(final_state["sentiment_report"]), - title="Social Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # News Analyst Report - if final_state.get("news_report"): - analyst_reports.append( - Panel( - Markdown(final_state["news_report"]), - title="News Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # Fundamentals Analyst Report - if final_state.get("fundamentals_report"): - analyst_reports.append( - Panel( - Markdown(final_state["fundamentals_report"]), - title="Fundamentals Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - if analyst_reports: - console.print( - Panel( - Columns(analyst_reports, equal=True, expand=True), - title="I. Analyst Team Reports", - border_style="cyan", - padding=(1, 2), - ) - ) - - # II. Research Team Reports - if final_state.get("investment_debate_state"): - research_reports = [] - debate_state = final_state["investment_debate_state"] - - # Bull Researcher Analysis - if debate_state.get("bull_history"): - research_reports.append( - Panel( - Markdown(debate_state["bull_history"]), - title="Bull Researcher", - border_style="blue", - padding=(1, 2), - ) - ) - - # Bear Researcher Analysis - if debate_state.get("bear_history"): - research_reports.append( - Panel( - Markdown(debate_state["bear_history"]), - title="Bear Researcher", - border_style="blue", - padding=(1, 2), - ) - ) - - # Research Manager Decision - if debate_state.get("judge_decision"): - research_reports.append( - Panel( - Markdown(debate_state["judge_decision"]), - title="Research Manager", - border_style="blue", - padding=(1, 2), - ) - ) - - if research_reports: - console.print( - Panel( - Columns(research_reports, equal=True, expand=True), - title="II. Research Team Decision", - border_style="magenta", - padding=(1, 2), - ) - ) - - # III. Trading Team Reports - if final_state.get("trader_investment_plan"): - console.print( - Panel( - Panel( - Markdown(final_state["trader_investment_plan"]), - title="Trader", - border_style="blue", - padding=(1, 2), - ), - title="III. Trading Team Plan", - border_style="yellow", - padding=(1, 2), - ) - ) - - # IV. Risk Management Team Reports - if final_state.get("risk_debate_state"): - risk_reports = [] - risk_state = final_state["risk_debate_state"] - - # Aggressive (Risky) Analyst Analysis - if risk_state.get("risky_history"): - risk_reports.append( - Panel( - Markdown(risk_state["risky_history"]), - title="Aggressive Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # Conservative (Safe) Analyst Analysis - if risk_state.get("safe_history"): - risk_reports.append( - Panel( - Markdown(risk_state["safe_history"]), - title="Conservative Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - # Neutral Analyst Analysis - if risk_state.get("neutral_history"): - risk_reports.append( - Panel( - Markdown(risk_state["neutral_history"]), - title="Neutral Analyst", - border_style="blue", - padding=(1, 2), - ) - ) - - if risk_reports: - console.print( - Panel( - Columns(risk_reports, equal=True, expand=True), - title="IV. Risk Management Team Decision", - border_style="red", - padding=(1, 2), - ) - ) - - # V. Portfolio Manager Decision - if risk_state.get("judge_decision"): - console.print( - Panel( - Panel( - Markdown(risk_state["judge_decision"]), - title="Portfolio Manager", - border_style="blue", - padding=(1, 2), - ), - title="V. Portfolio Manager Decision", - border_style="green", - padding=(1, 2), - ) - ) - - -def update_research_team_status(status): - """Update status for all research team members and trader.""" - research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"] - for agent in research_team: - message_buffer.update_agent_status(agent, status) - -def extract_content_string(content): - """Extract string content from various message formats.""" - if isinstance(content, str): - return content - elif isinstance(content, list): - # Handle Anthropic's list format - text_parts = [] - for item in content: - if isinstance(item, dict): - if item.get('type') == 'text': - text_parts.append(item.get('text', '')) - elif item.get('type') == 'tool_use': - text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") - else: - text_parts.append(str(item)) - return ' '.join(text_parts) - else: - return str(content) - -def run_analysis(): - # First get all user selections - selections = get_user_selections() - - # Create config with selected research depth - config = DEFAULT_CONFIG.copy() - config["max_debate_rounds"] = selections["research_depth"] - config["max_risk_discuss_rounds"] = selections["research_depth"] - config["quick_think_llm"] = selections["shallow_thinker"] - config["deep_think_llm"] = selections["deep_thinker"] - config["backend_url"] = selections["backend_url"] - config["llm_provider"] = selections["llm_provider"].lower() - - # Initialize the graph - graph = TradingAgentsGraph( - [analyst.value for analyst in selections["analysts"]], config=config, debug=True - ) - - # Create result directory - results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] - results_dir.mkdir(parents=True, exist_ok=True) - report_dir = results_dir / "reports" - report_dir.mkdir(parents=True, exist_ok=True) - log_file = results_dir / "message_tool.log" - log_file.touch(exist_ok=True) - - def save_message_decorator(obj, func_name): - func = getattr(obj, func_name) - @wraps(func) - def wrapper(*args, **kwargs): - func(*args, **kwargs) - timestamp, message_type, content = obj.messages[-1] - content = content.replace("\n", " ") # Replace newlines with spaces - with open(log_file, "a") as f: - f.write(f"{timestamp} [{message_type}] {content}\n") - return wrapper - - def save_tool_call_decorator(obj, func_name): - func = getattr(obj, func_name) - @wraps(func) - def wrapper(*args, **kwargs): - func(*args, **kwargs) - timestamp, tool_name, args = obj.tool_calls[-1] - args_str = ", ".join(f"{k}={v}" for k, v in args.items()) - with open(log_file, "a") as f: - f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") - return wrapper - - def save_report_section_decorator(obj, func_name): - func = getattr(obj, func_name) - @wraps(func) - def wrapper(section_name, content): - func(section_name, content) - if section_name in obj.report_sections and obj.report_sections[section_name] is not None: - content = obj.report_sections[section_name] - if content: - file_name = f"{section_name}.md" - with open(report_dir / file_name, "w") as f: - f.write(content) - return wrapper - - message_buffer.add_message = save_message_decorator(message_buffer, "add_message") - message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call") - message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section") - - # Now start the display layout - layout = create_layout() - - with Live(layout, refresh_per_second=4) as live: - # Initial display - update_display(layout) - - # Add initial messages - message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") - message_buffer.add_message( - "System", f"Analysis date: {selections['analysis_date']}" - ) - message_buffer.add_message( - "System", - f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}", - ) - update_display(layout) - - # Reset agent statuses - for agent in message_buffer.agent_status: - message_buffer.update_agent_status(agent, "pending") - - # Reset report sections - for section in message_buffer.report_sections: - message_buffer.report_sections[section] = None - message_buffer.current_report = None - message_buffer.final_report = None - - # Update agent status to in_progress for the first analyst - first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst" - message_buffer.update_agent_status(first_analyst, "in_progress") - update_display(layout) - - # Create spinner text - spinner_text = ( - f"Analyzing {selections['ticker']} on {selections['analysis_date']}..." - ) - update_display(layout, spinner_text) - - # Initialize state and get graph args - init_agent_state = graph.propagator.create_initial_state( - selections["ticker"], selections["analysis_date"] - ) - args = graph.propagator.get_graph_args() - - # Stream the analysis - trace = [] - for chunk in graph.graph.stream(init_agent_state, **args): - if len(chunk["messages"]) > 0: - # Get the last message from the chunk - last_message = chunk["messages"][-1] - - # Extract message content and type - if hasattr(last_message, "content"): - content = extract_content_string(last_message.content) # Use the helper function - msg_type = "Reasoning" - else: - content = str(last_message) - msg_type = "System" - - # Add message to buffer - message_buffer.add_message(msg_type, content) - - # If it's a tool call, add it to tool calls - if hasattr(last_message, "tool_calls"): - for tool_call in last_message.tool_calls: - # Handle both dictionary and object tool calls - if isinstance(tool_call, dict): - message_buffer.add_tool_call( - tool_call["name"], tool_call["args"] - ) - else: - message_buffer.add_tool_call(tool_call.name, tool_call.args) - - # Update reports and agent status based on chunk content - # Analyst Team Reports - if "market_report" in chunk and chunk["market_report"]: - message_buffer.update_report_section( - "market_report", chunk["market_report"] - ) - message_buffer.update_agent_status("Market Analyst", "completed") - # Set next analyst to in_progress - if "social" in selections["analysts"]: - message_buffer.update_agent_status( - "Social Analyst", "in_progress" - ) - - if "sentiment_report" in chunk and chunk["sentiment_report"]: - message_buffer.update_report_section( - "sentiment_report", chunk["sentiment_report"] - ) - message_buffer.update_agent_status("Social Analyst", "completed") - # Set next analyst to in_progress - if "news" in selections["analysts"]: - message_buffer.update_agent_status( - "News Analyst", "in_progress" - ) - - if "news_report" in chunk and chunk["news_report"]: - message_buffer.update_report_section( - "news_report", chunk["news_report"] - ) - message_buffer.update_agent_status("News Analyst", "completed") - # Set next analyst to in_progress - if "fundamentals" in selections["analysts"]: - message_buffer.update_agent_status( - "Fundamentals Analyst", "in_progress" - ) - - if "fundamentals_report" in chunk and chunk["fundamentals_report"]: - message_buffer.update_report_section( - "fundamentals_report", chunk["fundamentals_report"] - ) - message_buffer.update_agent_status( - "Fundamentals Analyst", "completed" - ) - # Set all research team members to in_progress - update_research_team_status("in_progress") - - # Research Team - Handle Investment Debate State - if ( - "investment_debate_state" in chunk - and chunk["investment_debate_state"] - ): - debate_state = chunk["investment_debate_state"] - - # Update Bull Researcher status and report - if "bull_history" in debate_state and debate_state["bull_history"]: - # Keep all research team members in progress - update_research_team_status("in_progress") - # Extract latest bull response - bull_responses = debate_state["bull_history"].split("\n") - latest_bull = bull_responses[-1] if bull_responses else "" - if latest_bull: - message_buffer.add_message("Reasoning", latest_bull) - # Update research report with bull's latest analysis - message_buffer.update_report_section( - "investment_plan", - f"### Bull Researcher Analysis\n{latest_bull}", - ) - - # Update Bear Researcher status and report - if "bear_history" in debate_state and debate_state["bear_history"]: - # Keep all research team members in progress - update_research_team_status("in_progress") - # Extract latest bear response - bear_responses = debate_state["bear_history"].split("\n") - latest_bear = bear_responses[-1] if bear_responses else "" - if latest_bear: - message_buffer.add_message("Reasoning", latest_bear) - # Update research report with bear's latest analysis - message_buffer.update_report_section( - "investment_plan", - f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}", - ) - - # Update Research Manager status and final decision - if ( - "judge_decision" in debate_state - and debate_state["judge_decision"] - ): - # Keep all research team members in progress until final decision - update_research_team_status("in_progress") - message_buffer.add_message( - "Reasoning", - f"Research Manager: {debate_state['judge_decision']}", - ) - # Update research report with final decision - message_buffer.update_report_section( - "investment_plan", - f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}", - ) - # Mark all research team members as completed - update_research_team_status("completed") - # Set first risk analyst to in_progress - message_buffer.update_agent_status( - "Risky Analyst", "in_progress" - ) - - # Trading Team - if ( - "trader_investment_plan" in chunk - and chunk["trader_investment_plan"] - ): - message_buffer.update_report_section( - "trader_investment_plan", chunk["trader_investment_plan"] - ) - # Set first risk analyst to in_progress - message_buffer.update_agent_status("Risky Analyst", "in_progress") - - # Risk Management Team - Handle Risk Debate State - if "risk_debate_state" in chunk and chunk["risk_debate_state"]: - risk_state = chunk["risk_debate_state"] - - # Update Risky Analyst status and report - if ( - "current_risky_response" in risk_state - and risk_state["current_risky_response"] - ): - message_buffer.update_agent_status( - "Risky Analyst", "in_progress" - ) - message_buffer.add_message( - "Reasoning", - f"Risky Analyst: {risk_state['current_risky_response']}", - ) - # Update risk report with risky analyst's latest analysis only - message_buffer.update_report_section( - "final_trade_decision", - f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}", - ) - - # Update Safe Analyst status and report - if ( - "current_safe_response" in risk_state - and risk_state["current_safe_response"] - ): - message_buffer.update_agent_status( - "Safe Analyst", "in_progress" - ) - message_buffer.add_message( - "Reasoning", - f"Safe Analyst: {risk_state['current_safe_response']}", - ) - # Update risk report with safe analyst's latest analysis only - message_buffer.update_report_section( - "final_trade_decision", - f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}", - ) - - # Update Neutral Analyst status and report - if ( - "current_neutral_response" in risk_state - and risk_state["current_neutral_response"] - ): - message_buffer.update_agent_status( - "Neutral Analyst", "in_progress" - ) - message_buffer.add_message( - "Reasoning", - f"Neutral Analyst: {risk_state['current_neutral_response']}", - ) - # Update risk report with neutral analyst's latest analysis only - message_buffer.update_report_section( - "final_trade_decision", - f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}", - ) - - # Update Portfolio Manager status and final decision - if "judge_decision" in risk_state and risk_state["judge_decision"]: - message_buffer.update_agent_status( - "Portfolio Manager", "in_progress" - ) - message_buffer.add_message( - "Reasoning", - f"Portfolio Manager: {risk_state['judge_decision']}", - ) - # Update risk report with final decision only - message_buffer.update_report_section( - "final_trade_decision", - f"### Portfolio Manager Decision\n{risk_state['judge_decision']}", - ) - # Mark risk analysts as completed - message_buffer.update_agent_status("Risky Analyst", "completed") - message_buffer.update_agent_status("Safe Analyst", "completed") - message_buffer.update_agent_status( - "Neutral Analyst", "completed" - ) - message_buffer.update_agent_status( - "Portfolio Manager", "completed" - ) - - # Update the display - update_display(layout) - - trace.append(chunk) - - # Get final state and decision - final_state = trace[-1] - decision = graph.process_signal(final_state["final_trade_decision"]) - - # Update all agent statuses to completed - for agent in message_buffer.agent_status: - message_buffer.update_agent_status(agent, "completed") - - message_buffer.add_message( - "Analysis", f"Completed analysis for {selections['analysis_date']}" - ) - - # Update final report sections - for section in message_buffer.report_sections.keys(): - if section in final_state: - message_buffer.update_report_section(section, final_state[section]) - - # Display the complete final report - display_complete_report(final_state) - - update_display(layout) - - -@app.command() -def analyze(): - run_analysis() - - -if __name__ == "__main__": - app() +from typing import Optional +import datetime +import typer +from pathlib import Path +from functools import wraps +from rich.console import Console +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() +from rich.panel import Panel +from rich.spinner import Spinner +from rich.live import Live +from rich.columns import Columns +from rich.markdown import Markdown +from rich.layout import Layout +from rich.text import Text +from rich.live import Live +from rich.table import Table +from collections import deque +import time +from rich.tree import Tree +from rich import box +from rich.align import Align +from rich.rule import Rule + +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.default_config import DEFAULT_CONFIG +from cli.models import AnalystType +from cli.utils import * + +console = Console() + +app = typer.Typer( + name="TradingAgents", + help="TradingAgents CLI: Multi-Agents LLM Financial Trading Framework", + add_completion=True, # Enable shell completion +) + + +# Create a deque to store recent messages with a maximum length +class MessageBuffer: + def __init__(self, max_length=100): + self.messages = deque(maxlen=max_length) + self.tool_calls = deque(maxlen=max_length) + self.current_report = None + self.final_report = None # Store the complete final report + self.agent_status = { + # Analyst Team + "Market Analyst": "pending", + "Social Analyst": "pending", + "News Analyst": "pending", + "Fundamentals Analyst": "pending", + # Research Team + "Bull Researcher": "pending", + "Bear Researcher": "pending", + "Research Manager": "pending", + # Trading Team + "Trader": "pending", + # Risk Management Team + "Risky Analyst": "pending", + "Neutral Analyst": "pending", + "Safe Analyst": "pending", + # Portfolio Management Team + "Portfolio Manager": "pending", + } + self.current_agent = None + self.report_sections = { + "market_report": None, + "sentiment_report": None, + "news_report": None, + "fundamentals_report": None, + "investment_plan": None, + "trader_investment_plan": None, + "final_trade_decision": None, + } + + def add_message(self, message_type, content): + timestamp = datetime.datetime.now().strftime("%H:%M:%S") + self.messages.append((timestamp, message_type, content)) + + def add_tool_call(self, tool_name, args): + timestamp = datetime.datetime.now().strftime("%H:%M:%S") + self.tool_calls.append((timestamp, tool_name, args)) + + def update_agent_status(self, agent, status): + if agent in self.agent_status: + self.agent_status[agent] = status + self.current_agent = agent + + def update_report_section(self, section_name, content): + if section_name in self.report_sections: + self.report_sections[section_name] = content + self._update_current_report() + + def _update_current_report(self): + # For the panel display, only show the most recently updated section + latest_section = None + latest_content = None + + # Find the most recently updated section + for section, content in self.report_sections.items(): + if content is not None: + latest_section = section + latest_content = content + + if latest_section and latest_content: + # Format the current section for display + section_titles = { + "market_report": "Market Analysis", + "sentiment_report": "Social Sentiment", + "news_report": "News Analysis", + "fundamentals_report": "Fundamentals Analysis", + "investment_plan": "Research Team Decision", + "trader_investment_plan": "Trading Team Plan", + "final_trade_decision": "Portfolio Management Decision", + } + self.current_report = ( + f"### {section_titles[latest_section]}\n{latest_content}" + ) + + # Update the final complete report + self._update_final_report() + + def _update_final_report(self): + report_parts = [] + + # Analyst Team Reports + if any( + self.report_sections[section] + for section in [ + "market_report", + "sentiment_report", + "news_report", + "fundamentals_report", + ] + ): + report_parts.append("## Analyst Team Reports") + if self.report_sections["market_report"]: + report_parts.append( + f"### Market Analysis\n{self.report_sections['market_report']}" + ) + if self.report_sections["sentiment_report"]: + report_parts.append( + f"### Social Sentiment\n{self.report_sections['sentiment_report']}" + ) + if self.report_sections["news_report"]: + report_parts.append( + f"### News Analysis\n{self.report_sections['news_report']}" + ) + if self.report_sections["fundamentals_report"]: + report_parts.append( + f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" + ) + + # Research Team Reports + if self.report_sections["investment_plan"]: + report_parts.append("## Research Team Decision") + report_parts.append(f"{self.report_sections['investment_plan']}") + + # Trading Team Reports + if self.report_sections["trader_investment_plan"]: + report_parts.append("## Trading Team Plan") + report_parts.append(f"{self.report_sections['trader_investment_plan']}") + + # Portfolio Management Decision + if self.report_sections["final_trade_decision"]: + report_parts.append("## Portfolio Management Decision") + report_parts.append(f"{self.report_sections['final_trade_decision']}") + + self.final_report = "\n\n".join(report_parts) if report_parts else None + + +message_buffer = MessageBuffer() + + +def create_layout(): + layout = Layout() + layout.split_column( + Layout(name="header", size=3), + Layout(name="main"), + Layout(name="footer", size=3), + ) + layout["main"].split_column( + Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5) + ) + layout["upper"].split_row( + Layout(name="progress", ratio=2), Layout(name="messages", ratio=3) + ) + return layout + + +def update_display(layout, spinner_text=None): + # Header with welcome message + layout["header"].update( + Panel( + "[bold green]Welcome to TradingAgents CLI[/bold green]\n" + "[dim]ยฉ [Tauric Research](https://github.com/TauricResearch)[/dim]", + title="Welcome to TradingAgents", + border_style="green", + padding=(1, 2), + expand=True, + ) + ) + + # Progress panel showing agent status + progress_table = Table( + show_header=True, + header_style="bold magenta", + show_footer=False, + box=box.SIMPLE_HEAD, # Use simple header with horizontal lines + title=None, # Remove the redundant Progress title + padding=(0, 2), # Add horizontal padding + expand=True, # Make table expand to fill available space + ) + progress_table.add_column("Team", style="cyan", justify="center", width=20) + progress_table.add_column("Agent", style="green", justify="center", width=20) + progress_table.add_column("Status", style="yellow", justify="center", width=20) + + # Group agents by team + teams = { + "Analyst Team": [ + "Market Analyst", + "Social Analyst", + "News Analyst", + "Fundamentals Analyst", + ], + "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], + "Trading Team": ["Trader"], + "Risk Management": ["Risky Analyst", "Neutral Analyst", "Safe Analyst"], + "Portfolio Management": ["Portfolio Manager"], + } + + for team, agents in teams.items(): + # Add first agent with team name + first_agent = agents[0] + status = message_buffer.agent_status[first_agent] + if status == "in_progress": + spinner = Spinner( + "dots", text="[blue]in_progress[/blue]", style="bold cyan" + ) + status_cell = spinner + else: + status_color = { + "pending": "yellow", + "completed": "green", + "error": "red", + }.get(status, "white") + status_cell = f"[{status_color}]{status}[/{status_color}]" + progress_table.add_row(team, first_agent, status_cell) + + # Add remaining agents in team + for agent in agents[1:]: + status = message_buffer.agent_status[agent] + if status == "in_progress": + spinner = Spinner( + "dots", text="[blue]in_progress[/blue]", style="bold cyan" + ) + status_cell = spinner + else: + status_color = { + "pending": "yellow", + "completed": "green", + "error": "red", + }.get(status, "white") + status_cell = f"[{status_color}]{status}[/{status_color}]" + progress_table.add_row("", agent, status_cell) + + # Add horizontal line after each team + progress_table.add_row("โ”€" * 20, "โ”€" * 20, "โ”€" * 20, style="dim") + + layout["progress"].update( + Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)) + ) + + # Messages panel showing recent messages and tool calls + messages_table = Table( + show_header=True, + header_style="bold magenta", + show_footer=False, + expand=True, # Make table expand to fill available space + box=box.MINIMAL, # Use minimal box style for a lighter look + show_lines=True, # Keep horizontal lines + padding=(0, 1), # Add some padding between columns + ) + messages_table.add_column("Time", style="cyan", width=8, justify="center") + messages_table.add_column("Type", style="green", width=10, justify="center") + messages_table.add_column( + "Content", style="white", no_wrap=False, ratio=1 + ) # Make content column expand + + # Combine tool calls and messages + all_messages = [] + + # Add tool calls + for timestamp, tool_name, args in message_buffer.tool_calls: + # Truncate tool call args if too long + if isinstance(args, str) and len(args) > 100: + args = args[:97] + "..." + all_messages.append((timestamp, "Tool", f"{tool_name}: {args}")) + + # Add regular messages + for timestamp, msg_type, content in message_buffer.messages: + # Convert content to string if it's not already + content_str = content + if isinstance(content, list): + # Handle list of content blocks (Anthropic format) + text_parts = [] + for item in content: + if isinstance(item, dict): + if item.get('type') == 'text': + text_parts.append(item.get('text', '')) + elif item.get('type') == 'tool_use': + text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") + else: + text_parts.append(str(item)) + content_str = ' '.join(text_parts) + elif not isinstance(content_str, str): + content_str = str(content) + + # Truncate message content if too long + if len(content_str) > 200: + content_str = content_str[:197] + "..." + all_messages.append((timestamp, msg_type, content_str)) + + # Sort by timestamp + all_messages.sort(key=lambda x: x[0]) + + # Calculate how many messages we can show based on available space + # Start with a reasonable number and adjust based on content length + max_messages = 12 # Increased from 8 to better fill the space + + # Get the last N messages that will fit in the panel + recent_messages = all_messages[-max_messages:] + + # Add messages to table + for timestamp, msg_type, content in recent_messages: + # Format content with word wrapping + wrapped_content = Text(content, overflow="fold") + messages_table.add_row(timestamp, msg_type, wrapped_content) + + if spinner_text: + messages_table.add_row("", "Spinner", spinner_text) + + # Add a footer to indicate if messages were truncated + if len(all_messages) > max_messages: + messages_table.footer = ( + f"[dim]Showing last {max_messages} of {len(all_messages)} messages[/dim]" + ) + + layout["messages"].update( + Panel( + messages_table, + title="Messages & Tools", + border_style="blue", + padding=(1, 2), + ) + ) + + # Analysis panel showing current report + if message_buffer.current_report: + layout["analysis"].update( + Panel( + Markdown(message_buffer.current_report), + title="Current Report", + border_style="green", + padding=(1, 2), + ) + ) + else: + layout["analysis"].update( + Panel( + "[italic]Waiting for analysis report...[/italic]", + title="Current Report", + border_style="green", + padding=(1, 2), + ) + ) + + # Footer with statistics + tool_calls_count = len(message_buffer.tool_calls) + llm_calls_count = sum( + 1 for _, msg_type, _ in message_buffer.messages if msg_type == "Reasoning" + ) + reports_count = sum( + 1 for content in message_buffer.report_sections.values() if content is not None + ) + + stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True) + stats_table.add_column("Stats", justify="center") + stats_table.add_row( + f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}" + ) + + layout["footer"].update(Panel(stats_table, border_style="grey50")) + + +def get_user_selections(): + """Get all user selections before starting the analysis display.""" + # Display ASCII art welcome message + with open("./cli/static/welcome.txt", "r") as f: + welcome_ascii = f.read() + + # Create welcome box content + welcome_content = f"{welcome_ascii}\n" + welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" + welcome_content += "[bold]Workflow Steps:[/bold]\n" + welcome_content += "I. Analyst Team โ†’ II. Research Team โ†’ III. Trader โ†’ IV. Risk Management โ†’ V. Portfolio Management\n\n" + welcome_content += ( + "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]" + ) + + # Create and center the welcome box + welcome_box = Panel( + welcome_content, + border_style="green", + padding=(1, 2), + title="Welcome to TradingAgents", + subtitle="Multi-Agents LLM Financial Trading Framework", + ) + console.print(Align.center(welcome_box)) + console.print() # Add a blank line after the welcome box + + # Create a boxed questionnaire for each step + def create_question_box(title, prompt, default=None): + box_content = f"[bold]{title}[/bold]\n" + box_content += f"[dim]{prompt}[/dim]" + if default: + box_content += f"\n[dim]Default: {default}[/dim]" + return Panel(box_content, border_style="blue", padding=(1, 2)) + + # Step 1: Ticker symbol + console.print( + create_question_box( + "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY" + ) + ) + selected_ticker = get_ticker() + + # Step 2: Analysis date + default_date = datetime.datetime.now().strftime("%Y-%m-%d") + console.print( + create_question_box( + "Step 2: Analysis Date", + "Enter the analysis date (YYYY-MM-DD)", + default_date, + ) + ) + analysis_date = get_analysis_date() + + # Step 3: Select analysts + console.print( + create_question_box( + "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis" + ) + ) + selected_analysts = select_analysts() + console.print( + f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}" + ) + + # Step 4: Research depth + console.print( + create_question_box( + "Step 4: Research Depth", "Select your research depth level" + ) + ) + selected_research_depth = select_research_depth() + + # Step 5: OpenAI backend + console.print( + create_question_box( + "Step 5: OpenAI backend", "Select which service to talk to" + ) + ) + selected_llm_provider, backend_url = select_llm_provider() + + # Step 6: Thinking agents + console.print( + create_question_box( + "Step 6: Thinking Agents", "Select your thinking agents for analysis" + ) + ) + selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) + selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) + + return { + "ticker": selected_ticker, + "analysis_date": analysis_date, + "analysts": selected_analysts, + "research_depth": selected_research_depth, + "llm_provider": selected_llm_provider.lower(), + "backend_url": backend_url, + "shallow_thinker": selected_shallow_thinker, + "deep_thinker": selected_deep_thinker, + } + + +def get_ticker(): + """Get ticker symbol from user input.""" + return typer.prompt("", default="SPY") + + +def get_analysis_date(): + """Get the analysis date from user input.""" + while True: + date_str = typer.prompt( + "", default=datetime.datetime.now().strftime("%Y-%m-%d") + ) + try: + # Validate date format and ensure it's not in the future + analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d") + if analysis_date.date() > datetime.datetime.now().date(): + console.print("[red]Error: Analysis date cannot be in the future[/red]") + continue + return date_str + except ValueError: + console.print( + "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]" + ) + + +def display_complete_report(final_state): + """Display the complete analysis report with team-based panels.""" + console.print("\n[bold green]Complete Analysis Report[/bold green]\n") + + # I. Analyst Team Reports + analyst_reports = [] + + # Market Analyst Report + if final_state.get("market_report"): + analyst_reports.append( + Panel( + Markdown(final_state["market_report"]), + title="Market Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + # Social Analyst Report + if final_state.get("sentiment_report"): + analyst_reports.append( + Panel( + Markdown(final_state["sentiment_report"]), + title="Social Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + # News Analyst Report + if final_state.get("news_report"): + analyst_reports.append( + Panel( + Markdown(final_state["news_report"]), + title="News Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + # Fundamentals Analyst Report + if final_state.get("fundamentals_report"): + analyst_reports.append( + Panel( + Markdown(final_state["fundamentals_report"]), + title="Fundamentals Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if analyst_reports: + console.print( + Panel( + Columns(analyst_reports, equal=True, expand=True), + title="I. Analyst Team Reports", + border_style="cyan", + padding=(1, 2), + ) + ) + + # II. Research Team Reports + if final_state.get("investment_debate_state"): + research_reports = [] + debate_state = final_state["investment_debate_state"] + + # Bull Researcher Analysis + if debate_state.get("bull_history"): + research_reports.append( + Panel( + Markdown(debate_state["bull_history"]), + title="Bull Researcher", + border_style="blue", + padding=(1, 2), + ) + ) + + # Bear Researcher Analysis + if debate_state.get("bear_history"): + research_reports.append( + Panel( + Markdown(debate_state["bear_history"]), + title="Bear Researcher", + border_style="blue", + padding=(1, 2), + ) + ) + + # Research Manager Decision + if debate_state.get("judge_decision"): + research_reports.append( + Panel( + Markdown(debate_state["judge_decision"]), + title="Research Manager", + border_style="blue", + padding=(1, 2), + ) + ) + + if research_reports: + console.print( + Panel( + Columns(research_reports, equal=True, expand=True), + title="II. Research Team Decision", + border_style="magenta", + padding=(1, 2), + ) + ) + + # III. Trading Team Reports + if final_state.get("trader_investment_plan"): + console.print( + Panel( + Panel( + Markdown(final_state["trader_investment_plan"]), + title="Trader", + border_style="blue", + padding=(1, 2), + ), + title="III. Trading Team Plan", + border_style="yellow", + padding=(1, 2), + ) + ) + + # IV. Risk Management Team Reports + if final_state.get("risk_debate_state"): + risk_reports = [] + risk_state = final_state["risk_debate_state"] + + # Aggressive (Risky) Analyst Analysis + if risk_state.get("risky_history"): + risk_reports.append( + Panel( + Markdown(risk_state["risky_history"]), + title="Aggressive Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + # Conservative (Safe) Analyst Analysis + if risk_state.get("safe_history"): + risk_reports.append( + Panel( + Markdown(risk_state["safe_history"]), + title="Conservative Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + # Neutral Analyst Analysis + if risk_state.get("neutral_history"): + risk_reports.append( + Panel( + Markdown(risk_state["neutral_history"]), + title="Neutral Analyst", + border_style="blue", + padding=(1, 2), + ) + ) + + if risk_reports: + console.print( + Panel( + Columns(risk_reports, equal=True, expand=True), + title="IV. Risk Management Team Decision", + border_style="red", + padding=(1, 2), + ) + ) + + # V. Portfolio Manager Decision + if risk_state.get("judge_decision"): + console.print( + Panel( + Panel( + Markdown(risk_state["judge_decision"]), + title="Portfolio Manager", + border_style="blue", + padding=(1, 2), + ), + title="V. Portfolio Manager Decision", + border_style="green", + padding=(1, 2), + ) + ) + + +def update_research_team_status(status): + """Update status for all research team members and trader.""" + research_team = ["Bull Researcher", "Bear Researcher", "Research Manager", "Trader"] + for agent in research_team: + message_buffer.update_agent_status(agent, status) + +def extract_content_string(content): + """Extract string content from various message formats.""" + if isinstance(content, str): + return content + elif isinstance(content, list): + # Handle Anthropic's list format + text_parts = [] + for item in content: + if isinstance(item, dict): + if item.get('type') == 'text': + text_parts.append(item.get('text', '')) + elif item.get('type') == 'tool_use': + text_parts.append(f"[Tool: {item.get('name', 'unknown')}]") + else: + text_parts.append(str(item)) + return ' '.join(text_parts) + else: + return str(content) + +def run_analysis(): + # First get all user selections + selections = get_user_selections() + + # Create config with selected research depth + config = DEFAULT_CONFIG.copy() + config["max_debate_rounds"] = selections["research_depth"] + config["max_risk_discuss_rounds"] = selections["research_depth"] + config["quick_think_llm"] = selections["shallow_thinker"] + config["deep_think_llm"] = selections["deep_thinker"] + config["backend_url"] = selections["backend_url"] + config["llm_provider"] = selections["llm_provider"].lower() + + # Initialize the graph + graph = TradingAgentsGraph( + [analyst.value for analyst in selections["analysts"]], config=config, debug=True + ) + + # Create result directory + results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] + results_dir.mkdir(parents=True, exist_ok=True) + report_dir = results_dir / "reports" + report_dir.mkdir(parents=True, exist_ok=True) + log_file = results_dir / "message_tool.log" + log_file.touch(exist_ok=True) + + def save_message_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(*args, **kwargs): + func(*args, **kwargs) + timestamp, message_type, content = obj.messages[-1] + content = content.replace("\n", " ") # Replace newlines with spaces + with open(log_file, "a") as f: + f.write(f"{timestamp} [{message_type}] {content}\n") + return wrapper + + def save_tool_call_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(*args, **kwargs): + func(*args, **kwargs) + timestamp, tool_name, args = obj.tool_calls[-1] + args_str = ", ".join(f"{k}={v}" for k, v in args.items()) + with open(log_file, "a") as f: + f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") + return wrapper + + def save_report_section_decorator(obj, func_name): + func = getattr(obj, func_name) + @wraps(func) + def wrapper(section_name, content): + func(section_name, content) + if section_name in obj.report_sections and obj.report_sections[section_name] is not None: + content = obj.report_sections[section_name] + if content: + file_name = f"{section_name}.md" + with open(report_dir / file_name, "w") as f: + f.write(content) + return wrapper + + message_buffer.add_message = save_message_decorator(message_buffer, "add_message") + message_buffer.add_tool_call = save_tool_call_decorator(message_buffer, "add_tool_call") + message_buffer.update_report_section = save_report_section_decorator(message_buffer, "update_report_section") + + # Now start the display layout + layout = create_layout() + + with Live(layout, refresh_per_second=4) as live: + # Initial display + update_display(layout) + + # Add initial messages + message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") + message_buffer.add_message( + "System", f"Analysis date: {selections['analysis_date']}" + ) + message_buffer.add_message( + "System", + f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}", + ) + update_display(layout) + + # Reset agent statuses + for agent in message_buffer.agent_status: + message_buffer.update_agent_status(agent, "pending") + + # Reset report sections + for section in message_buffer.report_sections: + message_buffer.report_sections[section] = None + message_buffer.current_report = None + message_buffer.final_report = None + + # Update agent status to in_progress for the first analyst + first_analyst = f"{selections['analysts'][0].value.capitalize()} Analyst" + message_buffer.update_agent_status(first_analyst, "in_progress") + update_display(layout) + + # Create spinner text + spinner_text = ( + f"Analyzing {selections['ticker']} on {selections['analysis_date']}..." + ) + update_display(layout, spinner_text) + + # Initialize state and get graph args + init_agent_state = graph.propagator.create_initial_state( + selections["ticker"], selections["analysis_date"] + ) + args = graph.propagator.get_graph_args() + + # Stream the analysis + trace = [] + for chunk in graph.graph.stream(init_agent_state, **args): + if len(chunk["messages"]) > 0: + # Get the last message from the chunk + last_message = chunk["messages"][-1] + + # Extract message content and type + if hasattr(last_message, "content"): + content = extract_content_string(last_message.content) # Use the helper function + msg_type = "Reasoning" + else: + content = str(last_message) + msg_type = "System" + + # Add message to buffer + message_buffer.add_message(msg_type, content) + + # If it's a tool call, add it to tool calls + if hasattr(last_message, "tool_calls"): + for tool_call in last_message.tool_calls: + # Handle both dictionary and object tool calls + if isinstance(tool_call, dict): + message_buffer.add_tool_call( + tool_call["name"], tool_call["args"] + ) + else: + message_buffer.add_tool_call(tool_call.name, tool_call.args) + + # Update reports and agent status based on chunk content + # Analyst Team Reports + if "market_report" in chunk and chunk["market_report"]: + message_buffer.update_report_section( + "market_report", chunk["market_report"] + ) + message_buffer.update_agent_status("Market Analyst", "completed") + # Set next analyst to in_progress + if "social" in selections["analysts"]: + message_buffer.update_agent_status( + "Social Analyst", "in_progress" + ) + + if "sentiment_report" in chunk and chunk["sentiment_report"]: + message_buffer.update_report_section( + "sentiment_report", chunk["sentiment_report"] + ) + message_buffer.update_agent_status("Social Analyst", "completed") + # Set next analyst to in_progress + if "news" in selections["analysts"]: + message_buffer.update_agent_status( + "News Analyst", "in_progress" + ) + + if "news_report" in chunk and chunk["news_report"]: + message_buffer.update_report_section( + "news_report", chunk["news_report"] + ) + message_buffer.update_agent_status("News Analyst", "completed") + # Set next analyst to in_progress + if "fundamentals" in selections["analysts"]: + message_buffer.update_agent_status( + "Fundamentals Analyst", "in_progress" + ) + + if "fundamentals_report" in chunk and chunk["fundamentals_report"]: + message_buffer.update_report_section( + "fundamentals_report", chunk["fundamentals_report"] + ) + message_buffer.update_agent_status( + "Fundamentals Analyst", "completed" + ) + # Set all research team members to in_progress + update_research_team_status("in_progress") + + # Research Team - Handle Investment Debate State + if ( + "investment_debate_state" in chunk + and chunk["investment_debate_state"] + ): + debate_state = chunk["investment_debate_state"] + + # Update Bull Researcher status and report + if "bull_history" in debate_state and debate_state["bull_history"]: + # Keep all research team members in progress + update_research_team_status("in_progress") + # Extract latest bull response + bull_responses = debate_state["bull_history"].split("\n") + latest_bull = bull_responses[-1] if bull_responses else "" + if latest_bull: + message_buffer.add_message("Reasoning", latest_bull) + # Update research report with bull's latest analysis + message_buffer.update_report_section( + "investment_plan", + f"### Bull Researcher Analysis\n{latest_bull}", + ) + + # Update Bear Researcher status and report + if "bear_history" in debate_state and debate_state["bear_history"]: + # Keep all research team members in progress + update_research_team_status("in_progress") + # Extract latest bear response + bear_responses = debate_state["bear_history"].split("\n") + latest_bear = bear_responses[-1] if bear_responses else "" + if latest_bear: + message_buffer.add_message("Reasoning", latest_bear) + # Update research report with bear's latest analysis + message_buffer.update_report_section( + "investment_plan", + f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}", + ) + + # Update Research Manager status and final decision + if ( + "judge_decision" in debate_state + and debate_state["judge_decision"] + ): + # Keep all research team members in progress until final decision + update_research_team_status("in_progress") + message_buffer.add_message( + "Reasoning", + f"Research Manager: {debate_state['judge_decision']}", + ) + # Update research report with final decision + message_buffer.update_report_section( + "investment_plan", + f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}", + ) + # Mark all research team members as completed + update_research_team_status("completed") + # Set first risk analyst to in_progress + message_buffer.update_agent_status( + "Risky Analyst", "in_progress" + ) + + # Trading Team + if ( + "trader_investment_plan" in chunk + and chunk["trader_investment_plan"] + ): + message_buffer.update_report_section( + "trader_investment_plan", chunk["trader_investment_plan"] + ) + # Set first risk analyst to in_progress + message_buffer.update_agent_status("Risky Analyst", "in_progress") + + # Risk Management Team - Handle Risk Debate State + if "risk_debate_state" in chunk and chunk["risk_debate_state"]: + risk_state = chunk["risk_debate_state"] + + # Update Risky Analyst status and report + if ( + "current_risky_response" in risk_state + and risk_state["current_risky_response"] + ): + message_buffer.update_agent_status( + "Risky Analyst", "in_progress" + ) + message_buffer.add_message( + "Reasoning", + f"Risky Analyst: {risk_state['current_risky_response']}", + ) + # Update risk report with risky analyst's latest analysis only + message_buffer.update_report_section( + "final_trade_decision", + f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}", + ) + + # Update Safe Analyst status and report + if ( + "current_safe_response" in risk_state + and risk_state["current_safe_response"] + ): + message_buffer.update_agent_status( + "Safe Analyst", "in_progress" + ) + message_buffer.add_message( + "Reasoning", + f"Safe Analyst: {risk_state['current_safe_response']}", + ) + # Update risk report with safe analyst's latest analysis only + message_buffer.update_report_section( + "final_trade_decision", + f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}", + ) + + # Update Neutral Analyst status and report + if ( + "current_neutral_response" in risk_state + and risk_state["current_neutral_response"] + ): + message_buffer.update_agent_status( + "Neutral Analyst", "in_progress" + ) + message_buffer.add_message( + "Reasoning", + f"Neutral Analyst: {risk_state['current_neutral_response']}", + ) + # Update risk report with neutral analyst's latest analysis only + message_buffer.update_report_section( + "final_trade_decision", + f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}", + ) + + # Update Portfolio Manager status and final decision + if "judge_decision" in risk_state and risk_state["judge_decision"]: + message_buffer.update_agent_status( + "Portfolio Manager", "in_progress" + ) + message_buffer.add_message( + "Reasoning", + f"Portfolio Manager: {risk_state['judge_decision']}", + ) + # Update risk report with final decision only + message_buffer.update_report_section( + "final_trade_decision", + f"### Portfolio Manager Decision\n{risk_state['judge_decision']}", + ) + # Mark risk analysts as completed + message_buffer.update_agent_status("Risky Analyst", "completed") + message_buffer.update_agent_status("Safe Analyst", "completed") + message_buffer.update_agent_status( + "Neutral Analyst", "completed" + ) + message_buffer.update_agent_status( + "Portfolio Manager", "completed" + ) + + # Update the display + update_display(layout) + + trace.append(chunk) + + # Get final state and decision + final_state = trace[-1] + decision = graph.process_signal(final_state["final_trade_decision"]) + + # Update all agent statuses to completed + for agent in message_buffer.agent_status: + message_buffer.update_agent_status(agent, "completed") + + message_buffer.add_message( + "Analysis", f"Completed analysis for {selections['analysis_date']}" + ) + + # Update final report sections + for section in message_buffer.report_sections.keys(): + if section in final_state: + message_buffer.update_report_section(section, final_state[section]) + + # Display the complete final report + display_complete_report(final_state) + + update_display(layout) + + +@app.command() +def analyze(): + run_analysis() + + +if __name__ == "__main__": + app() diff --git a/cli/models.py b/cli/models.py index f68c3da1..06e67623 100644 --- a/cli/models.py +++ b/cli/models.py @@ -1,10 +1,10 @@ -from enum import Enum -from typing import List, Optional, Dict -from pydantic import BaseModel - - -class AnalystType(str, Enum): - MARKET = "market" - SOCIAL = "social" - NEWS = "news" - FUNDAMENTALS = "fundamentals" +from enum import Enum +from typing import List, Optional, Dict +from pydantic import BaseModel + + +class AnalystType(str, Enum): + MARKET = "market" + SOCIAL = "social" + NEWS = "news" + FUNDAMENTALS = "fundamentals" diff --git a/cli/static/welcome.txt b/cli/static/welcome.txt index f2cf641d..95a28e58 100644 --- a/cli/static/welcome.txt +++ b/cli/static/welcome.txt @@ -1,7 +1,7 @@ - - ______ ___ ___ __ - /_ __/________ _____/ (_)___ ____ _/ | ____ ____ ____ / /______ - / / / ___/ __ `/ __ / / __ \/ __ `/ /| |/ __ `/ _ \/ __ \/ __/ ___/ - / / / / / /_/ / /_/ / / / / / /_/ / ___ / /_/ / __/ / / / /_(__ ) -/_/ /_/ \__,_/\__,_/_/_/ /_/\__, /_/ |_\__, /\___/_/ /_/\__/____/ - /____/ /____/ + + ______ ___ ___ __ + /_ __/________ _____/ (_)___ ____ _/ | ____ ____ ____ / /______ + / / / ___/ __ `/ __ / / __ \/ __ `/ /| |/ __ `/ _ \/ __ \/ __/ ___/ + / / / / / /_/ / /_/ / / / / / /_/ / ___ / /_/ / __/ / / / /_(__ ) +/_/ /_/ \__,_/\__,_/_/_/ /_/\__, /_/ |_\__, /\___/_/ /_/\__/____/ + /____/ /____/ diff --git a/cli/utils.py b/cli/utils.py index 7b9682a6..6096507a 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,276 +1,276 @@ -import questionary -from typing import List, Optional, Tuple, Dict - -from cli.models import AnalystType - -ANALYST_ORDER = [ - ("Market Analyst", AnalystType.MARKET), - ("Social Media Analyst", AnalystType.SOCIAL), - ("News Analyst", AnalystType.NEWS), - ("Fundamentals Analyst", AnalystType.FUNDAMENTALS), -] - - -def get_ticker() -> str: - """Prompt the user to enter a ticker symbol.""" - ticker = questionary.text( - "Enter the ticker symbol to analyze:", - validate=lambda x: len(x.strip()) > 0 or "Please enter a valid ticker symbol.", - style=questionary.Style( - [ - ("text", "fg:green"), - ("highlighted", "noinherit"), - ] - ), - ).ask() - - if not ticker: - console.print("\n[red]No ticker symbol provided. Exiting...[/red]") - exit(1) - - return ticker.strip().upper() - - -def get_analysis_date() -> str: - """Prompt the user to enter a date in YYYY-MM-DD format.""" - import re - from datetime import datetime - - def validate_date(date_str: str) -> bool: - if not re.match(r"^\d{4}-\d{2}-\d{2}$", date_str): - return False - try: - datetime.strptime(date_str, "%Y-%m-%d") - return True - except ValueError: - return False - - date = questionary.text( - "Enter the analysis date (YYYY-MM-DD):", - validate=lambda x: validate_date(x.strip()) - or "Please enter a valid date in YYYY-MM-DD format.", - style=questionary.Style( - [ - ("text", "fg:green"), - ("highlighted", "noinherit"), - ] - ), - ).ask() - - if not date: - console.print("\n[red]No date provided. Exiting...[/red]") - exit(1) - - return date.strip() - - -def select_analysts() -> List[AnalystType]: - """Select analysts using an interactive checkbox.""" - choices = questionary.checkbox( - "Select Your [Analysts Team]:", - choices=[ - questionary.Choice(display, value=value) for display, value in ANALYST_ORDER - ], - instruction="\n- Press Space to select/unselect analysts\n- Press 'a' to select/unselect all\n- Press Enter when done", - validate=lambda x: len(x) > 0 or "You must select at least one analyst.", - style=questionary.Style( - [ - ("checkbox-selected", "fg:green"), - ("selected", "fg:green noinherit"), - ("highlighted", "noinherit"), - ("pointer", "noinherit"), - ] - ), - ).ask() - - if not choices: - console.print("\n[red]No analysts selected. Exiting...[/red]") - exit(1) - - return choices - - -def select_research_depth() -> int: - """Select research depth using an interactive selection.""" - - # Define research depth options with their corresponding values - DEPTH_OPTIONS = [ - ("Shallow - Quick research, few debate and strategy discussion rounds", 1), - ("Medium - Middle ground, moderate debate rounds and strategy discussion", 3), - ("Deep - Comprehensive research, in depth debate and strategy discussion", 5), - ] - - choice = questionary.select( - "Select Your [Research Depth]:", - choices=[ - questionary.Choice(display, value=value) for display, value in DEPTH_OPTIONS - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:yellow noinherit"), - ("highlighted", "fg:yellow noinherit"), - ("pointer", "fg:yellow noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print("\n[red]No research depth selected. Exiting...[/red]") - exit(1) - - return choice - - -def select_shallow_thinking_agent(provider) -> str: - """Select shallow thinking llm engine using an interactive selection.""" - - # Define shallow thinking llm engine options with their corresponding model names - SHALLOW_AGENT_OPTIONS = { - "openai": [ - ("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"), - ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), - ("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"), - ("GPT-4o - Standard model with solid capabilities", "gpt-4o"), - ], - "anthropic": [ - ("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"), - ("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"), - ("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"), - ("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"), - ], - "google": [ - ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), - ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), - ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), - ], - "openrouter": [ - ("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"), - ("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"), - ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), - ], - "ollama": [ - ("llama3.1 local", "llama3.1"), - ("llama3.2 local", "llama3.2"), - ] - } - - choice = questionary.select( - "Select Your [Quick-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print( - "\n[red]No shallow thinking llm engine selected. Exiting...[/red]" - ) - exit(1) - - return choice - - -def select_deep_thinking_agent(provider) -> str: - """Select deep thinking llm engine using an interactive selection.""" - - # Define deep thinking llm engine options with their corresponding model names - DEEP_AGENT_OPTIONS = { - "openai": [ - ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), - ("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"), - ("GPT-4o - Standard model with solid capabilities", "gpt-4o"), - ("o4-mini - Specialized reasoning model (compact)", "o4-mini"), - ("o3-mini - Advanced reasoning model (lightweight)", "o3-mini"), - ("o3 - Full advanced reasoning model", "o3"), - ("o1 - Premier reasoning and problem-solving model", "o1"), - ], - "anthropic": [ - ("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"), - ("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"), - ("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"), - ("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"), - ("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"), - ], - "google": [ - ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), - ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), - ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), - ("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"), - ], - "openrouter": [ - ("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"), - ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), - ], - "ollama": [ - ("llama3.1 local", "llama3.1"), - ("qwen3", "qwen3"), - ] - } - - choice = questionary.select( - "Select Your [Deep-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in DEEP_AGENT_OPTIONS[provider.lower()] - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") - exit(1) - - return choice - -def select_llm_provider() -> tuple[str, str]: - """Select the OpenAI api url using interactive selection.""" - # Define OpenAI api options with their corresponding endpoints - BASE_URLS = [ - ("OpenAI", "https://api.openai.com/v1"), - ("Anthropic", "https://api.anthropic.com/"), - ("Google", "https://generativelanguage.googleapis.com/v1"), - ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), - ] - - choice = questionary.select( - "Select your LLM Provider:", - choices=[ - questionary.Choice(display, value=(display, value)) - for display, value in BASE_URLS - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") - exit(1) - - display_name, url = choice - print(f"You selected: {display_name}\tURL: {url}") - - return display_name, url +import questionary +from typing import List, Optional, Tuple, Dict + +from cli.models import AnalystType + +ANALYST_ORDER = [ + ("Market Analyst", AnalystType.MARKET), + ("Social Media Analyst", AnalystType.SOCIAL), + ("News Analyst", AnalystType.NEWS), + ("Fundamentals Analyst", AnalystType.FUNDAMENTALS), +] + + +def get_ticker() -> str: + """Prompt the user to enter a ticker symbol.""" + ticker = questionary.text( + "Enter the ticker symbol to analyze:", + validate=lambda x: len(x.strip()) > 0 or "Please enter a valid ticker symbol.", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() + + if not ticker: + console.print("\n[red]No ticker symbol provided. Exiting...[/red]") + exit(1) + + return ticker.strip().upper() + + +def get_analysis_date() -> str: + """Prompt the user to enter a date in YYYY-MM-DD format.""" + import re + from datetime import datetime + + def validate_date(date_str: str) -> bool: + if not re.match(r"^\d{4}-\d{2}-\d{2}$", date_str): + return False + try: + datetime.strptime(date_str, "%Y-%m-%d") + return True + except ValueError: + return False + + date = questionary.text( + "Enter the analysis date (YYYY-MM-DD):", + validate=lambda x: validate_date(x.strip()) + or "Please enter a valid date in YYYY-MM-DD format.", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() + + if not date: + console.print("\n[red]No date provided. Exiting...[/red]") + exit(1) + + return date.strip() + + +def select_analysts() -> List[AnalystType]: + """Select analysts using an interactive checkbox.""" + choices = questionary.checkbox( + "Select Your [Analysts Team]:", + choices=[ + questionary.Choice(display, value=value) for display, value in ANALYST_ORDER + ], + instruction="\n- Press Space to select/unselect analysts\n- Press 'a' to select/unselect all\n- Press Enter when done", + validate=lambda x: len(x) > 0 or "You must select at least one analyst.", + style=questionary.Style( + [ + ("checkbox-selected", "fg:green"), + ("selected", "fg:green noinherit"), + ("highlighted", "noinherit"), + ("pointer", "noinherit"), + ] + ), + ).ask() + + if not choices: + console.print("\n[red]No analysts selected. Exiting...[/red]") + exit(1) + + return choices + + +def select_research_depth() -> int: + """Select research depth using an interactive selection.""" + + # Define research depth options with their corresponding values + DEPTH_OPTIONS = [ + ("Shallow - Quick research, few debate and strategy discussion rounds", 1), + ("Medium - Middle ground, moderate debate rounds and strategy discussion", 3), + ("Deep - Comprehensive research, in depth debate and strategy discussion", 5), + ] + + choice = questionary.select( + "Select Your [Research Depth]:", + choices=[ + questionary.Choice(display, value=value) for display, value in DEPTH_OPTIONS + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:yellow noinherit"), + ("highlighted", "fg:yellow noinherit"), + ("pointer", "fg:yellow noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print("\n[red]No research depth selected. Exiting...[/red]") + exit(1) + + return choice + + +def select_shallow_thinking_agent(provider) -> str: + """Select shallow thinking llm engine using an interactive selection.""" + + # Define shallow thinking llm engine options with their corresponding model names + SHALLOW_AGENT_OPTIONS = { + "openai": [ + ("GPT-4o-mini - Fast and efficient for quick tasks", "gpt-4o-mini"), + ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), + ("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"), + ("GPT-4o - Standard model with solid capabilities", "gpt-4o"), + ], + "anthropic": [ + ("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"), + ("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"), + ("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"), + ("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"), + ], + "google": [ + ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), + ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), + ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), + ], + "openrouter": [ + ("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"), + ("Meta: Llama 3.3 8B Instruct - A lightweight and ultra-fast variant of Llama 3.3 70B", "meta-llama/llama-3.3-8b-instruct:free"), + ("google/gemini-2.0-flash-exp:free - Gemini Flash 2.0 offers a significantly faster time to first token", "google/gemini-2.0-flash-exp:free"), + ], + "ollama": [ + ("llama3.1 local", "llama3.1"), + ("llama3.2 local", "llama3.2"), + ] + } + + choice = questionary.select( + "Select Your [Quick-Thinking LLM Engine]:", + choices=[ + questionary.Choice(display, value=value) + for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print( + "\n[red]No shallow thinking llm engine selected. Exiting...[/red]" + ) + exit(1) + + return choice + + +def select_deep_thinking_agent(provider) -> str: + """Select deep thinking llm engine using an interactive selection.""" + + # Define deep thinking llm engine options with their corresponding model names + DEEP_AGENT_OPTIONS = { + "openai": [ + ("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"), + ("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"), + ("GPT-4o - Standard model with solid capabilities", "gpt-4o"), + ("o4-mini - Specialized reasoning model (compact)", "o4-mini"), + ("o3-mini - Advanced reasoning model (lightweight)", "o3-mini"), + ("o3 - Full advanced reasoning model", "o3"), + ("o1 - Premier reasoning and problem-solving model", "o1"), + ], + "anthropic": [ + ("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"), + ("Claude Sonnet 3.5 - Highly capable standard model", "claude-3-5-sonnet-latest"), + ("Claude Sonnet 3.7 - Exceptional hybrid reasoning and agentic capabilities", "claude-3-7-sonnet-latest"), + ("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"), + ("Claude Opus 4 - Most powerful Anthropic model", " claude-opus-4-0"), + ], + "google": [ + ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), + ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), + ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), + ("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"), + ], + "openrouter": [ + ("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"), + ("Deepseek - latest iteration of the flagship chat model family from the DeepSeek team.", "deepseek/deepseek-chat-v3-0324:free"), + ], + "ollama": [ + ("llama3.1 local", "llama3.1"), + ("qwen3", "qwen3"), + ] + } + + choice = questionary.select( + "Select Your [Deep-Thinking LLM Engine]:", + choices=[ + questionary.Choice(display, value=value) + for display, value in DEEP_AGENT_OPTIONS[provider.lower()] + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") + exit(1) + + return choice + +def select_llm_provider() -> tuple[str, str]: + """Select the OpenAI api url using interactive selection.""" + # Define OpenAI api options with their corresponding endpoints + BASE_URLS = [ + ("OpenAI", "https://api.openai.com/v1"), + ("Anthropic", "https://api.anthropic.com/"), + ("Google", "https://generativelanguage.googleapis.com/v1"), + ("Openrouter", "https://openrouter.ai/api/v1"), + ("Ollama", "http://localhost:11434/v1"), + ] + + choice = questionary.select( + "Select your LLM Provider:", + choices=[ + questionary.Choice(display, value=(display, value)) + for display, value in BASE_URLS + ], + instruction="\n- Use arrow keys to navigate\n- Press Enter to select", + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") + exit(1) + + display_name, url = choice + print(f"You selected: {display_name}\tURL: {url}") + + return display_name, url diff --git a/main.py b/main.py index a85ee6ec..4f4474fd 100644 --- a/main.py +++ b/main.py @@ -1,31 +1,31 @@ -from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.default_config import DEFAULT_CONFIG - -from dotenv import load_dotenv - -# Load environment variables from .env file -load_dotenv() - -# Create a custom config -config = DEFAULT_CONFIG.copy() -config["deep_think_llm"] = "gpt-4o-mini" # Use a different model -config["quick_think_llm"] = "gpt-4o-mini" # Use a different model -config["max_debate_rounds"] = 1 # Increase debate rounds - -# Configure data vendors (default uses yfinance and alpha_vantage) -config["data_vendors"] = { - "core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local - "technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local - "fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local - "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local -} - -# Initialize with custom config -ta = TradingAgentsGraph(debug=True, config=config) - -# forward propagate -_, decision = ta.propagate("NVDA", "2024-05-10") -print(decision) - -# Memorize mistakes and reflect -# ta.reflect_and_remember(1000) # parameter is the position returns +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.default_config import DEFAULT_CONFIG + +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +# Create a custom config +config = DEFAULT_CONFIG.copy() +config["deep_think_llm"] = "gpt-4o-mini" # Use a different model +config["quick_think_llm"] = "gpt-4o-mini" # Use a different model +config["max_debate_rounds"] = 1 # Increase debate rounds + +# Configure data vendors (default uses yfinance and alpha_vantage) +config["data_vendors"] = { + "core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local + "technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local + "fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local + "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local +} + +# Initialize with custom config +ta = TradingAgentsGraph(debug=True, config=config) + +# forward propagate +_, decision = ta.propagate("NVDA", "2024-05-10") +print(decision) + +# Memorize mistakes and reflect +# ta.reflect_and_remember(1000) # parameter is the position returns diff --git a/main_screening.py b/main_screening.py new file mode 100644 index 00000000..a93439ef --- /dev/null +++ b/main_screening.py @@ -0,0 +1,114 @@ +import os +from dotenv import load_dotenv +from langchain_openai import ChatOpenAI +from langchain_core.messages import HumanMessage + +from tradingagents.agents.screening_agent import create_screening_agent +from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.default_config import DEFAULT_CONFIG + +# Load environment variables +load_dotenv() + +def main(): + print("--- Starting Market Screening ---") + + # 1. Initialize LLM for screening + config = DEFAULT_CONFIG.copy() + llm = ChatOpenAI(model=config["quick_think_llm"]) + + # 2. Create and run Screening Agent + screener = create_screening_agent(llm) + + # Initial state for screening + state = {"messages": [HumanMessage(content="Find me the top gainers today and pick the most interesting one to analyze.")]} + + print("Running Screening Agent...") + # In a real graph we would use langgraph, but here we can just invoke the node function for simplicity + # or build a mini-graph. Let's just invoke the node loop manually for this demo. + + # Loop to handle multiple rounds of tool calls + max_iterations = 5 + iteration = 0 + + while iteration < max_iterations: + iteration += 1 + + # Invoke agent + result = screener(state) + state["messages"].extend(result["messages"]) + last_msg = result["messages"][-1] + + # Check if tool call + if not last_msg.tool_calls: + # No more tools, this is the final response + print(f"Screening Agent Recommendation: {last_msg.content}") + + # Extract ticker (simple heuristic) + import re + tickers = re.findall(r'\b[A-Z]{2,5}\b', last_msg.content) + + if tickers: + target_ticker = tickers[0] # Pick the first one + print(f"Selected Ticker: {target_ticker}") + + # 3. Run TradingAgentsGraph on the selected ticker + print(f"--- Starting Analysis for {target_ticker} ---") + ta = TradingAgentsGraph(debug=True, config=config) + + from datetime import datetime + today = datetime.now().strftime("%Y-%m-%d") + + _, decision = ta.propagate(target_ticker, today) + print("\nFinal Decision:") + print(decision) + else: + print("No tickers found in recommendation.") + + break + + else: + print(f"Tool Call (Iter {iteration}): {last_msg.tool_calls}") + + # Execute tools + from tradingagents.agents.utils.core_stock_tools import get_market_movers, get_earnings_calendar + from tradingagents.agents.utils.news_data_tools import get_insider_transactions + from tradingagents.agents.utils.technical_indicators_tools import get_indicators + from tradingagents.dataflows.social_sentiment import get_trending_social + + tool_outputs = [] + for tool_call in last_msg.tool_calls: + output = "Error: Tool not found" + try: + if tool_call["name"] == "get_market_movers": + output = get_market_movers.invoke(tool_call["args"]) + elif tool_call["name"] == "get_earnings_calendar": + output = get_earnings_calendar.invoke(tool_call["args"]) + elif tool_call["name"] == "get_insider_transactions": + args = tool_call["args"] + if "curr_date" not in args: + from datetime import datetime + args["curr_date"] = datetime.now().strftime("%Y-%m-%d") + output = get_insider_transactions.invoke(args) + elif tool_call["name"] == "get_indicators": + args = tool_call["args"] + if "curr_date" not in args: + from datetime import datetime + args["curr_date"] = datetime.now().strftime("%Y-%m-%d") + output = get_indicators.invoke(args) + elif tool_call["name"] == "get_trending_social": + output = get_trending_social.invoke(tool_call["args"]) + except Exception as e: + output = f"Tool execution failed: {str(e)}" + + tool_outputs.append( + {"tool_call_id": tool_call["id"], "content": str(output)} + ) + + # Add tool outputs to messages + from langchain_core.messages import ToolMessage + for output in tool_outputs: + state["messages"].append(ToolMessage(content=output["content"], tool_call_id=output["tool_call_id"])) + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index a6154cd2..5e700795 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,26 +1,27 @@ -typing-extensions -langchain-openai -langchain-experimental -pandas -yfinance -praw -feedparser -stockstats -eodhd -langgraph -chromadb -setuptools -backtrader -akshare -tushare -finnhub-python -parsel -requests -tqdm -pytz -redis -chainlit -rich -questionary -langchain_anthropic -langchain-google-genai +typing-extensions +langchain-openai +langchain-experimental +pandas +yfinance +praw +feedparser +stockstats +eodhd +langgraph +chromadb +setuptools +backtrader +akshare +tushare +finnhub-python +parsel +requests +tqdm +pytz +redis +chainlit +rich +questionary +langchain_anthropic +langchain-google-genai +python-dotenv diff --git a/setup.py b/setup.py index 793df3e6..7e999eb2 100644 --- a/setup.py +++ b/setup.py @@ -1,43 +1,44 @@ -""" -Setup script for the TradingAgents package. -""" - -from setuptools import setup, find_packages - -setup( - name="tradingagents", - version="0.1.0", - description="Multi-Agents LLM Financial Trading Framework", - author="TradingAgents Team", - author_email="yijia.xiao@cs.ucla.edu", - url="https://github.com/TauricResearch", - packages=find_packages(), - install_requires=[ - "langchain>=0.1.0", - "langchain-openai>=0.0.2", - "langchain-experimental>=0.0.40", - "langgraph>=0.0.20", - "numpy>=1.24.0", - "pandas>=2.0.0", - "praw>=7.7.0", - "stockstats>=0.5.4", - "yfinance>=0.2.31", - "typer>=0.9.0", - "rich>=13.0.0", - "questionary>=2.0.1", - ], - python_requires=">=3.10", - entry_points={ - "console_scripts": [ - "tradingagents=cli.main:app", - ], - }, - classifiers=[ - "Development Status :: 3 - Alpha", - "Intended Audience :: Financial and Trading Industry", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Topic :: Office/Business :: Financial :: Investment", - ], -) +""" +Setup script for the TradingAgents package. +""" + +from setuptools import setup, find_packages + +setup( + name="tradingagents", + version="0.1.0", + description="Multi-Agents LLM Financial Trading Framework", + author="TradingAgents Team", + author_email="yijia.xiao@cs.ucla.edu", + url="https://github.com/TauricResearch", + packages=find_packages(), + install_requires=[ + "langchain>=0.1.0", + "langchain-openai>=0.0.2", + "langchain-experimental>=0.0.40", + "langgraph>=0.0.20", + "numpy>=1.24.0", + "pandas>=2.0.0", + "praw>=7.7.0", + "stockstats>=0.5.4", + "yfinance>=0.2.31", + "typer>=0.9.0", + "rich>=13.0.0", + "questionary>=2.0.1", + "python-dotenv>=1.0.0", + ], + python_requires=">=3.10", + entry_points={ + "console_scripts": [ + "tradingagents=cli.main:app", + ], + }, + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Financial and Trading Industry", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Topic :: Office/Business :: Financial :: Investment", + ], +) diff --git a/test.py b/test.py index b73783e1..c52b917d 100644 --- a/test.py +++ b/test.py @@ -1,11 +1,11 @@ -import time -from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions - -print("Testing optimized implementation with 30-day lookback:") -start_time = time.time() -result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30) -end_time = time.time() - -print(f"Execution time: {end_time - start_time:.2f} seconds") -print(f"Result length: {len(result)} characters") -print(result) +import time +from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions + +print("Testing optimized implementation with 30-day lookback:") +start_time = time.time() +result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30) +end_time = time.time() + +print(f"Execution time: {end_time - start_time:.2f} seconds") +print(f"Result length: {len(result)} characters") +print(result) diff --git a/test_social_apis.py b/test_social_apis.py new file mode 100644 index 00000000..5df5b3c3 --- /dev/null +++ b/test_social_apis.py @@ -0,0 +1,41 @@ +import requests +import json + +def test_stocktwits(): + print("Testing StockTwits Trending...") + url = "https://api.stocktwits.com/api/2/trending/symbols.json" + try: + response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}) + if response.status_code == 200: + data = response.json() + symbols = [s['symbol'] for s in data['symbols']] + print(f"StockTwits Trending: {symbols[:5]}") + return True + else: + print(f"StockTwits Failed: {response.status_code}") + return False + except Exception as e: + print(f"StockTwits Error: {e}") + return False + +def test_apewisdom(): + print("\nTesting Ape Wisdom (Reddit)...") + url = "https://apewisdom.io/api/v1.0/filter/all-stocks/page/1" + try: + response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}) + if response.status_code == 200: + data = response.json() + # Ape Wisdom returns a list of objects + symbols = [s['ticker'] for s in data['results']] + print(f"Ape Wisdom Trending: {symbols[:5]}") + return True + else: + print(f"Ape Wisdom Failed: {response.status_code}") + return False + except Exception as e: + print(f"Ape Wisdom Error: {e}") + return False + +if __name__ == "__main__": + st_success = test_stocktwits() + aw_success = test_apewisdom() diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index d84d9eb1..fe39b382 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -1,40 +1,40 @@ -from .utils.agent_utils import create_msg_delete -from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState -from .utils.memory import FinancialSituationMemory - -from .analysts.fundamentals_analyst import create_fundamentals_analyst -from .analysts.market_analyst import create_market_analyst -from .analysts.news_analyst import create_news_analyst -from .analysts.social_media_analyst import create_social_media_analyst - -from .researchers.bear_researcher import create_bear_researcher -from .researchers.bull_researcher import create_bull_researcher - -from .risk_mgmt.aggresive_debator import create_risky_debator -from .risk_mgmt.conservative_debator import create_safe_debator -from .risk_mgmt.neutral_debator import create_neutral_debator - -from .managers.research_manager import create_research_manager -from .managers.risk_manager import create_risk_manager - -from .trader.trader import create_trader - -__all__ = [ - "FinancialSituationMemory", - "AgentState", - "create_msg_delete", - "InvestDebateState", - "RiskDebateState", - "create_bear_researcher", - "create_bull_researcher", - "create_research_manager", - "create_fundamentals_analyst", - "create_market_analyst", - "create_neutral_debator", - "create_news_analyst", - "create_risky_debator", - "create_risk_manager", - "create_safe_debator", - "create_social_media_analyst", - "create_trader", -] +from .utils.agent_utils import create_msg_delete +from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState +from .utils.memory import FinancialSituationMemory + +from .analysts.fundamentals_analyst import create_fundamentals_analyst +from .analysts.market_analyst import create_market_analyst +from .analysts.news_analyst import create_news_analyst +from .analysts.social_media_analyst import create_social_media_analyst + +from .researchers.bear_researcher import create_bear_researcher +from .researchers.bull_researcher import create_bull_researcher + +from .risk_mgmt.aggresive_debator import create_risky_debator +from .risk_mgmt.conservative_debator import create_safe_debator +from .risk_mgmt.neutral_debator import create_neutral_debator + +from .managers.research_manager import create_research_manager +from .managers.risk_manager import create_risk_manager + +from .trader.trader import create_trader + +__all__ = [ + "FinancialSituationMemory", + "AgentState", + "create_msg_delete", + "InvestDebateState", + "RiskDebateState", + "create_bear_researcher", + "create_bull_researcher", + "create_research_manager", + "create_fundamentals_analyst", + "create_market_analyst", + "create_neutral_debator", + "create_news_analyst", + "create_risky_debator", + "create_risk_manager", + "create_safe_debator", + "create_social_media_analyst", + "create_trader", +] diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index e20139cb..6f8ec4ca 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,63 +1,63 @@ -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json -from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions -from tradingagents.dataflows.config import get_config - - -def create_fundamentals_analyst(llm): - def fundamentals_analyst_node(state): - current_date = state["trade_date"] - ticker = state["company_of_interest"] - company_name = state["company_of_interest"] - - tools = [ - get_fundamentals, - get_balance_sheet, - get_cashflow, - get_income_statement, - ] - - system_message = ( - "You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - + " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read." - + " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements.", - ) - - prompt = ChatPromptTemplate.from_messages( - [ - ( - "system", - "You are a helpful AI assistant, collaborating with other assistants." - " Use the provided tools to progress towards answering the question." - " If you are unable to fully answer, that's OK; another assistant with different tools" - " will help where you left off. Execute what you can to make progress." - " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," - " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." - " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. The company we want to look at is {ticker}", - ), - MessagesPlaceholder(variable_name="messages"), - ] - ) - - prompt = prompt.partial(system_message=system_message) - prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) - prompt = prompt.partial(current_date=current_date) - prompt = prompt.partial(ticker=ticker) - - chain = prompt | llm.bind_tools(tools) - - result = chain.invoke(state["messages"]) - - report = "" - - if len(result.tool_calls) == 0: - report = result.content - - return { - "messages": [result], - "fundamentals_report": report, - } - - return fundamentals_analyst_node +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +import time +import json +from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement, get_insider_sentiment, get_insider_transactions +from tradingagents.dataflows.config import get_config + + +def create_fundamentals_analyst(llm): + def fundamentals_analyst_node(state): + current_date = state["trade_date"] + ticker = state["company_of_interest"] + company_name = state["company_of_interest"] + + tools = [ + get_fundamentals, + get_balance_sheet, + get_cashflow, + get_income_statement, + ] + + system_message = ( + "You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, and company financial history to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." + + " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read." + + " Use the available tools: `get_fundamentals` for comprehensive company analysis, `get_balance_sheet`, `get_cashflow`, and `get_income_statement` for specific financial statements.", + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," + " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." + " You have access to the following tools: {tool_names}.\n{system_message}" + "For your reference, the current date is {current_date}. The company we want to look at is {ticker}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + prompt = prompt.partial(current_date=current_date) + prompt = prompt.partial(ticker=ticker) + + chain = prompt | llm.bind_tools(tools) + + result = chain.invoke(state["messages"]) + + report = "" + + if len(result.tool_calls) == 0: + report = result.content + + return { + "messages": [result], + "fundamentals_report": report, + } + + return fundamentals_analyst_node diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index c955dd76..c1b78a16 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,85 +1,85 @@ -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json -from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators -from tradingagents.dataflows.config import get_config - - -def create_market_analyst(llm): - - def market_analyst_node(state): - current_date = state["trade_date"] - ticker = state["company_of_interest"] - company_name = state["company_of_interest"] - - tools = [ - get_stock_data, - get_indicators, - ] - - system_message = ( - """You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are: - -Moving Averages: -- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals. -- close_200_sma: 200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries. -- close_10_ema: 10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals. - -MACD Related: -- macd: MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets. -- macds: MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives. -- macdh: MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets. - -Momentum Indicators: -- rsi: RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis. - -Volatility Indicators: -- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals. -- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends. -- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals. -- atr: ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy. - -Volume-Based Indicators: -- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses. - -- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.""" - + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" - ) - - prompt = ChatPromptTemplate.from_messages( - [ - ( - "system", - "You are a helpful AI assistant, collaborating with other assistants." - " Use the provided tools to progress towards answering the question." - " If you are unable to fully answer, that's OK; another assistant with different tools" - " will help where you left off. Execute what you can to make progress." - " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," - " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." - " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. The company we want to look at is {ticker}", - ), - MessagesPlaceholder(variable_name="messages"), - ] - ) - - prompt = prompt.partial(system_message=system_message) - prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) - prompt = prompt.partial(current_date=current_date) - prompt = prompt.partial(ticker=ticker) - - chain = prompt | llm.bind_tools(tools) - - result = chain.invoke(state["messages"]) - - report = "" - - if len(result.tool_calls) == 0: - report = result.content - - return { - "messages": [result], - "market_report": report, - } - - return market_analyst_node +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +import time +import json +from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicators +from tradingagents.dataflows.config import get_config + + +def create_market_analyst(llm): + + def market_analyst_node(state): + current_date = state["trade_date"] + ticker = state["company_of_interest"] + company_name = state["company_of_interest"] + + tools = [ + get_stock_data, + get_indicators, + ] + + system_message = ( + """You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are: + +Moving Averages: +- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals. +- close_200_sma: 200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries. +- close_10_ema: 10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals. + +MACD Related: +- macd: MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets. +- macds: MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives. +- macdh: MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets. + +Momentum Indicators: +- rsi: RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis. + +Volatility Indicators: +- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals. +- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends. +- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals. +- atr: ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy. + +Volume-Based Indicators: +- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses. + +- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_stock_data first to retrieve the CSV that is needed to generate indicators. Then use get_indicators with the specific indicator names. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.""" + + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," + " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." + " You have access to the following tools: {tool_names}.\n{system_message}" + "For your reference, the current date is {current_date}. The company we want to look at is {ticker}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + prompt = prompt.partial(current_date=current_date) + prompt = prompt.partial(ticker=ticker) + + chain = prompt | llm.bind_tools(tools) + + result = chain.invoke(state["messages"]) + + report = "" + + if len(result.tool_calls) == 0: + report = result.content + + return { + "messages": [result], + "market_report": report, + } + + return market_analyst_node diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index 03b4fae4..1c0ffb5a 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,58 +1,58 @@ -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json -from tradingagents.agents.utils.agent_utils import get_news, get_global_news -from tradingagents.dataflows.config import get_config - - -def create_news_analyst(llm): - def news_analyst_node(state): - current_date = state["trade_date"] - ticker = state["company_of_interest"] - - tools = [ - get_news, - get_global_news, - ] - - system_message = ( - "You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" - ) - - prompt = ChatPromptTemplate.from_messages( - [ - ( - "system", - "You are a helpful AI assistant, collaborating with other assistants." - " Use the provided tools to progress towards answering the question." - " If you are unable to fully answer, that's OK; another assistant with different tools" - " will help where you left off. Execute what you can to make progress." - " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," - " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." - " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. We are looking at the company {ticker}", - ), - MessagesPlaceholder(variable_name="messages"), - ] - ) - - prompt = prompt.partial(system_message=system_message) - prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) - prompt = prompt.partial(current_date=current_date) - prompt = prompt.partial(ticker=ticker) - - chain = prompt | llm.bind_tools(tools) - result = chain.invoke(state["messages"]) - - report = "" - - if len(result.tool_calls) == 0: - report = result.content - - return { - "messages": [result], - "news_report": report, - } - - return news_analyst_node +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +import time +import json +from tradingagents.agents.utils.agent_utils import get_news, get_global_news +from tradingagents.dataflows.config import get_config + + +def create_news_analyst(llm): + def news_analyst_node(state): + current_date = state["trade_date"] + ticker = state["company_of_interest"] + + tools = [ + get_news, + get_global_news, + ] + + system_message = ( + "You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." + + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," + " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." + " You have access to the following tools: {tool_names}.\n{system_message}" + "For your reference, the current date is {current_date}. We are looking at the company {ticker}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + prompt = prompt.partial(current_date=current_date) + prompt = prompt.partial(ticker=ticker) + + chain = prompt | llm.bind_tools(tools) + result = chain.invoke(state["messages"]) + + report = "" + + if len(result.tool_calls) == 0: + report = result.content + + return { + "messages": [result], + "news_report": report, + } + + return news_analyst_node diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index b25712d7..582bf6bb 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,59 +1,59 @@ -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -import time -import json -from tradingagents.agents.utils.agent_utils import get_news -from tradingagents.dataflows.config import get_config - - -def create_social_media_analyst(llm): - def social_media_analyst_node(state): - current_date = state["trade_date"] - ticker = state["company_of_interest"] - company_name = state["company_of_interest"] - - tools = [ - get_news, - ] - - system_message = ( - "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""", - ) - - prompt = ChatPromptTemplate.from_messages( - [ - ( - "system", - "You are a helpful AI assistant, collaborating with other assistants." - " Use the provided tools to progress towards answering the question." - " If you are unable to fully answer, that's OK; another assistant with different tools" - " will help where you left off. Execute what you can to make progress." - " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," - " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." - " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}", - ), - MessagesPlaceholder(variable_name="messages"), - ] - ) - - prompt = prompt.partial(system_message=system_message) - prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) - prompt = prompt.partial(current_date=current_date) - prompt = prompt.partial(ticker=ticker) - - chain = prompt | llm.bind_tools(tools) - - result = chain.invoke(state["messages"]) - - report = "" - - if len(result.tool_calls) == 0: - report = result.content - - return { - "messages": [result], - "sentiment_report": report, - } - - return social_media_analyst_node +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +import time +import json +from tradingagents.agents.utils.agent_utils import get_news +from tradingagents.dataflows.config import get_config + + +def create_social_media_analyst(llm): + def social_media_analyst_node(state): + current_date = state["trade_date"] + ticker = state["company_of_interest"] + company_name = state["company_of_interest"] + + tools = [ + get_news, + ] + + system_message = ( + "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Use the get_news(query, start_date, end_date) tool to search for company-specific news and social media discussions. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." + + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""", + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant, collaborating with other assistants." + " Use the provided tools to progress towards answering the question." + " If you are unable to fully answer, that's OK; another assistant with different tools" + " will help where you left off. Execute what you can to make progress." + " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," + " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." + " You have access to the following tools: {tool_names}.\n{system_message}" + "For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + prompt = prompt.partial(current_date=current_date) + prompt = prompt.partial(ticker=ticker) + + chain = prompt | llm.bind_tools(tools) + + result = chain.invoke(state["messages"]) + + report = "" + + if len(result.tool_calls) == 0: + report = result.content + + return { + "messages": [result], + "sentiment_report": report, + } + + return social_media_analyst_node diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index c537fa2f..029a83b5 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -1,55 +1,55 @@ -import time -import json - - -def create_research_manager(llm, memory): - def research_manager_node(state) -> dict: - history = state["investment_debate_state"].get("history", "") - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - - investment_debate_state = state["investment_debate_state"] - - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - - prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented. - -Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendationโ€”Buy, Sell, or Holdโ€”must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments. - -Additionally, develop a detailed investment plan for the trader. This should include: - -Your Recommendation: A decisive stance supported by the most convincing arguments. -Rationale: An explanation of why these arguments lead to your conclusion. -Strategic Actions: Concrete steps for implementing the recommendation. -Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. - -Here are your past reflections on mistakes: -\"{past_memory_str}\" - -Here is the debate: -Debate History: -{history}""" - response = llm.invoke(prompt) - - new_investment_debate_state = { - "judge_decision": response.content, - "history": investment_debate_state.get("history", ""), - "bear_history": investment_debate_state.get("bear_history", ""), - "bull_history": investment_debate_state.get("bull_history", ""), - "current_response": response.content, - "count": investment_debate_state["count"], - } - - return { - "investment_debate_state": new_investment_debate_state, - "investment_plan": response.content, - } - - return research_manager_node +import time +import json + + +def create_research_manager(llm, memory): + def research_manager_node(state) -> dict: + history = state["investment_debate_state"].get("history", "") + market_research_report = state["market_report"] + sentiment_report = state["sentiment_report"] + news_report = state["news_report"] + fundamentals_report = state["fundamentals_report"] + + investment_debate_state = state["investment_debate_state"] + + curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + past_memories = memory.get_memories(curr_situation, n_matches=2) + + past_memory_str = "" + for i, rec in enumerate(past_memories, 1): + past_memory_str += rec["recommendation"] + "\n\n" + + prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented. + +Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendationโ€”Buy, Sell, or Holdโ€”must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments. + +Additionally, develop a detailed investment plan for the trader. This should include: + +Your Recommendation: A decisive stance supported by the most convincing arguments. +Rationale: An explanation of why these arguments lead to your conclusion. +Strategic Actions: Concrete steps for implementing the recommendation. +Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. + +Here are your past reflections on mistakes: +\"{past_memory_str}\" + +Here is the debate: +Debate History: +{history}""" + response = llm.invoke(prompt) + + new_investment_debate_state = { + "judge_decision": response.content, + "history": investment_debate_state.get("history", ""), + "bear_history": investment_debate_state.get("bear_history", ""), + "bull_history": investment_debate_state.get("bull_history", ""), + "current_response": response.content, + "count": investment_debate_state["count"], + } + + return { + "investment_debate_state": new_investment_debate_state, + "investment_plan": response.content, + } + + return research_manager_node diff --git a/tradingagents/agents/managers/risk_manager.py b/tradingagents/agents/managers/risk_manager.py index fba763d6..1288cb76 100644 --- a/tradingagents/agents/managers/risk_manager.py +++ b/tradingagents/agents/managers/risk_manager.py @@ -1,66 +1,66 @@ -import time -import json - - -def create_risk_manager(llm, memory): - def risk_manager_node(state) -> dict: - - company_name = state["company_of_interest"] - - history = state["risk_debate_state"]["history"] - risk_debate_state = state["risk_debate_state"] - market_research_report = state["market_report"] - news_report = state["news_report"] - fundamentals_report = state["news_report"] - sentiment_report = state["sentiment_report"] - trader_plan = state["investment_plan"] - - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - - prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analystsโ€”Risky, Neutral, and Safe/Conservativeโ€”and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness. - -Guidelines for Decision-Making: -1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context. -2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate. -3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights. -4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money. - -Deliverables: -- A clear and actionable recommendation: Buy, Sell, or Hold. -- Detailed reasoning anchored in the debate and past reflections. - ---- - -**Analysts Debate History:** -{history} - ---- - -Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes.""" - - response = llm.invoke(prompt) - - new_risk_debate_state = { - "judge_decision": response.content, - "history": risk_debate_state["history"], - "risky_history": risk_debate_state["risky_history"], - "safe_history": risk_debate_state["safe_history"], - "neutral_history": risk_debate_state["neutral_history"], - "latest_speaker": "Judge", - "current_risky_response": risk_debate_state["current_risky_response"], - "current_safe_response": risk_debate_state["current_safe_response"], - "current_neutral_response": risk_debate_state["current_neutral_response"], - "count": risk_debate_state["count"], - } - - return { - "risk_debate_state": new_risk_debate_state, - "final_trade_decision": response.content, - } - - return risk_manager_node +import time +import json + + +def create_risk_manager(llm, memory): + def risk_manager_node(state) -> dict: + + company_name = state["company_of_interest"] + + history = state["risk_debate_state"]["history"] + risk_debate_state = state["risk_debate_state"] + market_research_report = state["market_report"] + news_report = state["news_report"] + fundamentals_report = state["news_report"] + sentiment_report = state["sentiment_report"] + trader_plan = state["investment_plan"] + + curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + past_memories = memory.get_memories(curr_situation, n_matches=2) + + past_memory_str = "" + for i, rec in enumerate(past_memories, 1): + past_memory_str += rec["recommendation"] + "\n\n" + + prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analystsโ€”Risky, Neutral, and Safe/Conservativeโ€”and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness. + +Guidelines for Decision-Making: +1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context. +2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate. +3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights. +4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money. + +Deliverables: +- A clear and actionable recommendation: Buy, Sell, or Hold. +- Detailed reasoning anchored in the debate and past reflections. + +--- + +**Analysts Debate History:** +{history} + +--- + +Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes.""" + + response = llm.invoke(prompt) + + new_risk_debate_state = { + "judge_decision": response.content, + "history": risk_debate_state["history"], + "risky_history": risk_debate_state["risky_history"], + "safe_history": risk_debate_state["safe_history"], + "neutral_history": risk_debate_state["neutral_history"], + "latest_speaker": "Judge", + "current_risky_response": risk_debate_state["current_risky_response"], + "current_safe_response": risk_debate_state["current_safe_response"], + "current_neutral_response": risk_debate_state["current_neutral_response"], + "count": risk_debate_state["count"], + } + + return { + "risk_debate_state": new_risk_debate_state, + "final_trade_decision": response.content, + } + + return risk_manager_node diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index 6634490a..47c28f66 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -1,61 +1,61 @@ -from langchain_core.messages import AIMessage -import time -import json - - -def create_bear_researcher(llm, memory): - def bear_node(state) -> dict: - investment_debate_state = state["investment_debate_state"] - history = investment_debate_state.get("history", "") - bear_history = investment_debate_state.get("bear_history", "") - - current_response = investment_debate_state.get("current_response", "") - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - - prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively. - -Key points to focus on: - -- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance. -- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors. -- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position. -- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions. -- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts. - -Resources available: - -Market research report: {market_research_report} -Social media sentiment report: {sentiment_report} -Latest world affairs news: {news_report} -Company fundamentals report: {fundamentals_report} -Conversation history of the debate: {history} -Last bull argument: {current_response} -Reflections from similar situations and lessons learned: {past_memory_str} -Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past. -""" - - response = llm.invoke(prompt) - - argument = f"Bear Analyst: {response.content}" - - new_investment_debate_state = { - "history": history + "\n" + argument, - "bear_history": bear_history + "\n" + argument, - "bull_history": investment_debate_state.get("bull_history", ""), - "current_response": argument, - "count": investment_debate_state["count"] + 1, - } - - return {"investment_debate_state": new_investment_debate_state} - - return bear_node +from langchain_core.messages import AIMessage +import time +import json + + +def create_bear_researcher(llm, memory): + def bear_node(state) -> dict: + investment_debate_state = state["investment_debate_state"] + history = investment_debate_state.get("history", "") + bear_history = investment_debate_state.get("bear_history", "") + + current_response = investment_debate_state.get("current_response", "") + market_research_report = state["market_report"] + sentiment_report = state["sentiment_report"] + news_report = state["news_report"] + fundamentals_report = state["fundamentals_report"] + + curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + past_memories = memory.get_memories(curr_situation, n_matches=2) + + past_memory_str = "" + for i, rec in enumerate(past_memories, 1): + past_memory_str += rec["recommendation"] + "\n\n" + + prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively. + +Key points to focus on: + +- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance. +- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors. +- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position. +- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions. +- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts. + +Resources available: + +Market research report: {market_research_report} +Social media sentiment report: {sentiment_report} +Latest world affairs news: {news_report} +Company fundamentals report: {fundamentals_report} +Conversation history of the debate: {history} +Last bull argument: {current_response} +Reflections from similar situations and lessons learned: {past_memory_str} +Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past. +""" + + response = llm.invoke(prompt) + + argument = f"Bear Analyst: {response.content}" + + new_investment_debate_state = { + "history": history + "\n" + argument, + "bear_history": bear_history + "\n" + argument, + "bull_history": investment_debate_state.get("bull_history", ""), + "current_response": argument, + "count": investment_debate_state["count"] + 1, + } + + return {"investment_debate_state": new_investment_debate_state} + + return bear_node diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index b03ef755..3be2e4a1 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -1,59 +1,59 @@ -from langchain_core.messages import AIMessage -import time -import json - - -def create_bull_researcher(llm, memory): - def bull_node(state) -> dict: - investment_debate_state = state["investment_debate_state"] - history = investment_debate_state.get("history", "") - bull_history = investment_debate_state.get("bull_history", "") - - current_response = investment_debate_state.get("current_response", "") - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - - prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively. - -Key points to focus on: -- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability. -- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning. -- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence. -- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit. -- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data. - -Resources available: -Market research report: {market_research_report} -Social media sentiment report: {sentiment_report} -Latest world affairs news: {news_report} -Company fundamentals report: {fundamentals_report} -Conversation history of the debate: {history} -Last bear argument: {current_response} -Reflections from similar situations and lessons learned: {past_memory_str} -Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past. -""" - - response = llm.invoke(prompt) - - argument = f"Bull Analyst: {response.content}" - - new_investment_debate_state = { - "history": history + "\n" + argument, - "bull_history": bull_history + "\n" + argument, - "bear_history": investment_debate_state.get("bear_history", ""), - "current_response": argument, - "count": investment_debate_state["count"] + 1, - } - - return {"investment_debate_state": new_investment_debate_state} - - return bull_node +from langchain_core.messages import AIMessage +import time +import json + + +def create_bull_researcher(llm, memory): + def bull_node(state) -> dict: + investment_debate_state = state["investment_debate_state"] + history = investment_debate_state.get("history", "") + bull_history = investment_debate_state.get("bull_history", "") + + current_response = investment_debate_state.get("current_response", "") + market_research_report = state["market_report"] + sentiment_report = state["sentiment_report"] + news_report = state["news_report"] + fundamentals_report = state["fundamentals_report"] + + curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + past_memories = memory.get_memories(curr_situation, n_matches=2) + + past_memory_str = "" + for i, rec in enumerate(past_memories, 1): + past_memory_str += rec["recommendation"] + "\n\n" + + prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively. + +Key points to focus on: +- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability. +- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning. +- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence. +- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit. +- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data. + +Resources available: +Market research report: {market_research_report} +Social media sentiment report: {sentiment_report} +Latest world affairs news: {news_report} +Company fundamentals report: {fundamentals_report} +Conversation history of the debate: {history} +Last bear argument: {current_response} +Reflections from similar situations and lessons learned: {past_memory_str} +Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past. +""" + + response = llm.invoke(prompt) + + argument = f"Bull Analyst: {response.content}" + + new_investment_debate_state = { + "history": history + "\n" + argument, + "bull_history": bull_history + "\n" + argument, + "bear_history": investment_debate_state.get("bear_history", ""), + "current_response": argument, + "count": investment_debate_state["count"] + 1, + } + + return {"investment_debate_state": new_investment_debate_state} + + return bull_node diff --git a/tradingagents/agents/risk_mgmt/aggresive_debator.py b/tradingagents/agents/risk_mgmt/aggresive_debator.py index 7e2b4937..d9aa34cb 100644 --- a/tradingagents/agents/risk_mgmt/aggresive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggresive_debator.py @@ -1,55 +1,55 @@ -import time -import json - - -def create_risky_debator(llm): - def risky_node(state) -> dict: - risk_debate_state = state["risk_debate_state"] - history = risk_debate_state.get("history", "") - risky_history = risk_debate_state.get("risky_history", "") - - current_safe_response = risk_debate_state.get("current_safe_response", "") - current_neutral_response = risk_debate_state.get("current_neutral_response", "") - - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - - trader_decision = state["trader_investment_plan"] - - prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefitsโ€”even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision: - -{trader_decision} - -Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments: - -Market Research Report: {market_research_report} -Social Media Sentiment Report: {sentiment_report} -Latest World Affairs Report: {news_report} -Company Fundamentals Report: {fundamentals_report} -Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. - -Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" - - response = llm.invoke(prompt) - - argument = f"Risky Analyst: {response.content}" - - new_risk_debate_state = { - "history": history + "\n" + argument, - "risky_history": risky_history + "\n" + argument, - "safe_history": risk_debate_state.get("safe_history", ""), - "neutral_history": risk_debate_state.get("neutral_history", ""), - "latest_speaker": "Risky", - "current_risky_response": argument, - "current_safe_response": risk_debate_state.get("current_safe_response", ""), - "current_neutral_response": risk_debate_state.get( - "current_neutral_response", "" - ), - "count": risk_debate_state["count"] + 1, - } - - return {"risk_debate_state": new_risk_debate_state} - - return risky_node +import time +import json + + +def create_risky_debator(llm): + def risky_node(state) -> dict: + risk_debate_state = state["risk_debate_state"] + history = risk_debate_state.get("history", "") + risky_history = risk_debate_state.get("risky_history", "") + + current_safe_response = risk_debate_state.get("current_safe_response", "") + current_neutral_response = risk_debate_state.get("current_neutral_response", "") + + market_research_report = state["market_report"] + sentiment_report = state["sentiment_report"] + news_report = state["news_report"] + fundamentals_report = state["fundamentals_report"] + + trader_decision = state["trader_investment_plan"] + + prompt = f"""As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefitsโ€”even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision: + +{trader_decision} + +Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments: + +Market Research Report: {market_research_report} +Social Media Sentiment Report: {sentiment_report} +Latest World Affairs Report: {news_report} +Company Fundamentals Report: {fundamentals_report} +Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. + +Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" + + response = llm.invoke(prompt) + + argument = f"Risky Analyst: {response.content}" + + new_risk_debate_state = { + "history": history + "\n" + argument, + "risky_history": risky_history + "\n" + argument, + "safe_history": risk_debate_state.get("safe_history", ""), + "neutral_history": risk_debate_state.get("neutral_history", ""), + "latest_speaker": "Risky", + "current_risky_response": argument, + "current_safe_response": risk_debate_state.get("current_safe_response", ""), + "current_neutral_response": risk_debate_state.get( + "current_neutral_response", "" + ), + "count": risk_debate_state["count"] + 1, + } + + return {"risk_debate_state": new_risk_debate_state} + + return risky_node diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index c56e16ad..0627a55a 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -1,58 +1,58 @@ -from langchain_core.messages import AIMessage -import time -import json - - -def create_safe_debator(llm): - def safe_node(state) -> dict: - risk_debate_state = state["risk_debate_state"] - history = risk_debate_state.get("history", "") - safe_history = risk_debate_state.get("safe_history", "") - - current_risky_response = risk_debate_state.get("current_risky_response", "") - current_neutral_response = risk_debate_state.get("current_neutral_response", "") - - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - - trader_decision = state["trader_investment_plan"] - - prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision: - -{trader_decision} - -Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision: - -Market Research Report: {market_research_report} -Social Media Sentiment Report: {sentiment_report} -Latest World Affairs Report: {news_report} -Company Fundamentals Report: {fundamentals_report} -Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. - -Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" - - response = llm.invoke(prompt) - - argument = f"Safe Analyst: {response.content}" - - new_risk_debate_state = { - "history": history + "\n" + argument, - "risky_history": risk_debate_state.get("risky_history", ""), - "safe_history": safe_history + "\n" + argument, - "neutral_history": risk_debate_state.get("neutral_history", ""), - "latest_speaker": "Safe", - "current_risky_response": risk_debate_state.get( - "current_risky_response", "" - ), - "current_safe_response": argument, - "current_neutral_response": risk_debate_state.get( - "current_neutral_response", "" - ), - "count": risk_debate_state["count"] + 1, - } - - return {"risk_debate_state": new_risk_debate_state} - - return safe_node +from langchain_core.messages import AIMessage +import time +import json + + +def create_safe_debator(llm): + def safe_node(state) -> dict: + risk_debate_state = state["risk_debate_state"] + history = risk_debate_state.get("history", "") + safe_history = risk_debate_state.get("safe_history", "") + + current_risky_response = risk_debate_state.get("current_risky_response", "") + current_neutral_response = risk_debate_state.get("current_neutral_response", "") + + market_research_report = state["market_report"] + sentiment_report = state["sentiment_report"] + news_report = state["news_report"] + fundamentals_report = state["fundamentals_report"] + + trader_decision = state["trader_investment_plan"] + + prompt = f"""As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision: + +{trader_decision} + +Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision: + +Market Research Report: {market_research_report} +Social Media Sentiment Report: {sentiment_report} +Latest World Affairs Report: {news_report} +Company Fundamentals Report: {fundamentals_report} +Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. + +Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" + + response = llm.invoke(prompt) + + argument = f"Safe Analyst: {response.content}" + + new_risk_debate_state = { + "history": history + "\n" + argument, + "risky_history": risk_debate_state.get("risky_history", ""), + "safe_history": safe_history + "\n" + argument, + "neutral_history": risk_debate_state.get("neutral_history", ""), + "latest_speaker": "Safe", + "current_risky_response": risk_debate_state.get( + "current_risky_response", "" + ), + "current_safe_response": argument, + "current_neutral_response": risk_debate_state.get( + "current_neutral_response", "" + ), + "count": risk_debate_state["count"] + 1, + } + + return {"risk_debate_state": new_risk_debate_state} + + return safe_node diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index a6d2ef5c..aacf297b 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -1,55 +1,55 @@ -import time -import json - - -def create_neutral_debator(llm): - def neutral_node(state) -> dict: - risk_debate_state = state["risk_debate_state"] - history = risk_debate_state.get("history", "") - neutral_history = risk_debate_state.get("neutral_history", "") - - current_risky_response = risk_debate_state.get("current_risky_response", "") - current_safe_response = risk_debate_state.get("current_safe_response", "") - - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - - trader_decision = state["trader_investment_plan"] - - prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision: - -{trader_decision} - -Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision: - -Market Research Report: {market_research_report} -Social Media Sentiment Report: {sentiment_report} -Latest World Affairs Report: {news_report} -Company Fundamentals Report: {fundamentals_report} -Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. - -Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" - - response = llm.invoke(prompt) - - argument = f"Neutral Analyst: {response.content}" - - new_risk_debate_state = { - "history": history + "\n" + argument, - "risky_history": risk_debate_state.get("risky_history", ""), - "safe_history": risk_debate_state.get("safe_history", ""), - "neutral_history": neutral_history + "\n" + argument, - "latest_speaker": "Neutral", - "current_risky_response": risk_debate_state.get( - "current_risky_response", "" - ), - "current_safe_response": risk_debate_state.get("current_safe_response", ""), - "current_neutral_response": argument, - "count": risk_debate_state["count"] + 1, - } - - return {"risk_debate_state": new_risk_debate_state} - - return neutral_node +import time +import json + + +def create_neutral_debator(llm): + def neutral_node(state) -> dict: + risk_debate_state = state["risk_debate_state"] + history = risk_debate_state.get("history", "") + neutral_history = risk_debate_state.get("neutral_history", "") + + current_risky_response = risk_debate_state.get("current_risky_response", "") + current_safe_response = risk_debate_state.get("current_safe_response", "") + + market_research_report = state["market_report"] + sentiment_report = state["sentiment_report"] + news_report = state["news_report"] + fundamentals_report = state["fundamentals_report"] + + trader_decision = state["trader_investment_plan"] + + prompt = f"""As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision: + +{trader_decision} + +Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision: + +Market Research Report: {market_research_report} +Social Media Sentiment Report: {sentiment_report} +Latest World Affairs Report: {news_report} +Company Fundamentals Report: {fundamentals_report} +Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point. + +Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" + + response = llm.invoke(prompt) + + argument = f"Neutral Analyst: {response.content}" + + new_risk_debate_state = { + "history": history + "\n" + argument, + "risky_history": risk_debate_state.get("risky_history", ""), + "safe_history": risk_debate_state.get("safe_history", ""), + "neutral_history": neutral_history + "\n" + argument, + "latest_speaker": "Neutral", + "current_risky_response": risk_debate_state.get( + "current_risky_response", "" + ), + "current_safe_response": risk_debate_state.get("current_safe_response", ""), + "current_neutral_response": argument, + "count": risk_debate_state["count"] + 1, + } + + return {"risk_debate_state": new_risk_debate_state} + + return neutral_node diff --git a/tradingagents/agents/screening_agent.py b/tradingagents/agents/screening_agent.py new file mode 100644 index 00000000..cf3f821e --- /dev/null +++ b/tradingagents/agents/screening_agent.py @@ -0,0 +1,63 @@ +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from tradingagents.agents.utils.agent_utils import ( + get_market_movers, + get_earnings_calendar, + get_insider_transactions, + get_indicators, + get_trending_social +) + +def create_screening_agent(llm): + """ + Creates a screening agent that identifies potential stocks to analyze. + """ + def screening_agent_node(state): + # Tools available to the screening agent + tools = [get_market_movers, get_earnings_calendar, get_insider_transactions, get_indicators, get_trending_social] + + system_message = ( + "You are a Market Screening Agent. Your goal is to identify 'Hidden Gem' stocks before they make a massive move." + " Do NOT just recommend stocks that have already risen 50%+ (unless there is a fresh catalyst)." + " Use a multi-factor approach:" + " 1. **Scan**: Use `get_market_movers` to find 'Most Active' or 'Top Losers' (potential reversals). Avoid chasing 'Top Gainers' if they are already up significantly." + " 2. **Social Hype**: Use `get_trending_social` to find stocks buzzing on Reddit/StockTwits. High chatter + low price movement = potential breakout." + " 3. **Catalyst**: Use `get_earnings_calendar` to find upcoming earnings." + " 4. **Smart Money**: Use `get_insider_transactions` on interesting tickers. If insiders are buying, it's a strong signal." + " 5. **Technicals**: Use `get_indicators` (RSI, MACD) to check if a stock is Oversold (RSI < 30) or showing divergence." + " \n" + " **Strategy**: Look for stocks that are active but haven't spiked yet, or are beaten down (Losers) with insider buying." + " Analyze the data and recommend 1-3 tickers." + " Return your recommendations as a comma-separated list of tickers in the final response, e.g., 'NVDA, TSLA, AAPL'." + " Do not include any other text in the final line, just the tickers." + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are a helpful AI assistant." + " You have access to the following tools: {tool_names}.\n{system_message}", + ), + MessagesPlaceholder(variable_name="messages"), + ] + ) + + prompt = prompt.partial(system_message=system_message) + prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) + + # Bind tools to the LLM + chain = prompt | llm.bind_tools(tools) + + # Invoke the chain + # For screening, we might not have a full conversation history yet, so we can start with a user request + if not state.get("messages"): + state["messages"] = [("user", "Please screen the market and find interesting stocks.")] + + result = chain.invoke(state["messages"]) + + return { + "messages": [result], + # We could parse the result here and put it in a specific state key if needed + } + + return screening_agent_node diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 1b05c35d..6291dbef 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -1,46 +1,46 @@ -import functools -import time -import json - - -def create_trader(llm, memory): - def trader_node(state, name): - company_name = state["company_of_interest"] - investment_plan = state["investment_plan"] - market_research_report = state["market_report"] - sentiment_report = state["sentiment_report"] - news_report = state["news_report"] - fundamentals_report = state["fundamentals_report"] - - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" - past_memories = memory.get_memories(curr_situation, n_matches=2) - - past_memory_str = "" - if past_memories: - for i, rec in enumerate(past_memories, 1): - past_memory_str += rec["recommendation"] + "\n\n" - else: - past_memory_str = "No past memories found." - - context = { - "role": "user", - "content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.", - } - - messages = [ - { - "role": "system", - "content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""", - }, - context, - ] - - result = llm.invoke(messages) - - return { - "messages": [result], - "trader_investment_plan": result.content, - "sender": name, - } - - return functools.partial(trader_node, name="Trader") +import functools +import time +import json + + +def create_trader(llm, memory): + def trader_node(state, name): + company_name = state["company_of_interest"] + investment_plan = state["investment_plan"] + market_research_report = state["market_report"] + sentiment_report = state["sentiment_report"] + news_report = state["news_report"] + fundamentals_report = state["fundamentals_report"] + + curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + past_memories = memory.get_memories(curr_situation, n_matches=2) + + past_memory_str = "" + if past_memories: + for i, rec in enumerate(past_memories, 1): + past_memory_str += rec["recommendation"] + "\n\n" + else: + past_memory_str = "No past memories found." + + context = { + "role": "user", + "content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.", + } + + messages = [ + { + "role": "system", + "content": f"""You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""", + }, + context, + ] + + result = llm.invoke(messages) + + return { + "messages": [result], + "trader_investment_plan": result.content, + "sender": name, + } + + return functools.partial(trader_node, name="Trader") diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 3a859ea1..f347febd 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,76 +1,76 @@ -from typing import Annotated, Sequence -from datetime import date, timedelta, datetime -from typing_extensions import TypedDict, Optional -from langchain_openai import ChatOpenAI -from tradingagents.agents import * -from langgraph.prebuilt import ToolNode -from langgraph.graph import END, StateGraph, START, MessagesState - - -# Researcher team state -class InvestDebateState(TypedDict): - bull_history: Annotated[ - str, "Bullish Conversation history" - ] # Bullish Conversation history - bear_history: Annotated[ - str, "Bearish Conversation history" - ] # Bullish Conversation history - history: Annotated[str, "Conversation history"] # Conversation history - current_response: Annotated[str, "Latest response"] # Last response - judge_decision: Annotated[str, "Final judge decision"] # Last response - count: Annotated[int, "Length of the current conversation"] # Conversation length - - -# Risk management team state -class RiskDebateState(TypedDict): - risky_history: Annotated[ - str, "Risky Agent's Conversation history" - ] # Conversation history - safe_history: Annotated[ - str, "Safe Agent's Conversation history" - ] # Conversation history - neutral_history: Annotated[ - str, "Neutral Agent's Conversation history" - ] # Conversation history - history: Annotated[str, "Conversation history"] # Conversation history - latest_speaker: Annotated[str, "Analyst that spoke last"] - current_risky_response: Annotated[ - str, "Latest response by the risky analyst" - ] # Last response - current_safe_response: Annotated[ - str, "Latest response by the safe analyst" - ] # Last response - current_neutral_response: Annotated[ - str, "Latest response by the neutral analyst" - ] # Last response - judge_decision: Annotated[str, "Judge's decision"] - count: Annotated[int, "Length of the current conversation"] # Conversation length - - -class AgentState(MessagesState): - company_of_interest: Annotated[str, "Company that we are interested in trading"] - trade_date: Annotated[str, "What date we are trading at"] - - sender: Annotated[str, "Agent that sent this message"] - - # research step - market_report: Annotated[str, "Report from the Market Analyst"] - sentiment_report: Annotated[str, "Report from the Social Media Analyst"] - news_report: Annotated[ - str, "Report from the News Researcher of current world affairs" - ] - fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] - - # researcher team discussion step - investment_debate_state: Annotated[ - InvestDebateState, "Current state of the debate on if to invest or not" - ] - investment_plan: Annotated[str, "Plan generated by the Analyst"] - - trader_investment_plan: Annotated[str, "Plan generated by the Trader"] - - # risk management team discussion step - risk_debate_state: Annotated[ - RiskDebateState, "Current state of the debate on evaluating risk" - ] - final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] +from typing import Annotated, Sequence +from datetime import date, timedelta, datetime +from typing_extensions import TypedDict, Optional +from langchain_openai import ChatOpenAI +from tradingagents.agents import * +from langgraph.prebuilt import ToolNode +from langgraph.graph import END, StateGraph, START, MessagesState + + +# Researcher team state +class InvestDebateState(TypedDict): + bull_history: Annotated[ + str, "Bullish Conversation history" + ] # Bullish Conversation history + bear_history: Annotated[ + str, "Bearish Conversation history" + ] # Bullish Conversation history + history: Annotated[str, "Conversation history"] # Conversation history + current_response: Annotated[str, "Latest response"] # Last response + judge_decision: Annotated[str, "Final judge decision"] # Last response + count: Annotated[int, "Length of the current conversation"] # Conversation length + + +# Risk management team state +class RiskDebateState(TypedDict): + risky_history: Annotated[ + str, "Risky Agent's Conversation history" + ] # Conversation history + safe_history: Annotated[ + str, "Safe Agent's Conversation history" + ] # Conversation history + neutral_history: Annotated[ + str, "Neutral Agent's Conversation history" + ] # Conversation history + history: Annotated[str, "Conversation history"] # Conversation history + latest_speaker: Annotated[str, "Analyst that spoke last"] + current_risky_response: Annotated[ + str, "Latest response by the risky analyst" + ] # Last response + current_safe_response: Annotated[ + str, "Latest response by the safe analyst" + ] # Last response + current_neutral_response: Annotated[ + str, "Latest response by the neutral analyst" + ] # Last response + judge_decision: Annotated[str, "Judge's decision"] + count: Annotated[int, "Length of the current conversation"] # Conversation length + + +class AgentState(MessagesState): + company_of_interest: Annotated[str, "Company that we are interested in trading"] + trade_date: Annotated[str, "What date we are trading at"] + + sender: Annotated[str, "Agent that sent this message"] + + # research step + market_report: Annotated[str, "Report from the Market Analyst"] + sentiment_report: Annotated[str, "Report from the Social Media Analyst"] + news_report: Annotated[ + str, "Report from the News Researcher of current world affairs" + ] + fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] + + # researcher team discussion step + investment_debate_state: Annotated[ + InvestDebateState, "Current state of the debate on if to invest or not" + ] + investment_plan: Annotated[str, "Plan generated by the Analyst"] + + trader_investment_plan: Annotated[str, "Plan generated by the Trader"] + + # risk management team discussion step + risk_debate_state: Annotated[ + RiskDebateState, "Current state of the debate on evaluating risk" + ] + final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 6cf294a1..ca6ffbe0 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -1,39 +1,39 @@ -from langchain_core.messages import HumanMessage, RemoveMessage - -# Import tools from separate utility files -from tradingagents.agents.utils.core_stock_tools import ( - get_stock_data -) -from tradingagents.agents.utils.technical_indicators_tools import ( - get_indicators -) -from tradingagents.agents.utils.fundamental_data_tools import ( - get_fundamentals, - get_balance_sheet, - get_cashflow, - get_income_statement -) -from tradingagents.agents.utils.news_data_tools import ( - get_news, - get_insider_sentiment, - get_insider_transactions, - get_global_news -) - -def create_msg_delete(): - def delete_messages(state): - """Clear messages and add placeholder for Anthropic compatibility""" - messages = state["messages"] - - # Remove all messages - removal_operations = [RemoveMessage(id=m.id) for m in messages] - - # Add a minimal placeholder message - placeholder = HumanMessage(content="Continue") - - return {"messages": removal_operations + [placeholder]} - - return delete_messages - - - \ No newline at end of file +from langchain_core.messages import HumanMessage, RemoveMessage + +# Import tools from separate utility files +from tradingagents.agents.utils.core_stock_tools import ( + get_stock_data, + get_market_movers, + get_earnings_calendar +) +from tradingagents.agents.utils.technical_indicators_tools import ( + get_indicators +) +from tradingagents.agents.utils.fundamental_data_tools import ( + get_fundamentals, + get_balance_sheet, + get_cashflow, + get_income_statement +) +from tradingagents.agents.utils.news_data_tools import ( + get_news, + get_insider_sentiment, + get_insider_transactions, + get_global_news +) +from tradingagents.dataflows.social_sentiment import get_trending_social + +def create_msg_delete(): + def delete_messages(state): + """Clear messages and add placeholder for Anthropic compatibility""" + messages = state["messages"] + + # Remove all messages + removal_operations = [RemoveMessage(id=m.id) for m in messages] + + # Add a minimal placeholder message + placeholder = HumanMessage(content="Continue") + + return {"messages": removal_operations + [placeholder]} + + return delete_messages \ No newline at end of file diff --git a/tradingagents/agents/utils/core_stock_tools.py b/tradingagents/agents/utils/core_stock_tools.py index 3a416622..2b4c8ba2 100644 --- a/tradingagents/agents/utils/core_stock_tools.py +++ b/tradingagents/agents/utils/core_stock_tools.py @@ -1,22 +1,42 @@ -from langchain_core.tools import tool -from typing import Annotated -from tradingagents.dataflows.interface import route_to_vendor - - -@tool -def get_stock_data( - symbol: Annotated[str, "ticker symbol of the company"], - start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "End date in yyyy-mm-dd format"], -) -> str: - """ - Retrieve stock price data (OHLCV) for a given ticker symbol. - Uses the configured core_stock_apis vendor. - Args: - symbol (str): Ticker symbol of the company, e.g. AAPL, TSM - start_date (str): Start date in yyyy-mm-dd format - end_date (str): End date in yyyy-mm-dd format - Returns: - str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range. - """ - return route_to_vendor("get_stock_data", symbol, start_date, end_date) +from langchain_core.tools import tool +from typing import Annotated +from tradingagents.dataflows.interface import route_to_vendor + + +@tool +def get_stock_data( + symbol: Annotated[str, "ticker symbol of the company"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], +) -> str: + """ + Retrieve stock price data (OHLCV) for a given ticker symbol. + Uses the configured core_stock_apis vendor. + Args: + symbol (str): Ticker symbol of the company, e.g. AAPL, TSM + start_date (str): Start date in yyyy-mm-dd format + end_date (str): End date in yyyy-mm-dd format + Returns: + str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range. + """ + return route_to_vendor("get_stock_data", symbol, start_date, end_date) + +@tool +def get_market_movers( + metric: Annotated[str, "One of 'top_gainers', 'top_losers', 'most_actively_traded'"] = "top_gainers", + limit: Annotated[int, "Number of results to return"] = 10, +) -> str: + """ + Retrieve top market movers (gainers, losers, active) to identify potential stocks to analyze. + """ + return route_to_vendor("get_market_movers", metric, limit) + +@tool +def get_earnings_calendar( + horizon: Annotated[str, "Time horizon: '3month', '6month', '12month'"] = "3month", + symbol: Annotated[str, "Optional ticker symbol to filter by"] = None, +) -> str: + """ + Retrieve earnings calendar to identify companies reporting soon (potential volatility). + """ + return route_to_vendor("get_earnings_calendar", horizon, symbol) diff --git a/tradingagents/agents/utils/fundamental_data_tools.py b/tradingagents/agents/utils/fundamental_data_tools.py index 47f6f2eb..97783614 100644 --- a/tradingagents/agents/utils/fundamental_data_tools.py +++ b/tradingagents/agents/utils/fundamental_data_tools.py @@ -1,77 +1,77 @@ -from langchain_core.tools import tool -from typing import Annotated -from tradingagents.dataflows.interface import route_to_vendor - - -@tool -def get_fundamentals( - ticker: Annotated[str, "ticker symbol"], - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -) -> str: - """ - Retrieve comprehensive fundamental data for a given ticker symbol. - Uses the configured fundamental_data vendor. - Args: - ticker (str): Ticker symbol of the company - curr_date (str): Current date you are trading at, yyyy-mm-dd - Returns: - str: A formatted report containing comprehensive fundamental data - """ - return route_to_vendor("get_fundamentals", ticker, curr_date) - - -@tool -def get_balance_sheet( - ticker: Annotated[str, "ticker symbol"], - freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly", - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None, -) -> str: - """ - Retrieve balance sheet data for a given ticker symbol. - Uses the configured fundamental_data vendor. - Args: - ticker (str): Ticker symbol of the company - freq (str): Reporting frequency: annual/quarterly (default quarterly) - curr_date (str): Current date you are trading at, yyyy-mm-dd - Returns: - str: A formatted report containing balance sheet data - """ - return route_to_vendor("get_balance_sheet", ticker, freq, curr_date) - - -@tool -def get_cashflow( - ticker: Annotated[str, "ticker symbol"], - freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly", - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None, -) -> str: - """ - Retrieve cash flow statement data for a given ticker symbol. - Uses the configured fundamental_data vendor. - Args: - ticker (str): Ticker symbol of the company - freq (str): Reporting frequency: annual/quarterly (default quarterly) - curr_date (str): Current date you are trading at, yyyy-mm-dd - Returns: - str: A formatted report containing cash flow statement data - """ - return route_to_vendor("get_cashflow", ticker, freq, curr_date) - - -@tool -def get_income_statement( - ticker: Annotated[str, "ticker symbol"], - freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly", - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None, -) -> str: - """ - Retrieve income statement data for a given ticker symbol. - Uses the configured fundamental_data vendor. - Args: - ticker (str): Ticker symbol of the company - freq (str): Reporting frequency: annual/quarterly (default quarterly) - curr_date (str): Current date you are trading at, yyyy-mm-dd - Returns: - str: A formatted report containing income statement data - """ +from langchain_core.tools import tool +from typing import Annotated +from tradingagents.dataflows.interface import route_to_vendor + + +@tool +def get_fundamentals( + ticker: Annotated[str, "ticker symbol"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +) -> str: + """ + Retrieve comprehensive fundamental data for a given ticker symbol. + Uses the configured fundamental_data vendor. + Args: + ticker (str): Ticker symbol of the company + curr_date (str): Current date you are trading at, yyyy-mm-dd + Returns: + str: A formatted report containing comprehensive fundamental data + """ + return route_to_vendor("get_fundamentals", ticker, curr_date) + + +@tool +def get_balance_sheet( + ticker: Annotated[str, "ticker symbol"], + freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly", + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None, +) -> str: + """ + Retrieve balance sheet data for a given ticker symbol. + Uses the configured fundamental_data vendor. + Args: + ticker (str): Ticker symbol of the company + freq (str): Reporting frequency: annual/quarterly (default quarterly) + curr_date (str): Current date you are trading at, yyyy-mm-dd + Returns: + str: A formatted report containing balance sheet data + """ + return route_to_vendor("get_balance_sheet", ticker, freq, curr_date) + + +@tool +def get_cashflow( + ticker: Annotated[str, "ticker symbol"], + freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly", + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None, +) -> str: + """ + Retrieve cash flow statement data for a given ticker symbol. + Uses the configured fundamental_data vendor. + Args: + ticker (str): Ticker symbol of the company + freq (str): Reporting frequency: annual/quarterly (default quarterly) + curr_date (str): Current date you are trading at, yyyy-mm-dd + Returns: + str: A formatted report containing cash flow statement data + """ + return route_to_vendor("get_cashflow", ticker, freq, curr_date) + + +@tool +def get_income_statement( + ticker: Annotated[str, "ticker symbol"], + freq: Annotated[str, "reporting frequency: annual/quarterly"] = "quarterly", + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"] = None, +) -> str: + """ + Retrieve income statement data for a given ticker symbol. + Uses the configured fundamental_data vendor. + Args: + ticker (str): Ticker symbol of the company + freq (str): Reporting frequency: annual/quarterly (default quarterly) + curr_date (str): Current date you are trading at, yyyy-mm-dd + Returns: + str: A formatted report containing income statement data + """ return route_to_vendor("get_income_statement", ticker, freq, curr_date) \ No newline at end of file diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..7de66fbf 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,113 +1,113 @@ -import chromadb -from chromadb.config import Settings -from openai import OpenAI - - -class FinancialSituationMemory: - def __init__(self, name, config): - if config["backend_url"] == "http://localhost:11434/v1": - self.embedding = "nomic-embed-text" - else: - self.embedding = "text-embedding-3-small" - self.client = OpenAI(base_url=config["backend_url"]) - self.chroma_client = chromadb.Client(Settings(allow_reset=True)) - self.situation_collection = self.chroma_client.create_collection(name=name) - - def get_embedding(self, text): - """Get OpenAI embedding for a text""" - - response = self.client.embeddings.create( - model=self.embedding, input=text - ) - return response.data[0].embedding - - def add_situations(self, situations_and_advice): - """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" - - situations = [] - advice = [] - ids = [] - embeddings = [] - - offset = self.situation_collection.count() - - for i, (situation, recommendation) in enumerate(situations_and_advice): - situations.append(situation) - advice.append(recommendation) - ids.append(str(offset + i)) - embeddings.append(self.get_embedding(situation)) - - self.situation_collection.add( - documents=situations, - metadatas=[{"recommendation": rec} for rec in advice], - embeddings=embeddings, - ids=ids, - ) - - def get_memories(self, current_situation, n_matches=1): - """Find matching recommendations using OpenAI embeddings""" - query_embedding = self.get_embedding(current_situation) - - results = self.situation_collection.query( - query_embeddings=[query_embedding], - n_results=n_matches, - include=["metadatas", "documents", "distances"], - ) - - matched_results = [] - for i in range(len(results["documents"][0])): - matched_results.append( - { - "matched_situation": results["documents"][0][i], - "recommendation": results["metadatas"][0][i]["recommendation"], - "similarity_score": 1 - results["distances"][0][i], - } - ) - - return matched_results - - -if __name__ == "__main__": - # Example usage - matcher = FinancialSituationMemory() - - # Example data - example_data = [ - ( - "High inflation rate with rising interest rates and declining consumer spending", - "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", - ), - ( - "Tech sector showing high volatility with increasing institutional selling pressure", - "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", - ), - ( - "Strong dollar affecting emerging markets with increasing forex volatility", - "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", - ), - ( - "Market showing signs of sector rotation with rising yields", - "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", - ), - ] - - # Add the example situations and recommendations - matcher.add_situations(example_data) - - # Example query - current_situation = """ - Market showing increased volatility in tech sector, with institutional investors - reducing positions and rising interest rates affecting growth stock valuations - """ - - try: - recommendations = matcher.get_memories(current_situation, n_matches=2) - - for i, rec in enumerate(recommendations, 1): - print(f"\nMatch {i}:") - print(f"Similarity Score: {rec['similarity_score']:.2f}") - print(f"Matched Situation: {rec['matched_situation']}") - print(f"Recommendation: {rec['recommendation']}") - - except Exception as e: - print(f"Error during recommendation: {str(e)}") +import chromadb +from chromadb.config import Settings +from openai import OpenAI + + +class FinancialSituationMemory: + def __init__(self, name, config): + if config["backend_url"] == "http://localhost:11434/v1": + self.embedding = "nomic-embed-text" + else: + self.embedding = "text-embedding-3-small" + self.client = OpenAI(base_url=config["backend_url"]) + self.chroma_client = chromadb.Client(Settings(allow_reset=True)) + self.situation_collection = self.chroma_client.create_collection(name=name) + + def get_embedding(self, text): + """Get OpenAI embedding for a text""" + + response = self.client.embeddings.create( + model=self.embedding, input=text + ) + return response.data[0].embedding + + def add_situations(self, situations_and_advice): + """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" + + situations = [] + advice = [] + ids = [] + embeddings = [] + + offset = self.situation_collection.count() + + for i, (situation, recommendation) in enumerate(situations_and_advice): + situations.append(situation) + advice.append(recommendation) + ids.append(str(offset + i)) + embeddings.append(self.get_embedding(situation)) + + self.situation_collection.add( + documents=situations, + metadatas=[{"recommendation": rec} for rec in advice], + embeddings=embeddings, + ids=ids, + ) + + def get_memories(self, current_situation, n_matches=1): + """Find matching recommendations using OpenAI embeddings""" + query_embedding = self.get_embedding(current_situation) + + results = self.situation_collection.query( + query_embeddings=[query_embedding], + n_results=n_matches, + include=["metadatas", "documents", "distances"], + ) + + matched_results = [] + for i in range(len(results["documents"][0])): + matched_results.append( + { + "matched_situation": results["documents"][0][i], + "recommendation": results["metadatas"][0][i]["recommendation"], + "similarity_score": 1 - results["distances"][0][i], + } + ) + + return matched_results + + +if __name__ == "__main__": + # Example usage + matcher = FinancialSituationMemory() + + # Example data + example_data = [ + ( + "High inflation rate with rising interest rates and declining consumer spending", + "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", + ), + ( + "Tech sector showing high volatility with increasing institutional selling pressure", + "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", + ), + ( + "Strong dollar affecting emerging markets with increasing forex volatility", + "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", + ), + ( + "Market showing signs of sector rotation with rising yields", + "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", + ), + ] + + # Add the example situations and recommendations + matcher.add_situations(example_data) + + # Example query + current_situation = """ + Market showing increased volatility in tech sector, with institutional investors + reducing positions and rising interest rates affecting growth stock valuations + """ + + try: + recommendations = matcher.get_memories(current_situation, n_matches=2) + + for i, rec in enumerate(recommendations, 1): + print(f"\nMatch {i}:") + print(f"Similarity Score: {rec['similarity_score']:.2f}") + print(f"Matched Situation: {rec['matched_situation']}") + print(f"Recommendation: {rec['recommendation']}") + + except Exception as e: + print(f"Error during recommendation: {str(e)}") diff --git a/tradingagents/agents/utils/news_data_tools.py b/tradingagents/agents/utils/news_data_tools.py index 0df9d047..76b6389a 100644 --- a/tradingagents/agents/utils/news_data_tools.py +++ b/tradingagents/agents/utils/news_data_tools.py @@ -1,71 +1,71 @@ -from langchain_core.tools import tool -from typing import Annotated -from tradingagents.dataflows.interface import route_to_vendor - -@tool -def get_news( - ticker: Annotated[str, "Ticker symbol"], - start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "End date in yyyy-mm-dd format"], -) -> str: - """ - Retrieve news data for a given ticker symbol. - Uses the configured news_data vendor. - Args: - ticker (str): Ticker symbol - start_date (str): Start date in yyyy-mm-dd format - end_date (str): End date in yyyy-mm-dd format - Returns: - str: A formatted string containing news data - """ - return route_to_vendor("get_news", ticker, start_date, end_date) - -@tool -def get_global_news( - curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], - look_back_days: Annotated[int, "Number of days to look back"] = 7, - limit: Annotated[int, "Maximum number of articles to return"] = 5, -) -> str: - """ - Retrieve global news data. - Uses the configured news_data vendor. - Args: - curr_date (str): Current date in yyyy-mm-dd format - look_back_days (int): Number of days to look back (default 7) - limit (int): Maximum number of articles to return (default 5) - Returns: - str: A formatted string containing global news data - """ - return route_to_vendor("get_global_news", curr_date, look_back_days, limit) - -@tool -def get_insider_sentiment( - ticker: Annotated[str, "ticker symbol for the company"], - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -) -> str: - """ - Retrieve insider sentiment information about a company. - Uses the configured news_data vendor. - Args: - ticker (str): Ticker symbol of the company - curr_date (str): Current date you are trading at, yyyy-mm-dd - Returns: - str: A report of insider sentiment data - """ - return route_to_vendor("get_insider_sentiment", ticker, curr_date) - -@tool -def get_insider_transactions( - ticker: Annotated[str, "ticker symbol"], - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -) -> str: - """ - Retrieve insider transaction information about a company. - Uses the configured news_data vendor. - Args: - ticker (str): Ticker symbol of the company - curr_date (str): Current date you are trading at, yyyy-mm-dd - Returns: - str: A report of insider transaction data - """ - return route_to_vendor("get_insider_transactions", ticker, curr_date) +from langchain_core.tools import tool +from typing import Annotated +from tradingagents.dataflows.interface import route_to_vendor + +@tool +def get_news( + ticker: Annotated[str, "Ticker symbol"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], +) -> str: + """ + Retrieve news data for a given ticker symbol. + Uses the configured news_data vendor. + Args: + ticker (str): Ticker symbol + start_date (str): Start date in yyyy-mm-dd format + end_date (str): End date in yyyy-mm-dd format + Returns: + str: A formatted string containing news data + """ + return route_to_vendor("get_news", ticker, start_date, end_date) + +@tool +def get_global_news( + curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], + look_back_days: Annotated[int, "Number of days to look back"] = 7, + limit: Annotated[int, "Maximum number of articles to return"] = 5, +) -> str: + """ + Retrieve global news data. + Uses the configured news_data vendor. + Args: + curr_date (str): Current date in yyyy-mm-dd format + look_back_days (int): Number of days to look back (default 7) + limit (int): Maximum number of articles to return (default 5) + Returns: + str: A formatted string containing global news data + """ + return route_to_vendor("get_global_news", curr_date, look_back_days, limit) + +@tool +def get_insider_sentiment( + ticker: Annotated[str, "ticker symbol for the company"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +) -> str: + """ + Retrieve insider sentiment information about a company. + Uses the configured news_data vendor. + Args: + ticker (str): Ticker symbol of the company + curr_date (str): Current date you are trading at, yyyy-mm-dd + Returns: + str: A report of insider sentiment data + """ + return route_to_vendor("get_insider_sentiment", ticker, curr_date) + +@tool +def get_insider_transactions( + ticker: Annotated[str, "ticker symbol"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +) -> str: + """ + Retrieve insider transaction information about a company. + Uses the configured news_data vendor. + Args: + ticker (str): Ticker symbol of the company + curr_date (str): Current date you are trading at, yyyy-mm-dd + Returns: + str: A report of insider transaction data + """ + return route_to_vendor("get_insider_transactions", ticker, curr_date) diff --git a/tradingagents/agents/utils/technical_indicators_tools.py b/tradingagents/agents/utils/technical_indicators_tools.py index c6c08bca..0de13ebb 100644 --- a/tradingagents/agents/utils/technical_indicators_tools.py +++ b/tradingagents/agents/utils/technical_indicators_tools.py @@ -1,23 +1,23 @@ -from langchain_core.tools import tool -from typing import Annotated -from tradingagents.dataflows.interface import route_to_vendor - -@tool -def get_indicators( - symbol: Annotated[str, "ticker symbol of the company"], - indicator: Annotated[str, "technical indicator to get the analysis and report of"], - curr_date: Annotated[str, "The current trading date you are trading on, YYYY-mm-dd"], - look_back_days: Annotated[int, "how many days to look back"] = 30, -) -> str: - """ - Retrieve technical indicators for a given ticker symbol. - Uses the configured technical_indicators vendor. - Args: - symbol (str): Ticker symbol of the company, e.g. AAPL, TSM - indicator (str): Technical indicator to get the analysis and report of - curr_date (str): The current trading date you are trading on, YYYY-mm-dd - look_back_days (int): How many days to look back, default is 30 - Returns: - str: A formatted dataframe containing the technical indicators for the specified ticker symbol and indicator. - """ +from langchain_core.tools import tool +from typing import Annotated +from tradingagents.dataflows.interface import route_to_vendor + +@tool +def get_indicators( + symbol: Annotated[str, "ticker symbol of the company"], + indicator: Annotated[str, "technical indicator to get the analysis and report of"], + curr_date: Annotated[str, "The current trading date you are trading on, YYYY-mm-dd"], + look_back_days: Annotated[int, "how many days to look back"] = 30, +) -> str: + """ + Retrieve technical indicators for a given ticker symbol. + Uses the configured technical_indicators vendor. + Args: + symbol (str): Ticker symbol of the company, e.g. AAPL, TSM + indicator (str): Technical indicator to get the analysis and report of + curr_date (str): The current trading date you are trading on, YYYY-mm-dd + look_back_days (int): How many days to look back, default is 30 + Returns: + str: A formatted dataframe containing the technical indicators for the specified ticker symbol and indicator. + """ return route_to_vendor("get_indicators", symbol, indicator, curr_date, look_back_days) \ No newline at end of file diff --git a/tradingagents/dataflows/alpha_vantage.py b/tradingagents/dataflows/alpha_vantage.py index c5177c29..5572cbc0 100644 --- a/tradingagents/dataflows/alpha_vantage.py +++ b/tradingagents/dataflows/alpha_vantage.py @@ -1,5 +1,6 @@ -# Import functions from specialized modules -from .alpha_vantage_stock import get_stock -from .alpha_vantage_indicator import get_indicator -from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement -from .alpha_vantage_news import get_news, get_insider_transactions \ No newline at end of file +# Import functions from specialized modules +from .alpha_vantage_stock import get_stock +from .alpha_vantage_indicator import get_indicator +from .alpha_vantage_fundamentals import get_fundamentals, get_balance_sheet, get_cashflow, get_income_statement +from .alpha_vantage_news import get_news, get_insider_transactions +from .alpha_vantage_market import get_market_movers, get_earnings_calendar \ No newline at end of file diff --git a/tradingagents/dataflows/alpha_vantage_common.py b/tradingagents/dataflows/alpha_vantage_common.py index 409ff29e..6fd5d143 100644 --- a/tradingagents/dataflows/alpha_vantage_common.py +++ b/tradingagents/dataflows/alpha_vantage_common.py @@ -1,122 +1,122 @@ -import os -import requests -import pandas as pd -import json -from datetime import datetime -from io import StringIO - -API_BASE_URL = "https://www.alphavantage.co/query" - -def get_api_key() -> str: - """Retrieve the API key for Alpha Vantage from environment variables.""" - api_key = os.getenv("ALPHA_VANTAGE_API_KEY") - if not api_key: - raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.") - return api_key - -def format_datetime_for_api(date_input) -> str: - """Convert various date formats to YYYYMMDDTHHMM format required by Alpha Vantage API.""" - if isinstance(date_input, str): - # If already in correct format, return as-is - if len(date_input) == 13 and 'T' in date_input: - return date_input - # Try to parse common date formats - try: - dt = datetime.strptime(date_input, "%Y-%m-%d") - return dt.strftime("%Y%m%dT0000") - except ValueError: - try: - dt = datetime.strptime(date_input, "%Y-%m-%d %H:%M") - return dt.strftime("%Y%m%dT%H%M") - except ValueError: - raise ValueError(f"Unsupported date format: {date_input}") - elif isinstance(date_input, datetime): - return date_input.strftime("%Y%m%dT%H%M") - else: - raise ValueError(f"Date must be string or datetime object, got {type(date_input)}") - -class AlphaVantageRateLimitError(Exception): - """Exception raised when Alpha Vantage API rate limit is exceeded.""" - pass - -def _make_api_request(function_name: str, params: dict) -> dict | str: - """Helper function to make API requests and handle responses. - - Raises: - AlphaVantageRateLimitError: When API rate limit is exceeded - """ - # Create a copy of params to avoid modifying the original - api_params = params.copy() - api_params.update({ - "function": function_name, - "apikey": get_api_key(), - "source": "trading_agents", - }) - - # Handle entitlement parameter if present in params or global variable - current_entitlement = globals().get('_current_entitlement') - entitlement = api_params.get("entitlement") or current_entitlement - - if entitlement: - api_params["entitlement"] = entitlement - elif "entitlement" in api_params: - # Remove entitlement if it's None or empty - api_params.pop("entitlement", None) - - response = requests.get(API_BASE_URL, params=api_params) - response.raise_for_status() - - response_text = response.text - - # Check if response is JSON (error responses are typically JSON) - try: - response_json = json.loads(response_text) - # Check for rate limit error - if "Information" in response_json: - info_message = response_json["Information"] - if "rate limit" in info_message.lower() or "api key" in info_message.lower(): - raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}") - except json.JSONDecodeError: - # Response is not JSON (likely CSV data), which is normal - pass - - return response_text - - - -def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> str: - """ - Filter CSV data to include only rows within the specified date range. - - Args: - csv_data: CSV string from Alpha Vantage API - start_date: Start date in yyyy-mm-dd format - end_date: End date in yyyy-mm-dd format - - Returns: - Filtered CSV string - """ - if not csv_data or csv_data.strip() == "": - return csv_data - - try: - # Parse CSV data - df = pd.read_csv(StringIO(csv_data)) - - # Assume the first column is the date column (timestamp) - date_col = df.columns[0] - df[date_col] = pd.to_datetime(df[date_col]) - - # Filter by date range - start_dt = pd.to_datetime(start_date) - end_dt = pd.to_datetime(end_date) - - filtered_df = df[(df[date_col] >= start_dt) & (df[date_col] <= end_dt)] - - # Convert back to CSV string - return filtered_df.to_csv(index=False) - - except Exception as e: - # If filtering fails, return original data with a warning - print(f"Warning: Failed to filter CSV data by date range: {e}") - return csv_data +import os +import requests +import pandas as pd +import json +from datetime import datetime +from io import StringIO + +API_BASE_URL = "https://www.alphavantage.co/query" + +def get_api_key() -> str: + """Retrieve the API key for Alpha Vantage from environment variables.""" + api_key = os.getenv("ALPHA_VANTAGE_API_KEY") + if not api_key: + raise ValueError("ALPHA_VANTAGE_API_KEY environment variable is not set.") + return api_key + +def format_datetime_for_api(date_input) -> str: + """Convert various date formats to YYYYMMDDTHHMM format required by Alpha Vantage API.""" + if isinstance(date_input, str): + # If already in correct format, return as-is + if len(date_input) == 13 and 'T' in date_input: + return date_input + # Try to parse common date formats + try: + dt = datetime.strptime(date_input, "%Y-%m-%d") + return dt.strftime("%Y%m%dT0000") + except ValueError: + try: + dt = datetime.strptime(date_input, "%Y-%m-%d %H:%M") + return dt.strftime("%Y%m%dT%H%M") + except ValueError: + raise ValueError(f"Unsupported date format: {date_input}") + elif isinstance(date_input, datetime): + return date_input.strftime("%Y%m%dT%H%M") + else: + raise ValueError(f"Date must be string or datetime object, got {type(date_input)}") + +class AlphaVantageRateLimitError(Exception): + """Exception raised when Alpha Vantage API rate limit is exceeded.""" + pass + +def _make_api_request(function_name: str, params: dict) -> dict | str: + """Helper function to make API requests and handle responses. + + Raises: + AlphaVantageRateLimitError: When API rate limit is exceeded + """ + # Create a copy of params to avoid modifying the original + api_params = params.copy() + api_params.update({ + "function": function_name, + "apikey": get_api_key(), + "source": "trading_agents", + }) + + # Handle entitlement parameter if present in params or global variable + current_entitlement = globals().get('_current_entitlement') + entitlement = api_params.get("entitlement") or current_entitlement + + if entitlement: + api_params["entitlement"] = entitlement + elif "entitlement" in api_params: + # Remove entitlement if it's None or empty + api_params.pop("entitlement", None) + + response = requests.get(API_BASE_URL, params=api_params) + response.raise_for_status() + + response_text = response.text + + # Check if response is JSON (error responses are typically JSON) + try: + response_json = json.loads(response_text) + # Check for rate limit error + if "Information" in response_json: + info_message = response_json["Information"] + if "rate limit" in info_message.lower() or "api key" in info_message.lower(): + raise AlphaVantageRateLimitError(f"Alpha Vantage rate limit exceeded: {info_message}") + except json.JSONDecodeError: + # Response is not JSON (likely CSV data), which is normal + pass + + return response_text + + + +def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> str: + """ + Filter CSV data to include only rows within the specified date range. + + Args: + csv_data: CSV string from Alpha Vantage API + start_date: Start date in yyyy-mm-dd format + end_date: End date in yyyy-mm-dd format + + Returns: + Filtered CSV string + """ + if not csv_data or csv_data.strip() == "": + return csv_data + + try: + # Parse CSV data + df = pd.read_csv(StringIO(csv_data)) + + # Assume the first column is the date column (timestamp) + date_col = df.columns[0] + df[date_col] = pd.to_datetime(df[date_col]) + + # Filter by date range + start_dt = pd.to_datetime(start_date) + end_dt = pd.to_datetime(end_date) + + filtered_df = df[(df[date_col] >= start_dt) & (df[date_col] <= end_dt)] + + # Convert back to CSV string + return filtered_df.to_csv(index=False) + + except Exception as e: + # If filtering fails, return original data with a warning + print(f"Warning: Failed to filter CSV data by date range: {e}") + return csv_data diff --git a/tradingagents/dataflows/alpha_vantage_fundamentals.py b/tradingagents/dataflows/alpha_vantage_fundamentals.py index 8b92faa6..402fa2f0 100644 --- a/tradingagents/dataflows/alpha_vantage_fundamentals.py +++ b/tradingagents/dataflows/alpha_vantage_fundamentals.py @@ -1,77 +1,77 @@ -from .alpha_vantage_common import _make_api_request - - -def get_fundamentals(ticker: str, curr_date: str = None) -> str: - """ - Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage. - - Args: - ticker (str): Ticker symbol of the company - curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage) - - Returns: - str: Company overview data including financial ratios and key metrics - """ - params = { - "symbol": ticker, - } - - return _make_api_request("OVERVIEW", params) - - -def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: - """ - Retrieve balance sheet data for a given ticker symbol using Alpha Vantage. - - Args: - ticker (str): Ticker symbol of the company - freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage - curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage) - - Returns: - str: Balance sheet data with normalized fields - """ - params = { - "symbol": ticker, - } - - return _make_api_request("BALANCE_SHEET", params) - - -def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: - """ - Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage. - - Args: - ticker (str): Ticker symbol of the company - freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage - curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage) - - Returns: - str: Cash flow statement data with normalized fields - """ - params = { - "symbol": ticker, - } - - return _make_api_request("CASH_FLOW", params) - - -def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: - """ - Retrieve income statement data for a given ticker symbol using Alpha Vantage. - - Args: - ticker (str): Ticker symbol of the company - freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage - curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage) - - Returns: - str: Income statement data with normalized fields - """ - params = { - "symbol": ticker, - } - - return _make_api_request("INCOME_STATEMENT", params) - +from .alpha_vantage_common import _make_api_request + + +def get_fundamentals(ticker: str, curr_date: str = None) -> str: + """ + Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage. + + Args: + ticker (str): Ticker symbol of the company + curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage) + + Returns: + str: Company overview data including financial ratios and key metrics + """ + params = { + "symbol": ticker, + } + + return _make_api_request("OVERVIEW", params) + + +def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: + """ + Retrieve balance sheet data for a given ticker symbol using Alpha Vantage. + + Args: + ticker (str): Ticker symbol of the company + freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage + curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage) + + Returns: + str: Balance sheet data with normalized fields + """ + params = { + "symbol": ticker, + } + + return _make_api_request("BALANCE_SHEET", params) + + +def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: + """ + Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage. + + Args: + ticker (str): Ticker symbol of the company + freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage + curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage) + + Returns: + str: Cash flow statement data with normalized fields + """ + params = { + "symbol": ticker, + } + + return _make_api_request("CASH_FLOW", params) + + +def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str: + """ + Retrieve income statement data for a given ticker symbol using Alpha Vantage. + + Args: + ticker (str): Ticker symbol of the company + freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage + curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage) + + Returns: + str: Income statement data with normalized fields + """ + params = { + "symbol": ticker, + } + + return _make_api_request("INCOME_STATEMENT", params) + diff --git a/tradingagents/dataflows/alpha_vantage_indicator.py b/tradingagents/dataflows/alpha_vantage_indicator.py index 6225b9bb..8d5bf118 100644 --- a/tradingagents/dataflows/alpha_vantage_indicator.py +++ b/tradingagents/dataflows/alpha_vantage_indicator.py @@ -1,222 +1,222 @@ -from .alpha_vantage_common import _make_api_request - -def get_indicator( - symbol: str, - indicator: str, - curr_date: str, - look_back_days: int, - interval: str = "daily", - time_period: int = 14, - series_type: str = "close" -) -> str: - """ - Returns Alpha Vantage technical indicator values over a time window. - - Args: - symbol: ticker symbol of the company - indicator: technical indicator to get the analysis and report of - curr_date: The current trading date you are trading on, YYYY-mm-dd - look_back_days: how many days to look back - interval: Time interval (daily, weekly, monthly) - time_period: Number of data points for calculation - series_type: The desired price type (close, open, high, low) - - Returns: - String containing indicator values and description - """ - from datetime import datetime - from dateutil.relativedelta import relativedelta - - supported_indicators = { - "close_50_sma": ("50 SMA", "close"), - "close_200_sma": ("200 SMA", "close"), - "close_10_ema": ("10 EMA", "close"), - "macd": ("MACD", "close"), - "macds": ("MACD Signal", "close"), - "macdh": ("MACD Histogram", "close"), - "rsi": ("RSI", "close"), - "boll": ("Bollinger Middle", "close"), - "boll_ub": ("Bollinger Upper Band", "close"), - "boll_lb": ("Bollinger Lower Band", "close"), - "atr": ("ATR", None), - "vwma": ("VWMA", "close") - } - - indicator_descriptions = { - "close_50_sma": "50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.", - "close_200_sma": "200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.", - "close_10_ema": "10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.", - "macd": "MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets.", - "macds": "MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives.", - "macdh": "MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets.", - "rsi": "RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.", - "boll": "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.", - "boll_ub": "Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.", - "boll_lb": "Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.", - "atr": "ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.", - "vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses." - } - - if indicator not in supported_indicators: - raise ValueError( - f"Indicator {indicator} is not supported. Please choose from: {list(supported_indicators.keys())}" - ) - - curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") - before = curr_date_dt - relativedelta(days=look_back_days) - - # Get the full data for the period instead of making individual calls - _, required_series_type = supported_indicators[indicator] - - # Use the provided series_type or fall back to the required one - if required_series_type: - series_type = required_series_type - - try: - # Get indicator data for the period - if indicator == "close_50_sma": - data = _make_api_request("SMA", { - "symbol": symbol, - "interval": interval, - "time_period": "50", - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "close_200_sma": - data = _make_api_request("SMA", { - "symbol": symbol, - "interval": interval, - "time_period": "200", - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "close_10_ema": - data = _make_api_request("EMA", { - "symbol": symbol, - "interval": interval, - "time_period": "10", - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "macd": - data = _make_api_request("MACD", { - "symbol": symbol, - "interval": interval, - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "macds": - data = _make_api_request("MACD", { - "symbol": symbol, - "interval": interval, - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "macdh": - data = _make_api_request("MACD", { - "symbol": symbol, - "interval": interval, - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "rsi": - data = _make_api_request("RSI", { - "symbol": symbol, - "interval": interval, - "time_period": str(time_period), - "series_type": series_type, - "datatype": "csv" - }) - elif indicator in ["boll", "boll_ub", "boll_lb"]: - data = _make_api_request("BBANDS", { - "symbol": symbol, - "interval": interval, - "time_period": "20", - "series_type": series_type, - "datatype": "csv" - }) - elif indicator == "atr": - data = _make_api_request("ATR", { - "symbol": symbol, - "interval": interval, - "time_period": str(time_period), - "datatype": "csv" - }) - elif indicator == "vwma": - # Alpha Vantage doesn't have direct VWMA, so we'll return an informative message - # In a real implementation, this would need to be calculated from OHLCV data - return f"## VWMA (Volume Weighted Moving Average) for {symbol}:\n\nVWMA calculation requires OHLCV data and is not directly available from Alpha Vantage API.\nThis indicator would need to be calculated from the raw stock data using volume-weighted price averaging.\n\n{indicator_descriptions.get('vwma', 'No description available.')}" - else: - return f"Error: Indicator {indicator} not implemented yet." - - # Parse CSV data and extract values for the date range - lines = data.strip().split('\n') - if len(lines) < 2: - return f"Error: No data returned for {indicator}" - - # Parse header and data - header = [col.strip() for col in lines[0].split(',')] - try: - date_col_idx = header.index('time') - except ValueError: - return f"Error: 'time' column not found in data for {indicator}. Available columns: {header}" - - # Map internal indicator names to expected CSV column names from Alpha Vantage - col_name_map = { - "macd": "MACD", "macds": "MACD_Signal", "macdh": "MACD_Hist", - "boll": "Real Middle Band", "boll_ub": "Real Upper Band", "boll_lb": "Real Lower Band", - "rsi": "RSI", "atr": "ATR", "close_10_ema": "EMA", - "close_50_sma": "SMA", "close_200_sma": "SMA" - } - - target_col_name = col_name_map.get(indicator) - - if not target_col_name: - # Default to the second column if no specific mapping exists - value_col_idx = 1 - else: - try: - value_col_idx = header.index(target_col_name) - except ValueError: - return f"Error: Column '{target_col_name}' not found for indicator '{indicator}'. Available columns: {header}" - - result_data = [] - for line in lines[1:]: - if not line.strip(): - continue - values = line.split(',') - if len(values) > value_col_idx: - try: - date_str = values[date_col_idx].strip() - # Parse the date - date_dt = datetime.strptime(date_str, "%Y-%m-%d") - - # Check if date is in our range - if before <= date_dt <= curr_date_dt: - value = values[value_col_idx].strip() - result_data.append((date_dt, value)) - except (ValueError, IndexError): - continue - - # Sort by date and format output - result_data.sort(key=lambda x: x[0]) - - ind_string = "" - for date_dt, value in result_data: - ind_string += f"{date_dt.strftime('%Y-%m-%d')}: {value}\n" - - if not ind_string: - ind_string = "No data available for the specified date range.\n" - - result_str = ( - f"## {indicator.upper()} values from {before.strftime('%Y-%m-%d')} to {curr_date}:\n\n" - + ind_string - + "\n\n" - + indicator_descriptions.get(indicator, "No description available.") - ) - - return result_str - - except Exception as e: - print(f"Error getting Alpha Vantage indicator data for {indicator}: {e}") - return f"Error retrieving {indicator} data: {str(e)}" +from .alpha_vantage_common import _make_api_request + +def get_indicator( + symbol: str, + indicator: str, + curr_date: str, + look_back_days: int, + interval: str = "daily", + time_period: int = 14, + series_type: str = "close" +) -> str: + """ + Returns Alpha Vantage technical indicator values over a time window. + + Args: + symbol: ticker symbol of the company + indicator: technical indicator to get the analysis and report of + curr_date: The current trading date you are trading on, YYYY-mm-dd + look_back_days: how many days to look back + interval: Time interval (daily, weekly, monthly) + time_period: Number of data points for calculation + series_type: The desired price type (close, open, high, low) + + Returns: + String containing indicator values and description + """ + from datetime import datetime + from dateutil.relativedelta import relativedelta + + supported_indicators = { + "close_50_sma": ("50 SMA", "close"), + "close_200_sma": ("200 SMA", "close"), + "close_10_ema": ("10 EMA", "close"), + "macd": ("MACD", "close"), + "macds": ("MACD Signal", "close"), + "macdh": ("MACD Histogram", "close"), + "rsi": ("RSI", "close"), + "boll": ("Bollinger Middle", "close"), + "boll_ub": ("Bollinger Upper Band", "close"), + "boll_lb": ("Bollinger Lower Band", "close"), + "atr": ("ATR", None), + "vwma": ("VWMA", "close") + } + + indicator_descriptions = { + "close_50_sma": "50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.", + "close_200_sma": "200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.", + "close_10_ema": "10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.", + "macd": "MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets.", + "macds": "MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives.", + "macdh": "MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets.", + "rsi": "RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.", + "boll": "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.", + "boll_ub": "Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.", + "boll_lb": "Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.", + "atr": "ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.", + "vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses." + } + + if indicator not in supported_indicators: + raise ValueError( + f"Indicator {indicator} is not supported. Please choose from: {list(supported_indicators.keys())}" + ) + + curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") + before = curr_date_dt - relativedelta(days=look_back_days) + + # Get the full data for the period instead of making individual calls + _, required_series_type = supported_indicators[indicator] + + # Use the provided series_type or fall back to the required one + if required_series_type: + series_type = required_series_type + + try: + # Get indicator data for the period + if indicator == "close_50_sma": + data = _make_api_request("SMA", { + "symbol": symbol, + "interval": interval, + "time_period": "50", + "series_type": series_type, + "datatype": "csv" + }) + elif indicator == "close_200_sma": + data = _make_api_request("SMA", { + "symbol": symbol, + "interval": interval, + "time_period": "200", + "series_type": series_type, + "datatype": "csv" + }) + elif indicator == "close_10_ema": + data = _make_api_request("EMA", { + "symbol": symbol, + "interval": interval, + "time_period": "10", + "series_type": series_type, + "datatype": "csv" + }) + elif indicator == "macd": + data = _make_api_request("MACD", { + "symbol": symbol, + "interval": interval, + "series_type": series_type, + "datatype": "csv" + }) + elif indicator == "macds": + data = _make_api_request("MACD", { + "symbol": symbol, + "interval": interval, + "series_type": series_type, + "datatype": "csv" + }) + elif indicator == "macdh": + data = _make_api_request("MACD", { + "symbol": symbol, + "interval": interval, + "series_type": series_type, + "datatype": "csv" + }) + elif indicator == "rsi": + data = _make_api_request("RSI", { + "symbol": symbol, + "interval": interval, + "time_period": str(time_period), + "series_type": series_type, + "datatype": "csv" + }) + elif indicator in ["boll", "boll_ub", "boll_lb"]: + data = _make_api_request("BBANDS", { + "symbol": symbol, + "interval": interval, + "time_period": "20", + "series_type": series_type, + "datatype": "csv" + }) + elif indicator == "atr": + data = _make_api_request("ATR", { + "symbol": symbol, + "interval": interval, + "time_period": str(time_period), + "datatype": "csv" + }) + elif indicator == "vwma": + # Alpha Vantage doesn't have direct VWMA, so we'll return an informative message + # In a real implementation, this would need to be calculated from OHLCV data + return f"## VWMA (Volume Weighted Moving Average) for {symbol}:\n\nVWMA calculation requires OHLCV data and is not directly available from Alpha Vantage API.\nThis indicator would need to be calculated from the raw stock data using volume-weighted price averaging.\n\n{indicator_descriptions.get('vwma', 'No description available.')}" + else: + return f"Error: Indicator {indicator} not implemented yet." + + # Parse CSV data and extract values for the date range + lines = data.strip().split('\n') + if len(lines) < 2: + return f"Error: No data returned for {indicator}" + + # Parse header and data + header = [col.strip() for col in lines[0].split(',')] + try: + date_col_idx = header.index('time') + except ValueError: + return f"Error: 'time' column not found in data for {indicator}. Available columns: {header}" + + # Map internal indicator names to expected CSV column names from Alpha Vantage + col_name_map = { + "macd": "MACD", "macds": "MACD_Signal", "macdh": "MACD_Hist", + "boll": "Real Middle Band", "boll_ub": "Real Upper Band", "boll_lb": "Real Lower Band", + "rsi": "RSI", "atr": "ATR", "close_10_ema": "EMA", + "close_50_sma": "SMA", "close_200_sma": "SMA" + } + + target_col_name = col_name_map.get(indicator) + + if not target_col_name: + # Default to the second column if no specific mapping exists + value_col_idx = 1 + else: + try: + value_col_idx = header.index(target_col_name) + except ValueError: + return f"Error: Column '{target_col_name}' not found for indicator '{indicator}'. Available columns: {header}" + + result_data = [] + for line in lines[1:]: + if not line.strip(): + continue + values = line.split(',') + if len(values) > value_col_idx: + try: + date_str = values[date_col_idx].strip() + # Parse the date + date_dt = datetime.strptime(date_str, "%Y-%m-%d") + + # Check if date is in our range + if before <= date_dt <= curr_date_dt: + value = values[value_col_idx].strip() + result_data.append((date_dt, value)) + except (ValueError, IndexError): + continue + + # Sort by date and format output + result_data.sort(key=lambda x: x[0]) + + ind_string = "" + for date_dt, value in result_data: + ind_string += f"{date_dt.strftime('%Y-%m-%d')}: {value}\n" + + if not ind_string: + ind_string = "No data available for the specified date range.\n" + + result_str = ( + f"## {indicator.upper()} values from {before.strftime('%Y-%m-%d')} to {curr_date}:\n\n" + + ind_string + + "\n\n" + + indicator_descriptions.get(indicator, "No description available.") + ) + + return result_str + + except Exception as e: + print(f"Error getting Alpha Vantage indicator data for {indicator}: {e}") + return f"Error retrieving {indicator} data: {str(e)}" diff --git a/tradingagents/dataflows/alpha_vantage_market.py b/tradingagents/dataflows/alpha_vantage_market.py new file mode 100644 index 00000000..19f97311 --- /dev/null +++ b/tradingagents/dataflows/alpha_vantage_market.py @@ -0,0 +1,76 @@ +from .alpha_vantage_common import _make_api_request +import csv +import io + +def get_market_movers( + metric: str = "top_gainers", + limit: int = 10 +) -> str: + """ + Returns the top gainers, losers, or most active stocks from Alpha Vantage. + + Args: + metric: One of "top_gainers", "top_losers", "most_actively_traded" + limit: Number of results to return (default 10) + + Returns: + CSV string containing the market movers data. + """ + params = {} + response = _make_api_request("TOP_GAINERS_LOSERS", params) + + # The response is JSON for this endpoint, not CSV + # We need to parse it and convert to CSV format for consistency + import json + try: + data = json.loads(response) + except json.JSONDecodeError: + return f"Error parsing response: {response}" + + if metric not in data: + return f"Metric '{metric}' not found in response. Available: {list(data.keys())}" + + items = data[metric] + + # Sort just in case, though API usually returns sorted + # Note: "change_percentage" is a string like "10.5%", need to parse for sorting if needed + # But usually the API returns them sorted. + + items = items[:limit] + + if not items: + return f"No data found for metric '{metric}'" + + # Convert to CSV + output = io.StringIO() + writer = csv.writer(output) + + # Write header + if items: + writer.writerow(items[0].keys()) + + for item in items: + writer.writerow(item.values()) + + return output.getvalue() + +def get_earnings_calendar( + horizon: str = "3month", + symbol: str = None +) -> str: + """ + Returns the earnings calendar for the specified horizon or symbol. + + Args: + horizon: "3month", "6month", or "12month" (default "3month") + symbol: Optional ticker symbol to filter by + + Returns: + CSV string containing the earnings calendar. + """ + params = {"horizon": horizon} + if symbol: + params["symbol"] = symbol + + response = _make_api_request("EARNINGS_CALENDAR", params) + return response diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py index 8124fb45..ef3d58de 100644 --- a/tradingagents/dataflows/alpha_vantage_news.py +++ b/tradingagents/dataflows/alpha_vantage_news.py @@ -1,43 +1,43 @@ -from .alpha_vantage_common import _make_api_request, format_datetime_for_api - -def get_news(ticker, start_date, end_date) -> dict[str, str] | str: - """Returns live and historical market news & sentiment data from premier news outlets worldwide. - - Covers stocks, cryptocurrencies, forex, and topics like fiscal policy, mergers & acquisitions, IPOs. - - Args: - ticker: Stock symbol for news articles. - start_date: Start date for news search. - end_date: End date for news search. - - Returns: - Dictionary containing news sentiment data or JSON string. - """ - - params = { - "tickers": ticker, - "time_from": format_datetime_for_api(start_date), - "time_to": format_datetime_for_api(end_date), - "sort": "LATEST", - "limit": "50", - } - - return _make_api_request("NEWS_SENTIMENT", params) - -def get_insider_transactions(symbol: str) -> dict[str, str] | str: - """Returns latest and historical insider transactions by key stakeholders. - - Covers transactions by founders, executives, board members, etc. - - Args: - symbol: Ticker symbol. Example: "IBM". - - Returns: - Dictionary containing insider transaction data or JSON string. - """ - - params = { - "symbol": symbol, - } - +from .alpha_vantage_common import _make_api_request, format_datetime_for_api + +def get_news(ticker, start_date, end_date) -> dict[str, str] | str: + """Returns live and historical market news & sentiment data from premier news outlets worldwide. + + Covers stocks, cryptocurrencies, forex, and topics like fiscal policy, mergers & acquisitions, IPOs. + + Args: + ticker: Stock symbol for news articles. + start_date: Start date for news search. + end_date: End date for news search. + + Returns: + Dictionary containing news sentiment data or JSON string. + """ + + params = { + "tickers": ticker, + "time_from": format_datetime_for_api(start_date), + "time_to": format_datetime_for_api(end_date), + "sort": "LATEST", + "limit": "50", + } + + return _make_api_request("NEWS_SENTIMENT", params) + +def get_insider_transactions(symbol: str) -> dict[str, str] | str: + """Returns latest and historical insider transactions by key stakeholders. + + Covers transactions by founders, executives, board members, etc. + + Args: + symbol: Ticker symbol. Example: "IBM". + + Returns: + Dictionary containing insider transaction data or JSON string. + """ + + params = { + "symbol": symbol, + } + return _make_api_request("INSIDER_TRANSACTIONS", params) \ No newline at end of file diff --git a/tradingagents/dataflows/alpha_vantage_stock.py b/tradingagents/dataflows/alpha_vantage_stock.py index ffd3570b..fc0f7fc4 100644 --- a/tradingagents/dataflows/alpha_vantage_stock.py +++ b/tradingagents/dataflows/alpha_vantage_stock.py @@ -1,38 +1,38 @@ -from datetime import datetime -from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range - -def get_stock( - symbol: str, - start_date: str, - end_date: str -) -> str: - """ - Returns raw daily OHLCV values, adjusted close values, and historical split/dividend events - filtered to the specified date range. - - Args: - symbol: The name of the equity. For example: symbol=IBM - start_date: Start date in yyyy-mm-dd format - end_date: End date in yyyy-mm-dd format - - Returns: - CSV string containing the daily adjusted time series data filtered to the date range. - """ - # Parse dates to determine the range - start_dt = datetime.strptime(start_date, "%Y-%m-%d") - today = datetime.now() - - # Choose outputsize based on whether the requested range is within the latest 100 days - # Compact returns latest 100 data points, so check if start_date is recent enough - days_from_today_to_start = (today - start_dt).days - outputsize = "compact" if days_from_today_to_start < 100 else "full" - - params = { - "symbol": symbol, - "outputsize": outputsize, - "datatype": "csv", - } - - response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params) - +from datetime import datetime +from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range + +def get_stock( + symbol: str, + start_date: str, + end_date: str +) -> str: + """ + Returns raw daily OHLCV values, adjusted close values, and historical split/dividend events + filtered to the specified date range. + + Args: + symbol: The name of the equity. For example: symbol=IBM + start_date: Start date in yyyy-mm-dd format + end_date: End date in yyyy-mm-dd format + + Returns: + CSV string containing the daily adjusted time series data filtered to the date range. + """ + # Parse dates to determine the range + start_dt = datetime.strptime(start_date, "%Y-%m-%d") + today = datetime.now() + + # Choose outputsize based on whether the requested range is within the latest 100 days + # Compact returns latest 100 data points, so check if start_date is recent enough + days_from_today_to_start = (today - start_dt).days + outputsize = "compact" if days_from_today_to_start < 100 else "full" + + params = { + "symbol": symbol, + "outputsize": outputsize, + "datatype": "csv", + } + + response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params) + return _filter_csv_by_date_range(response, start_date, end_date) \ No newline at end of file diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index b8a8f8aa..3d233d64 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -1,34 +1,34 @@ -import tradingagents.default_config as default_config -from typing import Dict, Optional - -# Use default config but allow it to be overridden -_config: Optional[Dict] = None -DATA_DIR: Optional[str] = None - - -def initialize_config(): - """Initialize the configuration with default values.""" - global _config, DATA_DIR - if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() - DATA_DIR = _config["data_dir"] - - -def set_config(config: Dict): - """Update the configuration with custom values.""" - global _config, DATA_DIR - if _config is None: - _config = default_config.DEFAULT_CONFIG.copy() - _config.update(config) - DATA_DIR = _config["data_dir"] - - -def get_config() -> Dict: - """Get the current configuration.""" - if _config is None: - initialize_config() - return _config.copy() - - -# Initialize with default config -initialize_config() +import tradingagents.default_config as default_config +from typing import Dict, Optional + +# Use default config but allow it to be overridden +_config: Optional[Dict] = None +DATA_DIR: Optional[str] = None + + +def initialize_config(): + """Initialize the configuration with default values.""" + global _config, DATA_DIR + if _config is None: + _config = default_config.DEFAULT_CONFIG.copy() + DATA_DIR = _config["data_dir"] + + +def set_config(config: Dict): + """Update the configuration with custom values.""" + global _config, DATA_DIR + if _config is None: + _config = default_config.DEFAULT_CONFIG.copy() + _config.update(config) + DATA_DIR = _config["data_dir"] + + +def get_config() -> Dict: + """Get the current configuration.""" + if _config is None: + initialize_config() + return _config.copy() + + +# Initialize with default config +initialize_config() diff --git a/tradingagents/dataflows/google.py b/tradingagents/dataflows/google.py index 3fe20f3c..1a81f389 100644 --- a/tradingagents/dataflows/google.py +++ b/tradingagents/dataflows/google.py @@ -1,30 +1,30 @@ -from typing import Annotated -from datetime import datetime -from dateutil.relativedelta import relativedelta -from .googlenews_utils import getNewsData - - -def get_google_news( - query: Annotated[str, "Query to search with"], - curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"], - look_back_days: Annotated[int, "how many days to look back"], -) -> str: - query = query.replace(" ", "+") - - start_date = datetime.strptime(curr_date, "%Y-%m-%d") - before = start_date - relativedelta(days=look_back_days) - before = before.strftime("%Y-%m-%d") - - news_results = getNewsData(query, before, curr_date) - - news_str = "" - - for news in news_results: - news_str += ( - f"### {news['title']} (source: {news['source']}) \n\n{news['snippet']}\n\n" - ) - - if len(news_results) == 0: - return "" - +from typing import Annotated +from datetime import datetime +from dateutil.relativedelta import relativedelta +from .googlenews_utils import getNewsData + + +def get_google_news( + query: Annotated[str, "Query to search with"], + curr_date: Annotated[str, "Curr date in yyyy-mm-dd format"], + look_back_days: Annotated[int, "how many days to look back"], +) -> str: + query = query.replace(" ", "+") + + start_date = datetime.strptime(curr_date, "%Y-%m-%d") + before = start_date - relativedelta(days=look_back_days) + before = before.strftime("%Y-%m-%d") + + news_results = getNewsData(query, before, curr_date) + + news_str = "" + + for news in news_results: + news_str += ( + f"### {news['title']} (source: {news['source']}) \n\n{news['snippet']}\n\n" + ) + + if len(news_results) == 0: + return "" + return f"## {query} Google News, from {before} to {curr_date}:\n\n{news_str}" \ No newline at end of file diff --git a/tradingagents/dataflows/googlenews_utils.py b/tradingagents/dataflows/googlenews_utils.py index bdc6124d..6c799ea5 100644 --- a/tradingagents/dataflows/googlenews_utils.py +++ b/tradingagents/dataflows/googlenews_utils.py @@ -1,108 +1,108 @@ -import json -import requests -from bs4 import BeautifulSoup -from datetime import datetime -import time -import random -from tenacity import ( - retry, - stop_after_attempt, - wait_exponential, - retry_if_exception_type, - retry_if_result, -) - - -def is_rate_limited(response): - """Check if the response indicates rate limiting (status code 429)""" - return response.status_code == 429 - - -@retry( - retry=(retry_if_result(is_rate_limited)), - wait=wait_exponential(multiplier=1, min=4, max=60), - stop=stop_after_attempt(5), -) -def make_request(url, headers): - """Make a request with retry logic for rate limiting""" - # Random delay before each request to avoid detection - time.sleep(random.uniform(2, 6)) - response = requests.get(url, headers=headers) - return response - - -def getNewsData(query, start_date, end_date): - """ - Scrape Google News search results for a given query and date range. - query: str - search query - start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy - end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy - """ - if "-" in start_date: - start_date = datetime.strptime(start_date, "%Y-%m-%d") - start_date = start_date.strftime("%m/%d/%Y") - if "-" in end_date: - end_date = datetime.strptime(end_date, "%Y-%m-%d") - end_date = end_date.strftime("%m/%d/%Y") - - headers = { - "User-Agent": ( - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/101.0.4951.54 Safari/537.36" - ) - } - - news_results = [] - page = 0 - while True: - offset = page * 10 - url = ( - f"https://www.google.com/search?q={query}" - f"&tbs=cdr:1,cd_min:{start_date},cd_max:{end_date}" - f"&tbm=nws&start={offset}" - ) - - try: - response = make_request(url, headers) - soup = BeautifulSoup(response.content, "html.parser") - results_on_page = soup.select("div.SoaBEf") - - if not results_on_page: - break # No more results found - - for el in results_on_page: - try: - link = el.find("a")["href"] - title = el.select_one("div.MBeuO").get_text() - snippet = el.select_one(".GI74Re").get_text() - date = el.select_one(".LfVVr").get_text() - source = el.select_one(".NUnG9d span").get_text() - news_results.append( - { - "link": link, - "title": title, - "snippet": snippet, - "date": date, - "source": source, - } - ) - except Exception as e: - print(f"Error processing result: {e}") - # If one of the fields is not found, skip this result - continue - - # Update the progress bar with the current count of results scraped - - # Check for the "Next" link (pagination) - next_link = soup.find("a", id="pnnext") - if not next_link: - break - - page += 1 - - except Exception as e: - print(f"Failed after multiple retries: {e}") - break - - return news_results +import json +import requests +from bs4 import BeautifulSoup +from datetime import datetime +import time +import random +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + retry_if_result, +) + + +def is_rate_limited(response): + """Check if the response indicates rate limiting (status code 429)""" + return response.status_code == 429 + + +@retry( + retry=(retry_if_result(is_rate_limited)), + wait=wait_exponential(multiplier=1, min=4, max=60), + stop=stop_after_attempt(5), +) +def make_request(url, headers): + """Make a request with retry logic for rate limiting""" + # Random delay before each request to avoid detection + time.sleep(random.uniform(2, 6)) + response = requests.get(url, headers=headers) + return response + + +def getNewsData(query, start_date, end_date): + """ + Scrape Google News search results for a given query and date range. + query: str - search query + start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy + end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy + """ + if "-" in start_date: + start_date = datetime.strptime(start_date, "%Y-%m-%d") + start_date = start_date.strftime("%m/%d/%Y") + if "-" in end_date: + end_date = datetime.strptime(end_date, "%Y-%m-%d") + end_date = end_date.strftime("%m/%d/%Y") + + headers = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/101.0.4951.54 Safari/537.36" + ) + } + + news_results = [] + page = 0 + while True: + offset = page * 10 + url = ( + f"https://www.google.com/search?q={query}" + f"&tbs=cdr:1,cd_min:{start_date},cd_max:{end_date}" + f"&tbm=nws&start={offset}" + ) + + try: + response = make_request(url, headers) + soup = BeautifulSoup(response.content, "html.parser") + results_on_page = soup.select("div.SoaBEf") + + if not results_on_page: + break # No more results found + + for el in results_on_page: + try: + link = el.find("a")["href"] + title = el.select_one("div.MBeuO").get_text() + snippet = el.select_one(".GI74Re").get_text() + date = el.select_one(".LfVVr").get_text() + source = el.select_one(".NUnG9d span").get_text() + news_results.append( + { + "link": link, + "title": title, + "snippet": snippet, + "date": date, + "source": source, + } + ) + except Exception as e: + print(f"Error processing result: {e}") + # If one of the fields is not found, skip this result + continue + + # Update the progress bar with the current count of results scraped + + # Check for the "Next" link (pagination) + next_link = soup.find("a", id="pnnext") + if not next_link: + break + + page += 1 + + except Exception as e: + print(f"Failed after multiple retries: {e}") + break + + return news_results diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 4cd5ddef..71f97796 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -1,244 +1,261 @@ -from typing import Annotated - -# Import from vendor-specific modules -from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news -from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions -from .google import get_google_news -from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai -from .alpha_vantage import ( - get_stock as get_alpha_vantage_stock, - get_indicator as get_alpha_vantage_indicator, - get_fundamentals as get_alpha_vantage_fundamentals, - get_balance_sheet as get_alpha_vantage_balance_sheet, - get_cashflow as get_alpha_vantage_cashflow, - get_income_statement as get_alpha_vantage_income_statement, - get_insider_transactions as get_alpha_vantage_insider_transactions, - get_news as get_alpha_vantage_news -) -from .alpha_vantage_common import AlphaVantageRateLimitError - -# Configuration and routing logic -from .config import get_config - -# Tools organized by category -TOOLS_CATEGORIES = { - "core_stock_apis": { - "description": "OHLCV stock price data", - "tools": [ - "get_stock_data" - ] - }, - "technical_indicators": { - "description": "Technical analysis indicators", - "tools": [ - "get_indicators" - ] - }, - "fundamental_data": { - "description": "Company fundamentals", - "tools": [ - "get_fundamentals", - "get_balance_sheet", - "get_cashflow", - "get_income_statement" - ] - }, - "news_data": { - "description": "News (public/insiders, original/processed)", - "tools": [ - "get_news", - "get_global_news", - "get_insider_sentiment", - "get_insider_transactions", - ] - } -} - -VENDOR_LIST = [ - "local", - "yfinance", - "openai", - "google" -] - -# Mapping of methods to their vendor-specific implementations -VENDOR_METHODS = { - # core_stock_apis - "get_stock_data": { - "alpha_vantage": get_alpha_vantage_stock, - "yfinance": get_YFin_data_online, - "local": get_YFin_data, - }, - # technical_indicators - "get_indicators": { - "alpha_vantage": get_alpha_vantage_indicator, - "yfinance": get_stock_stats_indicators_window, - "local": get_stock_stats_indicators_window - }, - # fundamental_data - "get_fundamentals": { - "alpha_vantage": get_alpha_vantage_fundamentals, - "openai": get_fundamentals_openai, - }, - "get_balance_sheet": { - "alpha_vantage": get_alpha_vantage_balance_sheet, - "yfinance": get_yfinance_balance_sheet, - "local": get_simfin_balance_sheet, - }, - "get_cashflow": { - "alpha_vantage": get_alpha_vantage_cashflow, - "yfinance": get_yfinance_cashflow, - "local": get_simfin_cashflow, - }, - "get_income_statement": { - "alpha_vantage": get_alpha_vantage_income_statement, - "yfinance": get_yfinance_income_statement, - "local": get_simfin_income_statements, - }, - # news_data - "get_news": { - "alpha_vantage": get_alpha_vantage_news, - "openai": get_stock_news_openai, - "google": get_google_news, - "local": [get_finnhub_news, get_reddit_company_news, get_google_news], - }, - "get_global_news": { - "openai": get_global_news_openai, - "local": get_reddit_global_news - }, - "get_insider_sentiment": { - "local": get_finnhub_company_insider_sentiment - }, - "get_insider_transactions": { - "alpha_vantage": get_alpha_vantage_insider_transactions, - "yfinance": get_yfinance_insider_transactions, - "local": get_finnhub_company_insider_transactions, - }, -} - -def get_category_for_method(method: str) -> str: - """Get the category that contains the specified method.""" - for category, info in TOOLS_CATEGORIES.items(): - if method in info["tools"]: - return category - raise ValueError(f"Method '{method}' not found in any category") - -def get_vendor(category: str, method: str = None) -> str: - """Get the configured vendor for a data category or specific tool method. - Tool-level configuration takes precedence over category-level. - """ - config = get_config() - - # Check tool-level configuration first (if method provided) - if method: - tool_vendors = config.get("tool_vendors", {}) - if method in tool_vendors: - return tool_vendors[method] - - # Fall back to category-level configuration - return config.get("data_vendors", {}).get(category, "default") - -def route_to_vendor(method: str, *args, **kwargs): - """Route method calls to appropriate vendor implementation with fallback support.""" - category = get_category_for_method(method) - vendor_config = get_vendor(category, method) - - # Handle comma-separated vendors - primary_vendors = [v.strip() for v in vendor_config.split(',')] - - if method not in VENDOR_METHODS: - raise ValueError(f"Method '{method}' not supported") - - # Get all available vendors for this method for fallback - all_available_vendors = list(VENDOR_METHODS[method].keys()) - - # Create fallback vendor list: primary vendors first, then remaining vendors as fallbacks - fallback_vendors = primary_vendors.copy() - for vendor in all_available_vendors: - if vendor not in fallback_vendors: - fallback_vendors.append(vendor) - - # Debug: Print fallback ordering - primary_str = " โ†’ ".join(primary_vendors) - fallback_str = " โ†’ ".join(fallback_vendors) - print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]") - - # Track results and execution state - results = [] - vendor_attempt_count = 0 - any_primary_vendor_attempted = False - successful_vendor = None - - for vendor in fallback_vendors: - if vendor not in VENDOR_METHODS[method]: - if vendor in primary_vendors: - print(f"INFO: Vendor '{vendor}' not supported for method '{method}', falling back to next vendor") - continue - - vendor_impl = VENDOR_METHODS[method][vendor] - is_primary_vendor = vendor in primary_vendors - vendor_attempt_count += 1 - - # Track if we attempted any primary vendor - if is_primary_vendor: - any_primary_vendor_attempted = True - - # Debug: Print current attempt - vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK" - print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})") - - # Handle list of methods for a vendor - if isinstance(vendor_impl, list): - vendor_methods = [(impl, vendor) for impl in vendor_impl] - print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions") - else: - vendor_methods = [(vendor_impl, vendor)] - - # Run methods for this vendor - vendor_results = [] - for impl_func, vendor_name in vendor_methods: - try: - print(f"DEBUG: Calling {impl_func.__name__} from vendor '{vendor_name}'...") - result = impl_func(*args, **kwargs) - vendor_results.append(result) - print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully") - - except AlphaVantageRateLimitError as e: - if vendor == "alpha_vantage": - print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor") - print(f"DEBUG: Rate limit details: {e}") - # Continue to next vendor for fallback - continue - except Exception as e: - # Log error but continue with other implementations - print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}") - continue - - # Add this vendor's results - if vendor_results: - results.extend(vendor_results) - successful_vendor = vendor - result_summary = f"Got {len(vendor_results)} result(s)" - print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}") - - # Stopping logic: Stop after first successful vendor for single-vendor configs - # Multiple vendor configs (comma-separated) may want to collect from multiple sources - if len(primary_vendors) == 1: - print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)") - break - else: - print(f"FAILED: Vendor '{vendor}' produced no results") - - # Final result summary - if not results: - print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'") - raise RuntimeError(f"All vendor implementations failed for method '{method}'") - else: - print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)") - - # Return single result if only one, otherwise concatenate as string - if len(results) == 1: - return results[0] - else: - # Convert all results to strings and concatenate +from typing import Annotated + +# Import from vendor-specific modules +from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news +from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions +from .google import get_google_news +from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai +from .alpha_vantage import ( + get_stock as get_alpha_vantage_stock, + get_indicator as get_alpha_vantage_indicator, + get_fundamentals as get_alpha_vantage_fundamentals, + get_balance_sheet as get_alpha_vantage_balance_sheet, + get_cashflow as get_alpha_vantage_cashflow, + get_income_statement as get_alpha_vantage_income_statement, + get_news as get_alpha_vantage_news, + get_market_movers as get_alpha_vantage_market_movers, + get_earnings_calendar as get_alpha_vantage_earnings_calendar +) +from .social_sentiment import get_trending_social as get_social_trending +from .config import get_config +from .alpha_vantage_common import AlphaVantageRateLimitError + +TOOLS_CATEGORIES = { + "core_stock_apis": { + "description": "Core stock price and volume data", + "tools": [ + "get_stock_data", + "get_market_movers", + "get_earnings_calendar" + ] + }, + "technical_indicators": { + "description": "Technical analysis indicators", + "tools": [ + "get_indicators" + ] + }, + "fundamental_data": { + "description": "Company fundamentals", + "tools": [ + "get_fundamentals", + "get_balance_sheet", + "get_cashflow", + "get_income_statement" + ] + }, + "news_data": { + "description": "News (public/insiders, original/processed)", + "tools": [ + "get_news", + "get_global_news", + "get_insider_sentiment", + "get_insider_transactions", + ] + }, + "social_data": { + "description": "Social media trending and sentiment", + "tools": [ + "get_trending_social" + ] + } +} + +VENDOR_LIST = [ + "local", + "yfinance", + "openai", + "google", + "alpha_vantage" +] + +# Mapping of methods to their vendor-specific implementations +VENDOR_METHODS = { + # core_stock_apis + "get_stock_data": { + "alpha_vantage": get_alpha_vantage_stock, + "yfinance": get_YFin_data_online, + "local": get_YFin_data, + }, + "get_market_movers": { + "alpha_vantage": get_alpha_vantage_market_movers, + }, + "get_earnings_calendar": { + "alpha_vantage": get_alpha_vantage_earnings_calendar, + }, + # technical_indicators + "get_indicators": { + "alpha_vantage": get_alpha_vantage_indicator, + "yfinance": get_stock_stats_indicators_window, + "local": get_stock_stats_indicators_window + }, + # fundamental_data + "get_fundamentals": { + "alpha_vantage": get_alpha_vantage_fundamentals, + "openai": get_fundamentals_openai, + }, + "get_balance_sheet": { + "alpha_vantage": get_alpha_vantage_balance_sheet, + "yfinance": get_yfinance_balance_sheet, + "local": get_simfin_balance_sheet, + }, + "get_cashflow": { + "alpha_vantage": get_alpha_vantage_cashflow, + "yfinance": get_yfinance_cashflow, + "local": get_simfin_cashflow, + }, + "get_income_statement": { + "alpha_vantage": get_alpha_vantage_income_statement, + "yfinance": get_yfinance_income_statement, + "local": get_simfin_income_statements, + }, + # news_data + "get_news": { + "alpha_vantage": get_alpha_vantage_news, + "openai": get_stock_news_openai, + "google": get_google_news, + "local": [get_finnhub_news, get_reddit_company_news, get_google_news], + }, + "get_global_news": { + "openai": get_global_news_openai, + "local": get_reddit_global_news, + }, + "get_insider_sentiment": { + "local": get_finnhub_company_insider_sentiment, + }, + "get_insider_transactions": { + "yfinance": get_yfinance_insider_transactions, + "local": get_finnhub_company_insider_transactions, + }, + # social_data + "get_trending_social": { + "default": get_social_trending + } +} + +def get_category_for_method(method: str) -> str: + """Find which category a method belongs to.""" + for category, data in TOOLS_CATEGORIES.items(): + if method in data["tools"]: + return category + raise ValueError(f"Method '{method}' not found in any category") + +def get_vendor(category: str, method: str = None) -> str: + """Get the configured vendor for a data category or specific tool method. + Tool-level configuration takes precedence over category-level. + """ + config = get_config() + + # Check tool-level configuration first (if method provided) + if method: + tool_vendors = config.get("tool_vendors", {}) + if method in tool_vendors: + return tool_vendors[method] + + # Fall back to category-level configuration + return config.get("data_vendors", {}).get(category, "default") + +def route_to_vendor(method: str, *args, **kwargs): + """Route method calls to appropriate vendor implementation with fallback support.""" + category = get_category_for_method(method) + vendor_config = get_vendor(category, method) + + # Handle comma-separated vendors + primary_vendors = [v.strip() for v in vendor_config.split(',')] + + if method not in VENDOR_METHODS: + raise ValueError(f"Method '{method}' not supported") + + # Get all available vendors for this method for fallback + all_available_vendors = list(VENDOR_METHODS[method].keys()) + + # Create fallback vendor list: primary vendors first, then remaining vendors as fallbacks + fallback_vendors = primary_vendors.copy() + for vendor in all_available_vendors: + if vendor not in fallback_vendors: + fallback_vendors.append(vendor) + + # Debug: Print fallback ordering + primary_str = " โ†’ ".join(primary_vendors) + fallback_str = " โ†’ ".join(fallback_vendors) + print(f"DEBUG: {method} - Primary: [{primary_str}] | Full fallback order: [{fallback_str}]") + + # Track results and execution state + results = [] + vendor_attempt_count = 0 + any_primary_vendor_attempted = False + successful_vendor = None + + for vendor in fallback_vendors: + if vendor not in VENDOR_METHODS[method]: + if vendor in primary_vendors: + print(f"INFO: Vendor '{vendor}' not supported for method '{method}', falling back to next vendor") + continue + + vendor_impl = VENDOR_METHODS[method][vendor] + is_primary_vendor = vendor in primary_vendors + vendor_attempt_count += 1 + + # Track if we attempted any primary vendor + if is_primary_vendor: + any_primary_vendor_attempted = True + + # Debug: Print current attempt + vendor_type = "PRIMARY" if is_primary_vendor else "FALLBACK" + print(f"DEBUG: Attempting {vendor_type} vendor '{vendor}' for {method} (attempt #{vendor_attempt_count})") + + # Handle list of methods for a vendor + if isinstance(vendor_impl, list): + vendor_methods = [(impl, vendor) for impl in vendor_impl] + print(f"DEBUG: Vendor '{vendor}' has multiple implementations: {len(vendor_methods)} functions") + else: + vendor_methods = [(vendor_impl, vendor)] + + # Run methods for this vendor + vendor_results = [] + for impl_func, vendor_name in vendor_methods: + try: + print(f"DEBUG: Calling {impl_func.__name__} from vendor '{vendor_name}'...") + result = impl_func(*args, **kwargs) + vendor_results.append(result) + print(f"SUCCESS: {impl_func.__name__} from vendor '{vendor_name}' completed successfully") + + except AlphaVantageRateLimitError as e: + if vendor == "alpha_vantage": + print(f"RATE_LIMIT: Alpha Vantage rate limit exceeded, falling back to next available vendor") + print(f"DEBUG: Rate limit details: {e}") + # Continue to next vendor for fallback + continue + except Exception as e: + # Log error but continue with other implementations + print(f"FAILED: {impl_func.__name__} from vendor '{vendor_name}' failed: {e}") + continue + + # Add this vendor's results + if vendor_results: + results.extend(vendor_results) + successful_vendor = vendor + result_summary = f"Got {len(vendor_results)} result(s)" + print(f"SUCCESS: Vendor '{vendor}' succeeded - {result_summary}") + + # Stopping logic: Stop after first successful vendor for single-vendor configs + # Multiple vendor configs (comma-separated) may want to collect from multiple sources + if len(primary_vendors) == 1: + print(f"DEBUG: Stopping after successful vendor '{vendor}' (single-vendor config)") + break + else: + print(f"FAILED: Vendor '{vendor}' produced no results") + + # Final result summary + if not results: + print(f"FAILURE: All {vendor_attempt_count} vendor attempts failed for method '{method}'") + raise RuntimeError(f"All vendor implementations failed for method '{method}'") + else: + print(f"FINAL: Method '{method}' completed with {len(results)} result(s) from {vendor_attempt_count} vendor attempt(s)") + + # Return single result if only one, otherwise concatenate as string + if len(results) == 1: + return results[0] + else: + # Convert all results to strings and concatenate return '\n'.join(str(result) for result in results) \ No newline at end of file diff --git a/tradingagents/dataflows/local.py b/tradingagents/dataflows/local.py index 502bc43a..6dde9f2e 100644 --- a/tradingagents/dataflows/local.py +++ b/tradingagents/dataflows/local.py @@ -1,475 +1,475 @@ -from typing import Annotated -import pandas as pd -import os -from .config import DATA_DIR -from datetime import datetime -from dateutil.relativedelta import relativedelta -import json -from .reddit_utils import fetch_top_from_category -from tqdm import tqdm - -def get_YFin_data_window( - symbol: Annotated[str, "ticker symbol of the company"], - curr_date: Annotated[str, "Start date in yyyy-mm-dd format"], - look_back_days: Annotated[int, "how many days to look back"], -) -> str: - # calculate past days - date_obj = datetime.strptime(curr_date, "%Y-%m-%d") - before = date_obj - relativedelta(days=look_back_days) - start_date = before.strftime("%Y-%m-%d") - - # read in data - data = pd.read_csv( - os.path.join( - DATA_DIR, - f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) - ) - - # Extract just the date part for comparison - data["DateOnly"] = data["Date"].str[:10] - - # Filter data between the start and end dates (inclusive) - filtered_data = data[ - (data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date) - ] - - # Drop the temporary column we created - filtered_data = filtered_data.drop("DateOnly", axis=1) - - # Set pandas display options to show the full DataFrame - with pd.option_context( - "display.max_rows", None, "display.max_columns", None, "display.width", None - ): - df_string = filtered_data.to_string() - - return ( - f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n" - + df_string - ) - -def get_YFin_data( - symbol: Annotated[str, "ticker symbol of the company"], - start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "End date in yyyy-mm-dd format"], -) -> str: - # read in data - data = pd.read_csv( - os.path.join( - DATA_DIR, - f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) - ) - - if end_date > "2025-03-25": - raise Exception( - f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25" - ) - - # Extract just the date part for comparison - data["DateOnly"] = data["Date"].str[:10] - - # Filter data between the start and end dates (inclusive) - filtered_data = data[ - (data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date) - ] - - # Drop the temporary column we created - filtered_data = filtered_data.drop("DateOnly", axis=1) - - # remove the index from the dataframe - filtered_data = filtered_data.reset_index(drop=True) - - return filtered_data - -def get_finnhub_news( - query: Annotated[str, "Search query or ticker symbol"], - start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "End date in yyyy-mm-dd format"], -): - """ - Retrieve news about a company within a time frame - - Args - query (str): Search query or ticker symbol - start_date (str): Start date in yyyy-mm-dd format - end_date (str): End date in yyyy-mm-dd format - Returns - str: dataframe containing the news of the company in the time frame - - """ - - result = get_data_in_range(query, start_date, end_date, "news_data", DATA_DIR) - - if len(result) == 0: - return "" - - combined_result = "" - for day, data in result.items(): - if len(data) == 0: - continue - for entry in data: - current_news = ( - "### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"] - ) - combined_result += current_news + "\n\n" - - return f"## {query} News, from {start_date} to {end_date}:\n" + str(combined_result) - - -def get_finnhub_company_insider_sentiment( - ticker: Annotated[str, "ticker symbol for the company"], - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): - """ - Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days - Args: - ticker (str): ticker symbol of the company - curr_date (str): current date you are trading on, yyyy-mm-dd - Returns: - str: a report of the sentiment in the past 15 days starting at curr_date - """ - - date_obj = datetime.strptime(curr_date, "%Y-%m-%d") - before = date_obj - relativedelta(days=15) # Default 15 days lookback - before = before.strftime("%Y-%m-%d") - - data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR) - - if len(data) == 0: - return "" - - result_str = "" - seen_dicts = [] - for date, senti_list in data.items(): - for entry in senti_list: - if entry not in seen_dicts: - result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n" - seen_dicts.append(entry) - - return ( - f"## {ticker} Insider Sentiment Data for {before} to {curr_date}:\n" - + result_str - + "The change field refers to the net buying/selling from all insiders' transactions. The mspr field refers to monthly share purchase ratio." - ) - - -def get_finnhub_company_insider_transactions( - ticker: Annotated[str, "ticker symbol"], - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): - """ - Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days - Args: - ticker (str): ticker symbol of the company - curr_date (str): current date you are trading at, yyyy-mm-dd - Returns: - str: a report of the company's insider transaction/trading informtaion in the past 15 days - """ - - date_obj = datetime.strptime(curr_date, "%Y-%m-%d") - before = date_obj - relativedelta(days=15) # Default 15 days lookback - before = before.strftime("%Y-%m-%d") - - data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR) - - if len(data) == 0: - return "" - - result_str = "" - - seen_dicts = [] - for date, senti_list in data.items(): - for entry in senti_list: - if entry not in seen_dicts: - result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n" - seen_dicts.append(entry) - - return ( - f"## {ticker} insider transactions from {before} to {curr_date}:\n" - + result_str - + "The change field reflects the variation in share countโ€”here a negative number indicates a reduction in holdingsโ€”while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction." - ) - -def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None): - """ - Gets finnhub data saved and processed on disk. - Args: - start_date (str): Start date in YYYY-MM-DD format. - end_date (str): End date in YYYY-MM-DD format. - data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported. - data_dir (str): Directory where the data is saved. - period (str): Default to none, if there is a period specified, should be annual or quarterly. - """ - - if period: - data_path = os.path.join( - data_dir, - "finnhub_data", - data_type, - f"{ticker}_{period}_data_formatted.json", - ) - else: - data_path = os.path.join( - data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json" - ) - - data = open(data_path, "r") - data = json.load(data) - - # filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD) - filtered_data = {} - for key, value in data.items(): - if start_date <= key <= end_date and len(value) > 0: - filtered_data[key] = value - return filtered_data - -def get_simfin_balance_sheet( - ticker: Annotated[str, "ticker symbol"], - freq: Annotated[ - str, - "reporting frequency of the company's financial history: annual / quarterly", - ], - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): - data_path = os.path.join( - DATA_DIR, - "fundamental_data", - "simfin_data_all", - "balance_sheet", - "companies", - "us", - f"us-balance-{freq}.csv", - ) - df = pd.read_csv(data_path, sep=";") - - # Convert date strings to datetime objects and remove any time components - df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() - df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() - - # Convert the current date to datetime and normalize - curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() - - # Filter the DataFrame for the given ticker and for reports that were published on or before the current date - filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] - - # Check if there are any available reports; if not, return a notification - if filtered_df.empty: - print("No balance sheet available before the given current date.") - return "" - - # Get the most recent balance sheet by selecting the row with the latest Publish Date - latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()] - - # drop the SimFinID column - latest_balance_sheet = latest_balance_sheet.drop("SimFinId") - - return ( - f"## {freq} balance sheet for {ticker} released on {str(latest_balance_sheet['Publish Date'])[0:10]}: \n" - + str(latest_balance_sheet) - + "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of assets, liabilities, and equity. Assets are grouped as current (liquid items like cash and receivables) and noncurrent (long-term investments and property). Liabilities are split between short-term obligations and long-term debts, while equity reflects shareholder funds such as paid-in capital and retained earnings. Together, these components ensure that total assets equal the sum of liabilities and equity." - ) - - -def get_simfin_cashflow( - ticker: Annotated[str, "ticker symbol"], - freq: Annotated[ - str, - "reporting frequency of the company's financial history: annual / quarterly", - ], - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): - data_path = os.path.join( - DATA_DIR, - "fundamental_data", - "simfin_data_all", - "cash_flow", - "companies", - "us", - f"us-cashflow-{freq}.csv", - ) - df = pd.read_csv(data_path, sep=";") - - # Convert date strings to datetime objects and remove any time components - df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() - df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() - - # Convert the current date to datetime and normalize - curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() - - # Filter the DataFrame for the given ticker and for reports that were published on or before the current date - filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] - - # Check if there are any available reports; if not, return a notification - if filtered_df.empty: - print("No cash flow statement available before the given current date.") - return "" - - # Get the most recent cash flow statement by selecting the row with the latest Publish Date - latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()] - - # drop the SimFinID column - latest_cash_flow = latest_cash_flow.drop("SimFinId") - - return ( - f"## {freq} cash flow statement for {ticker} released on {str(latest_cash_flow['Publish Date'])[0:10]}: \n" - + str(latest_cash_flow) - + "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of cash movements. Operating activities show cash generated from core business operations, including net income adjustments for non-cash items and working capital changes. Investing activities cover asset acquisitions/disposals and investments. Financing activities include debt transactions, equity issuances/repurchases, and dividend payments. The net change in cash represents the overall increase or decrease in the company's cash position during the reporting period." - ) - - -def get_simfin_income_statements( - ticker: Annotated[str, "ticker symbol"], - freq: Annotated[ - str, - "reporting frequency of the company's financial history: annual / quarterly", - ], - curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], -): - data_path = os.path.join( - DATA_DIR, - "fundamental_data", - "simfin_data_all", - "income_statements", - "companies", - "us", - f"us-income-{freq}.csv", - ) - df = pd.read_csv(data_path, sep=";") - - # Convert date strings to datetime objects and remove any time components - df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() - df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() - - # Convert the current date to datetime and normalize - curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() - - # Filter the DataFrame for the given ticker and for reports that were published on or before the current date - filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] - - # Check if there are any available reports; if not, return a notification - if filtered_df.empty: - print("No income statement available before the given current date.") - return "" - - # Get the most recent income statement by selecting the row with the latest Publish Date - latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()] - - # drop the SimFinID column - latest_income = latest_income.drop("SimFinId") - - return ( - f"## {freq} income statement for {ticker} released on {str(latest_income['Publish Date'])[0:10]}: \n" - + str(latest_income) - + "\n\nThis includes metadata like reporting dates and currency, share details, and a comprehensive breakdown of the company's financial performance. Starting with Revenue, it shows Cost of Revenue and resulting Gross Profit. Operating Expenses are detailed, including SG&A, R&D, and Depreciation. The statement then shows Operating Income, followed by non-operating items and Interest Expense, leading to Pretax Income. After accounting for Income Tax and any Extraordinary items, it concludes with Net Income, representing the company's bottom-line profit or loss for the period." - ) - - -def get_reddit_global_news( - curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], - look_back_days: Annotated[int, "Number of days to look back"] = 7, - limit: Annotated[int, "Maximum number of articles to return"] = 5, -) -> str: - """ - Retrieve the latest top reddit news - Args: - curr_date: Current date in yyyy-mm-dd format - look_back_days: Number of days to look back (default 7) - limit: Maximum number of articles to return (default 5) - Returns: - str: A formatted string containing the latest news articles posts on reddit - """ - - curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") - before = curr_date_dt - relativedelta(days=look_back_days) - before = before.strftime("%Y-%m-%d") - - posts = [] - # iterate from before to curr_date - curr_iter_date = datetime.strptime(before, "%Y-%m-%d") - - total_iterations = (curr_date_dt - curr_iter_date).days + 1 - pbar = tqdm(desc=f"Getting Global News on {curr_date}", total=total_iterations) - - while curr_iter_date <= curr_date_dt: - curr_date_str = curr_iter_date.strftime("%Y-%m-%d") - fetch_result = fetch_top_from_category( - "global_news", - curr_date_str, - limit, - data_path=os.path.join(DATA_DIR, "reddit_data"), - ) - posts.extend(fetch_result) - curr_iter_date += relativedelta(days=1) - pbar.update(1) - - pbar.close() - - if len(posts) == 0: - return "" - - news_str = "" - for post in posts: - if post["content"] == "": - news_str += f"### {post['title']}\n\n" - else: - news_str += f"### {post['title']}\n\n{post['content']}\n\n" - - return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}" - - -def get_reddit_company_news( - query: Annotated[str, "Search query or ticker symbol"], - start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "End date in yyyy-mm-dd format"], -) -> str: - """ - Retrieve the latest top reddit news - Args: - query: Search query or ticker symbol - start_date: Start date in yyyy-mm-dd format - end_date: End date in yyyy-mm-dd format - Returns: - str: A formatted string containing news articles posts on reddit - """ - - start_date_dt = datetime.strptime(start_date, "%Y-%m-%d") - end_date_dt = datetime.strptime(end_date, "%Y-%m-%d") - - posts = [] - # iterate from start_date to end_date - curr_date = start_date_dt - - total_iterations = (end_date_dt - curr_date).days + 1 - pbar = tqdm( - desc=f"Getting Company News for {query} from {start_date} to {end_date}", - total=total_iterations, - ) - - while curr_date <= end_date_dt: - curr_date_str = curr_date.strftime("%Y-%m-%d") - fetch_result = fetch_top_from_category( - "company_news", - curr_date_str, - 10, # max limit per day - query, - data_path=os.path.join(DATA_DIR, "reddit_data"), - ) - posts.extend(fetch_result) - curr_date += relativedelta(days=1) - - pbar.update(1) - - pbar.close() - - if len(posts) == 0: - return "" - - news_str = "" - for post in posts: - if post["content"] == "": - news_str += f"### {post['title']}\n\n" - else: - news_str += f"### {post['title']}\n\n{post['content']}\n\n" - +from typing import Annotated +import pandas as pd +import os +from .config import DATA_DIR +from datetime import datetime +from dateutil.relativedelta import relativedelta +import json +from .reddit_utils import fetch_top_from_category +from tqdm import tqdm + +def get_YFin_data_window( + symbol: Annotated[str, "ticker symbol of the company"], + curr_date: Annotated[str, "Start date in yyyy-mm-dd format"], + look_back_days: Annotated[int, "how many days to look back"], +) -> str: + # calculate past days + date_obj = datetime.strptime(curr_date, "%Y-%m-%d") + before = date_obj - relativedelta(days=look_back_days) + start_date = before.strftime("%Y-%m-%d") + + # read in data + data = pd.read_csv( + os.path.join( + DATA_DIR, + f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + + # Extract just the date part for comparison + data["DateOnly"] = data["Date"].str[:10] + + # Filter data between the start and end dates (inclusive) + filtered_data = data[ + (data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date) + ] + + # Drop the temporary column we created + filtered_data = filtered_data.drop("DateOnly", axis=1) + + # Set pandas display options to show the full DataFrame + with pd.option_context( + "display.max_rows", None, "display.max_columns", None, "display.width", None + ): + df_string = filtered_data.to_string() + + return ( + f"## Raw Market Data for {symbol} from {start_date} to {curr_date}:\n\n" + + df_string + ) + +def get_YFin_data( + symbol: Annotated[str, "ticker symbol of the company"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], +) -> str: + # read in data + data = pd.read_csv( + os.path.join( + DATA_DIR, + f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + + if end_date > "2025-03-25": + raise Exception( + f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25" + ) + + # Extract just the date part for comparison + data["DateOnly"] = data["Date"].str[:10] + + # Filter data between the start and end dates (inclusive) + filtered_data = data[ + (data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date) + ] + + # Drop the temporary column we created + filtered_data = filtered_data.drop("DateOnly", axis=1) + + # remove the index from the dataframe + filtered_data = filtered_data.reset_index(drop=True) + + return filtered_data + +def get_finnhub_news( + query: Annotated[str, "Search query or ticker symbol"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], +): + """ + Retrieve news about a company within a time frame + + Args + query (str): Search query or ticker symbol + start_date (str): Start date in yyyy-mm-dd format + end_date (str): End date in yyyy-mm-dd format + Returns + str: dataframe containing the news of the company in the time frame + + """ + + result = get_data_in_range(query, start_date, end_date, "news_data", DATA_DIR) + + if len(result) == 0: + return "" + + combined_result = "" + for day, data in result.items(): + if len(data) == 0: + continue + for entry in data: + current_news = ( + "### " + entry["headline"] + f" ({day})" + "\n" + entry["summary"] + ) + combined_result += current_news + "\n\n" + + return f"## {query} News, from {start_date} to {end_date}:\n" + str(combined_result) + + +def get_finnhub_company_insider_sentiment( + ticker: Annotated[str, "ticker symbol for the company"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +): + """ + Retrieve insider sentiment about a company (retrieved from public SEC information) for the past 15 days + Args: + ticker (str): ticker symbol of the company + curr_date (str): current date you are trading on, yyyy-mm-dd + Returns: + str: a report of the sentiment in the past 15 days starting at curr_date + """ + + date_obj = datetime.strptime(curr_date, "%Y-%m-%d") + before = date_obj - relativedelta(days=15) # Default 15 days lookback + before = before.strftime("%Y-%m-%d") + + data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR) + + if len(data) == 0: + return "" + + result_str = "" + seen_dicts = [] + for date, senti_list in data.items(): + for entry in senti_list: + if entry not in seen_dicts: + result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n" + seen_dicts.append(entry) + + return ( + f"## {ticker} Insider Sentiment Data for {before} to {curr_date}:\n" + + result_str + + "The change field refers to the net buying/selling from all insiders' transactions. The mspr field refers to monthly share purchase ratio." + ) + + +def get_finnhub_company_insider_transactions( + ticker: Annotated[str, "ticker symbol"], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +): + """ + Retrieve insider transcaction information about a company (retrieved from public SEC information) for the past 15 days + Args: + ticker (str): ticker symbol of the company + curr_date (str): current date you are trading at, yyyy-mm-dd + Returns: + str: a report of the company's insider transaction/trading informtaion in the past 15 days + """ + + date_obj = datetime.strptime(curr_date, "%Y-%m-%d") + before = date_obj - relativedelta(days=15) # Default 15 days lookback + before = before.strftime("%Y-%m-%d") + + data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR) + + if len(data) == 0: + return "" + + result_str = "" + + seen_dicts = [] + for date, senti_list in data.items(): + for entry in senti_list: + if entry not in seen_dicts: + result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n" + seen_dicts.append(entry) + + return ( + f"## {ticker} insider transactions from {before} to {curr_date}:\n" + + result_str + + "The change field reflects the variation in share countโ€”here a negative number indicates a reduction in holdingsโ€”while share specifies the total number of shares involved. The transactionPrice denotes the per-share price at which the trade was executed, and transactionDate marks when the transaction occurred. The name field identifies the insider making the trade, and transactionCode (e.g., S for sale) clarifies the nature of the transaction. FilingDate records when the transaction was officially reported, and the unique id links to the specific SEC filing, as indicated by the source. Additionally, the symbol ties the transaction to a particular company, isDerivative flags whether the trade involves derivative securities, and currency notes the currency context of the transaction." + ) + +def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None): + """ + Gets finnhub data saved and processed on disk. + Args: + start_date (str): Start date in YYYY-MM-DD format. + end_date (str): End date in YYYY-MM-DD format. + data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported. + data_dir (str): Directory where the data is saved. + period (str): Default to none, if there is a period specified, should be annual or quarterly. + """ + + if period: + data_path = os.path.join( + data_dir, + "finnhub_data", + data_type, + f"{ticker}_{period}_data_formatted.json", + ) + else: + data_path = os.path.join( + data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json" + ) + + data = open(data_path, "r") + data = json.load(data) + + # filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD) + filtered_data = {} + for key, value in data.items(): + if start_date <= key <= end_date and len(value) > 0: + filtered_data[key] = value + return filtered_data + +def get_simfin_balance_sheet( + ticker: Annotated[str, "ticker symbol"], + freq: Annotated[ + str, + "reporting frequency of the company's financial history: annual / quarterly", + ], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +): + data_path = os.path.join( + DATA_DIR, + "fundamental_data", + "simfin_data_all", + "balance_sheet", + "companies", + "us", + f"us-balance-{freq}.csv", + ) + df = pd.read_csv(data_path, sep=";") + + # Convert date strings to datetime objects and remove any time components + df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() + df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() + + # Convert the current date to datetime and normalize + curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() + + # Filter the DataFrame for the given ticker and for reports that were published on or before the current date + filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] + + # Check if there are any available reports; if not, return a notification + if filtered_df.empty: + print("No balance sheet available before the given current date.") + return "" + + # Get the most recent balance sheet by selecting the row with the latest Publish Date + latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()] + + # drop the SimFinID column + latest_balance_sheet = latest_balance_sheet.drop("SimFinId") + + return ( + f"## {freq} balance sheet for {ticker} released on {str(latest_balance_sheet['Publish Date'])[0:10]}: \n" + + str(latest_balance_sheet) + + "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of assets, liabilities, and equity. Assets are grouped as current (liquid items like cash and receivables) and noncurrent (long-term investments and property). Liabilities are split between short-term obligations and long-term debts, while equity reflects shareholder funds such as paid-in capital and retained earnings. Together, these components ensure that total assets equal the sum of liabilities and equity." + ) + + +def get_simfin_cashflow( + ticker: Annotated[str, "ticker symbol"], + freq: Annotated[ + str, + "reporting frequency of the company's financial history: annual / quarterly", + ], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +): + data_path = os.path.join( + DATA_DIR, + "fundamental_data", + "simfin_data_all", + "cash_flow", + "companies", + "us", + f"us-cashflow-{freq}.csv", + ) + df = pd.read_csv(data_path, sep=";") + + # Convert date strings to datetime objects and remove any time components + df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() + df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() + + # Convert the current date to datetime and normalize + curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() + + # Filter the DataFrame for the given ticker and for reports that were published on or before the current date + filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] + + # Check if there are any available reports; if not, return a notification + if filtered_df.empty: + print("No cash flow statement available before the given current date.") + return "" + + # Get the most recent cash flow statement by selecting the row with the latest Publish Date + latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()] + + # drop the SimFinID column + latest_cash_flow = latest_cash_flow.drop("SimFinId") + + return ( + f"## {freq} cash flow statement for {ticker} released on {str(latest_cash_flow['Publish Date'])[0:10]}: \n" + + str(latest_cash_flow) + + "\n\nThis includes metadata like reporting dates and currency, share details, and a breakdown of cash movements. Operating activities show cash generated from core business operations, including net income adjustments for non-cash items and working capital changes. Investing activities cover asset acquisitions/disposals and investments. Financing activities include debt transactions, equity issuances/repurchases, and dividend payments. The net change in cash represents the overall increase or decrease in the company's cash position during the reporting period." + ) + + +def get_simfin_income_statements( + ticker: Annotated[str, "ticker symbol"], + freq: Annotated[ + str, + "reporting frequency of the company's financial history: annual / quarterly", + ], + curr_date: Annotated[str, "current date you are trading at, yyyy-mm-dd"], +): + data_path = os.path.join( + DATA_DIR, + "fundamental_data", + "simfin_data_all", + "income_statements", + "companies", + "us", + f"us-income-{freq}.csv", + ) + df = pd.read_csv(data_path, sep=";") + + # Convert date strings to datetime objects and remove any time components + df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize() + df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize() + + # Convert the current date to datetime and normalize + curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize() + + # Filter the DataFrame for the given ticker and for reports that were published on or before the current date + filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)] + + # Check if there are any available reports; if not, return a notification + if filtered_df.empty: + print("No income statement available before the given current date.") + return "" + + # Get the most recent income statement by selecting the row with the latest Publish Date + latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()] + + # drop the SimFinID column + latest_income = latest_income.drop("SimFinId") + + return ( + f"## {freq} income statement for {ticker} released on {str(latest_income['Publish Date'])[0:10]}: \n" + + str(latest_income) + + "\n\nThis includes metadata like reporting dates and currency, share details, and a comprehensive breakdown of the company's financial performance. Starting with Revenue, it shows Cost of Revenue and resulting Gross Profit. Operating Expenses are detailed, including SG&A, R&D, and Depreciation. The statement then shows Operating Income, followed by non-operating items and Interest Expense, leading to Pretax Income. After accounting for Income Tax and any Extraordinary items, it concludes with Net Income, representing the company's bottom-line profit or loss for the period." + ) + + +def get_reddit_global_news( + curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], + look_back_days: Annotated[int, "Number of days to look back"] = 7, + limit: Annotated[int, "Maximum number of articles to return"] = 5, +) -> str: + """ + Retrieve the latest top reddit news + Args: + curr_date: Current date in yyyy-mm-dd format + look_back_days: Number of days to look back (default 7) + limit: Maximum number of articles to return (default 5) + Returns: + str: A formatted string containing the latest news articles posts on reddit + """ + + curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") + before = curr_date_dt - relativedelta(days=look_back_days) + before = before.strftime("%Y-%m-%d") + + posts = [] + # iterate from before to curr_date + curr_iter_date = datetime.strptime(before, "%Y-%m-%d") + + total_iterations = (curr_date_dt - curr_iter_date).days + 1 + pbar = tqdm(desc=f"Getting Global News on {curr_date}", total=total_iterations) + + while curr_iter_date <= curr_date_dt: + curr_date_str = curr_iter_date.strftime("%Y-%m-%d") + fetch_result = fetch_top_from_category( + "global_news", + curr_date_str, + limit, + data_path=os.path.join(DATA_DIR, "reddit_data"), + ) + posts.extend(fetch_result) + curr_iter_date += relativedelta(days=1) + pbar.update(1) + + pbar.close() + + if len(posts) == 0: + return "" + + news_str = "" + for post in posts: + if post["content"] == "": + news_str += f"### {post['title']}\n\n" + else: + news_str += f"### {post['title']}\n\n{post['content']}\n\n" + + return f"## Global News Reddit, from {before} to {curr_date}:\n{news_str}" + + +def get_reddit_company_news( + query: Annotated[str, "Search query or ticker symbol"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], +) -> str: + """ + Retrieve the latest top reddit news + Args: + query: Search query or ticker symbol + start_date: Start date in yyyy-mm-dd format + end_date: End date in yyyy-mm-dd format + Returns: + str: A formatted string containing news articles posts on reddit + """ + + start_date_dt = datetime.strptime(start_date, "%Y-%m-%d") + end_date_dt = datetime.strptime(end_date, "%Y-%m-%d") + + posts = [] + # iterate from start_date to end_date + curr_date = start_date_dt + + total_iterations = (end_date_dt - curr_date).days + 1 + pbar = tqdm( + desc=f"Getting Company News for {query} from {start_date} to {end_date}", + total=total_iterations, + ) + + while curr_date <= end_date_dt: + curr_date_str = curr_date.strftime("%Y-%m-%d") + fetch_result = fetch_top_from_category( + "company_news", + curr_date_str, + 10, # max limit per day + query, + data_path=os.path.join(DATA_DIR, "reddit_data"), + ) + posts.extend(fetch_result) + curr_date += relativedelta(days=1) + + pbar.update(1) + + pbar.close() + + if len(posts) == 0: + return "" + + news_str = "" + for post in posts: + if post["content"] == "": + news_str += f"### {post['title']}\n\n" + else: + news_str += f"### {post['title']}\n\n{post['content']}\n\n" + return f"##{query} News Reddit, from {start_date} to {end_date}:\n\n{news_str}" \ No newline at end of file diff --git a/tradingagents/dataflows/openai.py b/tradingagents/dataflows/openai.py index 91a2258b..04347300 100644 --- a/tradingagents/dataflows/openai.py +++ b/tradingagents/dataflows/openai.py @@ -1,107 +1,107 @@ -from openai import OpenAI -from .config import get_config - - -def get_stock_news_openai(query, start_date, end_date): - config = get_config() - client = OpenAI(base_url=config["backend_url"]) - - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search Social Media for {query} from {start_date} to {end_date}? Make sure you only get the data posted during that period.", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) - - return response.output[1].content[0].text - - -def get_global_news_openai(curr_date, look_back_days=7, limit=5): - config = get_config() - client = OpenAI(base_url=config["backend_url"]) - - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search global or macroeconomics news from {look_back_days} days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period. Limit the results to {limit} articles.", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) - - return response.output[1].content[0].text - - -def get_fundamentals_openai(ticker, curr_date): - config = get_config() - client = OpenAI(base_url=config["backend_url"]) - - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) - +from openai import OpenAI +from .config import get_config + + +def get_stock_news_openai(query, start_date, end_date): + config = get_config() + client = OpenAI(base_url=config["backend_url"]) + + response = client.responses.create( + model=config["quick_think_llm"], + input=[ + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": f"Can you search Social Media for {query} from {start_date} to {end_date}? Make sure you only get the data posted during that period.", + } + ], + } + ], + text={"format": {"type": "text"}}, + reasoning={}, + tools=[ + { + "type": "web_search_preview", + "user_location": {"type": "approximate"}, + "search_context_size": "low", + } + ], + temperature=1, + max_output_tokens=4096, + top_p=1, + store=True, + ) + + return response.output[1].content[0].text + + +def get_global_news_openai(curr_date, look_back_days=7, limit=5): + config = get_config() + client = OpenAI(base_url=config["backend_url"]) + + response = client.responses.create( + model=config["quick_think_llm"], + input=[ + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": f"Can you search global or macroeconomics news from {look_back_days} days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period. Limit the results to {limit} articles.", + } + ], + } + ], + text={"format": {"type": "text"}}, + reasoning={}, + tools=[ + { + "type": "web_search_preview", + "user_location": {"type": "approximate"}, + "search_context_size": "low", + } + ], + temperature=1, + max_output_tokens=4096, + top_p=1, + store=True, + ) + + return response.output[1].content[0].text + + +def get_fundamentals_openai(ticker, curr_date): + config = get_config() + client = OpenAI(base_url=config["backend_url"]) + + response = client.responses.create( + model=config["quick_think_llm"], + input=[ + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc", + } + ], + } + ], + text={"format": {"type": "text"}}, + reasoning={}, + tools=[ + { + "type": "web_search_preview", + "user_location": {"type": "approximate"}, + "search_context_size": "low", + } + ], + temperature=1, + max_output_tokens=4096, + top_p=1, + store=True, + ) + return response.output[1].content[0].text \ No newline at end of file diff --git a/tradingagents/dataflows/reddit_utils.py b/tradingagents/dataflows/reddit_utils.py index 2532f0d1..d5decdea 100644 --- a/tradingagents/dataflows/reddit_utils.py +++ b/tradingagents/dataflows/reddit_utils.py @@ -1,135 +1,135 @@ -import requests -import time -import json -from datetime import datetime, timedelta -from contextlib import contextmanager -from typing import Annotated -import os -import re - -ticker_to_company = { - "AAPL": "Apple", - "MSFT": "Microsoft", - "GOOGL": "Google", - "AMZN": "Amazon", - "TSLA": "Tesla", - "NVDA": "Nvidia", - "TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC", - "JPM": "JPMorgan Chase OR JP Morgan", - "JNJ": "Johnson & Johnson OR JNJ", - "V": "Visa", - "WMT": "Walmart", - "META": "Meta OR Facebook", - "AMD": "AMD", - "INTC": "Intel", - "QCOM": "Qualcomm", - "BABA": "Alibaba", - "ADBE": "Adobe", - "NFLX": "Netflix", - "CRM": "Salesforce", - "PYPL": "PayPal", - "PLTR": "Palantir", - "MU": "Micron", - "SQ": "Block OR Square", - "ZM": "Zoom", - "CSCO": "Cisco", - "SHOP": "Shopify", - "ORCL": "Oracle", - "X": "Twitter OR X", - "SPOT": "Spotify", - "AVGO": "Broadcom", - "ASML": "ASML ", - "TWLO": "Twilio", - "SNAP": "Snap Inc.", - "TEAM": "Atlassian", - "SQSP": "Squarespace", - "UBER": "Uber", - "ROKU": "Roku", - "PINS": "Pinterest", -} - - -def fetch_top_from_category( - category: Annotated[ - str, "Category to fetch top post from. Collection of subreddits." - ], - date: Annotated[str, "Date to fetch top posts from."], - max_limit: Annotated[int, "Maximum number of posts to fetch."], - query: Annotated[str, "Optional query to search for in the subreddit."] = None, - data_path: Annotated[ - str, - "Path to the data folder. Default is 'reddit_data'.", - ] = "reddit_data", -): - base_path = data_path - - all_content = [] - - if max_limit < len(os.listdir(os.path.join(base_path, category))): - raise ValueError( - "REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts" - ) - - limit_per_subreddit = max_limit // len( - os.listdir(os.path.join(base_path, category)) - ) - - for data_file in os.listdir(os.path.join(base_path, category)): - # check if data_file is a .jsonl file - if not data_file.endswith(".jsonl"): - continue - - all_content_curr_subreddit = [] - - with open(os.path.join(base_path, category, data_file), "rb") as f: - for i, line in enumerate(f): - # skip empty lines - if not line.strip(): - continue - - parsed_line = json.loads(line) - - # select only lines that are from the date - post_date = datetime.utcfromtimestamp( - parsed_line["created_utc"] - ).strftime("%Y-%m-%d") - if post_date != date: - continue - - # if is company_news, check that the title or the content has the company's name (query) mentioned - if "company" in category and query: - search_terms = [] - if "OR" in ticker_to_company[query]: - search_terms = ticker_to_company[query].split(" OR ") - else: - search_terms = [ticker_to_company[query]] - - search_terms.append(query) - - found = False - for term in search_terms: - if re.search( - term, parsed_line["title"], re.IGNORECASE - ) or re.search(term, parsed_line["selftext"], re.IGNORECASE): - found = True - break - - if not found: - continue - - post = { - "title": parsed_line["title"], - "content": parsed_line["selftext"], - "url": parsed_line["url"], - "upvotes": parsed_line["ups"], - "posted_date": post_date, - } - - all_content_curr_subreddit.append(post) - - # sort all_content_curr_subreddit by upvote_ratio in descending order - all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True) - - all_content.extend(all_content_curr_subreddit[:limit_per_subreddit]) - - return all_content +import requests +import time +import json +from datetime import datetime, timedelta +from contextlib import contextmanager +from typing import Annotated +import os +import re + +ticker_to_company = { + "AAPL": "Apple", + "MSFT": "Microsoft", + "GOOGL": "Google", + "AMZN": "Amazon", + "TSLA": "Tesla", + "NVDA": "Nvidia", + "TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC", + "JPM": "JPMorgan Chase OR JP Morgan", + "JNJ": "Johnson & Johnson OR JNJ", + "V": "Visa", + "WMT": "Walmart", + "META": "Meta OR Facebook", + "AMD": "AMD", + "INTC": "Intel", + "QCOM": "Qualcomm", + "BABA": "Alibaba", + "ADBE": "Adobe", + "NFLX": "Netflix", + "CRM": "Salesforce", + "PYPL": "PayPal", + "PLTR": "Palantir", + "MU": "Micron", + "SQ": "Block OR Square", + "ZM": "Zoom", + "CSCO": "Cisco", + "SHOP": "Shopify", + "ORCL": "Oracle", + "X": "Twitter OR X", + "SPOT": "Spotify", + "AVGO": "Broadcom", + "ASML": "ASML ", + "TWLO": "Twilio", + "SNAP": "Snap Inc.", + "TEAM": "Atlassian", + "SQSP": "Squarespace", + "UBER": "Uber", + "ROKU": "Roku", + "PINS": "Pinterest", +} + + +def fetch_top_from_category( + category: Annotated[ + str, "Category to fetch top post from. Collection of subreddits." + ], + date: Annotated[str, "Date to fetch top posts from."], + max_limit: Annotated[int, "Maximum number of posts to fetch."], + query: Annotated[str, "Optional query to search for in the subreddit."] = None, + data_path: Annotated[ + str, + "Path to the data folder. Default is 'reddit_data'.", + ] = "reddit_data", +): + base_path = data_path + + all_content = [] + + if max_limit < len(os.listdir(os.path.join(base_path, category))): + raise ValueError( + "REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts" + ) + + limit_per_subreddit = max_limit // len( + os.listdir(os.path.join(base_path, category)) + ) + + for data_file in os.listdir(os.path.join(base_path, category)): + # check if data_file is a .jsonl file + if not data_file.endswith(".jsonl"): + continue + + all_content_curr_subreddit = [] + + with open(os.path.join(base_path, category, data_file), "rb") as f: + for i, line in enumerate(f): + # skip empty lines + if not line.strip(): + continue + + parsed_line = json.loads(line) + + # select only lines that are from the date + post_date = datetime.utcfromtimestamp( + parsed_line["created_utc"] + ).strftime("%Y-%m-%d") + if post_date != date: + continue + + # if is company_news, check that the title or the content has the company's name (query) mentioned + if "company" in category and query: + search_terms = [] + if "OR" in ticker_to_company[query]: + search_terms = ticker_to_company[query].split(" OR ") + else: + search_terms = [ticker_to_company[query]] + + search_terms.append(query) + + found = False + for term in search_terms: + if re.search( + term, parsed_line["title"], re.IGNORECASE + ) or re.search(term, parsed_line["selftext"], re.IGNORECASE): + found = True + break + + if not found: + continue + + post = { + "title": parsed_line["title"], + "content": parsed_line["selftext"], + "url": parsed_line["url"], + "upvotes": parsed_line["ups"], + "posted_date": post_date, + } + + all_content_curr_subreddit.append(post) + + # sort all_content_curr_subreddit by upvote_ratio in descending order + all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True) + + all_content.extend(all_content_curr_subreddit[:limit_per_subreddit]) + + return all_content diff --git a/tradingagents/dataflows/social_sentiment.py b/tradingagents/dataflows/social_sentiment.py new file mode 100644 index 00000000..d2dd818c --- /dev/null +++ b/tradingagents/dataflows/social_sentiment.py @@ -0,0 +1,54 @@ +import requests +from langchain_core.tools import tool +from typing import Annotated + +def get_stocktwits_trending() -> list[str]: + """Fetch trending symbols from StockTwits.""" + url = "https://api.stocktwits.com/api/2/trending/symbols.json" + try: + response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}) + if response.status_code == 200: + data = response.json() + return [s['symbol'] for s in data['symbols']] + return [] + except Exception as e: + print(f"Error fetching StockTwits trending: {e}") + return [] + +def get_apewisdom_trending() -> list[str]: + """Fetch trending tickers from Reddit via Ape Wisdom.""" + url = "https://apewisdom.io/api/v1.0/filter/all-stocks/page/1" + try: + response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}) + if response.status_code == 200: + data = response.json() + return [s['ticker'] for s in data['results']] + return [] + except Exception as e: + print(f"Error fetching Ape Wisdom trending: {e}") + return [] + +@tool +def get_trending_social( + platform: Annotated[str, "Platform to check: 'stocktwits', 'reddit', or 'all'"] = "all" +) -> str: + """ + Retrieve a list of trending stocks from social media platforms (StockTwits, Reddit). + Useful for finding 'hyped' stocks or retail sentiment plays. + """ + results = [] + + if platform in ["stocktwits", "all"]: + st_symbols = get_stocktwits_trending() + if st_symbols: + results.append(f"StockTwits Trending: {', '.join(st_symbols[:10])}") + + if platform in ["reddit", "all"]: + aw_symbols = get_apewisdom_trending() + if aw_symbols: + results.append(f"Reddit Trending (Ape Wisdom): {', '.join(aw_symbols[:10])}") + + if not results: + return "No trending data available." + + return "\n\n".join(results) diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index e81684e0..d1def1d7 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -1,82 +1,82 @@ -import pandas as pd -import yfinance as yf -from stockstats import wrap -from typing import Annotated -import os -from .config import get_config, DATA_DIR - - -class StockstatsUtils: - @staticmethod - def get_stock_stats( - symbol: Annotated[str, "ticker symbol for the company"], - indicator: Annotated[ - str, "quantitative indicators based off of the stock data for the company" - ], - curr_date: Annotated[ - str, "curr date for retrieving stock price data, YYYY-mm-dd" - ], - ): - # Get config and set up data directory path - config = get_config() - online = config["data_vendors"]["technical_indicators"] != "local" - - df = None - data = None - - if not online: - try: - data = pd.read_csv( - os.path.join( - DATA_DIR, - f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) - ) - df = wrap(data) - except FileNotFoundError: - raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") - else: - # Get today's date as YYYY-mm-dd to add to cache - today_date = pd.Timestamp.today() - curr_date = pd.to_datetime(curr_date) - - end_date = today_date - start_date = today_date - pd.DateOffset(years=15) - start_date = start_date.strftime("%Y-%m-%d") - end_date = end_date.strftime("%Y-%m-%d") - - # Get config and ensure cache directory exists - os.makedirs(config["data_cache_dir"], exist_ok=True) - - data_file = os.path.join( - config["data_cache_dir"], - f"{symbol}-YFin-data-{start_date}-{end_date}.csv", - ) - - if os.path.exists(data_file): - data = pd.read_csv(data_file) - data["Date"] = pd.to_datetime(data["Date"]) - else: - data = yf.download( - symbol, - start=start_date, - end=end_date, - multi_level_index=False, - progress=False, - auto_adjust=True, - ) - data = data.reset_index() - data.to_csv(data_file, index=False) - - df = wrap(data) - df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") - curr_date = curr_date.strftime("%Y-%m-%d") - - df[indicator] # trigger stockstats to calculate the indicator - matching_rows = df[df["Date"].str.startswith(curr_date)] - - if not matching_rows.empty: - indicator_value = matching_rows[indicator].values[0] - return indicator_value - else: - return "N/A: Not a trading day (weekend or holiday)" +import pandas as pd +import yfinance as yf +from stockstats import wrap +from typing import Annotated +import os +from .config import get_config, DATA_DIR + + +class StockstatsUtils: + @staticmethod + def get_stock_stats( + symbol: Annotated[str, "ticker symbol for the company"], + indicator: Annotated[ + str, "quantitative indicators based off of the stock data for the company" + ], + curr_date: Annotated[ + str, "curr date for retrieving stock price data, YYYY-mm-dd" + ], + ): + # Get config and set up data directory path + config = get_config() + online = config["data_vendors"]["technical_indicators"] != "local" + + df = None + data = None + + if not online: + try: + data = pd.read_csv( + os.path.join( + DATA_DIR, + f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + df = wrap(data) + except FileNotFoundError: + raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") + else: + # Get today's date as YYYY-mm-dd to add to cache + today_date = pd.Timestamp.today() + curr_date = pd.to_datetime(curr_date) + + end_date = today_date + start_date = today_date - pd.DateOffset(years=15) + start_date = start_date.strftime("%Y-%m-%d") + end_date = end_date.strftime("%Y-%m-%d") + + # Get config and ensure cache directory exists + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{symbol}-YFin-data-{start_date}-{end_date}.csv", + ) + + if os.path.exists(data_file): + data = pd.read_csv(data_file) + data["Date"] = pd.to_datetime(data["Date"]) + else: + data = yf.download( + symbol, + start=start_date, + end=end_date, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + data = data.reset_index() + data.to_csv(data_file, index=False) + + df = wrap(data) + df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") + curr_date = curr_date.strftime("%Y-%m-%d") + + df[indicator] # trigger stockstats to calculate the indicator + matching_rows = df[df["Date"].str.startswith(curr_date)] + + if not matching_rows.empty: + indicator_value = matching_rows[indicator].values[0] + return indicator_value + else: + return "N/A: Not a trading day (weekend or holiday)" diff --git a/tradingagents/dataflows/utils.py b/tradingagents/dataflows/utils.py index 4523de19..1d8c2c4f 100644 --- a/tradingagents/dataflows/utils.py +++ b/tradingagents/dataflows/utils.py @@ -1,39 +1,39 @@ -import os -import json -import pandas as pd -from datetime import date, timedelta, datetime -from typing import Annotated - -SavePathType = Annotated[str, "File path to save data. If None, data is not saved."] - -def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None: - if save_path: - data.to_csv(save_path) - print(f"{tag} saved to {save_path}") - - -def get_current_date(): - return date.today().strftime("%Y-%m-%d") - - -def decorate_all_methods(decorator): - def class_decorator(cls): - for attr_name, attr_value in cls.__dict__.items(): - if callable(attr_value): - setattr(cls, attr_name, decorator(attr_value)) - return cls - - return class_decorator - - -def get_next_weekday(date): - - if not isinstance(date, datetime): - date = datetime.strptime(date, "%Y-%m-%d") - - if date.weekday() >= 5: - days_to_add = 7 - date.weekday() - next_weekday = date + timedelta(days=days_to_add) - return next_weekday - else: - return date +import os +import json +import pandas as pd +from datetime import date, timedelta, datetime +from typing import Annotated + +SavePathType = Annotated[str, "File path to save data. If None, data is not saved."] + +def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None: + if save_path: + data.to_csv(save_path) + print(f"{tag} saved to {save_path}") + + +def get_current_date(): + return date.today().strftime("%Y-%m-%d") + + +def decorate_all_methods(decorator): + def class_decorator(cls): + for attr_name, attr_value in cls.__dict__.items(): + if callable(attr_value): + setattr(cls, attr_name, decorator(attr_value)) + return cls + + return class_decorator + + +def get_next_weekday(date): + + if not isinstance(date, datetime): + date = datetime.strptime(date, "%Y-%m-%d") + + if date.weekday() >= 5: + days_to_add = 7 - date.weekday() + next_weekday = date + timedelta(days=days_to_add) + return next_weekday + else: + return date diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index da7273d5..8e3ca808 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -1,407 +1,407 @@ -from typing import Annotated -from datetime import datetime -from dateutil.relativedelta import relativedelta -import yfinance as yf -import os -from .stockstats_utils import StockstatsUtils - -def get_YFin_data_online( - symbol: Annotated[str, "ticker symbol of the company"], - start_date: Annotated[str, "Start date in yyyy-mm-dd format"], - end_date: Annotated[str, "End date in yyyy-mm-dd format"], -): - - datetime.strptime(start_date, "%Y-%m-%d") - datetime.strptime(end_date, "%Y-%m-%d") - - # Create ticker object - ticker = yf.Ticker(symbol.upper()) - - # Fetch historical data for the specified date range - data = ticker.history(start=start_date, end=end_date) - - # Check if data is empty - if data.empty: - return ( - f"No data found for symbol '{symbol}' between {start_date} and {end_date}" - ) - - # Remove timezone info from index for cleaner output - if data.index.tz is not None: - data.index = data.index.tz_localize(None) - - # Round numerical values to 2 decimal places for cleaner display - numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"] - for col in numeric_columns: - if col in data.columns: - data[col] = data[col].round(2) - - # Convert DataFrame to CSV string - csv_string = data.to_csv() - - # Add header information - header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n" - header += f"# Total records: {len(data)}\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - - return header + csv_string - -def get_stock_stats_indicators_window( - symbol: Annotated[str, "ticker symbol of the company"], - indicator: Annotated[str, "technical indicator to get the analysis and report of"], - curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd" - ], - look_back_days: Annotated[int, "how many days to look back"], -) -> str: - - best_ind_params = { - # Moving Averages - "close_50_sma": ( - "50 SMA: A medium-term trend indicator. " - "Usage: Identify trend direction and serve as dynamic support/resistance. " - "Tips: It lags price; combine with faster indicators for timely signals." - ), - "close_200_sma": ( - "200 SMA: A long-term trend benchmark. " - "Usage: Confirm overall market trend and identify golden/death cross setups. " - "Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries." - ), - "close_10_ema": ( - "10 EMA: A responsive short-term average. " - "Usage: Capture quick shifts in momentum and potential entry points. " - "Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals." - ), - # MACD Related - "macd": ( - "MACD: Computes momentum via differences of EMAs. " - "Usage: Look for crossovers and divergence as signals of trend changes. " - "Tips: Confirm with other indicators in low-volatility or sideways markets." - ), - "macds": ( - "MACD Signal: An EMA smoothing of the MACD line. " - "Usage: Use crossovers with the MACD line to trigger trades. " - "Tips: Should be part of a broader strategy to avoid false positives." - ), - "macdh": ( - "MACD Histogram: Shows the gap between the MACD line and its signal. " - "Usage: Visualize momentum strength and spot divergence early. " - "Tips: Can be volatile; complement with additional filters in fast-moving markets." - ), - # Momentum Indicators - "rsi": ( - "RSI: Measures momentum to flag overbought/oversold conditions. " - "Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. " - "Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis." - ), - # Volatility Indicators - "boll": ( - "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. " - "Usage: Acts as a dynamic benchmark for price movement. " - "Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals." - ), - "boll_ub": ( - "Bollinger Upper Band: Typically 2 standard deviations above the middle line. " - "Usage: Signals potential overbought conditions and breakout zones. " - "Tips: Confirm signals with other tools; prices may ride the band in strong trends." - ), - "boll_lb": ( - "Bollinger Lower Band: Typically 2 standard deviations below the middle line. " - "Usage: Indicates potential oversold conditions. " - "Tips: Use additional analysis to avoid false reversal signals." - ), - "atr": ( - "ATR: Averages true range to measure volatility. " - "Usage: Set stop-loss levels and adjust position sizes based on current market volatility. " - "Tips: It's a reactive measure, so use it as part of a broader risk management strategy." - ), - # Volume-Based Indicators - "vwma": ( - "VWMA: A moving average weighted by volume. " - "Usage: Confirm trends by integrating price action with volume data. " - "Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses." - ), - "mfi": ( - "MFI: The Money Flow Index is a momentum indicator that uses both price and volume to measure buying and selling pressure. " - "Usage: Identify overbought (>80) or oversold (<20) conditions and confirm the strength of trends or reversals. " - "Tips: Use alongside RSI or MACD to confirm signals; divergence between price and MFI can indicate potential reversals." - ), - } - - if indicator not in best_ind_params: - raise ValueError( - f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}" - ) - - end_date = curr_date - curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") - before = curr_date_dt - relativedelta(days=look_back_days) - - # Optimized: Get stock data once and calculate indicators for all dates - try: - indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date) - - # Generate the date range we need - current_dt = curr_date_dt - date_values = [] - - while current_dt >= before: - date_str = current_dt.strftime('%Y-%m-%d') - - # Look up the indicator value for this date - if date_str in indicator_data: - indicator_value = indicator_data[date_str] - else: - indicator_value = "N/A: Not a trading day (weekend or holiday)" - - date_values.append((date_str, indicator_value)) - current_dt = current_dt - relativedelta(days=1) - - # Build the result string - ind_string = "" - for date_str, value in date_values: - ind_string += f"{date_str}: {value}\n" - - except Exception as e: - print(f"Error getting bulk stockstats data: {e}") - # Fallback to original implementation if bulk method fails - ind_string = "" - curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") - while curr_date_dt >= before: - indicator_value = get_stockstats_indicator( - symbol, indicator, curr_date_dt.strftime("%Y-%m-%d") - ) - ind_string += f"{curr_date_dt.strftime('%Y-%m-%d')}: {indicator_value}\n" - curr_date_dt = curr_date_dt - relativedelta(days=1) - - result_str = ( - f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n" - + ind_string - + "\n\n" - + best_ind_params.get(indicator, "No description available.") - ) - - return result_str - - -def _get_stock_stats_bulk( - symbol: Annotated[str, "ticker symbol of the company"], - indicator: Annotated[str, "technical indicator to calculate"], - curr_date: Annotated[str, "current date for reference"] -) -> dict: - """ - Optimized bulk calculation of stock stats indicators. - Fetches data once and calculates indicator for all available dates. - Returns dict mapping date strings to indicator values. - """ - from .config import get_config - import pandas as pd - from stockstats import wrap - import os - - config = get_config() - online = config["data_vendors"]["technical_indicators"] != "local" - - if not online: - # Local data path - try: - data = pd.read_csv( - os.path.join( - config.get("data_cache_dir", "data"), - f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) - ) - df = wrap(data) - except FileNotFoundError: - raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") - else: - # Online data fetching with caching - today_date = pd.Timestamp.today() - curr_date_dt = pd.to_datetime(curr_date) - - end_date = today_date - start_date = today_date - pd.DateOffset(years=15) - start_date_str = start_date.strftime("%Y-%m-%d") - end_date_str = end_date.strftime("%Y-%m-%d") - - os.makedirs(config["data_cache_dir"], exist_ok=True) - - data_file = os.path.join( - config["data_cache_dir"], - f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", - ) - - if os.path.exists(data_file): - data = pd.read_csv(data_file) - data["Date"] = pd.to_datetime(data["Date"]) - else: - data = yf.download( - symbol, - start=start_date_str, - end=end_date_str, - multi_level_index=False, - progress=False, - auto_adjust=True, - ) - data = data.reset_index() - data.to_csv(data_file, index=False) - - df = wrap(data) - df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") - - # Calculate the indicator for all rows at once - df[indicator] # This triggers stockstats to calculate the indicator - - # Create a dictionary mapping date strings to indicator values - result_dict = {} - for _, row in df.iterrows(): - date_str = row["Date"] - indicator_value = row[indicator] - - # Handle NaN/None values - if pd.isna(indicator_value): - result_dict[date_str] = "N/A" - else: - result_dict[date_str] = str(indicator_value) - - return result_dict - - -def get_stockstats_indicator( - symbol: Annotated[str, "ticker symbol of the company"], - indicator: Annotated[str, "technical indicator to get the analysis and report of"], - curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd" - ], -) -> str: - - curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") - curr_date = curr_date_dt.strftime("%Y-%m-%d") - - try: - indicator_value = StockstatsUtils.get_stock_stats( - symbol, - indicator, - curr_date, - ) - except Exception as e: - print( - f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}" - ) - return "" - - return str(indicator_value) - - -def get_balance_sheet( - ticker: Annotated[str, "ticker symbol of the company"], - freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", - curr_date: Annotated[str, "current date (not used for yfinance)"] = None -): - """Get balance sheet data from yfinance.""" - try: - ticker_obj = yf.Ticker(ticker.upper()) - - if freq.lower() == "quarterly": - data = ticker_obj.quarterly_balance_sheet - else: - data = ticker_obj.balance_sheet - - if data.empty: - return f"No balance sheet data found for symbol '{ticker}'" - - # Convert to CSV string for consistency with other functions - csv_string = data.to_csv() - - # Add header information - header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - - return header + csv_string - - except Exception as e: - return f"Error retrieving balance sheet for {ticker}: {str(e)}" - - -def get_cashflow( - ticker: Annotated[str, "ticker symbol of the company"], - freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", - curr_date: Annotated[str, "current date (not used for yfinance)"] = None -): - """Get cash flow data from yfinance.""" - try: - ticker_obj = yf.Ticker(ticker.upper()) - - if freq.lower() == "quarterly": - data = ticker_obj.quarterly_cashflow - else: - data = ticker_obj.cashflow - - if data.empty: - return f"No cash flow data found for symbol '{ticker}'" - - # Convert to CSV string for consistency with other functions - csv_string = data.to_csv() - - # Add header information - header = f"# Cash Flow data for {ticker.upper()} ({freq})\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - - return header + csv_string - - except Exception as e: - return f"Error retrieving cash flow for {ticker}: {str(e)}" - - -def get_income_statement( - ticker: Annotated[str, "ticker symbol of the company"], - freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", - curr_date: Annotated[str, "current date (not used for yfinance)"] = None -): - """Get income statement data from yfinance.""" - try: - ticker_obj = yf.Ticker(ticker.upper()) - - if freq.lower() == "quarterly": - data = ticker_obj.quarterly_income_stmt - else: - data = ticker_obj.income_stmt - - if data.empty: - return f"No income statement data found for symbol '{ticker}'" - - # Convert to CSV string for consistency with other functions - csv_string = data.to_csv() - - # Add header information - header = f"# Income Statement data for {ticker.upper()} ({freq})\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - - return header + csv_string - - except Exception as e: - return f"Error retrieving income statement for {ticker}: {str(e)}" - - -def get_insider_transactions( - ticker: Annotated[str, "ticker symbol of the company"] -): - """Get insider transactions data from yfinance.""" - try: - ticker_obj = yf.Ticker(ticker.upper()) - data = ticker_obj.insider_transactions - - if data is None or data.empty: - return f"No insider transactions data found for symbol '{ticker}'" - - # Convert to CSV string for consistency with other functions - csv_string = data.to_csv() - - # Add header information - header = f"# Insider Transactions data for {ticker.upper()}\n" - header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - - return header + csv_string - - except Exception as e: +from typing import Annotated +from datetime import datetime +from dateutil.relativedelta import relativedelta +import yfinance as yf +import os +from .stockstats_utils import StockstatsUtils + +def get_YFin_data_online( + symbol: Annotated[str, "ticker symbol of the company"], + start_date: Annotated[str, "Start date in yyyy-mm-dd format"], + end_date: Annotated[str, "End date in yyyy-mm-dd format"], +): + + datetime.strptime(start_date, "%Y-%m-%d") + datetime.strptime(end_date, "%Y-%m-%d") + + # Create ticker object + ticker = yf.Ticker(symbol.upper()) + + # Fetch historical data for the specified date range + data = ticker.history(start=start_date, end=end_date) + + # Check if data is empty + if data.empty: + return ( + f"No data found for symbol '{symbol}' between {start_date} and {end_date}" + ) + + # Remove timezone info from index for cleaner output + if data.index.tz is not None: + data.index = data.index.tz_localize(None) + + # Round numerical values to 2 decimal places for cleaner display + numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"] + for col in numeric_columns: + if col in data.columns: + data[col] = data[col].round(2) + + # Convert DataFrame to CSV string + csv_string = data.to_csv() + + # Add header information + header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n" + header += f"# Total records: {len(data)}\n" + header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + + return header + csv_string + +def get_stock_stats_indicators_window( + symbol: Annotated[str, "ticker symbol of the company"], + indicator: Annotated[str, "technical indicator to get the analysis and report of"], + curr_date: Annotated[ + str, "The current trading date you are trading on, YYYY-mm-dd" + ], + look_back_days: Annotated[int, "how many days to look back"], +) -> str: + + best_ind_params = { + # Moving Averages + "close_50_sma": ( + "50 SMA: A medium-term trend indicator. " + "Usage: Identify trend direction and serve as dynamic support/resistance. " + "Tips: It lags price; combine with faster indicators for timely signals." + ), + "close_200_sma": ( + "200 SMA: A long-term trend benchmark. " + "Usage: Confirm overall market trend and identify golden/death cross setups. " + "Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries." + ), + "close_10_ema": ( + "10 EMA: A responsive short-term average. " + "Usage: Capture quick shifts in momentum and potential entry points. " + "Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals." + ), + # MACD Related + "macd": ( + "MACD: Computes momentum via differences of EMAs. " + "Usage: Look for crossovers and divergence as signals of trend changes. " + "Tips: Confirm with other indicators in low-volatility or sideways markets." + ), + "macds": ( + "MACD Signal: An EMA smoothing of the MACD line. " + "Usage: Use crossovers with the MACD line to trigger trades. " + "Tips: Should be part of a broader strategy to avoid false positives." + ), + "macdh": ( + "MACD Histogram: Shows the gap between the MACD line and its signal. " + "Usage: Visualize momentum strength and spot divergence early. " + "Tips: Can be volatile; complement with additional filters in fast-moving markets." + ), + # Momentum Indicators + "rsi": ( + "RSI: Measures momentum to flag overbought/oversold conditions. " + "Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. " + "Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis." + ), + # Volatility Indicators + "boll": ( + "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. " + "Usage: Acts as a dynamic benchmark for price movement. " + "Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals." + ), + "boll_ub": ( + "Bollinger Upper Band: Typically 2 standard deviations above the middle line. " + "Usage: Signals potential overbought conditions and breakout zones. " + "Tips: Confirm signals with other tools; prices may ride the band in strong trends." + ), + "boll_lb": ( + "Bollinger Lower Band: Typically 2 standard deviations below the middle line. " + "Usage: Indicates potential oversold conditions. " + "Tips: Use additional analysis to avoid false reversal signals." + ), + "atr": ( + "ATR: Averages true range to measure volatility. " + "Usage: Set stop-loss levels and adjust position sizes based on current market volatility. " + "Tips: It's a reactive measure, so use it as part of a broader risk management strategy." + ), + # Volume-Based Indicators + "vwma": ( + "VWMA: A moving average weighted by volume. " + "Usage: Confirm trends by integrating price action with volume data. " + "Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses." + ), + "mfi": ( + "MFI: The Money Flow Index is a momentum indicator that uses both price and volume to measure buying and selling pressure. " + "Usage: Identify overbought (>80) or oversold (<20) conditions and confirm the strength of trends or reversals. " + "Tips: Use alongside RSI or MACD to confirm signals; divergence between price and MFI can indicate potential reversals." + ), + } + + if indicator not in best_ind_params: + raise ValueError( + f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}" + ) + + end_date = curr_date + curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") + before = curr_date_dt - relativedelta(days=look_back_days) + + # Optimized: Get stock data once and calculate indicators for all dates + try: + indicator_data = _get_stock_stats_bulk(symbol, indicator, curr_date) + + # Generate the date range we need + current_dt = curr_date_dt + date_values = [] + + while current_dt >= before: + date_str = current_dt.strftime('%Y-%m-%d') + + # Look up the indicator value for this date + if date_str in indicator_data: + indicator_value = indicator_data[date_str] + else: + indicator_value = "N/A: Not a trading day (weekend or holiday)" + + date_values.append((date_str, indicator_value)) + current_dt = current_dt - relativedelta(days=1) + + # Build the result string + ind_string = "" + for date_str, value in date_values: + ind_string += f"{date_str}: {value}\n" + + except Exception as e: + print(f"Error getting bulk stockstats data: {e}") + # Fallback to original implementation if bulk method fails + ind_string = "" + curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") + while curr_date_dt >= before: + indicator_value = get_stockstats_indicator( + symbol, indicator, curr_date_dt.strftime("%Y-%m-%d") + ) + ind_string += f"{curr_date_dt.strftime('%Y-%m-%d')}: {indicator_value}\n" + curr_date_dt = curr_date_dt - relativedelta(days=1) + + result_str = ( + f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n" + + ind_string + + "\n\n" + + best_ind_params.get(indicator, "No description available.") + ) + + return result_str + + +def _get_stock_stats_bulk( + symbol: Annotated[str, "ticker symbol of the company"], + indicator: Annotated[str, "technical indicator to calculate"], + curr_date: Annotated[str, "current date for reference"] +) -> dict: + """ + Optimized bulk calculation of stock stats indicators. + Fetches data once and calculates indicator for all available dates. + Returns dict mapping date strings to indicator values. + """ + from .config import get_config + import pandas as pd + from stockstats import wrap + import os + + config = get_config() + online = config["data_vendors"]["technical_indicators"] != "local" + + if not online: + # Local data path + try: + data = pd.read_csv( + os.path.join( + config.get("data_cache_dir", "data"), + f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", + ) + ) + df = wrap(data) + except FileNotFoundError: + raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") + else: + # Online data fetching with caching + today_date = pd.Timestamp.today() + curr_date_dt = pd.to_datetime(curr_date) + + end_date = today_date + start_date = today_date - pd.DateOffset(years=15) + start_date_str = start_date.strftime("%Y-%m-%d") + end_date_str = end_date.strftime("%Y-%m-%d") + + os.makedirs(config["data_cache_dir"], exist_ok=True) + + data_file = os.path.join( + config["data_cache_dir"], + f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv", + ) + + if os.path.exists(data_file): + data = pd.read_csv(data_file) + data["Date"] = pd.to_datetime(data["Date"]) + else: + data = yf.download( + symbol, + start=start_date_str, + end=end_date_str, + multi_level_index=False, + progress=False, + auto_adjust=True, + ) + data = data.reset_index() + data.to_csv(data_file, index=False) + + df = wrap(data) + df["Date"] = df["Date"].dt.strftime("%Y-%m-%d") + + # Calculate the indicator for all rows at once + df[indicator] # This triggers stockstats to calculate the indicator + + # Create a dictionary mapping date strings to indicator values + result_dict = {} + for _, row in df.iterrows(): + date_str = row["Date"] + indicator_value = row[indicator] + + # Handle NaN/None values + if pd.isna(indicator_value): + result_dict[date_str] = "N/A" + else: + result_dict[date_str] = str(indicator_value) + + return result_dict + + +def get_stockstats_indicator( + symbol: Annotated[str, "ticker symbol of the company"], + indicator: Annotated[str, "technical indicator to get the analysis and report of"], + curr_date: Annotated[ + str, "The current trading date you are trading on, YYYY-mm-dd" + ], +) -> str: + + curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") + curr_date = curr_date_dt.strftime("%Y-%m-%d") + + try: + indicator_value = StockstatsUtils.get_stock_stats( + symbol, + indicator, + curr_date, + ) + except Exception as e: + print( + f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}" + ) + return "" + + return str(indicator_value) + + +def get_balance_sheet( + ticker: Annotated[str, "ticker symbol of the company"], + freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", + curr_date: Annotated[str, "current date (not used for yfinance)"] = None +): + """Get balance sheet data from yfinance.""" + try: + ticker_obj = yf.Ticker(ticker.upper()) + + if freq.lower() == "quarterly": + data = ticker_obj.quarterly_balance_sheet + else: + data = ticker_obj.balance_sheet + + if data.empty: + return f"No balance sheet data found for symbol '{ticker}'" + + # Convert to CSV string for consistency with other functions + csv_string = data.to_csv() + + # Add header information + header = f"# Balance Sheet data for {ticker.upper()} ({freq})\n" + header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + + return header + csv_string + + except Exception as e: + return f"Error retrieving balance sheet for {ticker}: {str(e)}" + + +def get_cashflow( + ticker: Annotated[str, "ticker symbol of the company"], + freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", + curr_date: Annotated[str, "current date (not used for yfinance)"] = None +): + """Get cash flow data from yfinance.""" + try: + ticker_obj = yf.Ticker(ticker.upper()) + + if freq.lower() == "quarterly": + data = ticker_obj.quarterly_cashflow + else: + data = ticker_obj.cashflow + + if data.empty: + return f"No cash flow data found for symbol '{ticker}'" + + # Convert to CSV string for consistency with other functions + csv_string = data.to_csv() + + # Add header information + header = f"# Cash Flow data for {ticker.upper()} ({freq})\n" + header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + + return header + csv_string + + except Exception as e: + return f"Error retrieving cash flow for {ticker}: {str(e)}" + + +def get_income_statement( + ticker: Annotated[str, "ticker symbol of the company"], + freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly", + curr_date: Annotated[str, "current date (not used for yfinance)"] = None +): + """Get income statement data from yfinance.""" + try: + ticker_obj = yf.Ticker(ticker.upper()) + + if freq.lower() == "quarterly": + data = ticker_obj.quarterly_income_stmt + else: + data = ticker_obj.income_stmt + + if data.empty: + return f"No income statement data found for symbol '{ticker}'" + + # Convert to CSV string for consistency with other functions + csv_string = data.to_csv() + + # Add header information + header = f"# Income Statement data for {ticker.upper()} ({freq})\n" + header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + + return header + csv_string + + except Exception as e: + return f"Error retrieving income statement for {ticker}: {str(e)}" + + +def get_insider_transactions( + ticker: Annotated[str, "ticker symbol of the company"] +): + """Get insider transactions data from yfinance.""" + try: + ticker_obj = yf.Ticker(ticker.upper()) + data = ticker_obj.insider_transactions + + if data is None or data.empty: + return f"No insider transactions data found for symbol '{ticker}'" + + # Convert to CSV string for consistency with other functions + csv_string = data.to_csv() + + # Add header information + header = f"# Insider Transactions data for {ticker.upper()}\n" + header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + + return header + csv_string + + except Exception as e: return f"Error retrieving insider transactions for {ticker}: {str(e)}" \ No newline at end of file diff --git a/tradingagents/dataflows/yfin_utils.py b/tradingagents/dataflows/yfin_utils.py index bd7ca324..0dba2b30 100644 --- a/tradingagents/dataflows/yfin_utils.py +++ b/tradingagents/dataflows/yfin_utils.py @@ -1,117 +1,117 @@ -# gets data/stats - -import yfinance as yf -from typing import Annotated, Callable, Any, Optional -from pandas import DataFrame -import pandas as pd -from functools import wraps - -from .utils import save_output, SavePathType, decorate_all_methods - - -def init_ticker(func: Callable) -> Callable: - """Decorator to initialize yf.Ticker and pass it to the function.""" - - @wraps(func) - def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any: - ticker = yf.Ticker(symbol) - return func(ticker, *args, **kwargs) - - return wrapper - - -@decorate_all_methods(init_ticker) -class YFinanceUtils: - - def get_stock_data( - symbol: Annotated[str, "ticker symbol"], - start_date: Annotated[ - str, "start date for retrieving stock price data, YYYY-mm-dd" - ], - end_date: Annotated[ - str, "end date for retrieving stock price data, YYYY-mm-dd" - ], - save_path: SavePathType = None, - ) -> DataFrame: - """retrieve stock price data for designated ticker symbol""" - ticker = symbol - # add one day to the end_date so that the data range is inclusive - end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1) - end_date = end_date.strftime("%Y-%m-%d") - stock_data = ticker.history(start=start_date, end=end_date) - # save_output(stock_data, f"Stock data for {ticker.ticker}", save_path) - return stock_data - - def get_stock_info( - symbol: Annotated[str, "ticker symbol"], - ) -> dict: - """Fetches and returns latest stock information.""" - ticker = symbol - stock_info = ticker.info - return stock_info - - def get_company_info( - symbol: Annotated[str, "ticker symbol"], - save_path: Optional[str] = None, - ) -> DataFrame: - """Fetches and returns company information as a DataFrame.""" - ticker = symbol - info = ticker.info - company_info = { - "Company Name": info.get("shortName", "N/A"), - "Industry": info.get("industry", "N/A"), - "Sector": info.get("sector", "N/A"), - "Country": info.get("country", "N/A"), - "Website": info.get("website", "N/A"), - } - company_info_df = DataFrame([company_info]) - if save_path: - company_info_df.to_csv(save_path) - print(f"Company info for {ticker.ticker} saved to {save_path}") - return company_info_df - - def get_stock_dividends( - symbol: Annotated[str, "ticker symbol"], - save_path: Optional[str] = None, - ) -> DataFrame: - """Fetches and returns the latest dividends data as a DataFrame.""" - ticker = symbol - dividends = ticker.dividends - if save_path: - dividends.to_csv(save_path) - print(f"Dividends for {ticker.ticker} saved to {save_path}") - return dividends - - def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: - """Fetches and returns the latest income statement of the company as a DataFrame.""" - ticker = symbol - income_stmt = ticker.financials - return income_stmt - - def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: - """Fetches and returns the latest balance sheet of the company as a DataFrame.""" - ticker = symbol - balance_sheet = ticker.balance_sheet - return balance_sheet - - def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: - """Fetches and returns the latest cash flow statement of the company as a DataFrame.""" - ticker = symbol - cash_flow = ticker.cashflow - return cash_flow - - def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple: - """Fetches the latest analyst recommendations and returns the most common recommendation and its count.""" - ticker = symbol - recommendations = ticker.recommendations - if recommendations.empty: - return None, 0 # No recommendations available - - # Assuming 'period' column exists and needs to be excluded - row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary - - # Find the maximum voting result - max_votes = row_0.max() - majority_voting_result = row_0[row_0 == max_votes].index.tolist() - - return majority_voting_result[0], max_votes +# gets data/stats + +import yfinance as yf +from typing import Annotated, Callable, Any, Optional +from pandas import DataFrame +import pandas as pd +from functools import wraps + +from .utils import save_output, SavePathType, decorate_all_methods + + +def init_ticker(func: Callable) -> Callable: + """Decorator to initialize yf.Ticker and pass it to the function.""" + + @wraps(func) + def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any: + ticker = yf.Ticker(symbol) + return func(ticker, *args, **kwargs) + + return wrapper + + +@decorate_all_methods(init_ticker) +class YFinanceUtils: + + def get_stock_data( + symbol: Annotated[str, "ticker symbol"], + start_date: Annotated[ + str, "start date for retrieving stock price data, YYYY-mm-dd" + ], + end_date: Annotated[ + str, "end date for retrieving stock price data, YYYY-mm-dd" + ], + save_path: SavePathType = None, + ) -> DataFrame: + """retrieve stock price data for designated ticker symbol""" + ticker = symbol + # add one day to the end_date so that the data range is inclusive + end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1) + end_date = end_date.strftime("%Y-%m-%d") + stock_data = ticker.history(start=start_date, end=end_date) + # save_output(stock_data, f"Stock data for {ticker.ticker}", save_path) + return stock_data + + def get_stock_info( + symbol: Annotated[str, "ticker symbol"], + ) -> dict: + """Fetches and returns latest stock information.""" + ticker = symbol + stock_info = ticker.info + return stock_info + + def get_company_info( + symbol: Annotated[str, "ticker symbol"], + save_path: Optional[str] = None, + ) -> DataFrame: + """Fetches and returns company information as a DataFrame.""" + ticker = symbol + info = ticker.info + company_info = { + "Company Name": info.get("shortName", "N/A"), + "Industry": info.get("industry", "N/A"), + "Sector": info.get("sector", "N/A"), + "Country": info.get("country", "N/A"), + "Website": info.get("website", "N/A"), + } + company_info_df = DataFrame([company_info]) + if save_path: + company_info_df.to_csv(save_path) + print(f"Company info for {ticker.ticker} saved to {save_path}") + return company_info_df + + def get_stock_dividends( + symbol: Annotated[str, "ticker symbol"], + save_path: Optional[str] = None, + ) -> DataFrame: + """Fetches and returns the latest dividends data as a DataFrame.""" + ticker = symbol + dividends = ticker.dividends + if save_path: + dividends.to_csv(save_path) + print(f"Dividends for {ticker.ticker} saved to {save_path}") + return dividends + + def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: + """Fetches and returns the latest income statement of the company as a DataFrame.""" + ticker = symbol + income_stmt = ticker.financials + return income_stmt + + def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: + """Fetches and returns the latest balance sheet of the company as a DataFrame.""" + ticker = symbol + balance_sheet = ticker.balance_sheet + return balance_sheet + + def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: + """Fetches and returns the latest cash flow statement of the company as a DataFrame.""" + ticker = symbol + cash_flow = ticker.cashflow + return cash_flow + + def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple: + """Fetches the latest analyst recommendations and returns the most common recommendation and its count.""" + ticker = symbol + recommendations = ticker.recommendations + if recommendations.empty: + return None, 0 # No recommendations available + + # Assuming 'period' column exists and needs to be excluded + row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary + + # Find the maximum voting result + max_votes = row_0.max() + majority_voting_result = row_0[row_0 == max_votes].index.tolist() + + return majority_voting_result[0], max_votes diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 1f40a2a2..33fad5e0 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -1,33 +1,33 @@ -import os - -DEFAULT_CONFIG = { - "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), - "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), - "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", - "data_cache_dir": os.path.join( - os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), - "dataflows/data_cache", - ), - # LLM settings - "llm_provider": "openai", - "deep_think_llm": "o4-mini", - "quick_think_llm": "gpt-4o-mini", - "backend_url": "https://api.openai.com/v1", - # Debate and discussion settings - "max_debate_rounds": 1, - "max_risk_discuss_rounds": 1, - "max_recur_limit": 100, - # Data vendor configuration - # Category-level configuration (default for all tools in category) - "data_vendors": { - "core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local - "technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local - "fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local - "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local - }, - # Tool-level configuration (takes precedence over category-level) - "tool_vendors": { - # Example: "get_stock_data": "alpha_vantage", # Override category default - # Example: "get_news": "openai", # Override category default - }, -} +import os + +DEFAULT_CONFIG = { + "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), + "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), + "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", + "data_cache_dir": os.path.join( + os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), + "dataflows/data_cache", + ), + # LLM settings + "llm_provider": "openai", + "deep_think_llm": "o4-mini", + "quick_think_llm": "gpt-4o-mini", + "backend_url": "https://api.openai.com/v1", + # Debate and discussion settings + "max_debate_rounds": 1, + "max_risk_discuss_rounds": 1, + "max_recur_limit": 100, + # Data vendor configuration + # Category-level configuration (default for all tools in category) + "data_vendors": { + "core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local + "technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local + "fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local + "news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local + }, + # Tool-level configuration (takes precedence over category-level) + "tool_vendors": { + # Example: "get_stock_data": "alpha_vantage", # Override category default + # Example: "get_news": "openai", # Override category default + }, +} diff --git a/tradingagents/graph/__init__.py b/tradingagents/graph/__init__.py index 80982c19..e713b81d 100644 --- a/tradingagents/graph/__init__.py +++ b/tradingagents/graph/__init__.py @@ -1,17 +1,17 @@ -# TradingAgents/graph/__init__.py - -from .trading_graph import TradingAgentsGraph -from .conditional_logic import ConditionalLogic -from .setup import GraphSetup -from .propagation import Propagator -from .reflection import Reflector -from .signal_processing import SignalProcessor - -__all__ = [ - "TradingAgentsGraph", - "ConditionalLogic", - "GraphSetup", - "Propagator", - "Reflector", - "SignalProcessor", -] +# TradingAgents/graph/__init__.py + +from .trading_graph import TradingAgentsGraph +from .conditional_logic import ConditionalLogic +from .setup import GraphSetup +from .propagation import Propagator +from .reflection import Reflector +from .signal_processing import SignalProcessor + +__all__ = [ + "TradingAgentsGraph", + "ConditionalLogic", + "GraphSetup", + "Propagator", + "Reflector", + "SignalProcessor", +] diff --git a/tradingagents/graph/conditional_logic.py b/tradingagents/graph/conditional_logic.py index e7c87859..dc28c9b2 100644 --- a/tradingagents/graph/conditional_logic.py +++ b/tradingagents/graph/conditional_logic.py @@ -1,67 +1,67 @@ -# TradingAgents/graph/conditional_logic.py - -from tradingagents.agents.utils.agent_states import AgentState - - -class ConditionalLogic: - """Handles conditional logic for determining graph flow.""" - - def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1): - """Initialize with configuration parameters.""" - self.max_debate_rounds = max_debate_rounds - self.max_risk_discuss_rounds = max_risk_discuss_rounds - - def should_continue_market(self, state: AgentState): - """Determine if market analysis should continue.""" - messages = state["messages"] - last_message = messages[-1] - if last_message.tool_calls: - return "tools_market" - return "Msg Clear Market" - - def should_continue_social(self, state: AgentState): - """Determine if social media analysis should continue.""" - messages = state["messages"] - last_message = messages[-1] - if last_message.tool_calls: - return "tools_social" - return "Msg Clear Social" - - def should_continue_news(self, state: AgentState): - """Determine if news analysis should continue.""" - messages = state["messages"] - last_message = messages[-1] - if last_message.tool_calls: - return "tools_news" - return "Msg Clear News" - - def should_continue_fundamentals(self, state: AgentState): - """Determine if fundamentals analysis should continue.""" - messages = state["messages"] - last_message = messages[-1] - if last_message.tool_calls: - return "tools_fundamentals" - return "Msg Clear Fundamentals" - - def should_continue_debate(self, state: AgentState) -> str: - """Determine if debate should continue.""" - - if ( - state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds - ): # 3 rounds of back-and-forth between 2 agents - return "Research Manager" - if state["investment_debate_state"]["current_response"].startswith("Bull"): - return "Bear Researcher" - return "Bull Researcher" - - def should_continue_risk_analysis(self, state: AgentState) -> str: - """Determine if risk analysis should continue.""" - if ( - state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds - ): # 3 rounds of back-and-forth between 3 agents - return "Risk Judge" - if state["risk_debate_state"]["latest_speaker"].startswith("Risky"): - return "Safe Analyst" - if state["risk_debate_state"]["latest_speaker"].startswith("Safe"): - return "Neutral Analyst" - return "Risky Analyst" +# TradingAgents/graph/conditional_logic.py + +from tradingagents.agents.utils.agent_states import AgentState + + +class ConditionalLogic: + """Handles conditional logic for determining graph flow.""" + + def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1): + """Initialize with configuration parameters.""" + self.max_debate_rounds = max_debate_rounds + self.max_risk_discuss_rounds = max_risk_discuss_rounds + + def should_continue_market(self, state: AgentState): + """Determine if market analysis should continue.""" + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools_market" + return "Msg Clear Market" + + def should_continue_social(self, state: AgentState): + """Determine if social media analysis should continue.""" + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools_social" + return "Msg Clear Social" + + def should_continue_news(self, state: AgentState): + """Determine if news analysis should continue.""" + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools_news" + return "Msg Clear News" + + def should_continue_fundamentals(self, state: AgentState): + """Determine if fundamentals analysis should continue.""" + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return "tools_fundamentals" + return "Msg Clear Fundamentals" + + def should_continue_debate(self, state: AgentState) -> str: + """Determine if debate should continue.""" + + if ( + state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds + ): # 3 rounds of back-and-forth between 2 agents + return "Research Manager" + if state["investment_debate_state"]["current_response"].startswith("Bull"): + return "Bear Researcher" + return "Bull Researcher" + + def should_continue_risk_analysis(self, state: AgentState) -> str: + """Determine if risk analysis should continue.""" + if ( + state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds + ): # 3 rounds of back-and-forth between 3 agents + return "Risk Judge" + if state["risk_debate_state"]["latest_speaker"].startswith("Risky"): + return "Safe Analyst" + if state["risk_debate_state"]["latest_speaker"].startswith("Safe"): + return "Neutral Analyst" + return "Risky Analyst" diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index 58ebd0a8..0eec27d1 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -1,49 +1,49 @@ -# TradingAgents/graph/propagation.py - -from typing import Dict, Any -from tradingagents.agents.utils.agent_states import ( - AgentState, - InvestDebateState, - RiskDebateState, -) - - -class Propagator: - """Handles state initialization and propagation through the graph.""" - - def __init__(self, max_recur_limit=100): - """Initialize with configuration parameters.""" - self.max_recur_limit = max_recur_limit - - def create_initial_state( - self, company_name: str, trade_date: str - ) -> Dict[str, Any]: - """Create the initial state for the agent graph.""" - return { - "messages": [("human", company_name)], - "company_of_interest": company_name, - "trade_date": str(trade_date), - "investment_debate_state": InvestDebateState( - {"history": "", "current_response": "", "count": 0} - ), - "risk_debate_state": RiskDebateState( - { - "history": "", - "current_risky_response": "", - "current_safe_response": "", - "current_neutral_response": "", - "count": 0, - } - ), - "market_report": "", - "fundamentals_report": "", - "sentiment_report": "", - "news_report": "", - } - - def get_graph_args(self) -> Dict[str, Any]: - """Get arguments for the graph invocation.""" - return { - "stream_mode": "values", - "config": {"recursion_limit": self.max_recur_limit}, - } +# TradingAgents/graph/propagation.py + +from typing import Dict, Any +from tradingagents.agents.utils.agent_states import ( + AgentState, + InvestDebateState, + RiskDebateState, +) + + +class Propagator: + """Handles state initialization and propagation through the graph.""" + + def __init__(self, max_recur_limit=100): + """Initialize with configuration parameters.""" + self.max_recur_limit = max_recur_limit + + def create_initial_state( + self, company_name: str, trade_date: str + ) -> Dict[str, Any]: + """Create the initial state for the agent graph.""" + return { + "messages": [("human", company_name)], + "company_of_interest": company_name, + "trade_date": str(trade_date), + "investment_debate_state": InvestDebateState( + {"history": "", "current_response": "", "count": 0} + ), + "risk_debate_state": RiskDebateState( + { + "history": "", + "current_risky_response": "", + "current_safe_response": "", + "current_neutral_response": "", + "count": 0, + } + ), + "market_report": "", + "fundamentals_report": "", + "sentiment_report": "", + "news_report": "", + } + + def get_graph_args(self) -> Dict[str, Any]: + """Get arguments for the graph invocation.""" + return { + "stream_mode": "values", + "config": {"recursion_limit": self.max_recur_limit}, + } diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 33303231..57a876dc 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -1,121 +1,121 @@ -# TradingAgents/graph/reflection.py - -from typing import Dict, Any -from langchain_openai import ChatOpenAI - - -class Reflector: - """Handles reflection on decisions and updating memory.""" - - def __init__(self, quick_thinking_llm: ChatOpenAI): - """Initialize the reflector with an LLM.""" - self.quick_thinking_llm = quick_thinking_llm - self.reflection_system_prompt = self._get_reflection_prompt() - - def _get_reflection_prompt(self) -> str: - """Get the system prompt for reflection.""" - return """ -You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis. -Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines: - -1. Reasoning: - - For each trading decision, determine whether it was correct or incorrect. A correct decision results in an increase in returns, while an incorrect decision does the opposite. - - Analyze the contributing factors to each success or mistake. Consider: - - Market intelligence. - - Technical indicators. - - Technical signals. - - Price movement analysis. - - Overall market data analysis - - News analysis. - - Social media and sentiment analysis. - - Fundamental data analysis. - - Weight the importance of each factor in the decision-making process. - -2. Improvement: - - For any incorrect decisions, propose revisions to maximize returns. - - Provide a detailed list of corrective actions or improvements, including specific recommendations (e.g., changing a decision from HOLD to BUY on a particular date). - -3. Summary: - - Summarize the lessons learned from the successes and mistakes. - - Highlight how these lessons can be adapted for future trading scenarios and draw connections between similar situations to apply the knowledge gained. - -4. Query: - - Extract key insights from the summary into a concise sentence of no more than 1000 tokens. - - Ensure the condensed sentence captures the essence of the lessons and reasoning for easy reference. - -Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis. -""" - - def _extract_current_situation(self, current_state: Dict[str, Any]) -> str: - """Extract the current market situation from the state.""" - curr_market_report = current_state["market_report"] - curr_sentiment_report = current_state["sentiment_report"] - curr_news_report = current_state["news_report"] - curr_fundamentals_report = current_state["fundamentals_report"] - - return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" - - def _reflect_on_component( - self, component_type: str, report: str, situation: str, returns_losses - ) -> str: - """Generate reflection for a component.""" - messages = [ - ("system", self.reflection_system_prompt), - ( - "human", - f"Returns: {returns_losses}\n\nAnalysis/Decision: {report}\n\nObjective Market Reports for Reference: {situation}", - ), - ] - - result = self.quick_thinking_llm.invoke(messages).content - return result - - def reflect_bull_researcher(self, current_state, returns_losses, bull_memory): - """Reflect on bull researcher's analysis and update memory.""" - situation = self._extract_current_situation(current_state) - bull_debate_history = current_state["investment_debate_state"]["bull_history"] - - result = self._reflect_on_component( - "BULL", bull_debate_history, situation, returns_losses - ) - bull_memory.add_situations([(situation, result)]) - - def reflect_bear_researcher(self, current_state, returns_losses, bear_memory): - """Reflect on bear researcher's analysis and update memory.""" - situation = self._extract_current_situation(current_state) - bear_debate_history = current_state["investment_debate_state"]["bear_history"] - - result = self._reflect_on_component( - "BEAR", bear_debate_history, situation, returns_losses - ) - bear_memory.add_situations([(situation, result)]) - - def reflect_trader(self, current_state, returns_losses, trader_memory): - """Reflect on trader's decision and update memory.""" - situation = self._extract_current_situation(current_state) - trader_decision = current_state["trader_investment_plan"] - - result = self._reflect_on_component( - "TRADER", trader_decision, situation, returns_losses - ) - trader_memory.add_situations([(situation, result)]) - - def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory): - """Reflect on investment judge's decision and update memory.""" - situation = self._extract_current_situation(current_state) - judge_decision = current_state["investment_debate_state"]["judge_decision"] - - result = self._reflect_on_component( - "INVEST JUDGE", judge_decision, situation, returns_losses - ) - invest_judge_memory.add_situations([(situation, result)]) - - def reflect_risk_manager(self, current_state, returns_losses, risk_manager_memory): - """Reflect on risk manager's decision and update memory.""" - situation = self._extract_current_situation(current_state) - judge_decision = current_state["risk_debate_state"]["judge_decision"] - - result = self._reflect_on_component( - "RISK JUDGE", judge_decision, situation, returns_losses - ) - risk_manager_memory.add_situations([(situation, result)]) +# TradingAgents/graph/reflection.py + +from typing import Dict, Any +from langchain_openai import ChatOpenAI + + +class Reflector: + """Handles reflection on decisions and updating memory.""" + + def __init__(self, quick_thinking_llm: ChatOpenAI): + """Initialize the reflector with an LLM.""" + self.quick_thinking_llm = quick_thinking_llm + self.reflection_system_prompt = self._get_reflection_prompt() + + def _get_reflection_prompt(self) -> str: + """Get the system prompt for reflection.""" + return """ +You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis. +Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines: + +1. Reasoning: + - For each trading decision, determine whether it was correct or incorrect. A correct decision results in an increase in returns, while an incorrect decision does the opposite. + - Analyze the contributing factors to each success or mistake. Consider: + - Market intelligence. + - Technical indicators. + - Technical signals. + - Price movement analysis. + - Overall market data analysis + - News analysis. + - Social media and sentiment analysis. + - Fundamental data analysis. + - Weight the importance of each factor in the decision-making process. + +2. Improvement: + - For any incorrect decisions, propose revisions to maximize returns. + - Provide a detailed list of corrective actions or improvements, including specific recommendations (e.g., changing a decision from HOLD to BUY on a particular date). + +3. Summary: + - Summarize the lessons learned from the successes and mistakes. + - Highlight how these lessons can be adapted for future trading scenarios and draw connections between similar situations to apply the knowledge gained. + +4. Query: + - Extract key insights from the summary into a concise sentence of no more than 1000 tokens. + - Ensure the condensed sentence captures the essence of the lessons and reasoning for easy reference. + +Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis. +""" + + def _extract_current_situation(self, current_state: Dict[str, Any]) -> str: + """Extract the current market situation from the state.""" + curr_market_report = current_state["market_report"] + curr_sentiment_report = current_state["sentiment_report"] + curr_news_report = current_state["news_report"] + curr_fundamentals_report = current_state["fundamentals_report"] + + return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" + + def _reflect_on_component( + self, component_type: str, report: str, situation: str, returns_losses + ) -> str: + """Generate reflection for a component.""" + messages = [ + ("system", self.reflection_system_prompt), + ( + "human", + f"Returns: {returns_losses}\n\nAnalysis/Decision: {report}\n\nObjective Market Reports for Reference: {situation}", + ), + ] + + result = self.quick_thinking_llm.invoke(messages).content + return result + + def reflect_bull_researcher(self, current_state, returns_losses, bull_memory): + """Reflect on bull researcher's analysis and update memory.""" + situation = self._extract_current_situation(current_state) + bull_debate_history = current_state["investment_debate_state"]["bull_history"] + + result = self._reflect_on_component( + "BULL", bull_debate_history, situation, returns_losses + ) + bull_memory.add_situations([(situation, result)]) + + def reflect_bear_researcher(self, current_state, returns_losses, bear_memory): + """Reflect on bear researcher's analysis and update memory.""" + situation = self._extract_current_situation(current_state) + bear_debate_history = current_state["investment_debate_state"]["bear_history"] + + result = self._reflect_on_component( + "BEAR", bear_debate_history, situation, returns_losses + ) + bear_memory.add_situations([(situation, result)]) + + def reflect_trader(self, current_state, returns_losses, trader_memory): + """Reflect on trader's decision and update memory.""" + situation = self._extract_current_situation(current_state) + trader_decision = current_state["trader_investment_plan"] + + result = self._reflect_on_component( + "TRADER", trader_decision, situation, returns_losses + ) + trader_memory.add_situations([(situation, result)]) + + def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory): + """Reflect on investment judge's decision and update memory.""" + situation = self._extract_current_situation(current_state) + judge_decision = current_state["investment_debate_state"]["judge_decision"] + + result = self._reflect_on_component( + "INVEST JUDGE", judge_decision, situation, returns_losses + ) + invest_judge_memory.add_situations([(situation, result)]) + + def reflect_risk_manager(self, current_state, returns_losses, risk_manager_memory): + """Reflect on risk manager's decision and update memory.""" + situation = self._extract_current_situation(current_state) + judge_decision = current_state["risk_debate_state"]["judge_decision"] + + result = self._reflect_on_component( + "RISK JUDGE", judge_decision, situation, returns_losses + ) + risk_manager_memory.add_situations([(situation, result)]) diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index b270ffc0..e99046c6 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,202 +1,202 @@ -# TradingAgents/graph/setup.py - -from typing import Dict, Any -from langchain_openai import ChatOpenAI -from langgraph.graph import END, StateGraph, START -from langgraph.prebuilt import ToolNode - -from tradingagents.agents import * -from tradingagents.agents.utils.agent_states import AgentState - -from .conditional_logic import ConditionalLogic - - -class GraphSetup: - """Handles the setup and configuration of the agent graph.""" - - def __init__( - self, - quick_thinking_llm: ChatOpenAI, - deep_thinking_llm: ChatOpenAI, - tool_nodes: Dict[str, ToolNode], - bull_memory, - bear_memory, - trader_memory, - invest_judge_memory, - risk_manager_memory, - conditional_logic: ConditionalLogic, - ): - """Initialize with required components.""" - self.quick_thinking_llm = quick_thinking_llm - self.deep_thinking_llm = deep_thinking_llm - self.tool_nodes = tool_nodes - self.bull_memory = bull_memory - self.bear_memory = bear_memory - self.trader_memory = trader_memory - self.invest_judge_memory = invest_judge_memory - self.risk_manager_memory = risk_manager_memory - self.conditional_logic = conditional_logic - - def setup_graph( - self, selected_analysts=["market", "social", "news", "fundamentals"] - ): - """Set up and compile the agent workflow graph. - - Args: - selected_analysts (list): List of analyst types to include. Options are: - - "market": Market analyst - - "social": Social media analyst - - "news": News analyst - - "fundamentals": Fundamentals analyst - """ - if len(selected_analysts) == 0: - raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") - - # Create analyst nodes - analyst_nodes = {} - delete_nodes = {} - tool_nodes = {} - - if "market" in selected_analysts: - analyst_nodes["market"] = create_market_analyst( - self.quick_thinking_llm - ) - delete_nodes["market"] = create_msg_delete() - tool_nodes["market"] = self.tool_nodes["market"] - - if "social" in selected_analysts: - analyst_nodes["social"] = create_social_media_analyst( - self.quick_thinking_llm - ) - delete_nodes["social"] = create_msg_delete() - tool_nodes["social"] = self.tool_nodes["social"] - - if "news" in selected_analysts: - analyst_nodes["news"] = create_news_analyst( - self.quick_thinking_llm - ) - delete_nodes["news"] = create_msg_delete() - tool_nodes["news"] = self.tool_nodes["news"] - - if "fundamentals" in selected_analysts: - analyst_nodes["fundamentals"] = create_fundamentals_analyst( - self.quick_thinking_llm - ) - delete_nodes["fundamentals"] = create_msg_delete() - tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] - - # Create researcher and manager nodes - bull_researcher_node = create_bull_researcher( - self.quick_thinking_llm, self.bull_memory - ) - bear_researcher_node = create_bear_researcher( - self.quick_thinking_llm, self.bear_memory - ) - research_manager_node = create_research_manager( - self.deep_thinking_llm, self.invest_judge_memory - ) - trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) - - # Create risk analysis nodes - risky_analyst = create_risky_debator(self.quick_thinking_llm) - neutral_analyst = create_neutral_debator(self.quick_thinking_llm) - safe_analyst = create_safe_debator(self.quick_thinking_llm) - risk_manager_node = create_risk_manager( - self.deep_thinking_llm, self.risk_manager_memory - ) - - # Create workflow - workflow = StateGraph(AgentState) - - # Add analyst nodes to the graph - for analyst_type, node in analyst_nodes.items(): - workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) - workflow.add_node( - f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type] - ) - workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) - - # Add other nodes - workflow.add_node("Bull Researcher", bull_researcher_node) - workflow.add_node("Bear Researcher", bear_researcher_node) - workflow.add_node("Research Manager", research_manager_node) - workflow.add_node("Trader", trader_node) - workflow.add_node("Risky Analyst", risky_analyst) - workflow.add_node("Neutral Analyst", neutral_analyst) - workflow.add_node("Safe Analyst", safe_analyst) - workflow.add_node("Risk Judge", risk_manager_node) - - # Define edges - # Start with the first analyst - first_analyst = selected_analysts[0] - workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst") - - # Connect analysts in sequence - for i, analyst_type in enumerate(selected_analysts): - current_analyst = f"{analyst_type.capitalize()} Analyst" - current_tools = f"tools_{analyst_type}" - current_clear = f"Msg Clear {analyst_type.capitalize()}" - - # Add conditional edges for current analyst - workflow.add_conditional_edges( - current_analyst, - getattr(self.conditional_logic, f"should_continue_{analyst_type}"), - [current_tools, current_clear], - ) - workflow.add_edge(current_tools, current_analyst) - - # Connect to next analyst or to Bull Researcher if this is the last analyst - if i < len(selected_analysts) - 1: - next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" - workflow.add_edge(current_clear, next_analyst) - else: - workflow.add_edge(current_clear, "Bull Researcher") - - # Add remaining edges - workflow.add_conditional_edges( - "Bull Researcher", - self.conditional_logic.should_continue_debate, - { - "Bear Researcher": "Bear Researcher", - "Research Manager": "Research Manager", - }, - ) - workflow.add_conditional_edges( - "Bear Researcher", - self.conditional_logic.should_continue_debate, - { - "Bull Researcher": "Bull Researcher", - "Research Manager": "Research Manager", - }, - ) - workflow.add_edge("Research Manager", "Trader") - workflow.add_edge("Trader", "Risky Analyst") - workflow.add_conditional_edges( - "Risky Analyst", - self.conditional_logic.should_continue_risk_analysis, - { - "Safe Analyst": "Safe Analyst", - "Risk Judge": "Risk Judge", - }, - ) - workflow.add_conditional_edges( - "Safe Analyst", - self.conditional_logic.should_continue_risk_analysis, - { - "Neutral Analyst": "Neutral Analyst", - "Risk Judge": "Risk Judge", - }, - ) - workflow.add_conditional_edges( - "Neutral Analyst", - self.conditional_logic.should_continue_risk_analysis, - { - "Risky Analyst": "Risky Analyst", - "Risk Judge": "Risk Judge", - }, - ) - - workflow.add_edge("Risk Judge", END) - - # Compile and return - return workflow.compile() +# TradingAgents/graph/setup.py + +from typing import Dict, Any +from langchain_openai import ChatOpenAI +from langgraph.graph import END, StateGraph, START +from langgraph.prebuilt import ToolNode + +from tradingagents.agents import * +from tradingagents.agents.utils.agent_states import AgentState + +from .conditional_logic import ConditionalLogic + + +class GraphSetup: + """Handles the setup and configuration of the agent graph.""" + + def __init__( + self, + quick_thinking_llm: ChatOpenAI, + deep_thinking_llm: ChatOpenAI, + tool_nodes: Dict[str, ToolNode], + bull_memory, + bear_memory, + trader_memory, + invest_judge_memory, + risk_manager_memory, + conditional_logic: ConditionalLogic, + ): + """Initialize with required components.""" + self.quick_thinking_llm = quick_thinking_llm + self.deep_thinking_llm = deep_thinking_llm + self.tool_nodes = tool_nodes + self.bull_memory = bull_memory + self.bear_memory = bear_memory + self.trader_memory = trader_memory + self.invest_judge_memory = invest_judge_memory + self.risk_manager_memory = risk_manager_memory + self.conditional_logic = conditional_logic + + def setup_graph( + self, selected_analysts=["market", "social", "news", "fundamentals"] + ): + """Set up and compile the agent workflow graph. + + Args: + selected_analysts (list): List of analyst types to include. Options are: + - "market": Market analyst + - "social": Social media analyst + - "news": News analyst + - "fundamentals": Fundamentals analyst + """ + if len(selected_analysts) == 0: + raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") + + # Create analyst nodes + analyst_nodes = {} + delete_nodes = {} + tool_nodes = {} + + if "market" in selected_analysts: + analyst_nodes["market"] = create_market_analyst( + self.quick_thinking_llm + ) + delete_nodes["market"] = create_msg_delete() + tool_nodes["market"] = self.tool_nodes["market"] + + if "social" in selected_analysts: + analyst_nodes["social"] = create_social_media_analyst( + self.quick_thinking_llm + ) + delete_nodes["social"] = create_msg_delete() + tool_nodes["social"] = self.tool_nodes["social"] + + if "news" in selected_analysts: + analyst_nodes["news"] = create_news_analyst( + self.quick_thinking_llm + ) + delete_nodes["news"] = create_msg_delete() + tool_nodes["news"] = self.tool_nodes["news"] + + if "fundamentals" in selected_analysts: + analyst_nodes["fundamentals"] = create_fundamentals_analyst( + self.quick_thinking_llm + ) + delete_nodes["fundamentals"] = create_msg_delete() + tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] + + # Create researcher and manager nodes + bull_researcher_node = create_bull_researcher( + self.quick_thinking_llm, self.bull_memory + ) + bear_researcher_node = create_bear_researcher( + self.quick_thinking_llm, self.bear_memory + ) + research_manager_node = create_research_manager( + self.deep_thinking_llm, self.invest_judge_memory + ) + trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) + + # Create risk analysis nodes + risky_analyst = create_risky_debator(self.quick_thinking_llm) + neutral_analyst = create_neutral_debator(self.quick_thinking_llm) + safe_analyst = create_safe_debator(self.quick_thinking_llm) + risk_manager_node = create_risk_manager( + self.deep_thinking_llm, self.risk_manager_memory + ) + + # Create workflow + workflow = StateGraph(AgentState) + + # Add analyst nodes to the graph + for analyst_type, node in analyst_nodes.items(): + workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) + workflow.add_node( + f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type] + ) + workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) + + # Add other nodes + workflow.add_node("Bull Researcher", bull_researcher_node) + workflow.add_node("Bear Researcher", bear_researcher_node) + workflow.add_node("Research Manager", research_manager_node) + workflow.add_node("Trader", trader_node) + workflow.add_node("Risky Analyst", risky_analyst) + workflow.add_node("Neutral Analyst", neutral_analyst) + workflow.add_node("Safe Analyst", safe_analyst) + workflow.add_node("Risk Judge", risk_manager_node) + + # Define edges + # Start with the first analyst + first_analyst = selected_analysts[0] + workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst") + + # Connect analysts in sequence + for i, analyst_type in enumerate(selected_analysts): + current_analyst = f"{analyst_type.capitalize()} Analyst" + current_tools = f"tools_{analyst_type}" + current_clear = f"Msg Clear {analyst_type.capitalize()}" + + # Add conditional edges for current analyst + workflow.add_conditional_edges( + current_analyst, + getattr(self.conditional_logic, f"should_continue_{analyst_type}"), + [current_tools, current_clear], + ) + workflow.add_edge(current_tools, current_analyst) + + # Connect to next analyst or to Bull Researcher if this is the last analyst + if i < len(selected_analysts) - 1: + next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst" + workflow.add_edge(current_clear, next_analyst) + else: + workflow.add_edge(current_clear, "Bull Researcher") + + # Add remaining edges + workflow.add_conditional_edges( + "Bull Researcher", + self.conditional_logic.should_continue_debate, + { + "Bear Researcher": "Bear Researcher", + "Research Manager": "Research Manager", + }, + ) + workflow.add_conditional_edges( + "Bear Researcher", + self.conditional_logic.should_continue_debate, + { + "Bull Researcher": "Bull Researcher", + "Research Manager": "Research Manager", + }, + ) + workflow.add_edge("Research Manager", "Trader") + workflow.add_edge("Trader", "Risky Analyst") + workflow.add_conditional_edges( + "Risky Analyst", + self.conditional_logic.should_continue_risk_analysis, + { + "Safe Analyst": "Safe Analyst", + "Risk Judge": "Risk Judge", + }, + ) + workflow.add_conditional_edges( + "Safe Analyst", + self.conditional_logic.should_continue_risk_analysis, + { + "Neutral Analyst": "Neutral Analyst", + "Risk Judge": "Risk Judge", + }, + ) + workflow.add_conditional_edges( + "Neutral Analyst", + self.conditional_logic.should_continue_risk_analysis, + { + "Risky Analyst": "Risky Analyst", + "Risk Judge": "Risk Judge", + }, + ) + + workflow.add_edge("Risk Judge", END) + + # Compile and return + return workflow.compile() diff --git a/tradingagents/graph/signal_processing.py b/tradingagents/graph/signal_processing.py index 903e8529..d4b843e5 100644 --- a/tradingagents/graph/signal_processing.py +++ b/tradingagents/graph/signal_processing.py @@ -1,31 +1,31 @@ -# TradingAgents/graph/signal_processing.py - -from langchain_openai import ChatOpenAI - - -class SignalProcessor: - """Processes trading signals to extract actionable decisions.""" - - def __init__(self, quick_thinking_llm: ChatOpenAI): - """Initialize with an LLM for processing.""" - self.quick_thinking_llm = quick_thinking_llm - - def process_signal(self, full_signal: str) -> str: - """ - Process a full trading signal to extract the core decision. - - Args: - full_signal: Complete trading signal text - - Returns: - Extracted decision (BUY, SELL, or HOLD) - """ - messages = [ - ( - "system", - "You are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.", - ), - ("human", full_signal), - ] - - return self.quick_thinking_llm.invoke(messages).content +# TradingAgents/graph/signal_processing.py + +from langchain_openai import ChatOpenAI + + +class SignalProcessor: + """Processes trading signals to extract actionable decisions.""" + + def __init__(self, quick_thinking_llm: ChatOpenAI): + """Initialize with an LLM for processing.""" + self.quick_thinking_llm = quick_thinking_llm + + def process_signal(self, full_signal: str) -> str: + """ + Process a full trading signal to extract the core decision. + + Args: + full_signal: Complete trading signal text + + Returns: + Extracted decision (BUY, SELL, or HOLD) + """ + messages = [ + ( + "system", + "You are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.", + ), + ("human", full_signal), + ] + + return self.quick_thinking_llm.invoke(messages).content diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 40cdff75..877742b5 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,257 +1,257 @@ -# TradingAgents/graph/trading_graph.py - -import os -from pathlib import Path -import json -from datetime import date -from typing import Dict, Any, Tuple, List, Optional - -from langchain_openai import ChatOpenAI -from langchain_anthropic import ChatAnthropic -from langchain_google_genai import ChatGoogleGenerativeAI - -from langgraph.prebuilt import ToolNode - -from tradingagents.agents import * -from tradingagents.default_config import DEFAULT_CONFIG -from tradingagents.agents.utils.memory import FinancialSituationMemory -from tradingagents.agents.utils.agent_states import ( - AgentState, - InvestDebateState, - RiskDebateState, -) -from tradingagents.dataflows.config import set_config - -# Import the new abstract tool methods from agent_utils -from tradingagents.agents.utils.agent_utils import ( - get_stock_data, - get_indicators, - get_fundamentals, - get_balance_sheet, - get_cashflow, - get_income_statement, - get_news, - get_insider_sentiment, - get_insider_transactions, - get_global_news -) - -from .conditional_logic import ConditionalLogic -from .setup import GraphSetup -from .propagation import Propagator -from .reflection import Reflector -from .signal_processing import SignalProcessor - - -class TradingAgentsGraph: - """Main class that orchestrates the trading agents framework.""" - - def __init__( - self, - selected_analysts=["market", "social", "news", "fundamentals"], - debug=False, - config: Dict[str, Any] = None, - ): - """Initialize the trading agents graph and components. - - Args: - selected_analysts: List of analyst types to include - debug: Whether to run in debug mode - config: Configuration dictionary. If None, uses default config - """ - self.debug = debug - self.config = config or DEFAULT_CONFIG - - # Update the interface's config - set_config(self.config) - - # Create necessary directories - os.makedirs( - os.path.join(self.config["project_dir"], "dataflows/data_cache"), - exist_ok=True, - ) - - # Initialize LLMs - if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": - self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) - self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) - elif self.config["llm_provider"].lower() == "anthropic": - self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) - self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) - elif self.config["llm_provider"].lower() == "google": - self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"]) - self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) - else: - raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") - - # Initialize memories - self.bull_memory = FinancialSituationMemory("bull_memory", self.config) - self.bear_memory = FinancialSituationMemory("bear_memory", self.config) - self.trader_memory = FinancialSituationMemory("trader_memory", self.config) - self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) - self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) - - # Create tool nodes - self.tool_nodes = self._create_tool_nodes() - - # Initialize components - self.conditional_logic = ConditionalLogic() - self.graph_setup = GraphSetup( - self.quick_thinking_llm, - self.deep_thinking_llm, - self.tool_nodes, - self.bull_memory, - self.bear_memory, - self.trader_memory, - self.invest_judge_memory, - self.risk_manager_memory, - self.conditional_logic, - ) - - self.propagator = Propagator() - self.reflector = Reflector(self.quick_thinking_llm) - self.signal_processor = SignalProcessor(self.quick_thinking_llm) - - # State tracking - self.curr_state = None - self.ticker = None - self.log_states_dict = {} # date to full state dict - - # Set up the graph - self.graph = self.graph_setup.setup_graph(selected_analysts) - - def _create_tool_nodes(self) -> Dict[str, ToolNode]: - """Create tool nodes for different data sources using abstract methods.""" - return { - "market": ToolNode( - [ - # Core stock data tools - get_stock_data, - # Technical indicators - get_indicators, - ] - ), - "social": ToolNode( - [ - # News tools for social media analysis - get_news, - ] - ), - "news": ToolNode( - [ - # News and insider information - get_news, - get_global_news, - get_insider_sentiment, - get_insider_transactions, - ] - ), - "fundamentals": ToolNode( - [ - # Fundamental analysis tools - get_fundamentals, - get_balance_sheet, - get_cashflow, - get_income_statement, - ] - ), - } - - def propagate(self, company_name, trade_date): - """Run the trading agents graph for a company on a specific date.""" - - self.ticker = company_name - - # Initialize state - init_agent_state = self.propagator.create_initial_state( - company_name, trade_date - ) - args = self.propagator.get_graph_args() - - if self.debug: - # Debug mode with tracing - trace = [] - for chunk in self.graph.stream(init_agent_state, **args): - if len(chunk["messages"]) == 0: - pass - else: - chunk["messages"][-1].pretty_print() - trace.append(chunk) - - final_state = trace[-1] - else: - # Standard mode without tracing - final_state = self.graph.invoke(init_agent_state, **args) - - # Store current state for reflection - self.curr_state = final_state - - # Log state - self._log_state(trade_date, final_state) - - # Return decision and processed signal - return final_state, self.process_signal(final_state["final_trade_decision"]) - - def _log_state(self, trade_date, final_state): - """Log the final state to a JSON file.""" - self.log_states_dict[str(trade_date)] = { - "company_of_interest": final_state["company_of_interest"], - "trade_date": final_state["trade_date"], - "market_report": final_state["market_report"], - "sentiment_report": final_state["sentiment_report"], - "news_report": final_state["news_report"], - "fundamentals_report": final_state["fundamentals_report"], - "investment_debate_state": { - "bull_history": final_state["investment_debate_state"]["bull_history"], - "bear_history": final_state["investment_debate_state"]["bear_history"], - "history": final_state["investment_debate_state"]["history"], - "current_response": final_state["investment_debate_state"][ - "current_response" - ], - "judge_decision": final_state["investment_debate_state"][ - "judge_decision" - ], - }, - "trader_investment_decision": final_state["trader_investment_plan"], - "risk_debate_state": { - "risky_history": final_state["risk_debate_state"]["risky_history"], - "safe_history": final_state["risk_debate_state"]["safe_history"], - "neutral_history": final_state["risk_debate_state"]["neutral_history"], - "history": final_state["risk_debate_state"]["history"], - "judge_decision": final_state["risk_debate_state"]["judge_decision"], - }, - "investment_plan": final_state["investment_plan"], - "final_trade_decision": final_state["final_trade_decision"], - } - - # Save to file - directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") - directory.mkdir(parents=True, exist_ok=True) - - with open( - f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", - "w", - ) as f: - json.dump(self.log_states_dict, f, indent=4) - - def reflect_and_remember(self, returns_losses): - """Reflect on decisions and update memory based on returns.""" - self.reflector.reflect_bull_researcher( - self.curr_state, returns_losses, self.bull_memory - ) - self.reflector.reflect_bear_researcher( - self.curr_state, returns_losses, self.bear_memory - ) - self.reflector.reflect_trader( - self.curr_state, returns_losses, self.trader_memory - ) - self.reflector.reflect_invest_judge( - self.curr_state, returns_losses, self.invest_judge_memory - ) - self.reflector.reflect_risk_manager( - self.curr_state, returns_losses, self.risk_manager_memory - ) - - def process_signal(self, full_signal): - """Process a signal to extract the core decision.""" - return self.signal_processor.process_signal(full_signal) +# TradingAgents/graph/trading_graph.py + +import os +from pathlib import Path +import json +from datetime import date +from typing import Dict, Any, Tuple, List, Optional + +from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic +from langchain_google_genai import ChatGoogleGenerativeAI + +from langgraph.prebuilt import ToolNode + +from tradingagents.agents import * +from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.agents.utils.memory import FinancialSituationMemory +from tradingagents.agents.utils.agent_states import ( + AgentState, + InvestDebateState, + RiskDebateState, +) +from tradingagents.dataflows.config import set_config + +# Import the new abstract tool methods from agent_utils +from tradingagents.agents.utils.agent_utils import ( + get_stock_data, + get_indicators, + get_fundamentals, + get_balance_sheet, + get_cashflow, + get_income_statement, + get_news, + get_insider_sentiment, + get_insider_transactions, + get_global_news +) + +from .conditional_logic import ConditionalLogic +from .setup import GraphSetup +from .propagation import Propagator +from .reflection import Reflector +from .signal_processing import SignalProcessor + + +class TradingAgentsGraph: + """Main class that orchestrates the trading agents framework.""" + + def __init__( + self, + selected_analysts=["market", "social", "news", "fundamentals"], + debug=False, + config: Dict[str, Any] = None, + ): + """Initialize the trading agents graph and components. + + Args: + selected_analysts: List of analyst types to include + debug: Whether to run in debug mode + config: Configuration dictionary. If None, uses default config + """ + self.debug = debug + self.config = config or DEFAULT_CONFIG + + # Update the interface's config + set_config(self.config) + + # Create necessary directories + os.makedirs( + os.path.join(self.config["project_dir"], "dataflows/data_cache"), + exist_ok=True, + ) + + # Initialize LLMs + if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": + self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) + self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) + elif self.config["llm_provider"].lower() == "anthropic": + self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) + self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) + elif self.config["llm_provider"].lower() == "google": + self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"]) + self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"]) + else: + raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") + + # Initialize memories + self.bull_memory = FinancialSituationMemory("bull_memory", self.config) + self.bear_memory = FinancialSituationMemory("bear_memory", self.config) + self.trader_memory = FinancialSituationMemory("trader_memory", self.config) + self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config) + self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config) + + # Create tool nodes + self.tool_nodes = self._create_tool_nodes() + + # Initialize components + self.conditional_logic = ConditionalLogic() + self.graph_setup = GraphSetup( + self.quick_thinking_llm, + self.deep_thinking_llm, + self.tool_nodes, + self.bull_memory, + self.bear_memory, + self.trader_memory, + self.invest_judge_memory, + self.risk_manager_memory, + self.conditional_logic, + ) + + self.propagator = Propagator() + self.reflector = Reflector(self.quick_thinking_llm) + self.signal_processor = SignalProcessor(self.quick_thinking_llm) + + # State tracking + self.curr_state = None + self.ticker = None + self.log_states_dict = {} # date to full state dict + + # Set up the graph + self.graph = self.graph_setup.setup_graph(selected_analysts) + + def _create_tool_nodes(self) -> Dict[str, ToolNode]: + """Create tool nodes for different data sources using abstract methods.""" + return { + "market": ToolNode( + [ + # Core stock data tools + get_stock_data, + # Technical indicators + get_indicators, + ] + ), + "social": ToolNode( + [ + # News tools for social media analysis + get_news, + ] + ), + "news": ToolNode( + [ + # News and insider information + get_news, + get_global_news, + get_insider_sentiment, + get_insider_transactions, + ] + ), + "fundamentals": ToolNode( + [ + # Fundamental analysis tools + get_fundamentals, + get_balance_sheet, + get_cashflow, + get_income_statement, + ] + ), + } + + def propagate(self, company_name, trade_date): + """Run the trading agents graph for a company on a specific date.""" + + self.ticker = company_name + + # Initialize state + init_agent_state = self.propagator.create_initial_state( + company_name, trade_date + ) + args = self.propagator.get_graph_args() + + if self.debug: + # Debug mode with tracing + trace = [] + for chunk in self.graph.stream(init_agent_state, **args): + if len(chunk["messages"]) == 0: + pass + else: + chunk["messages"][-1].pretty_print() + trace.append(chunk) + + final_state = trace[-1] + else: + # Standard mode without tracing + final_state = self.graph.invoke(init_agent_state, **args) + + # Store current state for reflection + self.curr_state = final_state + + # Log state + self._log_state(trade_date, final_state) + + # Return decision and processed signal + return final_state, self.process_signal(final_state["final_trade_decision"]) + + def _log_state(self, trade_date, final_state): + """Log the final state to a JSON file.""" + self.log_states_dict[str(trade_date)] = { + "company_of_interest": final_state["company_of_interest"], + "trade_date": final_state["trade_date"], + "market_report": final_state["market_report"], + "sentiment_report": final_state["sentiment_report"], + "news_report": final_state["news_report"], + "fundamentals_report": final_state["fundamentals_report"], + "investment_debate_state": { + "bull_history": final_state["investment_debate_state"]["bull_history"], + "bear_history": final_state["investment_debate_state"]["bear_history"], + "history": final_state["investment_debate_state"]["history"], + "current_response": final_state["investment_debate_state"][ + "current_response" + ], + "judge_decision": final_state["investment_debate_state"][ + "judge_decision" + ], + }, + "trader_investment_decision": final_state["trader_investment_plan"], + "risk_debate_state": { + "risky_history": final_state["risk_debate_state"]["risky_history"], + "safe_history": final_state["risk_debate_state"]["safe_history"], + "neutral_history": final_state["risk_debate_state"]["neutral_history"], + "history": final_state["risk_debate_state"]["history"], + "judge_decision": final_state["risk_debate_state"]["judge_decision"], + }, + "investment_plan": final_state["investment_plan"], + "final_trade_decision": final_state["final_trade_decision"], + } + + # Save to file + directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") + directory.mkdir(parents=True, exist_ok=True) + + with open( + f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", + "w", + ) as f: + json.dump(self.log_states_dict, f, indent=4) + + def reflect_and_remember(self, returns_losses): + """Reflect on decisions and update memory based on returns.""" + self.reflector.reflect_bull_researcher( + self.curr_state, returns_losses, self.bull_memory + ) + self.reflector.reflect_bear_researcher( + self.curr_state, returns_losses, self.bear_memory + ) + self.reflector.reflect_trader( + self.curr_state, returns_losses, self.trader_memory + ) + self.reflector.reflect_invest_judge( + self.curr_state, returns_losses, self.invest_judge_memory + ) + self.reflector.reflect_risk_manager( + self.curr_state, returns_losses, self.risk_manager_memory + ) + + def process_signal(self, full_signal): + """Process a signal to extract the core decision.""" + return self.signal_processor.process_signal(full_signal)