diff --git a/tradingagents/dataflows/local.py b/tradingagents/dataflows/local.py index 0591881a..b80dfa4f 100644 --- a/tradingagents/dataflows/local.py +++ b/tradingagents/dataflows/local.py @@ -418,6 +418,18 @@ def get_reddit_global_news( Returns: str: 包含 Reddit 上最新新聞文章貼文的格式化字串。 """ + + # 檢查數據目錄是否存在 + reddit_data_path = os.path.join(DATA_DIR, "reddit_data") + global_news_path = os.path.join(reddit_data_path, "global_news") + + if not os.path.exists(reddit_data_path): + print(f"警告:Reddit 數據目錄不存在: {reddit_data_path}。請確保數據目錄已正確設置。") + return "" + + if not os.path.exists(global_news_path): + print(f"警告:全球新聞數據目錄不存在: {global_news_path}。請確保已下載 Reddit 全球新聞數據。") + return "" curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d") before = curr_date_dt - relativedelta(days=look_back_days) @@ -432,13 +444,17 @@ def get_reddit_global_news( while curr_iter_date <= curr_date_dt: curr_date_str = curr_iter_date.strftime("%Y-%m-%d") - fetch_result = fetch_top_from_category( - "global_news", - curr_date_str, - limit, - data_path=os.path.join(DATA_DIR, "reddit_data"), - ) - posts.extend(fetch_result) + try: + fetch_result = fetch_top_from_category( + "global_news", + curr_date_str, + limit, + data_path=reddit_data_path, + ) + posts.extend(fetch_result) + except (FileNotFoundError, ValueError) as e: + # 如果特定日期的數據不存在,繼續下一天 + print(f"警告:無法獲取 {curr_date_str} 的數據: {e}") curr_iter_date += relativedelta(days=1) pbar.update(1) @@ -471,6 +487,18 @@ def get_reddit_company_news( Returns: str: 包含 Reddit 上新聞文章貼文的格式化字串。 """ + + # 檢查數據目錄是否存在 + reddit_data_path = os.path.join(DATA_DIR, "reddit_data") + company_news_path = os.path.join(reddit_data_path, "company_news") + + if not os.path.exists(reddit_data_path): + print(f"警告:Reddit 數據目錄不存在: {reddit_data_path}。請確保數據目錄已正確設置。") + return "" + + if not os.path.exists(company_news_path): + print(f"警告:公司新聞數據目錄不存在: {company_news_path}。請確保已下載 Reddit 公司新聞數據。") + return "" start_date_dt = datetime.strptime(start_date, "%Y-%m-%d") end_date_dt = datetime.strptime(end_date, "%Y-%m-%d") @@ -490,14 +518,18 @@ def get_reddit_company_news( while curr_date <= end_date_dt: curr_date_str = curr_date.strftime("%Y-%m-%d") - fetch_result = fetch_top_from_category( - "company_news", - curr_date_str, - max_per_day, - query, - data_path=os.path.join(DATA_DIR, "reddit_data"), - ) - posts.extend(fetch_result) + try: + fetch_result = fetch_top_from_category( + "company_news", + curr_date_str, + max_per_day, + query, + data_path=reddit_data_path, + ) + posts.extend(fetch_result) + except (FileNotFoundError, ValueError) as e: + # 如果特定日期的數據不存在,繼續下一天 + print(f"警告:無法獲取 {curr_date_str} 的數據: {e}") curr_date += relativedelta(days=1) pbar.update(1) diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 80084e8e..809bd656 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -3,7 +3,7 @@ import os DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"), - "data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data", + "data_dir": os.path.join(os.path.expanduser("~"), "Documents/Code/ScAI/FR1-data"), "data_cache_dir": os.path.join( os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "dataflows/data_cache", @@ -23,11 +23,12 @@ DEFAULT_CONFIG = { "core_stock_apis": "yfinance", # 選項: yfinance, alpha_vantage, local "technical_indicators": "yfinance", # 選項: yfinance, alpha_vantage, local "fundamental_data": "alpha_vantage", # 選項: openai, alpha_vantage, local - "news_data": "alpha_vantage", # 選項: openai, alpha_vantage, google, local + "news_data": "openai", # 選項: openai, alpha_vantage, google, local }, # 工具層級設定 (優先於類別層級設定) "tool_vendors": { # 範例: "get_stock_data": "alpha_vantage", # 覆寫類別預設值 # 範例: "get_news": "openai", # 覆寫類別預設值 + "get_global_news": "openai", # get_global_news 不支持 alpha_vantage,使用 openai 作為主要供應商 }, } \ No newline at end of file