TradingAgents/tests/test_custom_prompt.py

58 lines
1.9 KiB
Python

import tempfile
import unittest
from pathlib import Path
from cli.main import save_report_to_disk
from tradingagents.agents.utils.agent_utils import build_custom_prompt_context
from tradingagents.graph.propagation import Propagator
class CustomPromptTests(unittest.TestCase):
def test_build_custom_prompt_context_is_empty_when_missing(self):
self.assertEqual(build_custom_prompt_context(" "), "")
def test_build_custom_prompt_context_formats_user_guidance(self):
context = build_custom_prompt_context(
"Long-term horizon; focus on earnings quality and capex discipline."
)
self.assertIn("Additional user instructions", context)
self.assertIn("Long-term horizon", context)
self.assertIn("explicit strategy constraints", context)
def test_create_initial_state_stores_custom_prompt(self):
state = Propagator().create_initial_state(
"META",
"2026-04-05",
custom_prompt="Short-term swing trade; new positions only.",
)
self.assertEqual(
state["custom_prompt"],
"Short-term swing trade; new positions only.",
)
def test_save_report_to_disk_includes_custom_prompt_header(self):
final_state = {
"custom_prompt": "Long-term horizon; focus on capital allocation.",
"market_report": "",
"sentiment_report": "",
"news_report": "",
"fundamentals_report": "",
"investment_debate_state": {},
"risk_debate_state": {},
}
with tempfile.TemporaryDirectory() as tmpdir:
report_path = save_report_to_disk(
final_state,
"META",
Path(tmpdir),
)
report_text = Path(report_path).read_text()
self.assertIn("## Custom Prompt", report_text)
self.assertIn("Long-term horizon", report_text)
if __name__ == "__main__":
unittest.main()