This commit is contained in:
parent
0f434546aa
commit
52098b375d
26
cli/main.py
26
cli/main.py
|
|
@ -29,7 +29,7 @@ from tradingagents.graph.trading_graph import TradingAgentsXGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
from cli.models import AnalystType
|
from cli.models import AnalystType
|
||||||
from cli.utils import *
|
from cli.utils import *
|
||||||
from cli.utils import select_market
|
from cli.utils import select_market, select_embedding_model
|
||||||
|
|
||||||
# 初始化 rich Console
|
# 初始化 rich Console
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
@ -519,6 +519,14 @@ def get_user_selections():
|
||||||
)
|
)
|
||||||
embedding_provider, embedding_url = select_embedding_provider()
|
embedding_provider, embedding_url = select_embedding_provider()
|
||||||
|
|
||||||
|
# 步驟 8.5:選擇具體的嵌入模型
|
||||||
|
console.print(
|
||||||
|
create_question_box(
|
||||||
|
"步驟 8.5:嵌入模型", "選擇具體的嵌入模型"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
selected_embedding_model = select_embedding_model(embedding_provider)
|
||||||
|
|
||||||
# 步驟 9:API Keys
|
# 步驟 9:API Keys
|
||||||
console.print(
|
console.print(
|
||||||
create_question_box(
|
create_question_box(
|
||||||
|
|
@ -570,7 +578,15 @@ def get_user_selections():
|
||||||
|
|
||||||
default_quick_think_key = get_provider_api_key(quick_think_provider)
|
default_quick_think_key = get_provider_api_key(quick_think_provider)
|
||||||
default_deep_think_key = get_provider_api_key(deep_think_provider)
|
default_deep_think_key = get_provider_api_key(deep_think_provider)
|
||||||
default_embedding_key = get_provider_api_key(embedding_provider)
|
|
||||||
|
# 本地嵌入模型不需要 API Key
|
||||||
|
is_local_embedding = embedding_url == "local"
|
||||||
|
if is_local_embedding:
|
||||||
|
console.print("\n[green]✓ 本地嵌入模型無需 API Key[/green]")
|
||||||
|
default_embedding_key = None
|
||||||
|
embedding_api_key = None
|
||||||
|
else:
|
||||||
|
default_embedding_key = get_provider_api_key(embedding_provider)
|
||||||
|
|
||||||
# 快速思維模型 API Key
|
# 快速思維模型 API Key
|
||||||
quick_think_api_key = get_api_key("快速思維模型", default_quick_think_key)
|
quick_think_api_key = get_api_key("快速思維模型", default_quick_think_key)
|
||||||
|
|
@ -578,8 +594,9 @@ def get_user_selections():
|
||||||
# 深度思維模型 API Key
|
# 深度思維模型 API Key
|
||||||
deep_think_api_key = get_api_key("深度思維模型", default_deep_think_key)
|
deep_think_api_key = get_api_key("深度思維模型", default_deep_think_key)
|
||||||
|
|
||||||
# 嵌入模型 API Key
|
# 嵌入模型 API Key(僅在非本地模型時詢問)
|
||||||
embedding_api_key = get_api_key("嵌入模型", default_embedding_key)
|
if not is_local_embedding:
|
||||||
|
embedding_api_key = get_api_key("嵌入模型", default_embedding_key)
|
||||||
|
|
||||||
# Alpha Vantage API Key(必填)
|
# Alpha Vantage API Key(必填)
|
||||||
alpha_vantage_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
alpha_vantage_key = os.getenv("ALPHA_VANTAGE_API_KEY")
|
||||||
|
|
@ -601,6 +618,7 @@ def get_user_selections():
|
||||||
"market_type": selected_market,
|
"market_type": selected_market,
|
||||||
"embedding_provider": embedding_provider,
|
"embedding_provider": embedding_provider,
|
||||||
"embedding_url": embedding_url,
|
"embedding_url": embedding_url,
|
||||||
|
"embedding_model": selected_embedding_model,
|
||||||
"quick_think_api_key": quick_think_api_key,
|
"quick_think_api_key": quick_think_api_key,
|
||||||
"deep_think_api_key": deep_think_api_key,
|
"deep_think_api_key": deep_think_api_key,
|
||||||
"embedding_api_key": embedding_api_key,
|
"embedding_api_key": embedding_api_key,
|
||||||
|
|
|
||||||
71
cli/utils.py
71
cli/utils.py
|
|
@ -547,11 +547,13 @@ def select_embedding_provider() -> tuple[str, str]:
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
tuple[str, str]: 包含供應商名稱和 API 基礎 URL 的元組。
|
tuple[str, str]: 包含供應商名稱和 API 基礎 URL 的元組。
|
||||||
|
對於本地模型,URL 為 "local"。
|
||||||
"""
|
"""
|
||||||
# 定義嵌入模型供應商(只有 OpenAI 和自訂)
|
# 定義嵌入模型供應商(本地 HuggingFace、OpenAI 和自訂)
|
||||||
EMBEDDING_PROVIDERS = [
|
EMBEDDING_PROVIDERS = [
|
||||||
("OpenAI", "https://api.openai.com/v1"),
|
("🖥️ 本地模型 (HuggingFace) - 免費", "local"),
|
||||||
("自訂 URL", "custom")
|
("☁️ OpenAI - 收費", "https://api.openai.com/v1"),
|
||||||
|
("🔧 自訂 URL", "custom")
|
||||||
]
|
]
|
||||||
|
|
||||||
choice = questionary.select(
|
choice = questionary.select(
|
||||||
|
|
@ -603,12 +605,73 @@ def select_embedding_provider() -> tuple[str, str]:
|
||||||
display_name = "自訂供應商"
|
display_name = "自訂供應商"
|
||||||
|
|
||||||
# 印出使用者的選擇
|
# 印出使用者的選擇
|
||||||
print(f"您選擇了嵌入模型:{display_name}\tURL: {url}")
|
if url == "local":
|
||||||
|
print(f"您選擇了:{display_name}(本地執行,無需 API Key)")
|
||||||
|
else:
|
||||||
|
print(f"您選擇了嵌入模型:{display_name}\tURL: {url}")
|
||||||
|
|
||||||
# 返回供應商名稱和 URL
|
# 返回供應商名稱和 URL
|
||||||
return display_name, url
|
return display_name, url
|
||||||
|
|
||||||
|
|
||||||
|
def select_embedding_model(provider: str) -> str:
|
||||||
|
"""
|
||||||
|
根據供應商選擇具體的嵌入模型。
|
||||||
|
|
||||||
|
參數:
|
||||||
|
provider (str): 嵌入模型供應商名稱
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 選擇的嵌入模型名稱
|
||||||
|
"""
|
||||||
|
# 本地 HuggingFace 模型選項
|
||||||
|
LOCAL_EMBEDDING_MODELS = [
|
||||||
|
("all-MiniLM-L6-v2 (推薦) - 90MB, 輕量快速", "all-MiniLM-L6-v2"),
|
||||||
|
("all-mpnet-base-v2 - 420MB, 更高質量", "all-mpnet-base-v2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# OpenAI 嵌入模型選項
|
||||||
|
OPENAI_EMBEDDING_MODELS = [
|
||||||
|
("text-embedding-3-small (推薦) - 高性價比", "text-embedding-3-small"),
|
||||||
|
("text-embedding-3-large - 最高質量", "text-embedding-3-large"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 根據供應商判斷使用哪個模型列表
|
||||||
|
if "本地" in provider or "local" in provider.lower():
|
||||||
|
model_options = LOCAL_EMBEDDING_MODELS
|
||||||
|
prompt_text = "選擇本地嵌入模型:"
|
||||||
|
description = "\n[dim]💡 本地模型首次使用會自動下載,之後無需網路連接[/dim]"
|
||||||
|
else:
|
||||||
|
model_options = OPENAI_EMBEDDING_MODELS
|
||||||
|
prompt_text = "選擇 OpenAI 嵌入模型:"
|
||||||
|
description = "\n[dim]💡 OpenAI 模型需要 API Key 和網路連接[/dim]"
|
||||||
|
|
||||||
|
console.print(description)
|
||||||
|
|
||||||
|
choice = questionary.select(
|
||||||
|
prompt_text,
|
||||||
|
choices=[
|
||||||
|
questionary.Choice(display, value=value)
|
||||||
|
for display, value in model_options
|
||||||
|
],
|
||||||
|
instruction="\n- 使用方向鍵導覽\n- 按下 Enter 鍵選擇",
|
||||||
|
style=questionary.Style(
|
||||||
|
[
|
||||||
|
("selected", "fg:green noinherit"),
|
||||||
|
("highlighted", "fg:green noinherit"),
|
||||||
|
("pointer", "fg:green noinherit"),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
if choice is None:
|
||||||
|
console.print("\n[red]未選擇嵌入模型。正在結束程式...[/red]")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
console.print(f"[green]✓ 已選擇:{choice}[/green]")
|
||||||
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def get_api_key(model_type: str, default_key: Optional[str] = None) -> str:
|
def get_api_key(model_type: str, default_key: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
提示使用者輸入 API Key,如果留空則使用預設值。
|
提示使用者輸入 API Key,如果留空則使用預設值。
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue