This commit is contained in:
parent
c93d118308
commit
2f79956694
|
|
@ -1,6 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
import json
|
||||
from tradingagents.agents.utils.output_filter import fix_common_llm_errors, validate_and_warn
|
||||
|
||||
|
||||
def create_risk_manager(llm, memory):
|
||||
|
|
@ -49,13 +50,10 @@ def create_risk_manager(llm, memory):
|
|||
# 從記憶體中獲取過去相似情況的經驗
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
# 將過去的經驗格式化為字串(限制長度)
|
||||
# 將過去的經驗格式化為字串
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
recommendation = rec["recommendation"]
|
||||
# 限制每條記憶的長度
|
||||
if len(recommendation) > 200:
|
||||
recommendation = recommendation[:200] + "...(已截斷)"
|
||||
past_memory_str += recommendation + "\n\n"
|
||||
|
||||
# 截斷辯論歷史 - 這是最容易超過限制的部分
|
||||
|
|
@ -108,6 +106,10 @@ def create_risk_manager(llm, memory):
|
|||
|
||||
# 呼叫 LLM 生成決策
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
# CRITICAL FIX: Apply output filtering to fix common LLM errors
|
||||
response.content = fix_common_llm_errors(response.content)
|
||||
validate_and_warn(response.content, "Risk_Manager")
|
||||
|
||||
# 更新風險辯論狀態
|
||||
new_risk_debate_state = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LLM Output Post-Processing Filter
|
||||
Fixes common LLM output errors including character corruption and format issues
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
def count_chinese_words(text: str) -> int:
|
||||
"""
|
||||
Count Chinese characters in text (excluding markdown and tables)
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
|
||||
Returns:
|
||||
Number of Chinese characters
|
||||
"""
|
||||
# Remove code blocks
|
||||
clean_text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
|
||||
# Remove tables
|
||||
clean_text = re.sub(r'\|.*?\|', '', clean_text, flags=re.MULTILINE)
|
||||
# Remove markdown formatting
|
||||
clean_text = re.sub(r'[#\*_`~\[\]\(\)]', '', clean_text)
|
||||
|
||||
# Count Chinese characters (CJK Unified Ideographs)
|
||||
return len([c for c in clean_text if '\u4e00' <= c <= '\u9fff'])
|
||||
|
||||
|
||||
def fix_common_llm_errors(text: str) -> str:
|
||||
"""
|
||||
Fix common LLM character selection errors
|
||||
|
||||
Args:
|
||||
text: LLM output text
|
||||
|
||||
Returns:
|
||||
Corrected text
|
||||
"""
|
||||
# Common character misuse patterns
|
||||
replacements = {
|
||||
# '煉' misuse - should be '練' (practice/train) in most contexts
|
||||
'煉習': '練習',
|
||||
'訓煉': '訓練',
|
||||
'**煉**': '**練**',
|
||||
'(煉': '(練',
|
||||
'煉)': '練)',
|
||||
|
||||
# Other common errors (add as discovered)
|
||||
'絓驗': '經驗', # We saw this corruption before
|
||||
}
|
||||
|
||||
for wrong, correct in replacements.items():
|
||||
text = text.replace(wrong, correct)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def validate_and_warn(content: str, agent_name: str) -> list:
|
||||
"""
|
||||
Validate report content and return list of warnings
|
||||
|
||||
Args:
|
||||
content: Report content
|
||||
agent_name: Name of the agent
|
||||
|
||||
Returns:
|
||||
List of warning messages
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# Check for suspicious '煉' character
|
||||
# Only flag if not in proper contexts like "冶煉", "提煉", "鍛煉"
|
||||
if '煉' in content:
|
||||
proper_contexts = ['冶煉', '提煉', '鍛煉', '精煉', '修煉']
|
||||
is_proper = any(ctx in content for ctx in proper_contexts)
|
||||
if not is_proper:
|
||||
# Find context around '煉'
|
||||
idx = content.find('煉')
|
||||
context = content[max(0, idx-15):min(len(content), idx+15)]
|
||||
warnings.append(f"Suspicious '煉' character found. Context: ...{context}...")
|
||||
|
||||
# Check word count
|
||||
word_count = count_chinese_words(content)
|
||||
if word_count < 800:
|
||||
warnings.append(f"Too short: {word_count} words (expected 800-1500)")
|
||||
elif word_count > 1500:
|
||||
warnings.append(f"Too long: {word_count} words (expected 800-1500)")
|
||||
|
||||
# Check for truncation markers that shouldn't be there
|
||||
truncation_markers = ['...(已截斷)', '...(內容已截斷)', '...(為控制長度已精簡)']
|
||||
for marker in truncation_markers:
|
||||
if marker in content:
|
||||
warnings.append(f"Found truncation marker: '{marker}'")
|
||||
|
||||
if warnings:
|
||||
print(f"\n⚠️ {agent_name} Report Warnings:")
|
||||
for warning in warnings:
|
||||
print(f" - {warning}")
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
def post_process_agent_output(content: str, agent_name: str, retry_callback=None) -> str:
|
||||
"""
|
||||
Complete post-processing pipeline for agent output
|
||||
|
||||
Args:
|
||||
content: Raw agent output
|
||||
agent_name: Name of the agent
|
||||
retry_callback: Optional function to call if validation fails
|
||||
|
||||
Returns:
|
||||
Processed and validated content
|
||||
"""
|
||||
# Step 1: Fix common errors
|
||||
content = fix_common_llm_errors(content)
|
||||
|
||||
# Step 2: Validate and warn
|
||||
warnings = validate_and_warn(content, agent_name)
|
||||
|
||||
# Step 3: Critical validation - retry if needed
|
||||
word_count = count_chinese_words(content)
|
||||
if (word_count < 800 or word_count > 1500) and retry_callback:
|
||||
print(f"\n🔄 {agent_name}: Word count {word_count} out of range, triggering retry...")
|
||||
# Callback should regenerate the content
|
||||
# This is optional and should be implemented in the calling code
|
||||
|
||||
return content
|
||||
|
|
@ -42,7 +42,7 @@ def get_stock_news_openai(query, start_date, end_date):
|
|||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
temperature=0.7, # Reduced from 1.0 to prevent character errors like '煉' and stabilize output
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
|
|
@ -90,7 +90,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
|
|||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
temperature=0.7, # Reduced from 1.0 to prevent character errors and stabilize output
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
|
|
@ -137,7 +137,7 @@ def get_fundamentals_openai(ticker, curr_date):
|
|||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
temperature=0.7, # Reduced from 1.0 to prevent character errors and stabilize output
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
|
|
|
|||
Loading…
Reference in New Issue