From 6a1f88da24d8e4cc92eb7a10fd19602ab58e1525 Mon Sep 17 00:00:00 2001 From: kimheesu Date: Tue, 1 Jul 2025 10:05:06 +0900 Subject: [PATCH] gemini embedding, gemini search --- cli/utils.py | 4 +- main.py | 4 +- requirements.txt | 1 + .../agents/analysts/fundamentals_analyst.py | 2 +- tradingagents/agents/analysts/news_analyst.py | 2 +- .../agents/analysts/social_media_analyst.py | 2 +- tradingagents/agents/utils/agent_utils.py | 24 +- tradingagents/agents/utils/memory.py | 53 +++- tradingagents/dataflows/interface.py | 288 ++++++++++++------ tradingagents/graph/trading_graph.py | 6 +- 10 files changed, 263 insertions(+), 123 deletions(-) diff --git a/cli/utils.py b/cli/utils.py index 7b9682a6..f9fb2cba 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -142,7 +142,7 @@ def select_shallow_thinking_agent(provider) -> str: "google": [ ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), - ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), + ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"), ], "openrouter": [ ("Meta: Llama 4 Scout", "meta-llama/llama-4-scout:free"), @@ -205,7 +205,7 @@ def select_deep_thinking_agent(provider) -> str: ("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"), ("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"), ("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"), - ("Gemini 2.5 Pro", "gemini-2.5-pro-preview-06-05"), + ("Gemini 2.5 Pro", "gemini-2.5-pro"), ], "openrouter": [ ("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"), diff --git a/main.py b/main.py index 6c8ae3d9..4b895199 100644 --- a/main.py +++ b/main.py @@ -5,8 +5,8 @@ from tradingagents.default_config import DEFAULT_CONFIG config = DEFAULT_CONFIG.copy() config["llm_provider"] = "google" # Use a different model config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend -config["deep_think_llm"] = "gemini-2.0-flash" # Use a different model -config["quick_think_llm"] = "gemini-2.0-flash" # Use a different model +config["deep_think_llm"] = "gemini-2.5-flash" # Use a different model +config["quick_think_llm"] = "gemini-2.5-flash" # Use a different model config["max_debate_rounds"] = 1 # Increase debate rounds config["online_tools"] = True # Increase debate rounds diff --git a/requirements.txt b/requirements.txt index a6154cd2..097dcdb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,3 +24,4 @@ rich questionary langchain_anthropic langchain-google-genai +google-genai diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 716d4de1..4412864f 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -10,7 +10,7 @@ def create_fundamentals_analyst(llm, toolkit): company_name = state["company_of_interest"] if toolkit.config["online_tools"]: - tools = [toolkit.get_fundamentals_openai] + tools = [toolkit.get_fundamentals] else: tools = [ toolkit.get_finnhub_company_insider_sentiment, diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index e1f03aa4..8e2d9c6a 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -9,7 +9,7 @@ def create_news_analyst(llm, toolkit): ticker = state["company_of_interest"] if toolkit.config["online_tools"]: - tools = [toolkit.get_global_news_openai, toolkit.get_google_news] + tools = [toolkit.get_global_news, toolkit.get_google_news] else: tools = [ toolkit.get_finnhub_news, diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index d556f73a..32d21a6b 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -10,7 +10,7 @@ def create_social_media_analyst(llm, toolkit): company_name = state["company_of_interest"] if toolkit.config["online_tools"]: - tools = [toolkit.get_stock_news_openai] + tools = [toolkit.get_stock_news] else: tools = [ toolkit.get_reddit_stock_info, diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 0b07f044..27597b45 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -363,12 +363,12 @@ class Toolkit: @staticmethod @tool - def get_stock_news_openai( + def get_stock_news( ticker: Annotated[str, "the company's ticker"], curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], ): """ - Retrieve the latest news about a given stock by using OpenAI's news API. + Retrieve the latest news about a given stock by using LLM's web search capabilities. Args: ticker (str): Ticker of a company. e.g. AAPL, TSM curr_date (str): Current date in yyyy-mm-dd format @@ -376,35 +376,35 @@ class Toolkit: str: A formatted string containing the latest news about the company on the given date. """ - openai_news_results = interface.get_stock_news_openai(ticker, curr_date) + results = interface.get_stock_news(ticker, curr_date) - return openai_news_results + return results @staticmethod @tool - def get_global_news_openai( + def get_global_news( curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], ): """ - Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API. + Retrieve the latest macroeconomics news on a given date using LLM's web search capabilities. Args: curr_date (str): Current date in yyyy-mm-dd format Returns: str: A formatted string containing the latest macroeconomic news on the given date. """ - openai_news_results = interface.get_global_news_openai(curr_date) + results = interface.get_global_news(curr_date) - return openai_news_results + return results @staticmethod @tool - def get_fundamentals_openai( + def get_fundamentals( ticker: Annotated[str, "the company's ticker"], curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], ): """ - Retrieve the latest fundamental information about a given stock on a given date by using OpenAI's news API. + Retrieve the latest fundamental information about a given stock on a given date by using LLM's web search capabilities. Args: ticker (str): Ticker of a company. e.g. AAPL, TSM curr_date (str): Current date in yyyy-mm-dd format @@ -412,8 +412,8 @@ class Toolkit: str: A formatted string containing the latest fundamental information about the company on the given date. """ - openai_fundamentals_results = interface.get_fundamentals_openai( + results = interface.get_fundamentals( ticker, curr_date ) - return openai_fundamentals_results + return results diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 69b8ab8c..91f63211 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,25 +1,56 @@ import chromadb from chromadb.config import Settings from openai import OpenAI - +import os +from google import genai class FinancialSituationMemory: def __init__(self, name, config): - if config["backend_url"] == "http://localhost:11434/v1": - self.embedding = "nomic-embed-text" + self.config = config + self.backend_url = config["backend_url"] + + # Determine embedding configuration based on provider + if self.backend_url == "http://localhost:11434/v1": + # Ollama + self.embedding_model = "nomic-embed-text" + self.use_openai_api = True + elif "openai.com" in self.backend_url: + # OpenAI + self.embedding_model = "text-embedding-3-small" + self.use_openai_api = True + elif "generativelanguage.googleapis.com" in self.backend_url: + # Google Gemini API + self.embedding_model = "gemini-embedding-exp-03-07" # Use Google's embedding model + self.use_openai_api = False else: - self.embedding = "text-embedding-3-small" - self.client = OpenAI(base_url=config["backend_url"]) + # Default to OpenAI-compatible + self.embedding_model = "text-embedding-3-small" + self.use_openai_api = True + + # Initialize clients + if self.use_openai_api: + self.client = OpenAI(base_url=self.backend_url) + else: + self.client = genai.Client() + self.chroma_client = chromadb.Client(Settings(allow_reset=True)) self.situation_collection = self.chroma_client.create_collection(name=name) def get_embedding(self, text): - """Get OpenAI embedding for a text""" + """Get embedding for a text using the appropriate API""" - response = self.client.embeddings.create( - model=self.embedding, input=text - ) - return response.data[0].embedding + if self.use_openai_api: + # Use OpenAI-compatible API + response = self.client.embeddings.create( + model=self.embedding_model, input=text + ) + return response.data[0].embedding + else: + response = self.client.models.embed_content( + model=self.embedding_model, + contents=text + ) + return response.embeddings[0].values def add_situations(self, situations_and_advice): """Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)""" @@ -45,7 +76,7 @@ class FinancialSituationMemory: ) def get_memories(self, current_situation, n_matches=1): - """Find matching recommendations using OpenAI embeddings""" + """Find matching recommendations using embeddings""" query_embedding = self.get_embedding(current_situation) results = self.situation_collection.query( diff --git a/tradingagents/dataflows/interface.py b/tradingagents/dataflows/interface.py index 7fffbb4f..aa489441 100644 --- a/tradingagents/dataflows/interface.py +++ b/tradingagents/dataflows/interface.py @@ -702,106 +702,214 @@ def get_YFin_data( return filtered_data -def get_stock_news_openai(ticker, curr_date): +def get_stock_news(ticker, curr_date): config = get_config() - client = OpenAI(base_url=config["backend_url"]) + + # Check if using Google API - implement grounding with Google Search + if "generativelanguage.googleapis.com" in config["backend_url"]: + try: + from google import genai + from google.genai.types import Tool, GenerateContentConfig, GoogleSearch + + client = genai.Client() + + # Create Google Search grounding tool + google_search_tool = Tool( + google_search=GoogleSearch() + ) + + # Generate content with grounding + response = client.models.generate_content( + model=config["quick_think_llm"], + contents=f"Can you search for recent social media and news about {ticker} stock from 7 days before {curr_date} to {curr_date}? Focus on sentiment, price movements, and any significant developments that could impact trading decisions.", + config=GenerateContentConfig( + tools=[google_search_tool], + response_modalities=["TEXT"] + ) + ) + + # Extract text from response + result_text = "" + for part in response.candidates[0].content.parts: + if hasattr(part, 'text'): + result_text += part.text + + return result_text + + except Exception as e: + return f"Error retrieving stock news for {ticker}: {str(e)}" + else: + # For OpenAI and other APIs, use original implementation + client = OpenAI(base_url=config["backend_url"]) - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) + response = client.responses.create( + model=config["quick_think_llm"], + input=[ + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.", + } + ], + } + ], + text={"format": {"type": "text"}}, + reasoning={}, + tools=[ + { + "type": "web_search_preview", + "user_location": {"type": "approximate"}, + "search_context_size": "low", + } + ], + temperature=1, + max_output_tokens=4096, + top_p=1, + store=True, + ) - return response.output[1].content[0].text + return response.output[1].content[0].text -def get_global_news_openai(curr_date): +def get_global_news(curr_date): config = get_config() - client = OpenAI(base_url=config["backend_url"]) + + # Check if using Google API - implement grounding with Google Search + if "generativelanguage.googleapis.com" in config["backend_url"]: + try: + from google import genai + from google.genai.types import Tool, GenerateContentConfig, GoogleSearch + + client = genai.Client() + + # Create Google Search grounding tool + google_search_tool = Tool( + google_search=GoogleSearch() + ) + + # Generate content with grounding + response = client.models.generate_content( + model=config["quick_think_llm"], + contents=f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions.", + config=GenerateContentConfig( + tools=[google_search_tool], + response_modalities=["TEXT"] + ) + ) + + # Extract text from response + result_text = "" + for part in response.candidates[0].content.parts: + if hasattr(part, 'text'): + result_text += part.text + + return result_text + + except Exception as e: + return f"Error retrieving global news: {str(e)}" + else: + # For OpenAI and other APIs, use original implementation + client = OpenAI(base_url=config["backend_url"]) - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) + response = client.responses.create( + model=config["quick_think_llm"], + input=[ + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.", + } + ], + } + ], + text={"format": {"type": "text"}}, + reasoning={}, + tools=[ + { + "type": "web_search_preview", + "user_location": {"type": "approximate"}, + "search_context_size": "low", + } + ], + temperature=1, + max_output_tokens=4096, + top_p=1, + store=True, + ) - return response.output[1].content[0].text + return response.output[1].content[0].text -def get_fundamentals_openai(ticker, curr_date): +def get_fundamentals(ticker, curr_date): config = get_config() - client = OpenAI(base_url=config["backend_url"]) + + # Check if using Google API - implement grounding with Google Search + if "generativelanguage.googleapis.com" in config["backend_url"]: + try: + from google import genai + from google.genai.types import Tool, GenerateContentConfig, GoogleSearch + + client = genai.Client() + + # Create Google Search grounding tool + google_search_tool = Tool( + google_search=GoogleSearch() + ) + + # Generate content with grounding + response = client.models.generate_content( + model=config["quick_think_llm"], + contents=f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format.", + config=GenerateContentConfig( + tools=[google_search_tool], + response_modalities=["TEXT"] + ) + ) + + # Extract text from response + result_text = "" + for part in response.candidates[0].content.parts: + if hasattr(part, 'text'): + result_text += part.text + + return result_text + + except Exception as e: + return f"Error retrieving fundamentals for {ticker}: {str(e)}" + else: + # For OpenAI and other APIs, use original implementation + client = OpenAI(base_url=config["backend_url"]) - response = client.responses.create( - model=config["quick_think_llm"], - input=[ - { - "role": "system", - "content": [ - { - "type": "input_text", - "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc", - } - ], - } - ], - text={"format": {"type": "text"}}, - reasoning={}, - tools=[ - { - "type": "web_search_preview", - "user_location": {"type": "approximate"}, - "search_context_size": "low", - } - ], - temperature=1, - max_output_tokens=4096, - top_p=1, - store=True, - ) + response = client.responses.create( + model=config["quick_think_llm"], + input=[ + { + "role": "system", + "content": [ + { + "type": "input_text", + "text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc", + } + ], + } + ], + text={"format": {"type": "text"}}, + reasoning={}, + tools=[ + { + "type": "web_search_preview", + "user_location": {"type": "approximate"}, + "search_context_size": "low", + } + ], + temperature=1, + max_output_tokens=4096, + top_p=1, + store=True, + ) - return response.output[1].content[0].text + return response.output[1].content[0].text diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index eb06cf43..6456ce14 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -125,7 +125,7 @@ class TradingAgentsGraph: "social": ToolNode( [ # online tools - self.toolkit.get_stock_news_openai, + self.toolkit.get_stock_news, # offline tools self.toolkit.get_reddit_stock_info, ] @@ -133,7 +133,7 @@ class TradingAgentsGraph: "news": ToolNode( [ # online tools - self.toolkit.get_global_news_openai, + self.toolkit.get_global_news, self.toolkit.get_google_news, # offline tools self.toolkit.get_finnhub_news, @@ -143,7 +143,7 @@ class TradingAgentsGraph: "fundamentals": ToolNode( [ # online tools - self.toolkit.get_fundamentals_openai, + self.toolkit.get_fundamentals, # offline tools self.toolkit.get_finnhub_company_insider_sentiment, self.toolkit.get_finnhub_company_insider_transactions,