Pass model name to _get_provider_kwargs for per-model decisions
The previous fix only checked deep_think_llm when deciding whether to include reasoning_effort, but both LLMs shared the same kwargs. This meant mixing an o-series deep model with a gpt-* quick model would still crash. Now each LLM gets its own kwargs by passing the model name to _get_provider_kwargs. reasoning_effort is only added when the specific model being configured is an o-series model.
This commit is contained in:
parent
a1ee0bd824
commit
19dae36f58
|
|
@ -71,24 +71,31 @@ class TradingAgentsGraph:
|
||||||
exist_ok=True,
|
exist_ok=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize LLMs with provider-specific thinking configuration
|
# Initialize LLMs with provider-specific thinking configuration.
|
||||||
llm_kwargs = self._get_provider_kwargs()
|
# Each model gets its own kwargs since options like reasoning_effort
|
||||||
|
# are only valid for certain model families.
|
||||||
|
deep_model = self.config["deep_think_llm"]
|
||||||
|
quick_model = self.config["quick_think_llm"]
|
||||||
|
|
||||||
|
deep_kwargs = self._get_provider_kwargs(deep_model)
|
||||||
|
quick_kwargs = self._get_provider_kwargs(quick_model)
|
||||||
|
|
||||||
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||||||
if self.callbacks:
|
if self.callbacks:
|
||||||
llm_kwargs["callbacks"] = self.callbacks
|
deep_kwargs["callbacks"] = self.callbacks
|
||||||
|
quick_kwargs["callbacks"] = self.callbacks
|
||||||
|
|
||||||
deep_client = create_llm_client(
|
deep_client = create_llm_client(
|
||||||
provider=self.config["llm_provider"],
|
provider=self.config["llm_provider"],
|
||||||
model=self.config["deep_think_llm"],
|
model=deep_model,
|
||||||
base_url=self.config.get("backend_url"),
|
base_url=self.config.get("backend_url"),
|
||||||
**llm_kwargs,
|
**deep_kwargs,
|
||||||
)
|
)
|
||||||
quick_client = create_llm_client(
|
quick_client = create_llm_client(
|
||||||
provider=self.config["llm_provider"],
|
provider=self.config["llm_provider"],
|
||||||
model=self.config["quick_think_llm"],
|
model=quick_model,
|
||||||
base_url=self.config.get("backend_url"),
|
base_url=self.config.get("backend_url"),
|
||||||
**llm_kwargs,
|
**quick_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deep_thinking_llm = deep_client.get_llm()
|
self.deep_thinking_llm = deep_client.get_llm()
|
||||||
|
|
@ -133,8 +140,13 @@ class TradingAgentsGraph:
|
||||||
# Set up the graph
|
# Set up the graph
|
||||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||||
|
|
||||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
def _get_provider_kwargs(self, model_name: str) -> Dict[str, Any]:
|
||||||
"""Get provider-specific kwargs for LLM client creation."""
|
"""Get provider-specific kwargs for LLM client creation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The model being configured, used to decide which
|
||||||
|
provider-specific options are applicable.
|
||||||
|
"""
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
provider = self.config.get("llm_provider", "").lower()
|
provider = self.config.get("llm_provider", "").lower()
|
||||||
|
|
||||||
|
|
@ -148,8 +160,7 @@ class TradingAgentsGraph:
|
||||||
if reasoning_effort:
|
if reasoning_effort:
|
||||||
# reasoning_effort is only supported by o-series models
|
# reasoning_effort is only supported by o-series models
|
||||||
# (o1, o3, o3-mini, o4-mini, etc.), not by gpt-* models
|
# (o1, o3, o3-mini, o4-mini, etc.), not by gpt-* models
|
||||||
model = self.config.get("deep_think_llm", "")
|
if model_name.startswith("o"):
|
||||||
if model.startswith("o"):
|
|
||||||
kwargs["reasoning_effort"] = reasoning_effort
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue