diff --git a/cli/main.py b/cli/main.py index edbaa3ad..4041ec01 100644 --- a/cli/main.py +++ b/cli/main.py @@ -1,31 +1,32 @@ import datetime -import typer -from pathlib import Path -from functools import wraps -from rich.console import Console -from rich.panel import Panel -from rich.spinner import Spinner -from rich.live import Live -from rich.columns import Columns -from rich.markdown import Markdown -from rich.layout import Layout -from rich.text import Text -from rich.table import Table from collections import deque +from functools import wraps +from pathlib import Path + +import typer from rich import box from rich.align import Align +from rich.columns import Columns +from rich.console import Console +from rich.layout import Layout +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.spinner import Spinner +from rich.table import Table +from rich.text import Text -from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.default_config import DEFAULT_CONFIG from cli.utils import ( - get_ticker, get_analysis_date, + get_ticker, select_analysts, - select_research_depth, - select_shallow_thinking_agent, select_deep_thinking_agent, select_llm_provider, + select_research_depth, + select_shallow_thinking_agent, ) +from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.graph.trading_graph import TradingAgentsGraph console = Console() @@ -136,19 +137,20 @@ class MessageBuffer: report_parts.append("## Analyst Team Reports") if self.report_sections["market_report"]: report_parts.append( - f"### Market Analysis\n{self.report_sections['market_report']}" + f"### Market Analysis\n{self.report_sections['market_report']}", ) if self.report_sections["sentiment_report"]: report_parts.append( - f"### Social Sentiment\n{self.report_sections['sentiment_report']}" + f"### Social Sentiment\n{self.report_sections['sentiment_report']}", ) if self.report_sections["news_report"]: report_parts.append( - f"### News Analysis\n{self.report_sections['news_report']}" + f"### News Analysis\n{self.report_sections['news_report']}", ) if self.report_sections["fundamentals_report"]: + fundamentals = self.report_sections['fundamentals_report'] report_parts.append( - f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" + f"### Fundamentals Analysis\n{fundamentals}", ) # Research Team Reports @@ -180,10 +182,10 @@ def create_layout(): Layout(name="footer", size=3), ) layout["main"].split_column( - Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5) + Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5), ) layout["upper"].split_row( - Layout(name="progress", ratio=2), Layout(name="messages", ratio=3) + Layout(name="progress", ratio=2), Layout(name="messages", ratio=3), ) return layout @@ -198,7 +200,7 @@ def update_display(layout, spinner_text=None): border_style="green", padding=(1, 2), expand=True, - ) + ), ) # Progress panel showing agent status @@ -235,7 +237,7 @@ def update_display(layout, spinner_text=None): status = message_buffer.agent_status[first_agent] if status == "in_progress": spinner = Spinner( - "dots", text="[blue]in_progress[/blue]", style="bold cyan" + "dots", text="[blue]in_progress[/blue]", style="bold cyan", ) status_cell = spinner else: @@ -252,7 +254,7 @@ def update_display(layout, spinner_text=None): status = message_buffer.agent_status[agent] if status == "in_progress": spinner = Spinner( - "dots", text="[blue]in_progress[/blue]", style="bold cyan" + "dots", text="[blue]in_progress[/blue]", style="bold cyan", ) status_cell = spinner else: @@ -268,7 +270,7 @@ def update_display(layout, spinner_text=None): progress_table.add_row("─" * 20, "─" * 20, "─" * 20, style="dim") layout["progress"].update( - Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)) + Panel(progress_table, title="Progress", border_style="cyan", padding=(1, 2)), ) # Messages panel showing recent messages and tool calls @@ -284,7 +286,7 @@ def update_display(layout, spinner_text=None): messages_table.add_column("Time", style="cyan", width=8, justify="center") messages_table.add_column("Type", style="green", width=10, justify="center") messages_table.add_column( - "Content", style="white", no_wrap=False, ratio=1 + "Content", style="white", no_wrap=False, ratio=1, ) # Make content column expand # Combine tool calls and messages @@ -352,7 +354,7 @@ def update_display(layout, spinner_text=None): title="Messages & Tools", border_style="blue", padding=(1, 2), - ) + ), ) # Analysis panel showing current report @@ -363,7 +365,7 @@ def update_display(layout, spinner_text=None): title="Current Report", border_style="green", padding=(1, 2), - ) + ), ) else: layout["analysis"].update( @@ -372,7 +374,7 @@ def update_display(layout, spinner_text=None): title="Current Report", border_style="green", padding=(1, 2), - ) + ), ) # Footer with statistics @@ -386,9 +388,12 @@ def update_display(layout, spinner_text=None): stats_table = Table(show_header=False, box=None, padding=(0, 2), expand=True) stats_table.add_column("Stats", justify="center") - stats_table.add_row( - f"Tool Calls: {tool_calls_count} | LLM Calls: {llm_calls_count} | Generated Reports: {reports_count}" + stats_text = ( + f"Tool Calls: {tool_calls_count} | " + f"LLM Calls: {llm_calls_count} | " + f"Generated Reports: {reports_count}" ) + stats_table.add_row(stats_text) layout["footer"].update(Panel(stats_table, border_style="grey50")) @@ -396,14 +401,20 @@ def update_display(layout, spinner_text=None): def get_user_selections(): """Get all user selections before starting the analysis display.""" # Display ASCII art welcome message - with open("./cli/static/welcome.txt", "r") as f: + with open("./cli/static/welcome.txt") as f: welcome_ascii = f.read() # Create welcome box content welcome_content = f"{welcome_ascii}\n" - welcome_content += "[bold green]TradingAgents: Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" + welcome_content += ( + "[bold green]TradingAgents: " + "Multi-Agents LLM Financial Trading Framework - CLI[/bold green]\n\n" + ) welcome_content += "[bold]Workflow Steps:[/bold]\n" - welcome_content += "I. Analyst Team → II. Research Team → III. Trader → IV. Risk Management → V. Portfolio Management\n\n" + welcome_content += ( + "I. Analyst Team → II. Research Team → III. Trader → " + "IV. Risk Management → V. Portfolio Management\n\n" + ) welcome_content += ( "[dim]Built by [Tauric Research](https://github.com/TauricResearch)[/dim]" ) @@ -430,8 +441,8 @@ def get_user_selections(): # Step 1: Ticker symbol console.print( create_question_box( - "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY" - ) + "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY", + ), ) selected_ticker = get_ticker() @@ -442,40 +453,40 @@ def get_user_selections(): "Step 2: Analysis Date", "Enter the analysis date (YYYY-MM-DD)", default_date, - ) + ), ) analysis_date = get_analysis_date() # Step 3: Select analysts console.print( create_question_box( - "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis" - ) + "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis", + ), ) selected_analysts = select_analysts() console.print( - f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}" + f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}", ) # Step 4: Research depth console.print( create_question_box( - "Step 4: Research Depth", "Select your research depth level" - ) + "Step 4: Research Depth", "Select your research depth level", + ), ) selected_research_depth = select_research_depth() # Step 5: OpenAI backend console.print( - create_question_box("Step 5: OpenAI backend", "Select which service to talk to") + create_question_box("Step 5: OpenAI backend", "Select which service to talk to"), ) selected_llm_provider, backend_url = select_llm_provider() # Step 6: Thinking agents console.print( create_question_box( - "Step 6: Thinking Agents", "Select your thinking agents for analysis" - ) + "Step 6: Thinking Agents", "Select your thinking agents for analysis", + ), ) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) @@ -492,28 +503,7 @@ def get_user_selections(): } -def get_ticker(): - """Get ticker symbol from user input.""" - return typer.prompt("", default="SPY") - - -def get_analysis_date(): - """Get the analysis date from user input.""" - while True: - date_str = typer.prompt( - "", default=datetime.datetime.now().strftime("%Y-%m-%d") - ) - try: - # Validate date format and ensure it's not in the future - analysis_date = datetime.datetime.strptime(date_str, "%Y-%m-%d") - if analysis_date.date() > datetime.datetime.now().date(): - console.print("[red]Error: Analysis date cannot be in the future[/red]") - continue - return date_str - except ValueError: - console.print( - "[red]Error: Invalid date format. Please use YYYY-MM-DD[/red]" - ) +# Functions get_ticker and get_analysis_date are imported from cli.utils def display_complete_report(final_state): @@ -531,7 +521,7 @@ def display_complete_report(final_state): title="Market Analyst", border_style="blue", padding=(1, 2), - ) + ), ) # Social Analyst Report @@ -542,7 +532,7 @@ def display_complete_report(final_state): title="Social Analyst", border_style="blue", padding=(1, 2), - ) + ), ) # News Analyst Report @@ -553,7 +543,7 @@ def display_complete_report(final_state): title="News Analyst", border_style="blue", padding=(1, 2), - ) + ), ) # Fundamentals Analyst Report @@ -564,7 +554,7 @@ def display_complete_report(final_state): title="Fundamentals Analyst", border_style="blue", padding=(1, 2), - ) + ), ) if analyst_reports: @@ -574,7 +564,7 @@ def display_complete_report(final_state): title="I. Analyst Team Reports", border_style="cyan", padding=(1, 2), - ) + ), ) # II. Research Team Reports @@ -590,7 +580,7 @@ def display_complete_report(final_state): title="Bull Researcher", border_style="blue", padding=(1, 2), - ) + ), ) # Bear Researcher Analysis @@ -601,7 +591,7 @@ def display_complete_report(final_state): title="Bear Researcher", border_style="blue", padding=(1, 2), - ) + ), ) # Research Manager Decision @@ -612,7 +602,7 @@ def display_complete_report(final_state): title="Research Manager", border_style="blue", padding=(1, 2), - ) + ), ) if research_reports: @@ -622,7 +612,7 @@ def display_complete_report(final_state): title="II. Research Team Decision", border_style="magenta", padding=(1, 2), - ) + ), ) # III. Trading Team Reports @@ -638,7 +628,7 @@ def display_complete_report(final_state): title="III. Trading Team Plan", border_style="yellow", padding=(1, 2), - ) + ), ) # IV. Risk Management Team Reports @@ -654,7 +644,7 @@ def display_complete_report(final_state): title="Aggressive Analyst", border_style="blue", padding=(1, 2), - ) + ), ) # Conservative (Safe) Analyst Analysis @@ -665,7 +655,7 @@ def display_complete_report(final_state): title="Conservative Analyst", border_style="blue", padding=(1, 2), - ) + ), ) # Neutral Analyst Analysis @@ -676,7 +666,7 @@ def display_complete_report(final_state): title="Neutral Analyst", border_style="blue", padding=(1, 2), - ) + ), ) if risk_reports: @@ -686,7 +676,7 @@ def display_complete_report(final_state): title="IV. Risk Management Team Decision", border_style="red", padding=(1, 2), - ) + ), ) # V. Portfolio Manager Decision @@ -702,7 +692,7 @@ def display_complete_report(final_state): title="V. Portfolio Manager Decision", border_style="green", padding=(1, 2), - ) + ), ) @@ -717,7 +707,7 @@ def extract_content_string(content): """Extract string content from various message formats.""" if isinstance(content, str): return content - elif isinstance(content, list): + if isinstance(content, list): # Handle Anthropic's list format text_parts = [] for item in content: @@ -729,8 +719,7 @@ def extract_content_string(content): else: text_parts.append(str(item)) return " ".join(text_parts) - else: - return str(content) + return str(content) def run_analysis(): @@ -748,7 +737,7 @@ def run_analysis(): # Initialize the graph graph = TradingAgentsGraph( - [analyst.value for analyst in selections["analysts"]], config=config, debug=True + [analyst.value for analyst in selections["analysts"]], config=config, debug=True, ) # Create result directory @@ -807,23 +796,23 @@ def run_analysis(): message_buffer.add_message = save_message_decorator(message_buffer, "add_message") message_buffer.add_tool_call = save_tool_call_decorator( - message_buffer, "add_tool_call" + message_buffer, "add_tool_call", ) message_buffer.update_report_section = save_report_section_decorator( - message_buffer, "update_report_section" + message_buffer, "update_report_section", ) # Now start the display layout layout = create_layout() - with Live(layout, refresh_per_second=4) as live: + with Live(layout, refresh_per_second=4): # Initial display update_display(layout) # Add initial messages message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") message_buffer.add_message( - "System", f"Analysis date: {selections['analysis_date']}" + "System", f"Analysis date: {selections['analysis_date']}", ) message_buffer.add_message( "System", @@ -854,7 +843,7 @@ def run_analysis(): # Initialize state and get graph args init_agent_state = graph.propagator.create_initial_state( - selections["ticker"], selections["analysis_date"] + selections["ticker"], selections["analysis_date"], ) args = graph.propagator.get_graph_args() @@ -868,7 +857,7 @@ def run_analysis(): # Extract message content and type if hasattr(last_message, "content"): content = extract_content_string( - last_message.content + last_message.content, ) # Use the helper function msg_type = "Reasoning" else: @@ -884,65 +873,64 @@ def run_analysis(): # Handle both dictionary and object tool calls if isinstance(tool_call, dict): message_buffer.add_tool_call( - tool_call["name"], tool_call["args"] + tool_call["name"], tool_call["args"], ) else: message_buffer.add_tool_call(tool_call.name, tool_call.args) # Update reports and agent status based on chunk content # Analyst Team Reports - if "market_report" in chunk and chunk["market_report"]: + if chunk.get("market_report"): message_buffer.update_report_section( - "market_report", chunk["market_report"] + "market_report", chunk["market_report"], ) message_buffer.update_agent_status("Market Analyst", "completed") # Set next analyst to in_progress if "social" in selections["analysts"]: message_buffer.update_agent_status( - "Social Analyst", "in_progress" + "Social Analyst", "in_progress", ) - if "sentiment_report" in chunk and chunk["sentiment_report"]: + if chunk.get("sentiment_report"): message_buffer.update_report_section( - "sentiment_report", chunk["sentiment_report"] + "sentiment_report", chunk["sentiment_report"], ) message_buffer.update_agent_status("Social Analyst", "completed") # Set next analyst to in_progress if "news" in selections["analysts"]: message_buffer.update_agent_status( - "News Analyst", "in_progress" + "News Analyst", "in_progress", ) - if "news_report" in chunk and chunk["news_report"]: + if chunk.get("news_report"): message_buffer.update_report_section( - "news_report", chunk["news_report"] + "news_report", chunk["news_report"], ) message_buffer.update_agent_status("News Analyst", "completed") # Set next analyst to in_progress if "fundamentals" in selections["analysts"]: message_buffer.update_agent_status( - "Fundamentals Analyst", "in_progress" + "Fundamentals Analyst", "in_progress", ) - if "fundamentals_report" in chunk and chunk["fundamentals_report"]: + if chunk.get("fundamentals_report"): message_buffer.update_report_section( - "fundamentals_report", chunk["fundamentals_report"] + "fundamentals_report", chunk["fundamentals_report"], ) message_buffer.update_agent_status( - "Fundamentals Analyst", "completed" + "Fundamentals Analyst", "completed", ) # Set all research team members to in_progress update_research_team_status("in_progress") # Research Team - Handle Investment Debate State if ( - "investment_debate_state" in chunk - and chunk["investment_debate_state"] + chunk.get("investment_debate_state") ): debate_state = chunk["investment_debate_state"] # Update Bull Researcher status and report - if "bull_history" in debate_state and debate_state["bull_history"]: + if debate_state.get("bull_history"): # Keep all research team members in progress update_research_team_status("in_progress") # Extract latest bull response @@ -957,7 +945,7 @@ def run_analysis(): ) # Update Bear Researcher status and report - if "bear_history" in debate_state and debate_state["bear_history"]: + if debate_state.get("bear_history"): # Keep all research team members in progress update_research_team_status("in_progress") # Extract latest bear response @@ -973,8 +961,7 @@ def run_analysis(): # Update Research Manager status and final decision if ( - "judge_decision" in debate_state - and debate_state["judge_decision"] + debate_state.get("judge_decision") ): # Keep all research team members in progress until final decision update_research_team_status("in_progress") @@ -991,31 +978,29 @@ def run_analysis(): update_research_team_status("completed") # Set first risk analyst to in_progress message_buffer.update_agent_status( - "Risky Analyst", "in_progress" + "Risky Analyst", "in_progress", ) # Trading Team if ( - "trader_investment_plan" in chunk - and chunk["trader_investment_plan"] + chunk.get("trader_investment_plan") ): message_buffer.update_report_section( - "trader_investment_plan", chunk["trader_investment_plan"] + "trader_investment_plan", chunk["trader_investment_plan"], ) # Set first risk analyst to in_progress message_buffer.update_agent_status("Risky Analyst", "in_progress") # Risk Management Team - Handle Risk Debate State - if "risk_debate_state" in chunk and chunk["risk_debate_state"]: + if chunk.get("risk_debate_state"): risk_state = chunk["risk_debate_state"] # Update Risky Analyst status and report if ( - "current_risky_response" in risk_state - and risk_state["current_risky_response"] + risk_state.get("current_risky_response") ): message_buffer.update_agent_status( - "Risky Analyst", "in_progress" + "Risky Analyst", "in_progress", ) message_buffer.add_message( "Reasoning", @@ -1029,11 +1014,10 @@ def run_analysis(): # Update Safe Analyst status and report if ( - "current_safe_response" in risk_state - and risk_state["current_safe_response"] + risk_state.get("current_safe_response") ): message_buffer.update_agent_status( - "Safe Analyst", "in_progress" + "Safe Analyst", "in_progress", ) message_buffer.add_message( "Reasoning", @@ -1047,11 +1031,10 @@ def run_analysis(): # Update Neutral Analyst status and report if ( - "current_neutral_response" in risk_state - and risk_state["current_neutral_response"] + risk_state.get("current_neutral_response") ): message_buffer.update_agent_status( - "Neutral Analyst", "in_progress" + "Neutral Analyst", "in_progress", ) message_buffer.add_message( "Reasoning", @@ -1064,9 +1047,9 @@ def run_analysis(): ) # Update Portfolio Manager status and final decision - if "judge_decision" in risk_state and risk_state["judge_decision"]: + if risk_state.get("judge_decision"): message_buffer.update_agent_status( - "Portfolio Manager", "in_progress" + "Portfolio Manager", "in_progress", ) message_buffer.add_message( "Reasoning", @@ -1081,10 +1064,10 @@ def run_analysis(): message_buffer.update_agent_status("Risky Analyst", "completed") message_buffer.update_agent_status("Safe Analyst", "completed") message_buffer.update_agent_status( - "Neutral Analyst", "completed" + "Neutral Analyst", "completed", ) message_buffer.update_agent_status( - "Portfolio Manager", "completed" + "Portfolio Manager", "completed", ) # Update the display @@ -1094,18 +1077,18 @@ def run_analysis(): # Get final state and decision final_state = trace[-1] - decision = graph.process_signal(final_state["final_trade_decision"]) + graph.process_signal(final_state["final_trade_decision"]) # Update all agent statuses to completed for agent in message_buffer.agent_status: message_buffer.update_agent_status(agent, "completed") message_buffer.add_message( - "Analysis", f"Completed analysis for {selections['analysis_date']}" + "Analysis", f"Completed analysis for {selections['analysis_date']}", ) # Update final report sections - for section in message_buffer.report_sections.keys(): + for section in message_buffer.report_sections: if section in final_state: message_buffer.update_report_section(section, final_state[section]) diff --git a/cli/utils.py b/cli/utils.py index 4f4eff6b..e8df4a78 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,5 +1,7 @@ + +import sys + import questionary -from typing import List from rich.console import Console from cli.models import AnalystType @@ -23,13 +25,13 @@ def get_ticker() -> str: [ ("text", "fg:green"), ("highlighted", "noinherit"), - ] + ], ), ).ask() if not ticker: console.print("\n[red]No ticker symbol provided. Exiting...[/red]") - exit(1) + sys.exit(1) return ticker.strip().upper() @@ -56,18 +58,18 @@ def get_analysis_date() -> str: [ ("text", "fg:green"), ("highlighted", "noinherit"), - ] + ], ), ).ask() if not date: console.print("\n[red]No date provided. Exiting...[/red]") - exit(1) + sys.exit(1) return date.strip() -def select_analysts() -> List[AnalystType]: +def select_analysts() -> list[AnalystType]: """Select analysts using an interactive checkbox.""" choices = questionary.checkbox( "Select Your [Analysts Team]:", @@ -82,13 +84,13 @@ def select_analysts() -> List[AnalystType]: ("selected", "fg:green noinherit"), ("highlighted", "noinherit"), ("pointer", "noinherit"), - ] + ], ), ).ask() if not choices: console.print("\n[red]No analysts selected. Exiting...[/red]") - exit(1) + sys.exit(1) return choices @@ -114,13 +116,13 @@ def select_research_depth() -> int: ("selected", "fg:yellow noinherit"), ("highlighted", "fg:yellow noinherit"), ("pointer", "fg:yellow noinherit"), - ] + ], ), ).ask() if choice is None: console.print("\n[red]No research depth selected. Exiting...[/red]") - exit(1) + sys.exit(1) return choice @@ -200,15 +202,15 @@ def select_shallow_thinking_agent(provider) -> str: ("selected", "fg:magenta noinherit"), ("highlighted", "fg:magenta noinherit"), ("pointer", "fg:magenta noinherit"), - ] + ], ), ).ask() if choice is None: console.print( - "\n[red]No shallow thinking llm engine selected. Exiting...[/red]" + "\n[red]No shallow thinking llm engine selected. Exiting...[/red]", ) - exit(1) + sys.exit(1) return choice @@ -292,13 +294,13 @@ def select_deep_thinking_agent(provider) -> str: ("selected", "fg:magenta noinherit"), ("highlighted", "fg:magenta noinherit"), ("pointer", "fg:magenta noinherit"), - ] + ], ), ).ask() if choice is None: console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") - exit(1) + sys.exit(1) return choice @@ -326,15 +328,14 @@ def select_llm_provider() -> tuple[str, str]: ("selected", "fg:magenta noinherit"), ("highlighted", "fg:magenta noinherit"), ("pointer", "fg:magenta noinherit"), - ] + ], ), ).ask() if choice is None: console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") - exit(1) + sys.exit(1) display_name, url = choice - print(f"You selected: {display_name}\tURL: {url}") return display_name, url diff --git a/test_hooks.py b/test_hooks.py index 82af1372..e4dd48b4 100644 --- a/test_hooks.py +++ b/test_hooks.py @@ -1,10 +1,7 @@ -import os,sys # Multiple imports on one line (Ruff will fix) -from typing import List,Dict # Multiple imports on one line def poorly_formatted_function(x,y,z): # Missing type hints """This function has formatting issues.""" result=x+y*z # Missing spaces around operators - unused_variable = 42 # Unused variable (Ruff will detect) if result>100: # Missing spaces print( "Result is large" ) # Extra spaces in parentheses return result diff --git a/test_setup_demo.py b/test_setup_demo.py index 515fe8a5..75b8ad2d 100755 --- a/test_setup_demo.py +++ b/test_setup_demo.py @@ -23,11 +23,11 @@ def run_command(cmd, description=""): print(f"✅ {description or 'Command completed successfully'}") return True else: - print(f"❌ Command failed:") + print("❌ Command failed:") print(result.stderr) return False except subprocess.TimeoutExpired: - print(f"⏱️ Command timed out") + print("⏱️ Command timed out") return False except Exception as e: print(f"❌ Error running command: {e}") @@ -73,7 +73,7 @@ def main(): # Summary print("\n" + "=" * 50) - print(f"📊 Test Setup Verification Results:") + print("📊 Test Setup Verification Results:") print(f"✅ Successful: {success_count}/{total_tests}") print(f"❌ Failed: {total_tests - success_count}/{total_tests}") diff --git a/tests/conftest.py b/tests/conftest.py index b30472d6..1f2fb16c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,10 @@ """Pytest configuration and shared fixtures for TradingAgents tests.""" import os -import pytest import tempfile -from unittest.mock import Mock, MagicMock -from datetime import date, datetime -from typing import Dict, Any +from unittest.mock import Mock + +import pytest from tradingagents.default_config import DEFAULT_CONFIG @@ -22,7 +21,7 @@ def sample_config(): "deep_think_llm": "gpt-4o-mini", "quick_think_llm": "gpt-4o-mini", "project_dir": "/tmp/test_tradingagents", - } + }, ) return config @@ -174,7 +173,7 @@ def mock_memory(): def pytest_configure(config): """Configure pytest with custom markers.""" config.addinivalue_line( - "markers", "integration: mark test as integration test (slow)" + "markers", "integration: mark test as integration test (slow)", ) config.addinivalue_line("markers", "unit: mark test as unit test (fast)") config.addinivalue_line("markers", "api: mark test as requiring API access") diff --git a/tests/fixtures/sample_data.py b/tests/fixtures/sample_data.py index 58705966..94a15135 100644 --- a/tests/fixtures/sample_data.py +++ b/tests/fixtures/sample_data.py @@ -2,14 +2,14 @@ import json from datetime import datetime, timedelta -from typing import Dict, List, Any +from typing import Any class SampleDataFactory: """Factory class for creating sample test data.""" @staticmethod - def create_market_data(ticker: str = "AAPL", days: int = 30) -> Dict[str, Any]: + def create_market_data(ticker: str = "AAPL", days: int = 30) -> dict[str, Any]: """Create sample market data for testing.""" base_date = datetime(2024, 5, 1) data = {} @@ -36,8 +36,8 @@ class SampleDataFactory: @staticmethod def create_finnhub_news_data( - ticker: str = "AAPL", count: int = 10 - ) -> Dict[str, List[Dict[str, Any]]]: + ticker: str = "AAPL", count: int = 10, + ) -> dict[str, list[dict[str, Any]]]: """Create sample FinnHub news data for testing.""" base_date = datetime(2024, 5, 10) data = {} @@ -93,7 +93,7 @@ class SampleDataFactory: @staticmethod def create_insider_transactions_data( ticker: str = "AAPL", - ) -> Dict[str, List[Dict[str, Any]]]: + ) -> dict[str, list[dict[str, Any]]]: """Create sample insider transactions data for testing.""" base_date = datetime(2024, 5, 5) data = {} @@ -129,15 +129,15 @@ class SampleDataFactory: "transactionValue": transaction["shares"] * transaction["price"], "reportingName": transaction["person"], "typeOfOwner": "officer", - } + }, ] return data @staticmethod def create_financial_statements_data( - ticker: str = "AAPL", period: str = "annual" - ) -> Dict[str, List[Dict[str, Any]]]: + ticker: str = "AAPL", period: str = "annual", + ) -> dict[str, list[dict[str, Any]]]: """Create sample financial statements data for testing.""" if period == "annual": dates = ["2023-12-31", "2022-12-31", "2021-12-31"] @@ -174,7 +174,7 @@ class SampleDataFactory: @staticmethod def create_social_sentiment_data( ticker: str = "AAPL", - ) -> Dict[str, List[Dict[str, Any]]]: + ) -> dict[str, list[dict[str, Any]]]: """Create sample social media sentiment data for testing.""" base_date = datetime(2024, 5, 8) data = {} @@ -226,7 +226,7 @@ class SampleDataFactory: "subreddit": "stocks" if j % 2 else "investing", "upvotes": 10 + (j * 5), "comments": 3 + j, - } + }, ) data[date_str] = daily_posts @@ -234,7 +234,7 @@ class SampleDataFactory: return data @staticmethod - def create_technical_indicators_data(ticker: str = "AAPL") -> Dict[str, Any]: + def create_technical_indicators_data(ticker: str = "AAPL") -> dict[str, Any]: """Create sample technical indicators data for testing.""" return { "symbol": ticker, @@ -262,23 +262,23 @@ class SampleDataFactory: } @staticmethod - def create_complete_test_dataset(ticker: str = "AAPL") -> Dict[str, Dict[str, Any]]: + def create_complete_test_dataset(ticker: str = "AAPL") -> dict[str, dict[str, Any]]: """Create a complete dataset for comprehensive testing.""" return { "market_data": SampleDataFactory.create_market_data(ticker), "news_data": SampleDataFactory.create_finnhub_news_data(ticker), "insider_transactions": SampleDataFactory.create_insider_transactions_data( - ticker + ticker, ), "financial_annual": SampleDataFactory.create_financial_statements_data( - ticker, "annual" + ticker, "annual", ), "financial_quarterly": SampleDataFactory.create_financial_statements_data( - ticker, "quarterly" + ticker, "quarterly", ), "social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker), "technical_indicators": SampleDataFactory.create_technical_indicators_data( - ticker + ticker, ), } @@ -343,7 +343,7 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None: # Save quarterly data separately quarterly_path = os.path.join( - finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json" + finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json", ) with open(quarterly_path, "w") as f: json.dump(dataset["financial_quarterly"], f, indent=2) diff --git a/tests/integration/test_full_workflow.py b/tests/integration/test_full_workflow.py index 94e8efb2..23fbe3ce 100644 --- a/tests/integration/test_full_workflow.py +++ b/tests/integration/test_full_workflow.py @@ -1,13 +1,11 @@ """Integration tests for the full TradingAgents workflow.""" -import pytest -import os -import tempfile -from unittest.mock import Mock, patch, MagicMock -from datetime import datetime +from unittest.mock import Mock, patch + +import pytest -from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG +from tradingagents.graph.trading_graph import TradingAgentsGraph @pytest.mark.integration @@ -26,14 +24,14 @@ class TestFullWorkflowIntegration: "deep_think_llm": "gpt-4o-mini", "quick_think_llm": "gpt-4o-mini", "project_dir": temp_data_dir, - } + }, ) return config @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_end_to_end_trading_workflow( - self, mock_toolkit, mock_chat_openai, integration_config + self, mock_toolkit, mock_chat_openai, integration_config, ): """Test complete end-to-end trading workflow.""" # Setup mocks @@ -69,15 +67,14 @@ class TestFullWorkflowIntegration: "company_of_interest": "AAPL", "trade_date": "2024-05-10", "messages": [], - } + }, ) trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.signal_processor.process_signal = Mock(return_value="BUY") # Execute the full workflow - with patch("builtins.open", create=True): - with patch("json.dump"): - final_state, decision = trading_graph.propagate("AAPL", "2024-05-10") + with patch("builtins.open", create=True), patch("json.dump"): + final_state, decision = trading_graph.propagate("AAPL", "2024-05-10") # Verify the workflow completed successfully assert final_state is not None @@ -89,7 +86,7 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_multiple_analysts_integration( - self, mock_toolkit, mock_chat_openai, integration_config + self, mock_toolkit, mock_chat_openai, integration_config, ): """Test integration with different analyst combinations.""" analyst_combinations = [ @@ -117,7 +114,7 @@ class TestFullWorkflowIntegration: with patch("tradingagents.graph.trading_graph.set_config"): # Test each analyst combination trading_graph = TradingAgentsGraph( - selected_analysts=analysts, config=integration_config + selected_analysts=analysts, config=integration_config, ) trading_graph.graph = mock_graph @@ -127,19 +124,18 @@ class TestFullWorkflowIntegration: "company_of_interest": "TSLA", "trade_date": "2024-05-15", "messages": [], - } + }, ) trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.signal_processor.process_signal = Mock( - return_value="HOLD" + return_value="HOLD", ) # Execute - with patch("builtins.open", create=True): - with patch("json.dump"): - final_state, decision = trading_graph.propagate( - "TSLA", "2024-05-15" - ) + with patch("builtins.open", create=True), patch("json.dump"): + final_state, decision = trading_graph.propagate( + "TSLA", "2024-05-15", + ) # Verify assert final_state is not None @@ -148,7 +144,7 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_memory_and_reflection_integration( - self, mock_toolkit, mock_chat_openai, integration_config + self, mock_toolkit, mock_chat_openai, integration_config, ): """Test integration of memory and reflection components.""" # Setup @@ -165,7 +161,7 @@ class TestFullWorkflowIntegration: mock_graph.invoke.return_value = mock_final_state with patch( - "tradingagents.graph.trading_graph.FinancialSituationMemory" + "tradingagents.graph.trading_graph.FinancialSituationMemory", ) as mock_memory: mock_memory_instance = Mock() mock_memory.return_value = mock_memory_instance @@ -180,11 +176,11 @@ class TestFullWorkflowIntegration: "company_of_interest": "NVDA", "trade_date": "2024-05-20", "messages": [], - } + }, ) trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.signal_processor.process_signal = Mock( - return_value="SELL" + return_value="SELL", ) # Mock reflection methods @@ -195,9 +191,8 @@ class TestFullWorkflowIntegration: trading_graph.reflector.reflect_risk_manager = Mock() # Execute workflow - with patch("builtins.open", create=True): - with patch("json.dump"): - final_state, decision = trading_graph.propagate("NVDA", "2024-05-20") + with patch("builtins.open", create=True), patch("json.dump"): + final_state, decision = trading_graph.propagate("NVDA", "2024-05-20") # Test reflection and memory update returns_losses = {"return": -0.03, "loss": -0.08} @@ -213,7 +208,7 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_debug_mode_integration( - self, mock_toolkit, mock_chat_openai, integration_config + self, mock_toolkit, mock_chat_openai, integration_config, ): """Test integration in debug mode.""" # Setup @@ -233,7 +228,7 @@ class TestFullWorkflowIntegration: self._create_mock_final_state(), # Final chunk ] for chunk in mock_chunks: - if "messages" in chunk and chunk["messages"]: + if chunk.get("messages"): for msg in chunk["messages"]: if hasattr(msg, "pretty_print"): msg.pretty_print = Mock() @@ -245,7 +240,7 @@ class TestFullWorkflowIntegration: with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.set_config"): trading_graph = TradingAgentsGraph( - debug=True, config=integration_config + debug=True, config=integration_config, ) trading_graph.graph = mock_graph @@ -255,15 +250,14 @@ class TestFullWorkflowIntegration: "company_of_interest": "AMZN", "trade_date": "2024-05-25", "messages": [], - } + }, ) trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.signal_processor.process_signal = Mock(return_value="BUY") # Execute in debug mode - with patch("builtins.open", create=True): - with patch("json.dump"): - final_state, decision = trading_graph.propagate("AMZN", "2024-05-25") + with patch("builtins.open", create=True), patch("json.dump"): + final_state, decision = trading_graph.propagate("AMZN", "2024-05-25") # Verify debug mode was used mock_graph.stream.assert_called_once() @@ -271,7 +265,7 @@ class TestFullWorkflowIntegration: assert decision == "BUY" @pytest.mark.parametrize( - "ticker,date", + ("ticker", "date"), [ ("AAPL", "2024-01-15"), ("TSLA", "2024-02-20"), @@ -282,7 +276,7 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_multiple_stocks_integration( - self, mock_toolkit, mock_chat_openai, ticker, date, integration_config + self, mock_toolkit, mock_chat_openai, ticker, date, integration_config, ): """Test integration with different stocks and dates.""" # Setup @@ -309,17 +303,16 @@ class TestFullWorkflowIntegration: "company_of_interest": ticker, "trade_date": date, "messages": [], - } + }, ) trading_graph.propagator.get_graph_args = Mock(return_value={}) trading_graph.signal_processor.process_signal = Mock( - return_value="HOLD" + return_value="HOLD", ) # Execute - with patch("builtins.open", create=True): - with patch("json.dump"): - final_state, decision = trading_graph.propagate(ticker, date) + with patch("builtins.open", create=True), patch("json.dump"): + final_state, decision = trading_graph.propagate(ticker, date) # Verify assert final_state["company_of_interest"] == ticker @@ -389,7 +382,7 @@ class TestPerformanceIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_multiple_consecutive_runs( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, ): """Test multiple consecutive trading decisions.""" sample_config["project_dir"] = temp_data_dir @@ -455,9 +448,8 @@ class TestPerformanceIntegration: mock_final_state["final_trade_decision"] ) - with patch("builtins.open", create=True): - with patch("json.dump"): - final_state, decision = trading_graph.propagate(ticker, date) + with patch("builtins.open", create=True), patch("json.dump"): + final_state, decision = trading_graph.propagate(ticker, date) decisions.append(decision) diff --git a/tests/unit/agents/test_market_analyst.py b/tests/unit/agents/test_market_analyst.py index f1a896c2..b44cf266 100644 --- a/tests/unit/agents/test_market_analyst.py +++ b/tests/unit/agents/test_market_analyst.py @@ -1,8 +1,9 @@ """Unit tests for market analyst agent.""" +from unittest.mock import Mock + import pytest -from unittest.mock import Mock, patch, MagicMock -from langchain_core.messages import HumanMessage, AIMessage +from langchain_core.messages import HumanMessage from tradingagents.agents.analysts.market_analyst import create_market_analyst @@ -16,7 +17,7 @@ class TestMarketAnalyst: assert callable(analyst_node) def test_market_analyst_node_basic_execution( - self, mock_llm, mock_toolkit, sample_agent_state + self, mock_llm, mock_toolkit, sample_agent_state, ): """Test basic execution of market analyst node.""" # Setup @@ -38,7 +39,7 @@ class TestMarketAnalyst: assert result["market_report"] == "Market analysis complete" def test_market_analyst_uses_online_tools_when_configured( - self, mock_llm, mock_toolkit, sample_agent_state + self, mock_llm, mock_toolkit, sample_agent_state, ): """Test that analyst uses online tools when configured.""" # Setup @@ -54,7 +55,7 @@ class TestMarketAnalyst: analyst_node = create_market_analyst(mock_llm, mock_toolkit) # Execute - result = analyst_node(sample_agent_state) + analyst_node(sample_agent_state) # Verify tools were bound correctly mock_llm.bind_tools.assert_called_once() @@ -63,7 +64,7 @@ class TestMarketAnalyst: assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2 def test_market_analyst_uses_offline_tools_when_configured( - self, mock_llm, mock_toolkit, sample_agent_state + self, mock_llm, mock_toolkit, sample_agent_state, ): """Test that analyst uses offline tools when configured.""" # Setup @@ -79,7 +80,7 @@ class TestMarketAnalyst: analyst_node = create_market_analyst(mock_llm, mock_toolkit) # Execute - result = analyst_node(sample_agent_state) + analyst_node(sample_agent_state) # Verify tools were bound correctly mock_llm.bind_tools.assert_called_once() @@ -87,7 +88,7 @@ class TestMarketAnalyst: assert len(bound_tools) == 2 # Should have 2 offline tools def test_market_analyst_processes_state_variables( - self, mock_llm, mock_toolkit, sample_agent_state + self, mock_llm, mock_toolkit, sample_agent_state, ): """Test that market analyst correctly processes state variables.""" # Setup @@ -111,7 +112,7 @@ class TestMarketAnalyst: assert result["market_report"] == "Analysis for AAPL on 2024-05-10" def test_market_analyst_handles_empty_tool_calls( - self, mock_llm, mock_toolkit, sample_agent_state + self, mock_llm, mock_toolkit, sample_agent_state, ): """Test handling when no tool calls are made.""" # Setup @@ -131,7 +132,7 @@ class TestMarketAnalyst: assert result["messages"] == [mock_result] def test_market_analyst_with_tool_calls( - self, mock_llm, mock_toolkit, sample_agent_state + self, mock_llm, mock_toolkit, sample_agent_state, ): """Test handling when tool calls are present.""" # Setup @@ -152,7 +153,7 @@ class TestMarketAnalyst: @pytest.mark.parametrize("online_tools", [True, False]) def test_market_analyst_tool_configuration( - self, mock_llm, mock_toolkit, sample_agent_state, online_tools + self, mock_llm, mock_toolkit, sample_agent_state, online_tools, ): """Test tool configuration for both online and offline modes.""" # Setup @@ -194,15 +195,15 @@ class TestMarketAnalystIntegration: mock_result = Mock() mock_result.content = """ # Market Analysis for TSLA (2024-05-15) - + ## Technical Analysis - RSI: 65 (slightly overbought) - MACD: Bullish crossover - 50-day SMA: Trending upward - + ## Volume Analysis - Above average volume suggests strong interest - + | Indicator | Value | Signal | |-----------|-------|--------| | RSI | 65 | Neutral | diff --git a/tests/unit/dataflows/test_finnhub_utils.py b/tests/unit/dataflows/test_finnhub_utils.py index e7686a46..48953a40 100644 --- a/tests/unit/dataflows/test_finnhub_utils.py +++ b/tests/unit/dataflows/test_finnhub_utils.py @@ -1,10 +1,9 @@ """Unit tests for FinnHub utilities.""" -import pytest import json import os -import tempfile -from unittest.mock import patch, mock_open, Mock + +import pytest from tradingagents.dataflows.finnhub_utils import get_data_in_range @@ -191,7 +190,7 @@ class TestFinnhubUtils: # Test without period expected_path_no_period = os.path.join( - temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json" + temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json", ) # Test with period @@ -238,7 +237,7 @@ class TestFinnhubUtils: assert len(result2) == 1 @pytest.mark.parametrize( - "data_type,period", + ("data_type", "period"), [ ("news_data", None), ("insider_trans", None), @@ -249,7 +248,7 @@ class TestFinnhubUtils: ], ) def test_get_data_in_range_various_data_types( - self, temp_data_dir, data_type, period + self, temp_data_dir, data_type, period, ): """Test get_data_in_range with various data types.""" ticker = "TEST" diff --git a/tests/unit/graph/test_trading_graph.py b/tests/unit/graph/test_trading_graph.py index f3ccfdc7..e5fbb923 100644 --- a/tests/unit/graph/test_trading_graph.py +++ b/tests/unit/graph/test_trading_graph.py @@ -1,12 +1,10 @@ """Unit tests for TradingAgentsGraph.""" +from unittest.mock import Mock, mock_open, patch + import pytest -import os -from unittest.mock import Mock, patch, MagicMock -from pathlib import Path from tradingagents.graph.trading_graph import TradingAgentsGraph -from tradingagents.default_config import DEFAULT_CONFIG class TestTradingAgentsGraph: @@ -47,7 +45,7 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_init_with_debug( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, ): """Test initialization with debug mode enabled.""" sample_config["project_dir"] = temp_data_dir @@ -65,7 +63,7 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatAnthropic") @patch("tradingagents.graph.trading_graph.Toolkit") def test_init_with_anthropic( - self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir, ): """Test initialization with Anthropic LLM provider.""" sample_config["project_dir"] = temp_data_dir @@ -77,14 +75,14 @@ class TestTradingAgentsGraph: with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.set_config"): - graph = TradingAgentsGraph(config=sample_config) + TradingAgentsGraph(config=sample_config) assert mock_chat_anthropic.call_count == 2 @patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_init_with_google( - self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir, ): """Test initialization with Google LLM provider.""" sample_config["project_dir"] = temp_data_dir @@ -96,13 +94,13 @@ class TestTradingAgentsGraph: with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.set_config"): - graph = TradingAgentsGraph(config=sample_config) + TradingAgentsGraph(config=sample_config) assert mock_chat_google.call_count == 2 @patch("tradingagents.graph.trading_graph.Toolkit") def test_init_unsupported_llm_provider( - self, mock_toolkit, sample_config, temp_data_dir + self, mock_toolkit, sample_config, temp_data_dir, ): """Test initialization with unsupported LLM provider raises error.""" sample_config["project_dir"] = temp_data_dir @@ -117,7 +115,7 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_create_tool_nodes( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, ): """Test creation of tool nodes.""" sample_config["project_dir"] = temp_data_dir @@ -145,7 +143,7 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_propagate_basic( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, ): """Test basic propagate functionality.""" sample_config["project_dir"] = temp_data_dir @@ -190,15 +188,14 @@ class TestTradingAgentsGraph: # Mock the propagator and signal processor graph.propagator.create_initial_state = Mock( - return_value={"test": "state"} + return_value={"test": "state"}, ) graph.propagator.get_graph_args = Mock(return_value={}) graph.signal_processor.process_signal = Mock(return_value="HOLD") # Execute - with patch("builtins.open", create=True): - with patch("json.dump"): - final_state, decision = graph.propagate("AAPL", "2024-05-10") + with patch("builtins.open", create=True), patch("json.dump"): + final_state, decision = graph.propagate("AAPL", "2024-05-10") # Verify assert final_state == mock_final_state @@ -209,7 +206,7 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_propagate_debug_mode( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, ): """Test propagate in debug mode.""" sample_config["project_dir"] = temp_data_dir @@ -231,15 +228,14 @@ class TestTradingAgentsGraph: # Mock other components graph.propagator.create_initial_state = Mock( - return_value={"test": "state"} + return_value={"test": "state"}, ) graph.propagator.get_graph_args = Mock(return_value={}) graph.signal_processor.process_signal = Mock(return_value="BUY") # Execute - with patch("builtins.open", create=True): - with patch("json.dump"): - final_state, decision = graph.propagate("TSLA", "2024-05-15") + with patch("builtins.open", create=True), patch("json.dump"): + final_state, decision = graph.propagate("TSLA", "2024-05-15") # Verify debug mode was used mock_graph.stream.assert_called_once() @@ -249,7 +245,7 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_log_state( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, ): """Test state logging functionality.""" sample_config["project_dir"] = temp_data_dir @@ -291,10 +287,9 @@ class TestTradingAgentsGraph: } # Mock file operations - with patch("pathlib.Path.mkdir"): - with patch("builtins.open", mock_open()) as mock_file: - with patch("json.dump") as mock_json_dump: - graph._log_state("2024-05-20", final_state) + with patch("pathlib.Path.mkdir"), patch("builtins.open", mock_open()): + with patch("json.dump"): + graph._log_state("2024-05-20", final_state) # Verify logging occurred assert "2024-05-20" in graph.log_states_dict @@ -305,7 +300,7 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_reflect_and_remember( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, ): """Test reflection and memory update functionality.""" sample_config["project_dir"] = temp_data_dir @@ -315,20 +310,19 @@ class TestTradingAgentsGraph: mock_toolkit.return_value = mock_toolkit_instance with patch( - "tradingagents.graph.trading_graph.FinancialSituationMemory" - ) as mock_memory: - with patch("tradingagents.graph.trading_graph.set_config"): - graph = TradingAgentsGraph(config=sample_config) + "tradingagents.graph.trading_graph.FinancialSituationMemory", + ), patch("tradingagents.graph.trading_graph.set_config"): + graph = TradingAgentsGraph(config=sample_config) - # Set up current state - graph.curr_state = {"test": "state"} + # Set up current state + graph.curr_state = {"test": "state"} - # Mock reflector methods - graph.reflector.reflect_bull_researcher = Mock() - graph.reflector.reflect_bear_researcher = Mock() - graph.reflector.reflect_trader = Mock() - graph.reflector.reflect_invest_judge = Mock() - graph.reflector.reflect_risk_manager = Mock() + # Mock reflector methods + graph.reflector.reflect_bull_researcher = Mock() + graph.reflector.reflect_bear_researcher = Mock() + graph.reflector.reflect_trader = Mock() + graph.reflector.reflect_invest_judge = Mock() + graph.reflector.reflect_risk_manager = Mock() returns_losses = {"return": 0.05, "loss": -0.02} @@ -345,7 +339,7 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_process_signal( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir + self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, ): """Test signal processing functionality.""" sample_config["project_dir"] = temp_data_dir @@ -393,8 +387,8 @@ class TestTradingAgentsGraph: with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.set_config"): - graph = TradingAgentsGraph( - selected_analysts=selected_analysts, config=sample_config + TradingAgentsGraph( + selected_analysts=selected_analysts, config=sample_config, ) # Verify graph was set up with selected analysts @@ -415,14 +409,14 @@ class TestTradingAgentsGraphErrorHandling: # This should still work as the class should use defaults for missing keys with patch("tradingagents.graph.trading_graph.set_config"): with pytest.raises( - KeyError + KeyError, ): # Should fail when trying to access missing config keys TradingAgentsGraph(config=invalid_config) @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_directory_creation_failure( - self, mock_toolkit, mock_chat_openai, sample_config + self, mock_toolkit, mock_chat_openai, sample_config, ): """Test handling when directory creation fails.""" sample_config["project_dir"] = "/invalid/path/that/cannot/be/created" diff --git a/tradingagents/agents/__init__.py b/tradingagents/agents/__init__.py index 6f507651..ef0067e1 100644 --- a/tradingagents/agents/__init__.py +++ b/tradingagents/agents/__init__.py @@ -1,40 +1,35 @@ -from .utils.agent_utils import Toolkit, create_msg_delete -from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState -from .utils.memory import FinancialSituationMemory - from .analysts.fundamentals_analyst import create_fundamentals_analyst from .analysts.market_analyst import create_market_analyst from .analysts.news_analyst import create_news_analyst from .analysts.social_media_analyst import create_social_media_analyst - +from .managers.research_manager import create_research_manager +from .managers.risk_manager import create_risk_manager from .researchers.bear_researcher import create_bear_researcher from .researchers.bull_researcher import create_bull_researcher - from .risk_mgmt.aggresive_debator import create_risky_debator from .risk_mgmt.conservative_debator import create_safe_debator from .risk_mgmt.neutral_debator import create_neutral_debator - -from .managers.research_manager import create_research_manager -from .managers.risk_manager import create_risk_manager - from .trader.trader import create_trader +from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState +from .utils.agent_utils import Toolkit, create_msg_delete +from .utils.memory import FinancialSituationMemory __all__ = [ - "FinancialSituationMemory", - "Toolkit", "AgentState", - "create_msg_delete", + "FinancialSituationMemory", "InvestDebateState", "RiskDebateState", + "Toolkit", "create_bear_researcher", "create_bull_researcher", - "create_research_manager", "create_fundamentals_analyst", "create_market_analyst", + "create_msg_delete", "create_neutral_debator", "create_news_analyst", - "create_risky_debator", + "create_research_manager", "create_risk_manager", + "create_risky_debator", "create_safe_debator", "create_social_media_analyst", "create_trader", diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 7399e9b4..b4fa6b45 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -5,7 +5,7 @@ def create_fundamentals_analyst(llm, toolkit): def fundamentals_analyst_node(state): current_date = state["trade_date"] ticker = state["company_of_interest"] - company_name = state["company_of_interest"] + state["company_of_interest"] if toolkit.config["online_tools"]: tools = [toolkit.get_fundamentals_openai] @@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit): system_message = ( "You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - + " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.", + " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.", ) prompt = ChatPromptTemplate.from_messages( @@ -37,7 +37,7 @@ def create_fundamentals_analyst(llm, toolkit): "For your reference, the current date is {current_date}. The company we want to look at is {ticker}", ), MessagesPlaceholder(variable_name="messages"), - ] + ], ) prompt = prompt.partial(system_message=system_message) diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index ebaa3775..a741c66b 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -6,7 +6,7 @@ def create_market_analyst(llm, toolkit): def market_analyst_node(state): current_date = state["trade_date"] ticker = state["company_of_interest"] - company_name = state["company_of_interest"] + state["company_of_interest"] if toolkit.config["online_tools"]: tools = [ @@ -45,7 +45,7 @@ Volume-Based Indicators: - vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses. - Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.""" - + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" ) prompt = ChatPromptTemplate.from_messages( @@ -62,7 +62,7 @@ Volume-Based Indicators: "For your reference, the current date is {current_date}. The company we want to look at is {ticker}", ), MessagesPlaceholder(variable_name="messages"), - ] + ], ) prompt = prompt.partial(system_message=system_message) diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index 76d0a114..598f890c 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit): system_message = ( "You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - + """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""" + """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""" ) prompt = ChatPromptTemplate.from_messages( @@ -34,7 +34,7 @@ def create_news_analyst(llm, toolkit): "For your reference, the current date is {current_date}. We are looking at the company {ticker}", ), MessagesPlaceholder(variable_name="messages"), - ] + ], ) prompt = prompt.partial(system_message=system_message) diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 52051547..b5e3aa94 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -5,7 +5,7 @@ def create_social_media_analyst(llm, toolkit): def social_media_analyst_node(state): current_date = state["trade_date"] ticker = state["company_of_interest"] - company_name = state["company_of_interest"] + state["company_of_interest"] if toolkit.config["online_tools"]: tools = [toolkit.get_stock_news_openai] @@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit): system_message = ( "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - + """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""", + """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""", ) prompt = ChatPromptTemplate.from_messages( @@ -33,7 +33,7 @@ def create_social_media_analyst(llm, toolkit): "For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}", ), MessagesPlaceholder(variable_name="messages"), - ] + ], ) prompt = prompt.partial(system_message=system_message) diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 25a6ef05..f9039ccf 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -12,7 +12,7 @@ def create_research_manager(llm, memory): past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" - for i, rec in enumerate(past_memories, 1): + for _i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" prompt = f"""As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented. @@ -24,7 +24,7 @@ Additionally, develop a detailed investment plan for the trader. This should inc Your Recommendation: A decisive stance supported by the most convincing arguments. Rationale: An explanation of why these arguments lead to your conclusion. Strategic Actions: Concrete steps for implementing the recommendation. -Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. +Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting. Here are your past reflections on mistakes: \"{past_memory_str}\" diff --git a/tradingagents/agents/managers/risk_manager.py b/tradingagents/agents/managers/risk_manager.py index 3c4b0227..7d563a6f 100644 --- a/tradingagents/agents/managers/risk_manager.py +++ b/tradingagents/agents/managers/risk_manager.py @@ -1,7 +1,7 @@ def create_risk_manager(llm, memory): def risk_manager_node(state) -> dict: - company_name = state["company_of_interest"] + state["company_of_interest"] history = state["risk_debate_state"]["history"] risk_debate_state = state["risk_debate_state"] @@ -15,7 +15,7 @@ def create_risk_manager(llm, memory): past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" - for i, rec in enumerate(past_memories, 1): + for _i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" prompt = f"""As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness. @@ -32,7 +32,7 @@ Deliverables: --- -**Analysts Debate History:** +**Analysts Debate History:** {history} --- diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index b7e21b05..8fade5d7 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -14,7 +14,7 @@ def create_bear_researcher(llm, memory): past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" - for i, rec in enumerate(past_memories, 1): + for _i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" prompt = f"""You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively. diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index 32124fe6..968fea7b 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -14,7 +14,7 @@ def create_bull_researcher(llm, memory): past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" - for i, rec in enumerate(past_memories, 1): + for _i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" prompt = f"""You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively. diff --git a/tradingagents/agents/risk_mgmt/aggresive_debator.py b/tradingagents/agents/risk_mgmt/aggresive_debator.py index 7abe3895..86c8c6b3 100644 --- a/tradingagents/agents/risk_mgmt/aggresive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggresive_debator.py @@ -41,7 +41,7 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes "current_risky_response": argument, "current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_neutral_response": risk_debate_state.get( - "current_neutral_response", "" + "current_neutral_response", "", ), "count": risk_debate_state["count"] + 1, } diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index a384e2c3..c574d2d3 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -39,11 +39,11 @@ Engage by questioning their optimism and emphasizing the potential downsides the "neutral_history": risk_debate_state.get("neutral_history", ""), "latest_speaker": "Safe", "current_risky_response": risk_debate_state.get( - "current_risky_response", "" + "current_risky_response", "", ), "current_safe_response": argument, "current_neutral_response": risk_debate_state.get( - "current_neutral_response", "" + "current_neutral_response", "", ), "count": risk_debate_state["count"] + 1, } diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index ad4f3438..f965a4e1 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -39,7 +39,7 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the "neutral_history": neutral_history + "\n" + argument, "latest_speaker": "Neutral", "current_risky_response": risk_debate_state.get( - "current_risky_response", "" + "current_risky_response", "", ), "current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_neutral_response": argument, diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 2f50fdd8..7234b5f7 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -15,7 +15,7 @@ def create_trader(llm, memory): past_memory_str = "" if past_memories: - for i, rec in enumerate(past_memories, 1): + for _i, rec in enumerate(past_memories, 1): past_memory_str += rec["recommendation"] + "\n\n" else: past_memory_str = "No past memories found." diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index c7209901..d94ea284 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -1,16 +1,18 @@ from typing import Annotated -from typing_extensions import TypedDict -from tradingagents.agents import * + from langgraph.graph import MessagesState +from typing_extensions import TypedDict + +# Import specific agent classes as needed # Researcher team state class InvestDebateState(TypedDict): bull_history: Annotated[ - str, "Bullish Conversation history" + str, "Bullish Conversation history", ] # Bullish Conversation history bear_history: Annotated[ - str, "Bearish Conversation history" + str, "Bearish Conversation history", ] # Bullish Conversation history history: Annotated[str, "Conversation history"] # Conversation history current_response: Annotated[str, "Latest response"] # Last response @@ -21,24 +23,24 @@ class InvestDebateState(TypedDict): # Risk management team state class RiskDebateState(TypedDict): risky_history: Annotated[ - str, "Risky Agent's Conversation history" + str, "Risky Agent's Conversation history", ] # Conversation history safe_history: Annotated[ - str, "Safe Agent's Conversation history" + str, "Safe Agent's Conversation history", ] # Conversation history neutral_history: Annotated[ - str, "Neutral Agent's Conversation history" + str, "Neutral Agent's Conversation history", ] # Conversation history history: Annotated[str, "Conversation history"] # Conversation history latest_speaker: Annotated[str, "Analyst that spoke last"] current_risky_response: Annotated[ - str, "Latest response by the risky analyst" + str, "Latest response by the risky analyst", ] # Last response current_safe_response: Annotated[ - str, "Latest response by the safe analyst" + str, "Latest response by the safe analyst", ] # Last response current_neutral_response: Annotated[ - str, "Latest response by the neutral analyst" + str, "Latest response by the neutral analyst", ] # Last response judge_decision: Annotated[str, "Judge's decision"] count: Annotated[int, "Length of the current conversation"] # Conversation length @@ -54,13 +56,13 @@ class AgentState(MessagesState): market_report: Annotated[str, "Report from the Market Analyst"] sentiment_report: Annotated[str, "Report from the Social Media Analyst"] news_report: Annotated[ - str, "Report from the News Researcher of current world affairs" + str, "Report from the News Researcher of current world affairs", ] fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] # researcher team discussion step investment_debate_state: Annotated[ - InvestDebateState, "Current state of the debate on if to invest or not" + InvestDebateState, "Current state of the debate on if to invest or not", ] investment_plan: Annotated[str, "Plan generated by the Analyst"] @@ -68,6 +70,6 @@ class AgentState(MessagesState): # risk management team discussion step risk_debate_state: Annotated[ - RiskDebateState, "Current state of the debate on evaluating risk" + RiskDebateState, "Current state of the debate on evaluating risk", ] final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 8a371358..7978ce48 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -1,9 +1,10 @@ -from langchain_core.messages import HumanMessage -from typing import Annotated -from langchain_core.messages import RemoveMessage -from langchain_core.tools import tool from datetime import datetime -import tradingagents.dataflows.interface as interface +from typing import Annotated + +from langchain_core.messages import HumanMessage, RemoveMessage +from langchain_core.tools import tool + +from tradingagents.dataflows import interface from tradingagents.default_config import DEFAULT_CONFIG @@ -18,7 +19,7 @@ def create_msg_delete(): # Add a minimal placeholder message placeholder = HumanMessage(content="Continue") - return {"messages": removal_operations + [placeholder]} + return {"messages": [*removal_operations, placeholder]} return delete_messages @@ -53,9 +54,8 @@ class Toolkit: str: A formatted dataframe containing the latest global news from Reddit in the specified time frame. """ - global_news_result = interface.get_reddit_global_news(curr_date, 7, 5) + return interface.get_reddit_global_news(curr_date, 7, 5) - return global_news_result @staticmethod @tool @@ -83,11 +83,10 @@ class Toolkit: start_date = datetime.strptime(start_date, "%Y-%m-%d") look_back_days = (end_date - start_date).days - finnhub_news_result = interface.get_finnhub_news( - ticker, end_date_str, look_back_days + return interface.get_finnhub_news( + ticker, end_date_str, look_back_days, ) - return finnhub_news_result @staticmethod @tool @@ -107,9 +106,8 @@ class Toolkit: str: A formatted dataframe containing the latest news about the company on the given date """ - stock_news_results = interface.get_reddit_company_news(ticker, curr_date, 7, 5) + return interface.get_reddit_company_news(ticker, curr_date, 7, 5) - return stock_news_results @staticmethod @tool @@ -128,9 +126,8 @@ class Toolkit: str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range. """ - result_data = interface.get_YFin_data(symbol, start_date, end_date) + return interface.get_YFin_data(symbol, start_date, end_date) - return result_data @staticmethod @tool @@ -149,19 +146,18 @@ class Toolkit: str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range. """ - result_data = interface.get_YFin_data_online(symbol, start_date, end_date) + return interface.get_YFin_data_online(symbol, start_date, end_date) - return result_data @staticmethod @tool def get_stockstats_indicators_report( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[ - str, "technical indicator to get the analysis and report of" + str, "technical indicator to get the analysis and report of", ], curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd" + str, "The current trading date you are trading on, YYYY-mm-dd", ], look_back_days: Annotated[int, "how many days to look back"] = 30, ) -> str: @@ -176,21 +172,20 @@ class Toolkit: str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator. """ - result_stockstats = interface.get_stock_stats_indicators_window( - symbol, indicator, curr_date, look_back_days, False + return interface.get_stock_stats_indicators_window( + symbol, indicator, curr_date, look_back_days, False, ) - return result_stockstats @staticmethod @tool def get_stockstats_indicators_report_online( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[ - str, "technical indicator to get the analysis and report of" + str, "technical indicator to get the analysis and report of", ], curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd" + str, "The current trading date you are trading on, YYYY-mm-dd", ], look_back_days: Annotated[int, "how many days to look back"] = 30, ) -> str: @@ -205,11 +200,10 @@ class Toolkit: str: A formatted dataframe containing the stock stats indicators for the specified ticker symbol and indicator. """ - result_stockstats = interface.get_stock_stats_indicators_window( - symbol, indicator, curr_date, look_back_days, True + return interface.get_stock_stats_indicators_window( + symbol, indicator, curr_date, look_back_days, True, ) - return result_stockstats @staticmethod @tool @@ -229,11 +223,10 @@ class Toolkit: str: a report of the sentiment in the past 30 days starting at curr_date """ - data_sentiment = interface.get_finnhub_company_insider_sentiment( - ticker, curr_date, 30 + return interface.get_finnhub_company_insider_sentiment( + ticker, curr_date, 30, ) - return data_sentiment @staticmethod @tool @@ -253,11 +246,10 @@ class Toolkit: str: a report of the company's insider transactions/trading information in the past 30 days """ - data_trans = interface.get_finnhub_company_insider_transactions( - ticker, curr_date, 30 + return interface.get_finnhub_company_insider_transactions( + ticker, curr_date, 30, ) - return data_trans @staticmethod @tool @@ -279,9 +271,8 @@ class Toolkit: str: a report of the company's most recent balance sheet """ - data_balance_sheet = interface.get_simfin_balance_sheet(ticker, freq, curr_date) + return interface.get_simfin_balance_sheet(ticker, freq, curr_date) - return data_balance_sheet @staticmethod @tool @@ -303,9 +294,8 @@ class Toolkit: str: a report of the company's most recent cash flow statement """ - data_cashflow = interface.get_simfin_cashflow(ticker, freq, curr_date) + return interface.get_simfin_cashflow(ticker, freq, curr_date) - return data_cashflow @staticmethod @tool @@ -327,11 +317,10 @@ class Toolkit: str: a report of the company's most recent income statement """ - data_income_stmt = interface.get_simfin_income_statements( - ticker, freq, curr_date + return interface.get_simfin_income_statements( + ticker, freq, curr_date, ) - return data_income_stmt @staticmethod @tool @@ -349,9 +338,8 @@ class Toolkit: str: A formatted string containing the latest news from Google News based on the query and date range. """ - google_news_results = interface.get_google_news(query, curr_date, 7) + return interface.get_google_news(query, curr_date, 7) - return google_news_results @staticmethod @tool @@ -368,9 +356,8 @@ class Toolkit: str: A formatted string containing the latest news about the company on the given date. """ - openai_news_results = interface.get_stock_news_openai(ticker, curr_date) + return interface.get_stock_news_openai(ticker, curr_date) - return openai_news_results @staticmethod @tool @@ -385,9 +372,8 @@ class Toolkit: str: A formatted string containing the latest macroeconomic news on the given date. """ - openai_news_results = interface.get_global_news_openai(curr_date) + return interface.get_global_news_openai(curr_date) - return openai_news_results @staticmethod @tool @@ -404,8 +390,7 @@ class Toolkit: str: A formatted string containing the latest fundamental information about the company on the given date. """ - openai_fundamentals_results = interface.get_fundamentals_openai( - ticker, curr_date + return interface.get_fundamentals_openai( + ticker, curr_date, ) - return openai_fundamentals_results diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 1f302d80..108ae735 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -59,7 +59,7 @@ class FinancialSituationMemory: "matched_situation": results["documents"][0][i], "recommendation": results["metadatas"][0][i]["recommendation"], "similarity_score": 1 - results["distances"][0][i], - } + }, ) return matched_results @@ -94,18 +94,15 @@ if __name__ == "__main__": # Example query current_situation = """ - Market showing increased volatility in tech sector, with institutional investors + Market showing increased volatility in tech sector, with institutional investors reducing positions and rising interest rates affecting growth stock valuations """ try: recommendations = matcher.get_memories(current_situation, n_matches=2) - for i, rec in enumerate(recommendations, 1): - print(f"\nMatch {i}:") - print(f"Similarity Score: {rec['similarity_score']:.2f}") - print(f"Matched Situation: {rec['matched_situation']}") - print(f"Recommendation: {rec['recommendation']}") + for _i, _rec in enumerate(recommendations, 1): + pass - except Exception as e: - print(f"Error during recommendation: {str(e)}") + except Exception: + pass diff --git a/tradingagents/config.py b/tradingagents/config.py index 5fb71dd9..77e1367c 100644 --- a/tradingagents/config.py +++ b/tradingagents/config.py @@ -5,6 +5,7 @@ Loads configuration from environment variables and .env file. import os from pathlib import Path + from dotenv import load_dotenv # Load .env file from project root @@ -20,10 +21,10 @@ def get_config(): "project_dir": str(project_root / "tradingagents"), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "data_dir": os.getenv( - "TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data" + "TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data", ), "data_cache_dir": str( - project_root / "tradingagents" / "dataflows" / "data_cache" + project_root / "tradingagents" / "dataflows" / "data_cache", ), # LLM settings "llm_provider": os.getenv("LLM_PROVIDER", "openai"), @@ -47,16 +48,17 @@ def get_config(): # Validate required API keys based on provider if config["llm_provider"] == "openai" and not config["openai_api_key"]: - raise ValueError("OPENAI_API_KEY is required when using OpenAI provider") - elif config["llm_provider"] == "anthropic" and not config["anthropic_api_key"]: - raise ValueError("ANTHROPIC_API_KEY is required when using Anthropic provider") - elif config["llm_provider"] == "google" and not config["google_api_key"]: - raise ValueError("GOOGLE_API_KEY is required when using Google provider") + msg = "OPENAI_API_KEY is required when using OpenAI provider" + raise ValueError(msg) + if config["llm_provider"] == "anthropic" and not config["anthropic_api_key"]: + msg = "ANTHROPIC_API_KEY is required when using Anthropic provider" + raise ValueError(msg) + if config["llm_provider"] == "google" and not config["google_api_key"]: + msg = "GOOGLE_API_KEY is required when using Google provider" + raise ValueError(msg) if not config["finnhub_api_key"]: - print( - "Warning: FINNHUB_API_KEY not set. Some financial data features may be limited." - ) + pass return config diff --git a/tradingagents/dataflows/__init__.py b/tradingagents/dataflows/__init__.py index 004d4138..522256a3 100644 --- a/tradingagents/dataflows/__init__.py +++ b/tradingagents/dataflows/__init__.py @@ -1,17 +1,13 @@ from .finnhub_utils import get_data_in_range from .googlenews_utils import getNewsData -from .yfin_utils import YFinanceUtils -from .reddit_utils import fetch_top_from_category -from .stockstats_utils import StockstatsUtils - from .interface import ( - # News and sentiment functions - get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, + # News and sentiment functions + get_finnhub_news, get_google_news, - get_reddit_global_news, get_reddit_company_news, + get_reddit_global_news, # Financial statements functions get_simfin_balance_sheet, get_simfin_cashflow, @@ -19,19 +15,25 @@ from .interface import ( # Technical analysis functions get_stock_stats_indicators_window, get_stockstats_indicator, + get_YFin_data, # Market data functions get_YFin_data_window, - get_YFin_data, ) +from .reddit_utils import fetch_top_from_category +from .stockstats_utils import StockstatsUtils +from .yfin_utils import YFinanceUtils __all__ = [ - # News and sentiment functions - "get_finnhub_news", + "get_YFin_data", + # Market data functions + "get_YFin_data_window", "get_finnhub_company_insider_sentiment", "get_finnhub_company_insider_transactions", + # News and sentiment functions + "get_finnhub_news", "get_google_news", - "get_reddit_global_news", "get_reddit_company_news", + "get_reddit_global_news", # Financial statements functions "get_simfin_balance_sheet", "get_simfin_cashflow", @@ -39,7 +41,10 @@ __all__ = [ # Technical analysis functions "get_stock_stats_indicators_window", "get_stockstats_indicator", - # Market data functions - "get_YFin_data_window", - "get_YFin_data", + # Utilities and classes + "get_data_in_range", + "getNewsData", + "YFinanceUtils", + "fetch_top_from_category", + "StockstatsUtils", ] diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index b8a8f8aa..b3adcaf2 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -1,9 +1,9 @@ -import tradingagents.default_config as default_config -from typing import Dict, Optional + +from tradingagents import default_config # Use default config but allow it to be overridden -_config: Optional[Dict] = None -DATA_DIR: Optional[str] = None +_config: dict | None = None +DATA_DIR: str | None = None def initialize_config(): @@ -14,7 +14,7 @@ def initialize_config(): DATA_DIR = _config["data_dir"] -def set_config(config: Dict): +def set_config(config: dict): """Update the configuration with custom values.""" global _config, DATA_DIR if _config is None: @@ -23,7 +23,7 @@ def set_config(config: Dict): DATA_DIR = _config["data_dir"] -def get_config() -> Dict: +def get_config() -> dict: """Get the current configuration.""" if _config is None: initialize_config() diff --git a/tradingagents/dataflows/finnhub_utils.py b/tradingagents/dataflows/finnhub_utils.py index e7c7103c..7d0a1e30 100644 --- a/tradingagents/dataflows/finnhub_utils.py +++ b/tradingagents/dataflows/finnhub_utils.py @@ -22,10 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period= ) else: data_path = os.path.join( - data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json" + data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json", ) - data = open(data_path, "r") + data = open(data_path) data = json.load(data) # filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD) diff --git a/tradingagents/dataflows/googlenews_utils.py b/tradingagents/dataflows/googlenews_utils.py index 9f454abc..501bc2d6 100644 --- a/tradingagents/dataflows/googlenews_utils.py +++ b/tradingagents/dataflows/googlenews_utils.py @@ -1,20 +1,20 @@ -import requests -from bs4 import BeautifulSoup -from datetime import datetime -import time -import random import logging +import random +import time +from datetime import datetime from urllib.parse import quote_plus -logger = logging.getLogger(__name__) - +import requests +from bs4 import BeautifulSoup from tenacity import ( retry, + retry_if_result, stop_after_attempt, wait_exponential, - retry_if_result, ) +logger = logging.getLogger(__name__) + def is_rate_limited(response): """Check if the response indicates we should back off (rate-limited or temporarily unavailable).""" @@ -35,8 +35,7 @@ def _add_jitter(retry_state): def make_request(url, headers): """Make a request with retry logic for rate limiting""" # The retry decorator already applies exponential backoff with jitter - response = requests.get(url, headers=headers, timeout=(5, 20)) - return response + return requests.get(url, headers=headers, timeout=(5, 20)) def getNewsData(query, start_date, end_date): @@ -58,7 +57,7 @@ def getNewsData(query, start_date, end_date): "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " "AppleWebKit/537.36 (KHTML, like Gecko) " "Chrome/101.0.4951.54 Safari/537.36" - ) + ), } news_results = [] @@ -103,7 +102,7 @@ def getNewsData(query, start_date, end_date): "source": ( source_el.get_text(strip=True) if source_el else "" ), - } + }, ) except Exception as e: logger.warning("Error processing result: %s", e) @@ -120,7 +119,7 @@ def getNewsData(query, start_date, end_date): page += 1 except Exception as e: - logger.error("Failed after multiple retries: %s", e) + logger.exception("Failed after multiple retries: %s", e) break return news_results diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index e7062d51..5de638cb 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -1,17 +1,18 @@ -from typing import Annotated -from .reddit_utils import fetch_top_from_category -from .yfin_utils import * -from .stockstats_utils import * -from .googlenews_utils import * -from .finnhub_utils import get_data_in_range -from dateutil.relativedelta import relativedelta -from datetime import datetime import os +from datetime import datetime +from typing import Annotated + import pandas as pd -from tqdm import tqdm import yfinance as yf +from dateutil.relativedelta import relativedelta from openai import OpenAI -from .config import get_config, DATA_DIR +from tqdm import tqdm + +from .config import DATA_DIR, get_config +from .finnhub_utils import get_data_in_range +from .googlenews_utils import getNewsData +from .reddit_utils import fetch_top_from_category +from .stockstats_utils import StockstatsUtils def get_finnhub_news( @@ -84,7 +85,7 @@ def get_finnhub_company_insider_sentiment( result_str = "" seen_dicts = [] - for date, senti_list in data.items(): + for senti_list in data.values(): for entry in senti_list: if entry not in seen_dicts: result_str += f"### {entry['year']}-{entry['month']}:\nChange: {entry['change']}\nMonthly Share Purchase Ratio: {entry['mspr']}\n\n" @@ -126,7 +127,7 @@ def get_finnhub_company_insider_transactions( result_str = "" seen_dicts = [] - for date, senti_list in data.items(): + for senti_list in data.values(): for entry in senti_list: if entry not in seen_dicts: result_str += f"### Filing Date: {entry['filingDate']}, {entry['name']}:\nChange:{entry['change']}\nShares: {entry['share']}\nTransaction Price: {entry['transactionPrice']}\nTransaction Code: {entry['transactionCode']}\n\n" @@ -170,7 +171,6 @@ def get_simfin_balance_sheet( # Check if there are any available reports; if not, return a notification if filtered_df.empty: - print("No balance sheet available before the given current date.") return "" # Get the most recent balance sheet by selecting the row with the latest Publish Date @@ -217,7 +217,6 @@ def get_simfin_cashflow( # Check if there are any available reports; if not, return a notification if filtered_df.empty: - print("No cash flow statement available before the given current date.") return "" # Get the most recent cash flow statement by selecting the row with the latest Publish Date @@ -264,7 +263,6 @@ def get_simfin_income_statements( # Check if there are any available reports; if not, return a notification if filtered_df.empty: - print("No income statement available before the given current date.") return "" # Get the most recent income statement by selecting the row with the latest Publish Date @@ -421,7 +419,7 @@ def get_stock_stats_indicators_window( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[str, "technical indicator to get the analysis and report of"], curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd" + str, "The current trading date you are trading on, YYYY-mm-dd", ], look_back_days: Annotated[int, "how many days to look back"], online: Annotated[bool, "to fetch data online or offline"], @@ -501,8 +499,9 @@ def get_stock_stats_indicators_window( } if indicator not in best_ind_params: + msg = f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}" raise ValueError( - f"Indicator {indicator} is not supported. Please choose from: {list(best_ind_params.keys())}" + msg, ) end_date = curr_date @@ -515,7 +514,7 @@ def get_stock_stats_indicators_window( os.path.join( DATA_DIR, f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) + ), ) data["Date"] = pd.to_datetime(data["Date"], utc=True) dates_in_df = data["Date"].astype(str).str[:10] @@ -525,7 +524,7 @@ def get_stock_stats_indicators_window( # only do the trading dates if curr_date.strftime("%Y-%m-%d") in dates_in_df.values: indicator_value = get_stockstats_indicator( - symbol, indicator, curr_date.strftime("%Y-%m-%d"), online + symbol, indicator, curr_date.strftime("%Y-%m-%d"), online, ) ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n" @@ -536,28 +535,27 @@ def get_stock_stats_indicators_window( ind_string = "" while curr_date >= before: indicator_value = get_stockstats_indicator( - symbol, indicator, curr_date.strftime("%Y-%m-%d"), online + symbol, indicator, curr_date.strftime("%Y-%m-%d"), online, ) ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n" curr_date = curr_date - relativedelta(days=1) - result_str = ( + return ( f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n" + ind_string + "\n\n" + best_ind_params.get(indicator, "No description available.") ) - return result_str def get_stockstats_indicator( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[str, "technical indicator to get the analysis and report of"], curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd" + str, "The current trading date you are trading on, YYYY-mm-dd", ], online: Annotated[bool, "to fetch data online or offline"], ) -> str: @@ -573,10 +571,7 @@ def get_stockstats_indicator( os.path.join(DATA_DIR, "market_data", "price_data"), online=online, ) - except Exception as e: - print( - f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}" - ) + except Exception: return "" return str(indicator_value) @@ -597,7 +592,7 @@ def get_YFin_data_window( os.path.join( DATA_DIR, f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) + ), ) # Extract just the date part for comparison @@ -613,7 +608,7 @@ def get_YFin_data_window( # Set pandas display options to show the full DataFrame with pd.option_context( - "display.max_rows", None, "display.max_columns", None, "display.width", None + "display.max_rows", None, "display.max_columns", None, "display.width", None, ): df_string = filtered_data.to_string() @@ -675,12 +670,13 @@ def get_YFin_data( os.path.join( DATA_DIR, f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) + ), ) if end_date > "2025-03-25": + msg = f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25" raise Exception( - f"Get_YFin_Data: {end_date} is outside of the data range of 2015-01-01 to 2025-03-25" + msg, ) # Extract just the date part for comparison @@ -695,9 +691,8 @@ def get_YFin_data( filtered_data = filtered_data.drop("DateOnly", axis=1) # remove the index from the dataframe - filtered_data = filtered_data.reset_index(drop=True) + return filtered_data.reset_index(drop=True) - return filtered_data def get_stock_news_openai(ticker, curr_date): @@ -713,9 +708,9 @@ def get_stock_news_openai(ticker, curr_date): { "type": "input_text", "text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.", - } + }, ], - } + }, ], text={"format": {"type": "text"}}, reasoning={}, @@ -724,7 +719,7 @@ def get_stock_news_openai(ticker, curr_date): "type": "web_search_preview", "user_location": {"type": "approximate"}, "search_context_size": "low", - } + }, ], temperature=1, max_output_tokens=4096, @@ -748,9 +743,9 @@ def get_global_news_openai(curr_date): { "type": "input_text", "text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.", - } + }, ], - } + }, ], text={"format": {"type": "text"}}, reasoning={}, @@ -759,7 +754,7 @@ def get_global_news_openai(curr_date): "type": "web_search_preview", "user_location": {"type": "approximate"}, "search_context_size": "low", - } + }, ], temperature=1, max_output_tokens=4096, @@ -783,9 +778,9 @@ def get_fundamentals_openai(ticker, curr_date): { "type": "input_text", "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc", - } + }, ], - } + }, ], text={"format": {"type": "text"}}, reasoning={}, @@ -794,7 +789,7 @@ def get_fundamentals_openai(ticker, curr_date): "type": "web_search_preview", "user_location": {"type": "approximate"}, "search_context_size": "low", - } + }, ], temperature=1, max_output_tokens=4096, diff --git a/tradingagents/dataflows/reddit_utils.py b/tradingagents/dataflows/reddit_utils.py index 5d401239..5710073c 100644 --- a/tradingagents/dataflows/reddit_utils.py +++ b/tradingagents/dataflows/reddit_utils.py @@ -1,8 +1,8 @@ import json -from datetime import datetime -from typing import Annotated import os import re +from datetime import datetime +from typing import Annotated ticker_to_company = { "AAPL": "Apple", @@ -48,11 +48,11 @@ ticker_to_company = { def fetch_top_from_category( category: Annotated[ - str, "Category to fetch top post from. Collection of subreddits." + str, "Category to fetch top post from. Collection of subreddits.", ], date: Annotated[str, "Date to fetch top posts from."], max_limit: Annotated[int, "Maximum number of posts to fetch."], - query: Annotated[str, "Optional query to search for in the subreddit."] = None, + query: Annotated[str | None, "Optional query to search for in the subreddit."] = None, data_path: Annotated[ str, "Path to the data folder. Default is 'reddit_data'.", @@ -63,12 +63,13 @@ def fetch_top_from_category( all_content = [] if max_limit < len(os.listdir(os.path.join(base_path, category))): + msg = "REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts" raise ValueError( - "REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts" + msg, ) limit_per_subreddit = max_limit // len( - os.listdir(os.path.join(base_path, category)) + os.listdir(os.path.join(base_path, category)), ) for data_file in os.listdir(os.path.join(base_path, category)): @@ -79,7 +80,7 @@ def fetch_top_from_category( all_content_curr_subreddit = [] with open(os.path.join(base_path, category, data_file), "rb") as f: - for i, line in enumerate(f): + for _i, line in enumerate(f): # skip empty lines if not line.strip(): continue @@ -88,7 +89,7 @@ def fetch_top_from_category( # select only lines that are from the date post_date = datetime.utcfromtimestamp( - parsed_line["created_utc"] + parsed_line["created_utc"], ).strftime("%Y-%m-%d") if post_date != date: continue @@ -106,7 +107,7 @@ def fetch_top_from_category( found = False for term in search_terms: if re.search( - term, parsed_line["title"], re.IGNORECASE + term, parsed_line["title"], re.IGNORECASE, ) or re.search(term, parsed_line["selftext"], re.IGNORECASE): found = True break diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index 78ffb220..a36e150f 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -1,8 +1,10 @@ +import os +from typing import Annotated + import pandas as pd import yfinance as yf from stockstats import wrap -from typing import Annotated -import os + from .config import get_config @@ -11,10 +13,10 @@ class StockstatsUtils: def get_stock_stats( symbol: Annotated[str, "ticker symbol for the company"], indicator: Annotated[ - str, "quantitative indicators based off of the stock data for the company" + str, "quantitative indicators based off of the stock data for the company", ], curr_date: Annotated[ - str, "curr date for retrieving stock price data, YYYY-mm-dd" + str, "curr date for retrieving stock price data, YYYY-mm-dd", ], data_dir: Annotated[ str, @@ -34,11 +36,12 @@ class StockstatsUtils: os.path.join( data_dir, f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv", - ) + ), ) df = wrap(data) except FileNotFoundError: - raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!") + msg = "Stockstats fail: Yahoo Finance data not fetched yet!" + raise Exception(msg) else: # Get today's date as YYYY-mm-dd to add to cache today_date = pd.Timestamp.today() @@ -81,7 +84,5 @@ class StockstatsUtils: matching_rows = df[df["Date"].str.startswith(curr_date)] if not matching_rows.empty: - indicator_value = matching_rows[indicator].values[0] - return indicator_value - else: - return "N/A: Not a trading day (weekend or holiday)" + return matching_rows[indicator].values[0] + return "N/A: Not a trading day (weekend or holiday)" diff --git a/tradingagents/dataflows/utils.py b/tradingagents/dataflows/utils.py index d21b9266..c2d82b4c 100644 --- a/tradingagents/dataflows/utils.py +++ b/tradingagents/dataflows/utils.py @@ -1,14 +1,14 @@ -import pandas as pd -from datetime import date, timedelta, datetime +from datetime import date, datetime, timedelta from typing import Annotated +import pandas as pd + SavePathType = Annotated[str, "File path to save data. If None, data is not saved."] def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None: if save_path: data.to_csv(save_path) - print(f"{tag} saved to {save_path}") def get_current_date(): @@ -32,7 +32,5 @@ def get_next_weekday(date): if date.weekday() >= 5: days_to_add = 7 - date.weekday() - next_weekday = date + timedelta(days=days_to_add) - return next_weekday - else: - return date + return date + timedelta(days=days_to_add) + return date diff --git a/tradingagents/dataflows/yfin_utils.py b/tradingagents/dataflows/yfin_utils.py index 2a59f883..f1a69df3 100644 --- a/tradingagents/dataflows/yfin_utils.py +++ b/tradingagents/dataflows/yfin_utils.py @@ -1,10 +1,12 @@ # gets data/stats -import yfinance as yf -from typing import Annotated, Callable, Any, Optional -from pandas import DataFrame -import pandas as pd +from collections.abc import Callable from functools import wraps +from typing import Annotated, Any + +import pandas as pd +import yfinance as yf +from pandas import DataFrame from .utils import SavePathType, decorate_all_methods @@ -24,38 +26,36 @@ def init_ticker(func: Callable) -> Callable: class YFinanceUtils: def get_stock_data( - symbol: Annotated[str, "ticker symbol"], + self: Annotated[str, "ticker symbol"], start_date: Annotated[ - str, "start date for retrieving stock price data, YYYY-mm-dd" + str, "start date for retrieving stock price data, YYYY-mm-dd", ], end_date: Annotated[ - str, "end date for retrieving stock price data, YYYY-mm-dd" + str, "end date for retrieving stock price data, YYYY-mm-dd", ], save_path: SavePathType = None, ) -> DataFrame: """retrieve stock price data for designated ticker symbol""" - ticker = symbol + ticker = self # add one day to the end_date so that the data range is inclusive end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1) end_date = end_date.strftime("%Y-%m-%d") - stock_data = ticker.history(start=start_date, end=end_date) + return ticker.history(start=start_date, end=end_date) # save_output(stock_data, f"Stock data for {ticker.ticker}", save_path) - return stock_data def get_stock_info( - symbol: Annotated[str, "ticker symbol"], + self: Annotated[str, "ticker symbol"], ) -> dict: """Fetches and returns latest stock information.""" - ticker = symbol - stock_info = ticker.info - return stock_info + ticker = self + return ticker.info def get_company_info( - symbol: Annotated[str, "ticker symbol"], - save_path: Optional[str] = None, + self: Annotated[str, "ticker symbol"], + save_path: str | None = None, ) -> DataFrame: """Fetches and returns company information as a DataFrame.""" - ticker = symbol + ticker = self info = ticker.info company_info = { "Company Name": info.get("shortName", "N/A"), @@ -67,42 +67,37 @@ class YFinanceUtils: company_info_df = DataFrame([company_info]) if save_path: company_info_df.to_csv(save_path) - print(f"Company info for {ticker.ticker} saved to {save_path}") return company_info_df def get_stock_dividends( - symbol: Annotated[str, "ticker symbol"], - save_path: Optional[str] = None, + self: Annotated[str, "ticker symbol"], + save_path: str | None = None, ) -> DataFrame: """Fetches and returns the latest dividends data as a DataFrame.""" - ticker = symbol + ticker = self dividends = ticker.dividends if save_path: dividends.to_csv(save_path) - print(f"Dividends for {ticker.ticker} saved to {save_path}") return dividends - def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: + def get_income_stmt(self: Annotated[str, "ticker symbol"]) -> DataFrame: """Fetches and returns the latest income statement of the company as a DataFrame.""" - ticker = symbol - income_stmt = ticker.financials - return income_stmt + ticker = self + return ticker.financials - def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: + def get_balance_sheet(self: Annotated[str, "ticker symbol"]) -> DataFrame: """Fetches and returns the latest balance sheet of the company as a DataFrame.""" - ticker = symbol - balance_sheet = ticker.balance_sheet - return balance_sheet + ticker = self + return ticker.balance_sheet - def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame: + def get_cash_flow(self: Annotated[str, "ticker symbol"]) -> DataFrame: """Fetches and returns the latest cash flow statement of the company as a DataFrame.""" - ticker = symbol - cash_flow = ticker.cashflow - return cash_flow + ticker = self + return ticker.cashflow - def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple: + def get_analyst_recommendations(self: Annotated[str, "ticker symbol"]) -> tuple: """Fetches the latest analyst recommendations and returns the most common recommendation and its count.""" - ticker = symbol + ticker = self recommendations = ticker.recommendations if recommendations.empty: return None, 0 # No recommendations available diff --git a/tradingagents/graph/__init__.py b/tradingagents/graph/__init__.py index 80982c19..4ee4d847 100644 --- a/tradingagents/graph/__init__.py +++ b/tradingagents/graph/__init__.py @@ -1,17 +1,17 @@ # TradingAgents/graph/__init__.py -from .trading_graph import TradingAgentsGraph from .conditional_logic import ConditionalLogic -from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector +from .setup import GraphSetup from .signal_processing import SignalProcessor +from .trading_graph import TradingAgentsGraph __all__ = [ - "TradingAgentsGraph", "ConditionalLogic", "GraphSetup", "Propagator", "Reflector", "SignalProcessor", + "TradingAgentsGraph", ] diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index db53ee32..dc522af8 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -1,6 +1,7 @@ # TradingAgents/graph/propagation.py -from typing import Dict, Any +from typing import Any + from tradingagents.agents.utils.agent_states import ( InvestDebateState, RiskDebateState, @@ -15,15 +16,15 @@ class Propagator: self.max_recur_limit = max_recur_limit def create_initial_state( - self, company_name: str, trade_date: str - ) -> Dict[str, Any]: + self, company_name: str, trade_date: str, + ) -> dict[str, Any]: """Create the initial state for the agent graph.""" return { "messages": [("human", company_name)], "company_of_interest": company_name, "trade_date": str(trade_date), "investment_debate_state": InvestDebateState( - {"history": "", "current_response": "", "count": 0} + {"history": "", "current_response": "", "count": 0}, ), "risk_debate_state": RiskDebateState( { @@ -32,7 +33,7 @@ class Propagator: "current_safe_response": "", "current_neutral_response": "", "count": 0, - } + }, ), "market_report": "", "fundamentals_report": "", @@ -40,7 +41,7 @@ class Propagator: "news_report": "", } - def get_graph_args(self) -> Dict[str, Any]: + def get_graph_args(self) -> dict[str, Any]: """Get arguments for the graph invocation.""" return { "stream_mode": "values", diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 33303231..04b66224 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -1,6 +1,7 @@ # TradingAgents/graph/reflection.py -from typing import Dict, Any +from typing import Any + from langchain_openai import ChatOpenAI @@ -15,7 +16,7 @@ class Reflector: def _get_reflection_prompt(self) -> str: """Get the system prompt for reflection.""" return """ -You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis. +You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis. Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines: 1. Reasoning: @@ -25,7 +26,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh - Technical indicators. - Technical signals. - Price movement analysis. - - Overall market data analysis + - Overall market data analysis - News analysis. - Social media and sentiment analysis. - Fundamental data analysis. @@ -46,7 +47,7 @@ Your goal is to deliver detailed insights into investment decisions and highligh Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis. """ - def _extract_current_situation(self, current_state: Dict[str, Any]) -> str: + def _extract_current_situation(self, current_state: dict[str, Any]) -> str: """Extract the current market situation from the state.""" curr_market_report = current_state["market_report"] curr_sentiment_report = current_state["sentiment_report"] @@ -56,7 +57,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" def _reflect_on_component( - self, component_type: str, report: str, situation: str, returns_losses + self, component_type: str, report: str, situation: str, returns_losses, ) -> str: """Generate reflection for a component.""" messages = [ @@ -67,8 +68,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur ), ] - result = self.quick_thinking_llm.invoke(messages).content - return result + return self.quick_thinking_llm.invoke(messages).content def reflect_bull_researcher(self, current_state, returns_losses, bull_memory): """Reflect on bull researcher's analysis and update memory.""" @@ -76,7 +76,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur bull_debate_history = current_state["investment_debate_state"]["bull_history"] result = self._reflect_on_component( - "BULL", bull_debate_history, situation, returns_losses + "BULL", bull_debate_history, situation, returns_losses, ) bull_memory.add_situations([(situation, result)]) @@ -86,7 +86,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur bear_debate_history = current_state["investment_debate_state"]["bear_history"] result = self._reflect_on_component( - "BEAR", bear_debate_history, situation, returns_losses + "BEAR", bear_debate_history, situation, returns_losses, ) bear_memory.add_situations([(situation, result)]) @@ -96,7 +96,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur trader_decision = current_state["trader_investment_plan"] result = self._reflect_on_component( - "TRADER", trader_decision, situation, returns_losses + "TRADER", trader_decision, situation, returns_losses, ) trader_memory.add_situations([(situation, result)]) @@ -106,7 +106,7 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur judge_decision = current_state["investment_debate_state"]["judge_decision"] result = self._reflect_on_component( - "INVEST JUDGE", judge_decision, situation, returns_losses + "INVEST JUDGE", judge_decision, situation, returns_losses, ) invest_judge_memory.add_situations([(situation, result)]) @@ -116,6 +116,6 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur judge_decision = current_state["risk_debate_state"]["judge_decision"] result = self._reflect_on_component( - "RISK JUDGE", judge_decision, situation, returns_losses + "RISK JUDGE", judge_decision, situation, returns_losses, ) risk_manager_memory.add_situations([(situation, result)]) diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index 031922d8..c5f882fa 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -1,26 +1,26 @@ # TradingAgents/graph/setup.py -from typing import Dict + from langchain_openai import ChatOpenAI -from langgraph.graph import END, StateGraph, START +from langgraph.graph import END, START, StateGraph from langgraph.prebuilt import ToolNode from tradingagents.agents import ( - create_market_analyst, - create_social_media_analyst, - create_news_analyst, - create_fundamentals_analyst, - create_bull_researcher, - create_bear_researcher, - create_research_manager, - create_trader, - create_risky_debator, - create_neutral_debator, - create_safe_debator, - create_risk_manager, - create_msg_delete, AgentState, Toolkit, + create_bear_researcher, + create_bull_researcher, + create_fundamentals_analyst, + create_market_analyst, + create_msg_delete, + create_neutral_debator, + create_news_analyst, + create_research_manager, + create_risk_manager, + create_risky_debator, + create_safe_debator, + create_social_media_analyst, + create_trader, ) from .conditional_logic import ConditionalLogic @@ -34,7 +34,7 @@ class GraphSetup: quick_thinking_llm: ChatOpenAI, deep_thinking_llm: ChatOpenAI, toolkit: Toolkit, - tool_nodes: Dict[str, ToolNode], + tool_nodes: dict[str, ToolNode], bull_memory, bear_memory, trader_memory, @@ -55,7 +55,7 @@ class GraphSetup: self.conditional_logic = conditional_logic def setup_graph( - self, selected_analysts=["market", "social", "news", "fundamentals"] + self, selected_analysts=None, ): """Set up and compile the agent workflow graph. @@ -66,8 +66,11 @@ class GraphSetup: - "news": News analyst - "fundamentals": Fundamentals analyst """ + if selected_analysts is None: + selected_analysts = ["market", "social", "news", "fundamentals"] if len(selected_analysts) == 0: - raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") + msg = "Trading Agents Graph Setup Error: no analysts selected!" + raise ValueError(msg) # Create analyst nodes analyst_nodes = {} @@ -76,41 +79,41 @@ class GraphSetup: if "market" in selected_analysts: analyst_nodes["market"] = create_market_analyst( - self.quick_thinking_llm, self.toolkit + self.quick_thinking_llm, self.toolkit, ) delete_nodes["market"] = create_msg_delete() tool_nodes["market"] = self.tool_nodes["market"] if "social" in selected_analysts: analyst_nodes["social"] = create_social_media_analyst( - self.quick_thinking_llm, self.toolkit + self.quick_thinking_llm, self.toolkit, ) delete_nodes["social"] = create_msg_delete() tool_nodes["social"] = self.tool_nodes["social"] if "news" in selected_analysts: analyst_nodes["news"] = create_news_analyst( - self.quick_thinking_llm, self.toolkit + self.quick_thinking_llm, self.toolkit, ) delete_nodes["news"] = create_msg_delete() tool_nodes["news"] = self.tool_nodes["news"] if "fundamentals" in selected_analysts: analyst_nodes["fundamentals"] = create_fundamentals_analyst( - self.quick_thinking_llm, self.toolkit + self.quick_thinking_llm, self.toolkit, ) delete_nodes["fundamentals"] = create_msg_delete() tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] # Create researcher and manager nodes bull_researcher_node = create_bull_researcher( - self.quick_thinking_llm, self.bull_memory + self.quick_thinking_llm, self.bull_memory, ) bear_researcher_node = create_bear_researcher( - self.quick_thinking_llm, self.bear_memory + self.quick_thinking_llm, self.bear_memory, ) research_manager_node = create_research_manager( - self.deep_thinking_llm, self.invest_judge_memory + self.deep_thinking_llm, self.invest_judge_memory, ) trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) @@ -119,7 +122,7 @@ class GraphSetup: neutral_analyst = create_neutral_debator(self.quick_thinking_llm) safe_analyst = create_safe_debator(self.quick_thinking_llm) risk_manager_node = create_risk_manager( - self.deep_thinking_llm, self.risk_manager_memory + self.deep_thinking_llm, self.risk_manager_memory, ) # Create workflow @@ -129,7 +132,7 @@ class GraphSetup: for analyst_type, node in analyst_nodes.items(): workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) workflow.add_node( - f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type] + f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type], ) workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index ea76cfe3..c656b6b3 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -1,25 +1,24 @@ # TradingAgents/graph/trading_graph.py +import json import os from pathlib import Path -import json -from typing import Dict, Any +from typing import Any -from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI - +from langchain_openai import ChatOpenAI from langgraph.prebuilt import ToolNode from tradingagents.agents import Toolkit -from tradingagents.default_config import DEFAULT_CONFIG from tradingagents.agents.utils.memory import FinancialSituationMemory from tradingagents.dataflows.interface import set_config +from tradingagents.default_config import DEFAULT_CONFIG from .conditional_logic import ConditionalLogic -from .setup import GraphSetup from .propagation import Propagator from .reflection import Reflector +from .setup import GraphSetup from .signal_processing import SignalProcessor @@ -28,9 +27,9 @@ class TradingAgentsGraph: def __init__( self, - selected_analysts=["market", "social", "news", "fundamentals"], + selected_analysts=None, debug=False, - config: Dict[str, Any] = None, + config: dict[str, Any] | None = None, ): """Initialize the trading agents graph and components. @@ -39,6 +38,8 @@ class TradingAgentsGraph: debug: Whether to run in debug mode config: Configuration dictionary. If None, uses default config """ + if selected_analysts is None: + selected_analysts = ["market", "social", "news", "fundamentals"] self.debug = debug self.config = config or DEFAULT_CONFIG @@ -58,7 +59,7 @@ class TradingAgentsGraph: or self.config["llm_provider"] == "openrouter" ): self.deep_thinking_llm = ChatOpenAI( - model=self.config["deep_think_llm"], base_url=self.config["backend_url"] + model=self.config["deep_think_llm"], base_url=self.config["backend_url"], ) self.quick_thinking_llm = ChatOpenAI( model=self.config["quick_think_llm"], @@ -66,7 +67,7 @@ class TradingAgentsGraph: ) elif self.config["llm_provider"].lower() == "anthropic": self.deep_thinking_llm = ChatAnthropic( - model=self.config["deep_think_llm"], base_url=self.config["backend_url"] + model=self.config["deep_think_llm"], base_url=self.config["backend_url"], ) self.quick_thinking_llm = ChatAnthropic( model=self.config["quick_think_llm"], @@ -74,13 +75,14 @@ class TradingAgentsGraph: ) elif self.config["llm_provider"].lower() == "google": self.deep_thinking_llm = ChatGoogleGenerativeAI( - model=self.config["deep_think_llm"] + model=self.config["deep_think_llm"], ) self.quick_thinking_llm = ChatGoogleGenerativeAI( - model=self.config["quick_think_llm"] + model=self.config["quick_think_llm"], ) else: - raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}") + msg = f"Unsupported LLM provider: {self.config['llm_provider']}" + raise ValueError(msg) self.toolkit = Toolkit(config=self.config) @@ -89,10 +91,10 @@ class TradingAgentsGraph: self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config) self.invest_judge_memory = FinancialSituationMemory( - "invest_judge_memory", self.config + "invest_judge_memory", self.config, ) self.risk_manager_memory = FinancialSituationMemory( - "risk_manager_memory", self.config + "risk_manager_memory", self.config, ) # Create tool nodes @@ -125,7 +127,7 @@ class TradingAgentsGraph: # Set up the graph self.graph = self.graph_setup.setup_graph(selected_analysts) - def _create_tool_nodes(self) -> Dict[str, ToolNode]: + def _create_tool_nodes(self) -> dict[str, ToolNode]: """Create tool nodes for different data sources.""" return { "market": ToolNode( @@ -136,7 +138,7 @@ class TradingAgentsGraph: # offline tools self.toolkit.get_YFin_data, self.toolkit.get_stockstats_indicators_report, - ] + ], ), "social": ToolNode( [ @@ -144,7 +146,7 @@ class TradingAgentsGraph: self.toolkit.get_stock_news_openai, # offline tools self.toolkit.get_reddit_stock_info, - ] + ], ), "news": ToolNode( [ @@ -154,7 +156,7 @@ class TradingAgentsGraph: # offline tools self.toolkit.get_finnhub_news, self.toolkit.get_reddit_news, - ] + ], ), "fundamentals": ToolNode( [ @@ -166,7 +168,7 @@ class TradingAgentsGraph: self.toolkit.get_simfin_balance_sheet, self.toolkit.get_simfin_cashflow, self.toolkit.get_simfin_income_stmt, - ] + ], ), } @@ -177,7 +179,7 @@ class TradingAgentsGraph: # Initialize state init_agent_state = self.propagator.create_initial_state( - company_name, trade_date + company_name, trade_date, ) args = self.propagator.get_graph_args() @@ -250,19 +252,19 @@ class TradingAgentsGraph: def reflect_and_remember(self, returns_losses): """Reflect on decisions and update memory based on returns.""" self.reflector.reflect_bull_researcher( - self.curr_state, returns_losses, self.bull_memory + self.curr_state, returns_losses, self.bull_memory, ) self.reflector.reflect_bear_researcher( - self.curr_state, returns_losses, self.bear_memory + self.curr_state, returns_losses, self.bear_memory, ) self.reflector.reflect_trader( - self.curr_state, returns_losses, self.trader_memory + self.curr_state, returns_losses, self.trader_memory, ) self.reflector.reflect_invest_judge( - self.curr_state, returns_losses, self.invest_judge_memory + self.curr_state, returns_losses, self.invest_judge_memory, ) self.reflector.reflect_risk_manager( - self.curr_state, returns_losses, self.risk_manager_memory + self.curr_state, returns_losses, self.risk_manager_memory, ) def process_signal(self, full_signal):