[add] searh using GoogleSearch grounding
This commit is contained in:
parent
6a1f88da24
commit
a679d61724
|
|
@ -6,3 +6,4 @@ src/
|
||||||
eval_results/
|
eval_results/
|
||||||
eval_data/
|
eval_data/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
|
results/
|
||||||
|
|
@ -140,8 +140,9 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
|
("Claude Sonnet 4 - High performance and excellent reasoning", "claude-sonnet-4-0"),
|
||||||
],
|
],
|
||||||
"google": [
|
"google": [
|
||||||
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
|
|
||||||
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
|
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
|
||||||
|
("Gemini 2.5 Flash-Lite - Cost efficiency and low latency", "gemini-2.5-flash-lite-preview-06-17"),
|
||||||
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
|
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash"),
|
||||||
],
|
],
|
||||||
"openrouter": [
|
"openrouter": [
|
||||||
|
|
@ -205,7 +206,7 @@ def select_deep_thinking_agent(provider) -> str:
|
||||||
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
|
("Gemini 2.0 Flash-Lite - Cost efficiency and low latency", "gemini-2.0-flash-lite"),
|
||||||
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
|
("Gemini 2.0 Flash - Next generation features, speed, and thinking", "gemini-2.0-flash"),
|
||||||
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
|
("Gemini 2.5 Flash - Adaptive thinking, cost efficiency", "gemini-2.5-flash-preview-05-20"),
|
||||||
("Gemini 2.5 Pro", "gemini-2.5-pro"),
|
("Gemini 2.5 Pro - Most powerful Gemini model", "gemini-2.5-pro"),
|
||||||
],
|
],
|
||||||
"openrouter": [
|
"openrouter": [
|
||||||
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
|
("DeepSeek V3 - a 685B-parameter, mixture-of-experts model", "deepseek/deepseek-chat-v3-0324:free"),
|
||||||
|
|
|
||||||
15
main.py
15
main.py
|
|
@ -1,14 +1,17 @@
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
# Create a custom config
|
# Create a custom config
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "google" # Use a different model
|
config["llm_provider"] = os.getenv("LLM_PROVIDER", "openai") # Use a different model
|
||||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
|
config["backend_url"] = os.getenv("BACKEND_URL", "https://api.openai.com/v1") # Use a different backend
|
||||||
config["deep_think_llm"] = "gemini-2.5-flash" # Use a different model
|
config["deep_think_llm"] = os.getenv("DEEP_THINK_LLM", "o4-mini") # Use a different model
|
||||||
config["quick_think_llm"] = "gemini-2.5-flash" # Use a different model
|
config["quick_think_llm"] = os.getenv("QUICK_THINK_LLM", "gpt-4o-mini") # Use a different model
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["max_debate_rounds"] = int(os.getenv("MAX_DEBATE_ROUNDS", 1)) # Increase debate rounds
|
||||||
config["online_tools"] = True # Increase debate rounds
|
config["online_tools"] = bool(os.getenv("ONLINE_TOOLS", "True")) # Increase debate rounds
|
||||||
|
|
||||||
# Initialize with custom config
|
# Initialize with custom config
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
ta = TradingAgentsGraph(debug=True, config=config)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
from .embedding_providers import (
|
||||||
|
EmbeddingProvider,
|
||||||
|
OpenAIEmbeddingProvider,
|
||||||
|
GeminiEmbeddingProvider,
|
||||||
|
OllamaEmbeddingProvider
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProviderFactory:
|
||||||
|
@staticmethod
|
||||||
|
def create_provider(config : dict[str, any])->EmbeddingProvider:
|
||||||
|
backend_url = config["backend_url"]
|
||||||
|
|
||||||
|
if "generativelanguage.googleapis.com" in backend_url:
|
||||||
|
return GeminiEmbeddingProvider(backend_url)
|
||||||
|
elif "localhost:11434" in backend_url:
|
||||||
|
return OllamaEmbeddingProvider(backend_url)
|
||||||
|
else:
|
||||||
|
return OpenAIEmbeddingProvider(backend_url)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from openai import OpenAI
|
||||||
|
from google import genai
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_embedding(self, text: str)->list[float]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def model_name(self)->str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||||
|
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
|
||||||
|
self.client = OpenAI(base_url=backend_url)
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding(self, text: str)->list[float]:
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=self._embedding_model,
|
||||||
|
input=text
|
||||||
|
)
|
||||||
|
return response.data[0].embedding
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self)->str:
|
||||||
|
return self._embedding_model
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||||
|
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
|
||||||
|
self.client = genai.Client()
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
|
||||||
|
def get_embedding(self, text: str)->list[float]:
|
||||||
|
response = self.client.models.embed_content(
|
||||||
|
model=self._embedding_model,
|
||||||
|
contents=text
|
||||||
|
)
|
||||||
|
return response.embeddings[0].values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self)->str:
|
||||||
|
return self._embedding_model
|
||||||
|
|
||||||
|
class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||||
|
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
|
||||||
|
self.client = OpenAI(base_url=backend_url)
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
|
||||||
|
def get_embedding(self, text: str)->list[float]:
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=self._embedding_model,
|
||||||
|
input=text
|
||||||
|
)
|
||||||
|
return response.data[0].embedding
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self)->str:
|
||||||
|
return self._embedding_model
|
||||||
|
|
||||||
|
|
@ -2,6 +2,7 @@ import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import os
|
import os
|
||||||
|
from .embedding_provider_factory import EmbeddingProviderFactory
|
||||||
from google import genai
|
from google import genai
|
||||||
|
|
||||||
class FinancialSituationMemory:
|
class FinancialSituationMemory:
|
||||||
|
|
@ -9,29 +10,7 @@ class FinancialSituationMemory:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.backend_url = config["backend_url"]
|
self.backend_url = config["backend_url"]
|
||||||
|
|
||||||
# Determine embedding configuration based on provider
|
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
|
||||||
if self.backend_url == "http://localhost:11434/v1":
|
|
||||||
# Ollama
|
|
||||||
self.embedding_model = "nomic-embed-text"
|
|
||||||
self.use_openai_api = True
|
|
||||||
elif "openai.com" in self.backend_url:
|
|
||||||
# OpenAI
|
|
||||||
self.embedding_model = "text-embedding-3-small"
|
|
||||||
self.use_openai_api = True
|
|
||||||
elif "generativelanguage.googleapis.com" in self.backend_url:
|
|
||||||
# Google Gemini API
|
|
||||||
self.embedding_model = "gemini-embedding-exp-03-07" # Use Google's embedding model
|
|
||||||
self.use_openai_api = False
|
|
||||||
else:
|
|
||||||
# Default to OpenAI-compatible
|
|
||||||
self.embedding_model = "text-embedding-3-small"
|
|
||||||
self.use_openai_api = True
|
|
||||||
|
|
||||||
# Initialize clients
|
|
||||||
if self.use_openai_api:
|
|
||||||
self.client = OpenAI(base_url=self.backend_url)
|
|
||||||
else:
|
|
||||||
self.client = genai.Client()
|
|
||||||
|
|
||||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||||
|
|
@ -39,18 +18,7 @@ class FinancialSituationMemory:
|
||||||
def get_embedding(self, text):
|
def get_embedding(self, text):
|
||||||
"""Get embedding for a text using the appropriate API"""
|
"""Get embedding for a text using the appropriate API"""
|
||||||
|
|
||||||
if self.use_openai_api:
|
return self.embedding_provider.get_embedding(text)
|
||||||
# Use OpenAI-compatible API
|
|
||||||
response = self.client.embeddings.create(
|
|
||||||
model=self.embedding_model, input=text
|
|
||||||
)
|
|
||||||
return response.data[0].embedding
|
|
||||||
else:
|
|
||||||
response = self.client.models.embed_content(
|
|
||||||
model=self.embedding_model,
|
|
||||||
contents=text
|
|
||||||
)
|
|
||||||
return response.embeddings[0].values
|
|
||||||
|
|
||||||
def add_situations(self, situations_and_advice):
|
def add_situations(self, situations_and_advice):
|
||||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from tqdm import tqdm
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from .config import get_config, set_config, DATA_DIR
|
from .config import get_config, set_config, DATA_DIR
|
||||||
|
from .search_provider_factory import SearchProviderFactory
|
||||||
|
|
||||||
|
|
||||||
def get_finnhub_news(
|
def get_finnhub_news(
|
||||||
|
|
@ -704,212 +705,24 @@ def get_YFin_data(
|
||||||
|
|
||||||
def get_stock_news(ticker, curr_date):
|
def get_stock_news(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
search_provider = SearchProviderFactory.create_provider(config)
|
||||||
|
query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
|
||||||
|
return search_provider.search(query, ticker, curr_date)
|
||||||
|
|
||||||
# Check if using Google API - implement grounding with Google Search
|
|
||||||
if "generativelanguage.googleapis.com" in config["backend_url"]:
|
|
||||||
try:
|
|
||||||
from google import genai
|
|
||||||
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
|
||||||
|
|
||||||
client = genai.Client()
|
|
||||||
|
|
||||||
# Create Google Search grounding tool
|
|
||||||
google_search_tool = Tool(
|
|
||||||
google_search=GoogleSearch()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate content with grounding
|
|
||||||
response = client.models.generate_content(
|
|
||||||
model=config["quick_think_llm"],
|
|
||||||
contents=f"Can you search for recent social media and news about {ticker} stock from 7 days before {curr_date} to {curr_date}? Focus on sentiment, price movements, and any significant developments that could impact trading decisions.",
|
|
||||||
config=GenerateContentConfig(
|
|
||||||
tools=[google_search_tool],
|
|
||||||
response_modalities=["TEXT"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract text from response
|
|
||||||
result_text = ""
|
|
||||||
for part in response.candidates[0].content.parts:
|
|
||||||
if hasattr(part, 'text'):
|
|
||||||
result_text += part.text
|
|
||||||
|
|
||||||
return result_text
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error retrieving stock news for {ticker}: {str(e)}"
|
|
||||||
else:
|
|
||||||
# For OpenAI and other APIs, use original implementation
|
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
|
||||||
|
|
||||||
response = client.responses.create(
|
|
||||||
model=config["quick_think_llm"],
|
|
||||||
input=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "input_text",
|
|
||||||
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
text={"format": {"type": "text"}},
|
|
||||||
reasoning={},
|
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"type": "web_search_preview",
|
|
||||||
"user_location": {"type": "approximate"},
|
|
||||||
"search_context_size": "low",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
temperature=1,
|
|
||||||
max_output_tokens=4096,
|
|
||||||
top_p=1,
|
|
||||||
store=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.output[1].content[0].text
|
|
||||||
|
|
||||||
|
|
||||||
def get_global_news(curr_date):
|
def get_global_news(curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
search_provider = SearchProviderFactory.create_provider(config)
|
||||||
|
query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions."
|
||||||
|
return search_provider.search(query, curr_date)
|
||||||
|
|
||||||
# Check if using Google API - implement grounding with Google Search
|
|
||||||
if "generativelanguage.googleapis.com" in config["backend_url"]:
|
|
||||||
try:
|
|
||||||
from google import genai
|
|
||||||
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
|
||||||
|
|
||||||
client = genai.Client()
|
|
||||||
|
|
||||||
# Create Google Search grounding tool
|
|
||||||
google_search_tool = Tool(
|
|
||||||
google_search=GoogleSearch()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate content with grounding
|
|
||||||
response = client.models.generate_content(
|
|
||||||
model=config["quick_think_llm"],
|
|
||||||
contents=f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions.",
|
|
||||||
config=GenerateContentConfig(
|
|
||||||
tools=[google_search_tool],
|
|
||||||
response_modalities=["TEXT"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract text from response
|
|
||||||
result_text = ""
|
|
||||||
for part in response.candidates[0].content.parts:
|
|
||||||
if hasattr(part, 'text'):
|
|
||||||
result_text += part.text
|
|
||||||
|
|
||||||
return result_text
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error retrieving global news: {str(e)}"
|
|
||||||
else:
|
|
||||||
# For OpenAI and other APIs, use original implementation
|
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
|
||||||
|
|
||||||
response = client.responses.create(
|
|
||||||
model=config["quick_think_llm"],
|
|
||||||
input=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "input_text",
|
|
||||||
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
text={"format": {"type": "text"}},
|
|
||||||
reasoning={},
|
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"type": "web_search_preview",
|
|
||||||
"user_location": {"type": "approximate"},
|
|
||||||
"search_context_size": "low",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
temperature=1,
|
|
||||||
max_output_tokens=4096,
|
|
||||||
top_p=1,
|
|
||||||
store=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.output[1].content[0].text
|
|
||||||
|
|
||||||
|
|
||||||
def get_fundamentals(ticker, curr_date):
|
def get_fundamentals(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
search_provider = SearchProviderFactory.create_provider(config)
|
||||||
|
query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format."
|
||||||
|
return search_provider.search(query, ticker, curr_date)
|
||||||
|
|
||||||
# Check if using Google API - implement grounding with Google Search
|
|
||||||
if "generativelanguage.googleapis.com" in config["backend_url"]:
|
|
||||||
try:
|
|
||||||
from google import genai
|
|
||||||
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
|
||||||
|
|
||||||
client = genai.Client()
|
|
||||||
|
|
||||||
# Create Google Search grounding tool
|
|
||||||
google_search_tool = Tool(
|
|
||||||
google_search=GoogleSearch()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate content with grounding
|
|
||||||
response = client.models.generate_content(
|
|
||||||
model=config["quick_think_llm"],
|
|
||||||
contents=f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format.",
|
|
||||||
config=GenerateContentConfig(
|
|
||||||
tools=[google_search_tool],
|
|
||||||
response_modalities=["TEXT"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract text from response
|
|
||||||
result_text = ""
|
|
||||||
for part in response.candidates[0].content.parts:
|
|
||||||
if hasattr(part, 'text'):
|
|
||||||
result_text += part.text
|
|
||||||
|
|
||||||
return result_text
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error retrieving fundamentals for {ticker}: {str(e)}"
|
|
||||||
else:
|
|
||||||
# For OpenAI and other APIs, use original implementation
|
|
||||||
client = OpenAI(base_url=config["backend_url"])
|
|
||||||
|
|
||||||
response = client.responses.create(
|
|
||||||
model=config["quick_think_llm"],
|
|
||||||
input=[
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "input_text",
|
|
||||||
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
text={"format": {"type": "text"}},
|
|
||||||
reasoning={},
|
|
||||||
tools=[
|
|
||||||
{
|
|
||||||
"type": "web_search_preview",
|
|
||||||
"user_location": {"type": "approximate"},
|
|
||||||
"search_context_size": "low",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
temperature=1,
|
|
||||||
max_output_tokens=4096,
|
|
||||||
top_p=1,
|
|
||||||
store=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.output[1].content[0].text
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
from google import genai
|
||||||
|
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
||||||
|
from openai import OpenAI
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SearchProvider(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def search(self, query: str, ticker: str, curr_date: str) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleSearchProvider(SearchProvider):
|
||||||
|
def __init__(self, model: str):
|
||||||
|
self.client = genai.Client()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def search(self, query: str, ticker: str, curr_date: str) -> str:
|
||||||
|
google_search_tool = Tool(
|
||||||
|
google_search=GoogleSearch()
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.models.generate_content(
|
||||||
|
model=self.model,
|
||||||
|
contents=query,
|
||||||
|
config=GenerateContentConfig(
|
||||||
|
tools=[google_search_tool],
|
||||||
|
response_modalities=["TEXT"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
result_text = ""
|
||||||
|
for part in response.candidates[0].content.parts:
|
||||||
|
if hasattr(part, 'text'):
|
||||||
|
result_text += part.text
|
||||||
|
|
||||||
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAISearchProvider(SearchProvider):
|
||||||
|
def __init__(self, model: str, backend_url: str):
|
||||||
|
self.client = OpenAI(base_url=backend_url)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def search(self, query: str, ticker: str, curr_date: str) -> str:
|
||||||
|
response = self.client.responses.create(
|
||||||
|
model=self.model,
|
||||||
|
input=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "input_text",
|
||||||
|
"text": query
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
text={"format": {"type": "text"}},
|
||||||
|
reasoning={},
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": "web_search_preview",
|
||||||
|
"user_location": {"type": "approximate"},
|
||||||
|
"search_context_size": "low",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=1,
|
||||||
|
max_output_tokens=4096,
|
||||||
|
top_p=1,
|
||||||
|
store=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.output[1].content[0].text
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
from .search_provider import (
|
||||||
|
SearchProvider,
|
||||||
|
GoogleSearchProvider,
|
||||||
|
OpenAISearchProvider
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchProviderFactory:
|
||||||
|
@staticmethod
|
||||||
|
def create_provider(config: dict[str, any])->SearchProvider:
|
||||||
|
backend_url = config["backend_url"]
|
||||||
|
model = config["quick_think_llm"]
|
||||||
|
|
||||||
|
if "generativelanguage.googleapis.com" in backend_url:
|
||||||
|
return GoogleSearchProvider(model)
|
||||||
|
else:
|
||||||
|
return OpenAISearchProvider(model, backend_url)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
import os
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
|
|
@ -9,14 +11,14 @@ DEFAULT_CONFIG = {
|
||||||
"dataflows/data_cache",
|
"dataflows/data_cache",
|
||||||
),
|
),
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"llm_provider": "openai",
|
"llm_provider": os.getenv("LLM_PROVIDER", "openai"),
|
||||||
"deep_think_llm": "o4-mini",
|
"deep_think_llm": os.getenv("DEEP_THINK_LLM", "o4-mini"),
|
||||||
"quick_think_llm": "gpt-4o-mini",
|
"quick_think_llm": os.getenv("QUICK_THINK_LLM", "gpt-4o-mini"),
|
||||||
"backend_url": "https://api.openai.com/v1",
|
"backend_url": os.getenv("BACKEND_URL", "https://api.openai.com/v1"),
|
||||||
# Debate and discussion settings
|
# Debate and discussion settings
|
||||||
"max_debate_rounds": 1,
|
"max_debate_rounds": int(os.getenv("MAX_DEBATE_ROUNDS", 1)),
|
||||||
"max_risk_discuss_rounds": 1,
|
"max_risk_discuss_rounds": 1,
|
||||||
"max_recur_limit": 100,
|
"max_recur_limit": 100,
|
||||||
# Tool settings
|
# Tool settings
|
||||||
"online_tools": True,
|
"online_tools": bool(os.getenv("ONLINE_TOOLS", "True")),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -170,10 +170,20 @@ class TradingAgentsGraph:
|
||||||
trace = []
|
trace = []
|
||||||
for chunk in self.graph.stream(init_agent_state, **args):
|
for chunk in self.graph.stream(init_agent_state, **args):
|
||||||
if len(chunk["messages"]) == 0:
|
if len(chunk["messages"]) == 0:
|
||||||
pass
|
continue
|
||||||
else:
|
|
||||||
chunk["messages"][-1].pretty_print()
|
message = chunk["messages"][-1]
|
||||||
trace.append(chunk)
|
# 중복 메시지 필터링
|
||||||
|
if message.content and message.content.strip():
|
||||||
|
# FINAL PROPOSAL 중복 방지
|
||||||
|
if "FINAL TRANSACTION PROPOSAL:" in message.content:
|
||||||
|
if not hasattr(self, '_final_printed'):
|
||||||
|
message.pretty_print()
|
||||||
|
self._final_printed = True
|
||||||
|
else:
|
||||||
|
message.pretty_print()
|
||||||
|
|
||||||
|
trace.append(chunk)
|
||||||
|
|
||||||
final_state = trace[-1]
|
final_state = trace[-1]
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue