diff --git a/cli/main.py b/cli/main.py index 413ec411..c8e6b18c 100644 --- a/cli/main.py +++ b/cli/main.py @@ -29,7 +29,7 @@ from tradingagents.graph.trading_graph import TradingAgentsXGraph from tradingagents.default_config import DEFAULT_CONFIG from cli.models import AnalystType from cli.utils import * -from cli.utils import select_market +from cli.utils import select_market, select_embedding_model # 初始化 rich Console console = Console() @@ -519,6 +519,14 @@ def get_user_selections(): ) embedding_provider, embedding_url = select_embedding_provider() + # 步驟 8.5:選擇具體的嵌入模型 + console.print( + create_question_box( + "步驟 8.5:嵌入模型", "選擇具體的嵌入模型" + ) + ) + selected_embedding_model = select_embedding_model(embedding_provider) + # 步驟 9:API Keys console.print( create_question_box( @@ -570,7 +578,15 @@ def get_user_selections(): default_quick_think_key = get_provider_api_key(quick_think_provider) default_deep_think_key = get_provider_api_key(deep_think_provider) - default_embedding_key = get_provider_api_key(embedding_provider) + + # 本地嵌入模型不需要 API Key + is_local_embedding = embedding_url == "local" + if is_local_embedding: + console.print("\n[green]✓ 本地嵌入模型無需 API Key[/green]") + default_embedding_key = None + embedding_api_key = None + else: + default_embedding_key = get_provider_api_key(embedding_provider) # 快速思維模型 API Key quick_think_api_key = get_api_key("快速思維模型", default_quick_think_key) @@ -578,8 +594,9 @@ def get_user_selections(): # 深度思維模型 API Key deep_think_api_key = get_api_key("深度思維模型", default_deep_think_key) - # 嵌入模型 API Key - embedding_api_key = get_api_key("嵌入模型", default_embedding_key) + # 嵌入模型 API Key(僅在非本地模型時詢問) + if not is_local_embedding: + embedding_api_key = get_api_key("嵌入模型", default_embedding_key) # Alpha Vantage API Key(必填) alpha_vantage_key = os.getenv("ALPHA_VANTAGE_API_KEY") @@ -601,6 +618,7 @@ def get_user_selections(): "market_type": selected_market, "embedding_provider": embedding_provider, "embedding_url": embedding_url, + "embedding_model": selected_embedding_model, "quick_think_api_key": quick_think_api_key, "deep_think_api_key": deep_think_api_key, "embedding_api_key": embedding_api_key, diff --git a/cli/utils.py b/cli/utils.py index 864c3e79..4994c9f6 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -547,11 +547,13 @@ def select_embedding_provider() -> tuple[str, str]: 返回: tuple[str, str]: 包含供應商名稱和 API 基礎 URL 的元組。 + 對於本地模型,URL 為 "local"。 """ - # 定義嵌入模型供應商(只有 OpenAI 和自訂) + # 定義嵌入模型供應商(本地 HuggingFace、OpenAI 和自訂) EMBEDDING_PROVIDERS = [ - ("OpenAI", "https://api.openai.com/v1"), - ("自訂 URL", "custom") + ("🖥️ 本地模型 (HuggingFace) - 免費", "local"), + ("☁️ OpenAI - 收費", "https://api.openai.com/v1"), + ("🔧 自訂 URL", "custom") ] choice = questionary.select( @@ -603,12 +605,73 @@ def select_embedding_provider() -> tuple[str, str]: display_name = "自訂供應商" # 印出使用者的選擇 - print(f"您選擇了嵌入模型:{display_name}\tURL: {url}") + if url == "local": + print(f"您選擇了:{display_name}(本地執行,無需 API Key)") + else: + print(f"您選擇了嵌入模型:{display_name}\tURL: {url}") # 返回供應商名稱和 URL return display_name, url +def select_embedding_model(provider: str) -> str: + """ + 根據供應商選擇具體的嵌入模型。 + + 參數: + provider (str): 嵌入模型供應商名稱 + + 返回: + str: 選擇的嵌入模型名稱 + """ + # 本地 HuggingFace 模型選項 + LOCAL_EMBEDDING_MODELS = [ + ("all-MiniLM-L6-v2 (推薦) - 90MB, 輕量快速", "all-MiniLM-L6-v2"), + ("all-mpnet-base-v2 - 420MB, 更高質量", "all-mpnet-base-v2"), + ] + + # OpenAI 嵌入模型選項 + OPENAI_EMBEDDING_MODELS = [ + ("text-embedding-3-small (推薦) - 高性價比", "text-embedding-3-small"), + ("text-embedding-3-large - 最高質量", "text-embedding-3-large"), + ] + + # 根據供應商判斷使用哪個模型列表 + if "本地" in provider or "local" in provider.lower(): + model_options = LOCAL_EMBEDDING_MODELS + prompt_text = "選擇本地嵌入模型:" + description = "\n[dim]💡 本地模型首次使用會自動下載,之後無需網路連接[/dim]" + else: + model_options = OPENAI_EMBEDDING_MODELS + prompt_text = "選擇 OpenAI 嵌入模型:" + description = "\n[dim]💡 OpenAI 模型需要 API Key 和網路連接[/dim]" + + console.print(description) + + choice = questionary.select( + prompt_text, + choices=[ + questionary.Choice(display, value=value) + for display, value in model_options + ], + instruction="\n- 使用方向鍵導覽\n- 按下 Enter 鍵選擇", + style=questionary.Style( + [ + ("selected", "fg:green noinherit"), + ("highlighted", "fg:green noinherit"), + ("pointer", "fg:green noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print("\n[red]未選擇嵌入模型。正在結束程式...[/red]") + exit(1) + + console.print(f"[green]✓ 已選擇:{choice}[/green]") + return choice + + def get_api_key(model_type: str, default_key: Optional[str] = None) -> str: """ 提示使用者輸入 API Key,如果留空則使用預設值。