[add] searh using GoogleSearch grounding

This commit is contained in:
kimheesu 2025-07-01 10:59:44 +09:00
parent 6a1f88da24
commit a679d61724
11 changed files with 231 additions and 252 deletions

1
.gitignore vendored
View File

@ -6,3 +6,4 @@ src/
eval_results/
eval_data/
*.egg-info/
results/

View File

@ -140,8 +140,9 @@ def select_shallow_thinking_agent(provider) -> str:
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
],
"google": [
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash-Lite - Cost efficiency and low latency", "gemini-2.5-flash-lite-preview-06-17"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
],
"openrouter": [
@ -205,7 +206,7 @@ def select_deep_thinking_agent(provider) -> str:
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
("Gemini 2.5 Pro", "gemini-2.5-pro"),
("Gemini 2.5 Pro - Most powerful Gemini model", "gemini-2.5-pro"),
],
"openrouter": [
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),

15
main.py
View File

@ -1,14 +1,17 @@
from tradingagents.graph.trading_graph import TradingAgentsGraph
from tradingagents.default_config import DEFAULT_CONFIG
from dotenv import load_dotenv
import os
load_dotenv()
# Create a custom config
config = DEFAULT_CONFIG.copy()
config["llm_provider"] = "google" # Use a different model
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
config["deep_think_llm"] = "gemini-2.5-flash" # Use a different model
config["quick_think_llm"] = "gemini-2.5-flash" # Use a different model
config["max_debate_rounds"] = 1 # Increase debate rounds
config["online_tools"] = True # Increase debate rounds
config["llm_provider"] = os.getenv("LLM_PROVIDER", "openai") # Use a different model
config["backend_url"] = os.getenv("BACKEND_URL", "https://api.openai.com/v1") # Use a different backend
config["deep_think_llm"] = os.getenv("DEEP_THINK_LLM", "o4-mini") # Use a different model
config["quick_think_llm"] = os.getenv("QUICK_THINK_LLM", "gpt-4o-mini") # Use a different model
config["max_debate_rounds"] = int(os.getenv("MAX_DEBATE_ROUNDS", 1)) # Increase debate rounds
config["online_tools"] = bool(os.getenv("ONLINE_TOOLS", "True")) # Increase debate rounds
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)

View File

@ -0,0 +1,20 @@
from .embedding_providers import (
EmbeddingProvider,
OpenAIEmbeddingProvider,
GeminiEmbeddingProvider,
OllamaEmbeddingProvider
)
class EmbeddingProviderFactory:
@staticmethod
def create_provider(config : dict[str, any])->EmbeddingProvider:
backend_url = config["backend_url"]
if "generativelanguage.googleapis.com" in backend_url:
return GeminiEmbeddingProvider(backend_url)
elif "localhost:11434" in backend_url:
return OllamaEmbeddingProvider(backend_url)
else:
return OpenAIEmbeddingProvider(backend_url)

View File

@ -0,0 +1,66 @@
from abc import ABC, abstractmethod
from openai import OpenAI
from google import genai
class EmbeddingProvider(ABC):
@abstractmethod
def get_embedding(self, text: str)->list[float]:
pass
@property
@abstractmethod
def model_name(self)->str:
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
self.client = OpenAI(base_url=backend_url)
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.embeddings.create(
model=self._embedding_model,
input=text
)
return response.data[0].embedding
@property
def model_name(self)->str:
return self._embedding_model
class GeminiEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
self.client = genai.Client()
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.models.embed_content(
model=self._embedding_model,
contents=text
)
return response.embeddings[0].values
@property
def model_name(self)->str:
return self._embedding_model
class OllamaEmbeddingProvider(EmbeddingProvider):
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
self.client = OpenAI(base_url=backend_url)
self._embedding_model = embedding_model
def get_embedding(self, text: str)->list[float]:
response = self.client.embeddings.create(
model=self._embedding_model,
input=text
)
return response.data[0].embedding
@property
def model_name(self)->str:
return self._embedding_model

View File

