diff --git a/.env.example b/.env.example index 1328b838..1708f962 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,6 @@ GOOGLE_API_KEY= ANTHROPIC_API_KEY= XAI_API_KEY= OPENROUTER_API_KEY= +OLLAMA_BASE_URL= +OLLAMA_MODEL= +OLLAMA_API_KEY= diff --git a/README.md b/README.md index 8cf085e8..971e7c0d 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,16 @@ export OPENROUTER_API_KEY=... # OpenRouter export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage ``` -For local models, configure Ollama with `llm_provider: "ollama"` in your config. +For Ollama (cloud or self-hosted), you can set: + +```bash +export OLLAMA_BASE_URL=https://ollama.com +export OLLAMA_MODEL=gpt-oss:120b +export OLLAMA_API_KEY=... +``` + +Then configure `llm_provider: "ollama"` in your config (OLLAMA_* env vars override CLI selection). +Make sure `OLLAMA_MODEL` exists in `curl $OLLAMA_BASE_URL/api/tags`. Alternatively, copy `.env.example` to `.env` and fill in your keys: ```bash diff --git a/cli/main.py b/cli/main.py index adda48fc..0ca80f48 100644 --- a/cli/main.py +++ b/cli/main.py @@ -543,6 +543,69 @@ def get_user_selections(): ) ) selected_llm_provider, backend_url = select_llm_provider() + + ollama_models = None + if selected_llm_provider.lower() == "ollama": + import os + import logging + from cli.utils import fetch_ollama_models, _normalize_ollama_base_url + + env_base_url = os.getenv("OLLAMA_BASE_URL") + resolved_base_url = _normalize_ollama_base_url(env_base_url or backend_url) + env_model = os.getenv("OLLAMA_MODEL") + env_api_key = os.getenv("OLLAMA_API_KEY") + + # Enable Ollama client logs (URL/model) in CLI + ollama_logger = logging.getLogger("tradingagents.llm_clients.ollama_client") + if not ollama_logger.handlers: + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("[ollama] %(message)s")) + ollama_logger.addHandler(handler) + ollama_logger.setLevel(logging.INFO) + ollama_logger.propagate = False + + tags_url = f"{resolved_base_url}/api/tags" if resolved_base_url else "N/A" + ollama_models = fetch_ollama_models(resolved_base_url, env_api_key) + + info_lines = [ + "[bold]Ollama Environment Variables[/bold]", + "Set `OLLAMA_BASE_URL` (default: https://ollama.com), " + "`OLLAMA_MODEL` (must exist in /api/tags), and " + "`OLLAMA_API_KEY` (if required).", + "If set, these override the CLI selection for both quick and deep models.", + "", + f"Resolved base URL: {resolved_base_url}", + f"Env model: {env_model or '(unset)'}", + f"Tags endpoint: {tags_url}", + f"Models fetched: {len(ollama_models) if ollama_models else 0}", + ] + + console.print( + Panel( + "\n".join(info_lines), + border_style="yellow", + padding=(1, 2), + ) + ) + + if ollama_models: + table = Table(title="Ollama Models", box=box.SIMPLE) + table.add_column("Model", style="cyan") + for model_name in ollama_models: + table.add_row(model_name) + console.print(table) + if env_model and env_model not in ollama_models: + console.print( + Panel( + f"[yellow]Warning:[/yellow] `OLLAMA_MODEL={env_model}` " + "does not exist in /api/tags. " + "The server may return an error. " + "Set `OLLAMA_MODEL` to a name from the list " + "or unset the variable to use the CLI selection.", + border_style="yellow", + padding=(1, 2), + ) + ) # Step 6: Thinking agents console.print( @@ -550,8 +613,36 @@ def get_user_selections(): "Step 6: Thinking Agents", "Select your thinking agents for analysis" ) ) - selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) - selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) + selected_shallow_thinker = select_shallow_thinking_agent( + selected_llm_provider, ollama_models=ollama_models + ) + selected_deep_thinker = select_deep_thinking_agent( + selected_llm_provider, ollama_models=ollama_models + ) + + if selected_llm_provider.lower() == "ollama": + import os + env_model = os.getenv("OLLAMA_MODEL") + if env_model: + console.print( + Panel( + "[bold]Ollama Model Override[/bold]\n" + f"`OLLAMA_MODEL` is set to `{env_model}` and overrides " + f"CLI selections (quick={selected_shallow_thinker}, deep={selected_deep_thinker}).", + border_style="yellow", + padding=(1, 2), + ) + ) + else: + console.print( + Panel( + "[bold]Ollama Models Selected[/bold]\n" + f"Quick: {selected_shallow_thinker}\n" + f"Deep: {selected_deep_thinker}", + border_style="green", + padding=(1, 2), + ) + ) # Step 7: Provider-specific thinking configuration thinking_level = None diff --git a/cli/utils.py b/cli/utils.py index 5a8ec16c..9f41cf38 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -1,3 +1,5 @@ +import os +import requests import questionary from typing import List, Optional, Tuple, Dict @@ -126,7 +128,41 @@ def select_research_depth() -> int: return choice -def select_shallow_thinking_agent(provider) -> str: +def _normalize_ollama_base_url(url: str) -> str: + if not url: + return url + url = url.rstrip("/") + if url.endswith("/v1"): + url = url[:-3] + return url + + +def fetch_ollama_models( + base_url: str, + api_key: Optional[str] = None, + timeout: float = 5.0, +) -> List[str]: + """Fetch available models from Ollama /api/tags endpoint.""" + if not base_url: + return [] + base_url = _normalize_ollama_base_url(base_url) + url = f"{base_url}/api/tags" + headers = {"Authorization": api_key} if api_key else None + try: + resp = requests.get(url, headers=headers, timeout=timeout) + resp.raise_for_status() + data = resp.json() + models = [] + for item in data.get("models", []): + name = item.get("name") or item.get("model") + if name: + models.append(name) + return models + except requests.exceptions.RequestException: + return [] + + +def select_shallow_thinking_agent(provider, ollama_models: Optional[List[str]] = None) -> str: """Select shallow thinking llm engine using an interactive selection.""" # Define shallow thinking llm engine options with their corresponding model names @@ -166,12 +202,17 @@ def select_shallow_thinking_agent(provider) -> str: ], } - choice = questionary.select( - "Select Your [Quick-Thinking LLM Engine]:", - choices=[ + if provider.lower() == "ollama" and ollama_models: + choices = [questionary.Choice(m, value=m) for m in ollama_models] + else: + choices = [ questionary.Choice(display, value=value) for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()] - ], + ] + + choice = questionary.select( + "Select Your [Quick-Thinking LLM Engine]:", + choices=choices, instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( [ @@ -191,7 +232,7 @@ def select_shallow_thinking_agent(provider) -> str: return choice -def select_deep_thinking_agent(provider) -> str: +def select_deep_thinking_agent(provider, ollama_models: Optional[List[str]] = None) -> str: """Select deep thinking llm engine using an interactive selection.""" # Define deep thinking llm engine options with their corresponding model names @@ -233,12 +274,17 @@ def select_deep_thinking_agent(provider) -> str: ], } - choice = questionary.select( - "Select Your [Deep-Thinking LLM Engine]:", - choices=[ + if provider.lower() == "ollama" and ollama_models: + choices = [questionary.Choice(m, value=m) for m in ollama_models] + else: + choices = [ questionary.Choice(display, value=value) for display, value in DEEP_AGENT_OPTIONS[provider.lower()] - ], + ] + + choice = questionary.select( + "Select Your [Deep-Thinking LLM Engine]:", + choices=choices, instruction="\n- Use arrow keys to navigate\n- Press Enter to select", style=questionary.Style( [ diff --git a/pyproject.toml b/pyproject.toml index 4c91a733..8d994245 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "langchain-anthropic>=0.3.15", "langchain-experimental>=0.3.4", "langchain-google-genai>=2.1.5", + "langchain-ollama", "langchain-openai>=0.3.23", "langgraph>=0.4.8", "pandas>=2.3.0", diff --git a/requirements.txt b/requirements.txt index 184468b8..b41b5b5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ typer questionary langchain_anthropic langchain-google-genai +langchain-ollama diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d3..74998107 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -2,6 +2,7 @@ from typing import Optional from .base_client import BaseLLMClient from .openai_client import OpenAIClient +from .ollama_client import OllamaClient from .anthropic_client import AnthropicClient from .google_client import GoogleClient @@ -34,9 +35,12 @@ def create_llm_client( """ provider_lower = provider.lower() - if provider_lower in ("openai", "ollama", "openrouter"): + if provider_lower in ("openai", "openrouter"): return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) + if provider_lower == "ollama": + return OllamaClient(model, base_url, **kwargs) + if provider_lower == "xai": return OpenAIClient(model, base_url, provider="xai", **kwargs) diff --git a/tradingagents/llm_clients/ollama_client.py b/tradingagents/llm_clients/ollama_client.py new file mode 100644 index 00000000..23ee6f22 --- /dev/null +++ b/tradingagents/llm_clients/ollama_client.py @@ -0,0 +1,68 @@ +import os +from typing import Any, Optional + +from langchain_ollama import ChatOllama + +from .base_client import BaseLLMClient + + +class OllamaClient(BaseLLMClient): + """Client for Ollama models using ChatOllama.""" + + def __init__(self, model: str, base_url: Optional[str] = None, **kwargs): + super().__init__(model, base_url, **kwargs) + + def _normalize_base_url(self, url: str) -> str: + # ChatOllama expects the root URL, not a /v1 suffix. + if not url: + return url + url = url.rstrip("/") + if url.endswith("/v1"): + return url[:-3] + return url + + def get_llm(self) -> Any: + """Return configured ChatOllama instance.""" + # Env-based defaults (can be overridden by explicit base_url/model) + ollama_base_url = os.getenv("OLLAMA_BASE_URL") + if not ollama_base_url: + # Avoid inheriting the OpenAI default backend_url when provider is ollama. + if self.base_url and self.base_url != "https://api.openai.com/v1": + ollama_base_url = self.base_url + else: + ollama_base_url = "https://ollama.com" + ollama_base_url = self._normalize_base_url(ollama_base_url) + + ollama_model = os.getenv("OLLAMA_MODEL") + if not ollama_model: + ollama_model = self.model or "llama4" + if not ollama_model: + ollama_model = "llama4" + + ollama_api_key = os.getenv("OLLAMA_API_KEY") + + llm_kwargs = { + "model": ollama_model, + "base_url": ollama_base_url, + "temperature": self.kwargs.get("temperature", 0), + "stream": False, + } + + # Important: pass headers via client_kwargs (httpx) + if ollama_api_key: + header_value = ollama_api_key + if not header_value.lower().startswith(("bearer ", "basic ", "token ")): + header_value = f"Bearer {header_value}" + llm_kwargs["client_kwargs"] = { + "headers": {"Authorization": header_value} + } + + for key in ("timeout", "max_retries", "callbacks"): + if key in self.kwargs: + llm_kwargs[key] = self.kwargs[key] + + return ChatOllama(**llm_kwargs) + + def validate_model(self) -> bool: + """Validate model for Ollama (accept any).""" + return True