From 99d5d2122dfba3cdf99a2fdf046bdef8f0f547ba Mon Sep 17 00:00:00 2001 From: Hewei603 Date: Sat, 14 Feb 2026 16:19:07 +0800 Subject: [PATCH] Fix yfinance rate limit and session compatibility in dataflows/y_finance.py --- cli/main.py | 8 +-- tradingagents/dataflows/y_finance.py | 97 +++++++++++++++++++++------- 2 files changed, 79 insertions(+), 26 deletions(-) diff --git a/cli/main.py b/cli/main.py index fb97d189..2237032e 100644 --- a/cli/main.py +++ b/cli/main.py @@ -462,7 +462,7 @@ def update_display(layout, spinner_text=None, stats_handler=None, start_time=Non def get_user_selections(): """Get all user selections before starting the analysis display.""" # Display ASCII art welcome message - with open("./cli/static/welcome.txt", "r") as f: + with open("./cli/static/welcome.txt", "r", encoding='utf-8') as f: welcome_ascii = f.read() # Create welcome box content @@ -948,7 +948,7 @@ def run_analysis(): func(*args, **kwargs) timestamp, message_type, content = obj.messages[-1] content = content.replace("\n", " ") # Replace newlines with spaces - with open(log_file, "a") as f: + with open(log_file, "a", encoding='utf-8') as f: f.write(f"{timestamp} [{message_type}] {content}\n") return wrapper @@ -959,7 +959,7 @@ def run_analysis(): func(*args, **kwargs) timestamp, tool_name, args = obj.tool_calls[-1] args_str = ", ".join(f"{k}={v}" for k, v in args.items()) - with open(log_file, "a") as f: + with open(log_file, "a", encoding='utf-8') as f: f.write(f"{timestamp} [Tool Call] {tool_name}({args_str})\n") return wrapper @@ -972,7 +972,7 @@ def run_analysis(): content = obj.report_sections[section_name] if content: file_name = f"{section_name}.md" - with open(report_dir / file_name, "w") as f: + with open(report_dir / file_name, "w", encoding='utf-8') as f: f.write(content) return wrapper diff --git a/tradingagents/dataflows/y_finance.py b/tradingagents/dataflows/y_finance.py index bc78d8b3..53b9067b 100644 --- a/tradingagents/dataflows/y_finance.py +++ b/tradingagents/dataflows/y_finance.py @@ -1,51 +1,104 @@ from typing import Annotated from datetime import datetime from dateutil.relativedelta import relativedelta +import time +from yfinance.exceptions import YFRateLimitError +import requests import yfinance as yf import os from .stockstats_utils import StockstatsUtils -def get_YFin_data_online( +import pandas as pd # for polygon data processing +from polygon import RESTClient # for polygon client (need to pip install polygon-api-client) + +# Polygon API Key: free registration https://polygon.io/ (free layer 5min/day enough) +POLYGON_API_KEY = "your key" # replace with your key + +def get_YFin_data_online( # rename? or keep, add fallback inside symbol: Annotated[str, "ticker symbol of the company"], start_date: Annotated[str, "Start date in yyyy-mm-dd format"], end_date: Annotated[str, "End date in yyyy-mm-dd format"], + max_retries: int = 5, + base_delay: float = 2.0, ): + """ + original yfinance + Polygon fallback: yfinance rate limit automatically switch to Polygon. + """ + # date validation + try: + datetime.strptime(start_date, "%Y-%m-%d") + datetime.strptime(end_date, "%Y-%m-%d") + except ValueError: + raise ValueError("date format must be YYYY-MM-DD") - datetime.strptime(start_date, "%Y-%m-%d") - datetime.strptime(end_date, "%Y-%m-%d") + time.sleep(1.0) # prevent concurrency - # Create ticker object - ticker = yf.Ticker(symbol.upper()) + # first try yfinance + for attempt in range(max_retries): + try: + ticker = yf.Ticker(symbol.upper()) + data = ticker.history(start=start_date, end=end_date) + if not data.empty: + return _format_data_to_csv(data, symbol, start_date, end_date) # extract formatting function + else: + print("yfinance returned empty data, trying Polygon...") + break # 空数据也fallback + except YFRateLimitError: + if attempt == max_retries - 1: + print("yfinance all failed, switching to Polygon...") + break + wait_time = base_delay * (2 ** attempt) + print(f"⚠️ Yahoo rate limit, waiting {wait_time:.0f} seconds... ({attempt+1}/{max_retries})") + time.sleep(wait_time) + except Exception as e: + print(f"yfinance other error: {e}") + if attempt == max_retries - 1: + break + time.sleep(base_delay) - # Fetch historical data for the specified date range - data = ticker.history(start=start_date, end=end_date) - - # Check if data is empty - if data.empty: - return ( - f"No data found for symbol '{symbol}' between {start_date} and {end_date}" + # Fallback: use Polygon to pull data (stable, free layer enough) + try: + client = RESTClient(api_key=POLYGON_API_KEY) + aggs = client.get_aggs( + ticker=symbol.upper(), + multiplier=1, # 1天 + timespan="day", + from_=start_date, + to=end_date, ) + if aggs: + df = pd.DataFrame([{ + 'Date': pd.to_datetime(agg.timestamp, unit='ms').date(), + 'Open': agg.open, + 'High': agg.high, + 'Low': agg.low, + 'Close': agg.close, + 'Volume': agg.volume, + } for agg in aggs]) + df.set_index('Date', inplace=True) + df.sort_index(inplace=True) + return _format_data_to_csv(df, symbol, start_date, end_date) + else: + return f"No data from Polygon for '{symbol}' between {start_date} and {end_date}" + except Exception as e: + return f"Polygon fallback failed: {str(e)}. Please check API Key." - # Remove timezone info from index for cleaner output +def _format_data_to_csv(data, symbol, start_date, end_date): + """统一格式化输出""" + """standardize data output""" if data.index.tz is not None: data.index = data.index.tz_localize(None) - - # Round numerical values to 2 decimal places for cleaner display - numeric_columns = ["Open", "High", "Low", "Close", "Adj Close"] + numeric_columns = ["Open", "High", "Low", "Close"] # Adj Close Polygon has no, simplify for col in numeric_columns: if col in data.columns: data[col] = data[col].round(2) - - # Convert DataFrame to CSV string + if "Adj Close" in data.columns: + data["Adj Close"] = data["Adj Close"].round(2) csv_string = data.to_csv() - - # Add header information header = f"# Stock data for {symbol.upper()} from {start_date} to {end_date}\n" header += f"# Total records: {len(data)}\n" header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" - return header + csv_string - def get_stock_stats_indicators_window( symbol: Annotated[str, "ticker symbol of the company"], indicator: Annotated[str, "technical indicator to get the analysis and report of"],