fix: pass routed base urls to provider clients

This commit is contained in:
Garrick 2026-03-24 13:33:48 -07:00
parent cd145f14d0
commit 73e8974182
3 changed files with 48 additions and 0 deletions

View File

@ -3,6 +3,8 @@ import json
import tradingagents.dataflows.config as dataflow_config
import tradingagents.default_config as default_config
from tradingagents.llm_clients.anthropic_client import AnthropicClient
from tradingagents.llm_clients.google_client import GoogleClient
from tradingagents.graph.trading_graph import TradingAgentsGraph
@ -256,6 +258,48 @@ def test_provider_normalization_avoids_duplicate_legacy_client_creation(monkeypa
assert created_clients.count(("openai", "gpt-5-mini")) == 1
def test_anthropic_client_passes_base_url_to_langchain(monkeypatch):
captured_kwargs = {}
class FakeChatAnthropic:
def __init__(self, **kwargs):
captured_kwargs.update(kwargs)
monkeypatch.setattr(
"tradingagents.llm_clients.anthropic_client.NormalizedChatAnthropic",
FakeChatAnthropic,
)
client = AnthropicClient(
model="claude-sonnet-4-6",
base_url="https://anthropic.example/v1",
)
client.get_llm()
assert captured_kwargs["anthropic_api_url"] == "https://anthropic.example/v1"
def test_google_client_passes_base_url_to_langchain(monkeypatch):
captured_kwargs = {}
class FakeChatGoogleGenerativeAI:
def __init__(self, **kwargs):
captured_kwargs.update(kwargs)
monkeypatch.setattr(
"tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI",
FakeChatGoogleGenerativeAI,
)
client = GoogleClient(
model="gemini-2.5-pro",
base_url="https://google.example/v1beta",
)
client.get_llm()
assert captured_kwargs["base_url"] == "https://google.example/v1beta"
def test_dataflow_config_returns_isolated_nested_routing(monkeypatch):
monkeypatch.setattr(dataflow_config, "_config", None)
dataflow_config.initialize_config()

View File

@ -32,6 +32,8 @@ class AnthropicClient(BaseLLMClient):
def get_llm(self) -> Any:
"""Return configured ChatAnthropic instance."""
llm_kwargs = {"model": self.model}
if self.base_url:
llm_kwargs["anthropic_api_url"] = self.base_url
for key in _PASSTHROUGH_KWARGS:
if key in self.kwargs:

View File

@ -26,6 +26,8 @@ class GoogleClient(BaseLLMClient):
def get_llm(self) -> Any:
"""Return configured ChatGoogleGenerativeAI instance."""
llm_kwargs = {"model": self.model}
if self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in ("timeout", "max_retries", "google_api_key", "callbacks", "http_client", "http_async_client"):
if key in self.kwargs: