702 lines
27 KiB
Python
702 lines
27 KiB
Python
"""
|
|
Test suite for CLI Error Handling with Rate Limit Errors.
|
|
|
|
This module tests:
|
|
1. Rate limit errors are caught and logged in main.py
|
|
2. Partial analysis is saved to JSON file when error occurs
|
|
3. User sees appropriate error message with retry guidance
|
|
4. Both terminal and file receive error logs
|
|
5. Integration with graph.stream() error handling
|
|
6. Error translation from provider errors to unified exceptions
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import pytest
|
|
import tempfile
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, MagicMock, call
|
|
|
|
|
|
# ============================================================================
|
|
# Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def temp_output_dir():
|
|
"""Create a temporary directory for output files."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
yield Path(tmpdir)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_graph():
|
|
"""Create a mock TradingAgentsGraph."""
|
|
mock = Mock()
|
|
mock.propagate = Mock()
|
|
return mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_config():
|
|
"""Create a mock configuration."""
|
|
return {
|
|
"llm_provider": "openrouter",
|
|
"deep_think_llm": "anthropic/claude-opus-4.5",
|
|
"quick_think_llm": "anthropic/claude-haiku-3.5",
|
|
"backend_url": "https://openrouter.ai/api/v1",
|
|
"max_debate_rounds": 1,
|
|
"data_vendors": {
|
|
"core_stock_apis": "yfinance",
|
|
"technical_indicators": "yfinance",
|
|
"fundamental_data": "yfinance",
|
|
"news_data": "google",
|
|
}
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_partial_state():
|
|
"""Create a sample partial state for testing."""
|
|
return {
|
|
"ticker": "AAPL",
|
|
"analysis_date": "2024-12-26",
|
|
"messages": [
|
|
{"role": "system", "content": "Starting analysis"},
|
|
{"role": "assistant", "content": "Fetched market data"},
|
|
],
|
|
"analyst_reports": {
|
|
"market": {"summary": "Bullish trend", "confidence": 0.8}
|
|
},
|
|
"error": "Rate limit exceeded",
|
|
"error_timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
|
|
# ============================================================================
|
|
# Test Rate Limit Error Catching in main.py
|
|
# ============================================================================
|
|
|
|
class TestMainRateLimitErrorHandling:
|
|
"""Test error handling in main.py around graph.stream()."""
|
|
|
|
@patch('main.TradingAgentsGraph')
|
|
def test_catches_rate_limit_error_from_openai(self, mock_graph_class, temp_output_dir):
|
|
"""Test that OpenAI rate limit errors are caught in main.py."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
|
|
# Setup mock to raise rate limit error
|
|
mock_instance = Mock()
|
|
mock_instance.propagate.side_effect = OpenAIRateLimitError(
|
|
"Rate limit exceeded for gpt-4",
|
|
retry_after=60,
|
|
)
|
|
mock_graph_class.return_value = mock_instance
|
|
|
|
# This import will fail initially (TDD RED phase)
|
|
# The main.py needs to be modified to catch these errors
|
|
# For now, we're testing the expected behavior
|
|
|
|
with pytest.raises(OpenAIRateLimitError) as exc_info:
|
|
mock_instance.propagate("AAPL", "2024-12-26")
|
|
|
|
assert exc_info.value.retry_after == 60
|
|
assert exc_info.value.provider == "openai"
|
|
|
|
@patch('main.TradingAgentsGraph')
|
|
def test_catches_rate_limit_error_from_anthropic(self, mock_graph_class):
|
|
"""Test that Anthropic rate limit errors are caught."""
|
|
from tradingagents.utils.exceptions import AnthropicRateLimitError
|
|
|
|
mock_instance = Mock()
|
|
mock_instance.propagate.side_effect = AnthropicRateLimitError(
|
|
"Rate limit exceeded for claude-opus-4.5",
|
|
retry_after=120,
|
|
)
|
|
mock_graph_class.return_value = mock_instance
|
|
|
|
with pytest.raises(AnthropicRateLimitError) as exc_info:
|
|
mock_instance.propagate("AAPL", "2024-12-26")
|
|
|
|
assert exc_info.value.retry_after == 120
|
|
assert exc_info.value.provider == "anthropic"
|
|
|
|
@patch('main.TradingAgentsGraph')
|
|
def test_catches_rate_limit_error_from_openrouter(self, mock_graph_class):
|
|
"""Test that OpenRouter rate limit errors are caught."""
|
|
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
|
|
|
mock_instance = Mock()
|
|
mock_instance.propagate.side_effect = OpenRouterRateLimitError(
|
|
"Rate limit exceeded for anthropic/claude-opus-4.5",
|
|
retry_after=45,
|
|
)
|
|
mock_graph_class.return_value = mock_instance
|
|
|
|
with pytest.raises(OpenRouterRateLimitError) as exc_info:
|
|
mock_instance.propagate("AAPL", "2024-12-26")
|
|
|
|
assert exc_info.value.retry_after == 45
|
|
assert exc_info.value.provider == "openrouter"
|
|
|
|
@patch('main.TradingAgentsGraph')
|
|
@patch('main.setup_dual_logger')
|
|
def test_rate_limit_error_is_logged(self, mock_logger_setup, mock_graph_class, temp_output_dir):
|
|
"""Test that rate limit errors are logged."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
|
|
# Setup mock logger
|
|
mock_logger = Mock()
|
|
mock_logger_setup.return_value = mock_logger
|
|
|
|
# Setup mock to raise error
|
|
mock_instance = Mock()
|
|
mock_instance.propagate.side_effect = OpenAIRateLimitError(
|
|
"Rate limit exceeded",
|
|
retry_after=60,
|
|
)
|
|
mock_graph_class.return_value = mock_instance
|
|
|
|
# In the modified main.py, the error should be caught and logged
|
|
# This test validates the expected logging behavior
|
|
|
|
try:
|
|
mock_instance.propagate("AAPL", "2024-12-26")
|
|
except OpenAIRateLimitError as e:
|
|
# Simulate what main.py should do
|
|
mock_logger.error(f"Rate limit error: {str(e)}")
|
|
mock_logger.info(f"Retry after: {e.retry_after} seconds")
|
|
|
|
# Verify logging calls
|
|
assert mock_logger.error.called
|
|
assert mock_logger.info.called
|
|
|
|
|
|
# ============================================================================
|
|
# Test Partial Analysis Saving
|
|
# ============================================================================
|
|
|
|
class TestPartialAnalysisSaving:
|
|
"""Test saving partial analysis to JSON when error occurs."""
|
|
|
|
def test_saves_partial_state_to_json(self, temp_output_dir, sample_partial_state):
|
|
"""Test that partial state is saved to JSON file on error."""
|
|
# This function would be in main.py or a utility module
|
|
from tradingagents.utils.error_recovery import save_partial_analysis
|
|
|
|
output_file = temp_output_dir / "partial_analysis.json"
|
|
|
|
save_partial_analysis(sample_partial_state, str(output_file))
|
|
|
|
assert output_file.exists()
|
|
|
|
with open(output_file, 'r') as f:
|
|
loaded_state = json.load(f)
|
|
|
|
assert loaded_state["ticker"] == "AAPL"
|
|
assert loaded_state["analysis_date"] == "2024-12-26"
|
|
assert "error" in loaded_state
|
|
|
|
def test_partial_state_includes_error_info(self, temp_output_dir):
|
|
"""Test that saved partial state includes error information."""
|
|
from tradingagents.utils.error_recovery import save_partial_analysis
|
|
|
|
state_with_error = {
|
|
"ticker": "TSLA",
|
|
"error": "Rate limit exceeded for gpt-4",
|
|
"error_timestamp": datetime.now().isoformat(),
|
|
"retry_after": 60,
|
|
"provider": "openai"
|
|
}
|
|
|
|
output_file = temp_output_dir / "error_state.json"
|
|
save_partial_analysis(state_with_error, str(output_file))
|
|
|
|
with open(output_file, 'r') as f:
|
|
loaded = json.load(f)
|
|
|
|
assert loaded["error"] == "Rate limit exceeded for gpt-4"
|
|
assert loaded["retry_after"] == 60
|
|
assert loaded["provider"] == "openai"
|
|
assert "error_timestamp" in loaded
|
|
|
|
def test_partial_state_includes_completed_work(self, temp_output_dir, sample_partial_state):
|
|
"""Test that partial state includes work completed before error."""
|
|
from tradingagents.utils.error_recovery import save_partial_analysis
|
|
|
|
output_file = temp_output_dir / "partial.json"
|
|
save_partial_analysis(sample_partial_state, str(output_file))
|
|
|
|
with open(output_file, 'r') as f:
|
|
loaded = json.load(f)
|
|
|
|
assert "analyst_reports" in loaded
|
|
assert "market" in loaded["analyst_reports"]
|
|
assert loaded["analyst_reports"]["market"]["summary"] == "Bullish trend"
|
|
|
|
def test_default_output_filename_format(self, temp_output_dir):
|
|
"""Test that default output filename includes ticker and timestamp."""
|
|
from tradingagents.utils.error_recovery import get_partial_analysis_filename
|
|
|
|
ticker = "AAPL"
|
|
timestamp = datetime.now()
|
|
|
|
filename = get_partial_analysis_filename(ticker, timestamp)
|
|
|
|
assert ticker in filename
|
|
assert filename.endswith(".json")
|
|
assert "partial" in filename.lower() or "error" in filename.lower()
|
|
|
|
def test_overwrites_existing_partial_file(self, temp_output_dir):
|
|
"""Test that saving overwrites existing partial analysis file."""
|
|
from tradingagents.utils.error_recovery import save_partial_analysis
|
|
|
|
output_file = temp_output_dir / "partial.json"
|
|
|
|
# Save first version
|
|
state_v1 = {"version": 1, "data": "first"}
|
|
save_partial_analysis(state_v1, str(output_file))
|
|
|
|
# Save second version
|
|
state_v2 = {"version": 2, "data": "second"}
|
|
save_partial_analysis(state_v2, str(output_file))
|
|
|
|
with open(output_file, 'r') as f:
|
|
loaded = json.load(f)
|
|
|
|
assert loaded["version"] == 2
|
|
assert loaded["data"] == "second"
|
|
|
|
def test_handles_non_serializable_data(self, temp_output_dir):
|
|
"""Test handling of non-JSON-serializable data in state."""
|
|
from tradingagents.utils.error_recovery import save_partial_analysis
|
|
|
|
# Include a Mock object which isn't JSON serializable
|
|
state = {
|
|
"ticker": "AAPL",
|
|
"mock_object": Mock(), # Not serializable
|
|
"normal_data": "test"
|
|
}
|
|
|
|
output_file = temp_output_dir / "partial.json"
|
|
|
|
# Should handle gracefully - either skip non-serializable or convert to string
|
|
save_partial_analysis(state, str(output_file))
|
|
|
|
with open(output_file, 'r') as f:
|
|
loaded = json.load(f)
|
|
|
|
assert loaded["ticker"] == "AAPL"
|
|
assert loaded["normal_data"] == "test"
|
|
# mock_object should be handled somehow (skipped or converted)
|
|
|
|
|
|
# ============================================================================
|
|
# Test User Error Messages
|
|
# ============================================================================
|
|
|
|
class TestUserErrorMessages:
|
|
"""Test user-facing error messages and guidance."""
|
|
|
|
def test_error_message_includes_retry_time(self):
|
|
"""Test that error message includes retry_after time."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
from tradingagents.utils.error_messages import format_rate_limit_error
|
|
|
|
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=60)
|
|
message = format_rate_limit_error(error)
|
|
|
|
assert "60" in message
|
|
assert "second" in message.lower() or "sec" in message.lower()
|
|
|
|
def test_error_message_includes_provider(self):
|
|
"""Test that error message identifies the provider."""
|
|
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
|
from tradingagents.utils.error_messages import format_rate_limit_error
|
|
|
|
error = OpenRouterRateLimitError("Rate limit exceeded", retry_after=45)
|
|
message = format_rate_limit_error(error)
|
|
|
|
assert "openrouter" in message.lower() or "OpenRouter" in message
|
|
|
|
def test_error_message_suggests_retry(self):
|
|
"""Test that error message suggests retrying."""
|
|
from tradingagents.utils.exceptions import AnthropicRateLimitError
|
|
from tradingagents.utils.error_messages import format_rate_limit_error
|
|
|
|
error = AnthropicRateLimitError("Rate limit exceeded", retry_after=120)
|
|
message = format_rate_limit_error(error)
|
|
|
|
assert "retry" in message.lower() or "try again" in message.lower()
|
|
|
|
def test_error_message_without_retry_after(self):
|
|
"""Test error message when retry_after is not provided."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
from tradingagents.utils.error_messages import format_rate_limit_error
|
|
|
|
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=None)
|
|
message = format_rate_limit_error(error)
|
|
|
|
# Should provide generic guidance
|
|
assert "later" in message.lower() or "wait" in message.lower()
|
|
|
|
def test_error_message_includes_partial_save_info(self, temp_output_dir):
|
|
"""Test that error message mentions where partial analysis was saved."""
|
|
from tradingagents.utils.error_messages import format_error_with_partial_save
|
|
|
|
error_msg = "Rate limit exceeded"
|
|
partial_file = temp_output_dir / "partial_AAPL_20241226.json"
|
|
|
|
message = format_error_with_partial_save(error_msg, str(partial_file))
|
|
|
|
assert str(partial_file) in message or partial_file.name in message
|
|
assert "saved" in message.lower()
|
|
|
|
def test_formats_retry_time_in_minutes(self):
|
|
"""Test that large retry_after times are formatted in minutes."""
|
|
from tradingagents.utils.error_messages import format_retry_time
|
|
|
|
# 300 seconds = 5 minutes
|
|
formatted = format_retry_time(300)
|
|
|
|
assert "5" in formatted
|
|
assert "minute" in formatted.lower()
|
|
|
|
def test_formats_retry_time_in_hours(self):
|
|
"""Test that very large retry_after times are formatted in hours."""
|
|
from tradingagents.utils.error_messages import format_retry_time
|
|
|
|
# 3600 seconds = 1 hour
|
|
formatted = format_retry_time(3600)
|
|
|
|
assert "1" in formatted or "60" in formatted
|
|
assert "hour" in formatted.lower() or "minute" in formatted.lower()
|
|
|
|
|
|
# ============================================================================
|
|
# Test Dual Logging of Errors
|
|
# ============================================================================
|
|
|
|
class TestDualLoggingOfErrors:
|
|
"""Test that errors are logged to both terminal and file."""
|
|
|
|
@patch('tradingagents.utils.logging_config.setup_dual_logger')
|
|
def test_error_logged_to_both_handlers(self, mock_logger_setup, temp_output_dir):
|
|
"""Test that errors are sent to both terminal and file handlers."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
|
|
# Create mock logger with two handlers
|
|
mock_logger = Mock()
|
|
mock_stream_handler = Mock(spec=logging.StreamHandler)
|
|
mock_file_handler = Mock(spec=logging.FileHandler)
|
|
|
|
mock_logger.handlers = [mock_stream_handler, mock_file_handler]
|
|
mock_logger_setup.return_value = mock_logger
|
|
|
|
# Simulate logging an error
|
|
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=60)
|
|
mock_logger.error(f"Rate limit error: {str(error)}")
|
|
|
|
# Both handlers should receive the message
|
|
assert mock_logger.error.called
|
|
|
|
def test_terminal_shows_user_friendly_message(self, capsys):
|
|
"""Test that terminal output is user-friendly."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
from tradingagents.utils.error_messages import print_user_error
|
|
|
|
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=60)
|
|
|
|
print_user_error(error)
|
|
|
|
captured = capsys.readouterr()
|
|
|
|
# Should be user-friendly, not a raw stack trace
|
|
assert "Rate limit" in captured.out or "Rate limit" in captured.err
|
|
assert "60" in captured.out or "60" in captured.err
|
|
|
|
def test_file_contains_detailed_error_info(self, temp_output_dir):
|
|
"""Test that file log contains detailed error information."""
|
|
from tradingagents.utils.logging_config import setup_dual_logger
|
|
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
|
|
|
log_file = temp_output_dir / "error_test.log"
|
|
logger = setup_dual_logger(name="test_error_logger", log_file=str(log_file))
|
|
|
|
error = OpenRouterRateLimitError("Rate limit exceeded", retry_after=45)
|
|
|
|
logger.error(f"Rate limit error: {str(error)}")
|
|
logger.error(f"Provider: {error.provider}")
|
|
logger.error(f"Retry after: {error.retry_after} seconds")
|
|
|
|
content = log_file.read_text()
|
|
|
|
assert "Rate limit" in content
|
|
assert "openrouter" in content.lower()
|
|
assert "45" in content
|
|
|
|
def test_sanitization_applied_to_error_logs(self, temp_output_dir):
|
|
"""Test that API keys in error messages are sanitized in logs."""
|
|
from tradingagents.utils.logging_config import setup_dual_logger, sanitize_log_message
|
|
|
|
log_file = temp_output_dir / "sanitized_error.log"
|
|
logger = setup_dual_logger(name="test_sanitize_logger", log_file=str(log_file))
|
|
|
|
# Simulate an error message that includes an API key
|
|
error_msg = "Authentication failed with key sk-test1234567890"
|
|
sanitized_msg = sanitize_log_message(error_msg)
|
|
|
|
logger.error(sanitized_msg)
|
|
|
|
content = log_file.read_text()
|
|
|
|
assert "sk-test1234567890" not in content
|
|
assert "[REDACTED-API-KEY]" in content
|
|
|
|
|
|
# ============================================================================
|
|
# Test Error Translation in Graph Setup
|
|
# ============================================================================
|
|
|
|
class TestGraphErrorTranslation:
|
|
"""Test error translation layer in tradingagents/graph/setup.py."""
|
|
|
|
def test_translates_openai_native_error(self):
|
|
"""Test translation of native OpenAI error to unified exception."""
|
|
from tradingagents.graph.error_handler import translate_llm_error
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
|
|
# Create a mock native OpenAI error
|
|
mock_error = Mock()
|
|
mock_error.__class__.__name__ = "RateLimitError"
|
|
mock_error.message = "Rate limit exceeded"
|
|
mock_error.response = Mock()
|
|
mock_error.response.headers = {"retry-after": "60"}
|
|
|
|
translated = translate_llm_error(mock_error, provider="openai")
|
|
|
|
assert isinstance(translated, OpenAIRateLimitError)
|
|
assert translated.retry_after == 60
|
|
|
|
def test_translates_anthropic_native_error(self):
|
|
"""Test translation of native Anthropic error to unified exception."""
|
|
from tradingagents.graph.error_handler import translate_llm_error
|
|
from tradingagents.utils.exceptions import AnthropicRateLimitError
|
|
|
|
mock_error = Mock()
|
|
mock_error.__class__.__name__ = "RateLimitError"
|
|
mock_error.message = "Rate limit exceeded"
|
|
mock_error.response = Mock()
|
|
mock_error.response.headers = {"retry-after": "120"}
|
|
|
|
translated = translate_llm_error(mock_error, provider="anthropic")
|
|
|
|
assert isinstance(translated, AnthropicRateLimitError)
|
|
assert translated.retry_after == 120
|
|
|
|
def test_translates_openrouter_native_error(self):
|
|
"""Test translation of native OpenRouter error to unified exception."""
|
|
from tradingagents.graph.error_handler import translate_llm_error
|
|
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
|
|
|
mock_error = Mock()
|
|
mock_error.__class__.__name__ = "RateLimitError"
|
|
mock_error.message = "Rate limit exceeded"
|
|
mock_error.response = Mock()
|
|
mock_error.response.headers = {"retry-after": "30"}
|
|
|
|
translated = translate_llm_error(mock_error, provider="openrouter")
|
|
|
|
assert isinstance(translated, OpenRouterRateLimitError)
|
|
assert translated.retry_after == 30
|
|
|
|
@patch('tradingagents.graph.trading_graph.TradingAgentsGraph.propagate')
|
|
def test_error_translation_in_propagate(self, mock_propagate):
|
|
"""Test that errors raised in propagate are translated."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
|
|
# Mock propagate to raise native error, which should be translated
|
|
mock_native_error = Mock()
|
|
mock_native_error.__class__.__name__ = "RateLimitError"
|
|
mock_propagate.side_effect = mock_native_error
|
|
|
|
# The graph should translate this to our unified exception
|
|
# This tests the integration point
|
|
|
|
def test_passes_through_non_rate_limit_errors(self):
|
|
"""Test that non-rate-limit errors are not translated."""
|
|
from tradingagents.graph.error_handler import translate_llm_error
|
|
|
|
mock_error = Mock()
|
|
mock_error.__class__.__name__ = "APIError"
|
|
mock_error.message = "Connection failed"
|
|
|
|
# Should raise ValueError or return None
|
|
with pytest.raises(ValueError):
|
|
translate_llm_error(mock_error, provider="openai")
|
|
|
|
|
|
# ============================================================================
|
|
# Integration Tests
|
|
# ============================================================================
|
|
|
|
class TestEndToEndErrorHandling:
|
|
"""Test complete error handling flow from graph to user output."""
|
|
|
|
@patch('main.TradingAgentsGraph')
|
|
@patch('main.setup_dual_logger')
|
|
def test_complete_error_flow(self, mock_logger_setup, mock_graph_class, temp_output_dir, capsys):
|
|
"""Test complete flow: error raised -> logged -> partial saved -> user notified."""
|
|
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
|
|
|
# Setup mocks
|
|
mock_logger = Mock()
|
|
mock_logger_setup.return_value = mock_logger
|
|
|
|
mock_instance = Mock()
|
|
mock_instance.propagate.side_effect = OpenRouterRateLimitError(
|
|
"Rate limit exceeded for anthropic/claude-opus-4.5",
|
|
retry_after=60,
|
|
)
|
|
mock_graph_class.return_value = mock_instance
|
|
|
|
# Simulate main.py execution
|
|
try:
|
|
state, decision = mock_instance.propagate("AAPL", "2024-12-26")
|
|
except OpenRouterRateLimitError as e:
|
|
# Log error
|
|
mock_logger.error(f"Rate limit error: {str(e)}")
|
|
|
|
# Save partial state
|
|
partial_file = temp_output_dir / f"partial_AAPL_{datetime.now().strftime('%Y%m%d')}.json"
|
|
partial_state = {
|
|
"ticker": "AAPL",
|
|
"error": str(e),
|
|
"retry_after": e.retry_after,
|
|
"provider": e.provider,
|
|
}
|
|
|
|
with open(partial_file, 'w') as f:
|
|
json.dump(partial_state, f)
|
|
|
|
# Print user message
|
|
print(f"\nError: {str(e)}")
|
|
print(f"Please retry in {e.retry_after} seconds")
|
|
print(f"Partial analysis saved to: {partial_file}")
|
|
|
|
# Verify all components
|
|
assert mock_logger.error.called
|
|
assert partial_file.exists()
|
|
|
|
captured = capsys.readouterr()
|
|
assert "60 seconds" in captured.out
|
|
assert "Partial analysis saved" in captured.out
|
|
|
|
def test_successful_execution_no_partial_save(self, temp_output_dir):
|
|
"""Test that successful execution doesn't save partial state."""
|
|
# When execution is successful, no partial analysis should be saved
|
|
# Only save on error
|
|
|
|
output_dir = temp_output_dir
|
|
before_files = set(output_dir.glob("*.json"))
|
|
|
|
# Simulate successful execution
|
|
# ... normal flow ...
|
|
|
|
after_files = set(output_dir.glob("*.json"))
|
|
|
|
# No new partial files should be created
|
|
assert len(after_files - before_files) == 0
|
|
|
|
@patch('main.TradingAgentsGraph')
|
|
def test_error_during_stream_operation(self, mock_graph_class):
|
|
"""Test error handling during graph.stream() operation."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
|
|
mock_instance = Mock()
|
|
|
|
# Mock stream to yield some items then raise error
|
|
def stream_generator():
|
|
yield {"step": 1, "data": "first"}
|
|
yield {"step": 2, "data": "second"}
|
|
raise OpenAIRateLimitError("Rate limit exceeded", retry_after=30)
|
|
|
|
mock_instance.stream = Mock(return_value=stream_generator())
|
|
mock_graph_class.return_value = mock_instance
|
|
|
|
# Collect partial results before error
|
|
partial_results = []
|
|
|
|
try:
|
|
for item in mock_instance.stream("AAPL", "2024-12-26"):
|
|
partial_results.append(item)
|
|
except OpenAIRateLimitError as e:
|
|
# Should have partial results
|
|
assert len(partial_results) == 2
|
|
assert e.retry_after == 30
|
|
|
|
|
|
# ============================================================================
|
|
# Edge Cases
|
|
# ============================================================================
|
|
|
|
class TestErrorHandlingEdgeCases:
|
|
"""Test edge cases in error handling."""
|
|
|
|
def test_rate_limit_error_without_retry_after(self):
|
|
"""Test handling rate limit error when retry_after is not provided."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
from tradingagents.utils.error_messages import format_rate_limit_error
|
|
|
|
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=None)
|
|
message = format_rate_limit_error(error)
|
|
|
|
# Should provide generic retry guidance
|
|
assert "retry" in message.lower() or "later" in message.lower()
|
|
|
|
def test_multiple_consecutive_rate_limit_errors(self):
|
|
"""Test handling multiple rate limit errors in a row."""
|
|
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
|
|
|
errors = []
|
|
for i in range(3):
|
|
errors.append(OpenAIRateLimitError(
|
|
f"Rate limit exceeded (attempt {i+1})",
|
|
retry_after=60 * (i+1) # Increasing backoff
|
|
))
|
|
|
|
# Each error should be handled independently
|
|
for i, error in enumerate(errors):
|
|
assert error.retry_after == 60 * (i+1)
|
|
|
|
def test_error_during_partial_save(self, temp_output_dir):
|
|
"""Test handling when saving partial analysis itself fails."""
|
|
from tradingagents.utils.error_recovery import save_partial_analysis
|
|
|
|
# Try to save to invalid location
|
|
invalid_file = "/root/cannot/write/here.json"
|
|
|
|
state = {"ticker": "AAPL", "data": "test"}
|
|
|
|
# Should handle gracefully and not crash
|
|
try:
|
|
save_partial_analysis(state, invalid_file)
|
|
except (PermissionError, OSError) as e:
|
|
# Expected - cannot write to invalid location
|
|
pass
|
|
|
|
def test_unicode_in_error_messages(self, temp_output_dir):
|
|
"""Test handling unicode characters in error messages."""
|
|
from tradingagents.utils.logging_config import setup_dual_logger
|
|
|
|
log_file = temp_output_dir / "unicode_error.log"
|
|
logger = setup_dual_logger(name="unicode_test", log_file=str(log_file))
|
|
|
|
error_msg = "Rate limit exceeded for model 你好-gpt-4"
|
|
logger.error(error_msg)
|
|
|
|
content = log_file.read_text(encoding='utf-8')
|
|
assert "Rate limit" in content
|