This commit is contained in:
MarkLo 2025-11-26 15:53:42 +08:00
parent c93d118308
commit 2f79956694
3 changed files with 138 additions and 7 deletions

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import time
import json
from tradingagents.agents.utils.output_filter import fix_common_llm_errors, validate_and_warn
def create_risk_manager(llm, memory):
@ -49,13 +50,10 @@ def create_risk_manager(llm, memory):
# 從記憶體中獲取過去相似情況的經驗
past_memories = memory.get_memories(curr_situation, n_matches=2)
# 將過去的經驗格式化為字串(限制長度)
# 將過去的經驗格式化為字串
past_memory_str = ""
for i, rec in enumerate(past_memories, 1):
recommendation = rec["recommendation"]
# 限制每條記憶的長度
if len(recommendation) > 200:
recommendation = recommendation[:200] + "...(已截斷)"
past_memory_str += recommendation + "\n\n"
# 截斷辯論歷史 - 這是最容易超過限制的部分
@ -108,6 +106,10 @@ def create_risk_manager(llm, memory):
# 呼叫 LLM 生成決策
response = llm.invoke(prompt)
# CRITICAL FIX: Apply output filtering to fix common LLM errors
response.content = fix_common_llm_errors(response.content)
validate_and_warn(response.content, "Risk_Manager")
# 更新風險辯論狀態
new_risk_debate_state = {

View File

@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""
LLM Output Post-Processing Filter
Fixes common LLM output errors including character corruption and format issues
"""
import re
def count_chinese_words(text: str) -> int:
"""
Count Chinese characters in text (excluding markdown and tables)
Args:
text: Input text
Returns:
Number of Chinese characters
"""
# Remove code blocks
clean_text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
# Remove tables
clean_text = re.sub(r'\|.*?\|', '', clean_text, flags=re.MULTILINE)
# Remove markdown formatting
clean_text = re.sub(r'[#\*_`~\[\]\(\)]', '', clean_text)
# Count Chinese characters (CJK Unified Ideographs)
return len([c for c in clean_text if '\u4e00' <= c <= '\u9fff'])
def fix_common_llm_errors(text: str) -> str:
"""
Fix common LLM character selection errors
Args:
text: LLM output text
Returns:
Corrected text
"""
# Common character misuse patterns
replacements = {
# '煉' misuse - should be '練' (practice/train) in most contexts
'煉習': '練習',
'訓煉': '訓練',
'**煉**': '**練**',
'(煉': '(練',
'煉)': '練)',
# Other common errors (add as discovered)
'絓驗': '經驗', # We saw this corruption before
}
for wrong, correct in replacements.items():
text = text.replace(wrong, correct)
return text
def validate_and_warn(content: str, agent_name: str) -> list:
"""
Validate report content and return list of warnings
Args:
content: Report content
agent_name: Name of the agent
Returns:
List of warning messages
"""
warnings = []
# Check for suspicious '煉' character
# Only flag if not in proper contexts like "冶煉", "提煉", "鍛煉"
if '' in content:
proper_contexts = ['冶煉', '提煉', '鍛煉', '精煉', '修煉']
is_proper = any(ctx in content for ctx in proper_contexts)
if not is_proper:
# Find context around '煉'
idx = content.find('')
context = content[max(0, idx-15):min(len(content), idx+15)]
warnings.append(f"Suspicious '' character found. Context: ...{context}...")
# Check word count
word_count = count_chinese_words(content)
if word_count < 800:
warnings.append(f"Too short: {word_count} words (expected 800-1500)")
elif word_count > 1500:
warnings.append(f"Too long: {word_count} words (expected 800-1500)")
# Check for truncation markers that shouldn't be there
truncation_markers = ['...(已截斷)', '...(內容已截斷)', '...(為控制長度已精簡)']
for marker in truncation_markers:
if marker in content:
warnings.append(f"Found truncation marker: '{marker}'")
if warnings:
print(f"\n⚠️ {agent_name} Report Warnings:")
for warning in warnings:
print(f" - {warning}")
return warnings
def post_process_agent_output(content: str, agent_name: str, retry_callback=None) -> str:
"""
Complete post-processing pipeline for agent output
Args:
content: Raw agent output
agent_name: Name of the agent
retry_callback: Optional function to call if validation fails
Returns:
Processed and validated content
"""
# Step 1: Fix common errors
content = fix_common_llm_errors(content)
# Step 2: Validate and warn
warnings = validate_and_warn(content, agent_name)
# Step 3: Critical validation - retry if needed
word_count = count_chinese_words(content)
if (word_count < 800 or word_count > 1500) and retry_callback:
print(f"\n🔄 {agent_name}: Word count {word_count} out of range, triggering retry...")
# Callback should regenerate the content
# This is optional and should be implemented in the calling code
return content

View File

@ -42,7 +42,7 @@ def get_stock_news_openai(query, start_date, end_date):
"search_context_size": "low",
}
],
temperature=1,
temperature=0.7, # Reduced from 1.0 to prevent character errors like '煉' and stabilize output
max_output_tokens=4096,
top_p=1,
store=True,
@ -90,7 +90,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
"search_context_size": "low",
}
],
temperature=1,
temperature=0.7, # Reduced from 1.0 to prevent character errors and stabilize output
max_output_tokens=4096,
top_p=1,
store=True,
@ -137,7 +137,7 @@ def get_fundamentals_openai(ticker, curr_date):
"search_context_size": "low",
}
],
temperature=1,
temperature=0.7, # Reduced from 1.0 to prevent character errors and stabilize output
max_output_tokens=4096,
top_p=1,
store=True,