From a679d61724cb34a1df89ad02fb051a7e258b9802 Mon Sep 17 00:00:00 2001 From: kimheesu Date: Tue, 1 Jul 2025 10:59:44 +0900 Subject: [PATCH] [add] searh using GoogleSearch grounding --- .gitignore | 1 + cli/utils.py | 5 +- main.py | 15 +- .../utils/embedding_provider_factory.py | 20 ++ .../agents/utils/embedding_providers.py | 66 ++++++ tradingagents/agents/utils/memory.py | 38 +--- tradingagents/dataflows/interface.py | 211 +----------------- tradingagents/dataflows/search_provider.py | 76 +++++++ .../dataflows/search_provider_factory.py | 19 ++ tradingagents/default_config.py | 14 +- tradingagents/graph/trading_graph.py | 18 +- 11 files changed, 231 insertions(+), 252 deletions(-) create mode 100644 tradingagents/agents/utils/embedding_provider_factory.py create mode 100644 tradingagents/agents/utils/embedding_providers.py create mode 100644 tradingagents/dataflows/search_provider.py create mode 100644 tradingagents/dataflows/search_provider_factory.py diff --git a/.gitignore b/.gitignore index 8313619e..a73fa468 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ src/ eval_results/ eval_data/ *.egg-info/ +results/ \ No newline at end of file diff --git a/cli/utils.py b/cli/utils.py index f9fb2cba..bfc2bba8 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -140,8 +140,9 @@ def select_shallow_thinking_agent(provider) -> str: ("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"), ], "google": [ - ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), + ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), + ("Gemini 2.5 Flash-Lite - Cost efficiency and low latency", "gemini-2.5-flash-lite-preview-06-17"), ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"), ], "openrouter": [ @@ -205,7 +206,7 @@ def select_deep_thinking_agent(provider) -> str: ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), - ("Gemini 2.5 Pro", "gemini-2.5-pro"), + ("Gemini 2.5 Pro - Most powerful Gemini model", "gemini-2.5-pro"), ], "openrouter": [ ("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"), diff --git a/main.py b/main.py index 4b895199..2d9d4cbc 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,17 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph from tradingagents.default_config import DEFAULT_CONFIG +from dotenv import load_dotenv +import os +load_dotenv() # Create a custom config config = DEFAULT_CONFIG.copy() -config["llm_provider"] = "google" # Use a different model -config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend -config["deep_think_llm"] = "gemini-2.5-flash" # Use a different model -config["quick_think_llm"] = "gemini-2.5-flash" # Use a different model -config["max_debate_rounds"] = 1 # Increase debate rounds -config["online_tools"] = True # Increase debate rounds +config["llm_provider"] = os.getenv("LLM_PROVIDER", "openai") # Use a different model +config["backend_url"] = os.getenv("BACKEND_URL", "https://api.openai.com/v1") # Use a different backend +config["deep_think_llm"] = os.getenv("DEEP_THINK_LLM", "o4-mini") # Use a different model +config["quick_think_llm"] = os.getenv("QUICK_THINK_LLM", "gpt-4o-mini") # Use a different model +config["max_debate_rounds"] = int(os.getenv("MAX_DEBATE_ROUNDS", 1)) # Increase debate rounds +config["online_tools"] = bool(os.getenv("ONLINE_TOOLS", "True")) # Increase debate rounds # Initialize with custom config ta = TradingAgentsGraph(debug=True, config=config) diff --git a/tradingagents/agents/utils/embedding_provider_factory.py b/tradingagents/agents/utils/embedding_provider_factory.py new file mode 100644 index 00000000..8e004ebd --- /dev/null +++ b/tradingagents/agents/utils/embedding_provider_factory.py @@ -0,0 +1,20 @@ +from .embedding_providers import ( + EmbeddingProvider, + OpenAIEmbeddingProvider, + GeminiEmbeddingProvider, + OllamaEmbeddingProvider +) + + +class EmbeddingProviderFactory: + @staticmethod + def create_provider(config : dict[str, any])->EmbeddingProvider: + backend_url = config["backend_url"] + + if "generativelanguage.googleapis.com" in backend_url: + return GeminiEmbeddingProvider(backend_url) + elif "localhost:11434" in backend_url: + return OllamaEmbeddingProvider(backend_url) + else: + return OpenAIEmbeddingProvider(backend_url) + \ No newline at end of file diff --git a/tradingagents/agents/utils/embedding_providers.py b/tradingagents/agents/utils/embedding_providers.py new file mode 100644 index 00000000..7dd8ea15 --- /dev/null +++ b/tradingagents/agents/utils/embedding_providers.py @@ -0,0 +1,66 @@ +from abc import ABC, abstractmethod +from openai import OpenAI +from google import genai + + +class EmbeddingProvider(ABC): + @abstractmethod + def get_embedding(self, text: str)->list[float]: + pass + + @property + @abstractmethod + def model_name(self)->str: + pass + + +class OpenAIEmbeddingProvider(EmbeddingProvider): + def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"): + self.client = OpenAI(base_url=backend_url) + self._embedding_model = embedding_model + + + def get_embedding(self, text: str)->list[float]: + response = self.client.embeddings.create( + model=self._embedding_model, + input=text + ) + return response.data[0].embedding + + @property + def model_name(self)->str: + return self._embedding_model + + +class GeminiEmbeddingProvider(EmbeddingProvider): + def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"): + self.client = genai.Client() + self._embedding_model = embedding_model + + def get_embedding(self, text: str)->list[float]: + response = self.client.models.embed_content( + model=self._embedding_model, + contents=text + ) + return response.embeddings[0].values + + @property + def model_name(self)->str: + return self._embedding_model + +class OllamaEmbeddingProvider(EmbeddingProvider): + def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"): + self.client = OpenAI(base_url=backend_url) + self._embedding_model = embedding_model + + def get_embedding(self, text: str)->list[float]: + response = self.client.embeddings.create( + model=self._embedding_model, + input=text + ) + return response.data[0].embedding + + @property + def model_name(self)->str: + return self._embedding_model + \ No newline at end of file diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 91f63211..9faadb09 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -2,6 +2,7 @@ import chromadb from chromadb.config import Settings from openai import OpenAI import os +from .embedding_provider_factory import EmbeddingProviderFactory from google import genai class FinancialSituationMemory: @@ -9,29 +10,7 @@ class FinancialSituationMemory: self.config = config self.backend_url = config["backend_url"] - # Determine embedding configuration based on provider - if self.backend_url == "http://localhost:11434/v1": - # Ollama - self.embedding_model = "nomic-embed-text" - self.use_openai_api = True - elif "openai.com" in self.backend_url: - # OpenAI - self.embedding_model = "text-embedding-3-small" - self.use_openai_api = True - elif "generativelanguage.googleapis.com" in self.backend_url: - # Google Gemini API - self.embedding_model = "gemini-embedding-exp-03-07" # Use Google's embedding model - self.use_openai_api = False - else: - # Default to OpenAI-compatible - self.embedding_model = "text-embedding-3-small" - self.use_openai_api = True - - # Initialize clients - if self.use_openai_api: - self.client = OpenAI(base_url=self.backend_url) - else: - self.client = genai.Client() + self.embedding_provider = EmbeddingProviderFactory.create_provider(config) self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) @@ -39,18 +18,7 @@ class FinancialSituationMemory: def get_embedding(self, text): """Get embedding for a text using the appropriate API""" - if self.use_openai_api: - # Use OpenAI-compatible API - response = self.client.embeddings.create( - model=self.embedding_model, input=text - ) - return response.data[0].embedding - else: - response = self.client.models.embed_content( - model=self.embedding_model, - contents=text - ) - return response.embeddings[0].values + return self.embedding_provider.get_embedding(text) def add_situations(self, situations_and_advice): """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index aa489441..0b45f3ab 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -14,6 +14,7 @@ from tqdm import tqdm import yfinance as yf from openai import OpenAI from .config import get_config, set_config, DATA_DIR +from .search_provider_factory import SearchProviderFactory def get_finnhub_news( @@ -704,212 +705,24 @@ def get_YFin_data( def get_stock_news(ticker, curr_date): config = get_config() + search_provider = SearchProviderFactory.create_provider(config) + query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period." + return search_provider.search(query, ticker, curr_date) - # Check if using Google API - implement grounding with Google Search - if "generativelanguage.googleapis.com" in config["backend_url"]: - try: - from google import genai - from google.genai.types import Tool, GenerateContentConfig, GoogleSearch - - client = genai.Client() - - # Create Google Search grounding tool - google_search_tool = Tool( - google_search=GoogleSearch() - ) - - # Generate content with grounding - response = client.models.generate_content( - model=config["quick_think_llm"], - contents=f"Can you search for recent social media and news about {ticker} stock from 7 days before {curr_date} to {curr_date}? Focus on sentiment, price movements, and any significant developments that could impact trading decisions.", - config=GenerateContentConfig( - tools=[google_search_tool], - response_modalities=["TEXT"] - ) - ) - - # Extract text from response - result_text = "" - for part in response.candidates[0].content.parts: - if hasattr(part, 'text'): - result_text += part.text - - return result_text - - except Exception as e: - return f"Error retrieving stock news for {ticker}: {str(e)}" - else: - # For OpenAI and other APIs, use original implementation - client = OpenAI(base_url=config["backend_url"]) - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) - - return response.output[1].content[0].text def get_global_news(curr_date): config = get_config() + search_provider = SearchProviderFactory.create_provider(config) + query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions." + return search_provider.search(query, curr_date) - # Check if using Google API - implement grounding with Google Search - if "generativelanguage.googleapis.com" in config["backend_url"]: - try: - from google import genai - from google.genai.types import Tool, GenerateContentConfig, GoogleSearch - - client = genai.Client() - - # Create Google Search grounding tool - google_search_tool = Tool( - google_search=GoogleSearch() - ) - - # Generate content with grounding - response = client.models.generate_content( - model=config["quick_think_llm"], - contents=f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions.", - config=GenerateContentConfig( - tools=[google_search_tool], - response_modalities=["TEXT"] - ) - ) - - # Extract text from response - result_text = "" - for part in response.candidates[0].content.parts: - if hasattr(part, 'text'): - result_text += part.text - - return result_text - - except Exception as e: - return f"Error retrieving global news: {str(e)}" - else: - # For OpenAI and other APIs, use original implementation - client = OpenAI(base_url=config["backend_url"]) - - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) - - return response.output[1].content[0].text - + def get_fundamentals(ticker, curr_date): config = get_config() - - # Check if using Google API - implement grounding with Google Search - if "generativelanguage.googleapis.com" in config["backend_url"]: - try: - from google import genai - from google.genai.types import Tool, GenerateContentConfig, GoogleSearch - - client = genai.Client() - - # Create Google Search grounding tool - google_search_tool = Tool( - google_search=GoogleSearch() - ) - - # Generate content with grounding - response = client.models.generate_content( - model=config["quick_think_llm"], - contents=f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format.", - config=GenerateContentConfig( - tools=[google_search_tool], - response_modalities=["TEXT"] - ) - ) - - # Extract text from response - result_text = "" - for part in response.candidates[0].content.parts: - if hasattr(part, 'text'): - result_text += part.text - - return result_text - - except Exception as e: - return f"Error retrieving fundamentals for {ticker}: {str(e)}" - else: - # For OpenAI and other APIs, use original implementation - client = OpenAI(base_url=config["backend_url"]) - - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) - - return response.output[1].content[0].text + search_provider = SearchProviderFactory.create_provider(config) + query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format." + return search_provider.search(query, ticker, curr_date) + \ No newline at end of file diff --git a/tradingagents/dataflows/search_provider.py b/tradingagents/dataflows/search_provider.py new file mode 100644 index 00000000..e4985bdb --- /dev/null +++ b/tradingagents/dataflows/search_provider.py @@ -0,0 +1,76 @@ +from google import genai +from google.genai.types import Tool, GenerateContentConfig, GoogleSearch +from openai import OpenAI +from abc import ABC, abstractmethod + + + +class SearchProvider(ABC): + @abstractmethod + def search(self, query: str, ticker: str, curr_date: str) -> str: + pass + + +class GoogleSearchProvider(SearchProvider): + def __init__(self, model: str): + self.client = genai.Client() + self.model = model + + def search(self, query: str, ticker: str, curr_date: str) -> str: + google_search_tool = Tool( + google_search=GoogleSearch() + ) + + response = self.client.models.generate_content( + model=self.model, + contents=query, + config=GenerateContentConfig( + tools=[google_search_tool], + response_modalities=["TEXT"] + ) + ) + + + result_text = "" + for part in response.candidates[0].content.parts: + if hasattr(part, 'text'): + result_text += part.text + + return result_text + + +class OpenAISearchProvider(SearchProvider): + def __init__(self, model: str, backend_url: str): + self.client = OpenAI(base_url=backend_url) + self.model = model + + def search(self, query: str, ticker: str, curr_date: str) -> str: + response = self.client.responses.create( + model=self.model, + input=[ + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": query + } + ], + } + ], + text={"format": {"type": "text"}}, + reasoning={}, + tools=[ + { + "type": "web_search_preview", + "user_location": {"type": "approximate"}, + "search_context_size": "low", + } + ], + temperature=1, + max_output_tokens=4096, + top_p=1, + store=True, + ) + + return response.output[1].content[0].text \ No newline at end of file diff --git a/tradingagents/dataflows/search_provider_factory.py b/tradingagents/dataflows/search_provider_factory.py new file mode 100644 index 00000000..3e57ead3 --- /dev/null +++ b/tradingagents/dataflows/search_provider_factory.py @@ -0,0 +1,19 @@ +from .search_provider import ( + SearchProvider, + GoogleSearchProvider, + OpenAISearchProvider +) + + +class SearchProviderFactory: + @staticmethod + def create_provider(config: dict[str, any])->SearchProvider: + backend_url = config["backend_url"] + model = config["quick_think_llm"] + + if "generativelanguage.googleapis.com" in backend_url: + return GoogleSearchProvider(model) + else: + return OpenAISearchProvider(model, backend_url) + + diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index 089e9c24..094ac9db 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -1,4 +1,6 @@ import os +from dotenv import load_dotenv +load_dotenv() DEFAULT_CONFIG = { "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), @@ -9,14 +11,14 @@ DEFAULT_CONFIG = { "dataflows/data_cache", ), # LLM settings - "llm_provider": "openai", - "deep_think_llm": "o4-mini", - "quick_think_llm": "gpt-4o-mini", - "backend_url": "https://api.openai.com/v1", + "llm_provider": os.getenv("LLM_PROVIDER", "openai"), + "deep_think_llm": os.getenv("DEEP_THINK_LLM", "o4-mini"), + "quick_think_llm": os.getenv("QUICK_THINK_LLM", "gpt-4o-mini"), + "backend_url": os.getenv("BACKEND_URL", "https://api.openai.com/v1"), # Debate and discussion settings - "max_debate_rounds": 1, + "max_debate_rounds": int(os.getenv("MAX_DEBATE_ROUNDS", 1)), "max_risk_discuss_rounds": 1, "max_recur_limit": 100, # Tool settings - "online_tools": True, + "online_tools": bool(os.getenv("ONLINE_TOOLS", "True")), } diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 6456ce14..f25d7932 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -170,10 +170,20 @@ class TradingAgentsGraph: trace = [] for chunk in self.graph.stream(init_agent_state, **args): if len(chunk["messages"]) == 0: - pass - else: - chunk["messages"][-1].pretty_print() - trace.append(chunk) + continue + + message = chunk["messages"][-1] + # 중복 메시지 필터링 + if message.content and message.content.strip(): + # FINAL PROPOSAL 중복 방지 + if "FINAL TRANSACTION PROPOSAL:" in message.content: + if not hasattr(self, '_final_printed'): + message.pretty_print() + self._final_printed = True + else: + message.pretty_print() + + trace.append(chunk) final_state = trace[-1] else: