From d861eccd2db8f65ea9a273ff2ab1757e3932a7ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?O=C4=9Fuzcan=20Topta=C5=9F?= Date: Sun, 29 Mar 2026 03:22:32 +0300 Subject: [PATCH] Improve Ollama defaults, data resilience, and docs --- README.md | 27 +++++++++ cli/main.py | 17 +++--- cli/utils.py | 4 +- main.py | 4 +- test.py | 15 ++--- tests/__init__.py | 1 + tests/test_date_normalization.py | 26 +++++++++ tests/test_graph_logging.py | 58 +++++++++++++++++++ tests/test_llm_client_factory.py | 36 ++++++++++++ .../dataflows/alpha_vantage_common.py | 5 +- tradingagents/dataflows/alpha_vantage_news.py | 6 +- .../dataflows/alpha_vantage_stock.py | 5 +- tradingagents/dataflows/utils.py | 45 ++++++++++++++ tradingagents/dataflows/y_finance.py | 7 +-- tradingagents/dataflows/yfinance_news.py | 4 ++ tradingagents/default_config.py | 8 +-- tradingagents/graph/setup.py | 5 +- tradingagents/graph/trading_graph.py | 13 ++--- tradingagents/llm_clients/TODO.md | 9 ++- tradingagents/llm_clients/factory.py | 28 +++++---- tradingagents/llm_clients/openai_client.py | 4 +- 21 files changed, 274 insertions(+), 53 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_date_normalization.py create mode 100644 tests/test_graph_logging.py create mode 100644 tests/test_llm_client_factory.py diff --git a/README.md b/README.md index 4c4856d1..35e96e10 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,12 @@ Install the package and its dependencies: pip install . ``` +If you are actively modifying the project, prefer an editable install so the +`tradingagents` CLI always reflects your local code changes: +```bash +pip install -e . +``` + ### Required APIs TradingAgents supports multiple LLM providers. Set the API key for your chosen provider: @@ -132,6 +138,17 @@ export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage For local models, configure Ollama with `llm_provider: "ollama"` in your config. +Example local setup for a default Ollama install running `qwen3:8b`: +```bash +export TRADINGAGENTS_LLM_PROVIDER=ollama +export TRADINGAGENTS_BACKEND_URL=http://localhost:11434/v1 +export TRADINGAGENTS_QUICK_THINK_LLM=qwen3:8b +export TRADINGAGENTS_DEEP_THINK_LLM=qwen3:8b +``` + +The default config in this fork is also set up for Ollama on +`http://localhost:11434/v1`, with both thinking models pointing to `qwen3:8b`. + Alternatively, copy `.env.example` to `.env` and fill in your keys: ```bash cp .env.example .env @@ -198,6 +215,16 @@ _, decision = ta.propagate("NVDA", "2026-01-15") print(decision) ``` +Local Ollama example: + +```python +config = DEFAULT_CONFIG.copy() +config["llm_provider"] = "ollama" +config["backend_url"] = "http://localhost:11434/v1" +config["deep_think_llm"] = "qwen3:8b" +config["quick_think_llm"] = "qwen3:8b" +``` + See `tradingagents/default_config.py` for all configuration options. ## Contributing diff --git a/cli/main.py b/cli/main.py index 53837db2..321fa013 100644 --- a/cli/main.py +++ b/cli/main.py @@ -939,6 +939,15 @@ def run_analysis(): selected_set = {analyst.value for analyst in selections["analysts"]} selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set] + # Create result directory early so graph logging uses the same run-specific base path. + results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] + results_dir.mkdir(parents=True, exist_ok=True) + report_dir = results_dir / "reports" + report_dir.mkdir(parents=True, exist_ok=True) + log_file = results_dir / "message_tool.log" + log_file.touch(exist_ok=True) + config["results_dir"] = str(results_dir) + # Initialize the graph with callbacks bound to LLMs graph = TradingAgentsGraph( selected_analyst_keys, @@ -953,14 +962,6 @@ def run_analysis(): # Track start time for elapsed display start_time = time.time() - # Create result directory - results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"] - results_dir.mkdir(parents=True, exist_ok=True) - report_dir = results_dir / "reports" - report_dir.mkdir(parents=True, exist_ok=True) - log_file = results_dir / "message_tool.log" - log_file.touch(exist_ok=True) - def save_message_decorator(obj, func_name): func = getattr(obj, func_name) @wraps(func) diff --git a/cli/utils.py b/cli/utils.py index 18abc3a7..8dc7cff2 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -167,7 +167,7 @@ def select_shallow_thinking_agent(provider) -> str: ("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"), ], "ollama": [ - ("Qwen3:latest (8B, local)", "qwen3:latest"), + ("Qwen3:8B (local)", "qwen3:8b"), ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), ], @@ -236,7 +236,7 @@ def select_deep_thinking_agent(provider) -> str: "ollama": [ ("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"), ("GPT-OSS:latest (20B, local)", "gpt-oss:latest"), - ("Qwen3:latest (8B, local)", "qwen3:latest"), + ("Qwen3:8B (local)", "qwen3:8b"), ], } diff --git a/main.py b/main.py index 26cab658..19258284 100644 --- a/main.py +++ b/main.py @@ -8,8 +8,8 @@ load_dotenv() # Create a custom config config = DEFAULT_CONFIG.copy() -config["deep_think_llm"] = "gpt-5-mini" # Use a different model -config["quick_think_llm"] = "gpt-5-mini" # Use a different model +config["deep_think_llm"] = "qwen3:8b" +config["quick_think_llm"] = "qwen3:8b" config["max_debate_rounds"] = 1 # Increase debate rounds # Configure data vendors (default uses yfinance, no extra API keys needed) diff --git a/test.py b/test.py index b73783e1..d4b4557a 100644 --- a/test.py +++ b/test.py @@ -1,11 +1,12 @@ import time from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions -print("Testing optimized implementation with 30-day lookback:") -start_time = time.time() -result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30) -end_time = time.time() +if __name__ == "__main__": + print("Testing optimized implementation with 30-day lookback:") + start_time = time.time() + result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30) + end_time = time.time() -print(f"Execution time: {end_time - start_time:.2f} seconds") -print(f"Result length: {len(result)} characters") -print(result) + print(f"Execution time: {end_time - start_time:.2f} seconds") + print(f"Result length: {len(result)} characters") + print(result) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_date_normalization.py b/tests/test_date_normalization.py new file mode 100644 index 00000000..75edfd29 --- /dev/null +++ b/tests/test_date_normalization.py @@ -0,0 +1,26 @@ +import unittest + +from tradingagents.dataflows.utils import normalize_date_range, normalize_iso_date + + +class DateNormalizationTests(unittest.TestCase): + def test_normalize_iso_date_keeps_valid_dates(self): + self.assertEqual(normalize_iso_date("2024-02-29"), "2024-02-29") + + def test_normalize_iso_date_clamps_invalid_month_end(self): + self.assertEqual(normalize_iso_date("2026-02-29"), "2026-02-28") + self.assertEqual(normalize_iso_date("2026-04-31"), "2026-04-30") + + def test_normalize_date_range_orders_dates_after_normalization(self): + self.assertEqual( + normalize_date_range("2026-03-29", "2026-02-29"), + ("2026-02-28", "2026-03-29"), + ) + + def test_normalize_iso_date_rejects_bad_format(self): + with self.assertRaises(ValueError): + normalize_iso_date("2026/02/29") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_graph_logging.py b/tests/test_graph_logging.py new file mode 100644 index 00000000..fd58abbb --- /dev/null +++ b/tests/test_graph_logging.py @@ -0,0 +1,58 @@ +import json +import tempfile +import unittest +from pathlib import Path + +from tradingagents.graph.trading_graph import TradingAgentsGraph + + +class GraphLoggingTests(unittest.TestCase): + def test_log_state_uses_configured_results_dir(self): + with tempfile.TemporaryDirectory() as tmpdir: + graph = TradingAgentsGraph.__new__(TradingAgentsGraph) + graph.config = {"results_dir": tmpdir} + graph.ticker = "SPY" + graph.log_states_dict = {} + + final_state = { + "company_of_interest": "SPY", + "trade_date": "2026-03-29", + "market_report": "market", + "sentiment_report": "sentiment", + "news_report": "news", + "fundamentals_report": "fundamentals", + "investment_debate_state": { + "bull_history": "bull", + "bear_history": "bear", + "history": "history", + "current_response": "current", + "judge_decision": "judge", + }, + "trader_investment_plan": "trade plan", + "risk_debate_state": { + "aggressive_history": "aggressive", + "conservative_history": "conservative", + "neutral_history": "neutral", + "history": "risk history", + "judge_decision": "portfolio judge", + }, + "investment_plan": "investment plan", + "final_trade_decision": "HOLD", + } + + graph._log_state("2026-03-29", final_state) + + log_path = ( + Path(tmpdir) + / "TradingAgentsStrategy_logs" + / "full_states_log_2026-03-29.json" + ) + self.assertTrue(log_path.exists()) + + payload = json.loads(log_path.read_text()) + self.assertIn("2026-03-29", payload) + self.assertEqual(payload["2026-03-29"]["final_trade_decision"], "HOLD") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_llm_client_factory.py b/tests/test_llm_client_factory.py new file mode 100644 index 00000000..101ffd77 --- /dev/null +++ b/tests/test_llm_client_factory.py @@ -0,0 +1,36 @@ +import unittest +import warnings + +from tradingagents.llm_clients.factory import create_llm_client + + +class LLMClientFactoryTests(unittest.TestCase): + def test_invalid_known_provider_model_emits_warning(self): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client = create_llm_client("openai", "not-a-real-openai-model") + + self.assertEqual(client.provider, "openai") + self.assertTrue( + any("not-a-real-openai-model" in str(w.message) for w in caught) + ) + + def test_known_valid_model_does_not_emit_warning(self): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client = create_llm_client("openai", "gpt-5-mini") + + self.assertEqual(client.provider, "openai") + self.assertEqual(len(caught), 0) + + def test_ollama_custom_models_do_not_emit_warning(self): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + client = create_llm_client("ollama", "qwen3:8b") + + self.assertEqual(client.provider, "ollama") + self.assertEqual(len(caught), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/dataflows/alpha_vantage_common.py b/tradingagents/dataflows/alpha_vantage_common.py index 409ff29e..e082f382 100644 --- a/tradingagents/dataflows/alpha_vantage_common.py +++ b/tradingagents/dataflows/alpha_vantage_common.py @@ -5,6 +5,8 @@ import json from datetime import datetime from io import StringIO +from .utils import normalize_date_range, normalize_iso_date + API_BASE_URL = "https://www.alphavantage.co/query" def get_api_key() -> str: @@ -22,7 +24,7 @@ def format_datetime_for_api(date_input) -> str: return date_input # Try to parse common date formats try: - dt = datetime.strptime(date_input, "%Y-%m-%d") + dt = datetime.strptime(normalize_iso_date(date_input), "%Y-%m-%d") return dt.strftime("%Y%m%dT0000") except ValueError: try: @@ -100,6 +102,7 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) -> return csv_data try: + start_date, end_date = normalize_date_range(start_date, end_date) # Parse CSV data df = pd.read_csv(StringIO(csv_data)) diff --git a/tradingagents/dataflows/alpha_vantage_news.py b/tradingagents/dataflows/alpha_vantage_news.py index 4cf7bb0e..2d1e7624 100644 --- a/tradingagents/dataflows/alpha_vantage_news.py +++ b/tradingagents/dataflows/alpha_vantage_news.py @@ -1,4 +1,5 @@ from .alpha_vantage_common import _make_api_request, format_datetime_for_api +from .utils import normalize_date_range, normalize_iso_date def get_news(ticker, start_date, end_date) -> dict[str, str] | str: """Returns live and historical market news & sentiment data from premier news outlets worldwide. @@ -14,6 +15,8 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str: Dictionary containing news sentiment data or JSON string. """ + start_date, end_date = normalize_date_range(start_date, end_date) + params = { "tickers": ticker, "time_from": format_datetime_for_api(start_date), @@ -38,6 +41,7 @@ def get_global_news(curr_date, look_back_days: int = 7, limit: int = 50) -> dict from datetime import datetime, timedelta # Calculate start date + curr_date = normalize_iso_date(curr_date) curr_dt = datetime.strptime(curr_date, "%Y-%m-%d") start_dt = curr_dt - timedelta(days=look_back_days) start_date = start_dt.strftime("%Y-%m-%d") @@ -68,4 +72,4 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str: "symbol": symbol, } - return _make_api_request("INSIDER_TRANSACTIONS", params) \ No newline at end of file + return _make_api_request("INSIDER_TRANSACTIONS", params) diff --git a/tradingagents/dataflows/alpha_vantage_stock.py b/tradingagents/dataflows/alpha_vantage_stock.py index ffd3570b..06a2d5be 100644 --- a/tradingagents/dataflows/alpha_vantage_stock.py +++ b/tradingagents/dataflows/alpha_vantage_stock.py @@ -1,5 +1,6 @@ from datetime import datetime from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range +from .utils import normalize_date_range def get_stock( symbol: str, @@ -18,6 +19,8 @@ def get_stock( Returns: CSV string containing the daily adjusted time series data filtered to the date range. """ + start_date, end_date = normalize_date_range(start_date, end_date) + # Parse dates to determine the range start_dt = datetime.strptime(start_date, "%Y-%m-%d") today = datetime.now() @@ -35,4 +38,4 @@ def get_stock( response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params) - return _filter_csv_by_date_range(response, start_date, end_date) \ No newline at end of file + return _filter_csv_by_date_range(response, start_date, end_date) diff --git a/tradingagents/dataflows/utils.py b/tradingagents/dataflows/utils.py index 4523de19..145ecb23 100644 --- a/tradingagents/dataflows/utils.py +++ b/tradingagents/dataflows/utils.py @@ -1,6 +1,7 @@ import os import json import pandas as pd +import calendar from datetime import date, timedelta, datetime from typing import Annotated @@ -16,6 +17,50 @@ def get_current_date(): return date.today().strftime("%Y-%m-%d") +def normalize_iso_date(date_str: str) -> str: + """Normalize YYYY-MM-DD dates, clamping invalid month-end days. + + LLM tool calls occasionally produce dates like 2026-02-29 when they mean + "the end of February". For valid ISO dates this returns the input as-is. + For invalid day-of-month values within a valid year/month, it clamps to the + last valid day of that month. Other malformed values still raise ValueError. + """ + try: + return datetime.strptime(date_str, "%Y-%m-%d").strftime("%Y-%m-%d") + except ValueError as exc: + try: + year_str, month_str, day_str = date_str.split("-") + year = int(year_str) + month = int(month_str) + day = int(day_str) + except (AttributeError, ValueError) as parse_exc: + raise ValueError(f"Unsupported date format: {date_str}") from parse_exc + + if not 1 <= month <= 12: + raise exc + if day < 1: + raise exc + + last_day = calendar.monthrange(year, month)[1] + if day > last_day: + return f"{year:04d}-{month:02d}-{last_day:02d}" + + raise exc + + +def normalize_date_range(start_date: str, end_date: str) -> tuple[str, str]: + """Normalize and order an ISO date range.""" + normalized_start = normalize_iso_date(start_date) + normalized_end = normalize_iso_date(end_date) + + start_dt = datetime.strptime(normalized_start, "%Y-%m-%d") + end_dt = datetime.strptime(normalized_end, "%Y-%m-%d") + + if start_dt <= end_dt: + return normalized_start, normalized_end + return normalized_end, normalized_start + + def decorate_all_methods(decorator): def class_decorator(cls): for attr_name, attr_value in cls.__dict__.items(): diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index 3682a01d..2c7e6eec 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -4,15 +4,14 @@ from dateutil.relativedelta import relativedelta import yfinance as yf import os from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry +from .utils import normalize_date_range def get_YFin_data_online( symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], ): - - datetime.strptime(start_date, "%Y-%m-%d") - datetime.strptime(end_date, "%Y-%m-%d") + start_date, end_date = normalize_date_range(start_date, end_date) # Create ticker object ticker = yf.Ticker(symbol.upper()) @@ -461,4 +460,4 @@ def get_insider_transactions( return header + csv_string except Exception as e: - return f"Error retrieving insider transactions for {ticker}: {str(e)}" \ No newline at end of file + return f"Error retrieving insider transactions for {ticker}: {str(e)}" diff --git a/tradingagents/dataflows/yfinance_news.py b/tradingagents/dataflows/yfinance_news.py index 20e9120d..7ce0b87f 100644 --- a/tradingagents/dataflows/yfinance_news.py +++ b/tradingagents/dataflows/yfinance_news.py @@ -4,6 +4,8 @@ import yfinance as yf from datetime import datetime from dateutil.relativedelta import relativedelta +from .utils import normalize_date_range, normalize_iso_date + def _extract_article_data(article: dict) -> dict: """Extract article data from yfinance news format (handles nested 'content' structure).""" @@ -63,6 +65,7 @@ def get_news_yfinance( Formatted string containing news articles """ try: + start_date, end_date = normalize_date_range(start_date, end_date) stock = yf.Ticker(ticker) news = stock.get_news(count=20) @@ -130,6 +133,7 @@ def get_global_news_yfinance( seen_titles = set() try: + curr_date = normalize_iso_date(curr_date) for query in search_queries: search = yf.Search( query=query, diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 898e1e1e..412778b7 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -8,10 +8,10 @@ DEFAULT_CONFIG = { "dataflows/data_cache", ), # LLM settings - "llm_provider": "openai", - "deep_think_llm": "gpt-5.2", - "quick_think_llm": "gpt-5-mini", - "backend_url": "https://api.openai.com/v1", + "llm_provider": os.getenv("TRADINGAGENTS_LLM_PROVIDER", "ollama"), + "deep_think_llm": os.getenv("TRADINGAGENTS_DEEP_THINK_LLM", "qwen3:8b"), + "quick_think_llm": os.getenv("TRADINGAGENTS_QUICK_THINK_LLM", "qwen3:8b"), + "backend_url": os.getenv("TRADINGAGENTS_BACKEND_URL", "http://localhost:11434/v1"), # Provider-specific thinking configuration "google_thinking_level": None, # "high", "minimal", etc. "openai_reasoning_effort": None, # "medium", "high", "low" diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index e0771c65..e9ab144a 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -38,7 +38,7 @@ class GraphSetup: self.conditional_logic = conditional_logic def setup_graph( - self, selected_analysts=["market", "social", "news", "fundamentals"] + self, selected_analysts=None ): """Set up and compile the agent workflow graph. @@ -49,6 +49,9 @@ class GraphSetup: - "news": News analyst - "fundamentals": Fundamentals analyst """ + if selected_analysts is None: + selected_analysts = ["market", "social", "news", "fundamentals"] + if len(selected_analysts) == 0: raise ValueError("Trading Agents Graph Setup Error: no analysts selected!") diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c8cd7492..3435358b 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -45,7 +45,7 @@ class TradingAgentsGraph: def __init__( self, - selected_analysts=["market", "social", "news", "fundamentals"], + selected_analysts=None, debug=False, config: Dict[str, Any] = None, callbacks: Optional[List] = None, @@ -59,6 +59,8 @@ class TradingAgentsGraph: callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats) """ self.debug = debug + if selected_analysts is None: + selected_analysts = ["market", "social", "news", "fundamentals"] self.config = config or DEFAULT_CONFIG self.callbacks = callbacks or [] @@ -66,10 +68,7 @@ class TradingAgentsGraph: set_config(self.config) # Create necessary directories - os.makedirs( - os.path.join(self.config["project_dir"], "dataflows/data_cache"), - exist_ok=True, - ) + os.makedirs(self.config["data_cache_dir"], exist_ok=True) # Initialize LLMs with provider-specific thinking configuration llm_kwargs = self._get_provider_kwargs() @@ -259,11 +258,11 @@ class TradingAgentsGraph: } # Save to file - directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/") + directory = Path(self.config["results_dir"]) / "TradingAgentsStrategy_logs" directory.mkdir(parents=True, exist_ok=True) with open( - f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json", + directory / f"full_states_log_{trade_date}.json", "w", encoding="utf-8", ) as f: diff --git a/tradingagents/llm_clients/TODO.md b/tradingagents/llm_clients/TODO.md index d5b5ac9c..500efbae 100644 --- a/tradingagents/llm_clients/TODO.md +++ b/tradingagents/llm_clients/TODO.md @@ -1,9 +1,12 @@ # LLM Clients - Consistency Improvements -## Issues to Fix +## Completed -### 1. `validate_model()` is never called -- Add validation call in `get_llm()` with warning (not error) for unknown models +### 1. `validate_model()` warning path +- `create_llm_client()` now calls `validate_model()` and emits a warning for + unknown model names instead of failing immediately. + +## Remaining Issues to Fix ### 2. Inconsistent parameter handling | Client | API Key Param | Special Params | diff --git a/tradingagents/llm_clients/factory.py b/tradingagents/llm_clients/factory.py index 93c2a7d3..410e9301 100644 --- a/tradingagents/llm_clients/factory.py +++ b/tradingagents/llm_clients/factory.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional from .base_client import BaseLLMClient @@ -35,15 +36,22 @@ def create_llm_client( provider_lower = provider.lower() if provider_lower in ("openai", "ollama", "openrouter"): - return OpenAIClient(model, base_url, provider=provider_lower, **kwargs) + client = OpenAIClient(model, base_url, provider=provider_lower, **kwargs) + elif provider_lower == "xai": + client = OpenAIClient(model, base_url, provider="xai", **kwargs) + elif provider_lower == "anthropic": + client = AnthropicClient(model, base_url, **kwargs) + elif provider_lower == "google": + client = GoogleClient(model, base_url, **kwargs) + else: + raise ValueError(f"Unsupported LLM provider: {provider}") - if provider_lower == "xai": - return OpenAIClient(model, base_url, provider="xai", **kwargs) + if not client.validate_model(): + warnings.warn( + f"Model '{model}' is not in the known model list for provider " + f"'{provider_lower}'. The request may still work if the provider " + "supports it, but mis-typed model names will fail at runtime.", + stacklevel=2, + ) - if provider_lower == "anthropic": - return AnthropicClient(model, base_url, **kwargs) - - if provider_lower == "google": - return GoogleClient(model, base_url, **kwargs) - - raise ValueError(f"Unsupported LLM provider: {provider}") + return client diff --git a/tradingagents/llm_clients/openai_client.py b/tradingagents/llm_clients/openai_client.py index fd9b4e33..147b8e59 100644 --- a/tradingagents/llm_clients/openai_client.py +++ b/tradingagents/llm_clients/openai_client.py @@ -57,8 +57,8 @@ class OpenAIClient(BaseLLMClient): # Provider-specific base URL and auth if self.provider in _PROVIDER_CONFIG: - base_url, api_key_env = _PROVIDER_CONFIG[self.provider] - llm_kwargs["base_url"] = base_url + default_base_url, api_key_env = _PROVIDER_CONFIG[self.provider] + llm_kwargs["base_url"] = self.base_url or default_base_url if api_key_env: api_key = os.environ.get(api_key_env) if api_key: