TradingAgents/tradingagents/dataflows/yfin_utils.py

106 lines
3.3 KiB
Python

import logging
import yfinance as yf
from typing import Annotated, Callable, Any, Optional
from pandas import DataFrame
import pandas as pd
from functools import wraps
from .utils import SavePathType, decorate_all_methods
logger = logging.getLogger(__name__)
def init_ticker(func: Callable) -> Callable:
@wraps(func)
def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any:
ticker = yf.Ticker(symbol)
return func(ticker, *args, **kwargs)
return wrapper
@decorate_all_methods(init_ticker)
class YFinanceUtils:
def get_stock_data(
symbol: Annotated[str, "ticker symbol"],
start_date: Annotated[
str, "start date for retrieving stock price data, YYYY-mm-dd"
],
end_date: Annotated[
str, "end date for retrieving stock price data, YYYY-mm-dd"
],
save_path: SavePathType = None,
) -> DataFrame:
ticker = symbol
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
end_date = end_date.strftime("%Y-%m-%d")
stock_data = ticker.history(start=start_date, end=end_date)
return stock_data
def get_stock_info(
symbol: Annotated[str, "ticker symbol"],
) -> dict:
ticker = symbol
stock_info = ticker.info
return stock_info
def get_company_info(
symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None,
) -> DataFrame:
ticker = symbol
info = ticker.info
company_info = {
"Company Name": info.get("shortName", "N/A"),
"Industry": info.get("industry", "N/A"),
"Sector": info.get("sector", "N/A"),
"Country": info.get("country", "N/A"),
"Website": info.get("website", "N/A"),
}
company_info_df = DataFrame([company_info])
if save_path:
company_info_df.to_csv(save_path)
logger.info("Company info for %s saved to %s", ticker.ticker, save_path)
return company_info_df
def get_stock_dividends(
symbol: Annotated[str, "ticker symbol"],
save_path: Optional[str] = None,
) -> DataFrame:
ticker = symbol
dividends = ticker.dividends
if save_path:
dividends.to_csv(save_path)
logger.info("Dividends for %s saved to %s", ticker.ticker, save_path)
return dividends
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
ticker = symbol
income_stmt = ticker.financials
return income_stmt
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
ticker = symbol
balance_sheet = ticker.balance_sheet
return balance_sheet
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
ticker = symbol
cash_flow = ticker.cashflow
return cash_flow
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
ticker = symbol
recommendations = ticker.recommendations
if recommendations.empty:
return None, 0
row_0 = recommendations.iloc[0, 1:]
max_votes = row_0.max()
majority_voting_result = row_0[row_0 == max_votes].index.tolist()
return majority_voting_result[0], max_votes