This commit is contained in:
parent
c93d118308
commit
2f79956694
|
|
@ -1,6 +1,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
from tradingagents.agents.utils.output_filter import fix_common_llm_errors, validate_and_warn
|
||||||
|
|
||||||
|
|
||||||
def create_risk_manager(llm, memory):
|
def create_risk_manager(llm, memory):
|
||||||
|
|
@ -49,13 +50,10 @@ def create_risk_manager(llm, memory):
|
||||||
# 從記憶體中獲取過去相似情況的經驗
|
# 從記憶體中獲取過去相似情況的經驗
|
||||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||||
|
|
||||||
# 將過去的經驗格式化為字串(限制長度)
|
# 將過去的經驗格式化為字串
|
||||||
past_memory_str = ""
|
past_memory_str = ""
|
||||||
for i, rec in enumerate(past_memories, 1):
|
for i, rec in enumerate(past_memories, 1):
|
||||||
recommendation = rec["recommendation"]
|
recommendation = rec["recommendation"]
|
||||||
# 限制每條記憶的長度
|
|
||||||
if len(recommendation) > 200:
|
|
||||||
recommendation = recommendation[:200] + "...(已截斷)"
|
|
||||||
past_memory_str += recommendation + "\n\n"
|
past_memory_str += recommendation + "\n\n"
|
||||||
|
|
||||||
# 截斷辯論歷史 - 這是最容易超過限制的部分
|
# 截斷辯論歷史 - 這是最容易超過限制的部分
|
||||||
|
|
@ -108,6 +106,10 @@ def create_risk_manager(llm, memory):
|
||||||
|
|
||||||
# 呼叫 LLM 生成決策
|
# 呼叫 LLM 生成決策
|
||||||
response = llm.invoke(prompt)
|
response = llm.invoke(prompt)
|
||||||
|
|
||||||
|
# CRITICAL FIX: Apply output filtering to fix common LLM errors
|
||||||
|
response.content = fix_common_llm_errors(response.content)
|
||||||
|
validate_and_warn(response.content, "Risk_Manager")
|
||||||
|
|
||||||
# 更新風險辯論狀態
|
# 更新風險辯論狀態
|
||||||
new_risk_debate_state = {
|
new_risk_debate_state = {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,129 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
LLM Output Post-Processing Filter
|
||||||
|
Fixes common LLM output errors including character corruption and format issues
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def count_chinese_words(text: str) -> int:
|
||||||
|
"""
|
||||||
|
Count Chinese characters in text (excluding markdown and tables)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of Chinese characters
|
||||||
|
"""
|
||||||
|
# Remove code blocks
|
||||||
|
clean_text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
|
||||||
|
# Remove tables
|
||||||
|
clean_text = re.sub(r'\|.*?\|', '', clean_text, flags=re.MULTILINE)
|
||||||
|
# Remove markdown formatting
|
||||||
|
clean_text = re.sub(r'[#\*_`~\[\]\(\)]', '', clean_text)
|
||||||
|
|
||||||
|
# Count Chinese characters (CJK Unified Ideographs)
|
||||||
|
return len([c for c in clean_text if '\u4e00' <= c <= '\u9fff'])
|
||||||
|
|
||||||
|
|
||||||
|
def fix_common_llm_errors(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Fix common LLM character selection errors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: LLM output text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Corrected text
|
||||||
|
"""
|
||||||
|
# Common character misuse patterns
|
||||||
|
replacements = {
|
||||||
|
# '煉' misuse - should be '練' (practice/train) in most contexts
|
||||||
|
'煉習': '練習',
|
||||||
|
'訓煉': '訓練',
|
||||||
|
'**煉**': '**練**',
|
||||||
|
'(煉': '(練',
|
||||||
|
'煉)': '練)',
|
||||||
|
|
||||||
|
# Other common errors (add as discovered)
|
||||||
|
'絓驗': '經驗', # We saw this corruption before
|
||||||
|
}
|
||||||
|
|
||||||
|
for wrong, correct in replacements.items():
|
||||||
|
text = text.replace(wrong, correct)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_warn(content: str, agent_name: str) -> list:
|
||||||
|
"""
|
||||||
|
Validate report content and return list of warnings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Report content
|
||||||
|
agent_name: Name of the agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of warning messages
|
||||||
|
"""
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
# Check for suspicious '煉' character
|
||||||
|
# Only flag if not in proper contexts like "冶煉", "提煉", "鍛煉"
|
||||||
|
if '煉' in content:
|
||||||
|
proper_contexts = ['冶煉', '提煉', '鍛煉', '精煉', '修煉']
|
||||||
|
is_proper = any(ctx in content for ctx in proper_contexts)
|
||||||
|
if not is_proper:
|
||||||
|
# Find context around '煉'
|
||||||
|
idx = content.find('煉')
|
||||||
|
context = content[max(0, idx-15):min(len(content), idx+15)]
|
||||||
|
warnings.append(f"Suspicious '煉' character found. Context: ...{context}...")
|
||||||
|
|
||||||
|
# Check word count
|
||||||
|
word_count = count_chinese_words(content)
|
||||||
|
if word_count < 800:
|
||||||
|
warnings.append(f"Too short: {word_count} words (expected 800-1500)")
|
||||||
|
elif word_count > 1500:
|
||||||
|
warnings.append(f"Too long: {word_count} words (expected 800-1500)")
|
||||||
|
|
||||||
|
# Check for truncation markers that shouldn't be there
|
||||||
|
truncation_markers = ['...(已截斷)', '...(內容已截斷)', '...(為控制長度已精簡)']
|
||||||
|
for marker in truncation_markers:
|
||||||
|
if marker in content:
|
||||||
|
warnings.append(f"Found truncation marker: '{marker}'")
|
||||||
|
|
||||||
|
if warnings:
|
||||||
|
print(f"\n⚠️ {agent_name} Report Warnings:")
|
||||||
|
for warning in warnings:
|
||||||
|
print(f" - {warning}")
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_agent_output(content: str, agent_name: str, retry_callback=None) -> str:
|
||||||
|
"""
|
||||||
|
Complete post-processing pipeline for agent output
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Raw agent output
|
||||||
|
agent_name: Name of the agent
|
||||||
|
retry_callback: Optional function to call if validation fails
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed and validated content
|
||||||
|
"""
|
||||||
|
# Step 1: Fix common errors
|
||||||
|
content = fix_common_llm_errors(content)
|
||||||
|
|
||||||
|
# Step 2: Validate and warn
|
||||||
|
warnings = validate_and_warn(content, agent_name)
|
||||||
|
|
||||||
|
# Step 3: Critical validation - retry if needed
|
||||||
|
word_count = count_chinese_words(content)
|
||||||
|
if (word_count < 800 or word_count > 1500) and retry_callback:
|
||||||
|
print(f"\n🔄 {agent_name}: Word count {word_count} out of range, triggering retry...")
|
||||||
|
# Callback should regenerate the content
|
||||||
|
# This is optional and should be implemented in the calling code
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
@ -42,7 +42,7 @@ def get_stock_news_openai(query, start_date, end_date):
|
||||||
"search_context_size": "low",
|
"search_context_size": "low",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=1,
|
temperature=0.7, # Reduced from 1.0 to prevent character errors like '煉' and stabilize output
|
||||||
max_output_tokens=4096,
|
max_output_tokens=4096,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
store=True,
|
store=True,
|
||||||
|
|
@ -90,7 +90,7 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
|
||||||
"search_context_size": "low",
|
"search_context_size": "low",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=1,
|
temperature=0.7, # Reduced from 1.0 to prevent character errors and stabilize output
|
||||||
max_output_tokens=4096,
|
max_output_tokens=4096,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
store=True,
|
store=True,
|
||||||
|
|
@ -137,7 +137,7 @@ def get_fundamentals_openai(ticker, curr_date):
|
||||||
"search_context_size": "low",
|
"search_context_size": "low",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
temperature=1,
|
temperature=0.7, # Reduced from 1.0 to prevent character errors and stabilize output
|
||||||
max_output_tokens=4096,
|
max_output_tokens=4096,
|
||||||
top_p=1,
|
top_p=1,
|
||||||
store=True,
|
store=True,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue