Pass model name to _get_provider_kwargs for per-model decisions
The previous fix only checked deep_think_llm when deciding whether to include reasoning_effort, but both LLMs shared the same kwargs. This meant mixing an o-series deep model with a gpt-* quick model would still crash. Now each LLM gets its own kwargs by passing the model name to _get_provider_kwargs. reasoning_effort is only added when the specific model being configured is an o-series model.
This commit is contained in:
parent
a1ee0bd824
commit
19dae36f58
|
|
@ -71,24 +71,31 @@ class TradingAgentsGraph:
|
|||
exist_ok=True,
|
||||
)
|
||||
|
||||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
# Initialize LLMs with provider-specific thinking configuration.
|
||||
# Each model gets its own kwargs since options like reasoning_effort
|
||||
# are only valid for certain model families.
|
||||
deep_model = self.config["deep_think_llm"]
|
||||
quick_model = self.config["quick_think_llm"]
|
||||
|
||||
deep_kwargs = self._get_provider_kwargs(deep_model)
|
||||
quick_kwargs = self._get_provider_kwargs(quick_model)
|
||||
|
||||
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
deep_kwargs["callbacks"] = self.callbacks
|
||||
quick_kwargs["callbacks"] = self.callbacks
|
||||
|
||||
deep_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["deep_think_llm"],
|
||||
model=deep_model,
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
**deep_kwargs,
|
||||
)
|
||||
quick_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["quick_think_llm"],
|
||||
model=quick_model,
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
**quick_kwargs,
|
||||
)
|
||||
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
|
|
@ -133,8 +140,13 @@ class TradingAgentsGraph:
|
|||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
def _get_provider_kwargs(self, model_name: str) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation.
|
||||
|
||||
Args:
|
||||
model_name: The model being configured, used to decide which
|
||||
provider-specific options are applicable.
|
||||
"""
|
||||
kwargs = {}
|
||||
provider = self.config.get("llm_provider", "").lower()
|
||||
|
||||
|
|
@ -148,8 +160,7 @@ class TradingAgentsGraph:
|
|||
if reasoning_effort:
|
||||
# reasoning_effort is only supported by o-series models
|
||||
# (o1, o3, o3-mini, o4-mini, etc.), not by gpt-* models
|
||||
model = self.config.get("deep_think_llm", "")
|
||||
if model.startswith("o"):
|
||||
if model_name.startswith("o"):
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
|
||||
return kwargs
|
||||
|
|
|
|||
Loading…
Reference in New Issue