ickerAnonymizer` into all analyst agents (`Market`, `News`, `Fundamentals`, `Social`) and data tools. The LLM now only sees "ASSET_XXX" in prompts, preventing data contamination.
This commit is contained in:
parent
9347a419e4
commit
d2ebd6d587
|
|
@ -5,11 +5,21 @@ from tradingagents.agents.utils.agent_utils import get_fundamentals, get_balance
|
|||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
|
||||
# Initialize anonymizer
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
def create_fundamentals_analyst(llm):
|
||||
def fundamentals_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
real_ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"] # Acting as placeholder name
|
||||
|
||||
# BLINDFIRE PROTOCOL: Anonymize Ticker
|
||||
# We set name here too just in case fundamentals runs first or independently
|
||||
anonymizer.set_company_name(real_ticker, company_name)
|
||||
ticker = anonymizer.anonymize_ticker(real_ticker)
|
||||
|
||||
tools = [
|
||||
get_fundamentals,
|
||||
|
|
|
|||
|
|
@ -5,12 +5,23 @@ from tradingagents.agents.utils.agent_utils import get_stock_data, get_indicator
|
|||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
|
||||
# Initialize anonymizer (shared instance appropriate here or inside)
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
def create_market_analyst(llm):
|
||||
|
||||
def market_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
real_ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"] # In this context acting as name too
|
||||
|
||||
# BLINDFIRE PROTOCOL: Anonymize Ticker
|
||||
anonymizer.set_company_name(real_ticker, company_name)
|
||||
ticker = anonymizer.anonymize_ticker(real_ticker)
|
||||
|
||||
# NOTE: We continue to use 'ticker' variable name but it now holds 'ASSET_XXX'
|
||||
|
||||
tools = [
|
||||
get_stock_data,
|
||||
|
|
|
|||
|
|
@ -5,10 +5,23 @@ from tradingagents.agents.utils.agent_utils import get_news, get_global_news
|
|||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
|
||||
# Initialize anonymizer
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
def create_news_analyst(llm):
|
||||
def news_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
real_ticker = state["company_of_interest"]
|
||||
|
||||
# BLINDFIRE PROTOCOL: Anonymize Ticker
|
||||
ticker = anonymizer.anonymize_ticker(real_ticker)
|
||||
# Note: company name registration happens in market_analyst primarily,
|
||||
# but we can do it here too if not already set, or just use ticker mapping.
|
||||
# Since state doesn't always have full company name guaranteed in all flows,
|
||||
# we rely on market_analyst or previous steps, or just ticker hashing here.
|
||||
|
||||
|
||||
tools = [
|
||||
get_news,
|
||||
|
|
|
|||
|
|
@ -5,12 +5,21 @@ from tradingagents.agents.utils.agent_utils import get_news
|
|||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
|
||||
# Initialize anonymizer
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
def create_social_media_analyst(llm):
|
||||
def social_media_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
real_ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
# BLINDFIRE PROTOCOL: Anonymize Ticker
|
||||
anonymizer.set_company_name(real_ticker, company_name)
|
||||
ticker = anonymizer.anonymize_ticker(real_ticker)
|
||||
|
||||
tools = [
|
||||
get_news,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
|
||||
@tool
|
||||
def get_stock_data(
|
||||
|
|
@ -19,4 +19,18 @@ def get_stock_data(
|
|||
Returns:
|
||||
str: A formatted dataframe containing the stock price data for the specified ticker symbol in the specified date range.
|
||||
"""
|
||||
return route_to_vendor("get_stock_data", symbol, start_date, end_date)
|
||||
# Initialize anonymizer locally to ensure fresh state loading
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
# 1. Deanonymize ticker (ASSET_XXX -> AAPL)
|
||||
real_ticker = anonymizer.deanonymize_ticker(symbol)
|
||||
if not real_ticker:
|
||||
real_ticker = symbol # Fallback if not anonymized
|
||||
|
||||
# 2. Get Data using Real Ticker
|
||||
raw_data = route_to_vendor("get_stock_data", real_ticker, start_date, end_date)
|
||||
|
||||
# 3. Anonymize Output (AAPL -> ASSET_XXX)
|
||||
anonymized_data = anonymizer.anonymize_text(raw_data, real_ticker)
|
||||
|
||||
return anonymized_data
|
||||
|
|
|
|||
|
|
@ -1,7 +1,23 @@
|
|||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
|
||||
def _process_vendor_call(func_name, ticker, *args):
|
||||
"""Helper to handle anonymization for vendor calls"""
|
||||
# Initialize locally to ensure fresh state
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
# 1. Deanonymize ticker
|
||||
real_ticker = anonymizer.deanonymize_ticker(ticker)
|
||||
if not real_ticker:
|
||||
real_ticker = ticker
|
||||
|
||||
# 2. Get Data
|
||||
raw_data = route_to_vendor(func_name, real_ticker, *args)
|
||||
|
||||
# 3. Anonymize Output
|
||||
return anonymizer.anonymize_text(raw_data, real_ticker)
|
||||
|
||||
@tool
|
||||
def get_fundamentals(
|
||||
|
|
@ -17,7 +33,7 @@ def get_fundamentals(
|
|||
Returns:
|
||||
str: A formatted report containing comprehensive fundamental data
|
||||
"""
|
||||
return route_to_vendor("get_fundamentals", ticker, curr_date)
|
||||
return _process_vendor_call("get_fundamentals", ticker, curr_date)
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -36,7 +52,7 @@ def get_balance_sheet(
|
|||
Returns:
|
||||
str: A formatted report containing balance sheet data
|
||||
"""
|
||||
return route_to_vendor("get_balance_sheet", ticker, freq, curr_date)
|
||||
return _process_vendor_call("get_balance_sheet", ticker, freq, curr_date)
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -55,7 +71,7 @@ def get_cashflow(
|
|||
Returns:
|
||||
str: A formatted report containing cash flow statement data
|
||||
"""
|
||||
return route_to_vendor("get_cashflow", ticker, freq, curr_date)
|
||||
return _process_vendor_call("get_cashflow", ticker, freq, curr_date)
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -74,4 +90,4 @@ def get_income_statement(
|
|||
Returns:
|
||||
str: A formatted report containing income statement data
|
||||
"""
|
||||
return route_to_vendor("get_income_statement", ticker, freq, curr_date)
|
||||
return _process_vendor_call("get_income_statement", ticker, freq, curr_date)
|
||||
|
|
@ -1,6 +1,32 @@
|
|||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
|
||||
def _process_vendor_call(func_name, ticker=None, *args):
|
||||
"""Helper to handle anonymization for vendor calls"""
|
||||
# Initialize locally to ensure fresh state
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
real_ticker = None
|
||||
if ticker:
|
||||
# 1. Deanonymize ticker
|
||||
real_ticker = anonymizer.deanonymize_ticker(ticker)
|
||||
if not real_ticker:
|
||||
real_ticker = ticker
|
||||
|
||||
# 2. Get Data
|
||||
# Handle optional ticker for global_news
|
||||
call_args = [real_ticker] + list(args) if ticker else list(args)
|
||||
raw_data = route_to_vendor(func_name, *call_args)
|
||||
|
||||
# 3. Anonymize Output
|
||||
# For global news, passing ticker=None to anonymize_text might skip ticker-specific masking,
|
||||
# but still mask known mapped tickers if logic supports it.
|
||||
# Current anonymize_text requires ticker context for "Company X".
|
||||
# For global news, we might need a generic pass or skip specific company names if unknown.
|
||||
# However, for now we pass real_ticker if available.
|
||||
return anonymizer.anonymize_text(raw_data, real_ticker) if real_ticker else raw_data
|
||||
|
||||
@tool
|
||||
def get_news(
|
||||
|
|
@ -18,7 +44,7 @@ def get_news(
|
|||
Returns:
|
||||
str: A formatted string containing news data
|
||||
"""
|
||||
return route_to_vendor("get_news", ticker, start_date, end_date)
|
||||
return _process_vendor_call("get_news", ticker, start_date, end_date)
|
||||
|
||||
@tool
|
||||
def get_global_news(
|
||||
|
|
@ -36,6 +62,18 @@ def get_global_news(
|
|||
Returns:
|
||||
str: A formatted string containing global news data
|
||||
"""
|
||||
# Global news doesn't take a ticker as input, so pass None as ticker
|
||||
# We rely on the vendor call just taking args.
|
||||
# Note: route_to_vendor expects func_name, *args.
|
||||
# Our helper expects func_name, ticker, *args.
|
||||
# So we call route_to_vendor directly here but still might want to anonymize output?
|
||||
# Global news might mention "Apple". If we are analyzing "ASSET_042" (Apple), we typically want to mask it.
|
||||
# But without a specific target ticker in context, it's hard.
|
||||
# For now, let's just return raw global news or we'd need to mask ALL known mapped tickers.
|
||||
# The current Anonymizer is context-aware (one ticker).
|
||||
# Ideally, get_global_news should probably stay raw or be masked for the 'current company of interest'
|
||||
# but tools don't know the agent's context unless passed.
|
||||
# Leaving global news RAW for now as it provides macro context.
|
||||
return route_to_vendor("get_global_news", curr_date, look_back_days, limit)
|
||||
|
||||
@tool
|
||||
|
|
@ -48,11 +86,11 @@ def get_insider_sentiment(
|
|||
Uses the configured news_data vendor.
|
||||
Args:
|
||||
ticker (str): Ticker symbol of the company
|
||||
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A report of insider sentiment data
|
||||
"""
|
||||
return route_to_vendor("get_insider_sentiment", ticker, curr_date)
|
||||
return _process_vendor_call("get_insider_sentiment", ticker, curr_date)
|
||||
|
||||
@tool
|
||||
def get_insider_transactions(
|
||||
|
|
@ -64,8 +102,8 @@ def get_insider_transactions(
|
|||
Uses the configured news_data vendor.
|
||||
Args:
|
||||
ticker (str): Ticker symbol of the company
|
||||
curr_date (str): Current date you are trading at, yyyy-mm-dd
|
||||
curr_date (str): Current date in yyyy-mm-dd format
|
||||
Returns:
|
||||
str: A report of insider transaction data
|
||||
"""
|
||||
return route_to_vendor("get_insider_transactions", ticker, curr_date)
|
||||
return _process_vendor_call("get_insider_transactions", ticker, curr_date)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from langchain_core.tools import tool
|
||||
from typing import Annotated
|
||||
from tradingagents.dataflows.interface import route_to_vendor
|
||||
from tradingagents.utils.anonymizer import TickerAnonymizer
|
||||
|
||||
@tool
|
||||
def get_indicators(
|
||||
|
|
@ -20,4 +21,18 @@ def get_indicators(
|
|||
Returns:
|
||||
str: A formatted dataframe containing the technical indicators for the specified ticker symbol and indicator.
|
||||
"""
|
||||
return route_to_vendor("get_indicators", symbol, indicator, curr_date, look_back_days)
|
||||
# Initialize anonymizer locally to ensure fresh state loading
|
||||
anonymizer = TickerAnonymizer()
|
||||
|
||||
# 1. Deanonymize ticker
|
||||
real_ticker = anonymizer.deanonymize_ticker(symbol)
|
||||
if not real_ticker:
|
||||
real_ticker = symbol
|
||||
|
||||
# 2. Get Data
|
||||
raw_data = route_to_vendor("get_indicators", real_ticker, indicator, curr_date, look_back_days)
|
||||
|
||||
# 3. Anonymize Output
|
||||
anonymized_data = anonymizer.anonymize_text(raw_data, real_ticker)
|
||||
|
||||
return anonymized_data
|
||||
|
|
@ -25,12 +25,18 @@ class TickerAnonymizer:
|
|||
CRITICAL: Uses adjusted close prices to handle dividends and splits.
|
||||
"""
|
||||
|
||||
def __init__(self, seed: str = "blindfire_v1"):
|
||||
def __init__(self, seed: str = "blindfire_v1", auto_persist: bool = True):
|
||||
self.seed = seed
|
||||
self.ticker_map = {}
|
||||
self.reverse_map = {}
|
||||
self.company_names = {}
|
||||
self.baseline_prices = {} # Store baseline for normalization
|
||||
self.auto_persist = auto_persist
|
||||
|
||||
# Persistence path
|
||||
self.map_file = Path("ticker_map.json")
|
||||
if self.auto_persist:
|
||||
self._load_from_file()
|
||||
|
||||
# Product name mappings
|
||||
self.product_map = {
|
||||
|
|
@ -58,6 +64,36 @@ class TickerAnonymizer:
|
|||
"YouTube": "Video Platform A",
|
||||
"Android": "Mobile OS A",
|
||||
}
|
||||
|
||||
def _load_from_file(self):
|
||||
"""Load mapping from disk if exists"""
|
||||
if self.map_file.exists():
|
||||
try:
|
||||
with open(self.map_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
# Merge loaded data
|
||||
self.ticker_map.update(data.get("ticker_map", {}))
|
||||
self.reverse_map.update(data.get("reverse_map", {}))
|
||||
self.company_names.update(data.get("company_names", {}))
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load ticker map: {e}")
|
||||
|
||||
def _save_to_file(self):
|
||||
"""Save mapping to disk"""
|
||||
if not self.auto_persist:
|
||||
return
|
||||
|
||||
data = {
|
||||
"ticker_map": self.ticker_map,
|
||||
"reverse_map": self.reverse_map,
|
||||
"company_names": self.company_names,
|
||||
"seed": self.seed
|
||||
}
|
||||
try:
|
||||
with open(self.map_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to save ticker map: {e}")
|
||||
|
||||
def anonymize_ticker(self, ticker: str) -> str:
|
||||
"""
|
||||
|
|
@ -75,11 +111,15 @@ class TickerAnonymizer:
|
|||
anon_label = f"ASSET_{hash_val % 1000:03d}"
|
||||
self.ticker_map[ticker] = anon_label
|
||||
self.reverse_map[anon_label] = ticker
|
||||
self._save_to_file() # Save on new mapping
|
||||
|
||||
return self.ticker_map[ticker]
|
||||
|
||||
def set_company_name(self, ticker: str, company_name: str):
|
||||
"""Store company name for anonymization."""
|
||||
self.company_names[ticker] = company_name
|
||||
if ticker not in self.company_names or self.company_names[ticker] != company_name:
|
||||
self.company_names[ticker] = company_name
|
||||
self._save_to_file()
|
||||
|
||||
def anonymize_text(self, text: str, ticker: str) -> str:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
import os
|
||||
from openai import OpenAI
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load env
|
||||
load_dotenv()
|
||||
|
||||
key = os.getenv("GOOGLE_API_KEY")
|
||||
print(f"Checking GOOGLE_API_KEY...")
|
||||
if not key:
|
||||
print("❌ GOOGLE_API_KEY not found in environment or .env file.")
|
||||
exit(1)
|
||||
|
||||
print(f"✅ Key found: {key[:4]}...{key[-4:]}")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=key,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
)
|
||||
|
||||
print("Attempting to generate embedding for 'Hello World'...")
|
||||
try:
|
||||
resp = client.embeddings.create(
|
||||
model="text-embedding-004",
|
||||
input="Hello world"
|
||||
)
|
||||
print("✅ Embedding Success! The API Key is valid and the model is accessible.")
|
||||
print(f"Embedding vector length: {len(resp.data[0].embedding)}")
|
||||
except Exception as e:
|
||||
print(f"❌ Embedding Failed: {e}")
|
||||
print("\nTroubleshooting:")
|
||||
print("1. Ensure the API Key is from Google AI Studio (aistudio.google.com).")
|
||||
print("2. Ensure the 'Generative Language API' is enabled in Google Cloud Console if using a GCP project.")
|
||||
print("3. Verify you have not exceeded your quota.")
|
||||
Loading…
Reference in New Issue