TradingAgents/tradingagents/dataflows/tushare.py

576 lines
20 KiB
Python

from __future__ import annotations
import os
from datetime import datetime, timedelta
from functools import lru_cache
from typing import Callable
import pandas as pd
from stockstats import wrap
from .exceptions import DataVendorUnavailable
_SUPPORTED_EXCHANGES = {"SH", "SZ", "BJ", "HK"}
_SUFFIX_MAP = {
"SH": "SH",
"SS": "SH",
"SSE": "SH",
"SZ": "SZ",
"SZSE": "SZ",
"BJ": "BJ",
"BSE": "BJ",
"HK": "HK",
"HKG": "HK",
"SEHK": "HK",
}
_A_SHARE_EXCHANGES = {"SH", "SZ", "BJ"}
def _parse_date(date_str: str) -> datetime:
return datetime.strptime(date_str, "%Y-%m-%d")
def _to_api_date(date_str: str) -> str:
return _parse_date(date_str).strftime("%Y%m%d")
def _classify_market(ts_code: str) -> str:
if "." in ts_code:
suffix = ts_code.rsplit(".", 1)[1]
if suffix in _A_SHARE_EXCHANGES:
return "a_share"
if suffix == "HK":
return "hk"
return "us"
def _normalize_ts_code(symbol: str) -> str:
raw = symbol.strip().upper()
if "." in raw:
code, suffix = raw.split(".", 1)
suffix = _SUFFIX_MAP.get(suffix, suffix)
if suffix in _A_SHARE_EXCHANGES and code.isdigit():
return f"{code.zfill(6)}.{suffix}"
if suffix == "HK" and code.isdigit():
return f"{code.zfill(5)}.HK"
raise DataVendorUnavailable(
f"Tushare currently supports A-share, Hong Kong, and US tickers only, got '{symbol}'."
)
if raw.isdigit() and len(raw) <= 6:
code = raw.zfill(6)
if code.startswith(("6", "9", "5")):
return f"{code}.SH"
if code.startswith(("0", "2", "3")):
return f"{code}.SZ"
if code.startswith(("4", "8")):
return f"{code}.BJ"
return f"{raw.zfill(5)}.HK"
if raw.replace("-", "").isalnum():
return raw
raise DataVendorUnavailable(
f"Cannot map ticker '{symbol}' to a supported Tushare market automatically."
)
@lru_cache(maxsize=1)
def _get_pro_client():
token = (
os.getenv("TUSHARE_TOKEN")
or os.getenv("TUSHARE_API_TOKEN")
or os.getenv("TS_TOKEN")
)
if not token:
raise DataVendorUnavailable(
"TUSHARE_TOKEN is not set. Configure token or use fallback vendor."
)
try:
import tushare as ts
except ImportError as exc:
raise DataVendorUnavailable(
"tushare package is not installed. Install it to enable tushare vendor."
) from exc
try:
ts.set_token(token)
return ts.pro_api(token)
except Exception as exc:
raise DataVendorUnavailable(f"Failed to initialize tushare client: {exc}") from exc
def _to_csv_with_header(df: pd.DataFrame, title: str) -> str:
if df is None or df.empty:
return f"No {title.lower()} data found."
header = f"# {title}\n"
header += f"# Total records: {len(df)}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return header + df.to_csv(index=False)
def _filter_statement(df: pd.DataFrame, freq: str, curr_date: str | None) -> pd.DataFrame:
if df is None or df.empty:
return df
output = df.copy()
if curr_date and "end_date" in output.columns:
cutoff = _to_api_date(curr_date)
output = output[output["end_date"].astype(str) <= cutoff]
if freq.lower() == "annual" and "end_date" in output.columns:
output = output[output["end_date"].astype(str).str.endswith("1231")]
sort_col = "end_date" if "end_date" in output.columns else output.columns[0]
output = output.sort_values(sort_col, ascending=False).head(8)
return output
def _fetch_price_data(pro, ts_code: str, start_api: str, end_api: str) -> pd.DataFrame:
market = _classify_market(ts_code)
if market == "a_share":
return pro.daily(ts_code=ts_code, start_date=start_api, end_date=end_api)
if market == "hk":
return pro.hk_daily(ts_code=ts_code, start_date=start_api, end_date=end_api)
return pro.us_daily(ts_code=ts_code, start_date=start_api, end_date=end_api)
def get_stock(symbol: str, start_date: str, end_date: str) -> str:
pro = _get_pro_client()
ts_code = _normalize_ts_code(symbol)
start_api = _to_api_date(start_date)
end_api = _to_api_date(end_date)
data = _fetch_price_data(pro, ts_code, start_api, end_api)
if data is None or data.empty:
return f"No stock data found for '{ts_code}' between {start_date} and {end_date}."
rename_map = {
"trade_date": "Date",
"open": "Open",
"high": "High",
"low": "Low",
"close": "Close",
"vol": "Volume",
"amount": "Amount",
"pct_chg": "PctChg",
"pre_close": "PrevClose",
"change": "Change",
}
output = data.rename(columns=rename_map)
if "Date" in output.columns:
output["Date"] = pd.to_datetime(output["Date"], format="%Y%m%d").dt.strftime(
"%Y-%m-%d"
)
output = output.sort_values("Date", ascending=True)
preferred_cols = [
"Date",
"Open",
"High",
"Low",
"Close",
"PrevClose",
"Change",
"PctChg",
"Volume",
"Amount",
]
existing_cols = [c for c in preferred_cols if c in output.columns]
output = output[existing_cols]
return _to_csv_with_header(
output,
f"Tushare stock data for {ts_code} from {start_date} to {end_date}",
)
def _load_price_frame(symbol: str, curr_date: str, look_back_days: int = 260) -> pd.DataFrame:
pro = _get_pro_client()
ts_code = _normalize_ts_code(symbol)
end_dt = _parse_date(curr_date)
start_dt = end_dt - timedelta(days=look_back_days)
data = _fetch_price_data(
pro,
ts_code,
start_dt.strftime("%Y%m%d"),
end_dt.strftime("%Y%m%d"),
)
if data is None or data.empty:
raise DataVendorUnavailable(
f"No tushare price data found for '{ts_code}' before {curr_date}."
)
df = data.rename(
columns={
"trade_date": "Date",
"open": "Open",
"high": "High",
"low": "Low",
"close": "Close",
"vol": "Volume",
}
).copy()
df["Date"] = pd.to_datetime(df["Date"], format="%Y%m%d")
df = df.sort_values("Date", ascending=True)
return df[["Date", "Open", "High", "Low", "Close", "Volume"]]
def get_indicator(
symbol: str,
indicator: str,
curr_date: str,
look_back_days: int,
) -> str:
descriptions = {
"close_50_sma": "50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.",
"close_200_sma": "200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.",
"close_10_ema": "10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.",
"macd": "MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets.",
"macds": "MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives.",
"macdh": "MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets.",
"rsi": "RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.",
"boll": "Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.",
"boll_ub": "Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.",
"boll_lb": "Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.",
"atr": "ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.",
"vwma": "VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.",
"mfi": "MFI: Uses both price and volume to measure buying and selling pressure. Usage: Identify overbought (>80) or oversold (<20) conditions and confirm trends or reversals.",
}
if indicator not in descriptions:
raise ValueError(
f"Indicator {indicator} is not supported. Please choose from: {list(descriptions.keys())}"
)
current_dt = _parse_date(curr_date)
start_dt = current_dt - timedelta(days=look_back_days)
stats_df = wrap(_load_price_frame(symbol, curr_date))
stats_df["Date"] = stats_df["Date"].dt.strftime("%Y-%m-%d")
stats_df[indicator]
lines = []
probe_dt = current_dt
while probe_dt >= start_dt:
date_str = probe_dt.strftime("%Y-%m-%d")
row = stats_df[stats_df["Date"] == date_str]
if row.empty:
lines.append(f"{date_str}: N/A: Not a trading day (weekend or holiday)")
else:
value = row.iloc[0][indicator]
if pd.isna(value):
lines.append(f"{date_str}: N/A")
else:
lines.append(f"{date_str}: {value}")
probe_dt -= timedelta(days=1)
return (
f"## {indicator} values from {start_dt.strftime('%Y-%m-%d')} to {curr_date}:\n\n"
+ "\n".join(lines)
+ "\n\n"
+ descriptions[indicator]
)
def get_fundamentals(ticker: str, curr_date: str | None = None) -> str:
pro = _get_pro_client()
ts_code = _normalize_ts_code(ticker)
market = _classify_market(ts_code)
if curr_date:
curr_dt = _parse_date(curr_date)
else:
curr_dt = datetime.now()
curr_date = curr_dt.strftime("%Y-%m-%d")
end_api = curr_dt.strftime("%Y%m%d")
start_api_40d = (curr_dt - timedelta(days=40)).strftime("%Y%m%d")
start_api_400d = (curr_dt - timedelta(days=400)).strftime("%Y%m%d")
if market == "a_share":
basic = pro.stock_basic(
ts_code=ts_code,
fields="ts_code,symbol,name,area,industry,market,list_date,list_status",
)
latest_price = pro.daily_basic(
ts_code=ts_code,
start_date=start_api_40d,
end_date=end_api,
)
fina_indicator = pro.fina_indicator(
ts_code=ts_code,
start_date=start_api_400d,
end_date=end_api,
)
elif market == "hk":
basic = pro.hk_basic(ts_code=ts_code)
latest_price = pro.hk_daily(ts_code=ts_code, start_date=start_api_40d, end_date=end_api)
fina_indicator = None
else:
basic = pro.us_basic(ts_code=ts_code)
latest_price = pro.us_daily(ts_code=ts_code, start_date=start_api_40d, end_date=end_api)
fina_indicator = None
lines = [
f"Ticker: {ts_code}",
f"Market: {market}",
f"Reference date: {curr_date}",
]
if basic is not None and not basic.empty:
row = basic.iloc[0]
if market == "a_share":
field_map = {
"name": "Name",
"area": "Area",
"industry": "Industry",
"market": "Market",
"list_date": "List Date",
"list_status": "List Status",
}
elif market == "hk":
field_map = {
"name": "Name",
"fullname": "Full Name",
"enname": "English Name",
"market": "Market",
"curr_type": "Currency",
"list_date": "List Date",
"list_status": "List Status",
}
else:
field_map = {
"name": "Name",
"enname": "English Name",
"classify": "Classify",
"list_date": "List Date",
"delist_date": "Delist Date",
}
for field, label in field_map.items():
value = row.get(field)
if pd.notna(value):
lines.append(f"{label}: {value}")
if latest_price is not None and not latest_price.empty:
row = latest_price.sort_values("trade_date", ascending=False).iloc[0]
if market == "a_share":
field_map = {
"trade_date": "Latest Trade Date",
"close": "Close",
"turnover_rate": "Turnover Rate",
"pe": "PE",
"pb": "PB",
"ps": "PS",
"dv_ratio": "Dividend Yield Ratio",
"total_mv": "Total Market Value",
"circ_mv": "Circulating Market Value",
}
else:
field_map = {
"trade_date": "Latest Trade Date",
"close": "Close",
"open": "Open",
"high": "High",
"low": "Low",
"pre_close": "Prev Close",
"change": "Change",
"pct_chg": "Pct Change",
"vol": "Volume",
"amount": "Amount",
}
for field, label in field_map.items():
value = row.get(field)
if pd.notna(value):
lines.append(f"{label}: {value}")
if fina_indicator is not None and not fina_indicator.empty:
row = fina_indicator.sort_values("end_date", ascending=False).iloc[0]
field_map = {
"end_date": "Latest Financial Period",
"roe": "ROE",
"roa": "ROA",
"grossprofit_margin": "Gross Margin",
"netprofit_margin": "Net Margin",
"debt_to_assets": "Debt to Assets",
"ocf_to_or": "OCF to Revenue",
}
for field, label in field_map.items():
value = row.get(field)
if pd.notna(value):
lines.append(f"{label}: {value}")
elif market == "hk":
income = pro.hk_income(ts_code=ts_code, end_date=end_api)
if income is not None and not income.empty:
latest_end = income["end_date"].astype(str).max()
lines.append(f"Latest Financial Period: {latest_end}")
sample = income[income["end_date"].astype(str) == latest_end].head(12)
for _, rec in sample.iterrows():
lines.append(f"{rec.get('ind_name')}: {rec.get('ind_value')}")
else:
income = pro.us_income(ts_code=ts_code, end_date=end_api)
if income is not None and not income.empty:
latest_end = income["end_date"].astype(str).max()
lines.append(f"Latest Financial Period: {latest_end}")
sample = income[income["end_date"].astype(str) == latest_end].head(12)
for _, rec in sample.iterrows():
lines.append(f"{rec.get('ind_name')}: {rec.get('ind_value')}")
header = f"# Tushare fundamentals for {ts_code}\n"
header += f"# Data retrieved on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return header + "\n".join(lines)
def _statement_common(
ticker: str,
freq: str,
curr_date: str | None,
fetcher: Callable,
title: str,
) -> str:
pro = _get_pro_client()
ts_code = _normalize_ts_code(ticker)
market = _classify_market(ts_code)
data = fetcher(pro, ts_code, market)
filtered = _filter_statement(data, freq, curr_date)
return _to_csv_with_header(filtered, f"Tushare {title} for {ts_code} ({freq})")
def get_balance_sheet(
ticker: str,
freq: str = "quarterly",
curr_date: str | None = None,
) -> str:
return _statement_common(
ticker,
freq,
curr_date,
lambda pro, ts_code, market: (
pro.balancesheet(ts_code=ts_code)
if market == "a_share"
else pro.hk_balancesheet(ts_code=ts_code)
if market == "hk"
else pro.us_balancesheet(ts_code=ts_code)
),
"balance sheet",
)
def get_cashflow(
ticker: str,
freq: str = "quarterly",
curr_date: str | None = None,
) -> str:
return _statement_common(
ticker,
freq,
curr_date,
lambda pro, ts_code, market: (
pro.cashflow(ts_code=ts_code)
if market == "a_share"
else pro.hk_cashflow(ts_code=ts_code)
if market == "hk"
else pro.us_cashflow(ts_code=ts_code)
),
"cashflow",
)
def get_income_statement(
ticker: str,
freq: str = "quarterly",
curr_date: str | None = None,
) -> str:
return _statement_common(
ticker,
freq,
curr_date,
lambda pro, ts_code, market: (
pro.income(ts_code=ts_code)
if market == "a_share"
else pro.hk_income(ts_code=ts_code)
if market == "hk"
else pro.us_income(ts_code=ts_code)
),
"income statement",
)
def get_insider_transactions(ticker: str) -> str:
pro = _get_pro_client()
ts_code = _normalize_ts_code(ticker)
market = _classify_market(ts_code)
if market != "a_share":
raise DataVendorUnavailable(
f"Tushare insider transactions currently support A-share tickers only, got '{ts_code}'."
)
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=365)
try:
data = pro.stk_holdertrade(
ts_code=ts_code,
start_date=start_dt.strftime("%Y%m%d"),
end_date=end_dt.strftime("%Y%m%d"),
)
except Exception as exc:
raise DataVendorUnavailable(
f"Failed to retrieve tushare insider transactions for '{ts_code}': {exc}"
) from exc
if data is None or data.empty:
return f"No tushare insider transactions found for '{ts_code}'."
output = data.rename(
columns={
"ann_date": "AnnouncementDate",
"holder_name": "HolderName",
"holder_type": "HolderType",
"in_de": "Direction",
"change_vol": "ChangeVolume",
"change_ratio": "ChangeRatio",
"after_share": "AfterShareholding",
"after_ratio": "AfterRatio",
"avg_price": "AveragePrice",
"total_share": "TotalShareholding",
"begin_date": "StartDate",
"close_date": "EndDate",
}
).copy()
for col in ("AnnouncementDate", "StartDate", "EndDate"):
if col in output.columns:
output[col] = pd.to_datetime(
output[col], format="%Y%m%d", errors="coerce"
).dt.strftime("%Y-%m-%d")
preferred_cols = [
"AnnouncementDate",
"HolderName",
"HolderType",
"Direction",
"ChangeVolume",
"ChangeRatio",
"AfterShareholding",
"AfterRatio",
"AveragePrice",
"TotalShareholding",
"StartDate",
"EndDate",
]
existing_cols = [col for col in preferred_cols if col in output.columns]
if existing_cols:
output = output[existing_cols]
sort_col = "AnnouncementDate" if "AnnouncementDate" in output.columns else output.columns[0]
output = output.sort_values(sort_col, ascending=False)
return _to_csv_with_header(output, f"Tushare insider transactions for {ts_code}")