@ -2,6 +2,7 @@ import chromadb
from chromadb.config import Settings
from openai import OpenAI
import os
from .embedding_provider_factory import EmbeddingProviderFactory
from google import genai
class FinancialSituationMemory:
@ -9,29 +10,7 @@ class FinancialSituationMemory:
self.config = config
self.backend_url = config["backend_url"]
# Determine embedding configuration based on provider
if self.backend_url == "http://localhost:11434/v1":
# Ollama
self.embedding_model = "nomic-embed-text"
self.use_openai_api = True
elif "openai.com" in self.backend_url:
# OpenAI
self.embedding_model = "text-embedding-3-small"
self.use_openai_api = True
elif "generativelanguage.googleapis.com" in self.backend_url:
# Google Gemini API
self.embedding_model = "gemini-embedding-exp-03-07" # Use Google's embedding model
self.use_openai_api = False
else:
# Default to OpenAI-compatible
self.embedding_model = "text-embedding-3-small"
self.use_openai_api = True
# Initialize clients
if self.use_openai_api:
self.client = OpenAI(base_url=self.backend_url)
else:
self.client = genai.Client()
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
@ -39,18 +18,7 @@ class FinancialSituationMemory:
def get_embedding(self, text):
"""Get embedding for a text using the appropriate API"""
if self.use_openai_api:
# Use OpenAI-compatible API
response = self.client.embeddings.create(
model=self.embedding_model, input=text
)
return response.data[0].embedding
else:
response = self.client.models.embed_content(
model=self.embedding_model,
contents=text
)
return response.embeddings[0].values
return self.embedding_provider.get_embedding(text)
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""

View File

@ -14,6 +14,7 @@ from tqdm import tqdm
import yfinance as yf
from openai import OpenAI
from .config import get_config, set_config, DATA_DIR
from .search_provider_factory import SearchProviderFactory
def get_finnhub_news(
@ -704,212 +705,24 @@ def get_YFin_data(
def get_stock_news(ticker, curr_date):
config = get_config()
search_provider = SearchProviderFactory.create_provider(config)
query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
return search_provider.search(query, ticker, curr_date)
# Check if using Google API - implement grounding with Google Search
if "generativelanguage.googleapis.com" in config["backend_url"]:
try:
from google import genai
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
client = genai.Client()
# Create Google Search grounding tool
google_search_tool = Tool(
google_search=GoogleSearch()
)
# Generate content with grounding
response = client.models.generate_content(
model=config["quick_think_llm"],
contents=f"Can you search for recent social media and news about {ticker} stock from 7 days before {curr_date} to {curr_date}? Focus on sentiment, price movements, and any significant developments that could impact trading decisions.",
config=GenerateContentConfig(
tools=[google_search_tool],
response_modalities=["TEXT"]
)
)
# Extract text from response
result_text = ""
for part in response.candidates[0].content.parts:
if hasattr(part, 'text'):
result_text += part.text
return result_text
except Exception as e:
return f"Error retrieving stock news for {ticker}: {str(e)}"
else:
# For OpenAI and other APIs, use original implementation
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
def get_global_news(curr_date):
config = get_config()
search_provider = SearchProviderFactory.create_provider(config)
query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions."
return search_provider.search(query, curr_date)
# Check if using Google API - implement grounding with Google Search
if "generativelanguage.googleapis.com" in config["backend_url"]:
try:
from google import genai
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
client = genai.Client()
# Create Google Search grounding tool
google_search_tool = Tool(
google_search=GoogleSearch()
)
# Generate content with grounding
response = client.models.generate_content(
model=config["quick_think_llm"],
contents=f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions.",
config=GenerateContentConfig(
tools=[google_search_tool],
response_modalities=["TEXT"]
)
)
# Extract text from response
result_text = ""
for part in response.candidates[0].content.parts:
if hasattr(part, 'text'):
result_text += part.text
return result_text
except Exception as e:
return f"Error retrieving global news: {str(e)}"
else:
# For OpenAI and other APIs, use original implementation
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
def get_fundamentals(ticker, curr_date):
config = get_config()
# Check if using Google API - implement grounding with Google Search
if "generativelanguage.googleapis.com" in config["backend_url"]:
try:
from google import genai
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
client = genai.Client()
# Create Google Search grounding tool
google_search_tool = Tool(
google_search=GoogleSearch()
)
# Generate content with grounding
response = client.models.generate_content(
model=config["quick_think_llm"],
contents=f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format.",
config=GenerateContentConfig(
tools=[google_search_tool],
response_modalities=["TEXT"]
)
)
# Extract text from response
result_text = ""
for part in response.candidates[0].content.parts:
if hasattr(part, 'text'):
result_text += part.text
return result_text
except Exception as e:
return f"Error retrieving fundamentals for {ticker}: {str(e)}"
else:
# For OpenAI and other APIs, use original implementation
client = OpenAI(base_url=config["backend_url"])
response = client.responses.create(
model=config["quick_think_llm"],
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text
search_provider = SearchProviderFactory.create_provider(config)
query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format."
return search_provider.search(query, ticker, curr_date)

View File

@ -0,0 +1,76 @@
from google import genai
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
from openai import OpenAI
from abc import ABC, abstractmethod
class SearchProvider(ABC):
@abstractmethod
def search(self, query: str, ticker: str, curr_date: str) -> str:
pass
class GoogleSearchProvider(SearchProvider):
def __init__(self, model: str):
self.client = genai.Client()
self.model = model
def search(self, query: str, ticker: str, curr_date: str) -> str:
google_search_tool = Tool(
google_search=GoogleSearch()
)
response = self.client.models.generate_content(
model=self.model,
contents=query,
config=GenerateContentConfig(
tools=[google_search_tool],
response_modalities=["TEXT"]
)
)
result_text = ""
for part in response.candidates[0].content.parts:
if hasattr(part, 'text'):
result_text += part.text
return result_text
class OpenAISearchProvider(SearchProvider):
def __init__(self, model: str, backend_url: str):
self.client = OpenAI(base_url=backend_url)
self.model = model
def search(self, query: str, ticker: str, curr_date: str) -> str:
response = self.client.responses.create(
model=self.model,
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": query
}
],
}
],
text={"format": {"type": "text"}},
reasoning={},
tools=[
{
"type": "web_search_preview",
"user_location": {"type": "approximate"},
"search_context_size": "low",
}
],
temperature=1,
max_output_tokens=4096,
top_p=1,
store=True,
)
return response.output[1].content[0].text

