Improve Ollama defaults, data resilience, and docs
This commit is contained in:
parent
589b351f2a
commit
d861eccd2d
27
README.md
27
README.md
|
|
@ -117,6 +117,12 @@ Install the package and its dependencies:
|
|||
pip install .
|
||||
```
|
||||
|
||||
If you are actively modifying the project, prefer an editable install so the
|
||||
`tradingagents` CLI always reflects your local code changes:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Required APIs
|
||||
|
||||
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
|
||||
|
|
@ -132,6 +138,17 @@ export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
|||
|
||||
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
||||
|
||||
Example local setup for a default Ollama install running `qwen3:8b`:
|
||||
```bash
|
||||
export TRADINGAGENTS_LLM_PROVIDER=ollama
|
||||
export TRADINGAGENTS_BACKEND_URL=http://localhost:11434/v1
|
||||
export TRADINGAGENTS_QUICK_THINK_LLM=qwen3:8b
|
||||
export TRADINGAGENTS_DEEP_THINK_LLM=qwen3:8b
|
||||
```
|
||||
|
||||
The default config in this fork is also set up for Ollama on
|
||||
`http://localhost:11434/v1`, with both thinking models pointing to `qwen3:8b`.
|
||||
|
||||
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
|
|
@ -198,6 +215,16 @@ _, decision = ta.propagate("NVDA", "2026-01-15")
|
|||
print(decision)
|
||||
```
|
||||
|
||||
Local Ollama example:
|
||||
|
||||
```python
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["llm_provider"] = "ollama"
|
||||
config["backend_url"] = "http://localhost:11434/v1"
|
||||
config["deep_think_llm"] = "qwen3:8b"
|
||||
config["quick_think_llm"] = "qwen3:8b"
|
||||
```
|
||||
|
||||
See `tradingagents/default_config.py` for all configuration options.
|
||||
|
||||
## Contributing
|
||||
|
|
|
|||
17
cli/main.py
17
cli/main.py
|
|
@ -939,6 +939,15 @@ def run_analysis():
|
|||
selected_set = {analyst.value for analyst in selections["analysts"]}
|
||||
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
|
||||
|
||||
# Create result directory early so graph logging uses the same run-specific base path.
|
||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
report_dir = results_dir / "reports"
|
||||
report_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = results_dir / "message_tool.log"
|
||||
log_file.touch(exist_ok=True)
|
||||
config["results_dir"] = str(results_dir)
|
||||
|
||||
# Initialize the graph with callbacks bound to LLMs
|
||||
graph = TradingAgentsGraph(
|
||||
selected_analyst_keys,
|
||||
|
|
@ -953,14 +962,6 @@ def run_analysis():
|
|||
# Track start time for elapsed display
|
||||
start_time = time.time()
|
||||
|
||||
# Create result directory
|
||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
report_dir = results_dir / "reports"
|
||||
report_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = results_dir / "message_tool.log"
|
||||
log_file.touch(exist_ok=True)
|
||||
|
||||
def save_message_decorator(obj, func_name):
|
||||
func = getattr(obj, func_name)
|
||||
@wraps(func)
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
],
|
||||
"ollama": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
("Qwen3:8B (local)", "qwen3:8b"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
],
|
||||
|
|
@ -236,7 +236,7 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
"ollama": [
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
("Qwen3:8B (local)", "qwen3:8b"),
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
|||
4
main.py
4
main.py
|
|
@ -8,8 +8,8 @@ load_dotenv()
|
|||
|
||||
# Create a custom config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["deep_think_llm"] = "gpt-5-mini" # Use a different model
|
||||
config["quick_think_llm"] = "gpt-5-mini" # Use a different model
|
||||
config["deep_think_llm"] = "qwen3:8b"
|
||||
config["quick_think_llm"] = "qwen3:8b"
|
||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||
|
||||
# Configure data vendors (default uses yfinance, no extra API keys needed)
|
||||
|
|
|
|||
15
test.py
15
test.py
|
|
@ -1,11 +1,12 @@
|
|||
import time
|
||||
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
||||
|
||||
print("Testing optimized implementation with 30-day lookback:")
|
||||
start_time = time.time()
|
||||
result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30)
|
||||
end_time = time.time()
|
||||
if __name__ == "__main__":
|
||||
print("Testing optimized implementation with 30-day lookback:")
|
||||
start_time = time.time()
|
||||
result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Execution time: {end_time - start_time:.2f} seconds")
|
||||
print(f"Result length: {len(result)} characters")
|
||||
print(result)
|
||||
print(f"Execution time: {end_time - start_time:.2f} seconds")
|
||||
print(f"Result length: {len(result)} characters")
|
||||
print(result)
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import unittest
|
||||
|
||||
from tradingagents.dataflows.utils import normalize_date_range, normalize_iso_date
|
||||
|
||||
|
||||
class DateNormalizationTests(unittest.TestCase):
|
||||
def test_normalize_iso_date_keeps_valid_dates(self):
|
||||
self.assertEqual(normalize_iso_date("2024-02-29"), "2024-02-29")
|
||||
|
||||
def test_normalize_iso_date_clamps_invalid_month_end(self):
|
||||
self.assertEqual(normalize_iso_date("2026-02-29"), "2026-02-28")
|
||||
self.assertEqual(normalize_iso_date("2026-04-31"), "2026-04-30")
|
||||
|
||||
def test_normalize_date_range_orders_dates_after_normalization(self):
|
||||
self.assertEqual(
|
||||
normalize_date_range("2026-03-29", "2026-02-29"),
|
||||
("2026-02-28", "2026-03-29"),
|
||||
)
|
||||
|
||||
def test_normalize_iso_date_rejects_bad_format(self):
|
||||
with self.assertRaises(ValueError):
|
||||
normalize_iso_date("2026/02/29")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
|
||||
class GraphLoggingTests(unittest.TestCase):
|
||||
def test_log_state_uses_configured_results_dir(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||
graph.config = {"results_dir": tmpdir}
|
||||
graph.ticker = "SPY"
|
||||
graph.log_states_dict = {}
|
||||
|
||||
final_state = {
|
||||
"company_of_interest": "SPY",
|
||||
"trade_date": "2026-03-29",
|
||||
"market_report": "market",
|
||||
"sentiment_report": "sentiment",
|
||||
"news_report": "news",
|
||||
"fundamentals_report": "fundamentals",
|
||||
"investment_debate_state": {
|
||||
"bull_history": "bull",
|
||||
"bear_history": "bear",
|
||||
"history": "history",
|
||||
"current_response": "current",
|
||||
"judge_decision": "judge",
|
||||
},
|
||||
"trader_investment_plan": "trade plan",
|
||||
"risk_debate_state": {
|
||||
"aggressive_history": "aggressive",
|
||||
"conservative_history": "conservative",
|
||||
"neutral_history": "neutral",
|
||||
"history": "risk history",
|
||||
"judge_decision": "portfolio judge",
|
||||
},
|
||||
"investment_plan": "investment plan",
|
||||
"final_trade_decision": "HOLD",
|
||||
}
|
||||
|
||||
graph._log_state("2026-03-29", final_state)
|
||||
|
||||
log_path = (
|
||||
Path(tmpdir)
|
||||
/ "TradingAgentsStrategy_logs"
|
||||
/ "full_states_log_2026-03-29.json"
|
||||
)
|
||||
self.assertTrue(log_path.exists())
|
||||
|
||||
payload = json.loads(log_path.read_text())
|
||||
self.assertIn("2026-03-29", payload)
|
||||
self.assertEqual(payload["2026-03-29"]["final_trade_decision"], "HOLD")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
import unittest
|
||||
import warnings
|
||||
|
||||
from tradingagents.llm_clients.factory import create_llm_client
|
||||
|
||||
|
||||
class LLMClientFactoryTests(unittest.TestCase):
|
||||
def test_invalid_known_provider_model_emits_warning(self):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
client = create_llm_client("openai", "not-a-real-openai-model")
|
||||
|
||||
self.assertEqual(client.provider, "openai")
|
||||
self.assertTrue(
|
||||
any("not-a-real-openai-model" in str(w.message) for w in caught)
|
||||
)
|
||||
|
||||
def test_known_valid_model_does_not_emit_warning(self):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
client = create_llm_client("openai", "gpt-5-mini")
|
||||
|
||||
self.assertEqual(client.provider, "openai")
|
||||
self.assertEqual(len(caught), 0)
|
||||
|
||||
def test_ollama_custom_models_do_not_emit_warning(self):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
client = create_llm_client("ollama", "qwen3:8b")
|
||||
|
||||
self.assertEqual(client.provider, "ollama")
|
||||
self.assertEqual(len(caught), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -5,6 +5,8 @@ import json
|
|||
from datetime import datetime
|
||||
from io import StringIO
|
||||
|
||||
from .utils import normalize_date_range, normalize_iso_date
|
||||
|
||||
API_BASE_URL = "https://www.alphavantage.co/query"
|
||||
|
||||
def get_api_key() -> str:
|
||||
|
|
@ -22,7 +24,7 @@ def format_datetime_for_api(date_input) -> str:
|
|||
return date_input
|
||||
# Try to parse common date formats
|
||||
try:
|
||||
dt = datetime.strptime(date_input, "%Y-%m-%d")
|
||||
dt = datetime.strptime(normalize_iso_date(date_input), "%Y-%m-%d")
|
||||
return dt.strftime("%Y%m%dT0000")
|
||||
except ValueError:
|
||||
try:
|
||||
|
|
@ -100,6 +102,7 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
|
|||
return csv_data
|
||||
|
||||
try:
|
||||
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||
# Parse CSV data
|
||||
df = pd.read_csv(StringIO(csv_data))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||
from .utils import normalize_date_range, normalize_iso_date
|
||||
|
||||
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
||||
|
|
@ -14,6 +15,8 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
|||
Dictionary containing news sentiment data or JSON string.
|
||||
"""
|
||||
|
||||
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||
|
||||
params = {
|
||||
"tickers": ticker,
|
||||
"time_from": format_datetime_for_api(start_date),
|
||||
|
|
@ -38,6 +41,7 @@ def get_global_news(curr_date, look_back_days: int = 7, limit: int = 50) -> dict
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
# Calculate start date
|
||||
curr_date = normalize_iso_date(curr_date)
|
||||
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
start_dt = curr_dt - timedelta(days=look_back_days)
|
||||
start_date = start_dt.strftime("%Y-%m-%d")
|
||||
|
|
@ -68,4 +72,4 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str:
|
|||
"symbol": symbol,
|
||||
}
|
||||
|
||||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from datetime import datetime
|
||||
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
|
||||
from .utils import normalize_date_range
|
||||
|
||||
def get_stock(
|
||||
symbol: str,
|
||||
|
|
@ -18,6 +19,8 @@ def get_stock(
|
|||
Returns:
|
||||
CSV string containing the daily adjusted time series data filtered to the date range.
|
||||
"""
|
||||
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||
|
||||
# Parse dates to determine the range
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
today = datetime.now()
|
||||
|
|
@ -35,4 +38,4 @@ def get_stock(
|
|||
|
||||
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
|
||||
|
||||
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import calendar
|
||||
from datetime import date, timedelta, datetime
|
||||
from typing import Annotated
|
||||
|
||||
|
|
@ -16,6 +17,50 @@ def get_current_date():
|
|||
return date.today().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def normalize_iso_date(date_str: str) -> str:
|
||||
"""Normalize YYYY-MM-DD dates, clamping invalid month-end days.
|
||||
|
||||
LLM tool calls occasionally produce dates like 2026-02-29 when they mean
|
||||
"the end of February". For valid ISO dates this returns the input as-is.
|
||||
For invalid day-of-month values within a valid year/month, it clamps to the
|
||||
last valid day of that month. Other malformed values still raise ValueError.
|
||||
"""
|
||||
try:
|
||||
return datetime.strptime(date_str, "%Y-%m-%d").strftime("%Y-%m-%d")
|
||||
except ValueError as exc:
|
||||
try:
|
||||
year_str, month_str, day_str = date_str.split("-")
|
||||
year = int(year_str)
|
||||
month = int(month_str)
|
||||
day = int(day_str)
|
||||
except (AttributeError, ValueError) as parse_exc:
|
||||
raise ValueError(f"Unsupported date format: {date_str}") from parse_exc
|
||||
|
||||
if not 1 <= month <= 12:
|
||||
raise exc
|
||||
if day < 1:
|
||||
raise exc
|
||||
|
||||
last_day = calendar.monthrange(year, month)[1]
|
||||
if day > last_day:
|
||||
return f"{year:04d}-{month:02d}-{last_day:02d}"
|
||||
|
||||
raise exc
|
||||
|
||||
|
||||
def normalize_date_range(start_date: str, end_date: str) -> tuple[str, str]:
|
||||
"""Normalize and order an ISO date range."""
|
||||
normalized_start = normalize_iso_date(start_date)
|
||||
normalized_end = normalize_iso_date(end_date)
|
||||
|
||||
start_dt = datetime.strptime(normalized_start, "%Y-%m-%d")
|
||||
end_dt = datetime.strptime(normalized_end, "%Y-%m-%d")
|
||||
|
||||
if start_dt <= end_dt:
|
||||
return normalized_start, normalized_end
|
||||
return normalized_end, normalized_start
|
||||
|
||||
|
||||
def decorate_all_methods(decorator):
|
||||
def class_decorator(cls):
|
||||
for attr_name, attr_value in cls.__dict__.items():
|
||||
|
|
|
|||
|
|
@ -4,15 +4,14 @@ from dateutil.relativedelta import relativedelta
|
|||
import yfinance as yf
|
||||
import os
|
||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
|
||||
from .utils import normalize_date_range
|
||||
|
||||
def get_YFin_data_online(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||
):
|
||||
|
||||
datetime.strptime(start_date, "%Y-%m-%d")
|
||||
datetime.strptime(end_date, "%Y-%m-%d")
|
||||
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||
|
||||
# Create ticker object
|
||||
ticker = yf.Ticker(symbol.upper())
|
||||
|
|
@ -461,4 +460,4 @@ def get_insider_transactions(
|
|||
return header + csv_string
|
||||
|
||||
except Exception as e:
|
||||
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
||||
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import yfinance as yf
|
|||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from .utils import normalize_date_range, normalize_iso_date
|
||||
|
||||
|
||||
def _extract_article_data(article: dict) -> dict:
|
||||
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
|
||||
|
|
@ -63,6 +65,7 @@ def get_news_yfinance(
|
|||
Formatted string containing news articles
|
||||
"""
|
||||
try:
|
||||
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||
stock = yf.Ticker(ticker)
|
||||
news = stock.get_news(count=20)
|
||||
|
||||
|
|
@ -130,6 +133,7 @@ def get_global_news_yfinance(
|
|||
seen_titles = set()
|
||||
|
||||
try:
|
||||
curr_date = normalize_iso_date(curr_date)
|
||||
for query in search_queries:
|
||||
search = yf.Search(
|
||||
query=query,
|
||||
|
|
|
|||
|
|
@ -8,10 +8,10 @@ DEFAULT_CONFIG = {
|
|||
"dataflows/data_cache",
|
||||
),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "gpt-5.2",
|
||||
"quick_think_llm": "gpt-5-mini",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
"llm_provider": os.getenv("TRADINGAGENTS_LLM_PROVIDER", "ollama"),
|
||||
"deep_think_llm": os.getenv("TRADINGAGENTS_DEEP_THINK_LLM", "qwen3:8b"),
|
||||
"quick_think_llm": os.getenv("TRADINGAGENTS_QUICK_THINK_LLM", "qwen3:8b"),
|
||||
"backend_url": os.getenv("TRADINGAGENTS_BACKEND_URL", "http://localhost:11434/v1"),
|
||||
# Provider-specific thinking configuration
|
||||
"google_thinking_level": None, # "high", "minimal", etc.
|
||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class GraphSetup:
|
|||
self.conditional_logic = conditional_logic
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
self, selected_analysts=None
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
|
|
@ -49,6 +49,9 @@ class GraphSetup:
|
|||
- "news": News analyst
|
||||
- "fundamentals": Fundamentals analyst
|
||||
"""
|
||||
if selected_analysts is None:
|
||||
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||
|
||||
if len(selected_analysts) == 0:
|
||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class TradingAgentsGraph:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
selected_analysts=None,
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
callbacks: Optional[List] = None,
|
||||
|
|
@ -59,6 +59,8 @@ class TradingAgentsGraph:
|
|||
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||||
"""
|
||||
self.debug = debug
|
||||
if selected_analysts is None:
|
||||
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
|
|
@ -66,10 +68,7 @@ class TradingAgentsGraph:
|
|||
set_config(self.config)
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs(
|
||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
||||
exist_ok=True,
|
||||
)
|
||||
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
|
||||
|
||||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
|
|
@ -259,11 +258,11 @@ class TradingAgentsGraph:
|
|||
}
|
||||
|
||||
# Save to file
|
||||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
||||
directory = Path(self.config["results_dir"]) / "TradingAgentsStrategy_logs"
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(
|
||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
||||
directory / f"full_states_log_{trade_date}.json",
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
# LLM Clients - Consistency Improvements
|
||||
|
||||
## Issues to Fix
|
||||
## Completed
|
||||
|
||||
### 1. `validate_model()` is never called
|
||||
- Add validation call in `get_llm()` with warning (not error) for unknown models
|
||||
### 1. `validate_model()` warning path
|
||||
- `create_llm_client()` now calls `validate_model()` and emits a warning for
|
||||
unknown model names instead of failing immediately.
|
||||
|
||||
## Remaining Issues to Fix
|
||||
|
||||
### 2. Inconsistent parameter handling
|
||||
| Client | API Key Param | Special Params |
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
|
@ -35,15 +36,22 @@ def create_llm_client(
|
|||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
client = OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||
elif provider_lower == "xai":
|
||||
client = OpenAIClient(model, base_url, provider="xai", **kwargs)
|
||||
elif provider_lower == "anthropic":
|
||||
client = AnthropicClient(model, base_url, **kwargs)
|
||||
elif provider_lower == "google":
|
||||
client = GoogleClient(model, base_url, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
if provider_lower == "xai":
|
||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
||||
if not client.validate_model():
|
||||
warnings.warn(
|
||||
f"Model '{model}' is not in the known model list for provider "
|
||||
f"'{provider_lower}'. The request may still work if the provider "
|
||||
"supports it, but mis-typed model names will fail at runtime.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if provider_lower == "anthropic":
|
||||
return AnthropicClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
return client
|
||||
|
|
|
|||
|
|
@ -57,8 +57,8 @@ class OpenAIClient(BaseLLMClient):
|
|||
|
||||
# Provider-specific base URL and auth
|
||||
if self.provider in _PROVIDER_CONFIG:
|
||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||
llm_kwargs["base_url"] = base_url
|
||||
default_base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||
llm_kwargs["base_url"] = self.base_url or default_base_url
|
||||
if api_key_env:
|
||||
api_key = os.environ.get(api_key_env)
|
||||
if api_key:
|
||||
|
|
|
|||
Loading…
Reference in New Issue