feat(logging): add dual-output logging and rate limit error handling - Fixes #39
This commit is contained in:
parent
d8093aa889
commit
bb0ea33100
|
|
@ -8,6 +8,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Rate limit error handling for LLM APIs (Issue #39)
|
||||
- Unified exception hierarchy for handling rate limit errors across providers (OpenAI, Anthropic, OpenRouter) [file:tradingagents/utils/exceptions.py](tradingagents/utils/exceptions.py)
|
||||
- Dual-output logging configuration supporting both terminal and file outputs [file:tradingagents/utils/logging_config.py](tradingagents/utils/logging_config.py)
|
||||
- Automatic rotating log files with 5MB rotation and 3 backups
|
||||
- Terminal logging at INFO level and file logging at DEBUG level
|
||||
- API key sanitization in log messages to prevent credential leaks
|
||||
- Error recovery utilities for saving partial analysis state on errors [file:tradingagents/utils/error_recovery.py](tradingagents/utils/error_recovery.py)
|
||||
- User-friendly error message formatting for rate limit errors [file:tradingagents/utils/error_messages.py](tradingagents/utils/error_messages.py)
|
||||
- Comprehensive test suite for exceptions and logging configuration [file:tests/test_exceptions.py](tests/test_exceptions.py) [file:tests/test_logging_config.py](tests/test_logging_config.py)
|
||||
- OpenRouter API provider support for unified access to multiple LLM models
|
||||
- Support for `provider/model-name` format (e.g., `anthropic/claude-sonnet-4.5`)
|
||||
- Proper API key handling with OPENROUTER_API_KEY environment variable
|
||||
|
|
|
|||
46
README.md
46
README.md
|
|
@ -289,6 +289,52 @@ print(decision)
|
|||
|
||||
You can view the full list of configurations in `tradingagents/default_config.py`.
|
||||
|
||||
### Error Handling and Logging
|
||||
|
||||
TradingAgents includes robust error handling for rate limit errors and comprehensive logging capabilities to help you monitor and debug your trading analysis.
|
||||
|
||||
#### Rate Limit Error Handling
|
||||
|
||||
The framework automatically handles rate limit errors from LLM providers (OpenAI, Anthropic, OpenRouter) through a unified exception hierarchy. When a rate limit is encountered:
|
||||
|
||||
1. The error is caught and processed by `tradingagents/utils/exceptions.py`
|
||||
2. Partial analysis state is automatically saved to allow resuming work
|
||||
3. User-friendly error messages guide you on retry timing
|
||||
|
||||
```python
|
||||
from tradingagents.utils.exceptions import LLMRateLimitError
|
||||
|
||||
try:
|
||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
except LLMRateLimitError as e:
|
||||
print(f"Rate limit: {e.message}")
|
||||
if e.retry_after:
|
||||
print(f"Retry after {e.retry_after} seconds")
|
||||
```
|
||||
|
||||
#### Dual-Output Logging
|
||||
|
||||
TradingAgents logs to both terminal and rotating log files for comprehensive monitoring:
|
||||
|
||||
- **Terminal logging** at INFO level shows real-time progress
|
||||
- **File logging** at DEBUG level provides detailed troubleshooting information
|
||||
- **Log rotation** automatically manages files at 5MB with 3 backups
|
||||
- **API key sanitization** automatically redacts sensitive credentials in logs
|
||||
|
||||
Logs are saved to the `TRADINGAGENTS_RESULTS_DIR` environment variable or `./logs` by default. Access logs with:
|
||||
|
||||
```bash
|
||||
# View recent logs
|
||||
tail -f ./logs/tradingagents.log
|
||||
|
||||
# Search for errors
|
||||
grep ERROR ./logs/tradingagents.log
|
||||
```
|
||||
|
||||
#### Partial Analysis Saving
|
||||
|
||||
If an error occurs during analysis, partial results are automatically saved, allowing you to inspect completed work and resume processing. Partial results are saved to the results directory in JSON format.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||
|
|
|
|||
479
cli/main.py
479
cli/main.py
|
|
@ -26,6 +26,10 @@ from rich.rule import Rule
|
|||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.utils.exceptions import LLMRateLimitError
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
from tradingagents.utils.error_recovery import save_partial_analysis, get_partial_analysis_filename
|
||||
from tradingagents.utils.error_messages import format_rate_limit_error, format_error_with_partial_save
|
||||
from cli.models import AnalystType
|
||||
from cli.utils import *
|
||||
|
||||
|
|
@ -761,6 +765,13 @@ def run_analysis():
|
|||
log_file = results_dir / "message_tool.log"
|
||||
log_file.touch(exist_ok=True)
|
||||
|
||||
# Setup dual logger (terminal + file)
|
||||
logger = setup_dual_logger(
|
||||
name="tradingagents.cli",
|
||||
log_file=str(results_dir / "logs" / "tradingagents.log")
|
||||
)
|
||||
logger.info(f"Starting analysis for {selections['ticker']} on {selections['analysis_date']}")
|
||||
|
||||
def save_message_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
@wraps(func)
|
||||
|
|
@ -847,235 +858,275 @@ def run_analysis():
|
|||
|
||||
# Stream the analysis
|
||||
trace = []
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) > 0:
|
||||
# Get the last message from the chunk
|
||||
last_message = chunk["messages"][-1]
|
||||
try:
|
||||
for chunk in graph.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) > 0:
|
||||
# Get the last message from the chunk
|
||||
last_message = chunk["messages"][-1]
|
||||
|
||||
# Extract message content and type
|
||||
if hasattr(last_message, "content"):
|
||||
content = extract_content_string(last_message.content) # Use the helper function
|
||||
msg_type = "Reasoning"
|
||||
else:
|
||||
content = str(last_message)
|
||||
msg_type = "System"
|
||||
# Extract message content and type
|
||||
if hasattr(last_message, "content"):
|
||||
content = extract_content_string(last_message.content) # Use the helper function
|
||||
msg_type = "Reasoning"
|
||||
else:
|
||||
content = str(last_message)
|
||||
msg_type = "System"
|
||||
|
||||
# Add message to buffer
|
||||
message_buffer.add_message(msg_type, content)
|
||||
# Add message to buffer
|
||||
message_buffer.add_message(msg_type, content)
|
||||
|
||||
# If it's a tool call, add it to tool calls
|
||||
if hasattr(last_message, "tool_calls"):
|
||||
for tool_call in last_message.tool_calls:
|
||||
# Handle both dictionary and object tool calls
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(
|
||||
tool_call["name"], tool_call["args"]
|
||||
)
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
# If it's a tool call, add it to tool calls
|
||||
if hasattr(last_message, "tool_calls"):
|
||||
for tool_call in last_message.tool_calls:
|
||||
# Handle both dictionary and object tool calls
|
||||
if isinstance(tool_call, dict):
|
||||
message_buffer.add_tool_call(
|
||||
tool_call["name"], tool_call["args"]
|
||||
)
|
||||
else:
|
||||
message_buffer.add_tool_call(tool_call.name, tool_call.args)
|
||||
|
||||
# Update reports and agent status based on chunk content
|
||||
# Analyst Team Reports
|
||||
if "market_report" in chunk and chunk["market_report"]:
|
||||
message_buffer.update_report_section(
|
||||
"market_report", chunk["market_report"]
|
||||
)
|
||||
message_buffer.update_agent_status("Market Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "social" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"Social Analyst", "in_progress"
|
||||
)
|
||||
|
||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
||||
message_buffer.update_report_section(
|
||||
"sentiment_report", chunk["sentiment_report"]
|
||||
)
|
||||
message_buffer.update_agent_status("Social Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "news" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"News Analyst", "in_progress"
|
||||
)
|
||||
|
||||
if "news_report" in chunk and chunk["news_report"]:
|
||||
message_buffer.update_report_section(
|
||||
"news_report", chunk["news_report"]
|
||||
)
|
||||
message_buffer.update_agent_status("News Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "fundamentals" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"Fundamentals Analyst", "in_progress"
|
||||
)
|
||||
|
||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
||||
message_buffer.update_report_section(
|
||||
"fundamentals_report", chunk["fundamentals_report"]
|
||||
)
|
||||
message_buffer.update_agent_status(
|
||||
"Fundamentals Analyst", "completed"
|
||||
)
|
||||
# Set all research team members to in_progress
|
||||
update_research_team_status("in_progress")
|
||||
|
||||
# Research Team - Handle Investment Debate State
|
||||
if (
|
||||
"investment_debate_state" in chunk
|
||||
and chunk["investment_debate_state"]
|
||||
):
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
|
||||
# Update Bull Researcher status and report
|
||||
if "bull_history" in debate_state and debate_state["bull_history"]:
|
||||
# Keep all research team members in progress
|
||||
update_research_team_status("in_progress")
|
||||
# Extract latest bull response
|
||||
bull_responses = debate_state["bull_history"].split("\n")
|
||||
latest_bull = bull_responses[-1] if bull_responses else ""
|
||||
if latest_bull:
|
||||
message_buffer.add_message("Reasoning", latest_bull)
|
||||
# Update research report with bull's latest analysis
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"### Bull Researcher Analysis\n{latest_bull}",
|
||||
)
|
||||
|
||||
# Update Bear Researcher status and report
|
||||
if "bear_history" in debate_state and debate_state["bear_history"]:
|
||||
# Keep all research team members in progress
|
||||
update_research_team_status("in_progress")
|
||||
# Extract latest bear response
|
||||
bear_responses = debate_state["bear_history"].split("\n")
|
||||
latest_bear = bear_responses[-1] if bear_responses else ""
|
||||
if latest_bear:
|
||||
message_buffer.add_message("Reasoning", latest_bear)
|
||||
# Update research report with bear's latest analysis
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}",
|
||||
)
|
||||
|
||||
# Update Research Manager status and final decision
|
||||
if (
|
||||
"judge_decision" in debate_state
|
||||
and debate_state["judge_decision"]
|
||||
):
|
||||
# Keep all research team members in progress until final decision
|
||||
update_research_team_status("in_progress")
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Research Manager: {debate_state['judge_decision']}",
|
||||
)
|
||||
# Update research report with final decision
|
||||
# Update reports and agent status based on chunk content
|
||||
# Analyst Team Reports
|
||||
if "market_report" in chunk and chunk["market_report"]:
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
|
||||
"market_report", chunk["market_report"]
|
||||
)
|
||||
message_buffer.update_agent_status("Market Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "social" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"Social Analyst", "in_progress"
|
||||
)
|
||||
|
||||
if "sentiment_report" in chunk and chunk["sentiment_report"]:
|
||||
message_buffer.update_report_section(
|
||||
"sentiment_report", chunk["sentiment_report"]
|
||||
)
|
||||
message_buffer.update_agent_status("Social Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "news" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"News Analyst", "in_progress"
|
||||
)
|
||||
|
||||
if "news_report" in chunk and chunk["news_report"]:
|
||||
message_buffer.update_report_section(
|
||||
"news_report", chunk["news_report"]
|
||||
)
|
||||
message_buffer.update_agent_status("News Analyst", "completed")
|
||||
# Set next analyst to in_progress
|
||||
if "fundamentals" in selections["analysts"]:
|
||||
message_buffer.update_agent_status(
|
||||
"Fundamentals Analyst", "in_progress"
|
||||
)
|
||||
|
||||
if "fundamentals_report" in chunk and chunk["fundamentals_report"]:
|
||||
message_buffer.update_report_section(
|
||||
"fundamentals_report", chunk["fundamentals_report"]
|
||||
)
|
||||
message_buffer.update_agent_status(
|
||||
"Fundamentals Analyst", "completed"
|
||||
)
|
||||
# Set all research team members to in_progress
|
||||
update_research_team_status("in_progress")
|
||||
|
||||
# Research Team - Handle Investment Debate State
|
||||
if (
|
||||
"investment_debate_state" in chunk
|
||||
and chunk["investment_debate_state"]
|
||||
):
|
||||
debate_state = chunk["investment_debate_state"]
|
||||
|
||||
# Update Bull Researcher status and report
|
||||
if "bull_history" in debate_state and debate_state["bull_history"]:
|
||||
# Keep all research team members in progress
|
||||
update_research_team_status("in_progress")
|
||||
# Extract latest bull response
|
||||
bull_responses = debate_state["bull_history"].split("\n")
|
||||
latest_bull = bull_responses[-1] if bull_responses else ""
|
||||
if latest_bull:
|
||||
message_buffer.add_message("Reasoning", latest_bull)
|
||||
# Update research report with bull's latest analysis
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"### Bull Researcher Analysis\n{latest_bull}",
|
||||
)
|
||||
|
||||
# Update Bear Researcher status and report
|
||||
if "bear_history" in debate_state and debate_state["bear_history"]:
|
||||
# Keep all research team members in progress
|
||||
update_research_team_status("in_progress")
|
||||
# Extract latest bear response
|
||||
bear_responses = debate_state["bear_history"].split("\n")
|
||||
latest_bear = bear_responses[-1] if bear_responses else ""
|
||||
if latest_bear:
|
||||
message_buffer.add_message("Reasoning", latest_bear)
|
||||
# Update research report with bear's latest analysis
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"{message_buffer.report_sections['investment_plan']}\n\n### Bear Researcher Analysis\n{latest_bear}",
|
||||
)
|
||||
|
||||
# Update Research Manager status and final decision
|
||||
if (
|
||||
"judge_decision" in debate_state
|
||||
and debate_state["judge_decision"]
|
||||
):
|
||||
# Keep all research team members in progress until final decision
|
||||
update_research_team_status("in_progress")
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Research Manager: {debate_state['judge_decision']}",
|
||||
)
|
||||
# Update research report with final decision
|
||||
message_buffer.update_report_section(
|
||||
"investment_plan",
|
||||
f"{message_buffer.report_sections['investment_plan']}\n\n### Research Manager Decision\n{debate_state['judge_decision']}",
|
||||
)
|
||||
# Mark all research team members as completed
|
||||
update_research_team_status("completed")
|
||||
# Set first risk analyst to in_progress
|
||||
message_buffer.update_agent_status(
|
||||
"Risky Analyst", "in_progress"
|
||||
)
|
||||
|
||||
# Trading Team
|
||||
if (
|
||||
"trader_investment_plan" in chunk
|
||||
and chunk["trader_investment_plan"]
|
||||
):
|
||||
message_buffer.update_report_section(
|
||||
"trader_investment_plan", chunk["trader_investment_plan"]
|
||||
)
|
||||
# Mark all research team members as completed
|
||||
update_research_team_status("completed")
|
||||
# Set first risk analyst to in_progress
|
||||
message_buffer.update_agent_status(
|
||||
"Risky Analyst", "in_progress"
|
||||
)
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
|
||||
# Trading Team
|
||||
if (
|
||||
"trader_investment_plan" in chunk
|
||||
and chunk["trader_investment_plan"]
|
||||
):
|
||||
message_buffer.update_report_section(
|
||||
"trader_investment_plan", chunk["trader_investment_plan"]
|
||||
)
|
||||
# Set first risk analyst to in_progress
|
||||
message_buffer.update_agent_status("Risky Analyst", "in_progress")
|
||||
# Risk Management Team - Handle Risk Debate State
|
||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
||||
risk_state = chunk["risk_debate_state"]
|
||||
|
||||
# Risk Management Team - Handle Risk Debate State
|
||||
if "risk_debate_state" in chunk and chunk["risk_debate_state"]:
|
||||
risk_state = chunk["risk_debate_state"]
|
||||
# Update Risky Analyst status and report
|
||||
if (
|
||||
"current_risky_response" in risk_state
|
||||
and risk_state["current_risky_response"]
|
||||
):
|
||||
message_buffer.update_agent_status(
|
||||
"Risky Analyst", "in_progress"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Risky Analyst: {risk_state['current_risky_response']}",
|
||||
)
|
||||
# Update risk report with risky analyst's latest analysis only
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}",
|
||||
)
|
||||
|
||||
# Update Risky Analyst status and report
|
||||
if (
|
||||
"current_risky_response" in risk_state
|
||||
and risk_state["current_risky_response"]
|
||||
):
|
||||
message_buffer.update_agent_status(
|
||||
"Risky Analyst", "in_progress"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Risky Analyst: {risk_state['current_risky_response']}",
|
||||
)
|
||||
# Update risk report with risky analyst's latest analysis only
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Risky Analyst Analysis\n{risk_state['current_risky_response']}",
|
||||
)
|
||||
# Update Safe Analyst status and report
|
||||
if (
|
||||
"current_safe_response" in risk_state
|
||||
and risk_state["current_safe_response"]
|
||||
):
|
||||
message_buffer.update_agent_status(
|
||||
"Safe Analyst", "in_progress"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Safe Analyst: {risk_state['current_safe_response']}",
|
||||
)
|
||||
# Update risk report with safe analyst's latest analysis only
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}",
|
||||
)
|
||||
|
||||
# Update Safe Analyst status and report
|
||||
if (
|
||||
"current_safe_response" in risk_state
|
||||
and risk_state["current_safe_response"]
|
||||
):
|
||||
message_buffer.update_agent_status(
|
||||
"Safe Analyst", "in_progress"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Safe Analyst: {risk_state['current_safe_response']}",
|
||||
)
|
||||
# Update risk report with safe analyst's latest analysis only
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Safe Analyst Analysis\n{risk_state['current_safe_response']}",
|
||||
)
|
||||
# Update Neutral Analyst status and report
|
||||
if (
|
||||
"current_neutral_response" in risk_state
|
||||
and risk_state["current_neutral_response"]
|
||||
):
|
||||
message_buffer.update_agent_status(
|
||||
"Neutral Analyst", "in_progress"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Neutral Analyst: {risk_state['current_neutral_response']}",
|
||||
)
|
||||
# Update risk report with neutral analyst's latest analysis only
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}",
|
||||
)
|
||||
|
||||
# Update Neutral Analyst status and report
|
||||
if (
|
||||
"current_neutral_response" in risk_state
|
||||
and risk_state["current_neutral_response"]
|
||||
):
|
||||
message_buffer.update_agent_status(
|
||||
"Neutral Analyst", "in_progress"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Neutral Analyst: {risk_state['current_neutral_response']}",
|
||||
)
|
||||
# Update risk report with neutral analyst's latest analysis only
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Neutral Analyst Analysis\n{risk_state['current_neutral_response']}",
|
||||
)
|
||||
# Update Portfolio Manager status and final decision
|
||||
if "judge_decision" in risk_state and risk_state["judge_decision"]:
|
||||
message_buffer.update_agent_status(
|
||||
"Portfolio Manager", "in_progress"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Portfolio Manager: {risk_state['judge_decision']}",
|
||||
)
|
||||
# Update risk report with final decision only
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
|
||||
)
|
||||
# Mark risk analysts as completed
|
||||
message_buffer.update_agent_status("Risky Analyst", "completed")
|
||||
message_buffer.update_agent_status("Safe Analyst", "completed")
|
||||
message_buffer.update_agent_status(
|
||||
"Neutral Analyst", "completed"
|
||||
)
|
||||
message_buffer.update_agent_status(
|
||||
"Portfolio Manager", "completed"
|
||||
)
|
||||
|
||||
# Update Portfolio Manager status and final decision
|
||||
if "judge_decision" in risk_state and risk_state["judge_decision"]:
|
||||
message_buffer.update_agent_status(
|
||||
"Portfolio Manager", "in_progress"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"Reasoning",
|
||||
f"Portfolio Manager: {risk_state['judge_decision']}",
|
||||
)
|
||||
# Update risk report with final decision only
|
||||
message_buffer.update_report_section(
|
||||
"final_trade_decision",
|
||||
f"### Portfolio Manager Decision\n{risk_state['judge_decision']}",
|
||||
)
|
||||
# Mark risk analysts as completed
|
||||
message_buffer.update_agent_status("Risky Analyst", "completed")
|
||||
message_buffer.update_agent_status("Safe Analyst", "completed")
|
||||
message_buffer.update_agent_status(
|
||||
"Neutral Analyst", "completed"
|
||||
)
|
||||
message_buffer.update_agent_status(
|
||||
"Portfolio Manager", "completed"
|
||||
)
|
||||
# Update the display
|
||||
update_display(layout)
|
||||
|
||||
# Update the display
|
||||
update_display(layout)
|
||||
trace.append(chunk)
|
||||
|
||||
trace.append(chunk)
|
||||
except LLMRateLimitError as e:
|
||||
# Handle rate limit errors gracefully
|
||||
logger.error(f"Rate limit error: {str(e)}")
|
||||
logger.info(f"Provider: {e.provider}, Retry after: {e.retry_after} seconds")
|
||||
|
||||
# Save partial analysis to JSON
|
||||
partial_state = {
|
||||
"ticker": selections["ticker"],
|
||||
"analysis_date": selections["analysis_date"],
|
||||
"error": str(e),
|
||||
"error_timestamp": datetime.datetime.now().isoformat(),
|
||||
"retry_after": e.retry_after,
|
||||
"provider": e.provider,
|
||||
"trace": trace, # Include work completed so far
|
||||
"agent_status": dict(message_buffer.agent_status),
|
||||
"report_sections": {k: v for k, v in message_buffer.report_sections.items() if v is not None},
|
||||
}
|
||||
|
||||
partial_filename = get_partial_analysis_filename(
|
||||
selections["ticker"],
|
||||
datetime.datetime.now(),
|
||||
str(results_dir)
|
||||
)
|
||||
save_partial_analysis(partial_state, partial_filename)
|
||||
logger.info(f"Partial analysis saved to: {partial_filename}")
|
||||
|
||||
# Display user-friendly error message
|
||||
error_message = format_rate_limit_error(e)
|
||||
full_message = format_error_with_partial_save(error_message, partial_filename)
|
||||
|
||||
console.print(Panel(
|
||||
full_message,
|
||||
title="[bold red]Analysis Interrupted - Rate Limit[/bold red]",
|
||||
border_style="red",
|
||||
))
|
||||
|
||||
# Re-raise to prevent continuing with incomplete data
|
||||
raise
|
||||
|
||||
# Get final state and decision
|
||||
final_state = trace[-1]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,701 @@
|
|||
"""
|
||||
Test suite for CLI Error Handling with Rate Limit Errors.
|
||||
|
||||
This module tests:
|
||||
1. Rate limit errors are caught and logged in main.py
|
||||
2. Partial analysis is saved to JSON file when error occurs
|
||||
3. User sees appropriate error message with retry guidance
|
||||
4. Both terminal and file receive error logs
|
||||
5. Integration with graph.stream() error handling
|
||||
6. Error translation from provider errors to unified exceptions
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock, call
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_dir():
|
||||
"""Create a temporary directory for output files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph():
|
||||
"""Create a mock TradingAgentsGraph."""
|
||||
mock = Mock()
|
||||
mock.propagate = Mock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Create a mock configuration."""
|
||||
return {
|
||||
"llm_provider": "openrouter",
|
||||
"deep_think_llm": "anthropic/claude-opus-4.5",
|
||||
"quick_think_llm": "anthropic/claude-haiku-3.5",
|
||||
"backend_url": "https://openrouter.ai/api/v1",
|
||||
"max_debate_rounds": 1,
|
||||
"data_vendors": {
|
||||
"core_stock_apis": "yfinance",
|
||||
"technical_indicators": "yfinance",
|
||||
"fundamental_data": "yfinance",
|
||||
"news_data": "google",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_partial_state():
|
||||
"""Create a sample partial state for testing."""
|
||||
return {
|
||||
"ticker": "AAPL",
|
||||
"analysis_date": "2024-12-26",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Starting analysis"},
|
||||
{"role": "assistant", "content": "Fetched market data"},
|
||||
],
|
||||
"analyst_reports": {
|
||||
"market": {"summary": "Bullish trend", "confidence": 0.8}
|
||||
},
|
||||
"error": "Rate limit exceeded",
|
||||
"error_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Rate Limit Error Catching in main.py
|
||||
# ============================================================================
|
||||
|
||||
class TestMainRateLimitErrorHandling:
|
||||
"""Test error handling in main.py around graph.stream()."""
|
||||
|
||||
@patch('main.TradingAgentsGraph')
|
||||
def test_catches_rate_limit_error_from_openai(self, mock_graph_class, temp_output_dir):
|
||||
"""Test that OpenAI rate limit errors are caught in main.py."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
|
||||
# Setup mock to raise rate limit error
|
||||
mock_instance = Mock()
|
||||
mock_instance.propagate.side_effect = OpenAIRateLimitError(
|
||||
"Rate limit exceeded for gpt-4",
|
||||
retry_after=60,
|
||||
)
|
||||
mock_graph_class.return_value = mock_instance
|
||||
|
||||
# This import will fail initially (TDD RED phase)
|
||||
# The main.py needs to be modified to catch these errors
|
||||
# For now, we're testing the expected behavior
|
||||
|
||||
with pytest.raises(OpenAIRateLimitError) as exc_info:
|
||||
mock_instance.propagate("AAPL", "2024-12-26")
|
||||
|
||||
assert exc_info.value.retry_after == 60
|
||||
assert exc_info.value.provider == "openai"
|
||||
|
||||
@patch('main.TradingAgentsGraph')
|
||||
def test_catches_rate_limit_error_from_anthropic(self, mock_graph_class):
|
||||
"""Test that Anthropic rate limit errors are caught."""
|
||||
from tradingagents.utils.exceptions import AnthropicRateLimitError
|
||||
|
||||
mock_instance = Mock()
|
||||
mock_instance.propagate.side_effect = AnthropicRateLimitError(
|
||||
"Rate limit exceeded for claude-opus-4.5",
|
||||
retry_after=120,
|
||||
)
|
||||
mock_graph_class.return_value = mock_instance
|
||||
|
||||
with pytest.raises(AnthropicRateLimitError) as exc_info:
|
||||
mock_instance.propagate("AAPL", "2024-12-26")
|
||||
|
||||
assert exc_info.value.retry_after == 120
|
||||
assert exc_info.value.provider == "anthropic"
|
||||
|
||||
@patch('main.TradingAgentsGraph')
|
||||
def test_catches_rate_limit_error_from_openrouter(self, mock_graph_class):
|
||||
"""Test that OpenRouter rate limit errors are caught."""
|
||||
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
||||
|
||||
mock_instance = Mock()
|
||||
mock_instance.propagate.side_effect = OpenRouterRateLimitError(
|
||||
"Rate limit exceeded for anthropic/claude-opus-4.5",
|
||||
retry_after=45,
|
||||
)
|
||||
mock_graph_class.return_value = mock_instance
|
||||
|
||||
with pytest.raises(OpenRouterRateLimitError) as exc_info:
|
||||
mock_instance.propagate("AAPL", "2024-12-26")
|
||||
|
||||
assert exc_info.value.retry_after == 45
|
||||
assert exc_info.value.provider == "openrouter"
|
||||
|
||||
@patch('main.TradingAgentsGraph')
|
||||
@patch('main.setup_dual_logger')
|
||||
def test_rate_limit_error_is_logged(self, mock_logger_setup, mock_graph_class, temp_output_dir):
|
||||
"""Test that rate limit errors are logged."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
|
||||
# Setup mock logger
|
||||
mock_logger = Mock()
|
||||
mock_logger_setup.return_value = mock_logger
|
||||
|
||||
# Setup mock to raise error
|
||||
mock_instance = Mock()
|
||||
mock_instance.propagate.side_effect = OpenAIRateLimitError(
|
||||
"Rate limit exceeded",
|
||||
retry_after=60,
|
||||
)
|
||||
mock_graph_class.return_value = mock_instance
|
||||
|
||||
# In the modified main.py, the error should be caught and logged
|
||||
# This test validates the expected logging behavior
|
||||
|
||||
try:
|
||||
mock_instance.propagate("AAPL", "2024-12-26")
|
||||
except OpenAIRateLimitError as e:
|
||||
# Simulate what main.py should do
|
||||
mock_logger.error(f"Rate limit error: {str(e)}")
|
||||
mock_logger.info(f"Retry after: {e.retry_after} seconds")
|
||||
|
||||
# Verify logging calls
|
||||
assert mock_logger.error.called
|
||||
assert mock_logger.info.called
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Partial Analysis Saving
|
||||
# ============================================================================
|
||||
|
||||
class TestPartialAnalysisSaving:
|
||||
"""Test saving partial analysis to JSON when error occurs."""
|
||||
|
||||
def test_saves_partial_state_to_json(self, temp_output_dir, sample_partial_state):
|
||||
"""Test that partial state is saved to JSON file on error."""
|
||||
# This function would be in main.py or a utility module
|
||||
from tradingagents.utils.error_recovery import save_partial_analysis
|
||||
|
||||
output_file = temp_output_dir / "partial_analysis.json"
|
||||
|
||||
save_partial_analysis(sample_partial_state, str(output_file))
|
||||
|
||||
assert output_file.exists()
|
||||
|
||||
with open(output_file, 'r') as f:
|
||||
loaded_state = json.load(f)
|
||||
|
||||
assert loaded_state["ticker"] == "AAPL"
|
||||
assert loaded_state["analysis_date"] == "2024-12-26"
|
||||
assert "error" in loaded_state
|
||||
|
||||
def test_partial_state_includes_error_info(self, temp_output_dir):
|
||||
"""Test that saved partial state includes error information."""
|
||||
from tradingagents.utils.error_recovery import save_partial_analysis
|
||||
|
||||
state_with_error = {
|
||||
"ticker": "TSLA",
|
||||
"error": "Rate limit exceeded for gpt-4",
|
||||
"error_timestamp": datetime.now().isoformat(),
|
||||
"retry_after": 60,
|
||||
"provider": "openai"
|
||||
}
|
||||
|
||||
output_file = temp_output_dir / "error_state.json"
|
||||
save_partial_analysis(state_with_error, str(output_file))
|
||||
|
||||
with open(output_file, 'r') as f:
|
||||
loaded = json.load(f)
|
||||
|
||||
assert loaded["error"] == "Rate limit exceeded for gpt-4"
|
||||
assert loaded["retry_after"] == 60
|
||||
assert loaded["provider"] == "openai"
|
||||
assert "error_timestamp" in loaded
|
||||
|
||||
def test_partial_state_includes_completed_work(self, temp_output_dir, sample_partial_state):
|
||||
"""Test that partial state includes work completed before error."""
|
||||
from tradingagents.utils.error_recovery import save_partial_analysis
|
||||
|
||||
output_file = temp_output_dir / "partial.json"
|
||||
save_partial_analysis(sample_partial_state, str(output_file))
|
||||
|
||||
with open(output_file, 'r') as f:
|
||||
loaded = json.load(f)
|
||||
|
||||
assert "analyst_reports" in loaded
|
||||
assert "market" in loaded["analyst_reports"]
|
||||
assert loaded["analyst_reports"]["market"]["summary"] == "Bullish trend"
|
||||
|
||||
def test_default_output_filename_format(self, temp_output_dir):
|
||||
"""Test that default output filename includes ticker and timestamp."""
|
||||
from tradingagents.utils.error_recovery import get_partial_analysis_filename
|
||||
|
||||
ticker = "AAPL"
|
||||
timestamp = datetime.now()
|
||||
|
||||
filename = get_partial_analysis_filename(ticker, timestamp)
|
||||
|
||||
assert ticker in filename
|
||||
assert filename.endswith(".json")
|
||||
assert "partial" in filename.lower() or "error" in filename.lower()
|
||||
|
||||
def test_overwrites_existing_partial_file(self, temp_output_dir):
|
||||
"""Test that saving overwrites existing partial analysis file."""
|
||||
from tradingagents.utils.error_recovery import save_partial_analysis
|
||||
|
||||
output_file = temp_output_dir / "partial.json"
|
||||
|
||||
# Save first version
|
||||
state_v1 = {"version": 1, "data": "first"}
|
||||
save_partial_analysis(state_v1, str(output_file))
|
||||
|
||||
# Save second version
|
||||
state_v2 = {"version": 2, "data": "second"}
|
||||
save_partial_analysis(state_v2, str(output_file))
|
||||
|
||||
with open(output_file, 'r') as f:
|
||||
loaded = json.load(f)
|
||||
|
||||
assert loaded["version"] == 2
|
||||
assert loaded["data"] == "second"
|
||||
|
||||
def test_handles_non_serializable_data(self, temp_output_dir):
|
||||
"""Test handling of non-JSON-serializable data in state."""
|
||||
from tradingagents.utils.error_recovery import save_partial_analysis
|
||||
|
||||
# Include a Mock object which isn't JSON serializable
|
||||
state = {
|
||||
"ticker": "AAPL",
|
||||
"mock_object": Mock(), # Not serializable
|
||||
"normal_data": "test"
|
||||
}
|
||||
|
||||
output_file = temp_output_dir / "partial.json"
|
||||
|
||||
# Should handle gracefully - either skip non-serializable or convert to string
|
||||
save_partial_analysis(state, str(output_file))
|
||||
|
||||
with open(output_file, 'r') as f:
|
||||
loaded = json.load(f)
|
||||
|
||||
assert loaded["ticker"] == "AAPL"
|
||||
assert loaded["normal_data"] == "test"
|
||||
# mock_object should be handled somehow (skipped or converted)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test User Error Messages
|
||||
# ============================================================================
|
||||
|
||||
class TestUserErrorMessages:
|
||||
"""Test user-facing error messages and guidance."""
|
||||
|
||||
def test_error_message_includes_retry_time(self):
|
||||
"""Test that error message includes retry_after time."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
from tradingagents.utils.error_messages import format_rate_limit_error
|
||||
|
||||
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=60)
|
||||
message = format_rate_limit_error(error)
|
||||
|
||||
assert "60" in message
|
||||
assert "second" in message.lower() or "sec" in message.lower()
|
||||
|
||||
def test_error_message_includes_provider(self):
|
||||
"""Test that error message identifies the provider."""
|
||||
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
||||
from tradingagents.utils.error_messages import format_rate_limit_error
|
||||
|
||||
error = OpenRouterRateLimitError("Rate limit exceeded", retry_after=45)
|
||||
message = format_rate_limit_error(error)
|
||||
|
||||
assert "openrouter" in message.lower() or "OpenRouter" in message
|
||||
|
||||
def test_error_message_suggests_retry(self):
|
||||
"""Test that error message suggests retrying."""
|
||||
from tradingagents.utils.exceptions import AnthropicRateLimitError
|
||||
from tradingagents.utils.error_messages import format_rate_limit_error
|
||||
|
||||
error = AnthropicRateLimitError("Rate limit exceeded", retry_after=120)
|
||||
message = format_rate_limit_error(error)
|
||||
|
||||
assert "retry" in message.lower() or "try again" in message.lower()
|
||||
|
||||
def test_error_message_without_retry_after(self):
|
||||
"""Test error message when retry_after is not provided."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
from tradingagents.utils.error_messages import format_rate_limit_error
|
||||
|
||||
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=None)
|
||||
message = format_rate_limit_error(error)
|
||||
|
||||
# Should provide generic guidance
|
||||
assert "later" in message.lower() or "wait" in message.lower()
|
||||
|
||||
def test_error_message_includes_partial_save_info(self, temp_output_dir):
|
||||
"""Test that error message mentions where partial analysis was saved."""
|
||||
from tradingagents.utils.error_messages import format_error_with_partial_save
|
||||
|
||||
error_msg = "Rate limit exceeded"
|
||||
partial_file = temp_output_dir / "partial_AAPL_20241226.json"
|
||||
|
||||
message = format_error_with_partial_save(error_msg, str(partial_file))
|
||||
|
||||
assert str(partial_file) in message or partial_file.name in message
|
||||
assert "saved" in message.lower()
|
||||
|
||||
def test_formats_retry_time_in_minutes(self):
|
||||
"""Test that large retry_after times are formatted in minutes."""
|
||||
from tradingagents.utils.error_messages import format_retry_time
|
||||
|
||||
# 300 seconds = 5 minutes
|
||||
formatted = format_retry_time(300)
|
||||
|
||||
assert "5" in formatted
|
||||
assert "minute" in formatted.lower()
|
||||
|
||||
def test_formats_retry_time_in_hours(self):
|
||||
"""Test that very large retry_after times are formatted in hours."""
|
||||
from tradingagents.utils.error_messages import format_retry_time
|
||||
|
||||
# 3600 seconds = 1 hour
|
||||
formatted = format_retry_time(3600)
|
||||
|
||||
assert "1" in formatted or "60" in formatted
|
||||
assert "hour" in formatted.lower() or "minute" in formatted.lower()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Dual Logging of Errors
|
||||
# ============================================================================
|
||||
|
||||
class TestDualLoggingOfErrors:
|
||||
"""Test that errors are logged to both terminal and file."""
|
||||
|
||||
@patch('tradingagents.utils.logging_config.setup_dual_logger')
|
||||
def test_error_logged_to_both_handlers(self, mock_logger_setup, temp_output_dir):
|
||||
"""Test that errors are sent to both terminal and file handlers."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
|
||||
# Create mock logger with two handlers
|
||||
mock_logger = Mock()
|
||||
mock_stream_handler = Mock(spec=logging.StreamHandler)
|
||||
mock_file_handler = Mock(spec=logging.FileHandler)
|
||||
|
||||
mock_logger.handlers = [mock_stream_handler, mock_file_handler]
|
||||
mock_logger_setup.return_value = mock_logger
|
||||
|
||||
# Simulate logging an error
|
||||
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=60)
|
||||
mock_logger.error(f"Rate limit error: {str(error)}")
|
||||
|
||||
# Both handlers should receive the message
|
||||
assert mock_logger.error.called
|
||||
|
||||
def test_terminal_shows_user_friendly_message(self, capsys):
|
||||
"""Test that terminal output is user-friendly."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
from tradingagents.utils.error_messages import print_user_error
|
||||
|
||||
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=60)
|
||||
|
||||
print_user_error(error)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
|
||||
# Should be user-friendly, not a raw stack trace
|
||||
assert "Rate limit" in captured.out or "Rate limit" in captured.err
|
||||
assert "60" in captured.out or "60" in captured.err
|
||||
|
||||
def test_file_contains_detailed_error_info(self, temp_output_dir):
|
||||
"""Test that file log contains detailed error information."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
||||
|
||||
log_file = temp_output_dir / "error_test.log"
|
||||
logger = setup_dual_logger(name="test_error_logger", log_file=str(log_file))
|
||||
|
||||
error = OpenRouterRateLimitError("Rate limit exceeded", retry_after=45)
|
||||
|
||||
logger.error(f"Rate limit error: {str(error)}")
|
||||
logger.error(f"Provider: {error.provider}")
|
||||
logger.error(f"Retry after: {error.retry_after} seconds")
|
||||
|
||||
content = log_file.read_text()
|
||||
|
||||
assert "Rate limit" in content
|
||||
assert "openrouter" in content.lower()
|
||||
assert "45" in content
|
||||
|
||||
def test_sanitization_applied_to_error_logs(self, temp_output_dir):
|
||||
"""Test that API keys in error messages are sanitized in logs."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger, sanitize_log_message
|
||||
|
||||
log_file = temp_output_dir / "sanitized_error.log"
|
||||
logger = setup_dual_logger(name="test_sanitize_logger", log_file=str(log_file))
|
||||
|
||||
# Simulate an error message that includes an API key
|
||||
error_msg = "Authentication failed with key sk-test1234567890"
|
||||
sanitized_msg = sanitize_log_message(error_msg)
|
||||
|
||||
logger.error(sanitized_msg)
|
||||
|
||||
content = log_file.read_text()
|
||||
|
||||
assert "sk-test1234567890" not in content
|
||||
assert "[REDACTED-API-KEY]" in content
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Error Translation in Graph Setup
|
||||
# ============================================================================
|
||||
|
||||
class TestGraphErrorTranslation:
|
||||
"""Test error translation layer in tradingagents/graph/setup.py."""
|
||||
|
||||
def test_translates_openai_native_error(self):
|
||||
"""Test translation of native OpenAI error to unified exception."""
|
||||
from tradingagents.graph.error_handler import translate_llm_error
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
|
||||
# Create a mock native OpenAI error
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": "60"}
|
||||
|
||||
translated = translate_llm_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(translated, OpenAIRateLimitError)
|
||||
assert translated.retry_after == 60
|
||||
|
||||
def test_translates_anthropic_native_error(self):
|
||||
"""Test translation of native Anthropic error to unified exception."""
|
||||
from tradingagents.graph.error_handler import translate_llm_error
|
||||
from tradingagents.utils.exceptions import AnthropicRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": "120"}
|
||||
|
||||
translated = translate_llm_error(mock_error, provider="anthropic")
|
||||
|
||||
assert isinstance(translated, AnthropicRateLimitError)
|
||||
assert translated.retry_after == 120
|
||||
|
||||
def test_translates_openrouter_native_error(self):
|
||||
"""Test translation of native OpenRouter error to unified exception."""
|
||||
from tradingagents.graph.error_handler import translate_llm_error
|
||||
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": "30"}
|
||||
|
||||
translated = translate_llm_error(mock_error, provider="openrouter")
|
||||
|
||||
assert isinstance(translated, OpenRouterRateLimitError)
|
||||
assert translated.retry_after == 30
|
||||
|
||||
@patch('tradingagents.graph.trading_graph.TradingAgentsGraph.propagate')
|
||||
def test_error_translation_in_propagate(self, mock_propagate):
|
||||
"""Test that errors raised in propagate are translated."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
|
||||
# Mock propagate to raise native error, which should be translated
|
||||
mock_native_error = Mock()
|
||||
mock_native_error.__class__.__name__ = "RateLimitError"
|
||||
mock_propagate.side_effect = mock_native_error
|
||||
|
||||
# The graph should translate this to our unified exception
|
||||
# This tests the integration point
|
||||
|
||||
def test_passes_through_non_rate_limit_errors(self):
|
||||
"""Test that non-rate-limit errors are not translated."""
|
||||
from tradingagents.graph.error_handler import translate_llm_error
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "APIError"
|
||||
mock_error.message = "Connection failed"
|
||||
|
||||
# Should raise ValueError or return None
|
||||
with pytest.raises(ValueError):
|
||||
translate_llm_error(mock_error, provider="openai")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestEndToEndErrorHandling:
|
||||
"""Test complete error handling flow from graph to user output."""
|
||||
|
||||
@patch('main.TradingAgentsGraph')
|
||||
@patch('main.setup_dual_logger')
|
||||
def test_complete_error_flow(self, mock_logger_setup, mock_graph_class, temp_output_dir, capsys):
|
||||
"""Test complete flow: error raised -> logged -> partial saved -> user notified."""
|
||||
from tradingagents.utils.exceptions import OpenRouterRateLimitError
|
||||
|
||||
# Setup mocks
|
||||
mock_logger = Mock()
|
||||
mock_logger_setup.return_value = mock_logger
|
||||
|
||||
mock_instance = Mock()
|
||||
mock_instance.propagate.side_effect = OpenRouterRateLimitError(
|
||||
"Rate limit exceeded for anthropic/claude-opus-4.5",
|
||||
retry_after=60,
|
||||
)
|
||||
mock_graph_class.return_value = mock_instance
|
||||
|
||||
# Simulate main.py execution
|
||||
try:
|
||||
state, decision = mock_instance.propagate("AAPL", "2024-12-26")
|
||||
except OpenRouterRateLimitError as e:
|
||||
# Log error
|
||||
mock_logger.error(f"Rate limit error: {str(e)}")
|
||||
|
||||
# Save partial state
|
||||
partial_file = temp_output_dir / f"partial_AAPL_{datetime.now().strftime('%Y%m%d')}.json"
|
||||
partial_state = {
|
||||
"ticker": "AAPL",
|
||||
"error": str(e),
|
||||
"retry_after": e.retry_after,
|
||||
"provider": e.provider,
|
||||
}
|
||||
|
||||
with open(partial_file, 'w') as f:
|
||||
json.dump(partial_state, f)
|
||||
|
||||
# Print user message
|
||||
print(f"\nError: {str(e)}")
|
||||
print(f"Please retry in {e.retry_after} seconds")
|
||||
print(f"Partial analysis saved to: {partial_file}")
|
||||
|
||||
# Verify all components
|
||||
assert mock_logger.error.called
|
||||
assert partial_file.exists()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "60 seconds" in captured.out
|
||||
assert "Partial analysis saved" in captured.out
|
||||
|
||||
def test_successful_execution_no_partial_save(self, temp_output_dir):
|
||||
"""Test that successful execution doesn't save partial state."""
|
||||
# When execution is successful, no partial analysis should be saved
|
||||
# Only save on error
|
||||
|
||||
output_dir = temp_output_dir
|
||||
before_files = set(output_dir.glob("*.json"))
|
||||
|
||||
# Simulate successful execution
|
||||
# ... normal flow ...
|
||||
|
||||
after_files = set(output_dir.glob("*.json"))
|
||||
|
||||
# No new partial files should be created
|
||||
assert len(after_files - before_files) == 0
|
||||
|
||||
@patch('main.TradingAgentsGraph')
|
||||
def test_error_during_stream_operation(self, mock_graph_class):
|
||||
"""Test error handling during graph.stream() operation."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
|
||||
mock_instance = Mock()
|
||||
|
||||
# Mock stream to yield some items then raise error
|
||||
def stream_generator():
|
||||
yield {"step": 1, "data": "first"}
|
||||
yield {"step": 2, "data": "second"}
|
||||
raise OpenAIRateLimitError("Rate limit exceeded", retry_after=30)
|
||||
|
||||
mock_instance.stream = Mock(return_value=stream_generator())
|
||||
mock_graph_class.return_value = mock_instance
|
||||
|
||||
# Collect partial results before error
|
||||
partial_results = []
|
||||
|
||||
try:
|
||||
for item in mock_instance.stream("AAPL", "2024-12-26"):
|
||||
partial_results.append(item)
|
||||
except OpenAIRateLimitError as e:
|
||||
# Should have partial results
|
||||
assert len(partial_results) == 2
|
||||
assert e.retry_after == 30
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Cases
|
||||
# ============================================================================
|
||||
|
||||
class TestErrorHandlingEdgeCases:
|
||||
"""Test edge cases in error handling."""
|
||||
|
||||
def test_rate_limit_error_without_retry_after(self):
|
||||
"""Test handling rate limit error when retry_after is not provided."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
from tradingagents.utils.error_messages import format_rate_limit_error
|
||||
|
||||
error = OpenAIRateLimitError("Rate limit exceeded", retry_after=None)
|
||||
message = format_rate_limit_error(error)
|
||||
|
||||
# Should provide generic retry guidance
|
||||
assert "retry" in message.lower() or "later" in message.lower()
|
||||
|
||||
def test_multiple_consecutive_rate_limit_errors(self):
|
||||
"""Test handling multiple rate limit errors in a row."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError
|
||||
|
||||
errors = []
|
||||
for i in range(3):
|
||||
errors.append(OpenAIRateLimitError(
|
||||
f"Rate limit exceeded (attempt {i+1})",
|
||||
retry_after=60 * (i+1) # Increasing backoff
|
||||
))
|
||||
|
||||
# Each error should be handled independently
|
||||
for i, error in enumerate(errors):
|
||||
assert error.retry_after == 60 * (i+1)
|
||||
|
||||
def test_error_during_partial_save(self, temp_output_dir):
|
||||
"""Test handling when saving partial analysis itself fails."""
|
||||
from tradingagents.utils.error_recovery import save_partial_analysis
|
||||
|
||||
# Try to save to invalid location
|
||||
invalid_file = "/root/cannot/write/here.json"
|
||||
|
||||
state = {"ticker": "AAPL", "data": "test"}
|
||||
|
||||
# Should handle gracefully and not crash
|
||||
try:
|
||||
save_partial_analysis(state, invalid_file)
|
||||
except (PermissionError, OSError) as e:
|
||||
# Expected - cannot write to invalid location
|
||||
pass
|
||||
|
||||
def test_unicode_in_error_messages(self, temp_output_dir):
|
||||
"""Test handling unicode characters in error messages."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_output_dir / "unicode_error.log"
|
||||
logger = setup_dual_logger(name="unicode_test", log_file=str(log_file))
|
||||
|
||||
error_msg = "Rate limit exceeded for model 你好-gpt-4"
|
||||
logger.error(error_msg)
|
||||
|
||||
content = log_file.read_text(encoding='utf-8')
|
||||
assert "Rate limit" in content
|
||||
|
|
@ -0,0 +1,505 @@
|
|||
"""
|
||||
Test suite for LLM Rate Limit Exception Hierarchy.
|
||||
|
||||
This module tests:
|
||||
1. LLMRateLimitError base class creation with message and retry_after
|
||||
2. Provider-specific exception classes (OpenAI, Anthropic, OpenRouter)
|
||||
3. from_provider_error() conversion from native provider exceptions
|
||||
4. Exception attribute validation (message, retry_after, provider)
|
||||
5. Exception inheritance chain
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Utilities
|
||||
# ============================================================================
|
||||
|
||||
def create_mock_openai_rate_limit_error(retry_after: Optional[int] = None):
|
||||
"""Create a mock OpenAI RateLimitError for testing."""
|
||||
error = Mock()
|
||||
error.__class__.__name__ = "RateLimitError"
|
||||
error.message = "Rate limit exceeded for model gpt-4"
|
||||
|
||||
# Mock response headers
|
||||
error.response = Mock()
|
||||
error.response.headers = {}
|
||||
if retry_after:
|
||||
error.response.headers["retry-after"] = str(retry_after)
|
||||
|
||||
return error
|
||||
|
||||
|
||||
def create_mock_anthropic_rate_limit_error(retry_after: Optional[int] = None):
|
||||
"""Create a mock Anthropic RateLimitError for testing."""
|
||||
error = Mock()
|
||||
error.__class__.__name__ = "RateLimitError"
|
||||
error.message = "Your request has exceeded the rate limit"
|
||||
|
||||
# Mock response with retry-after header
|
||||
error.response = Mock()
|
||||
error.response.headers = {}
|
||||
if retry_after:
|
||||
error.response.headers["retry-after"] = str(retry_after)
|
||||
|
||||
return error
|
||||
|
||||
|
||||
def create_mock_openrouter_rate_limit_error(retry_after: Optional[int] = None):
|
||||
"""Create a mock OpenRouter RateLimitError (via OpenAI client) for testing."""
|
||||
error = Mock()
|
||||
error.__class__.__name__ = "RateLimitError"
|
||||
error.message = "Rate limit reached for anthropic/claude-opus-4.5"
|
||||
|
||||
error.response = Mock()
|
||||
error.response.headers = {}
|
||||
if retry_after:
|
||||
error.response.headers["retry-after"] = str(retry_after)
|
||||
|
||||
return error
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test LLMRateLimitError Base Class
|
||||
# ============================================================================
|
||||
|
||||
class TestLLMRateLimitError:
|
||||
"""Test the base LLMRateLimitError exception class."""
|
||||
|
||||
def test_basic_exception_creation(self):
|
||||
"""Test creating LLMRateLimitError with just a message."""
|
||||
# Import will fail initially (TDD RED phase)
|
||||
from tradingagents.utils.exceptions import LLMRateLimitError
|
||||
|
||||
error = LLMRateLimitError("Rate limit exceeded")
|
||||
|
||||
assert str(error) == "Rate limit exceeded"
|
||||
assert error.retry_after is None
|
||||
assert error.provider is None
|
||||
|
||||
def test_exception_with_retry_after(self):
|
||||
"""Test LLMRateLimitError with retry_after parameter."""
|
||||
from tradingagents.utils.exceptions import LLMRateLimitError
|
||||
|
||||
error = LLMRateLimitError("Rate limit exceeded", retry_after=60)
|
||||
|
||||
assert str(error) == "Rate limit exceeded"
|
||||
assert error.retry_after == 60
|
||||
assert isinstance(error.retry_after, int)
|
||||
|
||||
def test_exception_with_provider(self):
|
||||
"""Test LLMRateLimitError with provider parameter."""
|
||||
from tradingagents.utils.exceptions import LLMRateLimitError
|
||||
|
||||
error = LLMRateLimitError(
|
||||
"Rate limit exceeded",
|
||||
retry_after=120,
|
||||
provider="openai"
|
||||
)
|
||||
|
||||
assert error.provider == "openai"
|
||||
assert error.retry_after == 120
|
||||
|
||||
def test_exception_inheritance(self):
|
||||
"""Test that LLMRateLimitError inherits from Exception."""
|
||||
from tradingagents.utils.exceptions import LLMRateLimitError
|
||||
|
||||
error = LLMRateLimitError("Test")
|
||||
|
||||
assert isinstance(error, Exception)
|
||||
assert isinstance(error, LLMRateLimitError)
|
||||
|
||||
def test_exception_with_none_retry_after(self):
|
||||
"""Test that retry_after can be None."""
|
||||
from tradingagents.utils.exceptions import LLMRateLimitError
|
||||
|
||||
error = LLMRateLimitError("Rate limit", retry_after=None)
|
||||
|
||||
assert error.retry_after is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Provider-Specific Exceptions
|
||||
# ============================================================================
|
||||
|
||||
class TestOpenAIRateLimitError:
|
||||
"""Test OpenAI-specific rate limit error."""
|
||||
|
||||
def test_openai_exception_creation(self):
|
||||
"""Test creating OpenAIRateLimitError."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError, LLMRateLimitError
|
||||
|
||||
error = OpenAIRateLimitError("OpenAI rate limit", retry_after=45)
|
||||
|
||||
assert isinstance(error, LLMRateLimitError)
|
||||
assert str(error) == "OpenAI rate limit"
|
||||
assert error.retry_after == 45
|
||||
assert error.provider == "openai"
|
||||
|
||||
def test_openai_exception_inherits_base(self):
|
||||
"""Test that OpenAIRateLimitError inherits from LLMRateLimitError."""
|
||||
from tradingagents.utils.exceptions import OpenAIRateLimitError, LLMRateLimitError
|
||||
|
||||
error = OpenAIRateLimitError("Test")
|
||||
|
||||
assert isinstance(error, LLMRateLimitError)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
class TestAnthropicRateLimitError:
|
||||
"""Test Anthropic-specific rate limit error."""
|
||||
|
||||
def test_anthropic_exception_creation(self):
|
||||
"""Test creating AnthropicRateLimitError."""
|
||||
from tradingagents.utils.exceptions import AnthropicRateLimitError, LLMRateLimitError
|
||||
|
||||
error = AnthropicRateLimitError("Anthropic rate limit", retry_after=90)
|
||||
|
||||
assert isinstance(error, LLMRateLimitError)
|
||||
assert str(error) == "Anthropic rate limit"
|
||||
assert error.retry_after == 90
|
||||
assert error.provider == "anthropic"
|
||||
|
||||
def test_anthropic_exception_inherits_base(self):
|
||||
"""Test that AnthropicRateLimitError inherits from LLMRateLimitError."""
|
||||
from tradingagents.utils.exceptions import AnthropicRateLimitError, LLMRateLimitError
|
||||
|
||||
error = AnthropicRateLimitError("Test")
|
||||
|
||||
assert isinstance(error, LLMRateLimitError)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
class TestOpenRouterRateLimitError:
|
||||
"""Test OpenRouter-specific rate limit error."""
|
||||
|
||||
def test_openrouter_exception_creation(self):
|
||||
"""Test creating OpenRouterRateLimitError."""
|
||||
from tradingagents.utils.exceptions import OpenRouterRateLimitError, LLMRateLimitError
|
||||
|
||||
error = OpenRouterRateLimitError("OpenRouter rate limit", retry_after=30)
|
||||
|
||||
assert isinstance(error, LLMRateLimitError)
|
||||
assert str(error) == "OpenRouter rate limit"
|
||||
assert error.retry_after == 30
|
||||
assert error.provider == "openrouter"
|
||||
|
||||
def test_openrouter_exception_inherits_base(self):
|
||||
"""Test that OpenRouterRateLimitError inherits from LLMRateLimitError."""
|
||||
from tradingagents.utils.exceptions import OpenRouterRateLimitError, LLMRateLimitError
|
||||
|
||||
error = OpenRouterRateLimitError("Test")
|
||||
|
||||
assert isinstance(error, LLMRateLimitError)
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test from_provider_error() Conversion
|
||||
# ============================================================================
|
||||
|
||||
class TestProviderErrorConversion:
|
||||
"""Test conversion from native provider errors to unified exceptions."""
|
||||
|
||||
def test_convert_openai_error_with_retry_after(self):
|
||||
"""Test converting OpenAI RateLimitError with retry-after header."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = create_mock_openai_rate_limit_error(retry_after=60)
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after == 60
|
||||
assert converted.provider == "openai"
|
||||
assert "Rate limit exceeded" in str(converted)
|
||||
|
||||
def test_convert_openai_error_without_retry_after(self):
|
||||
"""Test converting OpenAI RateLimitError without retry-after header."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = create_mock_openai_rate_limit_error(retry_after=None)
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after is None
|
||||
assert converted.provider == "openai"
|
||||
|
||||
def test_convert_anthropic_error_with_retry_after(self):
|
||||
"""Test converting Anthropic RateLimitError with retry-after header."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, AnthropicRateLimitError
|
||||
|
||||
mock_error = create_mock_anthropic_rate_limit_error(retry_after=120)
|
||||
|
||||
converted = from_provider_error(mock_error, provider="anthropic")
|
||||
|
||||
assert isinstance(converted, AnthropicRateLimitError)
|
||||
assert converted.retry_after == 120
|
||||
assert converted.provider == "anthropic"
|
||||
|
||||
def test_convert_anthropic_error_without_retry_after(self):
|
||||
"""Test converting Anthropic RateLimitError without retry-after header."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, AnthropicRateLimitError
|
||||
|
||||
mock_error = create_mock_anthropic_rate_limit_error(retry_after=None)
|
||||
|
||||
converted = from_provider_error(mock_error, provider="anthropic")
|
||||
|
||||
assert isinstance(converted, AnthropicRateLimitError)
|
||||
assert converted.retry_after is None
|
||||
|
||||
def test_convert_openrouter_error_with_retry_after(self):
|
||||
"""Test converting OpenRouter RateLimitError with retry-after header."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenRouterRateLimitError
|
||||
|
||||
mock_error = create_mock_openrouter_rate_limit_error(retry_after=45)
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openrouter")
|
||||
|
||||
assert isinstance(converted, OpenRouterRateLimitError)
|
||||
assert converted.retry_after == 45
|
||||
assert converted.provider == "openrouter"
|
||||
|
||||
def test_convert_openrouter_error_without_retry_after(self):
|
||||
"""Test converting OpenRouter RateLimitError without retry-after header."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenRouterRateLimitError
|
||||
|
||||
mock_error = create_mock_openrouter_rate_limit_error(retry_after=None)
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openrouter")
|
||||
|
||||
assert isinstance(converted, OpenRouterRateLimitError)
|
||||
assert converted.retry_after is None
|
||||
|
||||
def test_convert_unknown_provider(self):
|
||||
"""Test converting error from unknown provider defaults to base class."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, LLMRateLimitError
|
||||
|
||||
mock_error = create_mock_openai_rate_limit_error(retry_after=30)
|
||||
|
||||
converted = from_provider_error(mock_error, provider="unknown")
|
||||
|
||||
# Should return base LLMRateLimitError for unknown providers
|
||||
assert isinstance(converted, LLMRateLimitError)
|
||||
assert converted.provider == "unknown"
|
||||
|
||||
def test_convert_non_rate_limit_error(self):
|
||||
"""Test that non-rate-limit errors are not converted."""
|
||||
from tradingagents.utils.exceptions import from_provider_error
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "APIError"
|
||||
mock_error.message = "API connection failed"
|
||||
|
||||
# Should return None or raise ValueError for non-rate-limit errors
|
||||
with pytest.raises(ValueError, match="Not a rate limit error"):
|
||||
from_provider_error(mock_error, provider="openai")
|
||||
|
||||
def test_extract_retry_after_from_string(self):
|
||||
"""Test extracting retry_after when it's a string in headers."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": "75"}
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after == 75
|
||||
assert isinstance(converted.retry_after, int)
|
||||
|
||||
def test_extract_retry_after_from_int(self):
|
||||
"""Test extracting retry_after when it's already an int in headers."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": 90}
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert converted.retry_after == 90
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Cases and Error Handling
|
||||
# ============================================================================
|
||||
|
||||
class TestExceptionEdgeCases:
|
||||
"""Test edge cases and error handling in exception conversion."""
|
||||
|
||||
def test_missing_response_object(self):
|
||||
"""Test handling error with no response object."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = None
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after is None
|
||||
|
||||
def test_missing_headers_object(self):
|
||||
"""Test handling error with response but no headers."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = None
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after is None
|
||||
|
||||
def test_invalid_retry_after_string(self):
|
||||
"""Test handling invalid retry-after value (non-numeric string)."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": "invalid"}
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
# Should gracefully handle invalid values
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after is None
|
||||
|
||||
def test_negative_retry_after(self):
|
||||
"""Test handling negative retry-after value."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": "-10"}
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
# Should either convert to positive or set to None
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after is None or converted.retry_after >= 0
|
||||
|
||||
def test_zero_retry_after(self):
|
||||
"""Test handling zero retry-after value."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": "0"}
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after == 0
|
||||
|
||||
def test_very_large_retry_after(self):
|
||||
"""Test handling very large retry-after value."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.message = "Rate limit exceeded"
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {"retry-after": "86400"} # 24 hours
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert converted.retry_after == 86400
|
||||
|
||||
def test_message_extraction_from_str(self):
|
||||
"""Test extracting message when error has __str__ instead of message attribute."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = Mock()
|
||||
mock_error.__class__.__name__ = "RateLimitError"
|
||||
mock_error.__str__ = Mock(return_value="Rate limit from __str__")
|
||||
del mock_error.message # Remove message attribute
|
||||
mock_error.response = Mock()
|
||||
mock_error.response.headers = {}
|
||||
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
|
||||
assert isinstance(converted, OpenAIRateLimitError)
|
||||
assert "Rate limit from __str__" in str(converted)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestExceptionIntegration:
|
||||
"""Test exception usage in realistic scenarios."""
|
||||
|
||||
def test_catch_and_reraise_pattern(self):
|
||||
"""Test the typical catch-and-reraise pattern."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, OpenAIRateLimitError
|
||||
|
||||
mock_error = create_mock_openai_rate_limit_error(retry_after=60)
|
||||
|
||||
try:
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
raise converted
|
||||
except OpenAIRateLimitError as e:
|
||||
assert e.retry_after == 60
|
||||
assert e.provider == "openai"
|
||||
|
||||
def test_exception_in_except_block(self):
|
||||
"""Test using from_provider_error in an except block."""
|
||||
from tradingagents.utils.exceptions import from_provider_error, LLMRateLimitError
|
||||
|
||||
mock_error = create_mock_openai_rate_limit_error(retry_after=45)
|
||||
|
||||
try:
|
||||
# Simulate catching a provider error
|
||||
raise Exception("Simulated OpenAI error")
|
||||
except Exception:
|
||||
# Convert to our exception
|
||||
converted = from_provider_error(mock_error, provider="openai")
|
||||
assert isinstance(converted, LLMRateLimitError)
|
||||
|
||||
def test_multiple_provider_errors(self):
|
||||
"""Test handling errors from multiple providers in sequence."""
|
||||
from tradingagents.utils.exceptions import (
|
||||
from_provider_error,
|
||||
OpenAIRateLimitError,
|
||||
AnthropicRateLimitError,
|
||||
OpenRouterRateLimitError
|
||||
)
|
||||
|
||||
openai_error = create_mock_openai_rate_limit_error(retry_after=30)
|
||||
anthropic_error = create_mock_anthropic_rate_limit_error(retry_after=60)
|
||||
openrouter_error = create_mock_openrouter_rate_limit_error(retry_after=90)
|
||||
|
||||
openai_converted = from_provider_error(openai_error, provider="openai")
|
||||
anthropic_converted = from_provider_error(anthropic_error, provider="anthropic")
|
||||
openrouter_converted = from_provider_error(openrouter_error, provider="openrouter")
|
||||
|
||||
assert isinstance(openai_converted, OpenAIRateLimitError)
|
||||
assert isinstance(anthropic_converted, AnthropicRateLimitError)
|
||||
assert isinstance(openrouter_converted, OpenRouterRateLimitError)
|
||||
|
||||
assert openai_converted.retry_after == 30
|
||||
assert anthropic_converted.retry_after == 60
|
||||
assert openrouter_converted.retry_after == 90
|
||||
|
|
@ -0,0 +1,597 @@
|
|||
"""
|
||||
Test suite for Dual-Output Logging Configuration.
|
||||
|
||||
This module tests:
|
||||
1. setup_dual_logger() creates both terminal and file handlers
|
||||
2. RotatingFileHandler configuration (maxBytes, backupCount)
|
||||
3. sanitize_log_message() removes API keys and sensitive data
|
||||
4. Log rotation works at 5MB boundary
|
||||
5. Log formatting for both handlers
|
||||
6. File creation and permissions
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, call
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def temp_log_dir():
|
||||
"""Create a temporary directory for log files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger_name():
|
||||
"""Generate unique logger name for each test."""
|
||||
import uuid
|
||||
return f"test_logger_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_logger():
|
||||
"""Cleanup logger after test to prevent handler accumulation."""
|
||||
loggers_to_cleanup = []
|
||||
|
||||
def register(logger):
|
||||
loggers_to_cleanup.append(logger)
|
||||
return logger
|
||||
|
||||
yield register
|
||||
|
||||
# Cleanup
|
||||
for logger in loggers_to_cleanup:
|
||||
logger.handlers.clear()
|
||||
logger.filters.clear()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test setup_dual_logger() Function
|
||||
# ============================================================================
|
||||
|
||||
class TestSetupDualLogger:
|
||||
"""Test the dual logger setup function."""
|
||||
|
||||
def test_creates_logger_with_two_handlers(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that setup_dual_logger creates a logger with terminal and file handlers."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
assert isinstance(logger, logging.Logger)
|
||||
assert len(logger.handlers) == 2
|
||||
|
||||
# Check handler types
|
||||
handler_types = [type(h) for h in logger.handlers]
|
||||
assert logging.StreamHandler in handler_types
|
||||
assert RotatingFileHandler in handler_types
|
||||
|
||||
def test_terminal_handler_configuration(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that terminal handler is configured correctly."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
# Find the StreamHandler
|
||||
stream_handler = None
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, logging.StreamHandler) and not isinstance(handler, RotatingFileHandler):
|
||||
stream_handler = handler
|
||||
break
|
||||
|
||||
assert stream_handler is not None
|
||||
assert stream_handler.level == logging.INFO
|
||||
|
||||
def test_file_handler_configuration(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that file handler is configured with rotation settings."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
# Find the RotatingFileHandler
|
||||
file_handler = None
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, RotatingFileHandler):
|
||||
file_handler = handler
|
||||
break
|
||||
|
||||
assert file_handler is not None
|
||||
assert file_handler.maxBytes == 5 * 1024 * 1024 # 5MB
|
||||
assert file_handler.backupCount == 3
|
||||
assert file_handler.level == logging.DEBUG
|
||||
|
||||
def test_creates_log_file(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that setup_dual_logger creates the log file."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
logger.info("Test message")
|
||||
|
||||
# File should be created
|
||||
assert log_file.exists()
|
||||
|
||||
def test_creates_log_directory_if_missing(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that setup_dual_logger creates parent directories if they don't exist."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "nested" / "dir" / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
logger.info("Test message")
|
||||
|
||||
assert log_file.exists()
|
||||
assert log_file.parent.exists()
|
||||
|
||||
def test_logger_level_configuration(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that logger is configured with DEBUG level."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
assert logger.level == logging.DEBUG
|
||||
|
||||
def test_default_log_file_location(self, logger_name, cleanup_logger):
|
||||
"""Test default log file location when not specified."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
logger = setup_dual_logger(name=logger_name)
|
||||
cleanup_logger(logger)
|
||||
|
||||
# Find the RotatingFileHandler
|
||||
file_handler = None
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, RotatingFileHandler):
|
||||
file_handler = handler
|
||||
break
|
||||
|
||||
assert file_handler is not None
|
||||
# Should default to logs/tradingagents.log
|
||||
assert "logs" in file_handler.baseFilename
|
||||
assert "tradingagents.log" in file_handler.baseFilename
|
||||
|
||||
def test_custom_log_levels(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test setting custom log levels for handlers."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(
|
||||
name=logger_name,
|
||||
log_file=str(log_file),
|
||||
console_level=logging.WARNING,
|
||||
file_level=logging.INFO
|
||||
)
|
||||
cleanup_logger(logger)
|
||||
|
||||
# Find handlers
|
||||
stream_handler = None
|
||||
file_handler = None
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, RotatingFileHandler):
|
||||
file_handler = handler
|
||||
elif isinstance(handler, logging.StreamHandler):
|
||||
stream_handler = handler
|
||||
|
||||
assert stream_handler.level == logging.WARNING
|
||||
assert file_handler.level == logging.INFO
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test sanitize_log_message() Function
|
||||
# ============================================================================
|
||||
|
||||
class TestSanitizeLogMessage:
|
||||
"""Test the log message sanitization function."""
|
||||
|
||||
def test_sanitize_openai_api_key(self):
|
||||
"""Test that OpenAI API keys (sk-*) are redacted."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = "Error with API key sk-1234567890abcdef: Rate limit exceeded"
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert "sk-1234567890abcdef" not in sanitized
|
||||
assert "[REDACTED-API-KEY]" in sanitized
|
||||
|
||||
def test_sanitize_openrouter_api_key(self):
|
||||
"""Test that OpenRouter API keys (sk-or-*) are redacted."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = "Using key sk-or-v1-abcdef123456 for request"
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert "sk-or-v1-abcdef123456" not in sanitized
|
||||
assert "[REDACTED-API-KEY]" in sanitized
|
||||
|
||||
def test_sanitize_bearer_token(self):
|
||||
"""Test that Bearer tokens are redacted."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" not in sanitized
|
||||
assert "[REDACTED-TOKEN]" in sanitized
|
||||
|
||||
def test_sanitize_anthropic_api_key(self):
|
||||
"""Test that Anthropic API keys are redacted."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = "x-api-key: sk-ant-api03-1234567890abcdef"
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert "sk-ant-api03-1234567890abcdef" not in sanitized
|
||||
assert "[REDACTED-API-KEY]" in sanitized
|
||||
|
||||
def test_sanitize_multiple_keys_in_message(self):
|
||||
"""Test that multiple API keys in one message are all redacted."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = "Tried sk-1111111111 and sk-or-v1-2222222222 but both failed"
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert "sk-1111111111" not in sanitized
|
||||
assert "sk-or-v1-2222222222" not in sanitized
|
||||
assert sanitized.count("[REDACTED-API-KEY]") == 2
|
||||
|
||||
def test_sanitize_preserves_safe_content(self):
|
||||
"""Test that non-sensitive content is preserved."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = "Rate limit exceeded for model gpt-4. Please retry in 60 seconds."
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert sanitized == message
|
||||
|
||||
def test_sanitize_empty_message(self):
|
||||
"""Test sanitizing an empty message."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
sanitized = sanitize_log_message("")
|
||||
|
||||
assert sanitized == ""
|
||||
|
||||
def test_sanitize_none_message(self):
|
||||
"""Test sanitizing None message."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
sanitized = sanitize_log_message(None)
|
||||
|
||||
assert sanitized == "" or sanitized is None
|
||||
|
||||
def test_sanitize_message_with_json(self):
|
||||
"""Test sanitizing a message containing JSON with API key."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = '{"api_key": "sk-1234567890", "model": "gpt-4"}'
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert "sk-1234567890" not in sanitized
|
||||
assert "[REDACTED-API-KEY]" in sanitized
|
||||
assert '"model": "gpt-4"' in sanitized
|
||||
|
||||
def test_sanitize_url_with_api_key(self):
|
||||
"""Test sanitizing URLs containing API keys in query parameters."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = "Calling https://api.example.com/v1/chat?api_key=sk-test123456"
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert "sk-test123456" not in sanitized
|
||||
assert "[REDACTED-API-KEY]" in sanitized
|
||||
|
||||
def test_sanitize_partial_key_patterns(self):
|
||||
"""Test that partial key patterns that look like API keys are redacted."""
|
||||
from tradingagents.utils.logging_config import sanitize_log_message
|
||||
|
||||
message = "Key starts with sk- but full key is sk-proj-abcdefghijklmnop"
|
||||
sanitized = sanitize_log_message(message)
|
||||
|
||||
assert "sk-proj-abcdefghijklmnop" not in sanitized
|
||||
assert "[REDACTED-API-KEY]" in sanitized
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Log Rotation
|
||||
# ============================================================================
|
||||
|
||||
class TestLogRotation:
|
||||
"""Test log file rotation functionality."""
|
||||
|
||||
def test_rotation_at_5mb_boundary(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that log rotation occurs at 5MB file size."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
# Find the RotatingFileHandler
|
||||
file_handler = None
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, RotatingFileHandler):
|
||||
file_handler = handler
|
||||
break
|
||||
|
||||
# Write large amount of data to trigger rotation
|
||||
large_message = "X" * 1024 * 100 # 100KB per message
|
||||
for i in range(60): # 6MB total
|
||||
logger.info(large_message)
|
||||
|
||||
# Should create backup file when rotation occurs
|
||||
backup_file = Path(str(log_file) + ".1")
|
||||
assert log_file.exists()
|
||||
# Rotation may or may not have occurred yet depending on exact timing
|
||||
# Just verify the configuration is correct
|
||||
assert file_handler.maxBytes == 5 * 1024 * 1024
|
||||
|
||||
def test_backup_count_configuration(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that backupCount is set to 3."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
# Find the RotatingFileHandler
|
||||
file_handler = None
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, RotatingFileHandler):
|
||||
file_handler = handler
|
||||
break
|
||||
|
||||
assert file_handler.backupCount == 3
|
||||
|
||||
def test_rotation_creates_backup_files(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that rotation creates .1, .2, .3 backup files."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
# Use smaller maxBytes for testing
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
# Manually trigger rotation by writing through handler
|
||||
file_handler = None
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, RotatingFileHandler):
|
||||
file_handler = handler
|
||||
break
|
||||
|
||||
# Override maxBytes for testing
|
||||
file_handler.maxBytes = 1024 # 1KB for easy testing
|
||||
|
||||
# Write enough to trigger multiple rotations
|
||||
for i in range(10):
|
||||
logger.info("X" * 200) # 200 bytes per message
|
||||
|
||||
# Check that main log file exists
|
||||
assert log_file.exists()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Log Formatting
|
||||
# ============================================================================
|
||||
|
||||
class TestLogFormatting:
|
||||
"""Test log message formatting."""
|
||||
|
||||
def test_log_format_includes_timestamp(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that log messages include timestamp."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
logger.info("Test message")
|
||||
|
||||
# Read log file
|
||||
content = log_file.read_text()
|
||||
# Should have timestamp format like 2024-12-26 10:30:45
|
||||
assert any(char.isdigit() for char in content)
|
||||
|
||||
def test_log_format_includes_level(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that log messages include log level."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
logger.info("Info message")
|
||||
logger.warning("Warning message")
|
||||
logger.error("Error message")
|
||||
|
||||
content = log_file.read_text()
|
||||
assert "INFO" in content
|
||||
assert "WARNING" in content
|
||||
assert "ERROR" in content
|
||||
|
||||
def test_log_format_includes_message(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that log messages include the actual message."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
logger.info("This is a test message")
|
||||
|
||||
content = log_file.read_text()
|
||||
assert "This is a test message" in content
|
||||
|
||||
def test_multiline_log_message(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test handling of multiline log messages."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
logger.info("Line 1\nLine 2\nLine 3")
|
||||
|
||||
content = log_file.read_text()
|
||||
assert "Line 1" in content
|
||||
assert "Line 2" in content
|
||||
assert "Line 3" in content
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Integration with Sanitization
|
||||
# ============================================================================
|
||||
|
||||
class TestLoggingWithSanitization:
|
||||
"""Test that sanitization is applied when logging."""
|
||||
|
||||
def test_logged_message_is_sanitized(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that API keys are sanitized before being written to log."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
# This should be sanitized automatically
|
||||
logger.error("API request failed with key sk-test1234567890")
|
||||
|
||||
content = log_file.read_text()
|
||||
assert "sk-test1234567890" not in content
|
||||
assert "[REDACTED-API-KEY]" in content
|
||||
|
||||
@patch('tradingagents.utils.logging_config.sanitize_log_message')
|
||||
def test_sanitize_called_on_log(self, mock_sanitize, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that sanitize_log_message is called when logging."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
mock_sanitize.return_value = "Sanitized message"
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
logger.info("Test message with sk-test123")
|
||||
|
||||
# Sanitize should be called
|
||||
# Note: This test may need adjustment based on how sanitization is integrated
|
||||
# It might be called via a filter or formatter
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Cases and Error Handling
|
||||
# ============================================================================
|
||||
|
||||
class TestLoggingEdgeCases:
|
||||
"""Test edge cases in logging configuration."""
|
||||
|
||||
def test_permission_denied_for_log_file(self, temp_log_dir, logger_name):
|
||||
"""Test handling when log file location has no write permission."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
# Create a directory with no write permission
|
||||
readonly_dir = temp_log_dir / "readonly"
|
||||
readonly_dir.mkdir()
|
||||
readonly_dir.chmod(0o444)
|
||||
|
||||
log_file = readonly_dir / "test.log"
|
||||
|
||||
# Should handle gracefully or raise appropriate error
|
||||
try:
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
# If it succeeds, at least terminal logging should work
|
||||
assert len(logger.handlers) >= 1
|
||||
except (PermissionError, OSError):
|
||||
# Expected behavior - permission denied
|
||||
pass
|
||||
finally:
|
||||
# Cleanup
|
||||
readonly_dir.chmod(0o755)
|
||||
|
||||
def test_invalid_log_file_path(self, logger_name):
|
||||
"""Test handling of invalid log file path."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
# Use an invalid path
|
||||
log_file = "/invalid/path/that/does/not/exist/test.log"
|
||||
|
||||
# Should either create the path or handle gracefully
|
||||
try:
|
||||
logger = setup_dual_logger(name=logger_name, log_file=log_file)
|
||||
# If it succeeds, verify it created the directory
|
||||
assert Path(log_file).parent.exists() or len(logger.handlers) >= 1
|
||||
except (PermissionError, OSError):
|
||||
# Expected - cannot create directory
|
||||
pass
|
||||
|
||||
def test_unicode_in_log_message(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test handling of unicode characters in log messages."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
logger.info("Unicode test: 你好 🌍 €")
|
||||
|
||||
content = log_file.read_text(encoding='utf-8')
|
||||
assert "你好" in content or "Unicode test" in content
|
||||
|
||||
def test_very_long_log_message(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test handling of very long log messages."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
long_message = "X" * 10000 # 10KB message
|
||||
logger.info(long_message)
|
||||
|
||||
content = log_file.read_text()
|
||||
assert len(content) > 9000 # Should contain most of the message
|
||||
|
||||
def test_concurrent_logging(self, temp_log_dir, logger_name, cleanup_logger):
|
||||
"""Test that concurrent logging to same file works."""
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
import threading
|
||||
|
||||
log_file = temp_log_dir / "test.log"
|
||||
logger = setup_dual_logger(name=logger_name, log_file=str(log_file))
|
||||
cleanup_logger(logger)
|
||||
|
||||
def log_messages(thread_id):
|
||||
for i in range(10):
|
||||
logger.info(f"Thread {thread_id} message {i}")
|
||||
|
||||
threads = []
|
||||
for i in range(5):
|
||||
t = threading.Thread(target=log_messages, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
content = log_file.read_text()
|
||||
# Should have all 50 messages
|
||||
assert content.count("message") >= 40 # Allow some loss in concurrent scenario
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
Graph Error Translation Layer.
|
||||
|
||||
This module provides error translation from native LLM provider errors
|
||||
to unified TradingAgents exceptions. This allows the graph to handle
|
||||
errors consistently regardless of the underlying LLM provider.
|
||||
|
||||
Functions:
|
||||
translate_llm_error: Convert provider-specific errors to unified exceptions
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.utils.exceptions import (
|
||||
from_provider_error,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
def translate_llm_error(error: Any, provider: str) -> LLMRateLimitError:
|
||||
"""
|
||||
Translate a native LLM provider error to a unified exception.
|
||||
|
||||
This function serves as the integration point between the graph layer
|
||||
and the exception handling system. It converts provider-specific errors
|
||||
to our unified exception hierarchy.
|
||||
|
||||
Args:
|
||||
error: Native provider error object
|
||||
provider: Provider name ('openai', 'anthropic', 'openrouter')
|
||||
|
||||
Returns:
|
||||
LLMRateLimitError: Unified exception
|
||||
|
||||
Raises:
|
||||
ValueError: If the error is not a rate limit error
|
||||
|
||||
Example:
|
||||
try:
|
||||
response = llm_client.invoke(...)
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ == "RateLimitError":
|
||||
unified_error = translate_llm_error(e, provider="openai")
|
||||
raise unified_error
|
||||
raise
|
||||
"""
|
||||
return from_provider_error(error, provider=provider)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
"""
|
||||
TradingAgents utilities package.
|
||||
|
||||
This package provides utility functions and classes for the TradingAgents framework.
|
||||
"""
|
||||
|
||||
from tradingagents.utils.exceptions import (
|
||||
LLMRateLimitError,
|
||||
OpenAIRateLimitError,
|
||||
AnthropicRateLimitError,
|
||||
OpenRouterRateLimitError,
|
||||
from_provider_error,
|
||||
)
|
||||
|
||||
from tradingagents.utils.logging_config import (
|
||||
setup_dual_logger,
|
||||
sanitize_log_message,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LLMRateLimitError",
|
||||
"OpenAIRateLimitError",
|
||||
"AnthropicRateLimitError",
|
||||
"OpenRouterRateLimitError",
|
||||
"from_provider_error",
|
||||
"setup_dual_logger",
|
||||
"sanitize_log_message",
|
||||
]
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
"""
|
||||
User-Facing Error Messages.
|
||||
|
||||
This module provides functions for formatting user-friendly error messages,
|
||||
particularly for rate limit errors.
|
||||
|
||||
Functions:
|
||||
format_rate_limit_error: Format a rate limit error for user display
|
||||
format_error_with_partial_save: Format error with partial save location
|
||||
format_retry_time: Format retry time in human-readable format
|
||||
print_user_error: Print error to console in user-friendly format
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from tradingagents.utils.exceptions import LLMRateLimitError
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
|
||||
|
||||
def format_rate_limit_error(error: LLMRateLimitError) -> str:
|
||||
"""
|
||||
Format a rate limit error for user display.
|
||||
|
||||
Creates a user-friendly message that includes:
|
||||
- Provider name
|
||||
- Retry guidance
|
||||
- Retry time if available
|
||||
|
||||
Args:
|
||||
error: LLMRateLimitError instance
|
||||
|
||||
Returns:
|
||||
str: Formatted error message
|
||||
|
||||
Example:
|
||||
>>> error = OpenAIRateLimitError("Rate limit exceeded", retry_after=60)
|
||||
>>> format_rate_limit_error(error)
|
||||
'Rate limit exceeded for OpenAI. Please retry in 60 seconds (1 minute).'
|
||||
"""
|
||||
provider_name = _format_provider_name(error.provider)
|
||||
|
||||
if error.retry_after is not None:
|
||||
retry_time = format_retry_time(error.retry_after)
|
||||
return (
|
||||
f"Rate limit exceeded for {provider_name}. "
|
||||
f"Please retry in {retry_time}."
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"Rate limit exceeded for {provider_name}. "
|
||||
f"Please wait a moment and try again later."
|
||||
)
|
||||
|
||||
|
||||
def format_error_with_partial_save(error_message: str, partial_file: str) -> str:
|
||||
"""
|
||||
Format error message with information about saved partial analysis.
|
||||
|
||||
Args:
|
||||
error_message: The error message
|
||||
partial_file: Path to saved partial analysis file
|
||||
|
||||
Returns:
|
||||
str: Formatted message
|
||||
|
||||
Example:
|
||||
>>> format_error_with_partial_save(
|
||||
... "Rate limit exceeded",
|
||||
... "./results/partial_AAPL_20241226.json"
|
||||
... )
|
||||
'Rate limit exceeded\\n\\nPartial analysis saved to: ./results/partial_AAPL_20241226.json'
|
||||
"""
|
||||
return (
|
||||
f"{error_message}\n\n"
|
||||
f"Partial analysis saved to: {partial_file}\n"
|
||||
f"You can inspect the partial results and retry when the rate limit resets."
|
||||
)
|
||||
|
||||
|
||||
def format_retry_time(seconds: int) -> str:
|
||||
"""
|
||||
Format retry time in human-readable format.
|
||||
|
||||
Converts seconds to appropriate units:
|
||||
- < 60s: "X seconds"
|
||||
- < 3600s: "X minutes (Y seconds)"
|
||||
- >= 3600s: "X hours (Y minutes)"
|
||||
|
||||
Args:
|
||||
seconds: Number of seconds
|
||||
|
||||
Returns:
|
||||
str: Human-readable time format
|
||||
|
||||
Example:
|
||||
>>> format_retry_time(60)
|
||||
'1 minute (60 seconds)'
|
||||
>>> format_retry_time(300)
|
||||
'5 minutes (300 seconds)'
|
||||
>>> format_retry_time(3600)
|
||||
'1 hour (60 minutes)'
|
||||
"""
|
||||
if seconds < 60:
|
||||
return f"{seconds} seconds"
|
||||
|
||||
minutes = seconds // 60
|
||||
if minutes < 60:
|
||||
return f"{minutes} minute{'s' if minutes != 1 else ''} ({seconds} seconds)"
|
||||
|
||||
hours = minutes // 60
|
||||
remaining_minutes = minutes % 60
|
||||
return f"{hours} hour{'s' if hours != 1 else ''} ({remaining_minutes} minutes)"
|
||||
|
||||
|
||||
def print_user_error(error: LLMRateLimitError) -> None:
|
||||
"""
|
||||
Print error to console in user-friendly format.
|
||||
|
||||
Uses Rich Panel if available, otherwise falls back to simple print.
|
||||
|
||||
Args:
|
||||
error: LLMRateLimitError instance
|
||||
|
||||
Example:
|
||||
>>> error = OpenAIRateLimitError("Rate limit exceeded", retry_after=60)
|
||||
>>> print_user_error(error)
|
||||
# Displays formatted error panel in terminal
|
||||
"""
|
||||
message = format_rate_limit_error(error)
|
||||
|
||||
if RICH_AVAILABLE:
|
||||
console = Console()
|
||||
panel = Panel(
|
||||
message,
|
||||
title="[bold red]Rate Limit Error[/bold red]",
|
||||
border_style="red",
|
||||
)
|
||||
console.print(panel)
|
||||
else:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"RATE LIMIT ERROR")
|
||||
print(f"{'='*60}")
|
||||
print(message)
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def _format_provider_name(provider: Optional[str]) -> str:
|
||||
"""
|
||||
Format provider name for display.
|
||||
|
||||
Args:
|
||||
provider: Provider identifier
|
||||
|
||||
Returns:
|
||||
str: Formatted provider name
|
||||
"""
|
||||
if provider is None:
|
||||
return "LLM provider"
|
||||
|
||||
# Capitalize provider names
|
||||
provider_names = {
|
||||
"openai": "OpenAI",
|
||||
"anthropic": "Anthropic",
|
||||
"openrouter": "OpenRouter",
|
||||
}
|
||||
|
||||
return provider_names.get(provider.lower(), provider.title())
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
"""
|
||||
Error Recovery Utilities.
|
||||
|
||||
This module provides utilities for saving partial analysis state when errors occur,
|
||||
allowing users to resume or inspect work completed before the error.
|
||||
|
||||
Functions:
|
||||
save_partial_analysis: Save partial state to JSON file
|
||||
get_partial_analysis_filename: Generate filename for partial analysis
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def save_partial_analysis(state: Dict[str, Any], output_file: str) -> None:
|
||||
"""
|
||||
Save partial analysis state to a JSON file.
|
||||
|
||||
Handles non-serializable objects by converting them to strings.
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
state: Dictionary containing partial analysis state
|
||||
output_file: Path where to save the JSON file
|
||||
|
||||
Raises:
|
||||
PermissionError: If unable to write to output_file location
|
||||
OSError: If unable to create parent directories
|
||||
|
||||
Example:
|
||||
>>> state = {
|
||||
... "ticker": "AAPL",
|
||||
... "error": "Rate limit exceeded",
|
||||
... "analyst_reports": {"market": {...}}
|
||||
... }
|
||||
>>> save_partial_analysis(state, "./results/partial_AAPL.json")
|
||||
"""
|
||||
# Create parent directory if it doesn't exist
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert state to JSON-serializable format
|
||||
serializable_state = _make_serializable(state)
|
||||
|
||||
# Write to file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(serializable_state, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_partial_analysis_filename(
|
||||
ticker: str,
|
||||
timestamp: Optional[datetime] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a filename for partial analysis output.
|
||||
|
||||
Format: partial_analysis_{ticker}_{timestamp}.json
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
timestamp: Timestamp for filename (default: now)
|
||||
output_dir: Output directory (default: TRADINGAGENTS_RESULTS_DIR or ./results)
|
||||
|
||||
Returns:
|
||||
str: Full path to partial analysis file
|
||||
|
||||
Example:
|
||||
>>> get_partial_analysis_filename("AAPL")
|
||||
'./results/partial_analysis_AAPL_20241226_103045.json'
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results")
|
||||
|
||||
# Format: partial_analysis_{ticker}_{YYYYMMDD_HHMMSS}.json
|
||||
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"partial_analysis_{ticker}_{timestamp_str}.json"
|
||||
|
||||
return str(Path(output_dir) / filename)
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""
|
||||
Recursively convert objects to JSON-serializable format.
|
||||
|
||||
Handles:
|
||||
- Dictionaries (recurse on values)
|
||||
- Lists/tuples (recurse on items)
|
||||
- datetime objects (convert to ISO format)
|
||||
- Other objects (convert to string)
|
||||
|
||||
Args:
|
||||
obj: Object to make serializable
|
||||
|
||||
Returns:
|
||||
JSON-serializable version of obj
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return {key: _make_serializable(value) for key, value in obj.items()}
|
||||
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [_make_serializable(item) for item in obj]
|
||||
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
|
||||
# For everything else (including Mock objects), convert to string
|
||||
try:
|
||||
# Try to convert to dict if it has __dict__
|
||||
if hasattr(obj, '__dict__'):
|
||||
return {
|
||||
'_type': obj.__class__.__name__,
|
||||
'_str': str(obj),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Final fallback: convert to string
|
||||
return str(obj)
|
||||
|
|
@ -0,0 +1,224 @@
|
|||
"""
|
||||
LLM Rate Limit Exception Hierarchy.
|
||||
|
||||
This module provides a unified exception hierarchy for handling rate limit errors
|
||||
across different LLM providers (OpenAI, Anthropic, OpenRouter).
|
||||
|
||||
The exception hierarchy:
|
||||
Exception
|
||||
LLMRateLimitError (base class)
|
||||
OpenAIRateLimitError
|
||||
AnthropicRateLimitError
|
||||
OpenRouterRateLimitError
|
||||
|
||||
Each exception includes:
|
||||
- message: Human-readable error message
|
||||
- retry_after: Optional[int] - Seconds to wait before retrying
|
||||
- provider: str - The LLM provider that raised the error
|
||||
|
||||
Usage:
|
||||
from tradingagents.utils.exceptions import from_provider_error
|
||||
|
||||
try:
|
||||
# Make LLM API call
|
||||
response = client.chat.completions.create(...)
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ == "RateLimitError":
|
||||
# Convert to unified exception
|
||||
unified_error = from_provider_error(e, provider="openai")
|
||||
raise unified_error
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMRateLimitError(Exception):
|
||||
"""
|
||||
Base exception for LLM rate limit errors.
|
||||
|
||||
Attributes:
|
||||
message (str): Human-readable error message
|
||||
retry_after (Optional[int]): Seconds to wait before retrying
|
||||
provider (Optional[str]): The LLM provider that raised the error
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
retry_after: Optional[int] = None,
|
||||
provider: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a rate limit error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
retry_after: Optional seconds to wait before retrying
|
||||
provider: Optional provider name (openai, anthropic, openrouter)
|
||||
"""
|
||||
self.retry_after = retry_after
|
||||
self.provider = provider
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OpenAIRateLimitError(LLMRateLimitError):
|
||||
"""
|
||||
OpenAI-specific rate limit error.
|
||||
|
||||
Automatically sets provider='openai'.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, retry_after: Optional[int] = None):
|
||||
"""
|
||||
Initialize an OpenAI rate limit error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
retry_after: Optional seconds to wait before retrying
|
||||
"""
|
||||
super().__init__(message, retry_after=retry_after, provider="openai")
|
||||
|
||||
|
||||
class AnthropicRateLimitError(LLMRateLimitError):
|
||||
"""
|
||||
Anthropic-specific rate limit error.
|
||||
|
||||
Automatically sets provider='anthropic'.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, retry_after: Optional[int] = None):
|
||||
"""
|
||||
Initialize an Anthropic rate limit error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
retry_after: Optional seconds to wait before retrying
|
||||
"""
|
||||
super().__init__(message, retry_after=retry_after, provider="anthropic")
|
||||
|
||||
|
||||
class OpenRouterRateLimitError(LLMRateLimitError):
|
||||
"""
|
||||
OpenRouter-specific rate limit error.
|
||||
|
||||
Automatically sets provider='openrouter'.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, retry_after: Optional[int] = None):
|
||||
"""
|
||||
Initialize an OpenRouter rate limit error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
retry_after: Optional seconds to wait before retrying
|
||||
"""
|
||||
super().__init__(message, retry_after=retry_after, provider="openrouter")
|
||||
|
||||
|
||||
def from_provider_error(error, provider: str) -> LLMRateLimitError:
|
||||
"""
|
||||
Convert a native provider error to a unified LLMRateLimitError.
|
||||
|
||||
Extracts retry_after from response headers if available and creates
|
||||
the appropriate provider-specific exception.
|
||||
|
||||
Args:
|
||||
error: The native provider error object (e.g., openai.RateLimitError)
|
||||
provider: The provider name ('openai', 'anthropic', 'openrouter')
|
||||
|
||||
Returns:
|
||||
LLMRateLimitError: Provider-specific unified exception
|
||||
|
||||
Raises:
|
||||
ValueError: If the error is not a rate limit error
|
||||
|
||||
Example:
|
||||
try:
|
||||
response = client.chat.completions.create(...)
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ == "RateLimitError":
|
||||
unified = from_provider_error(e, provider="openai")
|
||||
logger.error(f"Rate limit: retry in {unified.retry_after}s")
|
||||
raise unified
|
||||
"""
|
||||
# Validate that this is a rate limit error
|
||||
if error.__class__.__name__ != "RateLimitError":
|
||||
raise ValueError(
|
||||
f"Not a rate limit error: {error.__class__.__name__}. "
|
||||
"This function only converts RateLimitError exceptions."
|
||||
)
|
||||
|
||||
# Extract error message
|
||||
message = _extract_message(error)
|
||||
|
||||
# Extract retry_after from response headers
|
||||
retry_after = _extract_retry_after(error)
|
||||
|
||||
# Create provider-specific exception
|
||||
if provider == "openai":
|
||||
return OpenAIRateLimitError(message, retry_after=retry_after)
|
||||
elif provider == "anthropic":
|
||||
return AnthropicRateLimitError(message, retry_after=retry_after)
|
||||
elif provider == "openrouter":
|
||||
return OpenRouterRateLimitError(message, retry_after=retry_after)
|
||||
else:
|
||||
# Unknown provider - use base class
|
||||
return LLMRateLimitError(message, retry_after=retry_after, provider=provider)
|
||||
|
||||
|
||||
def _extract_message(error) -> str:
|
||||
"""
|
||||
Extract error message from provider error object.
|
||||
|
||||
Args:
|
||||
error: The native provider error object
|
||||
|
||||
Returns:
|
||||
str: The error message
|
||||
"""
|
||||
# Try to get message attribute
|
||||
if hasattr(error, "message"):
|
||||
return str(error.message)
|
||||
|
||||
# Fall back to __str__
|
||||
return str(error)
|
||||
|
||||
|
||||
def _extract_retry_after(error) -> Optional[int]:
|
||||
"""
|
||||
Extract retry_after value from error response headers.
|
||||
|
||||
Args:
|
||||
error: The native provider error object
|
||||
|
||||
Returns:
|
||||
Optional[int]: Retry after seconds, or None if not available
|
||||
"""
|
||||
try:
|
||||
# Check if error has response object
|
||||
if not hasattr(error, "response") or error.response is None:
|
||||
return None
|
||||
|
||||
# Check if response has headers
|
||||
if not hasattr(error.response, "headers") or error.response.headers is None:
|
||||
return None
|
||||
|
||||
# Get retry-after header
|
||||
headers = error.response.headers
|
||||
retry_after = headers.get("retry-after") or headers.get("Retry-After")
|
||||
|
||||
if retry_after is None:
|
||||
return None
|
||||
|
||||
# Convert to int
|
||||
retry_after_int = int(retry_after)
|
||||
|
||||
# Validate - must be non-negative
|
||||
if retry_after_int < 0:
|
||||
return None
|
||||
|
||||
return retry_after_int
|
||||
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
# Invalid retry-after value or missing attributes
|
||||
return None
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
"""
|
||||
Dual-Output Logging Configuration.
|
||||
|
||||
This module provides logging configuration that outputs to both:
|
||||
1. Terminal (console) with Rich formatting
|
||||
2. Rotating log files (5MB rotation, 3 backups)
|
||||
|
||||
Features:
|
||||
- Terminal logging at INFO level by default
|
||||
- File logging at DEBUG level by default
|
||||
- Automatic log rotation at 5MB
|
||||
- API key sanitization in log messages
|
||||
- Log file creation in TRADINGAGENTS_RESULTS_DIR or ./logs
|
||||
|
||||
Usage:
|
||||
from tradingagents.utils.logging_config import setup_dual_logger
|
||||
|
||||
logger = setup_dual_logger(
|
||||
name="tradingagents",
|
||||
log_file="./logs/tradingagents.log"
|
||||
)
|
||||
|
||||
logger.info("This goes to both terminal and file")
|
||||
logger.debug("This only goes to file")
|
||||
|
||||
# API keys are automatically sanitized
|
||||
logger.error("Error with key sk-1234567890") # Logged as [REDACTED-API-KEY]
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from rich.logging import RichHandler
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
|
||||
|
||||
# API key patterns to sanitize
|
||||
API_KEY_PATTERNS = [
|
||||
(re.compile(r'sk-[a-zA-Z0-9\-_]+'), '[REDACTED-API-KEY]'), # OpenAI keys
|
||||
(re.compile(r'sk-or-v\d+-[a-zA-Z0-9\-_]+'), '[REDACTED-API-KEY]'), # OpenRouter keys
|
||||
(re.compile(r'sk-ant-[a-zA-Z0-9\-_]+'), '[REDACTED-API-KEY]'), # Anthropic keys
|
||||
(re.compile(r'sk-proj-[a-zA-Z0-9\-_]+'), '[REDACTED-API-KEY]'), # OpenAI project keys
|
||||
(re.compile(r'Bearer\s+[A-Za-z0-9+/\-_.=]+'), 'Bearer [REDACTED-TOKEN]'), # Bearer tokens (incl. Base64)
|
||||
]
|
||||
|
||||
|
||||
class SanitizingFilter(logging.Filter):
|
||||
"""
|
||||
Logging filter that sanitizes API keys and sensitive data from log messages.
|
||||
"""
|
||||
|
||||
def filter(self, record):
|
||||
"""
|
||||
Sanitize the log record message.
|
||||
|
||||
Args:
|
||||
record: LogRecord to sanitize
|
||||
|
||||
Returns:
|
||||
bool: Always True (we modify in place, don't filter out)
|
||||
"""
|
||||
if record.msg:
|
||||
record.msg = sanitize_log_message(str(record.msg))
|
||||
|
||||
# Also sanitize args if present
|
||||
if record.args:
|
||||
try:
|
||||
sanitized_args = tuple(
|
||||
sanitize_log_message(str(arg)) if isinstance(arg, str) else arg
|
||||
for arg in record.args
|
||||
)
|
||||
record.args = sanitized_args
|
||||
except (TypeError, ValueError):
|
||||
# If args aren't iterable or conversion fails, leave as-is
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def sanitize_log_message(message: Optional[str]) -> str:
|
||||
"""
|
||||
Remove API keys and sensitive data from log messages.
|
||||
|
||||
Sanitizes the following patterns:
|
||||
- OpenAI API keys (sk-*)
|
||||
- OpenRouter API keys (sk-or-*)
|
||||
- Anthropic API keys (sk-ant-*)
|
||||
- Bearer tokens
|
||||
- Other common API key patterns
|
||||
|
||||
Args:
|
||||
message: The log message to sanitize
|
||||
|
||||
Returns:
|
||||
str: Sanitized message with API keys replaced with [REDACTED-API-KEY]
|
||||
|
||||
Example:
|
||||
>>> sanitize_log_message("Error with key sk-1234567890")
|
||||
'Error with key [REDACTED-API-KEY]'
|
||||
"""
|
||||
if message is None:
|
||||
return ""
|
||||
|
||||
if not isinstance(message, str):
|
||||
message = str(message)
|
||||
|
||||
# Escape newlines/carriage returns to prevent log injection (CWE-117)
|
||||
sanitized = message.replace('\r\n', '\\r\\n').replace('\n', '\\n').replace('\r', '\\r')
|
||||
for pattern, replacement in API_KEY_PATTERNS:
|
||||
sanitized = pattern.sub(replacement, sanitized)
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def setup_dual_logger(
|
||||
name: str = "tradingagents",
|
||||
log_file: Optional[str] = None,
|
||||
console_level: int = logging.INFO,
|
||||
file_level: int = logging.DEBUG,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Setup a logger with dual output: terminal (Rich) + rotating file.
|
||||
|
||||
Creates a logger that outputs to:
|
||||
1. Terminal with Rich formatting (if available) or standard StreamHandler
|
||||
2. Rotating file handler (5MB max size, 3 backups)
|
||||
|
||||
Both handlers automatically sanitize API keys and sensitive data.
|
||||
|
||||
Args:
|
||||
name: Logger name (default: "tradingagents")
|
||||
log_file: Path to log file (default: logs/tradingagents.log in results dir)
|
||||
console_level: Log level for terminal output (default: INFO)
|
||||
file_level: Log level for file output (default: DEBUG)
|
||||
|
||||
Returns:
|
||||
logging.Logger: Configured logger instance
|
||||
|
||||
Example:
|
||||
>>> logger = setup_dual_logger("my_module", "./logs/app.log")
|
||||
>>> logger.info("Terminal and file")
|
||||
>>> logger.debug("File only")
|
||||
"""
|
||||
# Create logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG) # Capture all levels, handlers will filter
|
||||
|
||||
# Clear existing handlers to prevent duplicates
|
||||
logger.handlers.clear()
|
||||
|
||||
# Create sanitizing filter
|
||||
sanitize_filter = SanitizingFilter()
|
||||
|
||||
# ===== Terminal Handler =====
|
||||
if RICH_AVAILABLE:
|
||||
# Use Rich handler for beautiful terminal output
|
||||
console_handler = RichHandler(
|
||||
rich_tracebacks=True,
|
||||
show_time=True,
|
||||
show_path=False,
|
||||
)
|
||||
else:
|
||||
# Fall back to standard stream handler
|
||||
console_handler = logging.StreamHandler()
|
||||
|
||||
console_handler.setLevel(console_level)
|
||||
console_handler.addFilter(sanitize_filter)
|
||||
|
||||
# Console format (simpler for terminal)
|
||||
console_formatter = logging.Formatter(
|
||||
'%(message)s'
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# ===== File Handler =====
|
||||
# Determine log file path
|
||||
if log_file is None:
|
||||
# Use TRADINGAGENTS_RESULTS_DIR or default to ./logs
|
||||
results_dir = os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results")
|
||||
log_dir = Path(results_dir) / "logs"
|
||||
log_file = str(log_dir / "tradingagents.log")
|
||||
|
||||
# Create log directory if it doesn't exist
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create rotating file handler
|
||||
# 5MB max size, 3 backup files
|
||||
file_handler = RotatingFileHandler(
|
||||
filename=str(log_path),
|
||||
maxBytes=5 * 1024 * 1024, # 5MB
|
||||
backupCount=3,
|
||||
encoding='utf-8',
|
||||
)
|
||||
file_handler.setLevel(file_level)
|
||||
file_handler.addFilter(sanitize_filter)
|
||||
|
||||
# File format (more detailed)
|
||||
file_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# Prevent propagation to root logger
|
||||
logger.propagate = False
|
||||
|
||||
return logger
|
||||
Loading…
Reference in New Issue