From 37ee6047bb817fbb6517068447c66a54a9f9545f Mon Sep 17 00:00:00 2001 From: kimheesu Date: Tue, 1 Jul 2025 13:30:43 +0900 Subject: [PATCH] [add] use cache --- tradingagents/dataflows/interface.py | 7 ++-- .../dataflows/search_provider_factory.py | 34 +++++++++++++++++-- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index e65f226d..b48dbfb4 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -707,16 +707,15 @@ def get_stock_news(ticker, curr_date): config = get_config() search_provider = SearchProviderFactory.create_provider(config) query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period." - return search_provider.search(query, ticker, curr_date) + return search_provider.search(query) - def get_global_news(curr_date): config = get_config() search_provider = SearchProviderFactory.create_provider(config) query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions." - return search_provider.search(query, curr_date) + return search_provider.search(query) @@ -724,5 +723,5 @@ def get_fundamentals(ticker, curr_date): config = get_config() search_provider = SearchProviderFactory.create_provider(config) query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format." - return search_provider.search(query, ticker, curr_date) + return search_provider.search(query) diff --git a/tradingagents/dataflows/search_provider_factory.py b/tradingagents/dataflows/search_provider_factory.py index 3e57ead3..d0b7afd0 100644 --- a/tradingagents/dataflows/search_provider_factory.py +++ b/tradingagents/dataflows/search_provider_factory.py @@ -3,17 +3,45 @@ from .search_provider import ( GoogleSearchProvider, OpenAISearchProvider ) +import hashlib +import json class SearchProviderFactory: + _cache = {} # 클래스 레벨 캐시 + @staticmethod - def create_provider(config: dict[str, any])->SearchProvider: + def create_provider(config: dict[str, any]) -> SearchProvider: + """ + Create a SearchProvider with caching to avoid creating new instances. + Uses config hash as cache key for efficient reuse. + """ + # Create cache key from relevant config values + cache_key_data = { + "backend_url": config["backend_url"], + "model": config["quick_think_llm"] + } + cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest() + + # Return cached instance if exists + if cache_key in SearchProviderFactory._cache: + return SearchProviderFactory._cache[cache_key] + + # Create new instance backend_url = config["backend_url"] model = config["quick_think_llm"] if "generativelanguage.googleapis.com" in backend_url: - return GoogleSearchProvider(model) + provider = GoogleSearchProvider(model) else: - return OpenAISearchProvider(model, backend_url) + provider = OpenAISearchProvider(model, backend_url) + # Cache and return + SearchProviderFactory._cache[cache_key] = provider + return provider + + @staticmethod + def clear_cache(): + """Clear the provider cache (useful for testing or config changes).""" + SearchProviderFactory._cache.clear()