feat: Partial migration from OpenAI to Gemini
This commit includes the following changes as part of an ongoing effort to migrate the project from OpenAI models and SDKs to Google's Gemini:
1. **Updated `requirements.txt`**:
* Removed `langchain-openai` and `openai`.
* Added `langchain-google-genai` and `google-generativeai`.
2. **Modified `tradingagents/agents/utils/agent_utils.py`**:
* Replaced the import of `ChatOpenAI` from `langchain_openai` with `ChatGoogleGenerativeAI` from `langchain_google_genai`.
* Updated tool function names (e.g., `get_stock_news_openai` to `get_stock_news_gemini`) and their internal calls to `interface` methods to reflect changes in `interface.py`.
3. **Refactored `tradingagents/dataflows/interface.py`**:
* Removed direct usage of the `openai` package.
* Added `google.generativeai` for direct Gemini API calls.
* Configured Gemini API key handling via the `GEMINI_API_KEY` environment variable.
* Rewrote `get_stock_news_openai`, `get_global_news_openai`, and `get_fundamentals_openai` to use `gemini-pro` via `genai.GenerativeModel`. These functions were renamed to `_gemini` equivalents (e.g., `get_stock_news_gemini`).
* Removed the `web_search_preview` function as it's not directly available in the `google-generativeai` SDK in the same way.
**Work In Progress / Next Steps:**
The migration is not yet complete. Your feedback indicated that many files still contain the "openai" keyword, which would need to be addressed. Key remaining tasks before this migration can be considered complete include:
* **Locating and updating main Langchain LLM instantiations**: The `ChatOpenAI` models (e.g., `quick_thinking_llm`, `deep_thinking_llm` passed to `GraphSetup`) are instantiated outside of the files modified so far. These instantiations need to be found (likely in `main.py` or `cli/main.py`) and changed to `ChatGoogleGenerativeAI`.
* **Updating all type hints**: Type hints for `ChatOpenAI` need to be changed to `ChatGoogleGenerativeAI` in `tradingagents/graph/setup.py` and throughout all agent creation functions and their usages in `tradingagents/agents/`.
* **Addressing all remaining "openai" keywords**: A global search for "openai" should be performed to identify and address any remaining occurrences in comments, documentation, or other utility functions throughout the project.
* **Thorough testing**: Once all code changes are made, comprehensive testing (unit, integration, manual) will be required.
I was in the process of identifying the central LLM instantiation points when the work was submitted.
This commit is contained in:
parent
570644d939
commit
72d352d401
|
|
@ -1,5 +1,4 @@
|
|||
typing-extensions
|
||||
langchain-openai
|
||||
langchain-experimental
|
||||
pandas
|
||||
yfinance
|
||||
|
|
@ -7,6 +6,8 @@ praw
|
|||
feedparser
|
||||
stockstats
|
||||
eodhd
|
||||
langchain-google-genai
|
||||
google-generativeai
|
||||
langgraph
|
||||
chromadb
|
||||
setuptools
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import functools
|
|||
import pandas as pd
|
||||
import os
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
import tradingagents.dataflows.interface as interface
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
|
|
@ -355,12 +355,12 @@ class Toolkit:
|
|||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_stock_news_openai(
|
||||
def get_stock_news_gemini(
|
||||
ticker: Annotated[str, "the company's ticker"],
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
):
|
||||
"""
|
||||
Retrieve the latest news about a given stock by using OpenAI's news API.
|
||||
Retrieve the latest news about a given stock by using Gemini's API.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
|
|
@ -368,35 +368,35 @@ class Toolkit:
|
|||
str: A formatted string containing the latest news about the company on the given date.
|
||||
"""
|
||||
|
||||
openai_news_results = interface.get_stock_news_openai(ticker, curr_date)
|
||||
gemini_news_results = interface.get_stock_news_gemini(ticker, curr_date)
|
||||
|
||||
return openai_news_results
|
||||
return gemini_news_results
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_global_news_openai(
|
||||
def get_global_news_gemini(
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
):
|
||||
"""
|
||||
Retrieve the latest macroeconomics news on a given date using OpenAI's macroeconomics news API.
|
||||
Retrieve the latest macroeconomics news on a given date using Gemini's macroeconomics news API.
|
||||
Args:
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A formatted string containing the latest macroeconomic news on the given date.
|
||||
"""
|
||||
|
||||
openai_news_results = interface.get_global_news_openai(curr_date)
|
||||
gemini_news_results = interface.get_global_news_gemini(curr_date)
|
||||
|
||||
return openai_news_results
|
||||
return gemini_news_results
|
||||
|
||||
@staticmethod
|
||||
@tool
|
||||
def get_fundamentals_openai(
|
||||
def get_fundamentals_gemini(
|
||||
ticker: Annotated[str, "the company's ticker"],
|
||||
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
|
||||
):
|
||||
"""
|
||||
Retrieve the latest fundamental information about a given stock on a given date by using OpenAI's news API.
|
||||
Retrieve the latest fundamental information about a given stock on a given date by using Gemini's API.
|
||||
Args:
|
||||
ticker (str): Ticker of a company. e.g. AAPL, TSM
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
|
|
@ -404,8 +404,8 @@ class Toolkit:
|
|||
str: A formatted string containing the latest fundamental information about the company on the given date.
|
||||
"""
|
||||
|
||||
openai_fundamentals_results = interface.get_fundamentals_openai(
|
||||
gemini_fundamentals_results = interface.get_fundamentals_gemini(
|
||||
ticker, curr_date
|
||||
)
|
||||
|
||||
return openai_fundamentals_results
|
||||
return gemini_fundamentals_results
|
||||
|
|
|
|||
|
|
@ -12,9 +12,15 @@ import os
|
|||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import yfinance as yf
|
||||
from openai import OpenAI
|
||||
import google.generativeai as genai
|
||||
from .config import get_config, set_config, DATA_DIR
|
||||
|
||||
# Configure the Gemini API key
|
||||
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not gemini_api_key:
|
||||
raise ValueError("GEMINI_API_KEY environment variable not set.")
|
||||
genai.configure(api_key=gemini_api_key)
|
||||
|
||||
|
||||
def get_finnhub_news(
|
||||
ticker: Annotated[
|
||||
|
|
@ -702,103 +708,22 @@ def get_YFin_data(
|
|||
return filtered_data
|
||||
|
||||
|
||||
def get_stock_news_openai(ticker, curr_date):
|
||||
client = OpenAI()
|
||||
|
||||
response = client.responses.create(
|
||||
model="gpt-4.1-mini",
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period.",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
def get_stock_news_gemini(ticker, curr_date):
|
||||
model = genai.GenerativeModel('gemini-pro')
|
||||
prompt = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
|
||||
response = model.generate_content(prompt)
|
||||
return response.text
|
||||
|
||||
|
||||
def get_global_news_openai(curr_date):
|
||||
client = OpenAI()
|
||||
|
||||
response = client.responses.create(
|
||||
model="gpt-4.1-mini",
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period.",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
def get_global_news_gemini(curr_date):
|
||||
model = genai.GenerativeModel('gemini-pro')
|
||||
prompt = f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period."
|
||||
response = model.generate_content(prompt)
|
||||
return response.text
|
||||
|
||||
|
||||
def get_fundamentals_openai(ticker, curr_date):
|
||||
client = OpenAI()
|
||||
|
||||
response = client.responses.create(
|
||||
model="gpt-4.1-mini",
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
def get_fundamentals_gemini(ticker, curr_date):
|
||||
model = genai.GenerativeModel('gemini-pro')
|
||||
prompt = f"Can you search Fundamental for discussions on {ticker} during of the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc"
|
||||
response = model.generate_content(prompt)
|
||||
return response.text
|
||||
|
|
|
|||
Loading…
Reference in New Issue