💬 支持用户自定义分析视角注入(user_context 参数)
This commit is contained in:
parent
a125f5c906
commit
8daa8b3477
|
|
@ -4,7 +4,8 @@ TradingAgents 分析脚本
|
||||||
用法:
|
用法:
|
||||||
python run_analysis.py NVDA # 分析 NVDA,使用今日日期
|
python run_analysis.py NVDA # 分析 NVDA,使用今日日期
|
||||||
python run_analysis.py VOO # 分析 VOO
|
python run_analysis.py VOO # 分析 VOO
|
||||||
python run_analysis.py NVDA 2025-03-01 # 分析指定日期
|
python run_analysis.py NVDA 2026-03-20 # 指定日期
|
||||||
|
python run_analysis.py NVDA 2026-03-20 "中东地缘冲突..." # 注入自定义分析视角
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -23,7 +24,7 @@ config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "google"
|
config["llm_provider"] = "google"
|
||||||
config["deep_think_llm"] = "gemini-2.5-flash"
|
config["deep_think_llm"] = "gemini-2.5-flash"
|
||||||
config["quick_think_llm"] = "gemini-2.5-flash"
|
config["quick_think_llm"] = "gemini-2.5-flash"
|
||||||
config["max_debate_rounds"] = 1 # 减少辩论轮次,降低单次请求量
|
config["max_debate_rounds"] = 1
|
||||||
config["max_risk_discuss_rounds"] = 1
|
config["max_risk_discuss_rounds"] = 1
|
||||||
config["data_vendors"] = {
|
config["data_vendors"] = {
|
||||||
"core_stock_apis": "yfinance",
|
"core_stock_apis": "yfinance",
|
||||||
|
|
@ -35,12 +36,21 @@ config["data_vendors"] = {
|
||||||
# ── 入参解析 ─────────────────────────────────────────────────────────────────
|
# ── 入参解析 ─────────────────────────────────────────────────────────────────
|
||||||
ticker = sys.argv[1].upper() if len(sys.argv) > 1 else "NVDA"
|
ticker = sys.argv[1].upper() if len(sys.argv) > 1 else "NVDA"
|
||||||
analysis_date = sys.argv[2] if len(sys.argv) > 2 else str(date.today())
|
analysis_date = sys.argv[2] if len(sys.argv) > 2 else str(date.today())
|
||||||
|
user_context = sys.argv[3] if len(sys.argv) > 3 else ""
|
||||||
|
|
||||||
|
# 如果没有传入上下文,交互式询问
|
||||||
|
if not user_context and sys.stdin.isatty():
|
||||||
|
print("\n💬 是否有自定义分析视角?(直接回车跳过)")
|
||||||
|
print(" 例:当前主导美股的核心因素是中东地缘冲突,请重点考虑此风险")
|
||||||
|
user_context = input(" > ").strip()
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"🤖 TradingAgents 多 Agent 分析")
|
print(f"🤖 TradingAgents 多 Agent 分析")
|
||||||
print(f" 标的:{ticker}")
|
print(f" 标的:{ticker}")
|
||||||
print(f" 日期:{analysis_date}")
|
print(f" 日期:{analysis_date}")
|
||||||
print(f" 模型:Gemini 2.5 Flash")
|
print(f" 模型:Gemini 2.5 Flash")
|
||||||
|
if user_context:
|
||||||
|
print(f" 用户视角:{user_context[:60]}{'...' if len(user_context) > 60 else ''}")
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
# ── 执行分析(带 retry)────────────────────────────────────────────────────────
|
# ── 执行分析(带 retry)────────────────────────────────────────────────────────
|
||||||
|
|
@ -53,8 +63,8 @@ for attempt in range(1, MAX_RETRIES + 1):
|
||||||
debug=False,
|
debug=False,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
final_state, decision = ta.propagate(ticker, analysis_date)
|
final_state, decision = ta.propagate(ticker, analysis_date, user_context)
|
||||||
break # 成功则退出重试
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ 第 {attempt} 次失败: {type(e).__name__}: {str(e)[:120]}")
|
print(f"⚠️ 第 {attempt} 次失败: {type(e).__name__}: {str(e)[:120]}")
|
||||||
if attempt < MAX_RETRIES:
|
if attempt < MAX_RETRIES:
|
||||||
|
|
@ -76,8 +86,10 @@ os.makedirs(output_dir, exist_ok=True)
|
||||||
output_file = os.path.join(output_dir, f"{ticker}_{analysis_date}.txt")
|
output_file = os.path.join(output_dir, f"{ticker}_{analysis_date}.txt")
|
||||||
|
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
f.write(f"标的:{ticker}\n日期:{analysis_date}\n最终决策:{decision}\n\n")
|
f.write(f"标的:{ticker}\n日期:{analysis_date}\n最终决策:{decision}\n")
|
||||||
f.write("="*60 + "\n")
|
if user_context:
|
||||||
|
f.write(f"用户视角:{user_context}\n")
|
||||||
|
f.write("\n" + "="*60 + "\n")
|
||||||
|
|
||||||
f.write("【交易员决策报告】\n")
|
f.write("【交易员决策报告】\n")
|
||||||
f.write(str(final_state.get("trader_investment_plan", "N/A")))
|
f.write(str(final_state.get("trader_investment_plan", "N/A")))
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,16 @@ class Propagator:
|
||||||
self.max_recur_limit = max_recur_limit
|
self.max_recur_limit = max_recur_limit
|
||||||
|
|
||||||
def create_initial_state(
|
def create_initial_state(
|
||||||
self, company_name: str, trade_date: str
|
self, company_name: str, trade_date: str, user_context: str = ""
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create the initial state for the agent graph."""
|
"""Create the initial state for the agent graph."""
|
||||||
|
# 支持注入用户自定义分析视角
|
||||||
|
if user_context:
|
||||||
|
human_msg = f"{company_name}\n\n[用户补充视角]\n{user_context}"
|
||||||
|
else:
|
||||||
|
human_msg = company_name
|
||||||
return {
|
return {
|
||||||
"messages": [("human", company_name)],
|
"messages": [("human", human_msg)],
|
||||||
"company_of_interest": company_name,
|
"company_of_interest": company_name,
|
||||||
"trade_date": str(trade_date),
|
"trade_date": str(trade_date),
|
||||||
"investment_debate_state": InvestDebateState(
|
"investment_debate_state": InvestDebateState(
|
||||||
|
|
|
||||||
|
|
@ -186,14 +186,20 @@ class TradingAgentsGraph:
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
def propagate(self, company_name, trade_date):
|
def propagate(self, company_name, trade_date, user_context: str = ""):
|
||||||
"""Run the trading agents graph for a company on a specific date."""
|
"""Run the trading agents graph for a company on a specific date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
company_name: Stock ticker symbol
|
||||||
|
trade_date: Analysis date (YYYY-MM-DD)
|
||||||
|
user_context: Optional user-provided analysis perspective or focus areas
|
||||||
|
"""
|
||||||
|
|
||||||
self.ticker = company_name
|
self.ticker = company_name
|
||||||
|
|
||||||
# Initialize state
|
# Initialize state
|
||||||
init_agent_state = self.propagator.create_initial_state(
|
init_agent_state = self.propagator.create_initial_state(
|
||||||
company_name, trade_date
|
company_name, trade_date, user_context
|
||||||
)
|
)
|
||||||
args = self.propagator.get_graph_args()
|
args = self.propagator.get_graph_args()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue