This commit is contained in:
MarkLo 2025-12-05 23:34:50 +08:00
parent 727fd682b2
commit 5dee4b323c
13 changed files with 644 additions and 572 deletions

View File

@ -1,7 +1,7 @@
"""
Price data service for loading and processing stock price data
"""
import pandas as pd
import polars as pl
from pathlib import Path
from typing import List, Dict, Any, Optional
import logging
@ -13,7 +13,7 @@ class PriceService:
"""Service for loading and processing price data from data_cache"""
@staticmethod
def load_price_data(ticker: str, data_cache_dir: str) -> Optional[pd.DataFrame]:
def load_price_data(ticker: str, data_cache_dir: str) -> Optional[pl.DataFrame]:
"""
Load price data from data_cache CSV files
@ -54,10 +54,10 @@ class PriceService:
logger.info(f"Loading price data from {latest_file}")
df = pd.read_csv(latest_file)
df['Date'] = pd.to_datetime(df['Date'])
df = pl.read_csv(str(latest_file))
df = df.with_columns(pl.col("Date").str.to_datetime())
return df.sort_values('Date')
return df.sort("Date")
except Exception as e:
logger.error(f"Error loading price data for {ticker}: {e}")
@ -82,7 +82,7 @@ class PriceService:
return cache_age_hours < max_age_hours
@staticmethod
def _fetch_and_cache_data(ticker: str, data_cache_dir: str, max_retries: int = 3) -> Optional[pd.DataFrame]:
def _fetch_and_cache_data(ticker: str, data_cache_dir: str, max_retries: int = 3) -> Optional[pl.DataFrame]:
"""
Fetch data from yfinance and cache it
@ -130,10 +130,10 @@ class PriceService:
logger.info(f"成功獲取並緩存 {ticker} 數據到 {cache_file}")
# Prepare and return DataFrame
df = pd.read_csv(cache_file)
df['Date'] = pd.to_datetime(df['Date'])
return df.sort_values('Date')
# Prepare and return DataFrame - convert to polars
df = pl.read_csv(str(cache_file))
df = df.with_columns(pl.col("Date").str.to_datetime())
return df.sort("Date")
except Exception as e:
logger.warning(f"{attempt} 次嘗試失敗: {e}")
@ -149,7 +149,7 @@ class PriceService:
@staticmethod
def calculate_stats(df: pd.DataFrame) -> Dict[str, Any]:
def calculate_stats(df: pl.DataFrame) -> Dict[str, Any]:
"""
Calculate price statistics
@ -159,22 +159,22 @@ class PriceService:
Returns:
Dictionary with statistics
"""
start_price = float(df.iloc[0]['Close'])
end_price = float(df.iloc[-1]['Close'])
start_price = float(df.row(0, named=True)["Close"])
end_price = float(df.row(-1, named=True)["Close"])
growth_rate = ((end_price - start_price) / start_price) * 100
duration_days = (df.iloc[-1]['Date'] - df.iloc[0]['Date']).days
duration_days = (df.row(-1, named=True)["Date"] - df.row(0, named=True)["Date"]).days
return {
"growth_rate": round(growth_rate, 2),
"duration_days": int(duration_days),
"start_date": df.iloc[0]['Date'].strftime('%Y-%m-%d'),
"end_date": df.iloc[-1]['Date'].strftime('%Y-%m-%d'),
"start_date": df.row(0, named=True)["Date"].strftime('%Y-%m-%d'),
"end_date": df.row(-1, named=True)["Date"].strftime('%Y-%m-%d'),
"start_price": round(start_price, 2),
"end_price": round(end_price, 2),
}
@staticmethod
def prepare_chart_data(df: pd.DataFrame, limit: int = 365) -> List[Dict[str, Any]]:
def prepare_chart_data(df: pl.DataFrame, limit: int = 365) -> List[Dict[str, Any]]:
"""
Prepare price data for charting (limit to recent data)
@ -188,9 +188,9 @@ class PriceService:
# Get recent data
recent_df = df.tail(limit)
# Convert to list of dicts
# Convert to list of dicts using polars to_dicts()
data = []
for _, row in recent_df.iterrows():
for row in recent_df.iter_rows(named=True):
data.append({
"Date": row['Date'].strftime('%Y-%m-%d'),
"Open": round(float(row['Open']), 2),

View File

@ -16,7 +16,8 @@ python-dotenv==1.0.0
typing-extensions
langchain-openai
langchain-experimental
pandas
polars
pyarrow
yfinance
praw
feedparser

View File

@ -23,7 +23,7 @@
"clsx": "^2.1.1",
"date-fns": "^4.1.0",
"lucide-react": "^0.554.0",
"next": "16.0.3",
"next": "16.0.7",
"next-themes": "^0.4.6",
"react": "19.2.0",
"react-day-picker": "^9.11.3",
@ -42,6 +42,7 @@
"@types/react": "^19",
"@types/react-dom": "^19",
"babel-plugin-react-compiler": "1.0.0",
"baseline-browser-mapping": "^2.9.2",
"eslint": "^9",
"eslint-config-next": "16.0.3",
"tailwindcss": "^4",

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,3 @@
ignoredBuiltDependencies:
- sharp
- unrs-resolver

View File

@ -1,7 +1,8 @@
typing-extensions
langchain-openai
langchain-experimental
pandas
polars
pyarrow
yfinance
praw
feedparser

View File

@ -5,16 +5,31 @@ from tradingagents.dataflows.interface import route_to_vendor
@tool
def get_indicators(
symbol: Annotated[str, "公司的股票代碼"],
indicator: Annotated[str, "要獲取分析和報告的技術指標"],
indicator: Annotated[str, """技術指標名稱。
常用指標MA (移動平均), RSI, MACD
使用簡寫名稱如'MA', 'RSI', 'MACD'
或指定期間如'close_50_sma', 'close_200_sma'
範例'MA' 'RSI' 'MACD'"""],
curr_date: Annotated[str, "您正在交易的當前交易日期,格式為 YYYY-mm-dd"],
look_back_days: Annotated[int, "回溯天數"] = 30,
) -> str:
"""
檢索給定股票代碼的技術指標
使用設定的技術指標供應商
支持的指標
- MA/SMA: 簡單移動平均線使用 look_back_days 或指定 'close_50_sma', 'close_200_sma'
- EMA: 指數移動平均線
- RSI: 相對強弱指數
- MACD: 移動平均收斂背離
- BOLL: 布林通道
- ATR: 平均真實波幅
- VWMA: 成交量加權移動平均
- MFI: 資金流量指數
Args:
symbol (str): 公司的股票代碼例如 AAPL, TSM
indicator (str): 要獲取分析和報告的技術指標
indicator (str): 技術指標名稱使用簡寫MA, RSI, MACD
curr_date (str): 您正在交易的當前交易日期格式為 YYYY-mm-dd
look_back_days (int): 回溯天數預設為 30
Returns:
@ -23,34 +38,56 @@ def get_indicators(
# 規範化指標名稱以匹配供應商的預期格式
indicator_lower = indicator.lower().strip()
# 處理常見的變體
if "50" in indicator_lower and ("ma" in indicator_lower or "avg" in indicator_lower):
# 處理常見的變體 - 包含 "moving average" 的完整詞
if "50" in indicator_lower and ("ma" in indicator_lower or "avg" in indicator_lower or "moving" in indicator_lower):
normalized_indicator = "close_50_sma"
elif "200" in indicator_lower and ("ma" in indicator_lower or "avg" in indicator_lower):
elif "200" in indicator_lower and ("ma" in indicator_lower or "avg" in indicator_lower or "moving" in indicator_lower):
normalized_indicator = "close_200_sma"
elif "10" in indicator_lower and "ema" in indicator_lower:
normalized_indicator = "close_10_ema"
# 處理通用指標名稱,使用 look_back_days
elif indicator_lower in ["sma", "ma"]:
elif indicator_lower in ["sma", "ma", "moving average", "simple moving average"]:
normalized_indicator = f"close_{look_back_days}_sma"
elif indicator_lower == "ema":
elif indicator_lower in ["ema", "exponential moving average"]:
normalized_indicator = f"close_{look_back_days}_ema"
else:
# 常見指標名稱映射
# 常見指標名稱映射 - 擴充版
mapping = {
# SMA 變體
"sma50": "close_50_sma",
"sma200": "close_200_sma",
"ema10": "close_10_ema",
"bbands": "boll",
"bollinger": "boll",
"bollinger bands": "boll",
"macd_signal": "macds",
"macd_hist": "macdh",
"50-day ma": "close_50_sma",
"200-day ma": "close_200_sma",
"50 day ma": "close_50_sma",
"200 day ma": "close_200_sma",
"50-day moving average": "close_50_sma",
"200-day moving average": "close_200_sma",
"50 day moving average": "close_50_sma",
"200 day moving average": "close_200_sma",
"50-day simple moving average": "close_50_sma",
"200-day simple moving average": "close_200_sma",
# EMA 變體
"ema10": "close_10_ema",
"10-day ema": "close_10_ema",
"10 day ema": "close_10_ema",
# Bollinger Bands
"bbands": "boll",
"bollinger": "boll",
"bollinger bands": "boll",
"bb": "boll",
# MACD 變體
"macd_signal": "macds",
"macd signal": "macds",
"macd_hist": "macdh",
"macd histogram": "macdh",
# 其他常見別名
"relative strength index": "rsi",
"average true range": "atr",
"money flow index": "mfi",
}
# 如果在映射中,使用映射名稱

View File

@ -1,6 +1,6 @@
import os
import requests
import pandas as pd
import polars as pl
import json
from datetime import datetime
from io import StringIO
@ -102,20 +102,22 @@ def _filter_csv_by_date_range(csv_data: str, start_date: str, end_date: str) ->
try:
# 解析 CSV 數據
df = pd.read_csv(StringIO(csv_data))
df = pl.read_csv(StringIO(csv_data))
# 假設第一欄是日期欄 (時間戳)
date_col = df.columns[0]
df[date_col] = pd.to_datetime(df[date_col])
df = df.with_columns(pl.col(date_col).str.to_datetime())
# 按日期範圍過濾
start_dt = pd.to_datetime(start_date)
end_dt = pd.to_datetime(end_date)
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
filtered_df = df[(df[date_col] >= start_dt) & (df[date_col] <= end_dt)]
filtered_df = df.filter(
(pl.col(date_col) >= start_dt) & (pl.col(date_col) <= end_dt)
)
# 轉換回 CSV 字串
return filtered_df.to_csv(index=False)
return filtered_df.write_csv()
except Exception as e:
# 如果過濾失敗,返回原始數據並附帶警告

View File

@ -1,5 +1,5 @@
from typing import Annotated
import pandas as pd
import polars as pl
import os
from .config import DATA_DIR
from datetime import datetime
@ -30,29 +30,28 @@ def get_YFin_data_window(
start_date = before.strftime("%Y-%m-%d")
# 讀取數據
data = pd.read_csv(
data = pl.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
# 僅提取日期部分進行比較
data["DateOnly"] = data["Date"].str[:10]
# 節取日期部分
data = data.with_columns(
pl.col("Date").str.slice(0, 10).alias("DateOnly")
)
# 過濾指定日期範圍內的數據 (包含起訖日期)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= curr_date)
]
filtered_data = data.filter(
(pl.col("DateOnly") >= start_date) & (pl.col("DateOnly") <= curr_date)
)
# 刪除我們創建的臨時欄位
filtered_data = filtered_data.drop("DateOnly", axis=1)
filtered_data = filtered_data.drop("DateOnly")
# 設定 pandas 顯示選項以顯示完整的 DataFrame
with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", None
):
df_string = filtered_data.to_string()
# polars 的字串輸出
df_string = str(filtered_data)
return (
f"## {symbol}{start_date}{curr_date} 的原始市場數據:\n\n"
@ -76,7 +75,7 @@ def get_YFin_data(
pd.DataFrame: 包含過濾後數據的 DataFrame
"""
# 讀取數據
data = pd.read_csv(
data = pl.read_csv(
os.path.join(
DATA_DIR,
f"market_data/price_data/{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
@ -88,19 +87,21 @@ def get_YFin_data(
f"Get_YFin_Data{end_date} 超出 2015-01-01 到 2025-03-25 的數據範圍"
)
# 僅提取日期部分進行比較
data["DateOnly"] = data["Date"].str[:10]
# 節取日期部分
data = data.with_columns(
pl.col("Date").str.slice(0, 10).alias("DateOnly")
)
# 過濾指定日期範圍內的數據 (包含起訖日期)
filtered_data = data[
(data["DateOnly"] >= start_date) & (data["DateOnly"] <= end_date)
]
# 過濾指定日期範內的數據 (包含起訖日期)
filtered_data = data.filter(
(pl.col("DateOnly") >= start_date) & (pl.col("DateOnly") <= end_date)
)
# 刪除我們創建的臨時欄位
filtered_data = filtered_data.drop("DateOnly", axis=1)
filtered_data = filtered_data.drop("DateOnly")
# 從數據框中移除索引
filtered_data = filtered_data.reset_index(drop=True)
# 重置索引 (在 polars 中不需要,但保持一致性)
# filtered_data = filtered_data.with_row_count(name="index", offset=0)
return filtered_data
@ -280,28 +281,34 @@ def get_simfin_balance_sheet(
"us",
f"us-balance-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
df = pl.read_csv(data_path, separator=";")
# 將日期字串轉換為日期時間物件並移除任何時間部分
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
df = df.with_columns(
pl.col("Report Date").str.to_datetime().dt.date().alias("Report Date"),
pl.col("Publish Date").str.to_datetime().dt.date().alias("Publish Date")
)
# 將當前日期轉換為日期時間並標準化
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
from datetime import datetime as dt
curr_date_dt = dt.strptime(curr_date, "%Y-%m-%d").date()
# 過濾 DataFrame篩選出給定股票代碼且報告發布日期在當前日期或之前的報告
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
filtered_df = df.filter(
(pl.col("Ticker") == ticker) & (pl.col("Publish Date") <= curr_date_dt)
)
# 檢查是否有可用的報告;如果沒有,則返回通知
if filtered_df.empty:
if filtered_df.is_empty():
print("在給定的當前日期之前沒有可用的資產負債表。")
return ""
# 通過選擇具有最新發布日期的行來獲取最新的資產負債表
latest_balance_sheet = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
max_date_idx = filtered_df.select(pl.col("Publish Date")).arg_max()
latest_balance_sheet = filtered_df.row(max_date_idx, named=True)
# 刪除 SimFinID 欄位
latest_balance_sheet = latest_balance_sheet.drop("SimFinId")
# latest_balance_sheet = latest_balance_sheet.drop("SimFinId")
return (
f"## {ticker}{str(latest_balance_sheet['Publish Date'])[0:10]} 發布的 {freq} 資產負債表:\n"
@ -327,25 +334,31 @@ def get_simfin_cashflow(
"us",
f"us-cashflow-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
df = pl.read_csv(data_path, separator=";")
# 將日期字串轉換為日期時間物件並移除任何時間部分
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
df = df.with_columns(
pl.col("Report Date").str.to_datetime().dt.date().alias("Report Date"),
pl.col("Publish Date").str.to_datetime().dt.date().alias("Publish Date")
)
# 將當前日期轉換為日期時間並標準化
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
from datetime import datetime as dt
curr_date_dt = dt.strptime(curr_date, "%Y-%m-%d").date()
# 過濾 DataFrame篩選出給定股票代碼且報告發布日期在當前日期或之前的報告
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
filtered_df = df.filter(
(pl.col("Ticker") == ticker) & (pl.col("Publish Date") <= curr_date_dt)
)
# 檢查是否有可用的報告;如果沒有,則返回通知
if filtered_df.empty:
if filtered_df.is_empty():
print("在給定的當前日期之前沒有可用的現金流量表。")
return ""
# 通過選擇具有最新發布日期的行來獲取最新的現金流量表
latest_cash_flow = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
max_date_idx = filtered_df.select(pl.col("Publish Date")).arg_max()
latest_cash_flow = filtered_df.row(max_date_idx, named=True)
# 刪除 SimFinID 欄位
latest_cash_flow = latest_cash_flow.drop("SimFinId")
@ -374,25 +387,31 @@ def get_simfin_income_statements(
"us",
f"us-income-{freq}.csv",
)
df = pd.read_csv(data_path, sep=";")
df = pl.read_csv(data_path, separator=";")
# 將日期字串轉換為日期時間物件並移除任何時間部分
df["Report Date"] = pd.to_datetime(df["Report Date"], utc=True).dt.normalize()
df["Publish Date"] = pd.to_datetime(df["Publish Date"], utc=True).dt.normalize()
df = df.with_columns(
pl.col("Report Date").str.to_datetime().dt.date().alias("Report Date"),
pl.col("Publish Date").str.to_datetime().dt.date().alias("Publish Date")
)
# 將當前日期轉換為日期時間並標準化
curr_date_dt = pd.to_datetime(curr_date, utc=True).normalize()
from datetime import datetime as dt
curr_date_dt = dt.strptime(curr_date, "%Y-%m-%d").date()
# 過濾 DataFrame篩選出給定股票代碼且報告發布日期在當前日期或之前的報告
filtered_df = df[(df["Ticker"] == ticker) & (df["Publish Date"] <= curr_date_dt)]
filtered_df = df.filter(
(pl.col("Ticker") == ticker) & (pl.col("Publish Date") <= curr_date_dt)
)
# 檢查是否有可用的報告;如果沒有,則返回通知
if filtered_df.empty:
if filtered_df.is_empty():
print("在給定的當前日期之前沒有可用的損益表。")
return ""
# 通過選擇具有最新發布日期的行來獲取最新的損益表
latest_income = filtered_df.loc[filtered_df["Publish Date"].idxmax()]
max_date_idx = filtered_df.select(pl.col("Publish Date")).arg_max()
latest_income = filtered_df.row(max_date_idx, named=True)
# 刪除 SimFinID 欄位
latest_income = latest_income.drop("SimFinId")

View File

@ -1,4 +1,4 @@
import pandas as pd
import polars as pl
import yfinance as yf
from stockstats import wrap
from typing import Annotated
@ -9,6 +9,7 @@ from .config import get_config, DATA_DIR
class StockstatsUtils:
"""
一個提供股票統計功能的工具類別
注意: stockstats 函式庫需要 pandas DataFrame所以需要進行 pandas/polars 轉換
"""
@staticmethod
def get_stock_stats(
@ -31,6 +32,8 @@ class StockstatsUtils:
Returns:
float or str: 指標值或錯誤訊息
"""
from datetime import datetime, timedelta
# 獲取設定並設定數據目錄路徑
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
@ -40,51 +43,56 @@ class StockstatsUtils:
if not online:
try:
data = pd.read_csv(
data = pl.read_csv(
os.path.join(
DATA_DIR,
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
df = wrap(data)
# stockstats 需要 pandas DataFrame
data_pd = data.to_pandas()
df = wrap(data_pd)
except FileNotFoundError:
raise Exception("Stockstats 失敗:尚未獲取 Yahoo Finance 數據!")
else:
# 獲取今天的日期 (YYYY-mm-dd) 以添加到快取
today_date = pd.Timestamp.today()
curr_date = pd.to_datetime(curr_date)
today_date = datetime.now()
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
start_date = today_date - timedelta(days=365*15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
# 獲取設定並確保快取目錄存在
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
data = pl.read_csv(data_file)
data = data.with_columns(pl.col("Date").str.to_datetime())
else:
data = yf.download(
data_yf = yf.download(
symbol,
start=start_date,
end=end_date,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
data_yf = data_yf.reset_index()
data_yf.to_csv(data_file, index=False)
data = pl.from_pandas(data_yf)
df = wrap(data)
# stockstats 需要 pandas DataFrame
data_pd = data.to_pandas()
df = wrap(data_pd)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date = curr_date.strftime("%Y-%m-%d")
curr_date = curr_date_dt.strftime("%Y-%m-%d")
df[indicator] # 觸發 stockstats 計算指標
matching_rows = df[df["Date"].str.startswith(curr_date)]

View File

@ -1,22 +1,22 @@
import os
import json
import pandas as pd
import polars as pl
from datetime import date, timedelta, datetime
from typing import Annotated
SavePathType = Annotated[str, "儲存資料的檔案路徑。如果為 None則不儲存資料。"]
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
def save_output(data: pl.DataFrame, tag: str, save_path: SavePathType = None) -> None:
"""
DataFrame 儲存到 CSV 檔案
Args:
data (pd.DataFrame): 要儲存的 DataFrame
data (pl.DataFrame): 要儲存的 DataFrame
tag (str): 用於在控制台中打印的標籤
save_path (SavePathType, optional): 儲存檔案的路徑預設為 None
"""
if save_path:
data.to_csv(save_path)
data.write_csv(save_path)
print(f"{tag} 已儲存至 {save_path}")

View File

@ -229,6 +229,7 @@ def _get_stock_stats_bulk(
返回將日期字串映射到指標值的字典
"""
from .config import get_config
import polars as pl
import pandas as pd
from stockstats import wrap
import os
@ -239,22 +240,25 @@ def _get_stock_stats_bulk(
if not online:
# 本地數據路徑
try:
data = pd.read_csv(
data = pl.read_csv(
os.path.join(
config.get("data_cache_dir", "data"),
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
df = wrap(data)
# stockstats 需要 pandas DataFrame
data_pd = data.to_pandas()
df = wrap(data_pd)
except FileNotFoundError:
raise Exception("Stockstats 失敗:尚未獲取 Yahoo Finance 數據!")
else:
# 帶有快取的線上數據獲取
today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date)
from datetime import datetime as dt, timedelta
today_date = dt.now()
curr_date_dt = dt.strptime(curr_date, "%Y-%m-%d")
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = today_date - timedelta(days=365*15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
@ -279,8 +283,10 @@ def _get_stock_stats_bulk(
logger.info(f"{symbol} 緩存過期(年齡:{cache_age_hours:.1f} 小時),將重新下載")
if cache_valid:
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
data_pl = pl.read_csv(data_file)
data_pl = data_pl.with_columns(pl.col("Date").str.to_datetime())
# stockstats 需要 pandas DataFrame
data = data_pl.to_pandas()
else:
# 使用重試機制下載數據
@retry(max_attempts=3, backoff=2.0)
@ -305,8 +311,9 @@ def _get_stock_stats_bulk(
# 如果下載失敗但有舊緩存,使用舊緩存
if os.path.exists(data_file):
logger.warning(f"使用過期緩存作為備援")
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
data_pl = pl.read_csv(data_file)
data_pl = data_pl.with_columns(pl.col("Date").str.to_datetime())
data = data_pl.to_pandas()
else:
raise

View File

@ -2,8 +2,7 @@
import yfinance as yf
from typing import Annotated, Callable, Any, Optional
from pandas import DataFrame
import pandas as pd
import polars as pl
from functools import wraps
from .utils import save_output, SavePathType, decorate_all_methods
@ -35,15 +34,18 @@ class YFinanceUtils:
str, "檢索股價數據的結束日期,格式為 YYYY-mm-dd"
],
save_path: SavePathType = None,
) -> DataFrame:
) -> pl.DataFrame:
"""檢索指定股票代碼的股價數據"""
from datetime import datetime, timedelta
ticker = symbol
# 將結束日期加一天,使數據範圍包含結束日期
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
end_date = end_date.strftime("%Y-%m-%d")
end_date_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
end_date = end_date_dt.strftime("%Y-%m-%d")
stock_data = ticker.history(start=start_date, end=end_date)
# save_output(stock_data, f"{ticker.ticker} 的股票數據", save_path)
return stock_data
# 轉換為 polars DataFrame
stock_data_pl = pl.from_pandas(stock_data.reset_index())
# save_output(stock_data_pl, f"{ticker.ticker} 的股票數據", save_path)
return stock_data_pl
def get_stock_info(
symbol: Annotated[str, "股票代碼"],
@ -56,7 +58,7 @@ class YFinanceUtils:
def get_company_info(
symbol: Annotated[str, "股票代碼"],
save_path: Optional[str] = None,
) -> DataFrame:
) -> pl.DataFrame:
"""獲取並以 DataFrame 形式返回公司資訊。"""
ticker = symbol
info = ticker.info
@ -67,41 +69,42 @@ class YFinanceUtils:
"國家": info.get("country", "N/A"),
"網站": info.get("website", "N/A"),
}
company_info_df = DataFrame([company_info])
company_info_df = pl.DataFrame([company_info])
if save_path:
company_info_df.to_csv(save_path)
company_info_df.write_csv(save_path)
print(f"{ticker.ticker} 的公司資訊已儲存至 {save_path}")
return company_info_df
def get_stock_dividends(
symbol: Annotated[str, "股票代碼"],
save_path: Optional[str] = None,
) -> DataFrame:
) -> pl.DataFrame:
"""獲取並以 DataFrame 形式返回最新的股息數據。"""
ticker = symbol
dividends = ticker.dividends
dividends_pl = pl.from_pandas(dividends.reset_index())
if save_path:
dividends.to_csv(save_path)
dividends_pl.write_csv(save_path)
print(f"{ticker.ticker} 的股息已儲存至 {save_path}")
return dividends
return dividends_pl
def get_income_stmt(symbol: Annotated[str, "股票代碼"]) -> DataFrame:
def get_income_stmt(symbol: Annotated[str, "股票代碼"]) -> pl.DataFrame:
"""獲取並以 DataFrame 形式返回公司最新的損益表。"""
ticker = symbol
income_stmt = ticker.financials
return income_stmt
return pl.from_pandas(income_stmt.reset_index())
def get_balance_sheet(symbol: Annotated[str, "股票代碼"]) -> DataFrame:
def get_balance_sheet(symbol: Annotated[str, "股票代碼"]) -> pl.DataFrame:
"""獲取並以 DataFrame 形式返回公司最新的資產負債表。"""
ticker = symbol
balance_sheet = ticker.balance_sheet
return balance_sheet
return pl.from_pandas(balance_sheet.reset_index())
def get_cash_flow(symbol: Annotated[str, "股票代碼"]) -> DataFrame:
def get_cash_flow(symbol: Annotated[str, "股票代碼"]) -> pl.DataFrame:
"""獲取並以 DataFrame 形式返回公司最新的現金流量表。"""
ticker = symbol
cash_flow = ticker.cashflow
return cash_flow
return pl.from_pandas(cash_flow.reset_index())
def get_analyst_recommendations(symbol: Annotated[str, "股票代碼"]) -> tuple:
"""獲取最新的分析師建議,並返回最常見的建議及其計數。"""