Improve Ollama defaults, data resilience, and docs
This commit is contained in:
parent
589b351f2a
commit
d861eccd2d
27
README.md
27
README.md
|
|
@ -117,6 +117,12 @@ Install the package and its dependencies:
|
||||||
pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you are actively modifying the project, prefer an editable install so the
|
||||||
|
`tradingagents` CLI always reflects your local code changes:
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
### Required APIs
|
### Required APIs
|
||||||
|
|
||||||
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
|
TradingAgents supports multiple LLM providers. Set the API key for your chosen provider:
|
||||||
|
|
@ -132,6 +138,17 @@ export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||||
|
|
||||||
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
For local models, configure Ollama with `llm_provider: "ollama"` in your config.
|
||||||
|
|
||||||
|
Example local setup for a default Ollama install running `qwen3:8b`:
|
||||||
|
```bash
|
||||||
|
export TRADINGAGENTS_LLM_PROVIDER=ollama
|
||||||
|
export TRADINGAGENTS_BACKEND_URL=http://localhost:11434/v1
|
||||||
|
export TRADINGAGENTS_QUICK_THINK_LLM=qwen3:8b
|
||||||
|
export TRADINGAGENTS_DEEP_THINK_LLM=qwen3:8b
|
||||||
|
```
|
||||||
|
|
||||||
|
The default config in this fork is also set up for Ollama on
|
||||||
|
`http://localhost:11434/v1`, with both thinking models pointing to `qwen3:8b`.
|
||||||
|
|
||||||
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
Alternatively, copy `.env.example` to `.env` and fill in your keys:
|
||||||
```bash
|
```bash
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
|
|
@ -198,6 +215,16 @@ _, decision = ta.propagate("NVDA", "2026-01-15")
|
||||||
print(decision)
|
print(decision)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Local Ollama example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
config = DEFAULT_CONFIG.copy()
|
||||||
|
config["llm_provider"] = "ollama"
|
||||||
|
config["backend_url"] = "http://localhost:11434/v1"
|
||||||
|
config["deep_think_llm"] = "qwen3:8b"
|
||||||
|
config["quick_think_llm"] = "qwen3:8b"
|
||||||
|
```
|
||||||
|
|
||||||
See `tradingagents/default_config.py` for all configuration options.
|
See `tradingagents/default_config.py` for all configuration options.
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
|
||||||
17
cli/main.py
17
cli/main.py
|
|
@ -939,6 +939,15 @@ def run_analysis():
|
||||||
selected_set = {analyst.value for analyst in selections["analysts"]}
|
selected_set = {analyst.value for analyst in selections["analysts"]}
|
||||||
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
|
selected_analyst_keys = [a for a in ANALYST_ORDER if a in selected_set]
|
||||||
|
|
||||||
|
# Create result directory early so graph logging uses the same run-specific base path.
|
||||||
|
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
||||||
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
report_dir = results_dir / "reports"
|
||||||
|
report_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
log_file = results_dir / "message_tool.log"
|
||||||
|
log_file.touch(exist_ok=True)
|
||||||
|
config["results_dir"] = str(results_dir)
|
||||||
|
|
||||||
# Initialize the graph with callbacks bound to LLMs
|
# Initialize the graph with callbacks bound to LLMs
|
||||||
graph = TradingAgentsGraph(
|
graph = TradingAgentsGraph(
|
||||||
selected_analyst_keys,
|
selected_analyst_keys,
|
||||||
|
|
@ -953,14 +962,6 @@ def run_analysis():
|
||||||
# Track start time for elapsed display
|
# Track start time for elapsed display
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Create result directory
|
|
||||||
results_dir = Path(config["results_dir"]) / selections["ticker"] / selections["analysis_date"]
|
|
||||||
results_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
report_dir = results_dir / "reports"
|
|
||||||
report_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
log_file = results_dir / "message_tool.log"
|
|
||||||
log_file.touch(exist_ok=True)
|
|
||||||
|
|
||||||
def save_message_decorator(obj, func_name):
|
def save_message_decorator(obj, func_name):
|
||||||
func = getattr(obj, func_name)
|
func = getattr(obj, func_name)
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,7 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:8B (local)", "qwen3:8b"),
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||||
],
|
],
|
||||||
|
|
@ -236,7 +236,7 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
"ollama": [
|
"ollama": [
|
||||||
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
|
||||||
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
|
||||||
("Qwen3:latest (8B, local)", "qwen3:latest"),
|
("Qwen3:8B (local)", "qwen3:8b"),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
4
main.py
4
main.py
|
|
@ -8,8 +8,8 @@ load_dotenv()
|
||||||
|
|
||||||
# Create a custom config
|
# Create a custom config
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["deep_think_llm"] = "gpt-5-mini" # Use a different model
|
config["deep_think_llm"] = "qwen3:8b"
|
||||||
config["quick_think_llm"] = "gpt-5-mini" # Use a different model
|
config["quick_think_llm"] = "qwen3:8b"
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||||
|
|
||||||
# Configure data vendors (default uses yfinance, no extra API keys needed)
|
# Configure data vendors (default uses yfinance, no extra API keys needed)
|
||||||
|
|
|
||||||
15
test.py
15
test.py
|
|
@ -1,11 +1,12 @@
|
||||||
import time
|
import time
|
||||||
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
from tradingagents.dataflows.y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
||||||
|
|
||||||
print("Testing optimized implementation with 30-day lookback:")
|
if __name__ == "__main__":
|
||||||
start_time = time.time()
|
print("Testing optimized implementation with 30-day lookback:")
|
||||||
result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30)
|
start_time = time.time()
|
||||||
end_time = time.time()
|
result = get_stock_stats_indicators_window("AAPL", "macd", "2024-11-01", 30)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
print(f"Execution time: {end_time - start_time:.2f} seconds")
|
print(f"Execution time: {end_time - start_time:.2f} seconds")
|
||||||
print(f"Result length: {len(result)} characters")
|
print(f"Result length: {len(result)} characters")
|
||||||
print(result)
|
print(result)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tradingagents.dataflows.utils import normalize_date_range, normalize_iso_date
|
||||||
|
|
||||||
|
|
||||||
|
class DateNormalizationTests(unittest.TestCase):
|
||||||
|
def test_normalize_iso_date_keeps_valid_dates(self):
|
||||||
|
self.assertEqual(normalize_iso_date("2024-02-29"), "2024-02-29")
|
||||||
|
|
||||||
|
def test_normalize_iso_date_clamps_invalid_month_end(self):
|
||||||
|
self.assertEqual(normalize_iso_date("2026-02-29"), "2026-02-28")
|
||||||
|
self.assertEqual(normalize_iso_date("2026-04-31"), "2026-04-30")
|
||||||
|
|
||||||
|
def test_normalize_date_range_orders_dates_after_normalization(self):
|
||||||
|
self.assertEqual(
|
||||||
|
normalize_date_range("2026-03-29", "2026-02-29"),
|
||||||
|
("2026-02-28", "2026-03-29"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_normalize_iso_date_rejects_bad_format(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
normalize_iso_date("2026/02/29")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
|
|
||||||
|
|
||||||
|
class GraphLoggingTests(unittest.TestCase):
|
||||||
|
def test_log_state_uses_configured_results_dir(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
graph = TradingAgentsGraph.__new__(TradingAgentsGraph)
|
||||||
|
graph.config = {"results_dir": tmpdir}
|
||||||
|
graph.ticker = "SPY"
|
||||||
|
graph.log_states_dict = {}
|
||||||
|
|
||||||
|
final_state = {
|
||||||
|
"company_of_interest": "SPY",
|
||||||
|
"trade_date": "2026-03-29",
|
||||||
|
"market_report": "market",
|
||||||
|
"sentiment_report": "sentiment",
|
||||||
|
"news_report": "news",
|
||||||
|
"fundamentals_report": "fundamentals",
|
||||||
|
"investment_debate_state": {
|
||||||
|
"bull_history": "bull",
|
||||||
|
"bear_history": "bear",
|
||||||
|
"history": "history",
|
||||||
|
"current_response": "current",
|
||||||
|
"judge_decision": "judge",
|
||||||
|
},
|
||||||
|
"trader_investment_plan": "trade plan",
|
||||||
|
"risk_debate_state": {
|
||||||
|
"aggressive_history": "aggressive",
|
||||||
|
"conservative_history": "conservative",
|
||||||
|
"neutral_history": "neutral",
|
||||||
|
"history": "risk history",
|
||||||
|
"judge_decision": "portfolio judge",
|
||||||
|
},
|
||||||
|
"investment_plan": "investment plan",
|
||||||
|
"final_trade_decision": "HOLD",
|
||||||
|
}
|
||||||
|
|
||||||
|
graph._log_state("2026-03-29", final_state)
|
||||||
|
|
||||||
|
log_path = (
|
||||||
|
Path(tmpdir)
|
||||||
|
/ "TradingAgentsStrategy_logs"
|
||||||
|
/ "full_states_log_2026-03-29.json"
|
||||||
|
)
|
||||||
|
self.assertTrue(log_path.exists())
|
||||||
|
|
||||||
|
payload = json.loads(log_path.read_text())
|
||||||
|
self.assertIn("2026-03-29", payload)
|
||||||
|
self.assertEqual(payload["2026-03-29"]["final_trade_decision"], "HOLD")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClientFactoryTests(unittest.TestCase):
|
||||||
|
def test_invalid_known_provider_model_emits_warning(self):
|
||||||
|
with warnings.catch_warnings(record=True) as caught:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
client = create_llm_client("openai", "not-a-real-openai-model")
|
||||||
|
|
||||||
|
self.assertEqual(client.provider, "openai")
|
||||||
|
self.assertTrue(
|
||||||
|
any("not-a-real-openai-model" in str(w.message) for w in caught)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_known_valid_model_does_not_emit_warning(self):
|
||||||
|
with warnings.catch_warnings(record=True) as caught:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
client = create_llm_client("openai", "gpt-5-mini")
|
||||||
|
|
||||||
|
self.assertEqual(client.provider, "openai")
|
||||||
|
self.assertEqual(len(caught), 0)
|
||||||
|
|
||||||
|
def test_ollama_custom_models_do_not_emit_warning(self):
|
||||||
|
with warnings.catch_warnings(record=True) as caught:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
client = create_llm_client("ollama", "qwen3:8b")
|
||||||
|
|
||||||
|
self.assertEqual(client.provider, "ollama")
|
||||||
|
self.assertEqual(len(caught), 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -5,6 +5,8 @@ import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
|
from .utils import normalize_date_range, normalize_iso_date
|
||||||
|
|
||||||
API_BASE_URL = "https://www.alphavantage.co/query"
|
API_BASE_URL = "https://www.alphavantage.co/query"
|
||||||
|
|
||||||
def get_api_key() -> str:
|
def get_api_key() -> str:
|
||||||
|
|
@ -22,7 +24,7 @@ def format_datetime_for_api(date_input) -> str:
|
||||||
return date_input
|
return date_input
|
||||||
# Try to parse common date formats
|
# Try to parse common date formats
|
||||||
try:
|
try:
|
||||||
dt = datetime.strptime(date_input, "%Y-%m-%d")
|
dt = datetime.strptime(normalize_iso_date(date_input), "%Y-%m-%d")
|
||||||
return dt.strftime("%Y%m%dT0000")
|
return dt.strftime("%Y%m%dT0000")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
try:
|
try:
|
||||||
|
|
@ -100,6 +102,7 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
|
||||||
return csv_data
|
return csv_data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||||
# Parse CSV data
|
# Parse CSV data
|
||||||
df = pd.read_csv(StringIO(csv_data))
|
df = pd.read_csv(StringIO(csv_data))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
from .alpha_vantage_common import _make_api_request, format_datetime_for_api
|
||||||
|
from .utils import normalize_date_range, normalize_iso_date
|
||||||
|
|
||||||
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||||
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
"""Returns live and historical market news & sentiment data from premier news outlets worldwide.
|
||||||
|
|
@ -14,6 +15,8 @@ def get_news(ticker, start_date, end_date) -> dict[str, str] | str:
|
||||||
Dictionary containing news sentiment data or JSON string.
|
Dictionary containing news sentiment data or JSON string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"tickers": ticker,
|
"tickers": ticker,
|
||||||
"time_from": format_datetime_for_api(start_date),
|
"time_from": format_datetime_for_api(start_date),
|
||||||
|
|
@ -38,6 +41,7 @@ def get_global_news(curr_date, look_back_days: int = 7, limit: int = 50) -> dict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
# Calculate start date
|
# Calculate start date
|
||||||
|
curr_date = normalize_iso_date(curr_date)
|
||||||
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
start_dt = curr_dt - timedelta(days=look_back_days)
|
start_dt = curr_dt - timedelta(days=look_back_days)
|
||||||
start_date = start_dt.strftime("%Y-%m-%d")
|
start_date = start_dt.strftime("%Y-%m-%d")
|
||||||
|
|
@ -68,4 +72,4 @@ def get_insider_transactions(symbol: str) -> dict[str, str] | str:
|
||||||
"symbol": symbol,
|
"symbol": symbol,
|
||||||
}
|
}
|
||||||
|
|
||||||
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
return _make_api_request("INSIDER_TRANSACTIONS", params)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
|
from .alpha_vantage_common import _make_api_request, _filter_csv_by_date_range
|
||||||
|
from .utils import normalize_date_range
|
||||||
|
|
||||||
def get_stock(
|
def get_stock(
|
||||||
symbol: str,
|
symbol: str,
|
||||||
|
|
@ -18,6 +19,8 @@ def get_stock(
|
||||||
Returns:
|
Returns:
|
||||||
CSV string containing the daily adjusted time series data filtered to the date range.
|
CSV string containing the daily adjusted time series data filtered to the date range.
|
||||||
"""
|
"""
|
||||||
|
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||||
|
|
||||||
# Parse dates to determine the range
|
# Parse dates to determine the range
|
||||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||||
today = datetime.now()
|
today = datetime.now()
|
||||||
|
|
@ -35,4 +38,4 @@ def get_stock(
|
||||||
|
|
||||||
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
|
response = _make_api_request("TIME_SERIES_DAILY_ADJUSTED", params)
|
||||||
|
|
||||||
return _filter_csv_by_date_range(response, start_date, end_date)
|
return _filter_csv_by_date_range(response, start_date, end_date)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import calendar
|
||||||
from datetime import date, timedelta, datetime
|
from datetime import date, timedelta, datetime
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
|
@ -16,6 +17,50 @@ def get_current_date():
|
||||||
return date.today().strftime("%Y-%m-%d")
|
return date.today().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_iso_date(date_str: str) -> str:
|
||||||
|
"""Normalize YYYY-MM-DD dates, clamping invalid month-end days.
|
||||||
|
|
||||||
|
LLM tool calls occasionally produce dates like 2026-02-29 when they mean
|
||||||
|
"the end of February". For valid ISO dates this returns the input as-is.
|
||||||
|
For invalid day-of-month values within a valid year/month, it clamps to the
|
||||||
|
last valid day of that month. Other malformed values still raise ValueError.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return datetime.strptime(date_str, "%Y-%m-%d").strftime("%Y-%m-%d")
|
||||||
|
except ValueError as exc:
|
||||||
|
try:
|
||||||
|
year_str, month_str, day_str = date_str.split("-")
|
||||||
|
year = int(year_str)
|
||||||
|
month = int(month_str)
|
||||||
|
day = int(day_str)
|
||||||
|
except (AttributeError, ValueError) as parse_exc:
|
||||||
|
raise ValueError(f"Unsupported date format: {date_str}") from parse_exc
|
||||||
|
|
||||||
|
if not 1 <= month <= 12:
|
||||||
|
raise exc
|
||||||
|
if day < 1:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
last_day = calendar.monthrange(year, month)[1]
|
||||||
|
if day > last_day:
|
||||||
|
return f"{year:04d}-{month:02d}-{last_day:02d}"
|
||||||
|
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_date_range(start_date: str, end_date: str) -> tuple[str, str]:
|
||||||
|
"""Normalize and order an ISO date range."""
|
||||||
|
normalized_start = normalize_iso_date(start_date)
|
||||||
|
normalized_end = normalize_iso_date(end_date)
|
||||||
|
|
||||||
|
start_dt = datetime.strptime(normalized_start, "%Y-%m-%d")
|
||||||
|
end_dt = datetime.strptime(normalized_end, "%Y-%m-%d")
|
||||||
|
|
||||||
|
if start_dt <= end_dt:
|
||||||
|
return normalized_start, normalized_end
|
||||||
|
return normalized_end, normalized_start
|
||||||
|
|
||||||
|
|
||||||
def decorate_all_methods(decorator):
|
def decorate_all_methods(decorator):
|
||||||
def class_decorator(cls):
|
def class_decorator(cls):
|
||||||
for attr_name, attr_value in cls.__dict__.items():
|
for attr_name, attr_value in cls.__dict__.items():
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,14 @@ from dateutil.relativedelta import relativedelta
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
import os
|
import os
|
||||||
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
|
from .stockstats_utils import StockstatsUtils, _clean_dataframe, yf_retry
|
||||||
|
from .utils import normalize_date_range
|
||||||
|
|
||||||
def get_YFin_data_online(
|
def get_YFin_data_online(
|
||||||
symbol: Annotated[str, "ticker symbol of the company"],
|
symbol: Annotated[str, "ticker symbol of the company"],
|
||||||
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
start_date: Annotated[str, "Start date in yyyy-mm-dd format"],
|
||||||
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
end_date: Annotated[str, "End date in yyyy-mm-dd format"],
|
||||||
):
|
):
|
||||||
|
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||||
datetime.strptime(start_date, "%Y-%m-%d")
|
|
||||||
datetime.strptime(end_date, "%Y-%m-%d")
|
|
||||||
|
|
||||||
# Create ticker object
|
# Create ticker object
|
||||||
ticker = yf.Ticker(symbol.upper())
|
ticker = yf.Ticker(symbol.upper())
|
||||||
|
|
@ -461,4 +460,4 @@ def get_insider_transactions(
|
||||||
return header + csv_string
|
return header + csv_string
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
return f"Error retrieving insider transactions for {ticker}: {str(e)}"
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ import yfinance as yf
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
|
|
||||||
|
from .utils import normalize_date_range, normalize_iso_date
|
||||||
|
|
||||||
|
|
||||||
def _extract_article_data(article: dict) -> dict:
|
def _extract_article_data(article: dict) -> dict:
|
||||||
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
|
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
|
||||||
|
|
@ -63,6 +65,7 @@ def get_news_yfinance(
|
||||||
Formatted string containing news articles
|
Formatted string containing news articles
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
start_date, end_date = normalize_date_range(start_date, end_date)
|
||||||
stock = yf.Ticker(ticker)
|
stock = yf.Ticker(ticker)
|
||||||
news = stock.get_news(count=20)
|
news = stock.get_news(count=20)
|
||||||
|
|
||||||
|
|
@ -130,6 +133,7 @@ def get_global_news_yfinance(
|
||||||
seen_titles = set()
|
seen_titles = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
curr_date = normalize_iso_date(curr_date)
|
||||||
for query in search_queries:
|
for query in search_queries:
|
||||||
search = yf.Search(
|
search = yf.Search(
|
||||||
query=query,
|
query=query,
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ DEFAULT_CONFIG = {
|
||||||
"dataflows/data_cache",
|
"dataflows/data_cache",
|
||||||
),
|
),
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"llm_provider": "openai",
|
"llm_provider": os.getenv("TRADINGAGENTS_LLM_PROVIDER", "ollama"),
|
||||||
"deep_think_llm": "gpt-5.2",
|
"deep_think_llm": os.getenv("TRADINGAGENTS_DEEP_THINK_LLM", "qwen3:8b"),
|
||||||
"quick_think_llm": "gpt-5-mini",
|
"quick_think_llm": os.getenv("TRADINGAGENTS_QUICK_THINK_LLM", "qwen3:8b"),
|
||||||
"backend_url": "https://api.openai.com/v1",
|
"backend_url": os.getenv("TRADINGAGENTS_BACKEND_URL", "http://localhost:11434/v1"),
|
||||||
# Provider-specific thinking configuration
|
# Provider-specific thinking configuration
|
||||||
"google_thinking_level": None, # "high", "minimal", etc.
|
"google_thinking_level": None, # "high", "minimal", etc.
|
||||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class GraphSetup:
|
||||||
self.conditional_logic = conditional_logic
|
self.conditional_logic = conditional_logic
|
||||||
|
|
||||||
def setup_graph(
|
def setup_graph(
|
||||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
self, selected_analysts=None
|
||||||
):
|
):
|
||||||
"""Set up and compile the agent workflow graph.
|
"""Set up and compile the agent workflow graph.
|
||||||
|
|
||||||
|
|
@ -49,6 +49,9 @@ class GraphSetup:
|
||||||
- "news": News analyst
|
- "news": News analyst
|
||||||
- "fundamentals": Fundamentals analyst
|
- "fundamentals": Fundamentals analyst
|
||||||
"""
|
"""
|
||||||
|
if selected_analysts is None:
|
||||||
|
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||||
|
|
||||||
if len(selected_analysts) == 0:
|
if len(selected_analysts) == 0:
|
||||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class TradingAgentsGraph:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
selected_analysts=None,
|
||||||
debug=False,
|
debug=False,
|
||||||
config: Dict[str, Any] = None,
|
config: Dict[str, Any] = None,
|
||||||
callbacks: Optional[List] = None,
|
callbacks: Optional[List] = None,
|
||||||
|
|
@ -59,6 +59,8 @@ class TradingAgentsGraph:
|
||||||
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||||||
"""
|
"""
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
if selected_analysts is None:
|
||||||
|
selected_analysts = ["market", "social", "news", "fundamentals"]
|
||||||
self.config = config or DEFAULT_CONFIG
|
self.config = config or DEFAULT_CONFIG
|
||||||
self.callbacks = callbacks or []
|
self.callbacks = callbacks or []
|
||||||
|
|
||||||
|
|
@ -66,10 +68,7 @@ class TradingAgentsGraph:
|
||||||
set_config(self.config)
|
set_config(self.config)
|
||||||
|
|
||||||
# Create necessary directories
|
# Create necessary directories
|
||||||
os.makedirs(
|
os.makedirs(self.config["data_cache_dir"], exist_ok=True)
|
||||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize LLMs with provider-specific thinking configuration
|
# Initialize LLMs with provider-specific thinking configuration
|
||||||
llm_kwargs = self._get_provider_kwargs()
|
llm_kwargs = self._get_provider_kwargs()
|
||||||
|
|
@ -259,11 +258,11 @@ class TradingAgentsGraph:
|
||||||
}
|
}
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
directory = Path(self.config["results_dir"]) / "TradingAgentsStrategy_logs"
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(
|
with open(
|
||||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log_{trade_date}.json",
|
directory / f"full_states_log_{trade_date}.json",
|
||||||
"w",
|
"w",
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
) as f:
|
) as f:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
# LLM Clients - Consistency Improvements
|
# LLM Clients - Consistency Improvements
|
||||||
|
|
||||||
## Issues to Fix
|
## Completed
|
||||||
|
|
||||||
### 1. `validate_model()` is never called
|
### 1. `validate_model()` warning path
|
||||||
- Add validation call in `get_llm()` with warning (not error) for unknown models
|
- `create_llm_client()` now calls `validate_model()` and emits a warning for
|
||||||
|
unknown model names instead of failing immediately.
|
||||||
|
|
||||||
|
## Remaining Issues to Fix
|
||||||
|
|
||||||
### 2. Inconsistent parameter handling
|
### 2. Inconsistent parameter handling
|
||||||
| Client | API Key Param | Special Params |
|
| Client | API Key Param | Special Params |
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .base_client import BaseLLMClient
|
from .base_client import BaseLLMClient
|
||||||
|
|
@ -35,15 +36,22 @@ def create_llm_client(
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
if provider_lower in ("openai", "ollama", "openrouter"):
|
||||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
client = OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
elif provider_lower == "xai":
|
||||||
|
client = OpenAIClient(model, base_url, provider="xai", **kwargs)
|
||||||
|
elif provider_lower == "anthropic":
|
||||||
|
client = AnthropicClient(model, base_url, **kwargs)
|
||||||
|
elif provider_lower == "google":
|
||||||
|
client = GoogleClient(model, base_url, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
||||||
if provider_lower == "xai":
|
if not client.validate_model():
|
||||||
return OpenAIClient(model, base_url, provider="xai", **kwargs)
|
warnings.warn(
|
||||||
|
f"Model '{model}' is not in the known model list for provider "
|
||||||
|
f"'{provider_lower}'. The request may still work if the provider "
|
||||||
|
"supports it, but mis-typed model names will fail at runtime.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
if provider_lower == "anthropic":
|
return client
|
||||||
return AnthropicClient(model, base_url, **kwargs)
|
|
||||||
|
|
||||||
if provider_lower == "google":
|
|
||||||
return GoogleClient(model, base_url, **kwargs)
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,8 @@ class OpenAIClient(BaseLLMClient):
|
||||||
|
|
||||||
# Provider-specific base URL and auth
|
# Provider-specific base URL and auth
|
||||||
if self.provider in _PROVIDER_CONFIG:
|
if self.provider in _PROVIDER_CONFIG:
|
||||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
default_base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
||||||
llm_kwargs["base_url"] = base_url
|
llm_kwargs["base_url"] = self.base_url or default_base_url
|
||||||
if api_key_env:
|
if api_key_env:
|
||||||
api_key = os.environ.get(api_key_env)
|
api_key = os.environ.get(api_key_env)
|
||||||
if api_key:
|
if api_key:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue