diff --git a/cli/utils.py b/cli/utils.py index 7b9682a6..6bbe8bc6 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -125,6 +125,25 @@ def select_research_depth() -> int: def select_shallow_thinking_agent(provider) -> str: """Select shallow thinking llm engine using an interactive selection.""" + # if provider is custom, let user input model name + if provider.lower() == "custom": + model = questionary.text( + "Enter your custom Quick-Thinking model name:", + validate=lambda x: len( + x.strip()) > 0 or "Please enter a valid model name.", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() + + if not model: + console.print("\n[red]No model name provided. Exiting...[/red]") + exit(1) + + return model.strip() # Define shallow thinking llm engine options with their corresponding model names SHALLOW_AGENT_OPTIONS = { "openai": [ @@ -183,6 +202,26 @@ def select_shallow_thinking_agent(provider) -> str: def select_deep_thinking_agent(provider) -> str: """Select deep thinking llm engine using an interactive selection.""" + # if provider is custom, let user input model name + if provider.lower() == "custom": + model = questionary.text( + "Enter your custom Deep-Thinking model name:", + validate=lambda x: len( + x.strip()) > 0 or "Please enter a valid model name.", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() + + if not model: + console.print("\n[red]No model name provided. Exiting...[/red]") + exit(1) + + return model.strip() + # Define deep thinking llm engine options with their corresponding model names DEEP_AGENT_OPTIONS = { "openai": [ @@ -247,7 +286,8 @@ def select_llm_provider() -> tuple[str, str]: ("Anthropic", "https://api.anthropic.com/"), ("Google", "https://generativelanguage.googleapis.com/v1"), ("Openrouter", "https://openrouter.ai/api/v1"), - ("Ollama", "http://localhost:11434/v1"), + ("Ollama", "http://localhost:11434/v1"), + ("Custom (Enter your custom backend URL)", "custom"), ] choice = questionary.select( @@ -269,7 +309,31 @@ def select_llm_provider() -> tuple[str, str]: if choice is None: console.print("\n[red]no OpenAI backend selected. Exiting...[/red]") exit(1) - + + # if custom, let user input custom url and api key + if choice[1] == "custom": + custom_url = questionary.text( + "Enter your custom API base URL:", + validate=lambda x: len(x.strip()) > 0 or "Please enter a valid backend URL.").ask() + + if not custom_url: + console.print("\n[red]No URL provided. Exiting...[/red]") + exit(1) + + api_key = questionary.text( + "Enter your API key:", + validate=lambda x: len(x.strip()) > 0 or "Please enter a valid API key.").ask() + + if not api_key: + console.print("\n[red]No API key provided. Exiting...[/red]") + exit(1) + + # set api key to environment variable + import os + os.environ["OPENAI_API_KEY"] = api_key.strip() + + return "Custom", custom_url.strip() + display_name, url = choice print(f"You selected: {display_name}\tURL: {url}") diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 80a29e53..84a6c0cc 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -58,7 +58,7 @@ class TradingAgentsGraph: ) # Initialize LLMs - if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter": + if self.config["llm_provider"].lower() in ["openai", "ollama", "openrouter", "custom"]: self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"]) self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"]) elif self.config["llm_provider"].lower() == "anthropic":