TradingAgents/tests/test_discovery_refactor.py

223 lines
6.7 KiB
Python

#!/usr/bin/env python3
"""
Test script to verify DiscoveryGraph refactoring.
Tests: LLM Factory, CandidateFilter, CandidateRanker
"""
import os
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
def test_llm_factory():
"""Test LLM factory initialization."""
print("\n=== Testing LLM Factory ===")
try:
from tradingagents.utils.llm_factory import create_llms
# Mock API key
os.environ.setdefault("OPENAI_API_KEY", "sk-test-key")
config = {
"llm_provider": "openai",
"deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-3.5-turbo"
}
deep_llm, quick_llm = create_llms(config)
assert deep_llm is not None, "Deep LLM should be initialized"
assert quick_llm is not None, "Quick LLM should be initialized"
print("✅ LLM Factory: Successfully creates LLMs")
return True
except Exception as e:
print(f"❌ LLM Factory: Failed - {e}")
return False
def test_candidate_filter():
"""Test CandidateFilter class."""
print("\n=== Testing CandidateFilter ===")
try:
from unittest.mock import MagicMock
from tradingagents.dataflows.discovery.filter import CandidateFilter
config = {"discovery": {}}
mock_executor = MagicMock()
filter_obj = CandidateFilter(config, mock_executor)
assert hasattr(filter_obj, 'filter'), "Filter should have filter method"
assert filter_obj.execute_tool == mock_executor, "Should store executor"
print("✅ CandidateFilter: Successfully initialized")
return True
except Exception as e:
print(f"❌ CandidateFilter: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_candidate_ranker():
"""Test CandidateRanker class."""
print("\n=== Testing CandidateRanker ===")
try:
from unittest.mock import MagicMock
from tradingagents.dataflows.discovery.ranker import CandidateRanker
config = {"discovery": {}}
mock_llm = MagicMock()
mock_analytics = MagicMock()
ranker = CandidateRanker(config, mock_llm, mock_analytics)
assert hasattr(ranker, 'rank'), "Ranker should have rank method"
assert ranker.llm == mock_llm, "Should store LLM"
print("✅ CandidateRanker: Successfully initialized")
return True
except Exception as e:
print(f"❌ CandidateRanker: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_discovery_graph_import():
"""Test that DiscoveryGraph still imports correctly."""
print("\n=== Testing DiscoveryGraph Import ===")
try:
from tradingagents.graph.discovery_graph import DiscoveryGraph
# Mock API key
os.environ.setdefault("OPENAI_API_KEY", "sk-test-key")
config = {
"llm_provider": "openai",
"deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-3.5-turbo",
"backend_url": "https://api.openai.com/v1",
"discovery": {}
}
graph = DiscoveryGraph(config=config)
assert hasattr(graph, 'deep_thinking_llm'), "Should have deep LLM"
assert hasattr(graph, 'quick_thinking_llm'), "Should have quick LLM"
assert hasattr(graph, 'analytics'), "Should have analytics"
assert hasattr(graph, 'graph'), "Should have graph"
print("✅ DiscoveryGraph: Successfully initialized with refactored components")
return True
except Exception as e:
print(f"❌ DiscoveryGraph: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_trading_graph_import():
"""Test that TradingAgentsGraph still imports correctly."""
print("\n=== Testing TradingAgentsGraph Import ===")
try:
from tradingagents.graph.trading_graph import TradingAgentsGraph
# Mock API key
os.environ.setdefault("OPENAI_API_KEY", "sk-test-key")
config = {
"llm_provider": "openai",
"deep_think_llm": "gpt-4",
"quick_think_llm": "gpt-3.5-turbo",
"project_dir": str(project_root),
"enable_memory": False
}
graph = TradingAgentsGraph(config=config)
assert hasattr(graph, 'deep_thinking_llm'), "Should have deep LLM"
assert hasattr(graph, 'quick_thinking_llm'), "Should have quick LLM"
print("✅ TradingAgentsGraph: Successfully initialized with LLM factory")
return True
except Exception as e:
print(f"❌ TradingAgentsGraph: Failed - {e}")
import traceback
traceback.print_exc()
return False
def test_utils():
"""Test utility functions."""
print("\n=== Testing Utilities ===")
try:
from tradingagents.dataflows.discovery.utils import (
extract_technical_summary,
is_valid_ticker,
)
# Test ticker validation
assert is_valid_ticker("AAPL") == True, "AAPL should be valid"
assert is_valid_ticker("AAPL.WS") == False, "Warrant should be invalid"
assert is_valid_ticker("AAPL-RT") == False, "Rights should be invalid"
# Test technical summary extraction
tech_report = "RSI Value: 45.5"
summary = extract_technical_summary(tech_report)
assert "RSI:45" in summary or "RSI:46" in summary, "Should extract RSI"
print("✅ Utils: All utility functions work correctly")
return True
except Exception as e:
print(f"❌ Utils: Failed - {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Run all tests."""
print("=" * 60)
print("DISCOVERY GRAPH REFACTORING VERIFICATION")
print("=" * 60)
results = []
# Run all tests
results.append(("LLM Factory", test_llm_factory()))
results.append(("Candidate Filter", test_candidate_filter()))
results.append(("Candidate Ranker", test_candidate_ranker()))
results.append(("Utils", test_utils()))
results.append(("DiscoveryGraph", test_discovery_graph_import()))
results.append(("TradingAgentsGraph", test_trading_graph_import()))
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
passed = sum(1 for _, result in results if result)
total = len(results)
for name, result in results:
status = "✅ PASS" if result else "❌ FAIL"
print(f"{status}: {name}")
print(f"\n{passed}/{total} tests passed")
if passed == total:
print("\n🎉 All refactoring tests passed!")
return 0
else:
print(f"\n⚠️ {total - passed} test(s) failed")
return 1
if __name__ == "__main__":
sys.exit(main())