From 2f79956694a6d7f39ce16c64d4874b1587801976 Mon Sep 17 00:00:00 2001 From: MarkLo Date: Wed, 26 Nov 2025 15:53:42 +0800 Subject: [PATCH] --- tradingagents/agents/managers/risk_manager.py | 10 +- tradingagents/agents/utils/output_filter.py | 129 ++++++++++++++++++ tradingagents/dataflows/openai.py | 6 +- 3 files changed, 138 insertions(+), 7 deletions(-) create mode 100644 tradingagents/agents/utils/output_filter.py diff --git a/tradingagents/agents/managers/risk_manager.py b/tradingagents/agents/managers/risk_manager.py index 642a78fe..5a079d06 100644 --- a/tradingagents/agents/managers/risk_manager.py +++ b/tradingagents/agents/managers/risk_manager.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import time import json +from tradingagents.agents.utils.output_filter import fix_common_llm_errors, validate_and_warn def create_risk_manager(llm, memory): @@ -49,13 +50,10 @@ def create_risk_manager(llm, memory): # 從記憶體中獲取過去相似情況的經驗 past_memories = memory.get_memories(curr_situation, n_matches=2) - # 將過去的經驗格式化為字串(限制長度) + # 將過去的經驗格式化為字串 past_memory_str = "" for i, rec in enumerate(past_memories, 1): recommendation = rec["recommendation"] - # 限制每條記憶的長度 - if len(recommendation) > 200: - recommendation = recommendation[:200] + "...(已截斷)" past_memory_str += recommendation + "\n\n" # 截斷辯論歷史 - 這是最容易超過限制的部分 @@ -108,6 +106,10 @@ def create_risk_manager(llm, memory): # 呼叫 LLM 生成決策 response = llm.invoke(prompt) + + # CRITICAL FIX: Apply output filtering to fix common LLM errors + response.content = fix_common_llm_errors(response.content) + validate_and_warn(response.content, "Risk_Manager") # 更新風險辯論狀態 new_risk_debate_state = { diff --git a/tradingagents/agents/utils/output_filter.py b/tradingagents/agents/utils/output_filter.py new file mode 100644 index 00000000..bfe3793c --- /dev/null +++ b/tradingagents/agents/utils/output_filter.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +""" +LLM Output Post-Processing Filter +Fixes common LLM output errors including character corruption and format issues +""" +import re + + +def count_chinese_words(text: str) -> int: + """ + Count Chinese characters in text (excluding markdown and tables) + + Args: + text: Input text + + Returns: + Number of Chinese characters + """ + # Remove code blocks + clean_text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) + # Remove tables + clean_text = re.sub(r'\|.*?\|', '', clean_text, flags=re.MULTILINE) + # Remove markdown formatting + clean_text = re.sub(r'[#\*_`~\[\]\(\)]', '', clean_text) + + # Count Chinese characters (CJK Unified Ideographs) + return len([c for c in clean_text if '\u4e00' <= c <= '\u9fff']) + + +def fix_common_llm_errors(text: str) -> str: + """ + Fix common LLM character selection errors + + Args: + text: LLM output text + + Returns: + Corrected text + """ + # Common character misuse patterns + replacements = { + # '煉' misuse - should be '練' (practice/train) in most contexts + '煉習': '練習', + '訓煉': '訓練', + '**煉**': '**練**', + '(煉': '(練', + '煉)': '練)', + + # Other common errors (add as discovered) + '絓驗': '經驗', # We saw this corruption before + } + + for wrong, correct in replacements.items(): + text = text.replace(wrong, correct) + + return text + + +def validate_and_warn(content: str, agent_name: str) -> list: + """ + Validate report content and return list of warnings + + Args: + content: Report content + agent_name: Name of the agent + + Returns: + List of warning messages + """ + warnings = [] + + # Check for suspicious '煉' character + # Only flag if not in proper contexts like "冶煉", "提煉", "鍛煉" + if '煉' in content: + proper_contexts = ['冶煉', '提煉', '鍛煉', '精煉', '修煉'] + is_proper = any(ctx in content for ctx in proper_contexts) + if not is_proper: + # Find context around '煉' + idx = content.find('煉') + context = content[max(0, idx-15):min(len(content), idx+15)] + warnings.append(f"Suspicious '煉' character found. Context: ...{context}...") + + # Check word count + word_count = count_chinese_words(content) + if word_count < 800: + warnings.append(f"Too short: {word_count} words (expected 800-1500)") + elif word_count > 1500: + warnings.append(f"Too long: {word_count} words (expected 800-1500)") + + # Check for truncation markers that shouldn't be there + truncation_markers = ['...(已截斷)', '...(內容已截斷)', '...(為控制長度已精簡)'] + for marker in truncation_markers: + if marker in content: + warnings.append(f"Found truncation marker: '{marker}'") + + if warnings: + print(f"\n⚠️ {agent_name} Report Warnings:") + for warning in warnings: + print(f" - {warning}") + + return warnings + + +def post_process_agent_output(content: str, agent_name: str, retry_callback=None) -> str: + """ + Complete post-processing pipeline for agent output + + Args: + content: Raw agent output + agent_name: Name of the agent + retry_callback: Optional function to call if validation fails + + Returns: + Processed and validated content + """ + # Step 1: Fix common errors + content = fix_common_llm_errors(content) + + # Step 2: Validate and warn + warnings = validate_and_warn(content, agent_name) + + # Step 3: Critical validation - retry if needed + word_count = count_chinese_words(content) + if (word_count < 800 or word_count > 1500) and retry_callback: + print(f"\n🔄 {agent_name}: Word count {word_count} out of range, triggering retry...") + # Callback should regenerate the content + # This is optional and should be implemented in the calling code + + return content diff --git a/tradingagents/dataflows/openai.py b/tradingagents/dataflows/openai.py index 6174bbe0..6f182c43 100644 --- a/tradingagents/dataflows/openai.py +++ b/tradingagents/dataflows/openai.py @@ -42,7 +42,7 @@ def get_stock_news_openai(query, start_date, end_date): "search_context_size": "low", } ], - temperature=1, + temperature=0.7, # Reduced from 1.0 to prevent character errors like '煉' and stabilize output max_output_tokens=4096, top_p=1, store=True, @@ -90,7 +90,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5): "search_context_size": "low", } ], - temperature=1, + temperature=0.7, # Reduced from 1.0 to prevent character errors and stabilize output max_output_tokens=4096, top_p=1, store=True, @@ -137,7 +137,7 @@ def get_fundamentals_openai(ticker, curr_date): "search_context_size": "low", } ], - temperature=1, + temperature=0.7, # Reduced from 1.0 to prevent character errors and stabilize output max_output_tokens=4096, top_p=1, store=True,