TradingAgents/tests/test_output_language_propag...

174 lines
6.9 KiB
Python

import copy
import unittest
from tradingagents.agents.managers.portfolio_manager import create_portfolio_manager
from tradingagents.agents.managers.research_manager import create_research_manager
from tradingagents.agents.researchers.bear_researcher import create_bear_researcher
from tradingagents.agents.researchers.bull_researcher import create_bull_researcher
from tradingagents.agents.risk_mgmt.aggressive_debator import create_aggressive_debator
from tradingagents.agents.risk_mgmt.conservative_debator import create_conservative_debator
from tradingagents.agents.risk_mgmt.neutral_debator import create_neutral_debator
from tradingagents.agents.trader.trader import create_trader
from tradingagents.agents.utils.agent_utils import (
get_collaboration_stop_instruction,
normalize_chinese_role_terms,
)
from tradingagents.dataflows.config import get_config, set_config
from tradingagents.default_config import DEFAULT_CONFIG
class _FakeResponse:
def __init__(self, content="ok"):
self.content = content
class _CapturingLLM:
def __init__(self):
self.calls = []
def invoke(self, prompt):
self.calls.append(prompt)
return _FakeResponse("测试输出\n反馈快照:\n- 当前观点: x\n- 发生了什么变化: y\n- 为什么变化: z\n- 关键反驳: r\n- 下一轮教训: l")
class _EmptyMemory:
def get_memories(self, *_args, **_kwargs):
return []
class OutputLanguagePropagationTests(unittest.TestCase):
def setUp(self):
self.original_config = copy.deepcopy(get_config())
cfg = copy.deepcopy(DEFAULT_CONFIG)
cfg["output_language"] = "Chinese"
set_config(cfg)
self.base_state = {
"company_of_interest": "002155.SZ",
"investment_plan": "Plan",
"market_report": "Market report",
"sentiment_report": "Sentiment report",
"news_report": "News report",
"fundamentals_report": "Fundamentals report",
"trader_investment_plan": "Trader plan",
"risk_debate_state": {
"history": "",
"aggressive_history": "",
"conservative_history": "",
"neutral_history": "",
"latest_speaker": "",
"current_aggressive_response": "",
"current_conservative_response": "",
"current_neutral_response": "",
"aggressive_snapshot": "",
"conservative_snapshot": "",
"neutral_snapshot": "",
"debate_brief": "",
"judge_decision": "",
"count": 0,
},
"investment_debate_state": {
"history": "",
"bear_history": "",
"bull_history": "",
"current_response": "",
"bull_snapshot": "",
"bear_snapshot": "",
"debate_brief": "",
"latest_speaker": "",
"count": 0,
},
}
def tearDown(self):
set_config(self.original_config)
def test_trader_prompt_respects_output_language(self):
llm = _CapturingLLM()
node = create_trader(llm, _EmptyMemory())
node(self.base_state)
system_prompt = llm.calls[0][0]["content"]
self.assertIn("Write your entire response in Chinese.", system_prompt)
def test_research_manager_prompt_respects_output_language(self):
llm = _CapturingLLM()
node = create_research_manager(llm, _EmptyMemory())
node(self.base_state)
prompt = llm.calls[0]
self.assertIn("Write your entire response in Chinese.", prompt)
self.assertIn("多头分析师", prompt)
self.assertIn("空头分析师", prompt)
self.assertIn("买入", prompt)
self.assertIn("持有", prompt)
def test_bull_bear_researcher_prompts_require_chinese_body_and_chinese_final_title(self):
for factory in (create_bull_researcher, create_bear_researcher):
llm = _CapturingLLM()
node = factory(llm, _EmptyMemory())
node(self.base_state)
prompt = llm.calls[0]
self.assertIn("written entirely in Chinese", prompt)
self.assertIn("最终交易建议", prompt)
self.assertIn("Write your entire response in Chinese.", prompt)
self.assertIn("Do not use variants like", prompt)
self.assertIn("牛派分析师", prompt)
self.assertIn("熊派分析师", prompt)
self.assertIn("禁止填写“未明确说明”“暂无”“同上”“无变化”", prompt)
def test_normalize_chinese_role_terms_replaces_bull_bear_variants(self):
text = "我是熊派分析师,也不同意牛派分析师和熊派投资者的说法。"
normalized = normalize_chinese_role_terms(text)
self.assertNotIn("熊派分析师", normalized)
self.assertNotIn("牛派分析师", normalized)
self.assertNotIn("熊派投资者", normalized)
self.assertIn("空头分析师", normalized)
self.assertIn("多头分析师", normalized)
self.assertIn("空头投资者", normalized)
def test_risk_team_prompts_respect_output_language(self):
for factory in (
create_aggressive_debator,
create_conservative_debator,
create_neutral_debator,
):
llm = _CapturingLLM()
node = factory(llm)
node(self.base_state)
prompt = llm.calls[0]
self.assertIn("Write your entire response in Chinese.", prompt)
self.assertIn("反馈快照", prompt)
self.assertIn("禁止填写“未明确说明”“暂无”“同上”“无变化”", prompt)
def test_portfolio_manager_prompt_respects_output_language(self):
llm = _CapturingLLM()
node = create_portfolio_manager(llm, _EmptyMemory())
node(self.base_state)
prompt = llm.calls[0]
self.assertIn("Write your entire response in Chinese.", prompt)
self.assertIn("反馈快照", prompt)
self.assertIn("激进分析师", prompt)
self.assertIn("保守分析师", prompt)
self.assertIn("中性分析师", prompt)
self.assertIn("评级体系", prompt)
self.assertIn("买入", prompt)
self.assertIn("增持", prompt)
self.assertIn("持有", prompt)
self.assertIn("减持", prompt)
self.assertIn("卖出", prompt)
self.assertIn("禁止填写“未明确说明”“暂无”“同上”“无变化”", prompt)
def test_collaboration_stop_instruction_prefers_chinese_display(self):
instruction = get_collaboration_stop_instruction()
self.assertIn("最终交易建议: **买入/持有/卖出**", instruction)
self.assertIn("FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**", instruction)
if __name__ == "__main__":
unittest.main()