diff --git a/cli/main.py b/cli/main.py index 64616ee1..7fd21d02 100644 --- a/cli/main.py +++ b/cli/main.py @@ -479,7 +479,8 @@ def get_user_selections(): ) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) - + selected_embedding_model = select_embedding_agent(selected_llm_provider) + return { "ticker": selected_ticker, "analysis_date": analysis_date, @@ -489,6 +490,7 @@ def get_user_selections(): "backend_url": backend_url, "shallow_thinker": selected_shallow_thinker, "deep_thinker": selected_deep_thinker, + "embedding_model": selected_embedding_model, } @@ -741,6 +743,7 @@ def run_analysis(): config["max_risk_discuss_rounds"] = selections["research_depth"] config["quick_think_llm"] = selections["shallow_thinker"] config["deep_think_llm"] = selections["deep_thinker"] + config["embedding_model"] = selections["embedding_model"] config["backend_url"] = selections["backend_url"] config["llm_provider"] = selections["llm_provider"].lower() diff --git a/cli/utils.py b/cli/utils.py index 7b9682a6..538b6a42 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,5 +1,5 @@ import questionary -from typing import List, Optional, Tuple, Dict +from typing import List, Optional, Tuple, Dict, Sequence from cli.models import AnalystType @@ -10,6 +10,55 @@ ANALYST_ORDER = [ ("Fundamentals Analyst", AnalystType.FUNDAMENTALS), ] +def _ask_custom_model(label: str) -> str: + """Prompt the user to type an arbitrary model name.""" + model_name = questionary.text( + f"Enter the exact Ollama model name for {label}:", + validate=lambda x: len(x.strip()) > 0 or "Model name cannot be empty.", + style=questionary.Style([("text", "fg:green")]), + ).ask() + if not model_name: + console.print(f"\n[red]No model name provided. Exiting...[/red]") + exit(1) + return model_name + +def _select_llm( + provider: str, + label: str, + options: Sequence[Tuple[str, str]], +) -> str: + """ + Generic interactive selector that optionally offers a 'custom' entry + for Ollama users. + """ + opts = list(options) + if provider.lower() == "ollama": + opts.append(("Custom model (type manually)", "__CUSTOM__")) + + choice = questionary.select( + f"Select Your [{label}] LLM Engine:", + choices=[questionary.Choice(d, v) for d, v in opts], + style=questionary.Style( + [ + ("selected", "fg:magenta noinherit"), + ("highlighted", "fg:magenta noinherit"), + ("pointer", "fg:magenta noinherit"), + ] + ), + ).ask() + + if choice is None: + console.print(f"\n[red]No {label.lower()} engine selected. Exiting...[/red]") + exit(1) + + if choice == "__CUSTOM__": + # ask for arbitrary name + model_name = _ask_custom_model(label) + if model_name is None: + console.print("\n[red]No model name provided. Exiting...[/red]") + exit(1) + return model_name.strip() + return choice def get_ticker() -> str: """Prompt the user to enter a ticker symbol.""" @@ -154,30 +203,7 @@ def select_shallow_thinking_agent(provider) -> str: ("llama3.2 local", "llama3.2"), ] } - - choice = questionary.select( - "Select Your [Quick-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] - ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print( - "\n[red]No shallow thinking llm engine selected. Exiting...[/red]" - ) - exit(1) - - return choice + return _select_llm(provider, "Quick-Thinking LLM Engine", SHALLOW_AGENT_OPTIONS[provider.lower()]) def select_deep_thinking_agent(provider) -> str: @@ -217,27 +243,22 @@ def select_deep_thinking_agent(provider) -> str: ] } - choice = questionary.select( - "Select Your [Deep-Thinking LLM Engine]:", - choices=[ - questionary.Choice(display, value=value) - for display, value in DEEP_AGENT_OPTIONS[provider.lower()] + return _select_llm(provider, "Deep-Thinking LLM Engine", DEEP_AGENT_OPTIONS[provider.lower()]) + +def select_embedding_agent(provider) -> str: + """Select embedding llm engine using an interactive selection.""" + + # Define deep thinking llm engine options with their corresponding model names + EMBEDDING_AGENT_OPTIONS = { + "openai": [ + ("GPT", "text-embedding-3-small"), ], - instruction="\n- Use arrow keys to navigate\n- Press Enter to select", - style=questionary.Style( - [ - ("selected", "fg:magenta noinherit"), - ("highlighted", "fg:magenta noinherit"), - ("pointer", "fg:magenta noinherit"), - ] - ), - ).ask() - - if choice is None: - console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]") - exit(1) - - return choice + "ollama": [ + + ] + } + + return _select_llm(provider, "Embedding LLM Engine", EMBEDDING_AGENT_OPTIONS[provider.lower()]) def select_llm_provider() -> tuple[str, str]: """Select the OpenAI api url using interactive selection.""" @@ -247,7 +268,7 @@ def select_llm_provider() -> tuple[str, str]: ("Anthropic", "https://api.anthropic.com/"), ("Google", "https://generativelanguage.googleapis.com/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), + ("Ollama", "http://localhost:11434"), ] choice = questionary.select(