View File

@ -0,0 +1,19 @@
from .search_provider import (
SearchProvider,
GoogleSearchProvider,
OpenAISearchProvider
)
class SearchProviderFactory:
@staticmethod
def create_provider(config: dict[str, any])->SearchProvider:
backend_url = config["backend_url"]
model = config["quick_think_llm"]
if "generativelanguage.googleapis.com" in backend_url:
return GoogleSearchProvider(model)
else:
return OpenAISearchProvider(model, backend_url)

View File

@ -1,4 +1,6 @@
import os
from dotenv import load_dotenv
load_dotenv()
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
@ -9,14 +11,14 @@ DEFAULT_CONFIG = {
"dataflows/data_cache",
),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"backend_url": "https://api.openai.com/v1",
"llm_provider": os.getenv("LLM_PROVIDER", "openai"),
"deep_think_llm": os.getenv("DEEP_THINK_LLM", "o4-mini"),
"quick_think_llm": os.getenv("QUICK_THINK_LLM", "gpt-4o-mini"),
"backend_url": os.getenv("BACKEND_URL", "https://api.openai.com/v1"),
# Debate and discussion settings
"max_debate_rounds": 1,
"max_debate_rounds": int(os.getenv("MAX_DEBATE_ROUNDS", 1)),
"max_risk_discuss_rounds": 1,
"max_recur_limit": 100,
# Tool settings
"online_tools": True,
"online_tools": bool(os.getenv("ONLINE_TOOLS", "True")),
}

View File

@ -170,10 +170,20 @@ class TradingAgentsGraph:
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
pass
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
continue
message = chunk["messages"][-1]
# 중복 메시지 필터링
if message.content and message.content.strip():
# FINAL PROPOSAL 중복 방지
if "FINAL TRANSACTION PROPOSAL:" in message.content:
if not hasattr(self, '_final_printed'):
message.pretty_print()
self._final_printed = True
else:
message.pretty_print()
trace.append(chunk)
final_state = trace[-1]
else: