#!/usr/bin/env python3 """ Test script to verify DiscoveryGraph refactoring. Tests: LLM Factory, CandidateFilter, CandidateRanker """ import os import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) def test_llm_factory(): """Test LLM factory initialization.""" print("\n=== Testing LLM Factory ===") try: from tradingagents.utils.llm_factory import create_llms # Mock API key os.environ.setdefault("OPENAI_API_KEY", "sk-test-key") config = { "llm_provider": "openai", "deep_think_llm": "gpt-4", "quick_think_llm": "gpt-3.5-turbo" } deep_llm, quick_llm = create_llms(config) assert deep_llm is not None, "Deep LLM should be initialized" assert quick_llm is not None, "Quick LLM should be initialized" print("✅ LLM Factory: Successfully creates LLMs") return True except Exception as e: print(f"❌ LLM Factory: Failed - {e}") return False def test_candidate_filter(): """Test CandidateFilter class.""" print("\n=== Testing CandidateFilter ===") try: from unittest.mock import MagicMock from tradingagents.dataflows.discovery.filter import CandidateFilter config = {"discovery": {}} mock_executor = MagicMock() filter_obj = CandidateFilter(config, mock_executor) assert hasattr(filter_obj, 'filter'), "Filter should have filter method" assert filter_obj.execute_tool == mock_executor, "Should store executor" print("✅ CandidateFilter: Successfully initialized") return True except Exception as e: print(f"❌ CandidateFilter: Failed - {e}") import traceback traceback.print_exc() return False def test_candidate_ranker(): """Test CandidateRanker class.""" print("\n=== Testing CandidateRanker ===") try: from unittest.mock import MagicMock from tradingagents.dataflows.discovery.ranker import CandidateRanker config = {"discovery": {}} mock_llm = MagicMock() mock_analytics = MagicMock() ranker = CandidateRanker(config, mock_llm, mock_analytics) assert hasattr(ranker, 'rank'), "Ranker should have rank method" assert ranker.llm == mock_llm, "Should store LLM" print("✅ CandidateRanker: Successfully initialized") return True except Exception as e: print(f"❌ CandidateRanker: Failed - {e}") import traceback traceback.print_exc() return False def test_discovery_graph_import(): """Test that DiscoveryGraph still imports correctly.""" print("\n=== Testing DiscoveryGraph Import ===") try: from tradingagents.graph.discovery_graph import DiscoveryGraph # Mock API key os.environ.setdefault("OPENAI_API_KEY", "sk-test-key") config = { "llm_provider": "openai", "deep_think_llm": "gpt-4", "quick_think_llm": "gpt-3.5-turbo", "backend_url": "https://api.openai.com/v1", "discovery": {} } graph = DiscoveryGraph(config=config) assert hasattr(graph, 'deep_thinking_llm'), "Should have deep LLM" assert hasattr(graph, 'quick_thinking_llm'), "Should have quick LLM" assert hasattr(graph, 'analytics'), "Should have analytics" assert hasattr(graph, 'graph'), "Should have graph" print("✅ DiscoveryGraph: Successfully initialized with refactored components") return True except Exception as e: print(f"❌ DiscoveryGraph: Failed - {e}") import traceback traceback.print_exc() return False def test_trading_graph_import(): """Test that TradingAgentsGraph still imports correctly.""" print("\n=== Testing TradingAgentsGraph Import ===") try: from tradingagents.graph.trading_graph import TradingAgentsGraph # Mock API key os.environ.setdefault("OPENAI_API_KEY", "sk-test-key") config = { "llm_provider": "openai", "deep_think_llm": "gpt-4", "quick_think_llm": "gpt-3.5-turbo", "project_dir": str(project_root), "enable_memory": False } graph = TradingAgentsGraph(config=config) assert hasattr(graph, 'deep_thinking_llm'), "Should have deep LLM" assert hasattr(graph, 'quick_thinking_llm'), "Should have quick LLM" print("✅ TradingAgentsGraph: Successfully initialized with LLM factory") return True except Exception as e: print(f"❌ TradingAgentsGraph: Failed - {e}") import traceback traceback.print_exc() return False def test_utils(): """Test utility functions.""" print("\n=== Testing Utilities ===") try: from tradingagents.dataflows.discovery.utils import ( extract_technical_summary, is_valid_ticker, ) # Test ticker validation assert is_valid_ticker("AAPL") == True, "AAPL should be valid" assert is_valid_ticker("AAPL.WS") == False, "Warrant should be invalid" assert is_valid_ticker("AAPL-RT") == False, "Rights should be invalid" # Test technical summary extraction tech_report = "RSI Value: 45.5" summary = extract_technical_summary(tech_report) assert "RSI:45" in summary or "RSI:46" in summary, "Should extract RSI" print("✅ Utils: All utility functions work correctly") return True except Exception as e: print(f"❌ Utils: Failed - {e}") import traceback traceback.print_exc() return False def main(): """Run all tests.""" print("=" * 60) print("DISCOVERY GRAPH REFACTORING VERIFICATION") print("=" * 60) results = [] # Run all tests results.append(("LLM Factory", test_llm_factory())) results.append(("Candidate Filter", test_candidate_filter())) results.append(("Candidate Ranker", test_candidate_ranker())) results.append(("Utils", test_utils())) results.append(("DiscoveryGraph", test_discovery_graph_import())) results.append(("TradingAgentsGraph", test_trading_graph_import())) # Summary print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) passed = sum(1 for _, result in results if result) total = len(results) for name, result in results: status = "✅ PASS" if result else "❌ FAIL" print(f"{status}: {name}") print(f"\n{passed}/{total} tests passed") if passed == total: print("\n🎉 All refactoring tests passed!") return 0 else: print(f"\n⚠️ {total - passed} test(s) failed") return 1 if __name__ == "__main__": sys.exit(main())