Merge branch 'temp' into gemini
This commit is contained in:
commit
624571ff22
|
|
@ -8,3 +8,5 @@ eval_data/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
results/
|
results/
|
||||||
.env
|
.env
|
||||||
|
tradingagents/dataflows/data_cache/
|
||||||
|
CLAUDE.md
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: main.py",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/main.py",
|
||||||
|
"console": "integratedTerminal"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
@ -132,6 +132,8 @@ def select_shallow_thinking_agent(provider) -> str:
|
||||||
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
|
("GPT-4.1-nano - Ultra-lightweight model for basic operations", "gpt-4.1-nano"),
|
||||||
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
|
("GPT-4.1-mini - Compact model with good performance", "gpt-4.1-mini"),
|
||||||
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
|
("GPT-4o - Standard model with solid capabilities", "gpt-4o"),
|
||||||
|
("o4-mini - Specialized reasoning model (compact)", "o4-mini"),
|
||||||
|
("o3 - Full advanced reasoning model", "o3"),
|
||||||
],
|
],
|
||||||
"anthropic": [
|
"anthropic": [
|
||||||
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
|
("Claude Haiku 3.5 - Fast inference and standard capabilities", "claude-3-5-haiku-latest"),
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
mysql:
|
||||||
|
image: mysql:8.0
|
||||||
|
container_name: tradingagents_mysql
|
||||||
|
restart: unless-stopped
|
||||||
|
environment:
|
||||||
|
MYSQL_ROOT_PASSWORD: ${DB_PASSWORD:-password}
|
||||||
|
MYSQL_DATABASE: ${DB_NAME:-tradingagents_db}
|
||||||
|
MYSQL_USER: ${DB_USER:-tradinguser}
|
||||||
|
MYSQL_PASSWORD: ${DB_PASSWORD:-password}
|
||||||
|
ports:
|
||||||
|
- "3306:3306"
|
||||||
|
volumes:
|
||||||
|
- /home/hskim/mysql_data:/var/lib/mysql
|
||||||
|
- /home/hskim/docker/mysql/init:/docker-entrypoint-initdb.d
|
||||||
|
command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||||
|
networks:
|
||||||
|
- tradingagents_network
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
container_name: tradingagents_redis
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "6379:6379"
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
command: redis-server --appendonly yes
|
||||||
|
networks:
|
||||||
|
- tradingagents_network
|
||||||
|
|
||||||
|
# 개발용 phpMyAdmin (선택사항)
|
||||||
|
# phpmyadmin:
|
||||||
|
# image: phpmyadmin/phpmyadmin
|
||||||
|
# container_name: tradingagents_phpmyadmin
|
||||||
|
# restart: unless-stopped
|
||||||
|
# environment:
|
||||||
|
# PMA_HOST: mysql
|
||||||
|
# PMA_PORT: 3306
|
||||||
|
# PMA_USER: root
|
||||||
|
# PMA_PASSWORD: ${DB_PASSWORD:-password}
|
||||||
|
# ports:
|
||||||
|
# - "8080:80"
|
||||||
|
# depends_on:
|
||||||
|
# - mysql
|
||||||
|
# networks:
|
||||||
|
# - tradingagents_network
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
mysql_data:
|
||||||
|
driver: local
|
||||||
|
redis_data:
|
||||||
|
driver: local
|
||||||
|
|
||||||
|
networks:
|
||||||
|
tradingagents_network:
|
||||||
|
driver: bridge
|
||||||
21
main.py
21
main.py
|
|
@ -1,21 +0,0 @@
|
||||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
|
||||||
|
|
||||||
# Create a custom config
|
|
||||||
config = DEFAULT_CONFIG.copy()
|
|
||||||
config["llm_provider"] = "google" # Use a different model
|
|
||||||
config["backend_url"] = "https://generativelanguage.googleapis.com/v1" # Use a different backend
|
|
||||||
config["deep_think_llm"] = "gemini-2.5-pro" # Use a different model
|
|
||||||
config["quick_think_llm"] = "gemini-2.5-flash-lite-preview-06-17" # Use a different model
|
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
|
||||||
config["online_tools"] = True # Increase debate rounds
|
|
||||||
|
|
||||||
# Initialize with custom config
|
|
||||||
ta = TradingAgentsGraph(debug=True, config=config)
|
|
||||||
|
|
||||||
# forward propagate
|
|
||||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
|
||||||
print(decision)
|
|
||||||
|
|
||||||
# Memorize mistakes and reflect
|
|
||||||
# ta.reflect_and_remember(1000) # parameter is the position returns
|
|
||||||
|
|
@ -4,11 +4,11 @@ from .embedding_providers import (
|
||||||
GeminiEmbeddingProvider,
|
GeminiEmbeddingProvider,
|
||||||
OllamaEmbeddingProvider
|
OllamaEmbeddingProvider
|
||||||
)
|
)
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
class EmbeddingProviderFactory:
|
class EmbeddingProviderFactory:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_provider(config : dict[str, any])->EmbeddingProvider:
|
def create_provider(config : dict[str, Any])->EmbeddingProvider:
|
||||||
backend_url = config["backend_url"]
|
backend_url = config["backend_url"]
|
||||||
|
|
||||||
if "generativelanguage.googleapis.com" in backend_url:
|
if "generativelanguage.googleapis.com" in backend_url:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
# Use default config but allow it to be overridden
|
# Use default config but allow it to be overridden
|
||||||
_config: Optional[Dict] = None
|
_config: Optional[Dict] = None
|
||||||
DATA_DIR: Optional[str] = None
|
DATA_DIR: str = ""
|
||||||
|
|
||||||
|
|
||||||
def initialize_config():
|
def initialize_config():
|
||||||
|
|
|
||||||
|
|
@ -74,10 +74,14 @@ def getNewsData(query, start_date, end_date):
|
||||||
for el in results_on_page:
|
for el in results_on_page:
|
||||||
try:
|
try:
|
||||||
link = el.find("a")["href"]
|
link = el.find("a")["href"]
|
||||||
title = el.select_one("div.MBeuO").get_text()
|
title_el = el.select_one("div.MBeuO")
|
||||||
snippet = el.select_one(".GI74Re").get_text()
|
title = title_el.get_text() if title_el else ""
|
||||||
date = el.select_one(".LfVVr").get_text()
|
snippet_el = el.select_one(".GI74Re")
|
||||||
source = el.select_one(".NUnG9d span").get_text()
|
snippet = snippet_el.get_text() if snippet_el else ""
|
||||||
|
date_el = el.select_one(".LfVVr")
|
||||||
|
date = date_el.get_text() if date_el else ""
|
||||||
|
source_el = el.select_one(".NUnG9d span")
|
||||||
|
source = source_el.get_text() if source_el else ""
|
||||||
news_results.append(
|
news_results.append(
|
||||||
{
|
{
|
||||||
"link": link,
|
"link": link,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Annotated, Dict
|
from typing import Annotated, Dict, Tuple
|
||||||
from .reddit_utils import fetch_top_from_category
|
from .reddit_utils import fetch_top_from_category
|
||||||
from .yfin_utils import *
|
from .yfin_utils import *
|
||||||
from .stockstats_utils import *
|
from .stockstats_utils import *
|
||||||
|
|
@ -14,7 +14,24 @@ from tqdm import tqdm
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from .config import get_config, set_config, DATA_DIR
|
from .config import get_config, set_config, DATA_DIR
|
||||||
from .search_provider_factory import SearchProviderFactory
|
from .search_provider_factory import SearchProviderFactory, create_search_provider_factory
|
||||||
|
|
||||||
|
|
||||||
|
def parse_date_range(curr_date: str, look_back_days: int) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse date range and return start and end dates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
curr_date: Current date in yyyy-mm-dd format
|
||||||
|
look_back_days: Number of days to look back
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (start_date, end_date) as strings
|
||||||
|
"""
|
||||||
|
end_date = curr_date
|
||||||
|
start_date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
before = start_date_obj - relativedelta(days=look_back_days)
|
||||||
|
return before.strftime("%Y-%m-%d"), end_date
|
||||||
|
|
||||||
|
|
||||||
def get_finnhub_news(
|
def get_finnhub_news(
|
||||||
|
|
@ -37,9 +54,7 @@ def get_finnhub_news(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
before, _ = parse_date_range(curr_date, look_back_days)
|
||||||
before = start_date - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
|
result = get_data_in_range(ticker, before, curr_date, "news_data", DATA_DIR)
|
||||||
|
|
||||||
|
|
@ -76,9 +91,7 @@ def get_finnhub_company_insider_sentiment(
|
||||||
str: a report of the sentiment in the past 15 days starting at curr_date
|
str: a report of the sentiment in the past 15 days starting at curr_date
|
||||||
"""
|
"""
|
||||||
|
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
before, _ = parse_date_range(curr_date, look_back_days)
|
||||||
before = date_obj - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR)
|
data = get_data_in_range(ticker, before, curr_date, "insider_senti", DATA_DIR)
|
||||||
|
|
||||||
|
|
@ -117,9 +130,7 @@ def get_finnhub_company_insider_transactions(
|
||||||
str: a report of the company's insider transaction/trading informtaion in the past 15 days
|
str: a report of the company's insider transaction/trading informtaion in the past 15 days
|
||||||
"""
|
"""
|
||||||
|
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
before, _ = parse_date_range(curr_date, look_back_days)
|
||||||
before = date_obj - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR)
|
data = get_data_in_range(ticker, before, curr_date, "insider_trans", DATA_DIR)
|
||||||
|
|
||||||
|
|
@ -290,9 +301,7 @@ def get_google_news(
|
||||||
) -> str:
|
) -> str:
|
||||||
query = query.replace(" ", "+")
|
query = query.replace(" ", "+")
|
||||||
|
|
||||||
start_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
before, _ = parse_date_range(curr_date, look_back_days)
|
||||||
before = start_date - relativedelta(days=look_back_days)
|
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
news_results = getNewsData(query, before, curr_date)
|
news_results = getNewsData(query, before, curr_date)
|
||||||
|
|
||||||
|
|
@ -323,18 +332,17 @@ def get_reddit_global_news(
|
||||||
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
before, start_date_str = parse_date_range(start_date, look_back_days)
|
||||||
before = start_date - relativedelta(days=look_back_days)
|
start_date_dt = datetime.strptime(start_date_str, "%Y-%m-%d")
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
posts = []
|
posts = []
|
||||||
# iterate from start_date to end_date
|
# iterate from start_date to end_date
|
||||||
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
||||||
|
|
||||||
total_iterations = (start_date - curr_date).days + 1
|
total_iterations = (start_date_dt - curr_date).days + 1
|
||||||
pbar = tqdm(desc=f"Getting Global News on {start_date}", total=total_iterations)
|
pbar = tqdm(desc=f"Getting Global News on {start_date_dt}", total=total_iterations)
|
||||||
|
|
||||||
while curr_date <= start_date:
|
while curr_date <= start_date_dt:
|
||||||
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
||||||
fetch_result = fetch_top_from_category(
|
fetch_result = fetch_top_from_category(
|
||||||
"global_news",
|
"global_news",
|
||||||
|
|
@ -377,21 +385,20 @@ def get_reddit_company_news(
|
||||||
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
str: A formatted dataframe containing the latest news articles posts on reddit and meta information in these columns: "created_utc", "id", "title", "selftext", "score", "num_comments", "url"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
before, start_date_str = parse_date_range(start_date, look_back_days)
|
||||||
before = start_date - relativedelta(days=look_back_days)
|
start_date_dt = datetime.strptime(start_date_str, "%Y-%m-%d")
|
||||||
before = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
posts = []
|
posts = []
|
||||||
# iterate from start_date to end_date
|
# iterate from start_date to end_date
|
||||||
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
curr_date = datetime.strptime(before, "%Y-%m-%d")
|
||||||
|
|
||||||
total_iterations = (start_date - curr_date).days + 1
|
total_iterations = (start_date_dt - curr_date).days + 1
|
||||||
pbar = tqdm(
|
pbar = tqdm(
|
||||||
desc=f"Getting Company News for {ticker} on {start_date}",
|
desc=f"Getting Company News for {ticker} on {start_date_dt}",
|
||||||
total=total_iterations,
|
total=total_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
while curr_date <= start_date:
|
while curr_date <= start_date_dt:
|
||||||
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
curr_date_str = curr_date.strftime("%Y-%m-%d")
|
||||||
fetch_result = fetch_top_from_category(
|
fetch_result = fetch_top_from_category(
|
||||||
"company_news",
|
"company_news",
|
||||||
|
|
@ -509,8 +516,9 @@ def get_stock_stats_indicators_window(
|
||||||
)
|
)
|
||||||
|
|
||||||
end_date = curr_date
|
end_date = curr_date
|
||||||
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
before_str, _ = parse_date_range(curr_date, look_back_days)
|
||||||
before = curr_date - relativedelta(days=look_back_days)
|
before_dt = datetime.strptime(before_str, "%Y-%m-%d")
|
||||||
|
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
|
|
||||||
if not online:
|
if not online:
|
||||||
# read from YFin data
|
# read from YFin data
|
||||||
|
|
@ -524,30 +532,30 @@ def get_stock_stats_indicators_window(
|
||||||
dates_in_df = data["Date"].astype(str).str[:10]
|
dates_in_df = data["Date"].astype(str).str[:10]
|
||||||
|
|
||||||
ind_string = ""
|
ind_string = ""
|
||||||
while curr_date >= before:
|
while curr_date_dt >= before_dt:
|
||||||
# only do the trading dates
|
# only do the trading dates
|
||||||
if curr_date.strftime("%Y-%m-%d") in dates_in_df.values:
|
if curr_date_dt.strftime("%Y-%m-%d") in dates_in_df.values:
|
||||||
indicator_value = get_stockstats_indicator(
|
indicator_value = get_stockstats_indicator(
|
||||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
|
symbol, indicator, curr_date_dt.strftime("%Y-%m-%d"), online
|
||||||
)
|
)
|
||||||
|
|
||||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
ind_string += f"{curr_date_dt.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||||
|
|
||||||
curr_date = curr_date - relativedelta(days=1)
|
curr_date_dt = curr_date_dt - relativedelta(days=1)
|
||||||
else:
|
else:
|
||||||
# online gathering
|
# online gathering
|
||||||
ind_string = ""
|
ind_string = ""
|
||||||
while curr_date >= before:
|
while curr_date_dt >= before_dt:
|
||||||
indicator_value = get_stockstats_indicator(
|
indicator_value = get_stockstats_indicator(
|
||||||
symbol, indicator, curr_date.strftime("%Y-%m-%d"), online
|
symbol, indicator, curr_date_dt.strftime("%Y-%m-%d"), online
|
||||||
)
|
)
|
||||||
|
|
||||||
ind_string += f"{curr_date.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
ind_string += f"{curr_date_dt.strftime('%Y-%m-%d')}: {indicator_value}\n"
|
||||||
|
|
||||||
curr_date = curr_date - relativedelta(days=1)
|
curr_date_dt = curr_date_dt - relativedelta(days=1)
|
||||||
|
|
||||||
result_str = (
|
result_str = (
|
||||||
f"## {indicator} values from {before.strftime('%Y-%m-%d')} to {end_date}:\n\n"
|
f"## {indicator} values from {before_dt.strftime('%Y-%m-%d')} to {end_date}:\n\n"
|
||||||
+ ind_string
|
+ ind_string
|
||||||
+ "\n\n"
|
+ "\n\n"
|
||||||
+ best_ind_params.get(indicator, "No description available.")
|
+ best_ind_params.get(indicator, "No description available.")
|
||||||
|
|
@ -565,20 +573,20 @@ def get_stockstats_indicator(
|
||||||
online: Annotated[bool, "to fetch data online or offline"],
|
online: Annotated[bool, "to fetch data online or offline"],
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
curr_date = datetime.strptime(curr_date, "%Y-%m-%d")
|
curr_date_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||||
curr_date = curr_date.strftime("%Y-%m-%d")
|
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
indicator_value = StockstatsUtils.get_stock_stats(
|
indicator_value = StockstatsUtils.get_stock_stats(
|
||||||
symbol,
|
symbol,
|
||||||
indicator,
|
indicator,
|
||||||
curr_date,
|
curr_date_str,
|
||||||
os.path.join(DATA_DIR, "market_data", "price_data"),
|
os.path.join(DATA_DIR, "market_data", "price_data"),
|
||||||
online=online,
|
online=online,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date}: {e}"
|
f"Error getting stockstats indicator data for indicator {indicator} on {curr_date_str}: {e}"
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
@ -591,9 +599,7 @@ def get_YFin_data_window(
|
||||||
look_back_days: Annotated[int, "how many days to look back"],
|
look_back_days: Annotated[int, "how many days to look back"],
|
||||||
) -> str:
|
) -> str:
|
||||||
# calculate past days
|
# calculate past days
|
||||||
date_obj = datetime.strptime(curr_date, "%Y-%m-%d")
|
start_date, _ = parse_date_range(curr_date, look_back_days)
|
||||||
before = date_obj - relativedelta(days=look_back_days)
|
|
||||||
start_date = before.strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
# read in data
|
# read in data
|
||||||
data = pd.read_csv(
|
data = pd.read_csv(
|
||||||
|
|
@ -703,9 +709,13 @@ def get_YFin_data(
|
||||||
return filtered_data
|
return filtered_data
|
||||||
|
|
||||||
|
|
||||||
|
# Enhanced search provider factory instance (singleton)
|
||||||
|
_search_factory = create_search_provider_factory()
|
||||||
|
|
||||||
|
|
||||||
def get_stock_news(ticker, curr_date):
|
def get_stock_news(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
search_provider = SearchProviderFactory.create_provider(config)
|
search_provider = _search_factory.create_provider(config)
|
||||||
query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
|
query = f"Can you search Social Media for {ticker} from 7 days before {curr_date} to {curr_date}? Make sure you only get the data posted during that period."
|
||||||
return search_provider.search(query)
|
return search_provider.search(query)
|
||||||
|
|
||||||
|
|
@ -713,7 +723,7 @@ def get_stock_news(ticker, curr_date):
|
||||||
|
|
||||||
def get_global_news(curr_date):
|
def get_global_news(curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
search_provider = SearchProviderFactory.create_provider(config)
|
search_provider = _search_factory.create_provider(config)
|
||||||
query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions."
|
query = f"Search for global macroeconomic news and financial market updates from 7 days before {curr_date} to {curr_date}. Focus on central bank decisions, economic indicators, geopolitical events, and market-moving news that would be important for trading decisions."
|
||||||
return search_provider.search(query)
|
return search_provider.search(query)
|
||||||
|
|
||||||
|
|
@ -721,7 +731,7 @@ def get_global_news(curr_date):
|
||||||
|
|
||||||
def get_fundamentals(ticker, curr_date):
|
def get_fundamentals(ticker, curr_date):
|
||||||
config = get_config()
|
config = get_config()
|
||||||
search_provider = SearchProviderFactory.create_provider(config)
|
search_provider = _search_factory.create_provider(config)
|
||||||
query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format."
|
query = f"Search for fundamental analysis data and financial metrics for {ticker} stock from the month before {curr_date} to the month of {curr_date}. Look for earnings reports, financial ratios like PE, PS, cash flow, revenue growth, analyst ratings, and any fundamental analysis discussions. Please present key metrics in a structured format."
|
||||||
return search_provider.search(query)
|
return search_provider.search(query)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
class SearchProvider(ABC):
|
class SearchProvider(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, query: str, ticker: str, curr_date: str) -> str:
|
def search(self, query: str) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,47 +1,133 @@
|
||||||
from .search_provider import (
|
from .search_provider import SearchProvider
|
||||||
SearchProvider,
|
|
||||||
GoogleSearchProvider,
|
|
||||||
OpenAISearchProvider
|
|
||||||
)
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
from typing import Dict, Callable, Any
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class SearchProviderFactory:
|
class ProviderSelector(ABC):
|
||||||
_cache = {} # 클래스 레벨 캐시
|
"""Abstract base class for provider selection strategies."""
|
||||||
|
|
||||||
@staticmethod
|
@abstractmethod
|
||||||
def create_provider(config: dict[str, any]) -> SearchProvider:
|
def select_provider_type(self, config: Dict[str, Any]) -> str:
|
||||||
|
"""Select provider type based on configuration."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MappingBasedProviderSelector(ProviderSelector):
|
||||||
|
"""Selects provider based on URL pattern mapping table."""
|
||||||
|
|
||||||
|
def __init__(self, mappings: Dict[str, str], default_provider: str = "openai"):
|
||||||
|
self._mappings = mappings
|
||||||
|
self._default_provider = default_provider
|
||||||
|
|
||||||
|
def select_provider_type(self, config: Dict[str, Any]) -> str:
|
||||||
|
backend_url = config.get("backend_url", "")
|
||||||
|
for pattern, provider_type in self._mappings.items():
|
||||||
|
if pattern in backend_url:
|
||||||
|
return provider_type
|
||||||
|
return self._default_provider
|
||||||
|
|
||||||
|
|
||||||
|
class SearchProviderRegistry:
|
||||||
|
"""Registry for search provider creation functions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._providers: Dict[str, Callable[[Dict[str, Any]], SearchProvider]] = {}
|
||||||
|
|
||||||
|
def register(self, provider_type: str, creator: Callable[[Dict[str, Any]], SearchProvider]):
|
||||||
|
"""Register a provider creator function."""
|
||||||
|
self._providers[provider_type] = creator
|
||||||
|
|
||||||
|
def create(self, provider_type: str, config: Dict[str, Any]) -> SearchProvider:
|
||||||
|
"""Create a provider instance using registered creator."""
|
||||||
|
if provider_type not in self._providers:
|
||||||
|
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||||
|
return self._providers[provider_type](config)
|
||||||
|
|
||||||
|
def get_available_types(self) -> list[str]:
|
||||||
|
"""Get list of available provider types."""
|
||||||
|
return list(self._providers.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class SearchProviderFactoryImpl:
|
||||||
|
"""Enhanced factory for creating SearchProvider instances with caching and extensibility."""
|
||||||
|
|
||||||
|
def __init__(self, registry: SearchProviderRegistry, selector: ProviderSelector):
|
||||||
|
self._registry = registry
|
||||||
|
self._selector = selector
|
||||||
|
self._cache: Dict[str, SearchProvider] = {}
|
||||||
|
|
||||||
|
def create_provider(self, config: Dict[str, Any]) -> SearchProvider:
|
||||||
"""
|
"""
|
||||||
Create a SearchProvider with caching to avoid creating new instances.
|
Create a SearchProvider with caching to avoid creating new instances.
|
||||||
Uses config hash as cache key for efficient reuse.
|
Uses config hash as cache key for efficient reuse.
|
||||||
"""
|
"""
|
||||||
# Create cache key from relevant config values
|
# Create cache key from relevant config values
|
||||||
cache_key_data = {
|
cache_key_data = {
|
||||||
"backend_url": config["backend_url"],
|
"backend_url": config.get("backend_url", ""),
|
||||||
"model": config["quick_think_llm"]
|
"model": config.get("quick_think_llm", "")
|
||||||
}
|
}
|
||||||
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
||||||
|
|
||||||
# Return cached instance if exists
|
# Return cached instance if exists
|
||||||
if cache_key in SearchProviderFactory._cache:
|
if cache_key in self._cache:
|
||||||
return SearchProviderFactory._cache[cache_key]
|
return self._cache[cache_key]
|
||||||
|
|
||||||
# Create new instance
|
# Select and create provider
|
||||||
backend_url = config["backend_url"]
|
provider_type = self._selector.select_provider_type(config)
|
||||||
model = config["quick_think_llm"]
|
provider = self._registry.create(provider_type, config)
|
||||||
|
|
||||||
if "generativelanguage.googleapis.com" in backend_url:
|
|
||||||
provider = GoogleSearchProvider(model)
|
|
||||||
else:
|
|
||||||
provider = OpenAISearchProvider(model, backend_url)
|
|
||||||
|
|
||||||
# Cache and return
|
# Cache and return
|
||||||
SearchProviderFactory._cache[cache_key] = provider
|
self._cache[cache_key] = provider
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
"""Clear the provider cache (useful for testing or config changes)."""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
def get_available_provider_types(self) -> list[str]:
|
||||||
|
"""Get list of available provider types."""
|
||||||
|
return self._registry.get_available_types()
|
||||||
|
|
||||||
|
|
||||||
|
def create_search_provider_factory() -> SearchProviderFactoryImpl:
|
||||||
|
"""Create a configured SearchProviderFactory with default providers."""
|
||||||
|
registry = SearchProviderRegistry()
|
||||||
|
|
||||||
|
# Register default providers
|
||||||
|
def create_google_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||||
|
from .search_provider import GoogleSearchProvider
|
||||||
|
return GoogleSearchProvider(config["quick_think_llm"])
|
||||||
|
|
||||||
|
def create_openai_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||||
|
from .search_provider import OpenAISearchProvider
|
||||||
|
return OpenAISearchProvider(config["quick_think_llm"], config["backend_url"])
|
||||||
|
|
||||||
|
registry.register("google", create_google_provider)
|
||||||
|
registry.register("openai", create_openai_provider)
|
||||||
|
|
||||||
|
# Create URL pattern mappings (easily extensible)
|
||||||
|
url_mappings = {
|
||||||
|
"generativelanguage.googleapis.com": "google",
|
||||||
|
"api.openai.com": "openai",
|
||||||
|
}
|
||||||
|
|
||||||
|
selector = MappingBasedProviderSelector(url_mappings, default_provider="openai")
|
||||||
|
return SearchProviderFactoryImpl(registry, selector)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility - singleton instance
|
||||||
|
_default_factory = create_search_provider_factory()
|
||||||
|
|
||||||
|
|
||||||
|
class SearchProviderFactory:
|
||||||
|
"""Backward compatibility wrapper for the old static factory."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||||
|
return _default_factory.create_provider(config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clear_cache():
|
def clear_cache():
|
||||||
"""Clear the provider cache (useful for testing or config changes)."""
|
_default_factory.clear_cache()
|
||||||
SearchProviderFactory._cache.clear()
|
|
||||||
|
|
||||||
|
|
@ -3,7 +3,7 @@ import os
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
"data_dir": os.getenv("TRADINGAGENTS_DATA_DIR", "./data"),
|
||||||
"data_cache_dir": os.path.join(
|
"data_cache_dir": os.path.join(
|
||||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"dataflows/data_cache",
|
"dataflows/data_cache",
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,7 @@ class TradingAgentsGraph:
|
||||||
# online tools
|
# online tools
|
||||||
self.toolkit.get_stock_news,
|
self.toolkit.get_stock_news,
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_reddit_stock_info,
|
# self.toolkit.get_reddit_stock_info,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"news": ToolNode(
|
"news": ToolNode(
|
||||||
|
|
@ -136,8 +136,8 @@ class TradingAgentsGraph:
|
||||||
self.toolkit.get_global_news,
|
self.toolkit.get_global_news,
|
||||||
self.toolkit.get_google_news,
|
self.toolkit.get_google_news,
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_finnhub_news,
|
# self.toolkit.get_finnhub_news,
|
||||||
self.toolkit.get_reddit_news,
|
# self.toolkit.get_reddit_news,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
"fundamentals": ToolNode(
|
"fundamentals": ToolNode(
|
||||||
|
|
@ -145,11 +145,11 @@ class TradingAgentsGraph:
|
||||||
# online tools
|
# online tools
|
||||||
self.toolkit.get_fundamentals,
|
self.toolkit.get_fundamentals,
|
||||||
# offline tools
|
# offline tools
|
||||||
self.toolkit.get_finnhub_company_insider_sentiment,
|
# self.toolkit.get_finnhub_company_insider_sentiment,
|
||||||
self.toolkit.get_finnhub_company_insider_transactions,
|
# self.toolkit.get_finnhub_company_insider_transactions,
|
||||||
self.toolkit.get_simfin_balance_sheet,
|
# self.toolkit.get_simfin_balance_sheet,
|
||||||
self.toolkit.get_simfin_cashflow,
|
# self.toolkit.get_simfin_cashflow,
|
||||||
self.toolkit.get_simfin_income_stmt,
|
# self.toolkit.get_simfin_income_stmt,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue