Improve Ollama defaults, data resilience, and docs

This commit is contained in:
Oğuzcan Toptaş 2026-03-29 03:22:32 +03:00
parent 589b351f2a
commit d861eccd2d
21 changed files with 274 additions and 53 deletions

View File

@ -117,6 +117,12 @@ Install the package and its dependencies:
pip install . pip install .
``` ```
If you are actively modifying the project, prefer an editable install so the
`tradingagents` CLI always reflects your local code changes:
```bash
pip install -e .
```
### Required APIs ### Required APIs
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider: TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
@ -132,6 +138,17 @@ export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
For local models, configure Ollama with `llm_provider: "ollama"` in your config. For local models, configure Ollama with `llm_provider: "ollama"` in your config.
Example local setup for a default Ollama install running `qwen3:8b`:
```bash
export TRADINGAGENTS_LLM_PROVIDER=ollama
export TRADINGAGENTS_BACKEND_URL=http://localhost:11434/v1
export TRADINGAGENTS_QUICK_THINK_LLM=qwen3:8b
export TRADINGAGENTS_DEEP_THINK_LLM=qwen3:8b
```
The default config in this fork is also set up for Ollama on
`http://localhost:11434/v1`, with both thinking models pointing to `qwen3:8b`.
Alternatively, copy `.env.example` to `.env` and fill in your keys: Alternatively, copy `.env.example` to `.env` and fill in your keys:
```bash ```bash
cp .env.example .env cp .env.example .env
@ -198,6 +215,16 @@ _, decision = ta.propagate("NVDA", "2026-01-15")
print(decision) print(decision)
``` ```
Local Ollama example:
```python
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "ollama"
config["backend_url"] = "http://localhost:11434/v1"
config["deep_think_llm"] = "qwen3:8b"
config["quick_think_llm"] = "qwen3:8b"
```
See `tradingagents/default_config.py` for all configuration options. See `tradingagents/default_config.py` for all configuration options.
## Contributing ## Contributing

View File

@ -939,6 +939,15 @@ def run_analysis():
selected_set = {analyst.value for analyst in selections["analysts"]} selected_set = {analyst.value for analyst in selections["analysts"]}
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set] selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
# Create result directory early so graph logging uses the same run-specific base path.
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
log_file = results_dir / "message_tool.log"
log_file.touch(exist_ok=True)
config["results_dir"] = str(results_dir)
# Initialize the graph with callbacks bound to LLMs # Initialize the graph with callbacks bound to LLMs
graph = TradingAgentsGraph( graph = TradingAgentsGraph(
selected_analyst_keys, selected_analyst_keys,
@ -953,14 +962,6 @@ def run_analysis():
# Track start time for elapsed display # Track start time for elapsed display
start_time = time.time() start_time = time.time()
# Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
log_file = results_dir / "message_tool.log"
log_file.touch(exist_ok=True)
def save_message_decorator(obj, func_name): def save_message_decorator(obj, func_name):
func = getattr(obj, func_name) func = getattr(obj, func_name)
@wraps(func) @wraps(func)

View File

@ -167,7 +167,7 @@ def select_shallow_thinking_agent(provider) -> str:
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
], ],
"ollama": [ "ollama": [
("Qwen3:latest (8B, local)", "qwen3:latest"), ("Qwen3:8B (local)", "qwen3:8b"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
], ],
@ -236,7 +236,7 @@ def select_deep_thinking_agent(provider) -> str:
"ollama": [ "ollama": [
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"), ("Qwen3:8B (local)", "qwen3:8b"),
], ],
} }

View File

@ -8,8 +8,8 @@ load_dotenv()
# Create a custom config # Create a custom config
config = DEFAULT_CONFIG.copy() config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-5-mini" # Use a different model config["deep_think_llm"] = "qwen3:8b"
config["quick_think_llm"] = "gpt-5-mini" # Use a different model config["quick_think_llm"] = "qwen3:8b"
config["max_debate_rounds"] = 1 # Increase debate rounds config["max_debate_rounds"] = 1 # Increase debate rounds
# Configure data vendors (default uses yfinance, no extra API keys needed) # Configure data vendors (default uses yfinance, no extra API keys needed)

15
test.py
View File

@ -1,11 +1,12 @@
import time import time
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
print("Testing optimized implementation with 30-day lookback:") if __name__ == "__main__":
start_time = time.time() print("Testing optimized implementation with 30-day lookback:")
result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30) start_time = time.time()
end_time = time.time() result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30)
end_time = time.time()
print(f"Execution time: {end_time - start_time:.2f} seconds") print(f"Execution time: {end_time - start_time:.2f} seconds")
print(f"Result length: {len(result)} characters") print(f"Result length: {len(result)} characters")
print(result) print(result)

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@

View File

@ -0,0 +1,26 @@
import unittest
from tradingagents.dataflows.utils import normalize_date_range, normalize_iso_date
class DateNormalizationTests(unittest.TestCase):
def test_normalize_iso_date_keeps_valid_dates(self):
self.assertEqual(normalize_iso_date("2024-02-29"), "2024-02-29")
def test_normalize_iso_date_clamps_invalid_month_end(self):
self.assertEqual(normalize_iso_date("2026-02-29"), "2026-02-28")
self.assertEqual(normalize_iso_date("2026-04-31"), "2026-04-30")
def test_normalize_date_range_orders_dates_after_normalization(self):
self.assertEqual(
normalize_date_range("2026-03-29", "2026-02-29"),
("2026-02-28", "2026-03-29"),
)
def test_normalize_iso_date_rejects_bad_format(self):
with self.assertRaises(ValueError):
normalize_iso_date("2026/02/29")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,58 @@
import json
import tempfile
import unittest
from pathlib import Path
from tradingagents.graph.trading_graph import TradingAgentsGraph
class GraphLoggingTests(unittest.TestCase):
def test_log_state_uses_configured_results_dir(self):
with tempfile.TemporaryDirectory() as tmpdir:
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
graph.config = {"results_dir": tmpdir}
graph.ticker = "SPY"
graph.log_states_dict = {}
final_state = {
"company_of_interest": "SPY",
"trade_date": "2026-03-29",
"market_report": "market",
"sentiment_report": "sentiment",
"news_report": "news",
"fundamentals_report": "fundamentals",
"investment_debate_state": {
"bull_history": "bull",
"bear_history": "bear",
"history": "history",
"current_response": "current",
"judge_decision": "judge",
},
"trader_investment_plan": "trade plan",
"risk_debate_state": {
"aggressive_history": "aggressive",
"conservative_history": "conservative",
"neutral_history": "neutral",
"history": "risk history",
"judge_decision": "portfolio judge",
},
"investment_plan": "investment plan",
"final_trade_decision": "HOLD",
}
graph._log_state("2026-03-29", final_state)
log_path = (
Path(tmpdir)
/ "TradingAgentsStrategy_logs"
/ "full_states_log_2026-03-29.json"
)
self.assertTrue(log_path.exists())
payload = json.loads(log_path.read_text())
self.assertIn("2026-03-29", payload)
self.assertEqual(payload["2026-03-29"]["final_trade_decision"], "HOLD")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,36 @@
import unittest
import warnings
from tradingagents.llm_clients.factory import create_llm_client
class LLMClientFactoryTests(unittest.TestCase):
def test_invalid_known_provider_model_emits_warning(self):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client = create_llm_client("openai", "not-a-real-openai-model")
self.assertEqual(client.provider, "openai")
self.assertTrue(
any("not-a-real-openai-model" in str(w.message) for w in caught)
)
def test_known_valid_model_does_not_emit_warning(self):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client = create_llm_client("openai", "gpt-5-mini")
self.assertEqual(client.provider, "openai")
self.assertEqual(len(caught), 0)
def test_ollama_custom_models_do_not_emit_warning(self):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client = create_llm_client("ollama", "qwen3:8b")
self.assertEqual(client.provider, "ollama")
self.assertEqual(len(caught), 0)
if __name__ == "__main__":
unittest.main()

View File

@ -5,6 +5,8 @@ import json
from datetime import datetime from datetime import datetime
from io import StringIO from io import StringIO
from .utils import normalize_date_range, normalize_iso_date
API_BASE_URL = "https://www.alphavantage.co/query" API_BASE_URL = "https://www.alphavantage.co/query"
def get_api_key() -> str: def get_api_key() -> str:
@ -22,7 +24,7 @@ def format_datetime_for_api(date_input) -> str:
return date_input return date_input
# Try to parse common date formats # Try to parse common date formats
try: try:
dt = datetime.strptime(date_input, "%Y-%m-%d") dt = datetime.strptime(normalize_iso_date(date_input), "%Y-%m-%d")
return dt.strftime("%Y%m%dT0000") return dt.strftime("%Y%m%dT0000")
except ValueError: except ValueError:
try: try:
@ -100,6 +102,7 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
return csv_data return csv_data
try: try:
start_date, end_date = normalize_date_range(start_date, end_date)
# Parse CSV data # Parse CSV data
df = pd.read_csv(StringIO(csv_data)) df = pd.read_csv(StringIO(csv_data))

View File

@ -1,4 +1,5 @@
from .alpha_vantage_common import _make_api_request, format_datetime_for_api from .alpha_vantage_common import _make_api_request, format_datetime_for_api
from .utils import normalize_date_range, normalize_iso_date
def get_news(ticker, start_date, end_date) -> dict[str, str] | str: def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"""Returns live and historical market news & sentiment data from premier news outlets worldwide. """Returns live and historical market news & sentiment data from premier news outlets worldwide.
@ -14,6 +15,8 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
Dictionary containing news sentiment data or JSON string. Dictionary containing news sentiment data or JSON string.
""" """
start_date, end_date = normalize_date_range(start_date, end_date)
params = { params = {
"tickers": ticker, "tickers": ticker,
"time_from": format_datetime_for_api(start_date), "time_from": format_datetime_for_api(start_date),
@ -38,6 +41,7 @@ def get_global_news(curr_date, look_back_days: int = 7, limit: int = 50) -> dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
# Calculate start date # Calculate start date
curr_date = normalize_iso_date(curr_date)
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d") curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
start_dt = curr_dt - timedelta(days=look_back_days) start_dt = curr_dt - timedelta(days=look_back_days)
start_date = start_dt.strftime("%Y-%m-%d") start_date = start_dt.strftime("%Y-%m-%d")

View File

@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
from .utils import normalize_date_range
def get_stock( def get_stock(
symbol: str, symbol: str,
@ -18,6 +19,8 @@ def get_stock(
Returns: Returns:
CSV string containing the daily adjusted time series data filtered to the date range. CSV string containing the daily adjusted time series data filtered to the date range.
""" """
start_date, end_date = normalize_date_range(start_date, end_date)
# Parse dates to determine the range # Parse dates to determine the range
start_dt = datetime.strptime(start_date, "%Y-%m-%d") start_dt = datetime.strptime(start_date, "%Y-%m-%d")
today = datetime.now() today = datetime.now()

View File

@ -1,6 +1,7 @@
import os import os
import json import json
import pandas as pd import pandas as pd
import calendar
from datetime import date, timedelta, datetime from datetime import date, timedelta, datetime
from typing import Annotated from typing import Annotated
@ -16,6 +17,50 @@ def get_current_date():
return date.today().strftime("%Y-%m-%d") return date.today().strftime("%Y-%m-%d")
def normalize_iso_date(date_str: str) -> str:
"""Normalize YYYY-MM-DD dates, clamping invalid month-end days.
LLM tool calls occasionally produce dates like 2026-02-29 when they mean
"the end of February". For valid ISO dates this returns the input as-is.
For invalid day-of-month values within a valid year/month, it clamps to the
last valid day of that month. Other malformed values still raise ValueError.
"""
try:
return datetime.strptime(date_str, "%Y-%m-%d").strftime("%Y-%m-%d")
except ValueError as exc:
try:
year_str, month_str, day_str = date_str.split("-")
year = int(year_str)
month = int(month_str)
day = int(day_str)
except (AttributeError, ValueError) as parse_exc:
raise ValueError(f"Unsupported date format: {date_str}") from parse_exc
if not 1 <= month <= 12:
raise exc
if day < 1:
raise exc
last_day = calendar.monthrange(year, month)[1]
if day > last_day:
return f"{year:04d}-{month:02d}-{last_day:02d}"
raise exc
def normalize_date_range(start_date: str, end_date: str) -> tuple[str, str]:
"""Normalize and order an ISO date range."""
normalized_start = normalize_iso_date(start_date)
normalized_end = normalize_iso_date(end_date)
start_dt = datetime.strptime(normalized_start, "%Y-%m-%d")
end_dt = datetime.strptime(normalized_end, "%Y-%m-%d")
if start_dt <= end_dt:
return normalized_start, normalized_end
return normalized_end, normalized_start
def decorate_all_methods(decorator): def decorate_all_methods(decorator):
def class_decorator(cls): def class_decorator(cls):
for attr_name, attr_value in cls.__dict__.items(): for attr_name, attr_value in cls.__dict__.items():

View File

@ -4,15 +4,14 @@ from dateutil.relativedelta import relativedelta
import yfinance as yf import yfinance as yf
import os import os
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
from .utils import normalize_date_range
def get_YFin_data_online( def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"], symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"],
): ):
start_date, end_date = normalize_date_range(start_date, end_date)
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
# Create ticker object # Create ticker object
ticker = yf.Ticker(symbol.upper()) ticker = yf.Ticker(symbol.upper())

View File

@ -4,6 +4,8 @@ import yfinance as yf
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from .utils import normalize_date_range, normalize_iso_date
def _extract_article_data(article: dict) -> dict: def _extract_article_data(article: dict) -> dict:
"""Extract article data from yfinance news format (handles nested 'content' structure).""" """Extract article data from yfinance news format (handles nested 'content' structure)."""
@ -63,6 +65,7 @@ def get_news_yfinance(
Formatted string containing news articles Formatted string containing news articles
""" """
try: try:
start_date, end_date = normalize_date_range(start_date, end_date)
stock = yf.Ticker(ticker) stock = yf.Ticker(ticker)
news = stock.get_news(count=20) news = stock.get_news(count=20)
@ -130,6 +133,7 @@ def get_global_news_yfinance(
seen_titles = set() seen_titles = set()
try: try:
curr_date = normalize_iso_date(curr_date)
for query in search_queries: for query in search_queries:
search = yf.Search( search = yf.Search(
query=query, query=query,

View File

@ -8,10 +8,10 @@ DEFAULT_CONFIG = {
"dataflows/data_cache", "dataflows/data_cache",
), ),
# LLM settings # LLM settings
"llm_provider": "openai", "llm_provider": os.getenv("TRADINGAGENTS_LLM_PROVIDER", "ollama"),
"deep_think_llm": "gpt-5.2", "deep_think_llm": os.getenv("TRADINGAGENTS_DEEP_THINK_LLM", "qwen3:8b"),
"quick_think_llm": "gpt-5-mini", "quick_think_llm": os.getenv("TRADINGAGENTS_QUICK_THINK_LLM", "qwen3:8b"),
"backend_url": "https://api.openai.com/v1", "backend_url": os.getenv("TRADINGAGENTS_BACKEND_URL", "http://localhost:11434/v1"),
# Provider-specific thinking configuration # Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc. "google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low" "openai_reasoning_effort": None, # "medium", "high", "low"

View File

@ -38,7 +38,7 @@ class GraphSetup:
self.conditional_logic = conditional_logic self.conditional_logic = conditional_logic
def setup_graph( def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"] self, selected_analysts=None
): ):
"""Set up and compile the agent workflow graph. """Set up and compile the agent workflow graph.
@ -49,6 +49,9 @@ class GraphSetup:
- "news": News analyst - "news": News analyst
- "fundamentals": Fundamentals analyst - "fundamentals": Fundamentals analyst
""" """
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
if len(selected_analysts) == 0: if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")

View File

@ -45,7 +45,7 @@ class TradingAgentsGraph:
def __init__( def __init__(
self, self,
selected_analysts=["market", "social", "news", "fundamentals"], selected_analysts=None,
debug=False, debug=False,
config: Dict[str, Any] = None, config: Dict[str, Any] = None,
callbacks: Optional[List] = None, callbacks: Optional[List] = None,
@ -59,6 +59,8 @@ class TradingAgentsGraph:
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats) callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
""" """
self.debug = debug self.debug = debug
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
self.config = config or DEFAULT_CONFIG self.config = config or DEFAULT_CONFIG
self.callbacks = callbacks or [] self.callbacks = callbacks or []
@ -66,10 +68,7 @@ class TradingAgentsGraph:
set_config(self.config) set_config(self.config)
# Create necessary directories # Create necessary directories
os.makedirs( os.makedirs(self.config["data_cache_dir"], exist_ok=True)
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
# Initialize LLMs with provider-specific thinking configuration # Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs() llm_kwargs = self._get_provider_kwargs()
@ -259,11 +258,11 @@ class TradingAgentsGraph:
} }
# Save to file # Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") directory = Path(self.config["results_dir"]) / "TradingAgentsStrategy_logs"
directory.mkdir(parents=True, exist_ok=True) directory.mkdir(parents=True, exist_ok=True)
with open( with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", directory / f"full_states_log_{trade_date}.json",
"w", "w",
encoding="utf-8", encoding="utf-8",
) as f: ) as f:

View File

@ -1,9 +1,12 @@
# LLM Clients - Consistency Improvements # LLM Clients - Consistency Improvements
## Issues to Fix ## Completed
### 1. `validate_model()` is never called ### 1. `validate_model()` warning path
- Add validation call in `get_llm()` with warning (not error) for unknown models - `create_llm_client()` now calls `validate_model()` and emits a warning for
unknown model names instead of failing immediately.
## Remaining Issues to Fix
### 2. Inconsistent parameter handling ### 2. Inconsistent parameter handling
| Client | API Key Param | Special Params | | Client | API Key Param | Special Params |

View File

@ -1,3 +1,4 @@
import warnings
from typing import Optional from typing import Optional
from .base_client import BaseLLMClient from .base_client import BaseLLMClient
@ -35,15 +36,22 @@ def create_llm_client(
provider_lower = provider.lower() provider_lower = provider.lower()
if provider_lower in ("openai", "ollama", "openrouter"): if provider_lower in ("openai", "ollama", "openrouter"):
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) client = OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
elif provider_lower == "xai":
if provider_lower == "xai": client = OpenAIClient(model, base_url, provider="xai", **kwargs)
return OpenAIClient(model, base_url, provider="xai", **kwargs) elif provider_lower == "anthropic":
client = AnthropicClient(model, base_url, **kwargs)
if provider_lower == "anthropic": elif provider_lower == "google":
return AnthropicClient(model, base_url, **kwargs) client = GoogleClient(model, base_url, **kwargs)
else:
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")
if not client.validate_model():
warnings.warn(
f"Model '{model}' is not in the known model list for provider "
f"'{provider_lower}'. The request may still work if the provider "
"supports it, but mis-typed model names will fail at runtime.",
stacklevel=2,
)
return client

View File

@ -57,8 +57,8 @@ class OpenAIClient(BaseLLMClient):
# Provider-specific base URL and auth # Provider-specific base URL and auth
if self.provider in _PROVIDER_CONFIG: if self.provider in _PROVIDER_CONFIG:
base_url, api_key_env = _PROVIDER_CONFIG[self.provider] default_base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = base_url llm_kwargs["base_url"] = self.base_url or default_base_url
if api_key_env: if api_key_env:
api_key = os.environ.get(api_key_env) api_key = os.environ.get(api_key_env)
if api_key: if api_key: