From 850764ad7b1651640bd63c6445acf7643e28ec70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=90=E8=97=A4=E5=84=AA=E4=B8=80?= Date: Sun, 10 Aug 2025 23:25:04 +0900 Subject: [PATCH] Apply Black formatting to pass CI checks - Formatted 33 Python files with Black - Fixed code style to meet project standards - Ensures CI/CD pipeline passes formatting checks --- cli/main.py | 132 +++++++++++------- cli/utils.py | 1 - main.py | 4 +- run_tests.py | 52 +++---- test_hooks.py | 24 ++-- test_mypy.py | 3 +- test_setup_demo.py | 45 ++++-- tests/conftest.py | 3 +- tests/fixtures/sample_data.py | 16 ++- tests/integration/test_full_workflow.py | 42 ++++-- tests/unit/agents/test_market_analyst.py | 36 ++++- tests/unit/dataflows/test_finnhub_utils.py | 10 +- tests/unit/graph/test_trading_graph.py | 76 ++++++++-- .../agents/analysts/fundamentals_analyst.py | 2 +- .../agents/analysts/market_analyst.py | 2 +- tradingagents/agents/analysts/news_analyst.py | 2 +- .../agents/analysts/social_media_analyst.py | 2 +- .../agents/risk_mgmt/aggresive_debator.py | 3 +- .../agents/risk_mgmt/conservative_debator.py | 6 +- .../agents/risk_mgmt/neutral_debator.py | 3 +- tradingagents/agents/utils/agent_states.py | 33 +++-- tradingagents/agents/utils/agent_utils.py | 59 ++++---- tradingagents/config.py | 3 +- tradingagents/dataflows/config.py | 1 - tradingagents/dataflows/finnhub_utils.py | 5 +- tradingagents/dataflows/interface.py | 25 +++- tradingagents/dataflows/reddit_utils.py | 11 +- tradingagents/dataflows/stockstats_utils.py | 6 +- tradingagents/dataflows/yfin_utils.py | 6 +- tradingagents/graph/propagation.py | 4 +- tradingagents/graph/reflection.py | 31 +++- tradingagents/graph/setup.py | 30 ++-- tradingagents/graph/trading_graph.py | 35 +++-- 33 files changed, 475 insertions(+), 238 deletions(-) diff --git a/cli/main.py b/cli/main.py index 4041ec01..2828d230 100644 --- a/cli/main.py +++ b/cli/main.py @@ -148,7 +148,7 @@ class MessageBuffer: f"### News Analysis\n{self.report_sections['news_report']}", ) if self.report_sections["fundamentals_report"]: - fundamentals = self.report_sections['fundamentals_report'] + fundamentals = self.report_sections["fundamentals_report"] report_parts.append( f"### Fundamentals Analysis\n{fundamentals}", ) @@ -182,10 +182,12 @@ def create_layout(): Layout(name="footer", size=3), ) layout["main"].split_column( - Layout(name="upper", ratio=3), Layout(name="analysis", ratio=5), + Layout(name="upper", ratio=3), + Layout(name="analysis", ratio=5), ) layout["upper"].split_row( - Layout(name="progress", ratio=2), Layout(name="messages", ratio=3), + Layout(name="progress", ratio=2), + Layout(name="messages", ratio=3), ) return layout @@ -237,7 +239,9 @@ def update_display(layout, spinner_text=None): status = message_buffer.agent_status[first_agent] if status == "in_progress": spinner = Spinner( - "dots", text="[blue]in_progress[/blue]", style="bold cyan", + "dots", + text="[blue]in_progress[/blue]", + style="bold cyan", ) status_cell = spinner else: @@ -254,7 +258,9 @@ def update_display(layout, spinner_text=None): status = message_buffer.agent_status[agent] if status == "in_progress": spinner = Spinner( - "dots", text="[blue]in_progress[/blue]", style="bold cyan", + "dots", + text="[blue]in_progress[/blue]", + style="bold cyan", ) status_cell = spinner else: @@ -286,7 +292,10 @@ def update_display(layout, spinner_text=None): messages_table.add_column("Time", style="cyan", width=8, justify="center") messages_table.add_column("Type", style="green", width=10, justify="center") messages_table.add_column( - "Content", style="white", no_wrap=False, ratio=1, + "Content", + style="white", + no_wrap=False, + ratio=1, ) # Make content column expand # Combine tool calls and messages @@ -441,7 +450,9 @@ def get_user_selections(): # Step 1: Ticker symbol console.print( create_question_box( - "Step 1: Ticker Symbol", "Enter the ticker symbol to analyze", "SPY", + "Step 1: Ticker Symbol", + "Enter the ticker symbol to analyze", + "SPY", ), ) selected_ticker = get_ticker() @@ -460,7 +471,8 @@ def get_user_selections(): # Step 3: Select analysts console.print( create_question_box( - "Step 3: Analysts Team", "Select your LLM analyst agents for the analysis", + "Step 3: Analysts Team", + "Select your LLM analyst agents for the analysis", ), ) selected_analysts = select_analysts() @@ -471,21 +483,25 @@ def get_user_selections(): # Step 4: Research depth console.print( create_question_box( - "Step 4: Research Depth", "Select your research depth level", + "Step 4: Research Depth", + "Select your research depth level", ), ) selected_research_depth = select_research_depth() # Step 5: OpenAI backend console.print( - create_question_box("Step 5: OpenAI backend", "Select which service to talk to"), + create_question_box( + "Step 5: OpenAI backend", "Select which service to talk to" + ), ) selected_llm_provider, backend_url = select_llm_provider() # Step 6: Thinking agents console.print( create_question_box( - "Step 6: Thinking Agents", "Select your thinking agents for analysis", + "Step 6: Thinking Agents", + "Select your thinking agents for analysis", ), ) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) @@ -737,7 +753,9 @@ def run_analysis(): # Initialize the graph graph = TradingAgentsGraph( - [analyst.value for analyst in selections["analysts"]], config=config, debug=True, + [analyst.value for analyst in selections["analysts"]], + config=config, + debug=True, ) # Create result directory @@ -796,10 +814,12 @@ def run_analysis(): message_buffer.add_message = save_message_decorator(message_buffer, "add_message") message_buffer.add_tool_call = save_tool_call_decorator( - message_buffer, "add_tool_call", + message_buffer, + "add_tool_call", ) message_buffer.update_report_section = save_report_section_decorator( - message_buffer, "update_report_section", + message_buffer, + "update_report_section", ) # Now start the display layout @@ -812,7 +832,8 @@ def run_analysis(): # Add initial messages message_buffer.add_message("System", f"Selected ticker: {selections['ticker']}") message_buffer.add_message( - "System", f"Analysis date: {selections['analysis_date']}", + "System", + f"Analysis date: {selections['analysis_date']}", ) message_buffer.add_message( "System", @@ -843,7 +864,8 @@ def run_analysis(): # Initialize state and get graph args init_agent_state = graph.propagator.create_initial_state( - selections["ticker"], selections["analysis_date"], + selections["ticker"], + selections["analysis_date"], ) args = graph.propagator.get_graph_args() @@ -873,7 +895,8 @@ def run_analysis(): # Handle both dictionary and object tool calls if isinstance(tool_call, dict): message_buffer.add_tool_call( - tool_call["name"], tool_call["args"], + tool_call["name"], + tool_call["args"], ) else: message_buffer.add_tool_call(tool_call.name, tool_call.args) @@ -882,51 +905,57 @@ def run_analysis(): # Analyst Team Reports if chunk.get("market_report"): message_buffer.update_report_section( - "market_report", chunk["market_report"], + "market_report", + chunk["market_report"], ) message_buffer.update_agent_status("Market Analyst", "completed") # Set next analyst to in_progress if "social" in selections["analysts"]: message_buffer.update_agent_status( - "Social Analyst", "in_progress", + "Social Analyst", + "in_progress", ) if chunk.get("sentiment_report"): message_buffer.update_report_section( - "sentiment_report", chunk["sentiment_report"], + "sentiment_report", + chunk["sentiment_report"], ) message_buffer.update_agent_status("Social Analyst", "completed") # Set next analyst to in_progress if "news" in selections["analysts"]: message_buffer.update_agent_status( - "News Analyst", "in_progress", + "News Analyst", + "in_progress", ) if chunk.get("news_report"): message_buffer.update_report_section( - "news_report", chunk["news_report"], + "news_report", + chunk["news_report"], ) message_buffer.update_agent_status("News Analyst", "completed") # Set next analyst to in_progress if "fundamentals" in selections["analysts"]: message_buffer.update_agent_status( - "Fundamentals Analyst", "in_progress", + "Fundamentals Analyst", + "in_progress", ) if chunk.get("fundamentals_report"): message_buffer.update_report_section( - "fundamentals_report", chunk["fundamentals_report"], + "fundamentals_report", + chunk["fundamentals_report"], ) message_buffer.update_agent_status( - "Fundamentals Analyst", "completed", + "Fundamentals Analyst", + "completed", ) # Set all research team members to in_progress update_research_team_status("in_progress") # Research Team - Handle Investment Debate State - if ( - chunk.get("investment_debate_state") - ): + if chunk.get("investment_debate_state"): debate_state = chunk["investment_debate_state"] # Update Bull Researcher status and report @@ -960,9 +989,7 @@ def run_analysis(): ) # Update Research Manager status and final decision - if ( - debate_state.get("judge_decision") - ): + if debate_state.get("judge_decision"): # Keep all research team members in progress until final decision update_research_team_status("in_progress") message_buffer.add_message( @@ -978,15 +1005,15 @@ def run_analysis(): update_research_team_status("completed") # Set first risk analyst to in_progress message_buffer.update_agent_status( - "Risky Analyst", "in_progress", + "Risky Analyst", + "in_progress", ) # Trading Team - if ( - chunk.get("trader_investment_plan") - ): + if chunk.get("trader_investment_plan"): message_buffer.update_report_section( - "trader_investment_plan", chunk["trader_investment_plan"], + "trader_investment_plan", + chunk["trader_investment_plan"], ) # Set first risk analyst to in_progress message_buffer.update_agent_status("Risky Analyst", "in_progress") @@ -996,11 +1023,10 @@ def run_analysis(): risk_state = chunk["risk_debate_state"] # Update Risky Analyst status and report - if ( - risk_state.get("current_risky_response") - ): + if risk_state.get("current_risky_response"): message_buffer.update_agent_status( - "Risky Analyst", "in_progress", + "Risky Analyst", + "in_progress", ) message_buffer.add_message( "Reasoning", @@ -1013,11 +1039,10 @@ def run_analysis(): ) # Update Safe Analyst status and report - if ( - risk_state.get("current_safe_response") - ): + if risk_state.get("current_safe_response"): message_buffer.update_agent_status( - "Safe Analyst", "in_progress", + "Safe Analyst", + "in_progress", ) message_buffer.add_message( "Reasoning", @@ -1030,11 +1055,10 @@ def run_analysis(): ) # Update Neutral Analyst status and report - if ( - risk_state.get("current_neutral_response") - ): + if risk_state.get("current_neutral_response"): message_buffer.update_agent_status( - "Neutral Analyst", "in_progress", + "Neutral Analyst", + "in_progress", ) message_buffer.add_message( "Reasoning", @@ -1049,7 +1073,8 @@ def run_analysis(): # Update Portfolio Manager status and final decision if risk_state.get("judge_decision"): message_buffer.update_agent_status( - "Portfolio Manager", "in_progress", + "Portfolio Manager", + "in_progress", ) message_buffer.add_message( "Reasoning", @@ -1064,10 +1089,12 @@ def run_analysis(): message_buffer.update_agent_status("Risky Analyst", "completed") message_buffer.update_agent_status("Safe Analyst", "completed") message_buffer.update_agent_status( - "Neutral Analyst", "completed", + "Neutral Analyst", + "completed", ) message_buffer.update_agent_status( - "Portfolio Manager", "completed", + "Portfolio Manager", + "completed", ) # Update the display @@ -1084,7 +1111,8 @@ def run_analysis(): message_buffer.update_agent_status(agent, "completed") message_buffer.add_message( - "Analysis", f"Completed analysis for {selections['analysis_date']}", + "Analysis", + f"Completed analysis for {selections['analysis_date']}", ) # Update final report sections diff --git a/cli/utils.py b/cli/utils.py index e8df4a78..f42cd85b 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,4 +1,3 @@ - import sys import questionary diff --git a/main.py b/main.py index 6c8ae3d9..22e57793 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,9 @@ from tradingagents.default_config import DEFAULT_CONFIG # Create a custom config config = DEFAULT_CONFIG.copy() config["llm_provider"] = "google" # Use a different model -config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend +config["backend_url"] = ( + "https://generativelanguage.googleapis.com/v1" # Use a different backend +) config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model config["max_debate_rounds"] = 1 # Increase debate rounds diff --git a/run_tests.py b/run_tests.py index ed7e130d..daf7ccb9 100755 --- a/run_tests.py +++ b/run_tests.py @@ -16,10 +16,10 @@ def run_command(cmd, description=""): """Run a command and handle errors.""" if description: print(f"\n๐Ÿ”„ {description}") - + print(f"Running: {' '.join(cmd)}") result = subprocess.run(cmd, capture_output=False) - + if result.returncode != 0: print(f"โŒ Command failed with return code {result.returncode}") sys.exit(result.returncode) @@ -33,48 +33,42 @@ def main(): parser.add_argument( "test_type", choices=["unit", "integration", "all", "coverage", "fast", "slow", "lint"], - help="Type of tests to run" - ) - parser.add_argument( - "--verbose", "-v", action="store_true", help="Verbose output" + help="Type of tests to run", ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") parser.add_argument( "--parallel", "-p", action="store_true", help="Run tests in parallel" ) - parser.add_argument( - "--file", "-f", help="Run specific test file" - ) - parser.add_argument( - "--pattern", "-k", help="Run tests matching pattern" - ) - + parser.add_argument("--file", "-f", help="Run specific test file") + parser.add_argument("--pattern", "-k", help="Run tests matching pattern") + args = parser.parse_args() - + # Base pytest command base_cmd = ["python", "-m", "pytest"] - + if args.verbose: base_cmd.append("-v") - + if args.parallel: base_cmd.extend(["-n", "auto"]) - + if args.pattern: base_cmd.extend(["-k", args.pattern]) - + # Configure based on test type if args.test_type == "unit": cmd = base_cmd + ["tests/unit/", "-m", "unit"] run_command(cmd, "Running unit tests") - + elif args.test_type == "integration": cmd = base_cmd + ["tests/integration/", "-m", "integration"] run_command(cmd, "Running integration tests") - + elif args.test_type == "all": cmd = base_cmd + ["tests/"] run_command(cmd, "Running all tests") - + elif args.test_type == "coverage": cmd = base_cmd + [ "tests/", @@ -88,28 +82,28 @@ def main(): print("\n๐Ÿ“Š Coverage report generated:") print(" - HTML: htmlcov/index.html") print(" - XML: coverage.xml") - + elif args.test_type == "fast": cmd = base_cmd + ["tests/unit/", "-m", "unit", "--durations=10"] run_command(cmd, "Running fast unit tests") - + elif args.test_type == "slow": cmd = base_cmd + ["tests/", "-m", "slow", "--timeout=600"] run_command(cmd, "Running slow tests") - + elif args.test_type == "lint": # Run mypy cmd = ["python", "-m", "mypy", "tradingagents/", "cli/", "tests/"] run_command(cmd, "Running mypy type checking") - + # Run pytest on tests only for syntax cmd = base_cmd + ["tests/", "--collect-only"] run_command(cmd, "Validating test syntax") - + elif args.file: cmd = base_cmd + [args.file] run_command(cmd, f"Running tests in {args.file}") - + print("\n๐ŸŽ‰ All tests completed successfully!") @@ -117,5 +111,5 @@ if __name__ == "__main__": # Ensure we're in the project directory script_dir = Path(__file__).parent os.chdir(script_dir) - - main() \ No newline at end of file + + main() diff --git a/test_hooks.py b/test_hooks.py index e4dd48b4..30c393fa 100644 --- a/test_hooks.py +++ b/test_hooks.py @@ -1,19 +1,23 @@ - -def poorly_formatted_function(x,y,z): # Missing type hints +def poorly_formatted_function(x, y, z): # Missing type hints """This function has formatting issues.""" - result=x+y*z # Missing spaces around operators - if result>100: # Missing spaces - print( "Result is large" ) # Extra spaces in parentheses + result = x + y * z # Missing spaces around operators + if result > 100: # Missing spaces + print("Result is large") # Extra spaces in parentheses return result + # Long line that Black will wrap -very_long_variable_name_that_exceeds_the_standard_line_length_limit = "This is a very long string that will be wrapped by Black formatter" +very_long_variable_name_that_exceeds_the_standard_line_length_limit = ( + "This is a very long string that will be wrapped by Black formatter" +) + class MyClass: - def __init__(self,name:str,age:int): # Missing space after comma - self.name=name # Missing spaces around = - self.age=age + def __init__(self, name: str, age: int): # Missing space after comma + self.name = name # Missing spaces around = + self.age = age + # Function with wrong return type hint def get_number() -> str: - return 123 # Returns int but type hint says str \ No newline at end of file + return 123 # Returns int but type hint says str diff --git a/test_mypy.py b/test_mypy.py index a62a6e40..bb744c14 100644 --- a/test_mypy.py +++ b/test_mypy.py @@ -2,6 +2,7 @@ def add_numbers(a: int, b: int) -> int: """Add two numbers and return the result.""" return a + b + # Test the function result = add_numbers(1, 2) -print(f"Result: {result}") \ No newline at end of file +print(f"Result: {result}") diff --git a/test_setup_demo.py b/test_setup_demo.py index 75b8ad2d..5f573023 100755 --- a/test_setup_demo.py +++ b/test_setup_demo.py @@ -15,7 +15,7 @@ def run_command(cmd, description=""): """Run a command and return success status.""" if description: print(f"\n๐Ÿ”„ {description}") - + print(f"Running: {' '.join(cmd)}") try: result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) @@ -38,45 +38,60 @@ def main(): """Run setup verification tests.""" print("๐Ÿงช TradingAgents Test Setup Verification") print("=" * 50) - + # Change to project directory project_dir = Path(__file__).parent os.chdir(project_dir) - + success_count = 0 total_tests = 0 - + # Test 1: Check if pytest is installed and can discover tests total_tests += 1 - if run_command(["python", "-m", "pytest", "--version"], "Checking pytest installation"): + if run_command( + ["python", "-m", "pytest", "--version"], "Checking pytest installation" + ): success_count += 1 - + # Test 2: Test discovery total_tests += 1 - if run_command(["python", "-m", "pytest", "tests/", "--collect-only", "-q"], "Testing test discovery"): + if run_command( + ["python", "-m", "pytest", "tests/", "--collect-only", "-q"], + "Testing test discovery", + ): success_count += 1 - + # Test 3: Check if mypy can run total_tests += 1 if run_command(["python", "-m", "mypy", "--version"], "Checking mypy installation"): success_count += 1 - + # Test 4: Run a simple syntax check on test files total_tests += 1 - if run_command(["python", "-c", "import tests.conftest; print('Test imports work!')"], "Testing test imports"): + if run_command( + ["python", "-c", "import tests.conftest; print('Test imports work!')"], + "Testing test imports", + ): success_count += 1 - + # Test 5: Check if we can import the main module total_tests += 1 - if run_command(["python", "-c", "import tradingagents.config; print('Main module imports work!')"], "Testing main module imports"): + if run_command( + [ + "python", + "-c", + "import tradingagents.config; print('Main module imports work!')", + ], + "Testing main module imports", + ): success_count += 1 - + # Summary print("\n" + "=" * 50) print("๐Ÿ“Š Test Setup Verification Results:") print(f"โœ… Successful: {success_count}/{total_tests}") print(f"โŒ Failed: {total_tests - success_count}/{total_tests}") - + if success_count == total_tests: print("\n๐ŸŽ‰ All verification tests passed! Your test setup is ready.") print("\n๐Ÿ“š Next steps:") @@ -95,4 +110,4 @@ def main(): if __name__ == "__main__": - sys.exit(main()) \ No newline at end of file + sys.exit(main()) diff --git a/tests/conftest.py b/tests/conftest.py index 1f2fb16c..fcb04492 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -173,7 +173,8 @@ def mock_memory(): def pytest_configure(config): """Configure pytest with custom markers.""" config.addinivalue_line( - "markers", "integration: mark test as integration test (slow)", + "markers", + "integration: mark test as integration test (slow)", ) config.addinivalue_line("markers", "unit: mark test as unit test (fast)") config.addinivalue_line("markers", "api: mark test as requiring API access") diff --git a/tests/fixtures/sample_data.py b/tests/fixtures/sample_data.py index 94a15135..88e211c8 100644 --- a/tests/fixtures/sample_data.py +++ b/tests/fixtures/sample_data.py @@ -36,7 +36,8 @@ class SampleDataFactory: @staticmethod def create_finnhub_news_data( - ticker: str = "AAPL", count: int = 10, + ticker: str = "AAPL", + count: int = 10, ) -> dict[str, list[dict[str, Any]]]: """Create sample FinnHub news data for testing.""" base_date = datetime(2024, 5, 10) @@ -136,7 +137,8 @@ class SampleDataFactory: @staticmethod def create_financial_statements_data( - ticker: str = "AAPL", period: str = "annual", + ticker: str = "AAPL", + period: str = "annual", ) -> dict[str, list[dict[str, Any]]]: """Create sample financial statements data for testing.""" if period == "annual": @@ -271,10 +273,12 @@ class SampleDataFactory: ticker, ), "financial_annual": SampleDataFactory.create_financial_statements_data( - ticker, "annual", + ticker, + "annual", ), "financial_quarterly": SampleDataFactory.create_financial_statements_data( - ticker, "quarterly", + ticker, + "quarterly", ), "social_sentiment": SampleDataFactory.create_social_sentiment_data(ticker), "technical_indicators": SampleDataFactory.create_technical_indicators_data( @@ -343,7 +347,9 @@ def save_sample_data_to_files(base_path: str, ticker: str = "AAPL") -> None: # Save quarterly data separately quarterly_path = os.path.join( - finnhub_path, "fin_as_reported", f"{ticker}_quarterly_data_formatted.json", + finnhub_path, + "fin_as_reported", + f"{ticker}_quarterly_data_formatted.json", ) with open(quarterly_path, "w") as f: json.dump(dataset["financial_quarterly"], f, indent=2) diff --git a/tests/integration/test_full_workflow.py b/tests/integration/test_full_workflow.py index 23fbe3ce..bbe6c56a 100644 --- a/tests/integration/test_full_workflow.py +++ b/tests/integration/test_full_workflow.py @@ -31,7 +31,10 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_end_to_end_trading_workflow( - self, mock_toolkit, mock_chat_openai, integration_config, + self, + mock_toolkit, + mock_chat_openai, + integration_config, ): """Test complete end-to-end trading workflow.""" # Setup mocks @@ -86,7 +89,10 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_multiple_analysts_integration( - self, mock_toolkit, mock_chat_openai, integration_config, + self, + mock_toolkit, + mock_chat_openai, + integration_config, ): """Test integration with different analyst combinations.""" analyst_combinations = [ @@ -114,7 +120,8 @@ class TestFullWorkflowIntegration: with patch("tradingagents.graph.trading_graph.set_config"): # Test each analyst combination trading_graph = TradingAgentsGraph( - selected_analysts=analysts, config=integration_config, + selected_analysts=analysts, + config=integration_config, ) trading_graph.graph = mock_graph @@ -134,7 +141,8 @@ class TestFullWorkflowIntegration: # Execute with patch("builtins.open", create=True), patch("json.dump"): final_state, decision = trading_graph.propagate( - "TSLA", "2024-05-15", + "TSLA", + "2024-05-15", ) # Verify @@ -144,7 +152,10 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_memory_and_reflection_integration( - self, mock_toolkit, mock_chat_openai, integration_config, + self, + mock_toolkit, + mock_chat_openai, + integration_config, ): """Test integration of memory and reflection components.""" # Setup @@ -208,7 +219,10 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_debug_mode_integration( - self, mock_toolkit, mock_chat_openai, integration_config, + self, + mock_toolkit, + mock_chat_openai, + integration_config, ): """Test integration in debug mode.""" # Setup @@ -240,7 +254,8 @@ class TestFullWorkflowIntegration: with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.set_config"): trading_graph = TradingAgentsGraph( - debug=True, config=integration_config, + debug=True, + config=integration_config, ) trading_graph.graph = mock_graph @@ -276,7 +291,12 @@ class TestFullWorkflowIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_multiple_stocks_integration( - self, mock_toolkit, mock_chat_openai, ticker, date, integration_config, + self, + mock_toolkit, + mock_chat_openai, + ticker, + date, + integration_config, ): """Test integration with different stocks and dates.""" # Setup @@ -382,7 +402,11 @@ class TestPerformanceIntegration: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_multiple_consecutive_runs( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_openai, + sample_config, + temp_data_dir, ): """Test multiple consecutive trading decisions.""" sample_config["project_dir"] = temp_data_dir diff --git a/tests/unit/agents/test_market_analyst.py b/tests/unit/agents/test_market_analyst.py index b44cf266..4dad8ac2 100644 --- a/tests/unit/agents/test_market_analyst.py +++ b/tests/unit/agents/test_market_analyst.py @@ -17,7 +17,10 @@ class TestMarketAnalyst: assert callable(analyst_node) def test_market_analyst_node_basic_execution( - self, mock_llm, mock_toolkit, sample_agent_state, + self, + mock_llm, + mock_toolkit, + sample_agent_state, ): """Test basic execution of market analyst node.""" # Setup @@ -39,7 +42,10 @@ class TestMarketAnalyst: assert result["market_report"] == "Market analysis complete" def test_market_analyst_uses_online_tools_when_configured( - self, mock_llm, mock_toolkit, sample_agent_state, + self, + mock_llm, + mock_toolkit, + sample_agent_state, ): """Test that analyst uses online tools when configured.""" # Setup @@ -64,7 +70,10 @@ class TestMarketAnalyst: assert "get_YFin_data_online" in str(tool_names) or len(bound_tools) == 2 def test_market_analyst_uses_offline_tools_when_configured( - self, mock_llm, mock_toolkit, sample_agent_state, + self, + mock_llm, + mock_toolkit, + sample_agent_state, ): """Test that analyst uses offline tools when configured.""" # Setup @@ -88,7 +97,10 @@ class TestMarketAnalyst: assert len(bound_tools) == 2 # Should have 2 offline tools def test_market_analyst_processes_state_variables( - self, mock_llm, mock_toolkit, sample_agent_state, + self, + mock_llm, + mock_toolkit, + sample_agent_state, ): """Test that market analyst correctly processes state variables.""" # Setup @@ -112,7 +124,10 @@ class TestMarketAnalyst: assert result["market_report"] == "Analysis for AAPL on 2024-05-10" def test_market_analyst_handles_empty_tool_calls( - self, mock_llm, mock_toolkit, sample_agent_state, + self, + mock_llm, + mock_toolkit, + sample_agent_state, ): """Test handling when no tool calls are made.""" # Setup @@ -132,7 +147,10 @@ class TestMarketAnalyst: assert result["messages"] == [mock_result] def test_market_analyst_with_tool_calls( - self, mock_llm, mock_toolkit, sample_agent_state, + self, + mock_llm, + mock_toolkit, + sample_agent_state, ): """Test handling when tool calls are present.""" # Setup @@ -153,7 +171,11 @@ class TestMarketAnalyst: @pytest.mark.parametrize("online_tools", [True, False]) def test_market_analyst_tool_configuration( - self, mock_llm, mock_toolkit, sample_agent_state, online_tools, + self, + mock_llm, + mock_toolkit, + sample_agent_state, + online_tools, ): """Test tool configuration for both online and offline modes.""" # Setup diff --git a/tests/unit/dataflows/test_finnhub_utils.py b/tests/unit/dataflows/test_finnhub_utils.py index 48953a40..4f651882 100644 --- a/tests/unit/dataflows/test_finnhub_utils.py +++ b/tests/unit/dataflows/test_finnhub_utils.py @@ -190,7 +190,10 @@ class TestFinnhubUtils: # Test without period expected_path_no_period = os.path.join( - temp_data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json", + temp_data_dir, + "finnhub_data", + data_type, + f"{ticker}_data_formatted.json", ) # Test with period @@ -248,7 +251,10 @@ class TestFinnhubUtils: ], ) def test_get_data_in_range_various_data_types( - self, temp_data_dir, data_type, period, + self, + temp_data_dir, + data_type, + period, ): """Test get_data_in_range with various data types.""" ticker = "TEST" diff --git a/tests/unit/graph/test_trading_graph.py b/tests/unit/graph/test_trading_graph.py index e5fbb923..e043d3aa 100644 --- a/tests/unit/graph/test_trading_graph.py +++ b/tests/unit/graph/test_trading_graph.py @@ -45,7 +45,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_init_with_debug( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_openai, + sample_config, + temp_data_dir, ): """Test initialization with debug mode enabled.""" sample_config["project_dir"] = temp_data_dir @@ -63,7 +67,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatAnthropic") @patch("tradingagents.graph.trading_graph.Toolkit") def test_init_with_anthropic( - self, mock_toolkit, mock_chat_anthropic, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_anthropic, + sample_config, + temp_data_dir, ): """Test initialization with Anthropic LLM provider.""" sample_config["project_dir"] = temp_data_dir @@ -82,7 +90,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatGoogleGenerativeAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_init_with_google( - self, mock_toolkit, mock_chat_google, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_google, + sample_config, + temp_data_dir, ): """Test initialization with Google LLM provider.""" sample_config["project_dir"] = temp_data_dir @@ -100,7 +112,10 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.Toolkit") def test_init_unsupported_llm_provider( - self, mock_toolkit, sample_config, temp_data_dir, + self, + mock_toolkit, + sample_config, + temp_data_dir, ): """Test initialization with unsupported LLM provider raises error.""" sample_config["project_dir"] = temp_data_dir @@ -115,7 +130,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_create_tool_nodes( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_openai, + sample_config, + temp_data_dir, ): """Test creation of tool nodes.""" sample_config["project_dir"] = temp_data_dir @@ -143,7 +162,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_propagate_basic( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_openai, + sample_config, + temp_data_dir, ): """Test basic propagate functionality.""" sample_config["project_dir"] = temp_data_dir @@ -206,7 +229,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_propagate_debug_mode( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_openai, + sample_config, + temp_data_dir, ): """Test propagate in debug mode.""" sample_config["project_dir"] = temp_data_dir @@ -245,7 +272,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_log_state( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_openai, + sample_config, + temp_data_dir, ): """Test state logging functionality.""" sample_config["project_dir"] = temp_data_dir @@ -300,7 +331,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_reflect_and_remember( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_openai, + sample_config, + temp_data_dir, ): """Test reflection and memory update functionality.""" sample_config["project_dir"] = temp_data_dir @@ -309,9 +344,12 @@ class TestTradingAgentsGraph: mock_toolkit_instance = Mock() mock_toolkit.return_value = mock_toolkit_instance - with patch( - "tradingagents.graph.trading_graph.FinancialSituationMemory", - ), patch("tradingagents.graph.trading_graph.set_config"): + with ( + patch( + "tradingagents.graph.trading_graph.FinancialSituationMemory", + ), + patch("tradingagents.graph.trading_graph.set_config"), + ): graph = TradingAgentsGraph(config=sample_config) # Set up current state @@ -339,7 +377,11 @@ class TestTradingAgentsGraph: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_process_signal( - self, mock_toolkit, mock_chat_openai, sample_config, temp_data_dir, + self, + mock_toolkit, + mock_chat_openai, + sample_config, + temp_data_dir, ): """Test signal processing functionality.""" sample_config["project_dir"] = temp_data_dir @@ -388,7 +430,8 @@ class TestTradingAgentsGraph: with patch("tradingagents.graph.trading_graph.FinancialSituationMemory"): with patch("tradingagents.graph.trading_graph.set_config"): TradingAgentsGraph( - selected_analysts=selected_analysts, config=sample_config, + selected_analysts=selected_analysts, + config=sample_config, ) # Verify graph was set up with selected analysts @@ -416,7 +459,10 @@ class TestTradingAgentsGraphErrorHandling: @patch("tradingagents.graph.trading_graph.ChatOpenAI") @patch("tradingagents.graph.trading_graph.Toolkit") def test_directory_creation_failure( - self, mock_toolkit, mock_chat_openai, sample_config, + self, + mock_toolkit, + mock_chat_openai, + sample_config, ): """Test handling when directory creation fails.""" sample_config["project_dir"] = "/invalid/path/that/cannot/be/created" diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index b4fa6b45..7af64723 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -20,7 +20,7 @@ def create_fundamentals_analyst(llm, toolkit): system_message = ( "You are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.", + " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.", ) prompt = ChatPromptTemplate.from_messages( diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index a741c66b..e3ed5f96 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -45,7 +45,7 @@ Volume-Based Indicators: - vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses. - Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions.""" - """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" + """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.""" ) prompt = ChatPromptTemplate.from_messages( diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index 598f890c..3e978c91 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -17,7 +17,7 @@ def create_news_analyst(llm, toolkit): system_message = ( "You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""" + """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""" ) prompt = ChatPromptTemplate.from_messages( diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index b5e3aa94..2ea84d05 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -16,7 +16,7 @@ def create_social_media_analyst(llm, toolkit): system_message = ( "You are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions." - """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""", + """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""", ) prompt = ChatPromptTemplate.from_messages( diff --git a/tradingagents/agents/risk_mgmt/aggresive_debator.py b/tradingagents/agents/risk_mgmt/aggresive_debator.py index 86c8c6b3..8da2725c 100644 --- a/tradingagents/agents/risk_mgmt/aggresive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggresive_debator.py @@ -41,7 +41,8 @@ Engage actively by addressing any specific concerns raised, refuting the weaknes "current_risky_response": argument, "current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_neutral_response": risk_debate_state.get( - "current_neutral_response", "", + "current_neutral_response", + "", ), "count": risk_debate_state["count"] + 1, } diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index c574d2d3..e48fa433 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -39,11 +39,13 @@ Engage by questioning their optimism and emphasizing the potential downsides the "neutral_history": risk_debate_state.get("neutral_history", ""), "latest_speaker": "Safe", "current_risky_response": risk_debate_state.get( - "current_risky_response", "", + "current_risky_response", + "", ), "current_safe_response": argument, "current_neutral_response": risk_debate_state.get( - "current_neutral_response", "", + "current_neutral_response", + "", ), "count": risk_debate_state["count"] + 1, } diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index f965a4e1..86b3c583 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -39,7 +39,8 @@ Engage actively by analyzing both sides critically, addressing weaknesses in the "neutral_history": neutral_history + "\n" + argument, "latest_speaker": "Neutral", "current_risky_response": risk_debate_state.get( - "current_risky_response", "", + "current_risky_response", + "", ), "current_safe_response": risk_debate_state.get("current_safe_response", ""), "current_neutral_response": argument, diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index d94ea284..89d77bdb 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -9,10 +9,12 @@ from typing_extensions import TypedDict # Researcher team state class InvestDebateState(TypedDict): bull_history: Annotated[ - str, "Bullish Conversation history", + str, + "Bullish Conversation history", ] # Bullish Conversation history bear_history: Annotated[ - str, "Bearish Conversation history", + str, + "Bearish Conversation history", ] # Bullish Conversation history history: Annotated[str, "Conversation history"] # Conversation history current_response: Annotated[str, "Latest response"] # Last response @@ -23,24 +25,30 @@ class InvestDebateState(TypedDict): # Risk management team state class RiskDebateState(TypedDict): risky_history: Annotated[ - str, "Risky Agent's Conversation history", + str, + "Risky Agent's Conversation history", ] # Conversation history safe_history: Annotated[ - str, "Safe Agent's Conversation history", + str, + "Safe Agent's Conversation history", ] # Conversation history neutral_history: Annotated[ - str, "Neutral Agent's Conversation history", + str, + "Neutral Agent's Conversation history", ] # Conversation history history: Annotated[str, "Conversation history"] # Conversation history latest_speaker: Annotated[str, "Analyst that spoke last"] current_risky_response: Annotated[ - str, "Latest response by the risky analyst", + str, + "Latest response by the risky analyst", ] # Last response current_safe_response: Annotated[ - str, "Latest response by the safe analyst", + str, + "Latest response by the safe analyst", ] # Last response current_neutral_response: Annotated[ - str, "Latest response by the neutral analyst", + str, + "Latest response by the neutral analyst", ] # Last response judge_decision: Annotated[str, "Judge's decision"] count: Annotated[int, "Length of the current conversation"] # Conversation length @@ -56,13 +64,15 @@ class AgentState(MessagesState): market_report: Annotated[str, "Report from the Market Analyst"] sentiment_report: Annotated[str, "Report from the Social Media Analyst"] news_report: Annotated[ - str, "Report from the News Researcher of current world affairs", + str, + "Report from the News Researcher of current world affairs", ] fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"] # researcher team discussion step investment_debate_state: Annotated[ - InvestDebateState, "Current state of the debate on if to invest or not", + InvestDebateState, + "Current state of the debate on if to invest or not", ] investment_plan: Annotated[str, "Plan generated by the Analyst"] @@ -70,6 +80,7 @@ class AgentState(MessagesState): # risk management team discussion step risk_debate_state: Annotated[ - RiskDebateState, "Current state of the debate on evaluating risk", + RiskDebateState, + "Current state of the debate on evaluating risk", ] final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 7978ce48..5ed3f1c3 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -56,7 +56,6 @@ class Toolkit: return interface.get_reddit_global_news(curr_date, 7, 5) - @staticmethod @tool def get_finnhub_news( @@ -84,10 +83,11 @@ class Toolkit: look_back_days = (end_date - start_date).days return interface.get_finnhub_news( - ticker, end_date_str, look_back_days, + ticker, + end_date_str, + look_back_days, ) - @staticmethod @tool def get_reddit_stock_info( @@ -108,7 +108,6 @@ class Toolkit: return interface.get_reddit_company_news(ticker, curr_date, 7, 5) - @staticmethod @tool def get_YFin_data( @@ -128,7 +127,6 @@ class Toolkit: return interface.get_YFin_data(symbol, start_date, end_date) - @staticmethod @tool def get_YFin_data_online( @@ -148,16 +146,17 @@ class Toolkit: return interface.get_YFin_data_online(symbol, start_date, end_date) - @staticmethod @tool def get_stockstats_indicators_report( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[ - str, "technical indicator to get the analysis and report of", + str, + "technical indicator to get the analysis and report of", ], curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd", + str, + "The current trading date you are trading on, YYYY-mm-dd", ], look_back_days: Annotated[int, "how many days to look back"] = 30, ) -> str: @@ -173,19 +172,24 @@ class Toolkit: """ return interface.get_stock_stats_indicators_window( - symbol, indicator, curr_date, look_back_days, False, + symbol, + indicator, + curr_date, + look_back_days, + False, ) - @staticmethod @tool def get_stockstats_indicators_report_online( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[ - str, "technical indicator to get the analysis and report of", + str, + "technical indicator to get the analysis and report of", ], curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd", + str, + "The current trading date you are trading on, YYYY-mm-dd", ], look_back_days: Annotated[int, "how many days to look back"] = 30, ) -> str: @@ -201,10 +205,13 @@ class Toolkit: """ return interface.get_stock_stats_indicators_window( - symbol, indicator, curr_date, look_back_days, True, + symbol, + indicator, + curr_date, + look_back_days, + True, ) - @staticmethod @tool def get_finnhub_company_insider_sentiment( @@ -224,10 +231,11 @@ class Toolkit: """ return interface.get_finnhub_company_insider_sentiment( - ticker, curr_date, 30, + ticker, + curr_date, + 30, ) - @staticmethod @tool def get_finnhub_company_insider_transactions( @@ -247,10 +255,11 @@ class Toolkit: """ return interface.get_finnhub_company_insider_transactions( - ticker, curr_date, 30, + ticker, + curr_date, + 30, ) - @staticmethod @tool def get_simfin_balance_sheet( @@ -273,7 +282,6 @@ class Toolkit: return interface.get_simfin_balance_sheet(ticker, freq, curr_date) - @staticmethod @tool def get_simfin_cashflow( @@ -296,7 +304,6 @@ class Toolkit: return interface.get_simfin_cashflow(ticker, freq, curr_date) - @staticmethod @tool def get_simfin_income_stmt( @@ -318,10 +325,11 @@ class Toolkit: """ return interface.get_simfin_income_statements( - ticker, freq, curr_date, + ticker, + freq, + curr_date, ) - @staticmethod @tool def get_google_news( @@ -340,7 +348,6 @@ class Toolkit: return interface.get_google_news(query, curr_date, 7) - @staticmethod @tool def get_stock_news_openai( @@ -358,7 +365,6 @@ class Toolkit: return interface.get_stock_news_openai(ticker, curr_date) - @staticmethod @tool def get_global_news_openai( @@ -374,7 +380,6 @@ class Toolkit: return interface.get_global_news_openai(curr_date) - @staticmethod @tool def get_fundamentals_openai( @@ -391,6 +396,6 @@ class Toolkit: """ return interface.get_fundamentals_openai( - ticker, curr_date, + ticker, + curr_date, ) - diff --git a/tradingagents/config.py b/tradingagents/config.py index 77e1367c..75c1b5d4 100644 --- a/tradingagents/config.py +++ b/tradingagents/config.py @@ -21,7 +21,8 @@ def get_config(): "project_dir": str(project_root / "tradingagents"), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), "data_dir": os.getenv( - "TRADINGAGENTS_DATA_DIR", "/Users/yluo/Documents/Code/ScAI/FR1-data", + "TRADINGAGENTS_DATA_DIR", + "/Users/yluo/Documents/Code/ScAI/FR1-data", ), "data_cache_dir": str( project_root / "tradingagents" / "dataflows" / "data_cache", diff --git a/tradingagents/dataflows/config.py b/tradingagents/dataflows/config.py index b3adcaf2..dec03021 100644 --- a/tradingagents/dataflows/config.py +++ b/tradingagents/dataflows/config.py @@ -1,4 +1,3 @@ - from tradingagents import default_config # Use default config but allow it to be overridden diff --git a/tradingagents/dataflows/finnhub_utils.py b/tradingagents/dataflows/finnhub_utils.py index 7d0a1e30..4bb6422e 100644 --- a/tradingagents/dataflows/finnhub_utils.py +++ b/tradingagents/dataflows/finnhub_utils.py @@ -22,7 +22,10 @@ def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period= ) else: data_path = os.path.join( - data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json", + data_dir, + "finnhub_data", + data_type, + f"{ticker}_data_formatted.json", ) data = open(data_path) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 5de638cb..2e27e558 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -419,7 +419,8 @@ def get_stock_stats_indicators_window( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[str, "technical indicator to get the analysis and report of"], curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd", + str, + "The current trading date you are trading on, YYYY-mm-dd", ], look_back_days: Annotated[int, "how many days to look back"], online: Annotated[bool, "to fetch data online or offline"], @@ -524,7 +525,10 @@ def get_stock_stats_indicators_window( # only do the trading dates if curr_date.strftime("%Y-%m-%d") in dates_in_df.values: indicator_value = get_stockstats_indicator( - symbol, indicator, curr_date.strftime("%Y-%m-%d"), online, + symbol, + indicator, + curr_date.strftime("%Y-%m-%d"), + online, ) ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n" @@ -535,7 +539,10 @@ def get_stock_stats_indicators_window( ind_string = "" while curr_date >= before: indicator_value = get_stockstats_indicator( - symbol, indicator, curr_date.strftime("%Y-%m-%d"), online, + symbol, + indicator, + curr_date.strftime("%Y-%m-%d"), + online, ) ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n" @@ -550,12 +557,12 @@ def get_stock_stats_indicators_window( ) - def get_stockstats_indicator( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[str, "technical indicator to get the analysis and report of"], curr_date: Annotated[ - str, "The current trading date you are trading on, YYYY-mm-dd", + str, + "The current trading date you are trading on, YYYY-mm-dd", ], online: Annotated[bool, "to fetch data online or offline"], ) -> str: @@ -608,7 +615,12 @@ def get_YFin_data_window( # Set pandas display options to show the full DataFrame with pd.option_context( - "display.max_rows", None, "display.max_columns", None, "display.width", None, + "display.max_rows", + None, + "display.max_columns", + None, + "display.width", + None, ): df_string = filtered_data.to_string() @@ -694,7 +706,6 @@ def get_YFin_data( return filtered_data.reset_index(drop=True) - def get_stock_news_openai(ticker, curr_date): config = get_config() client = OpenAI(base_url=config["backend_url"]) diff --git a/tradingagents/dataflows/reddit_utils.py b/tradingagents/dataflows/reddit_utils.py index 5710073c..efda0634 100644 --- a/tradingagents/dataflows/reddit_utils.py +++ b/tradingagents/dataflows/reddit_utils.py @@ -48,11 +48,14 @@ ticker_to_company = { def fetch_top_from_category( category: Annotated[ - str, "Category to fetch top post from. Collection of subreddits.", + str, + "Category to fetch top post from. Collection of subreddits.", ], date: Annotated[str, "Date to fetch top posts from."], max_limit: Annotated[int, "Maximum number of posts to fetch."], - query: Annotated[str | None, "Optional query to search for in the subreddit."] = None, + query: Annotated[ + str | None, "Optional query to search for in the subreddit." + ] = None, data_path: Annotated[ str, "Path to the data folder. Default is 'reddit_data'.", @@ -107,7 +110,9 @@ def fetch_top_from_category( found = False for term in search_terms: if re.search( - term, parsed_line["title"], re.IGNORECASE, + term, + parsed_line["title"], + re.IGNORECASE, ) or re.search(term, parsed_line["selftext"], re.IGNORECASE): found = True break diff --git a/tradingagents/dataflows/stockstats_utils.py b/tradingagents/dataflows/stockstats_utils.py index a36e150f..06858522 100644 --- a/tradingagents/dataflows/stockstats_utils.py +++ b/tradingagents/dataflows/stockstats_utils.py @@ -13,10 +13,12 @@ class StockstatsUtils: def get_stock_stats( symbol: Annotated[str, "ticker symbol for the company"], indicator: Annotated[ - str, "quantitative indicators based off of the stock data for the company", + str, + "quantitative indicators based off of the stock data for the company", ], curr_date: Annotated[ - str, "curr date for retrieving stock price data, YYYY-mm-dd", + str, + "curr date for retrieving stock price data, YYYY-mm-dd", ], data_dir: Annotated[ str, diff --git a/tradingagents/dataflows/yfin_utils.py b/tradingagents/dataflows/yfin_utils.py index f1a69df3..3b835982 100644 --- a/tradingagents/dataflows/yfin_utils.py +++ b/tradingagents/dataflows/yfin_utils.py @@ -28,10 +28,12 @@ class YFinanceUtils: def get_stock_data( self: Annotated[str, "ticker symbol"], start_date: Annotated[ - str, "start date for retrieving stock price data, YYYY-mm-dd", + str, + "start date for retrieving stock price data, YYYY-mm-dd", ], end_date: Annotated[ - str, "end date for retrieving stock price data, YYYY-mm-dd", + str, + "end date for retrieving stock price data, YYYY-mm-dd", ], save_path: SavePathType = None, ) -> DataFrame: diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index dc522af8..36038f6d 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -16,7 +16,9 @@ class Propagator: self.max_recur_limit = max_recur_limit def create_initial_state( - self, company_name: str, trade_date: str, + self, + company_name: str, + trade_date: str, ) -> dict[str, Any]: """Create the initial state for the agent graph.""" return { diff --git a/tradingagents/graph/reflection.py b/tradingagents/graph/reflection.py index 04b66224..496ea2fc 100644 --- a/tradingagents/graph/reflection.py +++ b/tradingagents/graph/reflection.py @@ -57,7 +57,11 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}" def _reflect_on_component( - self, component_type: str, report: str, situation: str, returns_losses, + self, + component_type: str, + report: str, + situation: str, + returns_losses, ) -> str: """Generate reflection for a component.""" messages = [ @@ -76,7 +80,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur bull_debate_history = current_state["investment_debate_state"]["bull_history"] result = self._reflect_on_component( - "BULL", bull_debate_history, situation, returns_losses, + "BULL", + bull_debate_history, + situation, + returns_losses, ) bull_memory.add_situations([(situation, result)]) @@ -86,7 +93,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur bear_debate_history = current_state["investment_debate_state"]["bear_history"] result = self._reflect_on_component( - "BEAR", bear_debate_history, situation, returns_losses, + "BEAR", + bear_debate_history, + situation, + returns_losses, ) bear_memory.add_situations([(situation, result)]) @@ -96,7 +106,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur trader_decision = current_state["trader_investment_plan"] result = self._reflect_on_component( - "TRADER", trader_decision, situation, returns_losses, + "TRADER", + trader_decision, + situation, + returns_losses, ) trader_memory.add_situations([(situation, result)]) @@ -106,7 +119,10 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur judge_decision = current_state["investment_debate_state"]["judge_decision"] result = self._reflect_on_component( - "INVEST JUDGE", judge_decision, situation, returns_losses, + "INVEST JUDGE", + judge_decision, + situation, + returns_losses, ) invest_judge_memory.add_situations([(situation, result)]) @@ -116,6 +132,9 @@ Adhere strictly to these instructions, and ensure your output is detailed, accur judge_decision = current_state["risk_debate_state"]["judge_decision"] result = self._reflect_on_component( - "RISK JUDGE", judge_decision, situation, returns_losses, + "RISK JUDGE", + judge_decision, + situation, + returns_losses, ) risk_manager_memory.add_situations([(situation, result)]) diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index c5f882fa..6ef2b4cd 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -55,7 +55,8 @@ class GraphSetup: self.conditional_logic = conditional_logic def setup_graph( - self, selected_analysts=None, + self, + selected_analysts=None, ): """Set up and compile the agent workflow graph. @@ -79,41 +80,48 @@ class GraphSetup: if "market" in selected_analysts: analyst_nodes["market"] = create_market_analyst( - self.quick_thinking_llm, self.toolkit, + self.quick_thinking_llm, + self.toolkit, ) delete_nodes["market"] = create_msg_delete() tool_nodes["market"] = self.tool_nodes["market"] if "social" in selected_analysts: analyst_nodes["social"] = create_social_media_analyst( - self.quick_thinking_llm, self.toolkit, + self.quick_thinking_llm, + self.toolkit, ) delete_nodes["social"] = create_msg_delete() tool_nodes["social"] = self.tool_nodes["social"] if "news" in selected_analysts: analyst_nodes["news"] = create_news_analyst( - self.quick_thinking_llm, self.toolkit, + self.quick_thinking_llm, + self.toolkit, ) delete_nodes["news"] = create_msg_delete() tool_nodes["news"] = self.tool_nodes["news"] if "fundamentals" in selected_analysts: analyst_nodes["fundamentals"] = create_fundamentals_analyst( - self.quick_thinking_llm, self.toolkit, + self.quick_thinking_llm, + self.toolkit, ) delete_nodes["fundamentals"] = create_msg_delete() tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"] # Create researcher and manager nodes bull_researcher_node = create_bull_researcher( - self.quick_thinking_llm, self.bull_memory, + self.quick_thinking_llm, + self.bull_memory, ) bear_researcher_node = create_bear_researcher( - self.quick_thinking_llm, self.bear_memory, + self.quick_thinking_llm, + self.bear_memory, ) research_manager_node = create_research_manager( - self.deep_thinking_llm, self.invest_judge_memory, + self.deep_thinking_llm, + self.invest_judge_memory, ) trader_node = create_trader(self.quick_thinking_llm, self.trader_memory) @@ -122,7 +130,8 @@ class GraphSetup: neutral_analyst = create_neutral_debator(self.quick_thinking_llm) safe_analyst = create_safe_debator(self.quick_thinking_llm) risk_manager_node = create_risk_manager( - self.deep_thinking_llm, self.risk_manager_memory, + self.deep_thinking_llm, + self.risk_manager_memory, ) # Create workflow @@ -132,7 +141,8 @@ class GraphSetup: for analyst_type, node in analyst_nodes.items(): workflow.add_node(f"{analyst_type.capitalize()} Analyst", node) workflow.add_node( - f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type], + f"Msg Clear {analyst_type.capitalize()}", + delete_nodes[analyst_type], ) workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type]) diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c656b6b3..2594bf25 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -59,7 +59,8 @@ class TradingAgentsGraph: or self.config["llm_provider"] == "openrouter" ): self.deep_thinking_llm = ChatOpenAI( - model=self.config["deep_think_llm"], base_url=self.config["backend_url"], + model=self.config["deep_think_llm"], + base_url=self.config["backend_url"], ) self.quick_thinking_llm = ChatOpenAI( model=self.config["quick_think_llm"], @@ -67,7 +68,8 @@ class TradingAgentsGraph: ) elif self.config["llm_provider"].lower() == "anthropic": self.deep_thinking_llm = ChatAnthropic( - model=self.config["deep_think_llm"], base_url=self.config["backend_url"], + model=self.config["deep_think_llm"], + base_url=self.config["backend_url"], ) self.quick_thinking_llm = ChatAnthropic( model=self.config["quick_think_llm"], @@ -91,10 +93,12 @@ class TradingAgentsGraph: self.bear_memory = FinancialSituationMemory("bear_memory", self.config) self.trader_memory = FinancialSituationMemory("trader_memory", self.config) self.invest_judge_memory = FinancialSituationMemory( - "invest_judge_memory", self.config, + "invest_judge_memory", + self.config, ) self.risk_manager_memory = FinancialSituationMemory( - "risk_manager_memory", self.config, + "risk_manager_memory", + self.config, ) # Create tool nodes @@ -179,7 +183,8 @@ class TradingAgentsGraph: # Initialize state init_agent_state = self.propagator.create_initial_state( - company_name, trade_date, + company_name, + trade_date, ) args = self.propagator.get_graph_args() @@ -252,19 +257,29 @@ class TradingAgentsGraph: def reflect_and_remember(self, returns_losses): """Reflect on decisions and update memory based on returns.""" self.reflector.reflect_bull_researcher( - self.curr_state, returns_losses, self.bull_memory, + self.curr_state, + returns_losses, + self.bull_memory, ) self.reflector.reflect_bear_researcher( - self.curr_state, returns_losses, self.bear_memory, + self.curr_state, + returns_losses, + self.bear_memory, ) self.reflector.reflect_trader( - self.curr_state, returns_losses, self.trader_memory, + self.curr_state, + returns_losses, + self.trader_memory, ) self.reflector.reflect_invest_judge( - self.curr_state, returns_losses, self.invest_judge_memory, + self.curr_state, + returns_losses, + self.invest_judge_memory, ) self.reflector.reflect_risk_manager( - self.curr_state, returns_losses, self.risk_manager_memory, + self.curr_state, + returns_losses, + self.risk_manager_memory, ) def process_signal(self, full_signal):