This commit is contained in:
MarkLo 2025-12-21 23:15:41 +08:00
parent 0f434546aa
commit 52098b375d
2 changed files with 89 additions and 8 deletions

View File

@ -29,7 +29,7 @@ from tradingagents.graph.trading_graph import TradingAgentsXGraph
from tradingagents.default_config import DEFAULT_CONFIG
from cli.models import AnalystType
from cli.utils import *
from cli.utils import select_market
from cli.utils import select_market, select_embedding_model
# 初始化 rich Console
console = Console()
@ -519,6 +519,14 @@ def get_user_selections():
)
embedding_provider, embedding_url = select_embedding_provider()
# 步驟 8.5:選擇具體的嵌入模型
console.print(
create_question_box(
"步驟 8.5:嵌入模型", "選擇具體的嵌入模型"
)
)
selected_embedding_model = select_embedding_model(embedding_provider)
# 步驟 9API Keys
console.print(
create_question_box(
@ -570,7 +578,15 @@ def get_user_selections():
default_quick_think_key = get_provider_api_key(quick_think_provider)
default_deep_think_key = get_provider_api_key(deep_think_provider)
default_embedding_key = get_provider_api_key(embedding_provider)
# 本地嵌入模型不需要 API Key
is_local_embedding = embedding_url == "local"
if is_local_embedding:
console.print("\n[green]✓ 本地嵌入模型無需 API Key[/green]")
default_embedding_key = None
embedding_api_key = None
else:
default_embedding_key = get_provider_api_key(embedding_provider)
# 快速思維模型 API Key
quick_think_api_key = get_api_key("快速思維模型", default_quick_think_key)
@ -578,8 +594,9 @@ def get_user_selections():
# 深度思維模型 API Key
deep_think_api_key = get_api_key("深度思維模型", default_deep_think_key)
# 嵌入模型 API Key
embedding_api_key = get_api_key("嵌入模型", default_embedding_key)
# 嵌入模型 API Key僅在非本地模型時詢問
if not is_local_embedding:
embedding_api_key = get_api_key("嵌入模型", default_embedding_key)
# Alpha Vantage API Key必填
alpha_vantage_key = os.getenv("ALPHA_VANTAGE_API_KEY")
@ -601,6 +618,7 @@ def get_user_selections():
"market_type": selected_market,
"embedding_provider": embedding_provider,
"embedding_url": embedding_url,
"embedding_model": selected_embedding_model,
"quick_think_api_key": quick_think_api_key,
"deep_think_api_key": deep_think_api_key,
"embedding_api_key": embedding_api_key,

View File

@ -547,11 +547,13 @@ def select_embedding_provider() -> tuple[str, str]:
返回:
tuple[str, str]: 包含供應商名稱和 API 基礎 URL 的元組
對於本地模型URL "local"
"""
# 定義嵌入模型供應商(只有 OpenAI 和自訂)
# 定義嵌入模型供應商(本地 HuggingFace、OpenAI 和自訂)
EMBEDDING_PROVIDERS = [
("OpenAI", "https://api.openai.com/v1"),
("自訂 URL", "custom")
("🖥️ 本地模型 (HuggingFace) - 免費", "local"),
("☁️ OpenAI - 收費", "https://api.openai.com/v1"),
("🔧 自訂 URL", "custom")
]
choice = questionary.select(
@ -603,12 +605,73 @@ def select_embedding_provider() -> tuple[str, str]:
display_name = "自訂供應商"
# 印出使用者的選擇
print(f"您選擇了嵌入模型:{display_name}\tURL: {url}")
if url == "local":
print(f"您選擇了:{display_name}(本地執行,無需 API Key")
else:
print(f"您選擇了嵌入模型:{display_name}\tURL: {url}")
# 返回供應商名稱和 URL
return display_name, url
def select_embedding_model(provider: str) -> str:
"""
根據供應商選擇具體的嵌入模型
參數:
provider (str): 嵌入模型供應商名稱
返回:
str: 選擇的嵌入模型名稱
"""
# 本地 HuggingFace 模型選項
LOCAL_EMBEDDING_MODELS = [
("all-MiniLM-L6-v2 (推薦) - 90MB, 輕量快速", "all-MiniLM-L6-v2"),
("all-mpnet-base-v2 - 420MB, 更高質量", "all-mpnet-base-v2"),
]
# OpenAI 嵌入模型選項
OPENAI_EMBEDDING_MODELS = [
("text-embedding-3-small (推薦) - 高性價比", "text-embedding-3-small"),
("text-embedding-3-large - 最高質量", "text-embedding-3-large"),
]
# 根據供應商判斷使用哪個模型列表
if "本地" in provider or "local" in provider.lower():
model_options = LOCAL_EMBEDDING_MODELS
prompt_text = "選擇本地嵌入模型:"
description = "\n[dim]💡 本地模型首次使用會自動下載,之後無需網路連接[/dim]"
else:
model_options = OPENAI_EMBEDDING_MODELS
prompt_text = "選擇 OpenAI 嵌入模型:"
description = "\n[dim]💡 OpenAI 模型需要 API Key 和網路連接[/dim]"
console.print(description)
choice = questionary.select(
prompt_text,
choices=[
questionary.Choice(display, value=value)
for display, value in model_options
],
instruction="\n- 使用方向鍵導覽\n- 按下 Enter 鍵選擇",
style=questionary.Style(
[
("selected", "fg:green noinherit"),
("highlighted", "fg:green noinherit"),
("pointer", "fg:green noinherit"),
]
),
).ask()
if choice is None:
console.print("\n[red]未選擇嵌入模型。正在結束程式...[/red]")
exit(1)
console.print(f"[green]✓ 已選擇:{choice}[/green]")
return choice
def get_api_key(model_type: str, default_key: Optional[str] = None) -> str:
"""
提示使用者輸入 API Key如果留空則使用預設值