feat: make tool naming llm provider agnostic

This commit is contained in:
kevin-bruton 2025-09-27 14:39:53 +02:00
parent aae3b17f9c
commit 0134792d19
8 changed files with 18 additions and 18 deletions

View File

@ -465,10 +465,10 @@ def get_user_selections():
) )
selected_research_depth = select_research_depth() selected_research_depth = select_research_depth()
# Step 5: OpenAI backend # Step 5: LLM Provider
console.print( console.print(
create_question_box( create_question_box(
"Step 5: OpenAI backend", "Select which service to talk to" "Step 5: LLM Provider", "Select which service to talk to"
) )
) )
selected_llm_provider, backend_url = select_llm_provider() selected_llm_provider, backend_url = select_llm_provider()

View File

@ -10,7 +10,7 @@ def create_fundamentals_analyst(llm, toolkit):
company_name = state["company_of_interest"] company_name = state["company_of_interest"]
if toolkit.config["online_tools"]: if toolkit.config["online_tools"]:
tools = [toolkit.get_fundamentals_openai] tools = [toolkit.get_fundamentals_from_llm]
else: else:
tools = [ tools = [
toolkit.get_finnhub_company_insider_sentiment, toolkit.get_finnhub_company_insider_sentiment,

View File

@ -9,7 +9,7 @@ def create_news_analyst(llm, toolkit):
ticker = state["company_of_interest"] ticker = state["company_of_interest"]
if toolkit.config["online_tools"]: if toolkit.config["online_tools"]:
tools = [toolkit.get_global_news_openai, toolkit.get_google_news] tools = [toolkit.get_global_news_from_llm, toolkit.get_google_news]
else: else:
tools = [ tools = [
toolkit.get_finnhub_news, toolkit.get_finnhub_news,

View File

@ -10,7 +10,7 @@ def create_social_media_analyst(llm, toolkit):
company_name = state["company_of_interest"] company_name = state["company_of_interest"]
if toolkit.config["online_tools"]: if toolkit.config["online_tools"]:
tools = [toolkit.get_stock_news_openai] tools = [toolkit.get_stock_news_from_llm]
else: else:
tools = [ tools = [
toolkit.get_reddit_stock_info, toolkit.get_reddit_stock_info,

View File

@ -363,7 +363,7 @@ class Toolkit:
@staticmethod @staticmethod
@tool @tool
def get_stock_news_openai( def get_stock_news_from_llm(
ticker: Annotated[str, "the company's ticker"], ticker: Annotated[str, "the company's ticker"],
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
): ):
@ -377,7 +377,7 @@ class Toolkit:
""" """
try: try:
openai_news_results = interface.get_stock_news_openai(ticker, curr_date) openai_news_results = interface.get_stock_news_from_llm(ticker, curr_date)
return openai_news_results return openai_news_results
except ValueError as e: except ValueError as e:
# Return the detailed error message to the agent # Return the detailed error message to the agent
@ -385,7 +385,7 @@ class Toolkit:
@staticmethod @staticmethod
@tool @tool
def get_global_news_openai( def get_global_news_from_llm(
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
): ):
""" """
@ -397,7 +397,7 @@ class Toolkit:
""" """
try: try:
openai_news_results = interface.get_global_news_openai(curr_date) openai_news_results = interface.get_global_news_from_llm(curr_date)
return openai_news_results return openai_news_results
except ValueError as e: except ValueError as e:
# Return the detailed error message to the agent # Return the detailed error message to the agent
@ -405,7 +405,7 @@ class Toolkit:
@staticmethod @staticmethod
@tool @tool
def get_fundamentals_openai( def get_fundamentals_from_llm(
ticker: Annotated[str, "the company's ticker"], ticker: Annotated[str, "the company's ticker"],
curr_date: Annotated[str, "Current date in yyyy-mm-dd format"], curr_date: Annotated[str, "Current date in yyyy-mm-dd format"],
): ):
@ -419,7 +419,7 @@ class Toolkit:
""" """
try: try:
openai_fundamentals_results = interface.get_fundamentals_openai( openai_fundamentals_results = interface.get_fundamentals_from_llm(
ticker, curr_date ticker, curr_date
) )
return openai_fundamentals_results return openai_fundamentals_results

View File

@ -14,7 +14,7 @@ class FinancialSituationMemory:
# Use a good general-purpose model for financial text # Use a good general-purpose model for financial text
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.embedding_type = "local" self.embedding_type = "local"
print(f"✅ Using local embeddings with sentence-transformers") print(f"✅ Using local embeddings with sentence-transformers for {name}")
except ImportError: except ImportError:
print("⚠️ sentence-transformers not found. Install with: pip install sentence-transformers") print("⚠️ sentence-transformers not found. Install with: pip install sentence-transformers")
print("Falling back to ChromaDB's default embeddings...") print("Falling back to ChromaDB's default embeddings...")

View File

@ -932,19 +932,19 @@ def _call_llm_api(prompt, config):
raise ValueError(error_msg) from e raise ValueError(error_msg) from e
def get_stock_news_openai(ticker, curr_date): def get_stock_news_from_llm(ticker, curr_date):
config = get_config() config = get_config()
prompt = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period." prompt = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
return _call_llm_api(prompt, config) return _call_llm_api(prompt, config)
def get_global_news_openai(curr_date): def get_global_news_from_llm(curr_date):
config = get_config() config = get_config()
prompt = f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period." prompt = f"Can you search global or macroeconomics news from 7 days before {curr_date} to {curr_date} that would be informative for trading purposes? Make sure you only get the data posted during that period."
return _call_llm_api(prompt, config) return _call_llm_api(prompt, config)
def get_fundamentals_openai(ticker, curr_date): def get_fundamentals_from_llm(ticker, curr_date):
config = get_config() config = get_config()
prompt = f"Can you search for fundamental analysis discussions on {ticker} during the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc" prompt = f"Can you search for fundamental analysis discussions on {ticker} during the month before {curr_date} to the month of {curr_date}. Make sure you only get the data posted during that period. List as a table, with PE/PS/Cash flow/ etc"
return _call_llm_api(prompt, config) return _call_llm_api(prompt, config)

View File

@ -153,7 +153,7 @@ class TradingAgentsGraph:
"social": ToolNode( "social": ToolNode(
[ [
# online tools # online tools
self.toolkit.get_stock_news_openai, self.toolkit.get_stock_news_from_llm,
# offline tools # offline tools
self.toolkit.get_reddit_stock_info, self.toolkit.get_reddit_stock_info,
] ]
@ -161,7 +161,7 @@ class TradingAgentsGraph:
"news": ToolNode( "news": ToolNode(
[ [
# online tools # online tools
self.toolkit.get_global_news_openai, self.toolkit.get_global_news_from_llm,
self.toolkit.get_google_news, self.toolkit.get_google_news,
# offline tools # offline tools
self.toolkit.get_finnhub_news, self.toolkit.get_finnhub_news,
@ -171,7 +171,7 @@ class TradingAgentsGraph:
"fundamentals": ToolNode( "fundamentals": ToolNode(
[ [
# online tools # online tools
self.toolkit.get_fundamentals_openai, self.toolkit.get_fundamentals_from_llm,
# offline tools # offline tools
self.toolkit.get_finnhub_company_insider_sentiment, self.toolkit.get_finnhub_company_insider_sentiment,
self.toolkit.get_finnhub_company_insider_transactions, self.toolkit.get_finnhub_company_insider_transactions,