Compare commits
10 Commits
589b351f2a
...
c61242a28c
| Author | SHA1 | Date |
|---|---|---|
|
|
c61242a28c | |
|
|
58e99421bd | |
|
|
46e1b600b8 | |
|
|
ae8c8aebe8 | |
|
|
f3f58bdbdc | |
|
|
e1113880a1 | |
|
|
bd6a5b75b5 | |
|
|
8793336dad | |
|
|
047b38971c | |
|
|
f5026009f9 |
81
cli/utils.py
81
cli/utils.py
|
|
@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Dict
|
|||
from rich.console import Console
|
||||
|
||||
from cli.models import AnalystType
|
||||
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -136,48 +137,11 @@ def select_research_depth() -> int:
|
|||
def select_shallow_thinking_agent(provider) -> str:
|
||||
"""Select shallow thinking llm engine using an interactive selection."""
|
||||
|
||||
# Define shallow thinking llm engine options with their corresponding model names
|
||||
# Ordering: medium → light → heavy (balanced first for quick tasks)
|
||||
# Within same tier, newer models first
|
||||
SHALLOW_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"),
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
|
||||
],
|
||||
"anthropic": [
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
"google": [
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
|
||||
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
||||
],
|
||||
"xai": [
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
],
|
||||
"openrouter": [
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
],
|
||||
"ollama": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
],
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Quick-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
|
||||
for display, value in get_model_options(provider, "quick")
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
@ -201,50 +165,11 @@ def select_shallow_thinking_agent(provider) -> str:
|
|||
def select_deep_thinking_agent(provider) -> str:
|
||||
"""Select deep thinking llm engine using an interactive selection."""
|
||||
|
||||
# Define deep thinking llm engine options with their corresponding model names
|
||||
# Ordering: heavy → medium → light (most capable first for deep tasks)
|
||||
# Within same tier, newer models first
|
||||
DEEP_AGENT_OPTIONS = {
|
||||
"openai": [
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
|
||||
],
|
||||
"anthropic": [
|
||||
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
|
||||
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
"google": [
|
||||
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
],
|
||||
"xai": [
|
||||
("Grok 4 - Flagship model", "grok-4-0709"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
],
|
||||
"openrouter": [
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
],
|
||||
"ollama": [
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
],
|
||||
}
|
||||
|
||||
choice = questionary.select(
|
||||
"Select Your [Deep-Thinking LLM Engine]:",
|
||||
choices=[
|
||||
questionary.Choice(display, value=value)
|
||||
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
|
||||
for display, value in get_model_options(provider, "deep")
|
||||
],
|
||||
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
|
||||
style=questionary.Style(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from tradingagents.llm_clients.google_client import GoogleClient
|
||||
|
||||
|
||||
class TestGoogleApiKeyStandardization(unittest.TestCase):
|
||||
"""Verify GoogleClient accepts unified api_key parameter."""
|
||||
|
||||
@patch("tradingagents.llm_clients.google_client.NormalizedChatGoogleGenerativeAI")
|
||||
def test_api_key_handling(self, mock_chat):
|
||||
test_cases = [
|
||||
("unified api_key is mapped", {"api_key": "test-key-123"}, "test-key-123"),
|
||||
("legacy google_api_key still works", {"google_api_key": "legacy-key-456"}, "legacy-key-456"),
|
||||
("unified api_key takes precedence", {"api_key": "unified", "google_api_key": "legacy"}, "unified"),
|
||||
]
|
||||
|
||||
for msg, kwargs, expected_key in test_cases:
|
||||
with self.subTest(msg=msg):
|
||||
mock_chat.reset_mock()
|
||||
client = GoogleClient("gemini-2.5-flash", **kwargs)
|
||||
client.get_llm()
|
||||
call_kwargs = mock_chat.call_args[1]
|
||||
self.assertEqual(call_kwargs.get("google_api_key"), expected_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
import unittest
|
||||
import warnings
|
||||
|
||||
from tradingagents.llm_clients.base_client import BaseLLMClient
|
||||
from tradingagents.llm_clients.model_catalog import get_known_models
|
||||
from tradingagents.llm_clients.validators import validate_model
|
||||
|
||||
|
||||
class DummyLLMClient(BaseLLMClient):
|
||||
def __init__(self, provider: str, model: str):
|
||||
self.provider = provider
|
||||
super().__init__(model)
|
||||
|
||||
def get_llm(self):
|
||||
self.warn_if_unknown_model()
|
||||
return object()
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
return validate_model(self.provider, self.model)
|
||||
|
||||
|
||||
class ModelValidationTests(unittest.TestCase):
|
||||
def test_cli_catalog_models_are_all_validator_approved(self):
|
||||
for provider, models in get_known_models().items():
|
||||
if provider in ("ollama", "openrouter"):
|
||||
continue
|
||||
|
||||
for model in models:
|
||||
with self.subTest(provider=provider, model=model):
|
||||
self.assertTrue(validate_model(provider, model))
|
||||
|
||||
def test_unknown_model_emits_warning_for_strict_provider(self):
|
||||
client = DummyLLMClient("openai", "not-a-real-openai-model")
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
client.get_llm()
|
||||
|
||||
self.assertEqual(len(caught), 1)
|
||||
self.assertIn("not-a-real-openai-model", str(caught[0].message))
|
||||
self.assertIn("openai", str(caught[0].message))
|
||||
|
||||
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
|
||||
for provider in ("openrouter", "ollama"):
|
||||
client = DummyLLMClient(provider, "custom-model-name")
|
||||
|
||||
with self.subTest(provider=provider):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
warnings.simplefilter("always")
|
||||
client.get_llm()
|
||||
|
||||
self.assertEqual(caught, [])
|
||||
|
|
@ -23,9 +23,10 @@ def get_indicators(
|
|||
# LLMs sometimes pass multiple indicators as a comma-separated string;
|
||||
# split and process each individually.
|
||||
indicators = [i.strip() for i in indicator.split(",") if i.strip()]
|
||||
if len(indicators) > 1:
|
||||
results = []
|
||||
for ind in indicators:
|
||||
results = []
|
||||
for ind in indicators:
|
||||
try:
|
||||
results.append(route_to_vendor("get_indicators", symbol, ind, curr_date, look_back_days))
|
||||
return "\n\n".join(results)
|
||||
return route_to_vendor("get_indicators", symbol, indicator.strip(), curr_date, look_back_days)
|
||||
except ValueError as e:
|
||||
results.append(str(e))
|
||||
return "\n\n".join(results)
|
||||
|
|
@ -1,6 +1,23 @@
|
|||
from .alpha_vantage_common import _make_api_request
|
||||
|
||||
|
||||
def _filter_reports_by_date(result, curr_date: str):
|
||||
"""Filter annualReports/quarterlyReports to exclude entries after curr_date.
|
||||
|
||||
Prevents look-ahead bias by removing fiscal periods that end after
|
||||
the simulation's current date.
|
||||
"""
|
||||
if not curr_date or not isinstance(result, dict):
|
||||
return result
|
||||
for key in ("annualReports", "quarterlyReports"):
|
||||
if key in result:
|
||||
result[key] = [
|
||||
r for r in result[key]
|
||||
if r.get("fiscalDateEnding", "") <= curr_date
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
||||
"""
|
||||
Retrieve comprehensive fundamental data for a given ticker symbol using Alpha Vantage.
|
||||
|
|
@ -19,59 +36,20 @@ def get_fundamentals(ticker: str, curr_date: str = None) -> str:
|
|||
return _make_api_request("OVERVIEW", params)
|
||||
|
||||
|
||||
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
"""
|
||||
Retrieve balance sheet data for a given ticker symbol using Alpha Vantage.
|
||||
|
||||
Args:
|
||||
ticker (str): Ticker symbol of the company
|
||||
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
|
||||
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
|
||||
|
||||
Returns:
|
||||
str: Balance sheet data with normalized fields
|
||||
"""
|
||||
params = {
|
||||
"symbol": ticker,
|
||||
}
|
||||
|
||||
return _make_api_request("BALANCE_SHEET", params)
|
||||
def get_balance_sheet(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||
"""Retrieve balance sheet data for a given ticker symbol using Alpha Vantage."""
|
||||
result = _make_api_request("BALANCE_SHEET", {"symbol": ticker})
|
||||
return _filter_reports_by_date(result, curr_date)
|
||||
|
||||
|
||||
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
"""
|
||||
Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage.
|
||||
|
||||
Args:
|
||||
ticker (str): Ticker symbol of the company
|
||||
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
|
||||
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
|
||||
|
||||
Returns:
|
||||
str: Cash flow statement data with normalized fields
|
||||
"""
|
||||
params = {
|
||||
"symbol": ticker,
|
||||
}
|
||||
|
||||
return _make_api_request("CASH_FLOW", params)
|
||||
def get_cashflow(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||
"""Retrieve cash flow statement data for a given ticker symbol using Alpha Vantage."""
|
||||
result = _make_api_request("CASH_FLOW", {"symbol": ticker})
|
||||
return _filter_reports_by_date(result, curr_date)
|
||||
|
||||
|
||||
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None) -> str:
|
||||
"""
|
||||
Retrieve income statement data for a given ticker symbol using Alpha Vantage.
|
||||
|
||||
Args:
|
||||
ticker (str): Ticker symbol of the company
|
||||
freq (str): Reporting frequency: annual/quarterly (default quarterly) - not used for Alpha Vantage
|
||||
curr_date (str): Current date you are trading at, yyyy-mm-dd (not used for Alpha Vantage)
|
||||
|
||||
Returns:
|
||||
str: Income statement data with normalized fields
|
||||
"""
|
||||
params = {
|
||||
"symbol": ticker,
|
||||
}
|
||||
|
||||
return _make_api_request("INCOME_STATEMENT", params)
|
||||
def get_income_statement(ticker: str, freq: str = "quarterly", curr_date: str = None):
|
||||
"""Retrieve income statement data for a given ticker symbol using Alpha Vantage."""
|
||||
result = _make_api_request("INCOME_STATEMENT", {"symbol": ticker})
|
||||
return _filter_reports_by_date(result, curr_date)
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,64 @@ def _clean_dataframe(data: pd.DataFrame) -> pd.DataFrame:
|
|||
return data
|
||||
|
||||
|
||||
def load_ohlcv(symbol: str, curr_date: str) -> pd.DataFrame:
|
||||
"""Fetch OHLCV data with caching, filtered to prevent look-ahead bias.
|
||||
|
||||
Downloads 15 years of data up to today and caches per symbol. On
|
||||
subsequent calls the cache is reused. Rows after curr_date are
|
||||
filtered out so backtests never see future prices.
|
||||
"""
|
||||
config = get_config()
|
||||
curr_date_dt = pd.to_datetime(curr_date)
|
||||
|
||||
# Cache uses a fixed window (15y to today) so one file per symbol
|
||||
today_date = pd.Timestamp.today()
|
||||
start_date = today_date - pd.DateOffset(years=5)
|
||||
start_str = start_date.strftime("%Y-%m-%d")
|
||||
end_str = today_date.strftime("%Y-%m-%d")
|
||||
|
||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||
data_file = os.path.join(
|
||||
config["data_cache_dir"],
|
||||
f"{symbol}-YFin-data-{start_str}-{end_str}.csv",
|
||||
)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
data = pd.read_csv(data_file, on_bad_lines="skip")
|
||||
else:
|
||||
data = yf_retry(lambda: yf.download(
|
||||
symbol,
|
||||
start=start_str,
|
||||
end=end_str,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=True,
|
||||
))
|
||||
data = data.reset_index()
|
||||
data.to_csv(data_file, index=False)
|
||||
|
||||
data = _clean_dataframe(data)
|
||||
|
||||
# Filter to curr_date to prevent look-ahead bias in backtesting
|
||||
data = data[data["Date"] <= curr_date_dt]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def filter_financials_by_date(data: pd.DataFrame, curr_date: str) -> pd.DataFrame:
|
||||
"""Drop financial statement columns (fiscal period timestamps) after curr_date.
|
||||
|
||||
yfinance financial statements use fiscal period end dates as columns.
|
||||
Columns after curr_date represent future data and are removed to
|
||||
prevent look-ahead bias.
|
||||
"""
|
||||
if not curr_date or data.empty:
|
||||
return data
|
||||
cutoff = pd.Timestamp(curr_date)
|
||||
mask = pd.to_datetime(data.columns, errors="coerce") <= cutoff
|
||||
return data.loc[:, mask]
|
||||
|
||||
|
||||
class StockstatsUtils:
|
||||
@staticmethod
|
||||
def get_stock_stats(
|
||||
|
|
@ -55,42 +113,10 @@ class StockstatsUtils:
|
|||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
):
|
||||
config = get_config()
|
||||
|
||||
today_date = pd.Timestamp.today()
|
||||
curr_date_dt = pd.to_datetime(curr_date)
|
||||
|
||||
end_date = today_date
|
||||
start_date = today_date - pd.DateOffset(years=15)
|
||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Ensure cache directory exists
|
||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||
|
||||
data_file = os.path.join(
|
||||
config["data_cache_dir"],
|
||||
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
|
||||
)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
data = pd.read_csv(data_file, on_bad_lines="skip")
|
||||
else:
|
||||
data = yf_retry(lambda: yf.download(
|
||||
symbol,
|
||||
start=start_date_str,
|
||||
end=end_date_str,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=True,
|
||||
))
|
||||
data = data.reset_index()
|
||||
data.to_csv(data_file, index=False)
|
||||
|
||||
data = _clean_dataframe(data)
|
||||
data = load_ohlcv(symbol, curr_date)
|
||||
df = wrap(data)
|
||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
|
||||
curr_date_str = pd.to_datetime(curr_date).strftime("%Y-%m-%d")
|
||||
|
||||
df[indicator] # trigger stockstats to calculate the indicator
|
||||
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from datetime import datetime
|
|||
from dateutil.relativedelta import relativedelta
|
||||
import yfinance as yf
|
||||
import os
|
||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
|
||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry, load_ohlcv, filter_financials_by_date
|
||||
|
||||
def get_YFin_data_online(
|
||||
symbol: Annotated[str, "ticker symbol of the company"],
|
||||
|
|
@ -194,58 +194,9 @@ def _get_stock_stats_bulk(
|
|||
Fetches data once and calculates indicator for all available dates.
|
||||
Returns dict mapping date strings to indicator values.
|
||||
"""
|
||||
from .config import get_config
|
||||
import pandas as pd
|
||||
from stockstats import wrap
|
||||
import os
|
||||
|
||||
config = get_config()
|
||||
online = config["data_vendors"]["technical_indicators"] != "local"
|
||||
|
||||
if not online:
|
||||
# Local data path
|
||||
try:
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
config.get("data_cache_dir", "data"),
|
||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||
),
|
||||
on_bad_lines="skip",
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
|
||||
else:
|
||||
# Online data fetching with caching
|
||||
today_date = pd.Timestamp.today()
|
||||
curr_date_dt = pd.to_datetime(curr_date)
|
||||
|
||||
end_date = today_date
|
||||
start_date = today_date - pd.DateOffset(years=15)
|
||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||
|
||||
data_file = os.path.join(
|
||||
config["data_cache_dir"],
|
||||
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
|
||||
)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
data = pd.read_csv(data_file, on_bad_lines="skip")
|
||||
else:
|
||||
data = yf_retry(lambda: yf.download(
|
||||
symbol,
|
||||
start=start_date_str,
|
||||
end=end_date_str,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=True,
|
||||
))
|
||||
data = data.reset_index()
|
||||
data.to_csv(data_file, index=False)
|
||||
|
||||
data = _clean_dataframe(data)
|
||||
data = load_ohlcv(symbol, curr_date)
|
||||
df = wrap(data)
|
||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
|
||||
|
|
@ -353,7 +304,7 @@ def get_fundamentals(
|
|||
def get_balance_sheet(
|
||||
ticker: Annotated[str, "ticker symbol of the company"],
|
||||
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
||||
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||
):
|
||||
"""Get balance sheet data from yfinance."""
|
||||
try:
|
||||
|
|
@ -363,7 +314,9 @@ def get_balance_sheet(
|
|||
data = yf_retry(lambda: ticker_obj.quarterly_balance_sheet)
|
||||
else:
|
||||
data = yf_retry(lambda: ticker_obj.balance_sheet)
|
||||
|
||||
|
||||
data = filter_financials_by_date(data, curr_date)
|
||||
|
||||
if data.empty:
|
||||
return f"No balance sheet data found for symbol '{ticker}'"
|
||||
|
||||
|
|
@ -383,7 +336,7 @@ def get_balance_sheet(
|
|||
def get_cashflow(
|
||||
ticker: Annotated[str, "ticker symbol of the company"],
|
||||
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
||||
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||
):
|
||||
"""Get cash flow data from yfinance."""
|
||||
try:
|
||||
|
|
@ -393,7 +346,9 @@ def get_cashflow(
|
|||
data = yf_retry(lambda: ticker_obj.quarterly_cashflow)
|
||||
else:
|
||||
data = yf_retry(lambda: ticker_obj.cashflow)
|
||||
|
||||
|
||||
data = filter_financials_by_date(data, curr_date)
|
||||
|
||||
if data.empty:
|
||||
return f"No cash flow data found for symbol '{ticker}'"
|
||||
|
||||
|
|
@ -413,7 +368,7 @@ def get_cashflow(
|
|||
def get_income_statement(
|
||||
ticker: Annotated[str, "ticker symbol of the company"],
|
||||
freq: Annotated[str, "frequency of data: 'annual' or 'quarterly'"] = "quarterly",
|
||||
curr_date: Annotated[str, "current date (not used for yfinance)"] = None
|
||||
curr_date: Annotated[str, "current date in YYYY-MM-DD format"] = None
|
||||
):
|
||||
"""Get income statement data from yfinance."""
|
||||
try:
|
||||
|
|
@ -423,7 +378,9 @@ def get_income_statement(
|
|||
data = yf_retry(lambda: ticker_obj.quarterly_income_stmt)
|
||||
else:
|
||||
data = yf_retry(lambda: ticker_obj.income_stmt)
|
||||
|
||||
|
||||
data = filter_financials_by_date(data, curr_date)
|
||||
|
||||
if data.empty:
|
||||
return f"No income statement data found for symbol '{ticker}'"
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import yfinance as yf
|
|||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from .stockstats_utils import yf_retry
|
||||
|
||||
|
||||
def _extract_article_data(article: dict) -> dict:
|
||||
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
|
||||
|
|
@ -64,7 +66,7 @@ def get_news_yfinance(
|
|||
"""
|
||||
try:
|
||||
stock = yf.Ticker(ticker)
|
||||
news = stock.get_news(count=20)
|
||||
news = yf_retry(lambda: stock.get_news(count=20))
|
||||
|
||||
if not news:
|
||||
return f"No news found for {ticker}"
|
||||
|
|
@ -131,11 +133,11 @@ def get_global_news_yfinance(
|
|||
|
||||
try:
|
||||
for query in search_queries:
|
||||
search = yf.Search(
|
||||
query=query,
|
||||
search = yf_retry(lambda q=query: yf.Search(
|
||||
query=q,
|
||||
news_count=limit,
|
||||
enable_fuzzy_query=True,
|
||||
)
|
||||
))
|
||||
|
||||
if search.news:
|
||||
for article in search.news:
|
||||
|
|
@ -167,6 +169,11 @@ def get_global_news_yfinance(
|
|||
# Handle both flat and nested structures
|
||||
if "content" in article:
|
||||
data = _extract_article_data(article)
|
||||
# Skip articles published after curr_date (look-ahead guard)
|
||||
if data.get("pub_date"):
|
||||
pub_naive = data["pub_date"].replace(tzinfo=None) if hasattr(data["pub_date"], "replace") else data["pub_date"]
|
||||
if pub_naive > curr_dt + relativedelta(days=1):
|
||||
continue
|
||||
title = data["title"]
|
||||
publisher = data["publisher"]
|
||||
link = data["link"]
|
||||
|
|
|
|||
|
|
@ -5,20 +5,11 @@
|
|||
### 1. `validate_model()` is never called
|
||||
- Add validation call in `get_llm()` with warning (not error) for unknown models
|
||||
|
||||
### 2. Inconsistent parameter handling
|
||||
| Client | API Key Param | Special Params |
|
||||
|--------|---------------|----------------|
|
||||
| OpenAI | `api_key` | `reasoning_effort` |
|
||||
| Anthropic | `api_key` | `thinking_config` → `thinking` |
|
||||
| Google | `google_api_key` | `thinking_budget` |
|
||||
### 2. ~~Inconsistent parameter handling~~ (Fixed)
|
||||
- GoogleClient now accepts unified `api_key` and maps it to `google_api_key`
|
||||
|
||||
**Fix:** Standardize with unified `api_key` that maps to provider-specific keys
|
||||
### 3. ~~`base_url` accepted but ignored~~ (Fixed)
|
||||
- All clients now pass `base_url` to their respective LLM constructors
|
||||
|
||||
### 3. `base_url` accepted but ignored
|
||||
- `AnthropicClient`: accepts `base_url` but never uses it
|
||||
- `GoogleClient`: accepts `base_url` but never uses it (correct - Google doesn't support it)
|
||||
|
||||
**Fix:** Remove unused `base_url` from clients that don't support it
|
||||
|
||||
### 4. Update validators.py with models from CLI
|
||||
- Sync `VALID_MODELS` dict with CLI model options after Feature 2 is complete
|
||||
### 4. ~~Update validators.py with models from CLI~~ (Fixed)
|
||||
- Synced in v0.2.2
|
||||
|
|
|
|||
|
|
@ -31,8 +31,12 @@ class AnthropicClient(BaseLLMClient):
|
|||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatAnthropic instance."""
|
||||
self.warn_if_unknown_model()
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
if self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
for key in _PASSTHROUGH_KWARGS:
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
|
||||
def normalize_content(response):
|
||||
|
|
@ -29,6 +30,27 @@ class BaseLLMClient(ABC):
|
|||
self.base_url = base_url
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Return the provider name used in warning messages."""
|
||||
provider = getattr(self, "provider", None)
|
||||
if provider:
|
||||
return str(provider)
|
||||
return self.__class__.__name__.removesuffix("Client").lower()
|
||||
|
||||
def warn_if_unknown_model(self) -> None:
|
||||
"""Warn when the model is outside the known list for the provider."""
|
||||
if self.validate_model():
|
||||
return
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
f"Model '{self.model}' is not in the known model list for "
|
||||
f"provider '{self.get_provider_name()}'. Continuing anyway."
|
||||
),
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_llm(self) -> Any:
|
||||
"""Return the configured LLM instance."""
|
||||
|
|
|
|||
|
|
@ -25,12 +25,21 @@ class GoogleClient(BaseLLMClient):
|
|||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||
self.warn_if_unknown_model()
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "google_api_key", "callbacks", "http_client", "http_async_client"):
|
||||
if self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
for key in ("timeout", "max_retries", "callbacks", "http_client", "http_async_client"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
# Unified api_key maps to provider-specific google_api_key
|
||||
google_api_key = self.kwargs.get("api_key") or self.kwargs.get("google_api_key")
|
||||
if google_api_key:
|
||||
llm_kwargs["google_api_key"] = google_api_key
|
||||
|
||||
# Map thinking_level to appropriate API param based on model
|
||||
# Gemini 3 Pro: low, high
|
||||
# Gemini 3 Flash: minimal, low, medium, high
|
||||
|
|
|
|||
|
|
@ -0,0 +1,107 @@
|
|||
"""Shared model catalog for CLI selections and validation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
ModelOption = Tuple[str, str]
|
||||
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
|
||||
|
||||
|
||||
MODEL_OPTIONS: ProviderModeOptions = {
|
||||
"openai": {
|
||||
"quick": [
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"),
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
|
||||
],
|
||||
"deep": [
|
||||
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
|
||||
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
|
||||
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
|
||||
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
|
||||
],
|
||||
},
|
||||
"anthropic": {
|
||||
"quick": [
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
"deep": [
|
||||
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
|
||||
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
|
||||
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
|
||||
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
|
||||
],
|
||||
},
|
||||
"google": {
|
||||
"quick": [
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
|
||||
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
|
||||
],
|
||||
"deep": [
|
||||
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
|
||||
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
|
||||
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
|
||||
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
|
||||
],
|
||||
},
|
||||
"xai": {
|
||||
"quick": [
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
],
|
||||
"deep": [
|
||||
("Grok 4 - Flagship model", "grok-4-0709"),
|
||||
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
|
||||
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
|
||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||
],
|
||||
},
|
||||
"openrouter": {
|
||||
"quick": [
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
],
|
||||
"deep": [
|
||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
|
||||
],
|
||||
},
|
||||
"ollama": {
|
||||
"quick": [
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
],
|
||||
"deep": [
|
||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
|
||||
"""Return shared model options for a provider and selection mode."""
|
||||
return MODEL_OPTIONS[provider.lower()][mode]
|
||||
|
||||
|
||||
def get_known_models() -> Dict[str, List[str]]:
|
||||
"""Build known model names from the shared CLI catalog."""
|
||||
return {
|
||||
provider: sorted(
|
||||
{
|
||||
value
|
||||
for options in mode_options.values()
|
||||
for _, value in options
|
||||
}
|
||||
)
|
||||
for provider, mode_options in MODEL_OPTIONS.items()
|
||||
}
|
||||
|
|
@ -53,6 +53,7 @@ class OpenAIClient(BaseLLMClient):
|
|||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatOpenAI instance."""
|
||||
self.warn_if_unknown_model()
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
# Provider-specific base URL and auth
|
||||
|
|
|
|||
|
|
@ -1,53 +1,12 @@
|
|||
"""Model name validators for each provider.
|
||||
"""Model name validators for each provider."""
|
||||
|
||||
from .model_catalog import get_known_models
|
||||
|
||||
Only validates model names - does NOT enforce limits.
|
||||
Let LLM providers use their own defaults for unspecified params.
|
||||
"""
|
||||
|
||||
VALID_MODELS = {
|
||||
"openai": [
|
||||
# GPT-5 series
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
"gpt-5.2",
|
||||
"gpt-5.1",
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
# GPT-4.1 series
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
],
|
||||
"anthropic": [
|
||||
# Claude 4.6 series (latest)
|
||||
"claude-opus-4-6",
|
||||
"claude-sonnet-4-6",
|
||||
# Claude 4.5 series
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
],
|
||||
"google": [
|
||||
# Gemini 3.1 series (preview)
|
||||
"gemini-3.1-pro-preview",
|
||||
"gemini-3.1-flash-lite-preview",
|
||||
# Gemini 3 series (preview)
|
||||
"gemini-3-flash-preview",
|
||||
# Gemini 2.5 series
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
],
|
||||
"xai": [
|
||||
# Grok 4.1 series
|
||||
"grok-4-1-fast-reasoning",
|
||||
"grok-4-1-fast-non-reasoning",
|
||||
# Grok 4 series
|
||||
"grok-4-0709",
|
||||
"grok-4-fast-reasoning",
|
||||
"grok-4-fast-non-reasoning",
|
||||
],
|
||||
provider: models
|
||||
for provider, models in get_known_models().items()
|
||||
if provider not in ("ollama", "openrouter")
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue