Improve Ollama defaults, data resilience, and docs

This commit is contained in:
Oğuzcan Toptaş 2026-03-29 03:22:32 +03:00
parent 589b351f2a
commit d861eccd2d
21 changed files with 274 additions and 53 deletions

View File

@ -117,6 +117,12 @@ Install the package and its dependencies:
pip install .
```
If you are actively modifying the project, prefer an editable install so the
`tradingagents` CLI always reflects your local code changes:
```bash
pip install -e .
```
### Required APIs
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
@ -132,6 +138,17 @@ export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
Example local setup for a default Ollama install running `qwen3:8b`:
```bash
export TRADINGAGENTS_LLM_PROVIDER=ollama
export TRADINGAGENTS_BACKEND_URL=http://localhost:11434/v1
export TRADINGAGENTS_QUICK_THINK_LLM=qwen3:8b
export TRADINGAGENTS_DEEP_THINK_LLM=qwen3:8b
```
The default config in this fork is also set up for Ollama on
`http://localhost:11434/v1`, with both thinking models pointing to `qwen3:8b`.
Alternatively, copy `.env.example` to `.env` and fill in your keys:
```bash
cp .env.example .env
@ -198,6 +215,16 @@ _, decision = ta.propagate("NVDA", "2026-01-15")
print(decision)
```
Local Ollama example:
```python
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "ollama"
config["backend_url"] = "http://localhost:11434/v1"
config["deep_think_llm"] = "qwen3:8b"
config["quick_think_llm"] = "qwen3:8b"
```
See `tradingagents/default_config.py` for all configuration options.
## Contributing

View File

@ -939,6 +939,15 @@ def run_analysis():
selected_set = {analyst.value for analyst in selections["analysts"]}
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
# Create result directory early so graph logging uses the same run-specific base path.
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
log_file = results_dir / "message_tool.log"
log_file.touch(exist_ok=True)
config["results_dir"] = str(results_dir)
# Initialize the graph with callbacks bound to LLMs
graph = TradingAgentsGraph(
selected_analyst_keys,
@ -953,14 +962,6 @@ def run_analysis():
# Track start time for elapsed display
start_time = time.time()
# Create result directory
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
results_dir.mkdir(parents=True, exist_ok=True)
report_dir = results_dir / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
log_file = results_dir / "message_tool.log"
log_file.touch(exist_ok=True)
def save_message_decorator(obj, func_name):
func = getattr(obj, func_name)
@wraps(func)

View File

@ -167,7 +167,7 @@ def select_shallow_thinking_agent(provider) -> str:
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
],
"ollama": [
("Qwen3:latest (8B, local)", "qwen3:latest"),
("Qwen3:8B (local)", "qwen3:8b"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
],
@ -236,7 +236,7 @@ def select_deep_thinking_agent(provider) -> str:
"ollama": [
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
("Qwen3:latest (8B, local)", "qwen3:latest"),
("Qwen3:8B (local)", "qwen3:8b"),
],
}

View File

@ -8,8 +8,8 @@ load_dotenv()
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["deep_think_llm"] = "gpt-5-mini" # Use a different model
config["quick_think_llm"] = "gpt-5-mini" # Use a different model
config["deep_think_llm"] = "qwen3:8b"
config["quick_think_llm"] = "qwen3:8b"
config["max_debate_rounds"] = 1 # Increase debate rounds
# Configure data vendors (default uses yfinance, no extra API keys needed)

15
test.py
View File

@ -1,11 +1,12 @@
import time
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
print("Testing optimized implementation with 30-day lookback:")
start_time = time.time()
result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30)
end_time = time.time()
if __name__ == "__main__":
print("Testing optimized implementation with 30-day lookback:")
start_time = time.time()
result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30)
end_time = time.time()
print(f"Execution time: {end_time - start_time:.2f} seconds")
print(f"Result length: {len(result)} characters")
print(result)
print(f"Execution time: {end_time - start_time:.2f} seconds")
print(f"Result length: {len(result)} characters")
print(result)

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@

View File

@ -0,0 +1,26 @@
import unittest
from tradingagents.dataflows.utils import normalize_date_range, normalize_iso_date
class DateNormalizationTests(unittest.TestCase):
def test_normalize_iso_date_keeps_valid_dates(self):
self.assertEqual(normalize_iso_date("2024-02-29"), "2024-02-29")
def test_normalize_iso_date_clamps_invalid_month_end(self):
self.assertEqual(normalize_iso_date("2026-02-29"), "2026-02-28")
self.assertEqual(normalize_iso_date("2026-04-31"), "2026-04-30")
def test_normalize_date_range_orders_dates_after_normalization(self):
self.assertEqual(
normalize_date_range("2026-03-29", "2026-02-29"),
("2026-02-28", "2026-03-29"),
)
def test_normalize_iso_date_rejects_bad_format(self):
with self.assertRaises(ValueError):
normalize_iso_date("2026/02/29")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,58 @@
import json
import tempfile
import unittest
from pathlib import Path
from tradingagents.graph.trading_graph import TradingAgentsGraph
class GraphLoggingTests(unittest.TestCase):
def test_log_state_uses_configured_results_dir(self):
with tempfile.TemporaryDirectory() as tmpdir:
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
graph.config = {"results_dir": tmpdir}
graph.ticker = "SPY"
graph.log_states_dict = {}
final_state = {
"company_of_interest": "SPY",
"trade_date": "2026-03-29",
"market_report": "market",
"sentiment_report": "sentiment",
"news_report": "news",
"fundamentals_report": "fundamentals",
"investment_debate_state": {
"bull_history": "bull",
"bear_history": "bear",
"history": "history",
"current_response": "current",
"judge_decision": "judge",
},
"trader_investment_plan": "trade plan",
"risk_debate_state": {
"aggressive_history": "aggressive",
"conservative_history": "conservative",
"neutral_history": "neutral",
"history": "risk history",
"judge_decision": "portfolio judge",
},
"investment_plan": "investment plan",
"final_trade_decision": "HOLD",
}
graph._log_state("2026-03-29", final_state)
log_path = (
Path(tmpdir)
/ "TradingAgentsStrategy_logs"
/ "full_states_log_2026-03-29.json"
)
self.assertTrue(log_path.exists())
payload = json.loads(log_path.read_text())
self.assertIn("2026-03-29", payload)
self.assertEqual(payload["2026-03-29"]["final_trade_decision"], "HOLD")
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,36 @@
import unittest
import warnings
from tradingagents.llm_clients.factory import create_llm_client
class LLMClientFactoryTests(unittest.TestCase):
def test_invalid_known_provider_model_emits_warning(self):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client = create_llm_client("openai", "not-a-real-openai-model")
self.assertEqual(client.provider, "openai")
self.assertTrue(
any("not-a-real-openai-model" in str(w.message) for w in caught)
)
def test_known_valid_model_does_not_emit_warning(self):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client = create_llm_client("openai", "gpt-5-mini")
self.assertEqual(client.provider, "openai")
self.assertEqual(len(caught), 0)
def test_ollama_custom_models_do_not_emit_warning(self):
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
client = create_llm_client("ollama", "qwen3:8b")
self.assertEqual(client.provider, "ollama")
self.assertEqual(len(caught), 0)
if __name__ == "__main__":
unittest.main()

View File

@ -5,6 +5,8 @@ import json
from datetime import datetime
from io import StringIO
from .utils import normalize_date_range, normalize_iso_date
API_BASE_URL = "https://www.alphavantage.co/query"
def get_api_key() -> str:
@ -22,7 +24,7 @@ def format_datetime_for_api(date_input) -> str:
return date_input
# Try to parse common date formats
try:
dt = datetime.strptime(date_input, "%Y-%m-%d")
dt = datetime.strptime(normalize_iso_date(date_input), "%Y-%m-%d")
return dt.strftime("%Y%m%dT0000")
except ValueError:
try:
@ -100,6 +102,7 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
return csv_data
try:
start_date, end_date = normalize_date_range(start_date, end_date)
# Parse CSV data
df = pd.read_csv(StringIO(csv_data))

View File

@ -1,4 +1,5 @@
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
from .utils import normalize_date_range, normalize_iso_date
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
@ -14,6 +15,8 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
Dictionary containing news sentiment data or JSON string.
"""
start_date, end_date = normalize_date_range(start_date, end_date)
params = {
"tickers": ticker,
"time_from": format_datetime_for_api(start_date),
@ -38,6 +41,7 @@ def get_global_news(curr_date, look_back_days: int = 7, limit: int = 50) -> dict
from datetime import datetime, timedelta
# Calculate start date
curr_date = normalize_iso_date(curr_date)
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
start_dt = curr_dt - timedelta(days=look_back_days)
start_date = start_dt.strftime("%Y-%m-%d")
@ -68,4 +72,4 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str:
"symbol": symbol,
}
return _make_api_request("INSIDER_TRANSACTIONS", params)
return _make_api_request("INSIDER_TRANSACTIONS", params)

View File

@ -1,5 +1,6 @@
from datetime import datetime
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
from .utils import normalize_date_range
def get_stock(
symbol: str,
@ -18,6 +19,8 @@ def get_stock(
Returns:
CSV string containing the daily adjusted time series data filtered to the date range.
"""
start_date, end_date = normalize_date_range(start_date, end_date)
# Parse dates to determine the range
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
today = datetime.now()
@ -35,4 +38,4 @@ def get_stock(
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
return _filter_csv_by_date_range(response, start_date, end_date)
return _filter_csv_by_date_range(response, start_date, end_date)

View File

@ -1,6 +1,7 @@
import os
import json
import pandas as pd
import calendar
from datetime import date, timedelta, datetime
from typing import Annotated
@ -16,6 +17,50 @@ def get_current_date():
return date.today().strftime("%Y-%m-%d")
def normalize_iso_date(date_str: str) -> str:
"""Normalize YYYY-MM-DD dates, clamping invalid month-end days.
LLM tool calls occasionally produce dates like 2026-02-29 when they mean
"the end of February". For valid ISO dates this returns the input as-is.
For invalid day-of-month values within a valid year/month, it clamps to the
last valid day of that month. Other malformed values still raise ValueError.
"""
try:
return datetime.strptime(date_str, "%Y-%m-%d").strftime("%Y-%m-%d")
except ValueError as exc:
try:
year_str, month_str, day_str = date_str.split("-")
year = int(year_str)
month = int(month_str)
day = int(day_str)
except (AttributeError, ValueError) as parse_exc:
raise ValueError(f"Unsupported date format: {date_str}") from parse_exc
if not 1 <= month <= 12:
raise exc
if day < 1:
raise exc
last_day = calendar.monthrange(year, month)[1]
if day > last_day:
return f"{year:04d}-{month:02d}-{last_day:02d}"
raise exc
def normalize_date_range(start_date: str, end_date: str) -> tuple[str, str]:
"""Normalize and order an ISO date range."""
normalized_start = normalize_iso_date(start_date)
normalized_end = normalize_iso_date(end_date)
start_dt = datetime.strptime(normalized_start, "%Y-%m-%d")
end_dt = datetime.strptime(normalized_end, "%Y-%m-%d")
if start_dt <= end_dt:
return normalized_start, normalized_end
return normalized_end, normalized_start
def decorate_all_methods(decorator):
def class_decorator(cls):
for attr_name, attr_value in cls.__dict__.items():

View File

@ -4,15 +4,14 @@ from dateutil.relativedelta import relativedelta
import yfinance as yf
import os
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
from .utils import normalize_date_range
def get_YFin_data_online(
symbol: Annotated[str, "ticker symbol of the company"],
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
):
datetime.strptime(start_date, "%Y-%m-%d")
datetime.strptime(end_date, "%Y-%m-%d")
start_date, end_date = normalize_date_range(start_date, end_date)
# Create ticker object
ticker = yf.Ticker(symbol.upper())
@ -461,4 +460,4 @@ def get_insider_transactions(
return header + csv_string
except Exception as e:
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
return f"Error retrieving insider transactions for {ticker}: {str(e)}"

View File

@ -4,6 +4,8 @@ import yfinance as yf
from datetime import datetime
from dateutil.relativedelta import relativedelta
from .utils import normalize_date_range, normalize_iso_date
def _extract_article_data(article: dict) -> dict:
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
@ -63,6 +65,7 @@ def get_news_yfinance(
Formatted string containing news articles
"""
try:
start_date, end_date = normalize_date_range(start_date, end_date)
stock = yf.Ticker(ticker)
news = stock.get_news(count=20)
@ -130,6 +133,7 @@ def get_global_news_yfinance(
seen_titles = set()
try:
curr_date = normalize_iso_date(curr_date)
for query in search_queries:
search = yf.Search(
query=query,

View File

@ -8,10 +8,10 @@ DEFAULT_CONFIG = {
"dataflows/data_cache",
),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "gpt-5.2",
"quick_think_llm": "gpt-5-mini",
"backend_url": "https://api.openai.com/v1",
"llm_provider": os.getenv("TRADINGAGENTS_LLM_PROVIDER", "ollama"),
"deep_think_llm": os.getenv("TRADINGAGENTS_DEEP_THINK_LLM", "qwen3:8b"),
"quick_think_llm": os.getenv("TRADINGAGENTS_QUICK_THINK_LLM", "qwen3:8b"),
"backend_url": os.getenv("TRADINGAGENTS_BACKEND_URL", "http://localhost:11434/v1"),
# Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"

View File

@ -38,7 +38,7 @@ class GraphSetup:
self.conditional_logic = conditional_logic
def setup_graph(
self, selected_analysts=["market", "social", "news", "fundamentals"]
self, selected_analysts=None
):
"""Set up and compile the agent workflow graph.
@ -49,6 +49,9 @@ class GraphSetup:
- "news": News analyst
- "fundamentals": Fundamentals analyst
"""
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
if len(selected_analysts) == 0:
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")

View File

@ -45,7 +45,7 @@ class TradingAgentsGraph:
def __init__(
self,
selected_analysts=["market", "social", "news", "fundamentals"],
selected_analysts=None,
debug=False,
config: Dict[str, Any] = None,
callbacks: Optional[List] = None,
@ -59,6 +59,8 @@ class TradingAgentsGraph:
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
"""
self.debug = debug
if selected_analysts is None:
selected_analysts = ["market", "social", "news", "fundamentals"]
self.config = config or DEFAULT_CONFIG
self.callbacks = callbacks or []
@ -66,10 +68,7 @@ class TradingAgentsGraph:
set_config(self.config)
# Create necessary directories
os.makedirs(
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
exist_ok=True,
)
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
@ -259,11 +258,11 @@ class TradingAgentsGraph:
}
# Save to file
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
directory = Path(self.config["results_dir"]) / "TradingAgentsStrategy_logs"
directory.mkdir(parents=True, exist_ok=True)
with open(
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
directory / f"full_states_log_{trade_date}.json",
"w",
encoding="utf-8",
) as f:

View File

@ -1,9 +1,12 @@
# LLM Clients - Consistency Improvements
## Issues to Fix
## Completed
### 1. `validate_model()` is never called
- Add validation call in `get_llm()` with warning (not error) for unknown models
### 1. `validate_model()` warning path
- `create_llm_client()` now calls `validate_model()` and emits a warning for
unknown model names instead of failing immediately.
## Remaining Issues to Fix
### 2. Inconsistent parameter handling
| Client | API Key Param | Special Params |

View File

@ -1,3 +1,4 @@
import warnings
from typing import Optional
from .base_client import BaseLLMClient
@ -35,15 +36,22 @@ def create_llm_client(
provider_lower = provider.lower()
if provider_lower in ("openai", "ollama", "openrouter"):
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
client = OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
elif provider_lower == "xai":
client = OpenAIClient(model, base_url, provider="xai", **kwargs)
elif provider_lower == "anthropic":
client = AnthropicClient(model, base_url, **kwargs)
elif provider_lower == "google":
client = GoogleClient(model, base_url, **kwargs)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
if provider_lower == "xai":
return OpenAIClient(model, base_url, provider="xai", **kwargs)
if not client.validate_model():
warnings.warn(
f"Model '{model}' is not in the known model list for provider "
f"'{provider_lower}'. The request may still work if the provider "
"supports it, but mis-typed model names will fail at runtime.",
stacklevel=2,
)
if provider_lower == "anthropic":
return AnthropicClient(model, base_url, **kwargs)
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")
return client

View File

@ -57,8 +57,8 @@ class OpenAIClient(BaseLLMClient):
# Provider-specific base URL and auth
if self.provider in _PROVIDER_CONFIG:
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = base_url
default_base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
llm_kwargs["base_url"] = self.base_url or default_base_url
if api_key_env:
api_key = os.environ.get(api_key_env)
if api_key: