Pass model name to _get_provider_kwargs for per-model decisions

The previous fix only checked deep_think_llm when deciding whether
to include reasoning_effort, but both LLMs shared the same kwargs.
This meant mixing an o-series deep model with a gpt-* quick model
would still crash.

Now each LLM gets its own kwargs by passing the model name to
_get_provider_kwargs. reasoning_effort is only added when the
specific model being configured is an o-series model.
This commit is contained in:
Charlie Tonneslan 2026-03-22 12:47:32 -04:00
parent a1ee0bd824
commit 19dae36f58
1 changed files with 22 additions and 11 deletions

View File

@ -71,24 +71,31 @@ class TradingAgentsGraph:
exist_ok=True,
)
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
# Initialize LLMs with provider-specific thinking configuration.
# Each model gets its own kwargs since options like reasoning_effort
# are only valid for certain model families.
deep_model = self.config["deep_think_llm"]
quick_model = self.config["quick_think_llm"]
deep_kwargs = self._get_provider_kwargs(deep_model)
quick_kwargs = self._get_provider_kwargs(quick_model)
# Add callbacks to kwargs if provided (passed to LLM constructor)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
deep_kwargs["callbacks"] = self.callbacks
quick_kwargs["callbacks"] = self.callbacks
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
model=deep_model,
base_url=self.config.get("backend_url"),
**llm_kwargs,
**deep_kwargs,
)
quick_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["quick_think_llm"],
model=quick_model,
base_url=self.config.get("backend_url"),
**llm_kwargs,
**quick_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
@ -133,8 +140,13 @@ class TradingAgentsGraph:
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
def _get_provider_kwargs(self, model_name: str) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation.
Args:
model_name: The model being configured, used to decide which
provider-specific options are applicable.
"""
kwargs = {}
provider = self.config.get("llm_provider", "").lower()
@ -148,8 +160,7 @@ class TradingAgentsGraph:
if reasoning_effort:
# reasoning_effort is only supported by o-series models
# (o1, o3, o3-mini, o4-mini, etc.), not by gpt-* models
model = self.config.get("deep_think_llm", "")
if model.startswith("o"):
if model_name.startswith("o"):
kwargs["reasoning_effort"] = reasoning_effort
return kwargs