701 lines
23 KiB
Python
701 lines
23 KiB
Python
"""
|
|
Test suite for Output Validation Utilities.
|
|
|
|
This module tests:
|
|
1. ValidationResult dataclass behavior
|
|
2. Report completeness validation (length, markdown, sections)
|
|
3. Decision quality validation (signal extraction, reasoning)
|
|
4. Debate state validation (history, count, judge_decision)
|
|
5. Complete agent state validation (orchestration)
|
|
|
|
All tests use mocked data (no real API calls).
|
|
"""
|
|
|
|
import pytest
|
|
from typing import Dict, Any
|
|
|
|
from tradingagents.utils.output_validator import (
|
|
ValidationResult,
|
|
validate_report_completeness,
|
|
validate_decision_quality,
|
|
validate_debate_state,
|
|
validate_agent_state,
|
|
)
|
|
|
|
pytestmark = pytest.mark.unit
|
|
|
|
|
|
# ============================================================================
|
|
# Test ValidationResult Dataclass
|
|
# ============================================================================
|
|
|
|
class TestValidationResult:
|
|
"""Test ValidationResult dataclass behavior."""
|
|
|
|
def test_default_valid_result(self):
|
|
"""Test ValidationResult defaults to valid with empty lists."""
|
|
result = ValidationResult(is_valid=True)
|
|
|
|
assert result.is_valid is True
|
|
assert result.errors == []
|
|
assert result.warnings == []
|
|
assert result.metrics == {}
|
|
|
|
def test_add_error_marks_invalid(self):
|
|
"""Test that add_error() marks result as invalid."""
|
|
result = ValidationResult(is_valid=True)
|
|
result.add_error("Something went wrong")
|
|
|
|
assert result.is_valid is False
|
|
assert len(result.errors) == 1
|
|
assert result.errors[0] == "Something went wrong"
|
|
|
|
def test_add_warning_keeps_valid(self):
|
|
"""Test that add_warning() doesn't change validity."""
|
|
result = ValidationResult(is_valid=True)
|
|
result.add_warning("This could be better")
|
|
|
|
assert result.is_valid is True
|
|
assert len(result.warnings) == 1
|
|
assert result.warnings[0] == "This could be better"
|
|
|
|
def test_add_metric(self):
|
|
"""Test that add_metric() stores key-value pairs."""
|
|
result = ValidationResult(is_valid=True)
|
|
result.add_metric("length", 500)
|
|
result.add_metric("signal", "BUY")
|
|
|
|
assert result.metrics["length"] == 500
|
|
assert result.metrics["signal"] == "BUY"
|
|
|
|
def test_multiple_errors_and_warnings(self):
|
|
"""Test accumulating multiple errors and warnings."""
|
|
result = ValidationResult(is_valid=True)
|
|
result.add_error("Error 1")
|
|
result.add_error("Error 2")
|
|
result.add_warning("Warning 1")
|
|
result.add_warning("Warning 2")
|
|
|
|
assert result.is_valid is False
|
|
assert len(result.errors) == 2
|
|
assert len(result.warnings) == 2
|
|
|
|
|
|
# ============================================================================
|
|
# Test Report Validation
|
|
# ============================================================================
|
|
|
|
class TestReportValidation:
|
|
"""Test validate_report_completeness() function."""
|
|
|
|
def test_valid_report_passes(self):
|
|
"""Test that a valid report passes validation."""
|
|
report = "# Market Analysis\n\n" + "This is a comprehensive report. " * 50
|
|
|
|
result = validate_report_completeness(report, min_length=500)
|
|
|
|
assert result.is_valid is True
|
|
assert len(result.errors) == 0
|
|
assert result.metrics["length"] > 500
|
|
|
|
def test_none_report_fails(self):
|
|
"""Test that None report fails validation."""
|
|
result = validate_report_completeness(None)
|
|
|
|
assert result.is_valid is False
|
|
assert "None" in result.errors[0]
|
|
|
|
def test_empty_report_fails(self):
|
|
"""Test that empty report fails validation."""
|
|
result = validate_report_completeness("")
|
|
|
|
assert result.is_valid is False
|
|
assert "empty" in result.errors[0].lower()
|
|
|
|
def test_short_report_fails(self):
|
|
"""Test that report below min_length fails."""
|
|
short_report = "Too short"
|
|
|
|
result = validate_report_completeness(short_report, min_length=500)
|
|
|
|
assert result.is_valid is False
|
|
assert any("minimum" in err.lower() for err in result.errors)
|
|
assert result.metrics["length"] < 500
|
|
|
|
def test_wrong_type_fails(self):
|
|
"""Test that non-string report fails validation."""
|
|
result = validate_report_completeness(123)
|
|
|
|
assert result.is_valid is False
|
|
assert "string" in result.errors[0].lower()
|
|
|
|
def test_markdown_table_detection(self):
|
|
"""Test detection of markdown tables."""
|
|
report_with_table = """
|
|
# Analysis
|
|
|
|
| Metric | Value |
|
|
|--------|-------|
|
|
| Price | $100 |
|
|
| Volume | 1M |
|
|
""" + "Additional text. " * 50
|
|
|
|
result = validate_report_completeness(
|
|
report_with_table,
|
|
min_length=200,
|
|
require_markdown_tables=True
|
|
)
|
|
|
|
assert result.is_valid is True
|
|
assert result.metrics["markdown_tables"] > 0
|
|
|
|
def test_missing_markdown_table_fails_when_required(self):
|
|
"""Test that missing markdown tables fails when required."""
|
|
report = "# Analysis\n\n" + "No tables here. " * 50
|
|
|
|
result = validate_report_completeness(
|
|
report,
|
|
min_length=200,
|
|
require_markdown_tables=True
|
|
)
|
|
|
|
assert result.is_valid is False
|
|
assert any("table" in err.lower() for err in result.errors)
|
|
|
|
def test_section_header_detection(self):
|
|
"""Test detection of section headers."""
|
|
report_with_headers = """
|
|
# Main Title
|
|
## Subsection
|
|
### Details
|
|
|
|
Content here.
|
|
""" + "More content. " * 50
|
|
|
|
result = validate_report_completeness(
|
|
report_with_headers,
|
|
min_length=200,
|
|
require_sections=True
|
|
)
|
|
|
|
assert result.is_valid is True
|
|
assert result.metrics["section_headers"] >= 3
|
|
|
|
def test_missing_sections_fails_when_required(self):
|
|
"""Test that missing sections fails when required."""
|
|
report = "Just plain text. " * 50
|
|
|
|
result = validate_report_completeness(
|
|
report,
|
|
min_length=200,
|
|
require_sections=True
|
|
)
|
|
|
|
assert result.is_valid is False
|
|
assert any("section" in err.lower() for err in result.errors)
|
|
|
|
def test_short_report_warning(self):
|
|
"""Test warning for relatively short reports."""
|
|
# Report is above min but below 1.5x min
|
|
report = "Short but valid. " * 40 # ~680 chars
|
|
|
|
result = validate_report_completeness(report, min_length=500)
|
|
|
|
assert result.is_valid is True
|
|
assert len(result.warnings) > 0
|
|
assert any("short" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_bullet_point_detection(self):
|
|
"""Test detection of bullet points."""
|
|
report_with_bullets = """
|
|
# Analysis
|
|
|
|
- Point 1
|
|
- Point 2
|
|
* Point 3
|
|
|
|
""" + "Additional content. " * 50
|
|
|
|
result = validate_report_completeness(report_with_bullets, min_length=200)
|
|
|
|
assert result.metrics["has_bullet_points"] is True
|
|
|
|
def test_unstructured_content_warning(self):
|
|
"""Test warning for content lacking structure."""
|
|
unstructured_report = "Just a long stream of text without any structure. " * 50
|
|
|
|
result = validate_report_completeness(unstructured_report, min_length=500)
|
|
|
|
assert result.is_valid is True
|
|
assert any("structured" in warn.lower() for warn in result.warnings)
|
|
|
|
|
|
# ============================================================================
|
|
# Test Decision Validation
|
|
# ============================================================================
|
|
|
|
class TestDecisionValidation:
|
|
"""Test validate_decision_quality() function."""
|
|
|
|
def test_valid_buy_decision(self):
|
|
"""Test that valid BUY decision passes."""
|
|
decision = "BUY: Strong fundamentals and positive momentum"
|
|
|
|
result = validate_decision_quality(decision)
|
|
|
|
assert result.is_valid is True
|
|
assert result.metrics["signal"] == "BUY"
|
|
assert result.metrics["has_reasoning"] is True
|
|
|
|
def test_valid_sell_decision(self):
|
|
"""Test that valid SELL decision passes."""
|
|
decision = "SELL: Overvalued with deteriorating fundamentals"
|
|
|
|
result = validate_decision_quality(decision)
|
|
|
|
assert result.is_valid is True
|
|
assert result.metrics["signal"] == "SELL"
|
|
|
|
def test_valid_hold_decision(self):
|
|
"""Test that valid HOLD decision passes."""
|
|
decision = "HOLD: Mixed signals, awaiting clarity"
|
|
|
|
result = validate_decision_quality(decision)
|
|
|
|
assert result.is_valid is True
|
|
assert result.metrics["signal"] == "HOLD"
|
|
|
|
def test_case_insensitive_signal_extraction(self):
|
|
"""Test that signals are extracted case-insensitively."""
|
|
decisions = [
|
|
"buy the stock",
|
|
"BUY the stock",
|
|
"Buy the stock",
|
|
"We should buy",
|
|
]
|
|
|
|
for decision in decisions:
|
|
result = validate_decision_quality(decision)
|
|
assert result.metrics["signal"] == "BUY"
|
|
|
|
def test_none_decision_fails(self):
|
|
"""Test that None decision fails validation."""
|
|
result = validate_decision_quality(None)
|
|
|
|
assert result.is_valid is False
|
|
assert "None" in result.errors[0]
|
|
|
|
def test_empty_decision_fails(self):
|
|
"""Test that empty decision fails validation."""
|
|
result = validate_decision_quality("")
|
|
|
|
assert result.is_valid is False
|
|
assert "empty" in result.errors[0].lower()
|
|
|
|
def test_no_signal_fails(self):
|
|
"""Test that decision without signal fails."""
|
|
decision = "This is a decision without a clear signal"
|
|
|
|
result = validate_decision_quality(decision)
|
|
|
|
assert result.is_valid is False
|
|
assert any("signal" in err.lower() for err in result.errors)
|
|
assert result.metrics["signal"] is None
|
|
|
|
def test_wrong_type_fails(self):
|
|
"""Test that non-string decision fails."""
|
|
result = validate_decision_quality({"decision": "BUY"})
|
|
|
|
assert result.is_valid is False
|
|
assert "string" in result.errors[0].lower()
|
|
|
|
def test_multiple_signals_warning(self):
|
|
"""Test warning for multiple conflicting signals."""
|
|
decision = "BUY or maybe SELL, hard to decide, could HOLD"
|
|
|
|
result = validate_decision_quality(decision)
|
|
|
|
# Should still extract first signal
|
|
assert result.metrics["signal"] == "BUY"
|
|
# But warn about conflicts
|
|
assert len(result.warnings) > 0
|
|
assert any("conflicting" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_short_decision_warning(self):
|
|
"""Test warning for very short decisions."""
|
|
decision = "BUY"
|
|
|
|
result = validate_decision_quality(decision)
|
|
|
|
assert result.is_valid is True
|
|
assert len(result.warnings) > 0
|
|
assert any("short" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_decision_with_reasoning_markers(self):
|
|
"""Test that reasoning markers are detected."""
|
|
decisions_with_reasoning = [
|
|
"BUY: Strong fundamentals",
|
|
"SELL. Company is overvalued.",
|
|
"HOLD because market is uncertain",
|
|
]
|
|
|
|
for decision in decisions_with_reasoning:
|
|
result = validate_decision_quality(decision)
|
|
assert result.metrics["has_reasoning"] is True
|
|
|
|
def test_signal_count_metric(self):
|
|
"""Test that signal_count metric is accurate."""
|
|
decision = "BUY BUY BUY! Strong signal to buy"
|
|
|
|
result = validate_decision_quality(decision)
|
|
|
|
assert result.metrics["signal_count"] == 4
|
|
assert result.metrics["signal"] == "BUY"
|
|
|
|
|
|
# ============================================================================
|
|
# Test Debate State Validation
|
|
# ============================================================================
|
|
|
|
class TestDebateStateValidation:
|
|
"""Test validate_debate_state() function."""
|
|
|
|
def test_valid_invest_debate_state(self):
|
|
"""Test that valid invest debate state passes."""
|
|
debate_state = {
|
|
"history": "Round 1: Bull argues...\nRound 2: Bear counters...",
|
|
"count": 2,
|
|
"judge_decision": "BUY: Bulls made stronger case",
|
|
"bull_history": "Bull argument",
|
|
"bear_history": "Bear argument",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state, debate_type="invest")
|
|
|
|
assert result.is_valid is True
|
|
assert result.metrics["history_length"] > 0
|
|
assert result.metrics["count"] == 2
|
|
assert result.metrics["judge_signal"] == "BUY"
|
|
|
|
def test_valid_risk_debate_state(self):
|
|
"""Test that valid risk debate state passes."""
|
|
debate_state = {
|
|
"history": "Round 1: Risky argues...\nRound 2: Safe counters...",
|
|
"count": 2,
|
|
"judge_decision": "HOLD: Balanced risk profile",
|
|
"risky_history": "Risky argument",
|
|
"safe_history": "Safe argument",
|
|
"neutral_history": "Neutral argument",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state, debate_type="risk")
|
|
|
|
assert result.is_valid is True
|
|
assert result.metrics["count"] == 2
|
|
|
|
def test_none_debate_state_fails(self):
|
|
"""Test that None debate state fails."""
|
|
result = validate_debate_state(None)
|
|
|
|
assert result.is_valid is False
|
|
assert "None" in result.errors[0]
|
|
|
|
def test_wrong_type_fails(self):
|
|
"""Test that non-dict debate state fails."""
|
|
result = validate_debate_state("not a dict")
|
|
|
|
assert result.is_valid is False
|
|
assert "dict" in result.errors[0].lower()
|
|
|
|
def test_missing_required_fields_fails(self):
|
|
"""Test that missing required fields fails."""
|
|
incomplete_state = {
|
|
"history": "Some history",
|
|
# Missing count and judge_decision
|
|
}
|
|
|
|
result = validate_debate_state(incomplete_state)
|
|
|
|
assert result.is_valid is False
|
|
assert any("missing" in err.lower() for err in result.errors)
|
|
|
|
def test_invalid_debate_type_fails(self):
|
|
"""Test that unknown debate type fails."""
|
|
debate_state = {
|
|
"history": "History",
|
|
"count": 1,
|
|
"judge_decision": "BUY",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state, debate_type="unknown")
|
|
|
|
assert result.is_valid is False
|
|
assert "unknown" in result.errors[0].lower()
|
|
|
|
def test_empty_history_warning(self):
|
|
"""Test warning for empty history."""
|
|
debate_state = {
|
|
"history": "",
|
|
"count": 0,
|
|
"judge_decision": "HOLD",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state)
|
|
|
|
assert result.is_valid is True
|
|
assert any("empty" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_negative_count_fails(self):
|
|
"""Test that negative count fails."""
|
|
debate_state = {
|
|
"history": "History",
|
|
"count": -1,
|
|
"judge_decision": "BUY",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state)
|
|
|
|
assert result.is_valid is False
|
|
assert any("negative" in err.lower() for err in result.errors)
|
|
|
|
def test_high_count_warning(self):
|
|
"""Test warning for very high debate count."""
|
|
debate_state = {
|
|
"history": "Long debate...",
|
|
"count": 15,
|
|
"judge_decision": "SELL",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state)
|
|
|
|
assert result.is_valid is True
|
|
assert any("high" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_invalid_judge_decision_warning(self):
|
|
"""Test warning for poor quality judge decision."""
|
|
debate_state = {
|
|
"history": "History",
|
|
"count": 2,
|
|
"judge_decision": "No clear signal here",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state)
|
|
|
|
assert result.is_valid is True
|
|
assert len(result.warnings) > 0
|
|
|
|
def test_optional_fields_metric(self):
|
|
"""Test that optional fields are counted."""
|
|
debate_state = {
|
|
"history": "History",
|
|
"count": 1,
|
|
"judge_decision": "BUY",
|
|
"bull_history": "Bull",
|
|
"bear_history": "Bear",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state, debate_type="invest")
|
|
|
|
assert result.metrics["optional_fields_present"] >= 2
|
|
|
|
def test_wrong_history_type_fails(self):
|
|
"""Test that non-string history fails."""
|
|
debate_state = {
|
|
"history": 123,
|
|
"count": 1,
|
|
"judge_decision": "BUY",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state)
|
|
|
|
assert result.is_valid is False
|
|
assert any("string" in err.lower() for err in result.errors)
|
|
|
|
def test_wrong_count_type_fails(self):
|
|
"""Test that non-int count fails."""
|
|
debate_state = {
|
|
"history": "History",
|
|
"count": "two",
|
|
"judge_decision": "BUY",
|
|
}
|
|
|
|
result = validate_debate_state(debate_state)
|
|
|
|
assert result.is_valid is False
|
|
assert any("int" in err.lower() for err in result.errors)
|
|
|
|
|
|
# ============================================================================
|
|
# Test Agent State Validation
|
|
# ============================================================================
|
|
|
|
class TestAgentStateValidation:
|
|
"""Test validate_agent_state() function."""
|
|
|
|
def test_valid_complete_agent_state(self):
|
|
"""Test that complete valid agent state passes."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
"market_report": "# Market Analysis\n\n" + "Detailed analysis. " * 100,
|
|
"sentiment_report": "# Sentiment Report\n\n" + "Social sentiment. " * 100,
|
|
"news_report": "# News Report\n\n" + "Latest news. " * 100,
|
|
"fundamentals_report": "# Fundamentals\n\n" + "Financial data. " * 100,
|
|
"investment_debate_state": {
|
|
"history": "Debate history",
|
|
"count": 3,
|
|
"judge_decision": "BUY: Strong case",
|
|
},
|
|
"risk_debate_state": {
|
|
"history": "Risk debate",
|
|
"count": 2,
|
|
"judge_decision": "HOLD: Moderate risk",
|
|
},
|
|
"final_trade_decision": "BUY: All signals align positively",
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.is_valid is True
|
|
assert result.metrics["company_of_interest"] == "AAPL"
|
|
assert result.metrics["trade_date"] == "2024-01-15"
|
|
assert result.metrics["reports_present"] == 4
|
|
assert result.metrics["final_signal"] == "BUY"
|
|
|
|
def test_none_state_fails(self):
|
|
"""Test that None state fails."""
|
|
result = validate_agent_state(None)
|
|
|
|
assert result.is_valid is False
|
|
assert "None" in result.errors[0]
|
|
|
|
def test_wrong_type_fails(self):
|
|
"""Test that non-dict state fails."""
|
|
result = validate_agent_state("not a dict")
|
|
|
|
assert result.is_valid is False
|
|
assert "dict" in result.errors[0].lower()
|
|
|
|
def test_missing_company_fails(self):
|
|
"""Test that missing company fails."""
|
|
state = {
|
|
"trade_date": "2024-01-15",
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.is_valid is False
|
|
assert any("company" in err.lower() for err in result.errors)
|
|
|
|
def test_missing_trade_date_fails(self):
|
|
"""Test that missing trade date fails."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.is_valid is False
|
|
assert any("trade_date" in err.lower() for err in result.errors)
|
|
|
|
def test_incomplete_reports_warning(self):
|
|
"""Test warning when some reports are missing."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
"market_report": "Market analysis. " * 100,
|
|
# Missing other reports
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
# Basic fields present, so valid
|
|
assert result.is_valid is True
|
|
# But warn about missing reports
|
|
assert len(result.warnings) > 0
|
|
assert result.metrics["reports_present"] < 4
|
|
|
|
def test_invalid_report_warning(self):
|
|
"""Test warning for invalid report content."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
"market_report": "Too short", # Below min length
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.is_valid is True
|
|
assert any("market_report" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_invalid_invest_debate_warning(self):
|
|
"""Test warning for invalid investment debate."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
"investment_debate_state": {
|
|
# Missing required fields
|
|
"history": "History",
|
|
},
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.is_valid is True
|
|
assert any("investment debate" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_invalid_risk_debate_warning(self):
|
|
"""Test warning for invalid risk debate."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
"risk_debate_state": {
|
|
"count": -1, # Invalid
|
|
},
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.is_valid is True
|
|
assert any("risk debate" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_invalid_final_decision_warning(self):
|
|
"""Test warning for invalid final decision."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
"final_trade_decision": "No clear signal",
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.is_valid is True
|
|
assert any("final decision" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_incomplete_state_warning(self):
|
|
"""Test warning for very incomplete state."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
# No debates or decision
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.is_valid is True
|
|
assert any("incomplete" in warn.lower() for warn in result.warnings)
|
|
|
|
def test_reports_count_metrics(self):
|
|
"""Test that report counts are tracked."""
|
|
state = {
|
|
"company_of_interest": "AAPL",
|
|
"trade_date": "2024-01-15",
|
|
"market_report": "Report. " * 100,
|
|
"sentiment_report": "Report. " * 100,
|
|
}
|
|
|
|
result = validate_agent_state(state)
|
|
|
|
assert result.metrics["reports_present"] == 2
|
|
assert result.metrics["total_reports_expected"] == 4
|