This commit is contained in:
parent
0f434546aa
commit
52098b375d
26
cli/main.py
26
cli/main.py
|
|
@ -29,7 +29,7 @@ from tradingagents.graph.trading_graph import TradingAgentsXGraph
|
|||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from cli.models import AnalystType
|
||||
from cli.utils import *
|
||||
from cli.utils import select_market
|
||||
from cli.utils import select_market, select_embedding_model
|
||||
|
||||
# 初始化 rich Console
|
||||
console = Console()
|
||||
|
|
@ -519,6 +519,14 @@ def get_user_selections():
|
|||
)
|
||||
embedding_provider, embedding_url = select_embedding_provider()
|
||||
|
||||
# 步驟 8.5:選擇具體的嵌入模型
|
||||
console.print(
|
||||
create_question_box(
|
||||
"步驟 8.5:嵌入模型", "選擇具體的嵌入模型"
|
||||
)
|
||||
)
|
||||
selected_embedding_model = select_embedding_model(embedding_provider)
|
||||
|
||||
# 步驟 9:API Keys
|
||||
console.print(
|
||||
create_question_box(
|
||||
|
|
@ -570,7 +578,15 @@ def get_user_selections():
|
|||
|
||||
default_quick_think_key = get_provider_api_key(quick_think_provider)
|
||||
default_deep_think_key = get_provider_api_key(deep_think_provider)
|
||||
default_embedding_key = get_provider_api_key(embedding_provider)
|
||||
|
||||
# 本地嵌入模型不需要 API Key
|
||||
is_local_embedding = embedding_url == "local"
|
||||
if is_local_embedding:
|
||||
console.print("\n[green]✓ 本地嵌入模型無需 API Key[/green]")
|
||||
default_embedding_key = None
|
||||
embedding_api_key = None
|
||||
else:
|
||||
default_embedding_key = get_provider_api_key(embedding_provider)
|
||||
|
||||
# 快速思維模型 API Key
|
||||
quick_think_api_key = get_api_key("快速思維模型", default_quick_think_key)
|
||||
|
|
@ -578,8 +594,9 @@ def get_user_selections():
|
|||
# 深度思維模型 API Key
|
||||
deep_think_api_key = get_api_key("深度思維模型", default_deep_think_key)
|
||||
|
||||
# 嵌入模型 API Key
|
||||
embedding_api_key = get_api_key("嵌入模型", default_embedding_key)
|
||||
# 嵌入模型 API Key(僅在非本地模型時詢問)
|
||||
if not is_local_embedding:
|
||||
embedding_api_key = get_api_key("嵌入模型", default_embedding_key)
|
||||
|
||||
# Alpha Vantage API Key(必填)
|
||||
alpha_vantage_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
||||
|
|
@ -601,6 +618,7 @@ def get_user_selections():
|
|||
"market_type": selected_market,
|
||||
"embedding_provider": embedding_provider,
|
||||
"embedding_url": embedding_url,
|
||||
"embedding_model": selected_embedding_model,
|
||||
"quick_think_api_key": quick_think_api_key,
|
||||
"deep_think_api_key": deep_think_api_key,
|
||||
"embedding_api_key": embedding_api_key,
|
||||
|
|
|
|||
71
cli/utils.py
71
cli/utils.py
|
|
@ -547,11 +547,13 @@ def select_embedding_provider() -> tuple[str, str]:
|
|||
|
||||
返回:
|
||||
tuple[str, str]: 包含供應商名稱和 API 基礎 URL 的元組。
|
||||
對於本地模型,URL 為 "local"。
|
||||
"""
|
||||
# 定義嵌入模型供應商(只有 OpenAI 和自訂)
|
||||
# 定義嵌入模型供應商(本地 HuggingFace、OpenAI 和自訂)
|
||||
EMBEDDING_PROVIDERS = [
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("自訂 URL", "custom")
|
||||
("🖥️ 本地模型 (HuggingFace) - 免費", "local"),
|
||||
("☁️ OpenAI - 收費", "https://api.openai.com/v1"),
|
||||
("🔧 自訂 URL", "custom")
|
||||
]
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
@ -603,12 +605,73 @@ def select_embedding_provider() -> tuple[str, str]:
|
|||
display_name = "自訂供應商"
|
||||
|
||||
# 印出使用者的選擇
|
||||
print(f"您選擇了嵌入模型:{display_name}\tURL: {url}")
|
||||
if url == "local":
|
||||
print(f"您選擇了:{display_name}(本地執行,無需 API Key)")
|
||||
else:
|
||||
print(f"您選擇了嵌入模型:{display_name}\tURL: {url}")
|
||||
|
||||
# 返回供應商名稱和 URL
|
||||
return display_name, url
|
||||
|
||||
|
||||
def select_embedding_model(provider: str) -> str:
|
||||
"""
|
||||
根據供應商選擇具體的嵌入模型。
|
||||
|
||||
參數:
|
||||
provider (str): 嵌入模型供應商名稱
|
||||
|
||||
返回:
|
||||
str: 選擇的嵌入模型名稱
|
||||
"""
|
||||
# 本地 HuggingFace 模型選項
|
||||
LOCAL_EMBEDDING_MODELS = [
|
||||
("all-MiniLM-L6-v2 (推薦) - 90MB, 輕量快速", "all-MiniLM-L6-v2"),
|
||||
("all-mpnet-base-v2 - 420MB, 更高質量", "all-mpnet-base-v2"),
|
||||
]
|
||||
|
||||
# OpenAI 嵌入模型選項
|
||||
OPENAI_EMBEDDING_MODELS = [
|
||||
("text-embedding-3-small (推薦) - 高性價比", "text-embedding-3-small"),
|
||||
("text-embedding-3-large - 最高質量", "text-embedding-3-large"),
|
||||
]
|
||||
|
||||
# 根據供應商判斷使用哪個模型列表
|
||||
if "本地" in provider or "local" in provider.lower():
|
||||
model_options = LOCAL_EMBEDDING_MODELS
|
||||
prompt_text = "選擇本地嵌入模型:"
|
||||
description = "\n[dim]💡 本地模型首次使用會自動下載,之後無需網路連接[/dim]"
|
||||
else:
|
||||
model_options = OPENAI_EMBEDDING_MODELS
|
||||
prompt_text = "選擇 OpenAI 嵌入模型:"
|
||||
description = "\n[dim]💡 OpenAI 模型需要 API Key 和網路連接[/dim]"
|
||||
|
||||
console.print(description)
|
||||
|
||||
choice = questionary.select(
|
||||
prompt_text,
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in model_options
|
||||
],
|
||||
instruction="\n- 使用方向鍵導覽\n- 按下 Enter 鍵選擇",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("selected", "fg:green noinherit"),
|
||||
("highlighted", "fg:green noinherit"),
|
||||
("pointer", "fg:green noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if choice is None:
|
||||
console.print("\n[red]未選擇嵌入模型。正在結束程式...[/red]")
|
||||
exit(1)
|
||||
|
||||
console.print(f"[green]✓ 已選擇:{choice}[/green]")
|
||||
return choice
|
||||
|
||||
|
||||
def get_api_key(model_type: str, default_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
提示使用者輸入 API Key,如果留空則使用預設值。
|
||||
|
|
|
|||
Loading…
Reference in New Issue