temp
This commit is contained in:
parent
c3e609730b
commit
868ce0b37b
|
|
@ -1,10 +1,10 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnalystType(str, Enum):
|
||||
MARKET = "market"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
FUNDAMENTALS = "fundamentals"
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnalystType(str, Enum):
|
||||
MARKET = "market"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
FUNDAMENTALS = "fundamentals"
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
______ ___ ___ __
|
||||
/_ __/________ _____/ (_)___ ____ _/ | ____ ____ ____ / /______
|
||||
/ / / ___/ __ `/ __ / / __ \/ __ `/ /| |/ __ `/ _ \/ __ \/ __/ ___/
|
||||
/ / / / / /_/ / /_/ / / / / / /_/ / ___ / /_/ / __/ / / / /_(__ )
|
||||
/_/ /_/ \__,_/\__,_/_/_/ /_/\__, /_/ |_\__, /\___/_/ /_/\__/____/
|
||||
/____/ /____/
|
||||
|
||||
______ ___ ___ __
|
||||
/_ __/________ _____/ (_)___ ____ _/ | ____ ____ ____ / /______
|
||||
/ / / ___/ __ `/ __ / / __ \/ __ `/ /| |/ __ `/ _ \/ __ \/ __/ ___/
|
||||
/ / / / / /_/ / /_/ / / / / / /_/ / ___ / /_/ / __/ / / / /_(__ )
|
||||
/_/ /_/ \__,_/\__,_/_/_/ /_/\__, /_/ |_\__, /\___/_/ /_/\__/____/
|
||||
/____/ /____/
|
||||
|
|
@ -1,41 +1,41 @@
|
|||
from .utils.agent_utils import Toolkit, create_msg_delete
|
||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .utils.memory import FinancialSituationMemory
|
||||
|
||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||
from .analysts.market_analyst import create_market_analyst
|
||||
from .analysts.news_analyst import create_news_analyst
|
||||
from .analysts.social_media_analyst import create_social_media_analyst
|
||||
|
||||
from .researchers.bear_researcher import create_bear_researcher
|
||||
from .researchers.bull_researcher import create_bull_researcher
|
||||
|
||||
from .risk_mgmt.aggresive_debator import create_risky_debator
|
||||
from .risk_mgmt.conservative_debator import create_safe_debator
|
||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||
|
||||
from .managers.research_manager import create_research_manager
|
||||
from .managers.risk_manager import create_risk_manager
|
||||
|
||||
from .trader.trader import create_trader
|
||||
|
||||
__all__ = [
|
||||
"FinancialSituationMemory",
|
||||
"Toolkit",
|
||||
"AgentState",
|
||||
"create_msg_delete",
|
||||
"InvestDebateState",
|
||||
"RiskDebateState",
|
||||
"create_bear_researcher",
|
||||
"create_bull_researcher",
|
||||
"create_research_manager",
|
||||
"create_fundamentals_analyst",
|
||||
"create_market_analyst",
|
||||
"create_neutral_debator",
|
||||
"create_news_analyst",
|
||||
"create_risky_debator",
|
||||
"create_risk_manager",
|
||||
"create_safe_debator",
|
||||
"create_social_media_analyst",
|
||||
"create_trader",
|
||||
]
|
||||
from .utils.agent_utils import Toolkit, create_msg_delete
|
||||
from .utils.agent_states import AgentState, InvestDebateState, RiskDebateState
|
||||
from .utils.memory import FinancialSituationMemory
|
||||
|
||||
from .analysts.fundamentals_analyst import create_fundamentals_analyst
|
||||
from .analysts.market_analyst import create_market_analyst
|
||||
from .analysts.news_analyst import create_news_analyst
|
||||
from .analysts.social_media_analyst import create_social_media_analyst
|
||||
|
||||
from .researchers.bear_researcher import create_bear_researcher
|
||||
from .researchers.bull_researcher import create_bull_researcher
|
||||
|
||||
from .risk_mgmt.aggresive_debator import create_risky_debator
|
||||
from .risk_mgmt.conservative_debator import create_safe_debator
|
||||
from .risk_mgmt.neutral_debator import create_neutral_debator
|
||||
|
||||
from .managers.research_manager import create_research_manager
|
||||
from .managers.risk_manager import create_risk_manager
|
||||
|
||||
from .trader.trader import create_trader
|
||||
|
||||
__all__ = [
|
||||
"FinancialSituationMemory",
|
||||
"Toolkit",
|
||||
"AgentState",
|
||||
"create_msg_delete",
|
||||
"InvestDebateState",
|
||||
"RiskDebateState",
|
||||
"create_bear_researcher",
|
||||
"create_bull_researcher",
|
||||
"create_research_manager",
|
||||
"create_fundamentals_analyst",
|
||||
"create_market_analyst",
|
||||
"create_neutral_debator",
|
||||
"create_news_analyst",
|
||||
"create_risky_debator",
|
||||
"create_risk_manager",
|
||||
"create_safe_debator",
|
||||
"create_social_media_analyst",
|
||||
"create_trader",
|
||||
]
|
||||
|
|
@ -1,64 +1,64 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_fundamentals_analyst(llm, toolkit):
|
||||
def fundamentals_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_fundamentals]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_finnhub_company_insider_sentiment,
|
||||
toolkit.get_finnhub_company_insider_transactions,
|
||||
toolkit.get_simfin_balance_sheet,
|
||||
toolkit.get_simfin_cashflow,
|
||||
toolkit.get_simfin_income_stmt,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"fundamentals_report": report,
|
||||
}
|
||||
|
||||
return fundamentals_analyst_node
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_fundamentals_analyst(llm, toolkit):
|
||||
def fundamentals_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_fundamentals]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_finnhub_company_insider_sentiment,
|
||||
toolkit.get_finnhub_company_insider_transactions,
|
||||
toolkit.get_simfin_balance_sheet,
|
||||
toolkit.get_simfin_cashflow,
|
||||
toolkit.get_simfin_income_stmt,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a researcher tasked with analyzing fundamental information over the past week about a company. Please write a comprehensive report of the company's fundamental information such as financial documents, company profile, basic company financials, company financial history, insider sentiment and insider transactions to gain a full view of the company's fundamental information to inform traders. Make sure to include as much detail as possible. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ " Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read.",
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"fundamentals_report": report,
|
||||
}
|
||||
|
||||
return fundamentals_analyst_node
|
||||
|
|
@ -1,91 +1,91 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_market_analyst(llm, toolkit):
|
||||
|
||||
def market_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [
|
||||
toolkit.get_YFin_data_online,
|
||||
toolkit.get_stockstats_indicators_report_online,
|
||||
]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_YFin_data,
|
||||
toolkit.get_stockstats_indicators_report,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
|
||||
Moving Averages:
|
||||
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
||||
- close_200_sma: 200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
|
||||
- close_10_ema: 10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.
|
||||
|
||||
MACD Related:
|
||||
- macd: MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets.
|
||||
- macds: MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives.
|
||||
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets.
|
||||
|
||||
Momentum Indicators:
|
||||
- rsi: RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.
|
||||
|
||||
Volatility Indicators:
|
||||
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
|
||||
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.
|
||||
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.
|
||||
- atr: ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
|
||||
|
||||
Volume-Based Indicators:
|
||||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||
|
||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"market_report": report,
|
||||
}
|
||||
|
||||
return market_analyst_node
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_market_analyst(llm, toolkit):
|
||||
|
||||
def market_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [
|
||||
toolkit.get_YFin_data_online,
|
||||
toolkit.get_stockstats_indicators_report_online,
|
||||
]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_YFin_data,
|
||||
toolkit.get_stockstats_indicators_report,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a trading assistant tasked with analyzing financial markets. Your role is to select the **most relevant indicators** for a given market condition or trading strategy from the following list. The goal is to choose up to **8 indicators** that provide complementary insights without redundancy. Categories and each category's indicators are:
|
||||
|
||||
Moving Averages:
|
||||
- close_50_sma: 50 SMA: A medium-term trend indicator. Usage: Identify trend direction and serve as dynamic support/resistance. Tips: It lags price; combine with faster indicators for timely signals.
|
||||
- close_200_sma: 200 SMA: A long-term trend benchmark. Usage: Confirm overall market trend and identify golden/death cross setups. Tips: It reacts slowly; best for strategic trend confirmation rather than frequent trading entries.
|
||||
- close_10_ema: 10 EMA: A responsive short-term average. Usage: Capture quick shifts in momentum and potential entry points. Tips: Prone to noise in choppy markets; use alongside longer averages for filtering false signals.
|
||||
|
||||
MACD Related:
|
||||
- macd: MACD: Computes momentum via differences of EMAs. Usage: Look for crossovers and divergence as signals of trend changes. Tips: Confirm with other indicators in low-volatility or sideways markets.
|
||||
- macds: MACD Signal: An EMA smoothing of the MACD line. Usage: Use crossovers with the MACD line to trigger trades. Tips: Should be part of a broader strategy to avoid false positives.
|
||||
- macdh: MACD Histogram: Shows the gap between the MACD line and its signal. Usage: Visualize momentum strength and spot divergence early. Tips: Can be volatile; complement with additional filters in fast-moving markets.
|
||||
|
||||
Momentum Indicators:
|
||||
- rsi: RSI: Measures momentum to flag overbought/oversold conditions. Usage: Apply 70/30 thresholds and watch for divergence to signal reversals. Tips: In strong trends, RSI may remain extreme; always cross-check with trend analysis.
|
||||
|
||||
Volatility Indicators:
|
||||
- boll: Bollinger Middle: A 20 SMA serving as the basis for Bollinger Bands. Usage: Acts as a dynamic benchmark for price movement. Tips: Combine with the upper and lower bands to effectively spot breakouts or reversals.
|
||||
- boll_ub: Bollinger Upper Band: Typically 2 standard deviations above the middle line. Usage: Signals potential overbought conditions and breakout zones. Tips: Confirm signals with other tools; prices may ride the band in strong trends.
|
||||
- boll_lb: Bollinger Lower Band: Typically 2 standard deviations below the middle line. Usage: Indicates potential oversold conditions. Tips: Use additional analysis to avoid false reversal signals.
|
||||
- atr: ATR: Averages true range to measure volatility. Usage: Set stop-loss levels and adjust position sizes based on current market volatility. Tips: It's a reactive measure, so use it as part of a broader risk management strategy.
|
||||
|
||||
Volume-Based Indicators:
|
||||
- vwma: VWMA: A moving average weighted by volume. Usage: Confirm trends by integrating price action with volume data. Tips: Watch for skewed results from volume spikes; use in combination with other volume analyses.
|
||||
|
||||
- Select indicators that provide diverse and complementary information. Avoid redundancy (e.g., do not select both rsi and stochrsi). Also briefly explain why they are suitable for the given market context. When you tool call, please use the exact name of the indicators provided above as they are defined parameters, otherwise your call will fail. Please make sure to call get_YFin_data first to retrieve the CSV that is needed to generate indicators. Write a very detailed and nuanced report of the trends you observe. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."""
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. The company we want to look at is {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"market_report": report,
|
||||
}
|
||||
|
||||
return market_analyst_node
|
||||
|
|
@ -1,60 +1,60 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_news_analyst(llm, toolkit):
|
||||
def news_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_global_news, toolkit.get_google_news]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_finnhub_news,
|
||||
toolkit.get_reddit_news,
|
||||
toolkit.get_google_news,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"news_report": report,
|
||||
}
|
||||
|
||||
return news_analyst_node
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_news_analyst(llm, toolkit):
|
||||
def news_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_global_news, toolkit.get_google_news]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_finnhub_news,
|
||||
toolkit.get_reddit_news,
|
||||
toolkit.get_google_news,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Look at news from EODHD, and finnhub to be comprehensive. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"news_report": report,
|
||||
}
|
||||
|
||||
return news_analyst_node
|
||||
|
|
@ -1,60 +1,60 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_social_media_analyst(llm, toolkit):
|
||||
def social_media_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_stock_news]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_reddit_stock_info,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"sentiment_report": report,
|
||||
}
|
||||
|
||||
return social_media_analyst_node
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_social_media_analyst(llm, toolkit):
|
||||
def social_media_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
if toolkit.config["online_tools"]:
|
||||
tools = [toolkit.get_stock_news]
|
||||
else:
|
||||
tools = [
|
||||
toolkit.get_reddit_stock_info,
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are a social media and company specific news researcher/analyst tasked with analyzing social media posts, recent company news, and public sentiment for a specific company over the past week. You will be given a company's name your objective is to write a comprehensive long report detailing your analysis, insights, and implications for traders and investors on this company's current state after looking at social media and what people are saying about that company, analyzing sentiment data of what people feel each day about the company, and looking at recent company news. Try to look at all sources possible from social media to sentiment to news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Makrdown table at the end of the report to organize key points in the report, organized and easy to read.""",
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. The current company we want to analyze is {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
result = chain.invoke(state["messages"])
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"sentiment_report": report,
|
||||
}
|
||||
|
||||
return social_media_analyst_node
|
||||
|
|
@ -1,57 +1,57 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
def research_manager_node(state) -> dict:
|
||||
history = state["investment_debate_state"].get("history", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||
|
||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
||||
|
||||
Additionally, develop a detailed investment plan for the trader. This should include:
|
||||
|
||||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||
|
||||
Here are your past reflections on mistakes:
|
||||
\"{past_memory_str}\"
|
||||
|
||||
Here is the debate:
|
||||
Debate History:
|
||||
{history}"""
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
new_investment_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": investment_debate_state.get("history", ""),
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": response.content,
|
||||
"count": investment_debate_state["count"],
|
||||
}
|
||||
|
||||
return {
|
||||
"investment_debate_state": new_investment_debate_state,
|
||||
"investment_plan": response.content,
|
||||
}
|
||||
|
||||
return research_manager_node
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
def research_manager_node(state) -> dict:
|
||||
history = state["investment_debate_state"].get("history", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the portfolio manager and debate facilitator, your role is to critically evaluate this round of debate and make a definitive decision: align with the bear analyst, the bull analyst, or choose Hold only if it is strongly justified based on the arguments presented.
|
||||
|
||||
Summarize the key points from both sides concisely, focusing on the most compelling evidence or reasoning. Your recommendation—Buy, Sell, or Hold—must be clear and actionable. Avoid defaulting to Hold simply because both sides have valid points; commit to a stance grounded in the debate's strongest arguments.
|
||||
|
||||
Additionally, develop a detailed investment plan for the trader. This should include:
|
||||
|
||||
Your Recommendation: A decisive stance supported by the most convincing arguments.
|
||||
Rationale: An explanation of why these arguments lead to your conclusion.
|
||||
Strategic Actions: Concrete steps for implementing the recommendation.
|
||||
Take into account your past mistakes on similar situations. Use these insights to refine your decision-making and ensure you are learning and improving. Present your analysis conversationally, as if speaking naturally, without special formatting.
|
||||
|
||||
Here are your past reflections on mistakes:
|
||||
\"{past_memory_str}\"
|
||||
|
||||
Here is the debate:
|
||||
Debate History:
|
||||
{history}"""
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
new_investment_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": investment_debate_state.get("history", ""),
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": response.content,
|
||||
"count": investment_debate_state["count"],
|
||||
}
|
||||
|
||||
return {
|
||||
"investment_debate_state": new_investment_debate_state,
|
||||
"investment_plan": response.content,
|
||||
}
|
||||
|
||||
return research_manager_node
|
||||
|
|
@ -1,68 +1,68 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_risk_manager(llm, memory):
|
||||
def risk_manager_node(state) -> dict:
|
||||
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
history = state["risk_debate_state"]["history"]
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
market_research_report = state["market_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["news_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
trader_plan = state["investment_plan"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
||||
|
||||
Guidelines for Decision-Making:
|
||||
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
|
||||
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
|
||||
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
|
||||
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
|
||||
|
||||
Deliverables:
|
||||
- A clear and actionable recommendation: Buy, Sell, or Hold.
|
||||
- Detailed reasoning anchored in the debate and past reflections.
|
||||
|
||||
---
|
||||
|
||||
**Analysts Debate History:**
|
||||
{history}
|
||||
|
||||
---
|
||||
|
||||
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": risk_debate_state["history"],
|
||||
"risky_history": risk_debate_state["risky_history"],
|
||||
"safe_history": risk_debate_state["safe_history"],
|
||||
"neutral_history": risk_debate_state["neutral_history"],
|
||||
"latest_speaker": "Judge",
|
||||
"current_risky_response": risk_debate_state["current_risky_response"],
|
||||
"current_safe_response": risk_debate_state["current_safe_response"],
|
||||
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
||||
"count": risk_debate_state["count"],
|
||||
}
|
||||
|
||||
return {
|
||||
"risk_debate_state": new_risk_debate_state,
|
||||
"final_trade_decision": response.content,
|
||||
}
|
||||
|
||||
return risk_manager_node
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_risk_manager(llm, memory):
|
||||
def risk_manager_node(state) -> dict:
|
||||
|
||||
company_name = state["company_of_interest"]
|
||||
|
||||
history = state["risk_debate_state"]["history"]
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
market_research_report = state["market_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["news_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
trader_plan = state["investment_plan"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Risk Management Judge and Debate Facilitator, your goal is to evaluate the debate between three risk analysts—Risky, Neutral, and Safe/Conservative—and determine the best course of action for the trader. Your decision must result in a clear recommendation: Buy, Sell, or Hold. Choose Hold only if strongly justified by specific arguments, not as a fallback when all sides seem valid. Strive for clarity and decisiveness.
|
||||
|
||||
Guidelines for Decision-Making:
|
||||
1. **Summarize Key Arguments**: Extract the strongest points from each analyst, focusing on relevance to the context.
|
||||
2. **Provide Rationale**: Support your recommendation with direct quotes and counterarguments from the debate.
|
||||
3. **Refine the Trader's Plan**: Start with the trader's original plan, **{trader_plan}**, and adjust it based on the analysts' insights.
|
||||
4. **Learn from Past Mistakes**: Use lessons from **{past_memory_str}** to address prior misjudgments and improve the decision you are making now to make sure you don't make a wrong BUY/SELL/HOLD call that loses money.
|
||||
|
||||
Deliverables:
|
||||
- A clear and actionable recommendation: Buy, Sell, or Hold.
|
||||
- Detailed reasoning anchored in the debate and past reflections.
|
||||
|
||||
---
|
||||
|
||||
**Analysts Debate History:**
|
||||
{history}
|
||||
|
||||
---
|
||||
|
||||
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
new_risk_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": risk_debate_state["history"],
|
||||
"risky_history": risk_debate_state["risky_history"],
|
||||
"safe_history": risk_debate_state["safe_history"],
|
||||
"neutral_history": risk_debate_state["neutral_history"],
|
||||
"latest_speaker": "Judge",
|
||||
"current_risky_response": risk_debate_state["current_risky_response"],
|
||||
"current_safe_response": risk_debate_state["current_safe_response"],
|
||||
"current_neutral_response": risk_debate_state["current_neutral_response"],
|
||||
"count": risk_debate_state["count"],
|
||||
}
|
||||
|
||||
return {
|
||||
"risk_debate_state": new_risk_debate_state,
|
||||
"final_trade_decision": response.content,
|
||||
}
|
||||
|
||||
return risk_manager_node
|
||||
|
|
@ -1,63 +1,63 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
def bear_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bear_history = investment_debate_state.get("bear_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
|
||||
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
|
||||
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
|
||||
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
|
||||
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
|
||||
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
|
||||
|
||||
Resources available:
|
||||
|
||||
Market research report: {market_research_report}
|
||||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bull argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bear Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"bear_history": bear_history + "\n" + argument,
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": argument,
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"investment_debate_state": new_investment_debate_state}
|
||||
|
||||
return bear_node
|
||||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
def bear_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bear_history = investment_debate_state.get("bear_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a Bear Analyst making the case against investing in the stock. Your goal is to present a well-reasoned argument emphasizing risks, challenges, and negative indicators. Leverage the provided research and data to highlight potential downsides and counter bullish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
|
||||
- Risks and Challenges: Highlight factors like market saturation, financial instability, or macroeconomic threats that could hinder the stock's performance.
|
||||
- Competitive Weaknesses: Emphasize vulnerabilities such as weaker market positioning, declining innovation, or threats from competitors.
|
||||
- Negative Indicators: Use evidence from financial data, market trends, or recent adverse news to support your position.
|
||||
- Bull Counterpoints: Critically analyze the bull argument with specific data and sound reasoning, exposing weaknesses or over-optimistic assumptions.
|
||||
- Engagement: Present your argument in a conversational style, directly engaging with the bull analyst's points and debating effectively rather than simply listing facts.
|
||||
|
||||
Resources available:
|
||||
|
||||
Market research report: {market_research_report}
|
||||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bull argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bear Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"bear_history": bear_history + "\n" + argument,
|
||||
"bull_history": investment_debate_state.get("bull_history", ""),
|
||||
"current_response": argument,
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"investment_debate_state": new_investment_debate_state}
|
||||
|
||||
return bear_node
|
||||
|
|
@ -1,61 +1,61 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
def bull_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bull_history = investment_debate_state.get("bull_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
||||
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
|
||||
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
|
||||
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
|
||||
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
|
||||
|
||||
Resources available:
|
||||
Market research report: {market_research_report}
|
||||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bear argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bull Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"bull_history": bull_history + "\n" + argument,
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"current_response": argument,
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"investment_debate_state": new_investment_debate_state}
|
||||
|
||||
return bull_node
|
||||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
def bull_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
history = investment_debate_state.get("history", "")
|
||||
bull_history = investment_debate_state.get("bull_history", "")
|
||||
|
||||
current_response = investment_debate_state.get("current_response", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a Bull Analyst advocating for investing in the stock. Your task is to build a strong, evidence-based case emphasizing growth potential, competitive advantages, and positive market indicators. Leverage the provided research and data to address concerns and counter bearish arguments effectively.
|
||||
|
||||
Key points to focus on:
|
||||
- Growth Potential: Highlight the company's market opportunities, revenue projections, and scalability.
|
||||
- Competitive Advantages: Emphasize factors like unique products, strong branding, or dominant market positioning.
|
||||
- Positive Indicators: Use financial health, industry trends, and recent positive news as evidence.
|
||||
- Bear Counterpoints: Critically analyze the bear argument with specific data and sound reasoning, addressing concerns thoroughly and showing why the bull perspective holds stronger merit.
|
||||
- Engagement: Present your argument in a conversational style, engaging directly with the bear analyst's points and debating effectively rather than just listing data.
|
||||
|
||||
Resources available:
|
||||
Market research report: {market_research_report}
|
||||
Social media sentiment report: {sentiment_report}
|
||||
Latest world affairs news: {news_report}
|
||||
Company fundamentals report: {fundamentals_report}
|
||||
Conversation history of the debate: {history}
|
||||
Last bear argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Bull Analyst: {response.content}"
|
||||
|
||||
new_investment_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"bull_history": bull_history + "\n" + argument,
|
||||
"bear_history": investment_debate_state.get("bear_history", ""),
|
||||
"current_response": argument,
|
||||
"count": investment_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"investment_debate_state": new_investment_debate_state}
|
||||
|
||||
return bull_node
|
||||
|
|
@ -1,57 +1,57 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_risky_debator(llm):
|
||||
def risky_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
risky_history = risk_debate_state.get("risky_history", "")
|
||||
|
||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Risky Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risky_history + "\n" + argument,
|
||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Risky",
|
||||
"current_risky_response": argument,
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return risky_node
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_risky_debator(llm):
|
||||
def risky_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
risky_history = risk_debate_state.get("risky_history", "")
|
||||
|
||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Risky Risk Analyst, your role is to actively champion high-reward, high-risk opportunities, emphasizing bold strategies and competitive advantages. When evaluating the trader's decision or plan, focus intently on the potential upside, growth potential, and innovative benefits—even when these come with elevated risk. Use the provided market data and sentiment analysis to strengthen your arguments and challenge the opposing views. Specifically, respond directly to each point made by the conservative and neutral analysts, countering with data-driven rebuttals and persuasive reasoning. Highlight where their caution might miss critical opportunities or where their assumptions may be overly conservative. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to create a compelling case for the trader's decision by questioning and critiquing the conservative and neutral stances to demonstrate why your high-reward perspective offers the best path forward. Incorporate insights from the following sources into your arguments:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_safe_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Risky Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risky_history + "\n" + argument,
|
||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Risky",
|
||||
"current_risky_response": argument,
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return risky_node
|
||||
|
|
@ -1,60 +1,60 @@
|
|||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_safe_debator(llm):
|
||||
def safe_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
safe_history = risk_debate_state.get("safe_history", "")
|
||||
|
||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Safe Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||
"safe_history": safe_history + "\n" + argument,
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Safe",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", ""
|
||||
),
|
||||
"current_safe_response": argument,
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return safe_node
|
||||
from langchain_core.messages import AIMessage
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_safe_debator(llm):
|
||||
def safe_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
safe_history = risk_debate_state.get("safe_history", "")
|
||||
|
||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||
current_neutral_response = risk_debate_state.get("current_neutral_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Safe/Conservative Risk Analyst, your primary objective is to protect assets, minimize volatility, and ensure steady, reliable growth. You prioritize stability, security, and risk mitigation, carefully assessing potential losses, economic downturns, and market volatility. When evaluating the trader's decision or plan, critically examine high-risk elements, pointing out where the decision may expose the firm to undue risk and where more cautious alternatives could secure long-term gains. Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to actively counter the arguments of the Risky and Neutral Analysts, highlighting where their views may overlook potential threats or fail to prioritize sustainability. Respond directly to their points, drawing from the following data sources to build a convincing case for a low-risk approach adjustment to the trader's decision:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Safe Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||
"safe_history": safe_history + "\n" + argument,
|
||||
"neutral_history": risk_debate_state.get("neutral_history", ""),
|
||||
"latest_speaker": "Safe",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", ""
|
||||
),
|
||||
"current_safe_response": argument,
|
||||
"current_neutral_response": risk_debate_state.get(
|
||||
"current_neutral_response", ""
|
||||
),
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return safe_node
|
||||
|
|
@ -1,57 +1,57 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
def neutral_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
neutral_history = risk_debate_state.get("neutral_history", "")
|
||||
|
||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Neutral Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||
"neutral_history": neutral_history + "\n" + argument,
|
||||
"latest_speaker": "Neutral",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", ""
|
||||
),
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": argument,
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return neutral_node
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
def neutral_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
history = risk_debate_state.get("history", "")
|
||||
neutral_history = risk_debate_state.get("neutral_history", "")
|
||||
|
||||
current_risky_response = risk_debate_state.get("current_risky_response", "")
|
||||
current_safe_response = risk_debate_state.get("current_safe_response", "")
|
||||
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
trader_decision = state["trader_investment_plan"]
|
||||
|
||||
prompt = f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
As the Neutral Risk Analyst, your role is to provide a balanced perspective, weighing both the potential benefits and risks of the trader's decision or plan. You prioritize a well-rounded approach, evaluating the upsides and downsides while factoring in broader market trends, potential economic shifts, and diversification strategies.Here is the trader's decision:
|
||||
|
||||
{trader_decision}
|
||||
|
||||
Your task is to challenge both the Risky and Safe Analysts, pointing out where each perspective may be overly optimistic or overly cautious. Use insights from the following data sources to support a moderate, sustainable strategy to adjust the trader's decision:
|
||||
|
||||
Market Research Report: {market_research_report}
|
||||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
Here is the current conversation history: {history} Here is the last response from the risky analyst: {current_risky_response} Here is the last response from the safe analyst: {current_safe_response}. If there are no responses from the other viewpoints, do not halluncinate and just present your point.
|
||||
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the risky and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
argument = f"Neutral Analyst: {response.content}"
|
||||
|
||||
new_risk_debate_state = {
|
||||
"history": history + "\n" + argument,
|
||||
"risky_history": risk_debate_state.get("risky_history", ""),
|
||||
"safe_history": risk_debate_state.get("safe_history", ""),
|
||||
"neutral_history": neutral_history + "\n" + argument,
|
||||
"latest_speaker": "Neutral",
|
||||
"current_risky_response": risk_debate_state.get(
|
||||
"current_risky_response", ""
|
||||
),
|
||||
"current_safe_response": risk_debate_state.get("current_safe_response", ""),
|
||||
"current_neutral_response": argument,
|
||||
"count": risk_debate_state["count"] + 1,
|
||||
}
|
||||
|
||||
return {"risk_debate_state": new_risk_debate_state}
|
||||
|
||||
return neutral_node
|
||||
|
|
@ -1,45 +1,45 @@
|
|||
import functools
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_trader(llm, memory):
|
||||
def trader_node(state, name):
|
||||
company_name = state["company_of_interest"]
|
||||
investment_plan = state["investment_plan"]
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
context = {
|
||||
"role": "user",
|
||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
|
||||
result = llm.invoke(messages)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"trader_investment_plan": result.content,
|
||||
"sender": name,
|
||||
}
|
||||
|
||||
return functools.partial(trader_node, name="Trader")
|
||||
import functools
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
def create_trader(llm, memory):
|
||||
def trader_node(state, name):
|
||||
company_name = state["company_of_interest"]
|
||||
investment_plan = state["investment_plan"]
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
for i, rec in enumerate(past_memories, 1):
|
||||
past_memory_str += rec["recommendation"] + "\n\n"
|
||||
|
||||
context = {
|
||||
"role": "user",
|
||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"""**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are a trading agent analyzing market data to make investment decisions. Based on your analysis, provide a specific recommendation to buy, sell, or hold. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL**' to confirm your recommendation. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
|
||||
result = llm.invoke(messages)
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"trader_investment_plan": result.content,
|
||||
"sender": name,
|
||||
}
|
||||
|
||||
return functools.partial(trader_node, name="Trader")
|
||||
|
|
@ -1,76 +1,76 @@
|
|||
from typing import Annotated, Sequence
|
||||
from datetime import date, timedelta, datetime
|
||||
from typing_extensions import TypedDict, Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tradingagents.agents import *
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
||||
|
||||
|
||||
# Researcher team state
|
||||
class InvestDebateState(TypedDict):
|
||||
bull_history: Annotated[
|
||||
str, "Bullish Conversation history"
|
||||
] # Bullish Conversation history
|
||||
bear_history: Annotated[
|
||||
str, "Bearish Conversation history"
|
||||
] # Bullish Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
current_response: Annotated[str, "Latest response"] # Last response
|
||||
judge_decision: Annotated[str, "Final judge decision"] # Last response
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
|
||||
|
||||
# Risk management team state
|
||||
class RiskDebateState(TypedDict):
|
||||
risky_history: Annotated[
|
||||
str, "Risky Agent's Conversation history"
|
||||
] # Conversation history
|
||||
safe_history: Annotated[
|
||||
str, "Safe Agent's Conversation history"
|
||||
] # Conversation history
|
||||
neutral_history: Annotated[
|
||||
str, "Neutral Agent's Conversation history"
|
||||
] # Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||
current_risky_response: Annotated[
|
||||
str, "Latest response by the risky analyst"
|
||||
] # Last response
|
||||
current_safe_response: Annotated[
|
||||
str, "Latest response by the safe analyst"
|
||||
] # Last response
|
||||
current_neutral_response: Annotated[
|
||||
str, "Latest response by the neutral analyst"
|
||||
] # Last response
|
||||
judge_decision: Annotated[str, "Judge's decision"]
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
|
||||
|
||||
class AgentState(MessagesState):
|
||||
company_of_interest: Annotated[str, "Company that we are interested in trading"]
|
||||
trade_date: Annotated[str, "What date we are trading at"]
|
||||
|
||||
sender: Annotated[str, "Agent that sent this message"]
|
||||
|
||||
# research step
|
||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
||||
news_report: Annotated[
|
||||
str, "Report from the News Researcher of current world affairs"
|
||||
]
|
||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
||||
|
||||
# researcher team discussion step
|
||||
investment_debate_state: Annotated[
|
||||
InvestDebateState, "Current state of the debate on if to invest or not"
|
||||
]
|
||||
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
||||
|
||||
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
|
||||
|
||||
# risk management team discussion step
|
||||
risk_debate_state: Annotated[
|
||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
||||
]
|
||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||
from typing import Annotated, Sequence
|
||||
from datetime import date, timedelta, datetime
|
||||
from typing_extensions import TypedDict, Optional
|
||||
from langchain_openai import ChatOpenAI
|
||||
from tradingagents.agents import *
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langgraph.graph import END, StateGraph, START, MessagesState
|
||||
|
||||
|
||||
# Researcher team state
|
||||
class InvestDebateState(TypedDict):
|
||||
bull_history: Annotated[
|
||||
str, "Bullish Conversation history"
|
||||
] # Bullish Conversation history
|
||||
bear_history: Annotated[
|
||||
str, "Bearish Conversation history"
|
||||
] # Bullish Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
current_response: Annotated[str, "Latest response"] # Last response
|
||||
judge_decision: Annotated[str, "Final judge decision"] # Last response
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
|
||||
|
||||
# Risk management team state
|
||||
class RiskDebateState(TypedDict):
|
||||
risky_history: Annotated[
|
||||
str, "Risky Agent's Conversation history"
|
||||
] # Conversation history
|
||||
safe_history: Annotated[
|
||||
str, "Safe Agent's Conversation history"
|
||||
] # Conversation history
|
||||
neutral_history: Annotated[
|
||||
str, "Neutral Agent's Conversation history"
|
||||
] # Conversation history
|
||||
history: Annotated[str, "Conversation history"] # Conversation history
|
||||
latest_speaker: Annotated[str, "Analyst that spoke last"]
|
||||
current_risky_response: Annotated[
|
||||
str, "Latest response by the risky analyst"
|
||||
] # Last response
|
||||
current_safe_response: Annotated[
|
||||
str, "Latest response by the safe analyst"
|
||||
] # Last response
|
||||
current_neutral_response: Annotated[
|
||||
str, "Latest response by the neutral analyst"
|
||||
] # Last response
|
||||
judge_decision: Annotated[str, "Judge's decision"]
|
||||
count: Annotated[int, "Length of the current conversation"] # Conversation length
|
||||
|
||||
|
||||
class AgentState(MessagesState):
|
||||
company_of_interest: Annotated[str, "Company that we are interested in trading"]
|
||||
trade_date: Annotated[str, "What date we are trading at"]
|
||||
|
||||
sender: Annotated[str, "Agent that sent this message"]
|
||||
|
||||
# research step
|
||||
market_report: Annotated[str, "Report from the Market Analyst"]
|
||||
sentiment_report: Annotated[str, "Report from the Social Media Analyst"]
|
||||
news_report: Annotated[
|
||||
str, "Report from the News Researcher of current world affairs"
|
||||
]
|
||||
fundamentals_report: Annotated[str, "Report from the Fundamentals Researcher"]
|
||||
|
||||
# researcher team discussion step
|
||||
investment_debate_state: Annotated[
|
||||
InvestDebateState, "Current state of the debate on if to invest or not"
|
||||
]
|
||||
investment_plan: Annotated[str, "Plan generated by the Analyst"]
|
||||
|
||||
trader_investment_plan: Annotated[str, "Plan generated by the Trader"]
|
||||
|
||||
# risk management team discussion step
|
||||
risk_debate_state: Annotated[
|
||||
RiskDebateState, "Current state of the debate on evaluating risk"
|
||||
]
|
||||
final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"]
|
||||
|
|
@ -1,20 +1,20 @@
|
|||
from .embedding_providers import (
|
||||
EmbeddingProvider,
|
||||
OpenAIEmbeddingProvider,
|
||||
GeminiEmbeddingProvider,
|
||||
OllamaEmbeddingProvider
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
class EmbeddingProviderFactory:
|
||||
@staticmethod
|
||||
def create_provider(config : dict[str, Any])->EmbeddingProvider:
|
||||
backend_url = config["backend_url"]
|
||||
|
||||
if "generativelanguage.googleapis.com" in backend_url:
|
||||
return GeminiEmbeddingProvider(backend_url)
|
||||
elif "localhost:11434" in backend_url:
|
||||
return OllamaEmbeddingProvider(backend_url)
|
||||
else:
|
||||
return OpenAIEmbeddingProvider(backend_url)
|
||||
from .embedding_providers import (
|
||||
EmbeddingProvider,
|
||||
OpenAIEmbeddingProvider,
|
||||
GeminiEmbeddingProvider,
|
||||
OllamaEmbeddingProvider
|
||||
)
|
||||
from typing import Any
|
||||
|
||||
class EmbeddingProviderFactory:
|
||||
@staticmethod
|
||||
def create_provider(config : dict[str, Any])->EmbeddingProvider:
|
||||
backend_url = config["backend_url"]
|
||||
|
||||
if "generativelanguage.googleapis.com" in backend_url:
|
||||
return GeminiEmbeddingProvider(backend_url)
|
||||
elif "localhost:11434" in backend_url:
|
||||
return OllamaEmbeddingProvider(backend_url)
|
||||
else:
|
||||
return OpenAIEmbeddingProvider(backend_url)
|
||||
|
||||
|
|
@ -1,66 +1,66 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from openai import OpenAI
|
||||
from google import genai
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
@abstractmethod
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self)->str:
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model=self._embedding_model,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
|
||||
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
|
||||
self.client = genai.Client()
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.models.embed_content(
|
||||
model=self._embedding_model,
|
||||
contents=text
|
||||
)
|
||||
return response.embeddings[0].values
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model=self._embedding_model,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
from abc import ABC, abstractmethod
|
||||
from openai import OpenAI
|
||||
from google import genai
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
@abstractmethod
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_name(self)->str:
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "text-embedding-3-small"):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model=self._embedding_model,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
|
||||
class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "gemini-embedding-exp-03-07"):
|
||||
self.client = genai.Client()
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.models.embed_content(
|
||||
model=self._embedding_model,
|
||||
contents=text
|
||||
)
|
||||
return response.embeddings[0].values
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
class OllamaEmbeddingProvider(EmbeddingProvider):
|
||||
def __init__(self, backend_url: str, embedding_model: str = "nomic-embed-text"):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self._embedding_model = embedding_model
|
||||
|
||||
def get_embedding(self, text: str)->list[float]:
|
||||
response = self.client.embeddings.create(
|
||||
model=self._embedding_model,
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self._embedding_model
|
||||
|
||||
|
|
@ -1,112 +1,112 @@
|
|||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from .embedding_provider_factory import EmbeddingProviderFactory
|
||||
from google import genai
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
self.config = config
|
||||
self.backend_url = config["backend_url"]
|
||||
|
||||
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
|
||||
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get embedding for a text using the appropriate API"""
|
||||
|
||||
return self.embedding_provider.get_embedding(text)
|
||||
|
||||
def add_situations(self, situations_and_advice):
|
||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||
|
||||
situations = []
|
||||
advice = []
|
||||
ids = []
|
||||
embeddings = []
|
||||
|
||||
offset = self.situation_collection.count()
|
||||
|
||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||||
situations.append(situation)
|
||||
advice.append(recommendation)
|
||||
ids.append(str(offset + i))
|
||||
embeddings.append(self.get_embedding(situation))
|
||||
|
||||
self.situation_collection.add(
|
||||
documents=situations,
|
||||
metadatas=[{"recommendation": rec} for rec in advice],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
|
||||
def get_memories(self, current_situation, n_matches=1):
|
||||
"""Find matching recommendations using embeddings"""
|
||||
query_embedding = self.get_embedding(current_situation)
|
||||
|
||||
results = self.situation_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_matches,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
matched_results = []
|
||||
for i in range(len(results["documents"][0])):
|
||||
matched_results.append(
|
||||
{
|
||||
"matched_situation": results["documents"][0][i],
|
||||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
||||
"similarity_score": 1 - results["distances"][0][i],
|
||||
}
|
||||
)
|
||||
|
||||
return matched_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
matcher = FinancialSituationMemory()
|
||||
|
||||
# Example data
|
||||
example_data = [
|
||||
(
|
||||
"High inflation rate with rising interest rates and declining consumer spending",
|
||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||
),
|
||||
(
|
||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||
),
|
||||
(
|
||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||
),
|
||||
(
|
||||
"Market showing signs of sector rotation with rising yields",
|
||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||
),
|
||||
]
|
||||
|
||||
# Add the example situations and recommendations
|
||||
matcher.add_situations(example_data)
|
||||
|
||||
# Example query
|
||||
current_situation = """
|
||||
Market showing increased volatility in tech sector, with institutional investors
|
||||
reducing positions and rising interest rates affecting growth stock valuations
|
||||
"""
|
||||
|
||||
try:
|
||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||
|
||||
for i, rec in enumerate(recommendations, 1):
|
||||
print(f"\nMatch {i}:")
|
||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||||
print(f"Matched Situation: {rec['matched_situation']}")
|
||||
print(f"Recommendation: {rec['recommendation']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during recommendation: {str(e)}")
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from .embedding_provider_factory import EmbeddingProviderFactory
|
||||
from google import genai
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
self.config = config
|
||||
self.backend_url = config["backend_url"]
|
||||
|
||||
self.embedding_provider = EmbeddingProviderFactory.create_provider(config)
|
||||
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get embedding for a text using the appropriate API"""
|
||||
|
||||
return self.embedding_provider.get_embedding(text)
|
||||
|
||||
def add_situations(self, situations_and_advice):
|
||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||
|
||||
situations = []
|
||||
advice = []
|
||||
ids = []
|
||||
embeddings = []
|
||||
|
||||
offset = self.situation_collection.count()
|
||||
|
||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||||
situations.append(situation)
|
||||
advice.append(recommendation)
|
||||
ids.append(str(offset + i))
|
||||
embeddings.append(self.get_embedding(situation))
|
||||
|
||||
self.situation_collection.add(
|
||||
documents=situations,
|
||||
metadatas=[{"recommendation": rec} for rec in advice],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
|
||||
def get_memories(self, current_situation, n_matches=1):
|
||||
"""Find matching recommendations using embeddings"""
|
||||
query_embedding = self.get_embedding(current_situation)
|
||||
|
||||
results = self.situation_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_matches,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
|
||||
matched_results = []
|
||||
for i in range(len(results["documents"][0])):
|
||||
matched_results.append(
|
||||
{
|
||||
"matched_situation": results["documents"][0][i],
|
||||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
||||
"similarity_score": 1 - results["distances"][0][i],
|
||||
}
|
||||
)
|
||||
|
||||
return matched_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
matcher = FinancialSituationMemory()
|
||||
|
||||
# Example data
|
||||
example_data = [
|
||||
(
|
||||
"High inflation rate with rising interest rates and declining consumer spending",
|
||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||
),
|
||||
(
|
||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||
),
|
||||
(
|
||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||
),
|
||||
(
|
||||
"Market showing signs of sector rotation with rising yields",
|
||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||
),
|
||||
]
|
||||
|
||||
# Add the example situations and recommendations
|
||||
matcher.add_situations(example_data)
|
||||
|
||||
# Example query
|
||||
current_situation = """
|
||||
Market showing increased volatility in tech sector, with institutional investors
|
||||
reducing positions and rising interest rates affecting growth stock valuations
|
||||
"""
|
||||
|
||||
try:
|
||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||
|
||||
for i, rec in enumerate(recommendations, 1):
|
||||
print(f"\nMatch {i}:")
|
||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||||
print(f"Matched Situation: {rec['matched_situation']}")
|
||||
print(f"Recommendation: {rec['recommendation']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during recommendation: {str(e)}")
|
||||
|
|
@ -1,46 +1,46 @@
|
|||
from .finnhub_utils import get_data_in_range
|
||||
from .googlenews_utils import getNewsData
|
||||
from .yfin_utils import YFinanceUtils
|
||||
from .reddit_utils import fetch_top_from_category
|
||||
from .stockstats_utils import StockstatsUtils
|
||||
from .yfin_utils import YFinanceUtils
|
||||
|
||||
from .interface import (
|
||||
# News and sentiment functions
|
||||
get_finnhub_news,
|
||||
get_finnhub_company_insider_sentiment,
|
||||
get_finnhub_company_insider_transactions,
|
||||
get_google_news,
|
||||
get_reddit_global_news,
|
||||
get_reddit_company_news,
|
||||
# Financial statements functions
|
||||
get_simfin_balance_sheet,
|
||||
get_simfin_cashflow,
|
||||
get_simfin_income_statements,
|
||||
# Technical analysis functions
|
||||
get_stock_stats_indicators_window,
|
||||
get_stockstats_indicator,
|
||||
# Market data functions
|
||||
get_YFin_data_window,
|
||||
get_YFin_data,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# News and sentiment functions
|
||||
"get_finnhub_news",
|
||||
"get_finnhub_company_insider_sentiment",
|
||||
"get_finnhub_company_insider_transactions",
|
||||
"get_google_news",
|
||||
"get_reddit_global_news",
|
||||
"get_reddit_company_news",
|
||||
# Financial statements functions
|
||||
"get_simfin_balance_sheet",
|
||||
"get_simfin_cashflow",
|
||||
"get_simfin_income_statements",
|
||||
# Technical analysis functions
|
||||
"get_stock_stats_indicators_window",
|
||||
"get_stockstats_indicator",
|
||||
# Market data functions
|
||||
"get_YFin_data_window",
|
||||
"get_YFin_data",
|
||||
]
|
||||
from .finnhub_utils import get_data_in_range
|
||||
from .googlenews_utils import getNewsData
|
||||
from .yfin_utils import YFinanceUtils
|
||||
from .reddit_utils import fetch_top_from_category
|
||||
from .stockstats_utils import StockstatsUtils
|
||||
from .yfin_utils import YFinanceUtils
|
||||
|
||||
from .interface import (
|
||||
# News and sentiment functions
|
||||
get_finnhub_news,
|
||||
get_finnhub_company_insider_sentiment,
|
||||
get_finnhub_company_insider_transactions,
|
||||
get_google_news,
|
||||
get_reddit_global_news,
|
||||
get_reddit_company_news,
|
||||
# Financial statements functions
|
||||
get_simfin_balance_sheet,
|
||||
get_simfin_cashflow,
|
||||
get_simfin_income_statements,
|
||||
# Technical analysis functions
|
||||
get_stock_stats_indicators_window,
|
||||
get_stockstats_indicator,
|
||||
# Market data functions
|
||||
get_YFin_data_window,
|
||||
get_YFin_data,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# News and sentiment functions
|
||||
"get_finnhub_news",
|
||||
"get_finnhub_company_insider_sentiment",
|
||||
"get_finnhub_company_insider_transactions",
|
||||
"get_google_news",
|
||||
"get_reddit_global_news",
|
||||
"get_reddit_company_news",
|
||||
# Financial statements functions
|
||||
"get_simfin_balance_sheet",
|
||||
"get_simfin_cashflow",
|
||||
"get_simfin_income_statements",
|
||||
# Technical analysis functions
|
||||
"get_stock_stats_indicators_window",
|
||||
"get_stockstats_indicator",
|
||||
# Market data functions
|
||||
"get_YFin_data_window",
|
||||
"get_YFin_data",
|
||||
]
|
||||
|
|
@ -1,34 +1,34 @@
|
|||
import tradingagents.default_config as default_config
|
||||
from typing import Dict, Optional
|
||||
|
||||
# Use default config but allow it to be overridden
|
||||
_config: Optional[Dict] = None
|
||||
DATA_DIR: str = ""
|
||||
|
||||
|
||||
def initialize_config():
|
||||
"""Initialize the configuration with default values."""
|
||||
global _config, DATA_DIR
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
DATA_DIR = _config["data_dir"]
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
"""Update the configuration with custom values."""
|
||||
global _config, DATA_DIR
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config.update(config)
|
||||
DATA_DIR = _config["data_dir"]
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
"""Get the current configuration."""
|
||||
if _config is None:
|
||||
initialize_config()
|
||||
return _config.copy()
|
||||
|
||||
|
||||
# Initialize with default config
|
||||
initialize_config()
|
||||
import tradingagents.default_config as default_config
|
||||
from typing import Dict, Optional
|
||||
|
||||
# Use default config but allow it to be overridden
|
||||
_config: Optional[Dict] = None
|
||||
DATA_DIR: str = ""
|
||||
|
||||
|
||||
def initialize_config():
|
||||
"""Initialize the configuration with default values."""
|
||||
global _config, DATA_DIR
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
DATA_DIR = _config["data_dir"]
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
"""Update the configuration with custom values."""
|
||||
global _config, DATA_DIR
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config.update(config)
|
||||
DATA_DIR = _config["data_dir"]
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
"""Get the current configuration."""
|
||||
if _config is None:
|
||||
initialize_config()
|
||||
return _config.copy()
|
||||
|
||||
|
||||
# Initialize with default config
|
||||
initialize_config()
|
||||
|
|
@ -1,36 +1,36 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
|
||||
def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None):
|
||||
"""
|
||||
Gets finnhub data saved and processed on disk.
|
||||
Args:
|
||||
start_date (str): Start date in YYYY-MM-DD format.
|
||||
end_date (str): End date in YYYY-MM-DD format.
|
||||
data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported.
|
||||
data_dir (str): Directory where the data is saved.
|
||||
period (str): Default to none, if there is a period specified, should be annual or quarterly.
|
||||
"""
|
||||
|
||||
if period:
|
||||
data_path = os.path.join(
|
||||
data_dir,
|
||||
"finnhub_data",
|
||||
data_type,
|
||||
f"{ticker}_{period}_data_formatted.json",
|
||||
)
|
||||
else:
|
||||
data_path = os.path.join(
|
||||
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
|
||||
)
|
||||
|
||||
data = open(data_path, "r")
|
||||
data = json.load(data)
|
||||
|
||||
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)
|
||||
filtered_data = {}
|
||||
for key, value in data.items():
|
||||
if start_date <= key <= end_date and len(value) > 0:
|
||||
filtered_data[key] = value
|
||||
return filtered_data
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def get_data_in_range(ticker, start_date, end_date, data_type, data_dir, period=None):
|
||||
"""
|
||||
Gets finnhub data saved and processed on disk.
|
||||
Args:
|
||||
start_date (str): Start date in YYYY-MM-DD format.
|
||||
end_date (str): End date in YYYY-MM-DD format.
|
||||
data_type (str): Type of data from finnhub to fetch. Can be insider_trans, SEC_filings, news_data, insider_senti, or fin_as_reported.
|
||||
data_dir (str): Directory where the data is saved.
|
||||
period (str): Default to none, if there is a period specified, should be annual or quarterly.
|
||||
"""
|
||||
|
||||
if period:
|
||||
data_path = os.path.join(
|
||||
data_dir,
|
||||
"finnhub_data",
|
||||
data_type,
|
||||
f"{ticker}_{period}_data_formatted.json",
|
||||
)
|
||||
else:
|
||||
data_path = os.path.join(
|
||||
data_dir, "finnhub_data", data_type, f"{ticker}_data_formatted.json"
|
||||
)
|
||||
|
||||
data = open(data_path, "r")
|
||||
data = json.load(data)
|
||||
|
||||
# filter keys (date, str in format YYYY-MM-DD) by the date range (str, str in format YYYY-MM-DD)
|
||||
filtered_data = {}
|
||||
for key, value in data.items():
|
||||
if start_date <= key <= end_date and len(value) > 0:
|
||||
filtered_data[key] = value
|
||||
return filtered_data
|
||||
|
|
@ -1,112 +1,112 @@
|
|||
import json
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from datetime import datetime
|
||||
import time
|
||||
import random
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
retry_if_result,
|
||||
)
|
||||
|
||||
|
||||
def is_rate_limited(response):
|
||||
"""Check if the response indicates rate limiting (status code 429)"""
|
||||
return response.status_code == 429
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(retry_if_result(is_rate_limited)),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
stop=stop_after_attempt(5),
|
||||
)
|
||||
def make_request(url, headers):
|
||||
"""Make a request with retry logic for rate limiting"""
|
||||
# Random delay before each request to avoid detection
|
||||
time.sleep(random.uniform(2, 6))
|
||||
response = requests.get(url, headers=headers)
|
||||
return response
|
||||
|
||||
|
||||
def getNewsData(query, start_date, end_date):
|
||||
"""
|
||||
Scrape Google News search results for a given query and date range.
|
||||
query: str - search query
|
||||
start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy
|
||||
end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy
|
||||
"""
|
||||
if "-" in start_date:
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
start_date = start_date.strftime("%m/%d/%Y")
|
||||
if "-" in end_date:
|
||||
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
end_date = end_date.strftime("%m/%d/%Y")
|
||||
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/101.0.4951.54 Safari/537.36"
|
||||
)
|
||||
}
|
||||
|
||||
news_results = []
|
||||
page = 0
|
||||
while True:
|
||||
offset = page * 10
|
||||
url = (
|
||||
f"https://www.google.com/search?q={query}"
|
||||
f"&tbs=cdr:1,cd_min:{start_date},cd_max:{end_date}"
|
||||
f"&tbm=nws&start={offset}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = make_request(url, headers)
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
results_on_page = soup.select("div.SoaBEf")
|
||||
|
||||
if not results_on_page:
|
||||
break # No more results found
|
||||
|
||||
for el in results_on_page:
|
||||
try:
|
||||
link = el.find("a")["href"]
|
||||
title_el = el.select_one("div.MBeuO")
|
||||
title = title_el.get_text() if title_el else ""
|
||||
snippet_el = el.select_one(".GI74Re")
|
||||
snippet = snippet_el.get_text() if snippet_el else ""
|
||||
date_el = el.select_one(".LfVVr")
|
||||
date = date_el.get_text() if date_el else ""
|
||||
source_el = el.select_one(".NUnG9d span")
|
||||
source = source_el.get_text() if source_el else ""
|
||||
news_results.append(
|
||||
{
|
||||
"link": link,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"date": date,
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing result: {e}")
|
||||
# If one of the fields is not found, skip this result
|
||||
continue
|
||||
|
||||
# Update the progress bar with the current count of results scraped
|
||||
|
||||
# Check for the "Next" link (pagination)
|
||||
next_link = soup.find("a", id="pnnext")
|
||||
if not next_link:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed after multiple retries: {e}")
|
||||
break
|
||||
|
||||
return news_results
|
||||
import json
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from datetime import datetime
|
||||
import time
|
||||
import random
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
retry_if_result,
|
||||
)
|
||||
|
||||
|
||||
def is_rate_limited(response):
|
||||
"""Check if the response indicates rate limiting (status code 429)"""
|
||||
return response.status_code == 429
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(retry_if_result(is_rate_limited)),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
stop=stop_after_attempt(5),
|
||||
)
|
||||
def make_request(url, headers):
|
||||
"""Make a request with retry logic for rate limiting"""
|
||||
# Random delay before each request to avoid detection
|
||||
time.sleep(random.uniform(2, 6))
|
||||
response = requests.get(url, headers=headers)
|
||||
return response
|
||||
|
||||
|
||||
def getNewsData(query, start_date, end_date):
|
||||
"""
|
||||
Scrape Google News search results for a given query and date range.
|
||||
query: str - search query
|
||||
start_date: str - start date in the format yyyy-mm-dd or mm/dd/yyyy
|
||||
end_date: str - end date in the format yyyy-mm-dd or mm/dd/yyyy
|
||||
"""
|
||||
if "-" in start_date:
|
||||
start_date = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
start_date = start_date.strftime("%m/%d/%Y")
|
||||
if "-" in end_date:
|
||||
end_date = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
end_date = end_date.strftime("%m/%d/%Y")
|
||||
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/101.0.4951.54 Safari/537.36"
|
||||
)
|
||||
}
|
||||
|
||||
news_results = []
|
||||
page = 0
|
||||
while True:
|
||||
offset = page * 10
|
||||
url = (
|
||||
f"https://www.google.com/search?q={query}"
|
||||
f"&tbs=cdr:1,cd_min:{start_date},cd_max:{end_date}"
|
||||
f"&tbm=nws&start={offset}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = make_request(url, headers)
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
results_on_page = soup.select("div.SoaBEf")
|
||||
|
||||
if not results_on_page:
|
||||
break # No more results found
|
||||
|
||||
for el in results_on_page:
|
||||
try:
|
||||
link = el.find("a")["href"]
|
||||
title_el = el.select_one("div.MBeuO")
|
||||
title = title_el.get_text() if title_el else ""
|
||||
snippet_el = el.select_one(".GI74Re")
|
||||
snippet = snippet_el.get_text() if snippet_el else ""
|
||||
date_el = el.select_one(".LfVVr")
|
||||
date = date_el.get_text() if date_el else ""
|
||||
source_el = el.select_one(".NUnG9d span")
|
||||
source = source_el.get_text() if source_el else ""
|
||||
news_results.append(
|
||||
{
|
||||
"link": link,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"date": date,
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing result: {e}")
|
||||
# If one of the fields is not found, skip this result
|
||||
continue
|
||||
|
||||
# Update the progress bar with the current count of results scraped
|
||||
|
||||
# Check for the "Next" link (pagination)
|
||||
next_link = soup.find("a", id="pnnext")
|
||||
if not next_link:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed after multiple retries: {e}")
|
||||
break
|
||||
|
||||
return news_results
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,135 +1,135 @@
|
|||
import requests
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import contextmanager
|
||||
from typing import Annotated
|
||||
import os
|
||||
import re
|
||||
|
||||
ticker_to_company = {
|
||||
"AAPL": "Apple",
|
||||
"MSFT": "Microsoft",
|
||||
"GOOGL": "Google",
|
||||
"AMZN": "Amazon",
|
||||
"TSLA": "Tesla",
|
||||
"NVDA": "Nvidia",
|
||||
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
|
||||
"JPM": "JPMorgan Chase OR JP Morgan",
|
||||
"JNJ": "Johnson & Johnson OR JNJ",
|
||||
"V": "Visa",
|
||||
"WMT": "Walmart",
|
||||
"META": "Meta OR Facebook",
|
||||
"AMD": "AMD",
|
||||
"INTC": "Intel",
|
||||
"QCOM": "Qualcomm",
|
||||
"BABA": "Alibaba",
|
||||
"ADBE": "Adobe",
|
||||
"NFLX": "Netflix",
|
||||
"CRM": "Salesforce",
|
||||
"PYPL": "PayPal",
|
||||
"PLTR": "Palantir",
|
||||
"MU": "Micron",
|
||||
"SQ": "Block OR Square",
|
||||
"ZM": "Zoom",
|
||||
"CSCO": "Cisco",
|
||||
"SHOP": "Shopify",
|
||||
"ORCL": "Oracle",
|
||||
"X": "Twitter OR X",
|
||||
"SPOT": "Spotify",
|
||||
"AVGO": "Broadcom",
|
||||
"ASML": "ASML ",
|
||||
"TWLO": "Twilio",
|
||||
"SNAP": "Snap Inc.",
|
||||
"TEAM": "Atlassian",
|
||||
"SQSP": "Squarespace",
|
||||
"UBER": "Uber",
|
||||
"ROKU": "Roku",
|
||||
"PINS": "Pinterest",
|
||||
}
|
||||
|
||||
|
||||
def fetch_top_from_category(
|
||||
category: Annotated[
|
||||
str, "Category to fetch top post from. Collection of subreddits."
|
||||
],
|
||||
date: Annotated[str, "Date to fetch top posts from."],
|
||||
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
||||
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
|
||||
data_path: Annotated[
|
||||
str,
|
||||
"Path to the data folder. Default is 'reddit_data'.",
|
||||
] = "reddit_data",
|
||||
):
|
||||
base_path = data_path
|
||||
|
||||
all_content = []
|
||||
|
||||
if max_limit < len(os.listdir(os.path.join(base_path, category))):
|
||||
raise ValueError(
|
||||
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
|
||||
)
|
||||
|
||||
limit_per_subreddit = max_limit // len(
|
||||
os.listdir(os.path.join(base_path, category))
|
||||
)
|
||||
|
||||
for data_file in os.listdir(os.path.join(base_path, category)):
|
||||
# check if data_file is a .jsonl file
|
||||
if not data_file.endswith(".jsonl"):
|
||||
continue
|
||||
|
||||
all_content_curr_subreddit = []
|
||||
|
||||
with open(os.path.join(base_path, category, data_file), "rb") as f:
|
||||
for i, line in enumerate(f):
|
||||
# skip empty lines
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
parsed_line = json.loads(line)
|
||||
|
||||
# select only lines that are from the date
|
||||
post_date = datetime.utcfromtimestamp(
|
||||
parsed_line["created_utc"]
|
||||
).strftime("%Y-%m-%d")
|
||||
if post_date != date:
|
||||
continue
|
||||
|
||||
# if is company_news, check that the title or the content has the company's name (query) mentioned
|
||||
if "company" in category and query:
|
||||
search_terms = []
|
||||
if "OR" in ticker_to_company[query]:
|
||||
search_terms = ticker_to_company[query].split(" OR ")
|
||||
else:
|
||||
search_terms = [ticker_to_company[query]]
|
||||
|
||||
search_terms.append(query)
|
||||
|
||||
found = False
|
||||
for term in search_terms:
|
||||
if re.search(
|
||||
term, parsed_line["title"], re.IGNORECASE
|
||||
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
continue
|
||||
|
||||
post = {
|
||||
"title": parsed_line["title"],
|
||||
"content": parsed_line["selftext"],
|
||||
"url": parsed_line["url"],
|
||||
"upvotes": parsed_line["ups"],
|
||||
"posted_date": post_date,
|
||||
}
|
||||
|
||||
all_content_curr_subreddit.append(post)
|
||||
|
||||
# sort all_content_curr_subreddit by upvote_ratio in descending order
|
||||
all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True)
|
||||
|
||||
all_content.extend(all_content_curr_subreddit[:limit_per_subreddit])
|
||||
|
||||
return all_content
|
||||
import requests
|
||||
import time
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import contextmanager
|
||||
from typing import Annotated
|
||||
import os
|
||||
import re
|
||||
|
||||
ticker_to_company = {
|
||||
"AAPL": "Apple",
|
||||
"MSFT": "Microsoft",
|
||||
"GOOGL": "Google",
|
||||
"AMZN": "Amazon",
|
||||
"TSLA": "Tesla",
|
||||
"NVDA": "Nvidia",
|
||||
"TSM": "Taiwan Semiconductor Manufacturing Company OR TSMC",
|
||||
"JPM": "JPMorgan Chase OR JP Morgan",
|
||||
"JNJ": "Johnson & Johnson OR JNJ",
|
||||
"V": "Visa",
|
||||
"WMT": "Walmart",
|
||||
"META": "Meta OR Facebook",
|
||||
"AMD": "AMD",
|
||||
"INTC": "Intel",
|
||||
"QCOM": "Qualcomm",
|
||||
"BABA": "Alibaba",
|
||||
"ADBE": "Adobe",
|
||||
"NFLX": "Netflix",
|
||||
"CRM": "Salesforce",
|
||||
"PYPL": "PayPal",
|
||||
"PLTR": "Palantir",
|
||||
"MU": "Micron",
|
||||
"SQ": "Block OR Square",
|
||||
"ZM": "Zoom",
|
||||
"CSCO": "Cisco",
|
||||
"SHOP": "Shopify",
|
||||
"ORCL": "Oracle",
|
||||
"X": "Twitter OR X",
|
||||
"SPOT": "Spotify",
|
||||
"AVGO": "Broadcom",
|
||||
"ASML": "ASML ",
|
||||
"TWLO": "Twilio",
|
||||
"SNAP": "Snap Inc.",
|
||||
"TEAM": "Atlassian",
|
||||
"SQSP": "Squarespace",
|
||||
"UBER": "Uber",
|
||||
"ROKU": "Roku",
|
||||
"PINS": "Pinterest",
|
||||
}
|
||||
|
||||
|
||||
def fetch_top_from_category(
|
||||
category: Annotated[
|
||||
str, "Category to fetch top post from. Collection of subreddits."
|
||||
],
|
||||
date: Annotated[str, "Date to fetch top posts from."],
|
||||
max_limit: Annotated[int, "Maximum number of posts to fetch."],
|
||||
query: Annotated[str, "Optional query to search for in the subreddit."] = None,
|
||||
data_path: Annotated[
|
||||
str,
|
||||
"Path to the data folder. Default is 'reddit_data'.",
|
||||
] = "reddit_data",
|
||||
):
|
||||
base_path = data_path
|
||||
|
||||
all_content = []
|
||||
|
||||
if max_limit < len(os.listdir(os.path.join(base_path, category))):
|
||||
raise ValueError(
|
||||
"REDDIT FETCHING ERROR: max limit is less than the number of files in the category. Will not be able to fetch any posts"
|
||||
)
|
||||
|
||||
limit_per_subreddit = max_limit // len(
|
||||
os.listdir(os.path.join(base_path, category))
|
||||
)
|
||||
|
||||
for data_file in os.listdir(os.path.join(base_path, category)):
|
||||
# check if data_file is a .jsonl file
|
||||
if not data_file.endswith(".jsonl"):
|
||||
continue
|
||||
|
||||
all_content_curr_subreddit = []
|
||||
|
||||
with open(os.path.join(base_path, category, data_file), "rb") as f:
|
||||
for i, line in enumerate(f):
|
||||
# skip empty lines
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
parsed_line = json.loads(line)
|
||||
|
||||
# select only lines that are from the date
|
||||
post_date = datetime.utcfromtimestamp(
|
||||
parsed_line["created_utc"]
|
||||
).strftime("%Y-%m-%d")
|
||||
if post_date != date:
|
||||
continue
|
||||
|
||||
# if is company_news, check that the title or the content has the company's name (query) mentioned
|
||||
if "company" in category and query:
|
||||
search_terms = []
|
||||
if "OR" in ticker_to_company[query]:
|
||||
search_terms = ticker_to_company[query].split(" OR ")
|
||||
else:
|
||||
search_terms = [ticker_to_company[query]]
|
||||
|
||||
search_terms.append(query)
|
||||
|
||||
found = False
|
||||
for term in search_terms:
|
||||
if re.search(
|
||||
term, parsed_line["title"], re.IGNORECASE
|
||||
) or re.search(term, parsed_line["selftext"], re.IGNORECASE):
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
continue
|
||||
|
||||
post = {
|
||||
"title": parsed_line["title"],
|
||||
"content": parsed_line["selftext"],
|
||||
"url": parsed_line["url"],
|
||||
"upvotes": parsed_line["ups"],
|
||||
"posted_date": post_date,
|
||||
}
|
||||
|
||||
all_content_curr_subreddit.append(post)
|
||||
|
||||
# sort all_content_curr_subreddit by upvote_ratio in descending order
|
||||
all_content_curr_subreddit.sort(key=lambda x: x["upvotes"], reverse=True)
|
||||
|
||||
all_content.extend(all_content_curr_subreddit[:limit_per_subreddit])
|
||||
|
||||
return all_content
|
||||
|
|
@ -1,76 +1,76 @@
|
|||
from google import genai
|
||||
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
||||
from openai import OpenAI
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
class SearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class GoogleSearchProvider(SearchProvider):
|
||||
def __init__(self, model: str):
|
||||
self.client = genai.Client()
|
||||
self.model = model
|
||||
|
||||
def search(self, query: str) -> str:
|
||||
google_search_tool = Tool(
|
||||
google_search=GoogleSearch()
|
||||
)
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=query,
|
||||
config=GenerateContentConfig(
|
||||
tools=[google_search_tool],
|
||||
response_modalities=["TEXT"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
result_text = ""
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, 'text'):
|
||||
result_text += part.text
|
||||
|
||||
return result_text
|
||||
|
||||
|
||||
class OpenAISearchProvider(SearchProvider):
|
||||
def __init__(self, model: str, backend_url: str):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self.model = model
|
||||
|
||||
def search(self, query: str) -> str:
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": query
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
from google import genai
|
||||
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
|
||||
from openai import OpenAI
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
class SearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class GoogleSearchProvider(SearchProvider):
|
||||
def __init__(self, model: str):
|
||||
self.client = genai.Client()
|
||||
self.model = model
|
||||
|
||||
def search(self, query: str) -> str:
|
||||
google_search_tool = Tool(
|
||||
google_search=GoogleSearch()
|
||||
)
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=query,
|
||||
config=GenerateContentConfig(
|
||||
tools=[google_search_tool],
|
||||
response_modalities=["TEXT"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
result_text = ""
|
||||
for part in response.candidates[0].content.parts:
|
||||
if hasattr(part, 'text'):
|
||||
result_text += part.text
|
||||
|
||||
return result_text
|
||||
|
||||
|
||||
class OpenAISearchProvider(SearchProvider):
|
||||
def __init__(self, model: str, backend_url: str):
|
||||
self.client = OpenAI(base_url=backend_url)
|
||||
self.model = model
|
||||
|
||||
def search(self, query: str) -> str:
|
||||
response = self.client.responses.create(
|
||||
model=self.model,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": query
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
text={"format": {"type": "text"}},
|
||||
reasoning={},
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_preview",
|
||||
"user_location": {"type": "approximate"},
|
||||
"search_context_size": "low",
|
||||
}
|
||||
],
|
||||
temperature=1,
|
||||
max_output_tokens=4096,
|
||||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
|
|
@ -1,133 +1,133 @@
|
|||
from .search_provider import SearchProvider
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Dict, Callable, Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ProviderSelector(ABC):
|
||||
"""Abstract base class for provider selection strategies."""
|
||||
|
||||
@abstractmethod
|
||||
def select_provider_type(self, config: Dict[str, Any]) -> str:
|
||||
"""Select provider type based on configuration."""
|
||||
pass
|
||||
|
||||
|
||||
class MappingBasedProviderSelector(ProviderSelector):
|
||||
"""Selects provider based on URL pattern mapping table."""
|
||||
|
||||
def __init__(self, mappings: Dict[str, str], default_provider: str = "openai"):
|
||||
self._mappings = mappings
|
||||
self._default_provider = default_provider
|
||||
|
||||
def select_provider_type(self, config: Dict[str, Any]) -> str:
|
||||
backend_url = config.get("backend_url", "")
|
||||
for pattern, provider_type in self._mappings.items():
|
||||
if pattern in backend_url:
|
||||
return provider_type
|
||||
return self._default_provider
|
||||
|
||||
|
||||
class SearchProviderRegistry:
|
||||
"""Registry for search provider creation functions."""
|
||||
|
||||
def __init__(self):
|
||||
self._providers: Dict[str, Callable[[Dict[str, Any]], SearchProvider]] = {}
|
||||
|
||||
def register(self, provider_type: str, creator: Callable[[Dict[str, Any]], SearchProvider]):
|
||||
"""Register a provider creator function."""
|
||||
self._providers[provider_type] = creator
|
||||
|
||||
def create(self, provider_type: str, config: Dict[str, Any]) -> SearchProvider:
|
||||
"""Create a provider instance using registered creator."""
|
||||
if provider_type not in self._providers:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
return self._providers[provider_type](config)
|
||||
|
||||
def get_available_types(self) -> list[str]:
|
||||
"""Get list of available provider types."""
|
||||
return list(self._providers.keys())
|
||||
|
||||
|
||||
class SearchProviderFactoryImpl:
|
||||
"""Enhanced factory for creating SearchProvider instances with caching and extensibility."""
|
||||
|
||||
def __init__(self, registry: SearchProviderRegistry, selector: ProviderSelector):
|
||||
self._registry = registry
|
||||
self._selector = selector
|
||||
self._cache: Dict[str, SearchProvider] = {}
|
||||
|
||||
def create_provider(self, config: Dict[str, Any]) -> SearchProvider:
|
||||
"""
|
||||
Create a SearchProvider with caching to avoid creating new instances.
|
||||
Uses config hash as cache key for efficient reuse.
|
||||
"""
|
||||
# Create cache key from relevant config values
|
||||
cache_key_data = {
|
||||
"backend_url": config.get("backend_url", ""),
|
||||
"model": config.get("quick_think_llm", "")
|
||||
}
|
||||
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
# Return cached instance if exists
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
# Select and create provider
|
||||
provider_type = self._selector.select_provider_type(config)
|
||||
provider = self._registry.create(provider_type, config)
|
||||
|
||||
# Cache and return
|
||||
self._cache[cache_key] = provider
|
||||
return provider
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the provider cache (useful for testing or config changes)."""
|
||||
self._cache.clear()
|
||||
|
||||
def get_available_provider_types(self) -> list[str]:
|
||||
"""Get list of available provider types."""
|
||||
return self._registry.get_available_types()
|
||||
|
||||
|
||||
def create_search_provider_factory() -> SearchProviderFactoryImpl:
|
||||
"""Create a configured SearchProviderFactory with default providers."""
|
||||
registry = SearchProviderRegistry()
|
||||
|
||||
# Register default providers
|
||||
def create_google_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||
from .search_provider import GoogleSearchProvider
|
||||
return GoogleSearchProvider(config["quick_think_llm"])
|
||||
|
||||
def create_openai_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||
from .search_provider import OpenAISearchProvider
|
||||
return OpenAISearchProvider(config["quick_think_llm"], config["backend_url"])
|
||||
|
||||
registry.register("google", create_google_provider)
|
||||
registry.register("openai", create_openai_provider)
|
||||
|
||||
# Create URL pattern mappings (easily extensible)
|
||||
url_mappings = {
|
||||
"generativelanguage.googleapis.com": "google",
|
||||
"api.openai.com": "openai",
|
||||
}
|
||||
|
||||
selector = MappingBasedProviderSelector(url_mappings, default_provider="openai")
|
||||
return SearchProviderFactoryImpl(registry, selector)
|
||||
|
||||
|
||||
# Backward compatibility - singleton instance
|
||||
_default_factory = create_search_provider_factory()
|
||||
|
||||
|
||||
class SearchProviderFactory:
|
||||
"""Backward compatibility wrapper for the old static factory."""
|
||||
|
||||
@staticmethod
|
||||
def create_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||
return _default_factory.create_provider(config)
|
||||
|
||||
@staticmethod
|
||||
def clear_cache():
|
||||
from .search_provider import SearchProvider
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Dict, Callable, Any
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ProviderSelector(ABC):
|
||||
"""Abstract base class for provider selection strategies."""
|
||||
|
||||
@abstractmethod
|
||||
def select_provider_type(self, config: Dict[str, Any]) -> str:
|
||||
"""Select provider type based on configuration."""
|
||||
pass
|
||||
|
||||
|
||||
class MappingBasedProviderSelector(ProviderSelector):
|
||||
"""Selects provider based on URL pattern mapping table."""
|
||||
|
||||
def __init__(self, mappings: Dict[str, str], default_provider: str = "openai"):
|
||||
self._mappings = mappings
|
||||
self._default_provider = default_provider
|
||||
|
||||
def select_provider_type(self, config: Dict[str, Any]) -> str:
|
||||
backend_url = config.get("backend_url", "")
|
||||
for pattern, provider_type in self._mappings.items():
|
||||
if pattern in backend_url:
|
||||
return provider_type
|
||||
return self._default_provider
|
||||
|
||||
|
||||
class SearchProviderRegistry:
|
||||
"""Registry for search provider creation functions."""
|
||||
|
||||
def __init__(self):
|
||||
self._providers: Dict[str, Callable[[Dict[str, Any]], SearchProvider]] = {}
|
||||
|
||||
def register(self, provider_type: str, creator: Callable[[Dict[str, Any]], SearchProvider]):
|
||||
"""Register a provider creator function."""
|
||||
self._providers[provider_type] = creator
|
||||
|
||||
def create(self, provider_type: str, config: Dict[str, Any]) -> SearchProvider:
|
||||
"""Create a provider instance using registered creator."""
|
||||
if provider_type not in self._providers:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
return self._providers[provider_type](config)
|
||||
|
||||
def get_available_types(self) -> list[str]:
|
||||
"""Get list of available provider types."""
|
||||
return list(self._providers.keys())
|
||||
|
||||
|
||||
class SearchProviderFactoryImpl:
|
||||
"""Enhanced factory for creating SearchProvider instances with caching and extensibility."""
|
||||
|
||||
def __init__(self, registry: SearchProviderRegistry, selector: ProviderSelector):
|
||||
self._registry = registry
|
||||
self._selector = selector
|
||||
self._cache: Dict[str, SearchProvider] = {}
|
||||
|
||||
def create_provider(self, config: Dict[str, Any]) -> SearchProvider:
|
||||
"""
|
||||
Create a SearchProvider with caching to avoid creating new instances.
|
||||
Uses config hash as cache key for efficient reuse.
|
||||
"""
|
||||
# Create cache key from relevant config values
|
||||
cache_key_data = {
|
||||
"backend_url": config.get("backend_url", ""),
|
||||
"model": config.get("quick_think_llm", "")
|
||||
}
|
||||
cache_key = hashlib.md5(json.dumps(cache_key_data, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
# Return cached instance if exists
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
# Select and create provider
|
||||
provider_type = self._selector.select_provider_type(config)
|
||||
provider = self._registry.create(provider_type, config)
|
||||
|
||||
# Cache and return
|
||||
self._cache[cache_key] = provider
|
||||
return provider
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the provider cache (useful for testing or config changes)."""
|
||||
self._cache.clear()
|
||||
|
||||
def get_available_provider_types(self) -> list[str]:
|
||||
"""Get list of available provider types."""
|
||||
return self._registry.get_available_types()
|
||||
|
||||
|
||||
def create_search_provider_factory() -> SearchProviderFactoryImpl:
|
||||
"""Create a configured SearchProviderFactory with default providers."""
|
||||
registry = SearchProviderRegistry()
|
||||
|
||||
# Register default providers
|
||||
def create_google_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||
from .search_provider import GoogleSearchProvider
|
||||
return GoogleSearchProvider(config["quick_think_llm"])
|
||||
|
||||
def create_openai_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||
from .search_provider import OpenAISearchProvider
|
||||
return OpenAISearchProvider(config["quick_think_llm"], config["backend_url"])
|
||||
|
||||
registry.register("google", create_google_provider)
|
||||
registry.register("openai", create_openai_provider)
|
||||
|
||||
# Create URL pattern mappings (easily extensible)
|
||||
url_mappings = {
|
||||
"generativelanguage.googleapis.com": "google",
|
||||
"api.openai.com": "openai",
|
||||
}
|
||||
|
||||
selector = MappingBasedProviderSelector(url_mappings, default_provider="openai")
|
||||
return SearchProviderFactoryImpl(registry, selector)
|
||||
|
||||
|
||||
# Backward compatibility - singleton instance
|
||||
_default_factory = create_search_provider_factory()
|
||||
|
||||
|
||||
class SearchProviderFactory:
|
||||
"""Backward compatibility wrapper for the old static factory."""
|
||||
|
||||
@staticmethod
|
||||
def create_provider(config: Dict[str, Any]) -> SearchProvider:
|
||||
return _default_factory.create_provider(config)
|
||||
|
||||
@staticmethod
|
||||
def clear_cache():
|
||||
_default_factory.clear_cache()
|
||||
|
|
@ -1,87 +1,87 @@
|
|||
import pandas as pd
|
||||
import yfinance as yf
|
||||
from stockstats import wrap
|
||||
from typing import Annotated
|
||||
import os
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class StockstatsUtils:
|
||||
@staticmethod
|
||||
def get_stock_stats(
|
||||
symbol: Annotated[str, "ticker symbol for the company"],
|
||||
indicator: Annotated[
|
||||
str, "quantitative indicators based off of the stock data for the company"
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
data_dir: Annotated[
|
||||
str,
|
||||
"directory where the stock data is stored.",
|
||||
],
|
||||
online: Annotated[
|
||||
bool,
|
||||
"whether to use online tools to fetch data or offline tools. If True, will use online tools.",
|
||||
] = False,
|
||||
):
|
||||
df = None
|
||||
data = None
|
||||
|
||||
if not online:
|
||||
try:
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
data_dir,
|
||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||
)
|
||||
)
|
||||
df = wrap(data)
|
||||
except FileNotFoundError:
|
||||
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
|
||||
else:
|
||||
# Get today's date as YYYY-mm-dd to add to cache
|
||||
today_date = pd.Timestamp.today()
|
||||
curr_date = pd.to_datetime(curr_date)
|
||||
|
||||
end_date = today_date
|
||||
start_date = today_date - pd.DateOffset(years=15)
|
||||
start_date = start_date.strftime("%Y-%m-%d")
|
||||
end_date = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Get config and ensure cache directory exists
|
||||
config = get_config()
|
||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||
|
||||
data_file = os.path.join(
|
||||
config["data_cache_dir"],
|
||||
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
|
||||
)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
data = pd.read_csv(data_file)
|
||||
data["Date"] = pd.to_datetime(data["Date"])
|
||||
else:
|
||||
data = yf.download(
|
||||
symbol,
|
||||
start=start_date,
|
||||
end=end_date,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=True,
|
||||
)
|
||||
data = data.reset_index()
|
||||
data.to_csv(data_file, index=False)
|
||||
|
||||
df = wrap(data)
|
||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
curr_date = curr_date.strftime("%Y-%m-%d")
|
||||
|
||||
df[indicator] # trigger stockstats to calculate the indicator
|
||||
matching_rows = df[df["Date"].str.startswith(curr_date)]
|
||||
|
||||
if not matching_rows.empty:
|
||||
indicator_value = matching_rows[indicator].values[0]
|
||||
return indicator_value
|
||||
else:
|
||||
return "N/A: Not a trading day (weekend or holiday)"
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
from stockstats import wrap
|
||||
from typing import Annotated
|
||||
import os
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class StockstatsUtils:
|
||||
@staticmethod
|
||||
def get_stock_stats(
|
||||
symbol: Annotated[str, "ticker symbol for the company"],
|
||||
indicator: Annotated[
|
||||
str, "quantitative indicators based off of the stock data for the company"
|
||||
],
|
||||
curr_date: Annotated[
|
||||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
data_dir: Annotated[
|
||||
str,
|
||||
"directory where the stock data is stored.",
|
||||
],
|
||||
online: Annotated[
|
||||
bool,
|
||||
"whether to use online tools to fetch data or offline tools. If True, will use online tools.",
|
||||
] = False,
|
||||
):
|
||||
df = None
|
||||
data = None
|
||||
|
||||
if not online:
|
||||
try:
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
data_dir,
|
||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||
)
|
||||
)
|
||||
df = wrap(data)
|
||||
except FileNotFoundError:
|
||||
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
|
||||
else:
|
||||
# Get today's date as YYYY-mm-dd to add to cache
|
||||
today_date = pd.Timestamp.today()
|
||||
curr_date = pd.to_datetime(curr_date)
|
||||
|
||||
end_date = today_date
|
||||
start_date = today_date - pd.DateOffset(years=15)
|
||||
start_date = start_date.strftime("%Y-%m-%d")
|
||||
end_date = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Get config and ensure cache directory exists
|
||||
config = get_config()
|
||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||
|
||||
data_file = os.path.join(
|
||||
config["data_cache_dir"],
|
||||
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
|
||||
)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
data = pd.read_csv(data_file)
|
||||
data["Date"] = pd.to_datetime(data["Date"])
|
||||
else:
|
||||
data = yf.download(
|
||||
symbol,
|
||||
start=start_date,
|
||||
end=end_date,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=True,
|
||||
)
|
||||
data = data.reset_index()
|
||||
data.to_csv(data_file, index=False)
|
||||
|
||||
df = wrap(data)
|
||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
curr_date = curr_date.strftime("%Y-%m-%d")
|
||||
|
||||
df[indicator] # trigger stockstats to calculate the indicator
|
||||
matching_rows = df[df["Date"].str.startswith(curr_date)]
|
||||
|
||||
if not matching_rows.empty:
|
||||
indicator_value = matching_rows[indicator].values[0]
|
||||
return indicator_value
|
||||
else:
|
||||
return "N/A: Not a trading day (weekend or holiday)"
|
||||
|
|
@ -1,39 +1,39 @@
|
|||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
from datetime import date, timedelta, datetime
|
||||
from typing import Annotated
|
||||
|
||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||
|
||||
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
||||
if save_path:
|
||||
data.to_csv(save_path)
|
||||
print(f"{tag} saved to {save_path}")
|
||||
|
||||
|
||||
def get_current_date():
|
||||
return date.today().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def decorate_all_methods(decorator):
|
||||
def class_decorator(cls):
|
||||
for attr_name, attr_value in cls.__dict__.items():
|
||||
if callable(attr_value):
|
||||
setattr(cls, attr_name, decorator(attr_value))
|
||||
return cls
|
||||
|
||||
return class_decorator
|
||||
|
||||
|
||||
def get_next_weekday(date):
|
||||
|
||||
if not isinstance(date, datetime):
|
||||
date = datetime.strptime(date, "%Y-%m-%d")
|
||||
|
||||
if date.weekday() >= 5:
|
||||
days_to_add = 7 - date.weekday()
|
||||
next_weekday = date + timedelta(days=days_to_add)
|
||||
return next_weekday
|
||||
else:
|
||||
return date
|
||||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
from datetime import date, timedelta, datetime
|
||||
from typing import Annotated
|
||||
|
||||
SavePathType = Annotated[str, "File path to save data. If None, data is not saved."]
|
||||
|
||||
def save_output(data: pd.DataFrame, tag: str, save_path: SavePathType = None) -> None:
|
||||
if save_path:
|
||||
data.to_csv(save_path)
|
||||
print(f"{tag} saved to {save_path}")
|
||||
|
||||
|
||||
def get_current_date():
|
||||
return date.today().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def decorate_all_methods(decorator):
|
||||
def class_decorator(cls):
|
||||
for attr_name, attr_value in cls.__dict__.items():
|
||||
if callable(attr_value):
|
||||
setattr(cls, attr_name, decorator(attr_value))
|
||||
return cls
|
||||
|
||||
return class_decorator
|
||||
|
||||
|
||||
def get_next_weekday(date):
|
||||
|
||||
if not isinstance(date, datetime):
|
||||
date = datetime.strptime(date, "%Y-%m-%d")
|
||||
|
||||
if date.weekday() >= 5:
|
||||
days_to_add = 7 - date.weekday()
|
||||
next_weekday = date + timedelta(days=days_to_add)
|
||||
return next_weekday
|
||||
else:
|
||||
return date
|
||||
|
|
@ -1,117 +1,117 @@
|
|||
# gets data/stats
|
||||
|
||||
import yfinance as yf
|
||||
from typing import Annotated, Callable, Any, Optional
|
||||
from pandas import DataFrame
|
||||
import pandas as pd
|
||||
from functools import wraps
|
||||
|
||||
from .utils import save_output, SavePathType, decorate_all_methods
|
||||
|
||||
|
||||
def init_ticker(func: Callable) -> Callable:
|
||||
"""Decorator to initialize yf.Ticker and pass it to the function."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any:
|
||||
ticker = yf.Ticker(symbol)
|
||||
return func(ticker, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@decorate_all_methods(init_ticker)
|
||||
class YFinanceUtils:
|
||||
|
||||
def get_stock_data(
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
start_date: Annotated[
|
||||
str, "start date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
end_date: Annotated[
|
||||
str, "end date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
save_path: SavePathType = None,
|
||||
) -> DataFrame:
|
||||
"""retrieve stock price data for designated ticker symbol"""
|
||||
ticker = symbol
|
||||
# add one day to the end_date so that the data range is inclusive
|
||||
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
|
||||
end_date = end_date.strftime("%Y-%m-%d")
|
||||
stock_data = ticker.history(start=start_date, end=end_date)
|
||||
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
|
||||
return stock_data
|
||||
|
||||
def get_stock_info(
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
) -> dict:
|
||||
"""Fetches and returns latest stock information."""
|
||||
ticker = symbol
|
||||
stock_info = ticker.info
|
||||
return stock_info
|
||||
|
||||
def get_company_info(
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
save_path: Optional[str] = None,
|
||||
) -> DataFrame:
|
||||
"""Fetches and returns company information as a DataFrame."""
|
||||
ticker = symbol
|
||||
info = ticker.info
|
||||
company_info = {
|
||||
"Company Name": info.get("shortName", "N/A"),
|
||||
"Industry": info.get("industry", "N/A"),
|
||||
"Sector": info.get("sector", "N/A"),
|
||||
"Country": info.get("country", "N/A"),
|
||||
"Website": info.get("website", "N/A"),
|
||||
}
|
||||
company_info_df = DataFrame([company_info])
|
||||
if save_path:
|
||||
company_info_df.to_csv(save_path)
|
||||
print(f"Company info for {ticker.ticker} saved to {save_path}")
|
||||
return company_info_df
|
||||
|
||||
def get_stock_dividends(
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
save_path: Optional[str] = None,
|
||||
) -> DataFrame:
|
||||
"""Fetches and returns the latest dividends data as a DataFrame."""
|
||||
ticker = symbol
|
||||
dividends = ticker.dividends
|
||||
if save_path:
|
||||
dividends.to_csv(save_path)
|
||||
print(f"Dividends for {ticker.ticker} saved to {save_path}")
|
||||
return dividends
|
||||
|
||||
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest income statement of the company as a DataFrame."""
|
||||
ticker = symbol
|
||||
income_stmt = ticker.financials
|
||||
return income_stmt
|
||||
|
||||
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
|
||||
ticker = symbol
|
||||
balance_sheet = ticker.balance_sheet
|
||||
return balance_sheet
|
||||
|
||||
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
|
||||
ticker = symbol
|
||||
cash_flow = ticker.cashflow
|
||||
return cash_flow
|
||||
|
||||
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
|
||||
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
|
||||
ticker = symbol
|
||||
recommendations = ticker.recommendations
|
||||
if recommendations.empty:
|
||||
return None, 0 # No recommendations available
|
||||
|
||||
# Assuming 'period' column exists and needs to be excluded
|
||||
row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary
|
||||
|
||||
# Find the maximum voting result
|
||||
max_votes = row_0.max()
|
||||
majority_voting_result = row_0[row_0 == max_votes].index.tolist()
|
||||
|
||||
return majority_voting_result[0], max_votes
|
||||
# gets data/stats
|
||||
|
||||
import yfinance as yf
|
||||
from typing import Annotated, Callable, Any, Optional
|
||||
from pandas import DataFrame
|
||||
import pandas as pd
|
||||
from functools import wraps
|
||||
|
||||
from .utils import save_output, SavePathType, decorate_all_methods
|
||||
|
||||
|
||||
def init_ticker(func: Callable) -> Callable:
|
||||
"""Decorator to initialize yf.Ticker and pass it to the function."""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(symbol: Annotated[str, "ticker symbol"], *args, **kwargs) -> Any:
|
||||
ticker = yf.Ticker(symbol)
|
||||
return func(ticker, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@decorate_all_methods(init_ticker)
|
||||
class YFinanceUtils:
|
||||
|
||||
def get_stock_data(
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
start_date: Annotated[
|
||||
str, "start date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
end_date: Annotated[
|
||||
str, "end date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
save_path: SavePathType = None,
|
||||
) -> DataFrame:
|
||||
"""retrieve stock price data for designated ticker symbol"""
|
||||
ticker = symbol
|
||||
# add one day to the end_date so that the data range is inclusive
|
||||
end_date = pd.to_datetime(end_date) + pd.DateOffset(days=1)
|
||||
end_date = end_date.strftime("%Y-%m-%d")
|
||||
stock_data = ticker.history(start=start_date, end=end_date)
|
||||
# save_output(stock_data, f"Stock data for {ticker.ticker}", save_path)
|
||||
return stock_data
|
||||
|
||||
def get_stock_info(
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
) -> dict:
|
||||
"""Fetches and returns latest stock information."""
|
||||
ticker = symbol
|
||||
stock_info = ticker.info
|
||||
return stock_info
|
||||
|
||||
def get_company_info(
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
save_path: Optional[str] = None,
|
||||
) -> DataFrame:
|
||||
"""Fetches and returns company information as a DataFrame."""
|
||||
ticker = symbol
|
||||
info = ticker.info
|
||||
company_info = {
|
||||
"Company Name": info.get("shortName", "N/A"),
|
||||
"Industry": info.get("industry", "N/A"),
|
||||
"Sector": info.get("sector", "N/A"),
|
||||
"Country": info.get("country", "N/A"),
|
||||
"Website": info.get("website", "N/A"),
|
||||
}
|
||||
company_info_df = DataFrame([company_info])
|
||||
if save_path:
|
||||
company_info_df.to_csv(save_path)
|
||||
print(f"Company info for {ticker.ticker} saved to {save_path}")
|
||||
return company_info_df
|
||||
|
||||
def get_stock_dividends(
|
||||
symbol: Annotated[str, "ticker symbol"],
|
||||
save_path: Optional[str] = None,
|
||||
) -> DataFrame:
|
||||
"""Fetches and returns the latest dividends data as a DataFrame."""
|
||||
ticker = symbol
|
||||
dividends = ticker.dividends
|
||||
if save_path:
|
||||
dividends.to_csv(save_path)
|
||||
print(f"Dividends for {ticker.ticker} saved to {save_path}")
|
||||
return dividends
|
||||
|
||||
def get_income_stmt(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest income statement of the company as a DataFrame."""
|
||||
ticker = symbol
|
||||
income_stmt = ticker.financials
|
||||
return income_stmt
|
||||
|
||||
def get_balance_sheet(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest balance sheet of the company as a DataFrame."""
|
||||
ticker = symbol
|
||||
balance_sheet = ticker.balance_sheet
|
||||
return balance_sheet
|
||||
|
||||
def get_cash_flow(symbol: Annotated[str, "ticker symbol"]) -> DataFrame:
|
||||
"""Fetches and returns the latest cash flow statement of the company as a DataFrame."""
|
||||
ticker = symbol
|
||||
cash_flow = ticker.cashflow
|
||||
return cash_flow
|
||||
|
||||
def get_analyst_recommendations(symbol: Annotated[str, "ticker symbol"]) -> tuple:
|
||||
"""Fetches the latest analyst recommendations and returns the most common recommendation and its count."""
|
||||
ticker = symbol
|
||||
recommendations = ticker.recommendations
|
||||
if recommendations.empty:
|
||||
return None, 0 # No recommendations available
|
||||
|
||||
# Assuming 'period' column exists and needs to be excluded
|
||||
row_0 = recommendations.iloc[0, 1:] # Exclude 'period' column if necessary
|
||||
|
||||
# Find the maximum voting result
|
||||
max_votes = row_0.max()
|
||||
majority_voting_result = row_0[row_0 == max_votes].index.tolist()
|
||||
|
||||
return majority_voting_result[0], max_votes
|
||||
|
|
@ -1,17 +1,17 @@
|
|||
# TradingAgents/graph/__init__.py
|
||||
|
||||
from .trading_graph import TradingAgentsGraph
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
from .reflection import Reflector
|
||||
from .signal_processing import SignalProcessor
|
||||
|
||||
__all__ = [
|
||||
"TradingAgentsGraph",
|
||||
"ConditionalLogic",
|
||||
"GraphSetup",
|
||||
"Propagator",
|
||||
"Reflector",
|
||||
"SignalProcessor",
|
||||
]
|
||||
# TradingAgents/graph/__init__.py
|
||||
|
||||
from .trading_graph import TradingAgentsGraph
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
from .reflection import Reflector
|
||||
from .signal_processing import SignalProcessor
|
||||
|
||||
__all__ = [
|
||||
"TradingAgentsGraph",
|
||||
"ConditionalLogic",
|
||||
"GraphSetup",
|
||||
"Propagator",
|
||||
"Reflector",
|
||||
"SignalProcessor",
|
||||
]
|
||||
|
|
@ -1,67 +1,67 @@
|
|||
# TradingAgents/graph/conditional_logic.py
|
||||
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
|
||||
|
||||
class ConditionalLogic:
|
||||
"""Handles conditional logic for determining graph flow."""
|
||||
|
||||
def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1):
|
||||
"""Initialize with configuration parameters."""
|
||||
self.max_debate_rounds = max_debate_rounds
|
||||
self.max_risk_discuss_rounds = max_risk_discuss_rounds
|
||||
|
||||
def should_continue_market(self, state: AgentState):
|
||||
"""Determine if market analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_market"
|
||||
return "Msg Clear Market"
|
||||
|
||||
def should_continue_social(self, state: AgentState):
|
||||
"""Determine if social media analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_social"
|
||||
return "Msg Clear Social"
|
||||
|
||||
def should_continue_news(self, state: AgentState):
|
||||
"""Determine if news analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_news"
|
||||
return "Msg Clear News"
|
||||
|
||||
def should_continue_fundamentals(self, state: AgentState):
|
||||
"""Determine if fundamentals analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_fundamentals"
|
||||
return "Msg Clear Fundamentals"
|
||||
|
||||
def should_continue_debate(self, state: AgentState) -> str:
|
||||
"""Determine if debate should continue."""
|
||||
|
||||
if (
|
||||
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
|
||||
): # 3 rounds of back-and-forth between 2 agents
|
||||
return "Research Manager"
|
||||
if state["investment_debate_state"]["current_response"].startswith("Bull"):
|
||||
return "Bear Researcher"
|
||||
return "Bull Researcher"
|
||||
|
||||
def should_continue_risk_analysis(self, state: AgentState) -> str:
|
||||
"""Determine if risk analysis should continue."""
|
||||
if (
|
||||
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
|
||||
): # 3 rounds of back-and-forth between 3 agents
|
||||
return "Risk Judge"
|
||||
if state["risk_debate_state"]["latest_speaker"].startswith("Risky"):
|
||||
return "Safe Analyst"
|
||||
if state["risk_debate_state"]["latest_speaker"].startswith("Safe"):
|
||||
return "Neutral Analyst"
|
||||
return "Risky Analyst"
|
||||
# TradingAgents/graph/conditional_logic.py
|
||||
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
|
||||
|
||||
class ConditionalLogic:
|
||||
"""Handles conditional logic for determining graph flow."""
|
||||
|
||||
def __init__(self, max_debate_rounds=1, max_risk_discuss_rounds=1):
|
||||
"""Initialize with configuration parameters."""
|
||||
self.max_debate_rounds = max_debate_rounds
|
||||
self.max_risk_discuss_rounds = max_risk_discuss_rounds
|
||||
|
||||
def should_continue_market(self, state: AgentState):
|
||||
"""Determine if market analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_market"
|
||||
return "Msg Clear Market"
|
||||
|
||||
def should_continue_social(self, state: AgentState):
|
||||
"""Determine if social media analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_social"
|
||||
return "Msg Clear Social"
|
||||
|
||||
def should_continue_news(self, state: AgentState):
|
||||
"""Determine if news analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_news"
|
||||
return "Msg Clear News"
|
||||
|
||||
def should_continue_fundamentals(self, state: AgentState):
|
||||
"""Determine if fundamentals analysis should continue."""
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return "tools_fundamentals"
|
||||
return "Msg Clear Fundamentals"
|
||||
|
||||
def should_continue_debate(self, state: AgentState) -> str:
|
||||
"""Determine if debate should continue."""
|
||||
|
||||
if (
|
||||
state["investment_debate_state"]["count"] >= 2 * self.max_debate_rounds
|
||||
): # 3 rounds of back-and-forth between 2 agents
|
||||
return "Research Manager"
|
||||
if state["investment_debate_state"]["current_response"].startswith("Bull"):
|
||||
return "Bear Researcher"
|
||||
return "Bull Researcher"
|
||||
|
||||
def should_continue_risk_analysis(self, state: AgentState) -> str:
|
||||
"""Determine if risk analysis should continue."""
|
||||
if (
|
||||
state["risk_debate_state"]["count"] >= 3 * self.max_risk_discuss_rounds
|
||||
): # 3 rounds of back-and-forth between 3 agents
|
||||
return "Risk Judge"
|
||||
if state["risk_debate_state"]["latest_speaker"].startswith("Risky"):
|
||||
return "Safe Analyst"
|
||||
if state["risk_debate_state"]["latest_speaker"].startswith("Safe"):
|
||||
return "Neutral Analyst"
|
||||
return "Risky Analyst"
|
||||
|
|
@ -1,49 +1,49 @@
|
|||
# TradingAgents/graph/propagation.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
)
|
||||
|
||||
|
||||
class Propagator:
|
||||
"""Handles state initialization and propagation through the graph."""
|
||||
|
||||
def __init__(self, max_recur_limit=100):
|
||||
"""Initialize with configuration parameters."""
|
||||
self.max_recur_limit = max_recur_limit
|
||||
|
||||
def create_initial_state(
|
||||
self, company_name: str, trade_date: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create the initial state for the agent graph."""
|
||||
return {
|
||||
"messages": [("human", company_name)],
|
||||
"company_of_interest": company_name,
|
||||
"trade_date": str(trade_date),
|
||||
"investment_debate_state": InvestDebateState(
|
||||
{"history": "", "current_response": "", "count": 0}
|
||||
),
|
||||
"risk_debate_state": RiskDebateState(
|
||||
{
|
||||
"history": "",
|
||||
"current_risky_response": "",
|
||||
"current_safe_response": "",
|
||||
"current_neutral_response": "",
|
||||
"count": 0,
|
||||
}
|
||||
),
|
||||
"market_report": "",
|
||||
"fundamentals_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
}
|
||||
|
||||
def get_graph_args(self) -> Dict[str, Any]:
|
||||
"""Get arguments for the graph invocation."""
|
||||
return {
|
||||
"stream_mode": "values",
|
||||
"config": {"recursion_limit": self.max_recur_limit},
|
||||
}
|
||||
# TradingAgents/graph/propagation.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
)
|
||||
|
||||
|
||||
class Propagator:
|
||||
"""Handles state initialization and propagation through the graph."""
|
||||
|
||||
def __init__(self, max_recur_limit=100):
|
||||
"""Initialize with configuration parameters."""
|
||||
self.max_recur_limit = max_recur_limit
|
||||
|
||||
def create_initial_state(
|
||||
self, company_name: str, trade_date: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create the initial state for the agent graph."""
|
||||
return {
|
||||
"messages": [("human", company_name)],
|
||||
"company_of_interest": company_name,
|
||||
"trade_date": str(trade_date),
|
||||
"investment_debate_state": InvestDebateState(
|
||||
{"history": "", "current_response": "", "count": 0}
|
||||
),
|
||||
"risk_debate_state": RiskDebateState(
|
||||
{
|
||||
"history": "",
|
||||
"current_risky_response": "",
|
||||
"current_safe_response": "",
|
||||
"current_neutral_response": "",
|
||||
"count": 0,
|
||||
}
|
||||
),
|
||||
"market_report": "",
|
||||
"fundamentals_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
}
|
||||
|
||||
def get_graph_args(self) -> Dict[str, Any]:
|
||||
"""Get arguments for the graph invocation."""
|
||||
return {
|
||||
"stream_mode": "values",
|
||||
"config": {"recursion_limit": self.max_recur_limit},
|
||||
}
|
||||
|
|
@ -1,123 +1,123 @@
|
|||
# TradingAgents/graph/reflection.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class Reflector:
|
||||
"""Handles reflection on decisions and updating memory."""
|
||||
|
||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||
"""Initialize the reflector with an LLM."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
self.reflection_system_prompt = self._get_reflection_prompt()
|
||||
|
||||
def _get_reflection_prompt(self) -> str:
|
||||
"""Get the system prompt for reflection."""
|
||||
return """
|
||||
**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
|
||||
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
|
||||
|
||||
1. Reasoning:
|
||||
- For each trading decision, determine whether it was correct or incorrect. A correct decision results in an increase in returns, while an incorrect decision does the opposite.
|
||||
- Analyze the contributing factors to each success or mistake. Consider:
|
||||
- Market intelligence.
|
||||
- Technical indicators.
|
||||
- Technical signals.
|
||||
- Price movement analysis.
|
||||
- Overall market data analysis
|
||||
- News analysis.
|
||||
- Social media and sentiment analysis.
|
||||
- Fundamental data analysis.
|
||||
- Weight the importance of each factor in the decision-making process.
|
||||
|
||||
2. Improvement:
|
||||
- For any incorrect decisions, propose revisions to maximize returns.
|
||||
- Provide a detailed list of corrective actions or improvements, including specific recommendations (e.g., changing a decision from HOLD to BUY on a particular date).
|
||||
|
||||
3. Summary:
|
||||
- Summarize the lessons learned from the successes and mistakes.
|
||||
- Highlight how these lessons can be adapted for future trading scenarios and draw connections between similar situations to apply the knowledge gained.
|
||||
|
||||
4. Query:
|
||||
- Extract key insights from the summary into a concise sentence of no more than 1000 tokens.
|
||||
- Ensure the condensed sentence captures the essence of the lessons and reasoning for easy reference.
|
||||
|
||||
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
|
||||
"""
|
||||
|
||||
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
|
||||
"""Extract the current market situation from the state."""
|
||||
curr_market_report = current_state["market_report"]
|
||||
curr_sentiment_report = current_state["sentiment_report"]
|
||||
curr_news_report = current_state["news_report"]
|
||||
curr_fundamentals_report = current_state["fundamentals_report"]
|
||||
|
||||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
||||
|
||||
def _reflect_on_component(
|
||||
self, component_type: str, report: str, situation: str, returns_losses
|
||||
) -> str:
|
||||
"""Generate reflection for a component."""
|
||||
messages = [
|
||||
("system", self.reflection_system_prompt),
|
||||
(
|
||||
"human",
|
||||
f"Returns: {returns_losses}\n\nAnalysis/Decision: {report}\n\nObjective Market Reports for Reference: {situation}",
|
||||
),
|
||||
]
|
||||
|
||||
result = self.quick_thinking_llm.invoke(messages).content
|
||||
return result
|
||||
|
||||
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
|
||||
"""Reflect on bull researcher's analysis and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"BULL", bull_debate_history, situation, returns_losses
|
||||
)
|
||||
bull_memory.add_situations([(situation, result)])
|
||||
|
||||
def reflect_bear_researcher(self, current_state, returns_losses, bear_memory):
|
||||
"""Reflect on bear researcher's analysis and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"BEAR", bear_debate_history, situation, returns_losses
|
||||
)
|
||||
bear_memory.add_situations([(situation, result)])
|
||||
|
||||
def reflect_trader(self, current_state, returns_losses, trader_memory):
|
||||
"""Reflect on trader's decision and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
trader_decision = current_state["trader_investment_plan"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"TRADER", trader_decision, situation, returns_losses
|
||||
)
|
||||
trader_memory.add_situations([(situation, result)])
|
||||
|
||||
def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory):
|
||||
"""Reflect on investment judge's decision and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"INVEST JUDGE", judge_decision, situation, returns_losses
|
||||
)
|
||||
invest_judge_memory.add_situations([(situation, result)])
|
||||
|
||||
def reflect_risk_manager(self, current_state, returns_losses, risk_manager_memory):
|
||||
"""Reflect on risk manager's decision and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"RISK JUDGE", judge_decision, situation, returns_losses
|
||||
)
|
||||
risk_manager_memory.add_situations([(situation, result)])
|
||||
# TradingAgents/graph/reflection.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class Reflector:
|
||||
"""Handles reflection on decisions and updating memory."""
|
||||
|
||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||
"""Initialize the reflector with an LLM."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
self.reflection_system_prompt = self._get_reflection_prompt()
|
||||
|
||||
def _get_reflection_prompt(self) -> str:
|
||||
"""Get the system prompt for reflection."""
|
||||
return """
|
||||
**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)
|
||||
|
||||
You are an expert financial analyst tasked with reviewing trading decisions/analysis and providing a comprehensive, step-by-step analysis.
|
||||
Your goal is to deliver detailed insights into investment decisions and highlight opportunities for improvement, adhering strictly to the following guidelines:
|
||||
|
||||
1. Reasoning:
|
||||
- For each trading decision, determine whether it was correct or incorrect. A correct decision results in an increase in returns, while an incorrect decision does the opposite.
|
||||
- Analyze the contributing factors to each success or mistake. Consider:
|
||||
- Market intelligence.
|
||||
- Technical indicators.
|
||||
- Technical signals.
|
||||
- Price movement analysis.
|
||||
- Overall market data analysis
|
||||
- News analysis.
|
||||
- Social media and sentiment analysis.
|
||||
- Fundamental data analysis.
|
||||
- Weight the importance of each factor in the decision-making process.
|
||||
|
||||
2. Improvement:
|
||||
- For any incorrect decisions, propose revisions to maximize returns.
|
||||
- Provide a detailed list of corrective actions or improvements, including specific recommendations (e.g., changing a decision from HOLD to BUY on a particular date).
|
||||
|
||||
3. Summary:
|
||||
- Summarize the lessons learned from the successes and mistakes.
|
||||
- Highlight how these lessons can be adapted for future trading scenarios and draw connections between similar situations to apply the knowledge gained.
|
||||
|
||||
4. Query:
|
||||
- Extract key insights from the summary into a concise sentence of no more than 1000 tokens.
|
||||
- Ensure the condensed sentence captures the essence of the lessons and reasoning for easy reference.
|
||||
|
||||
Adhere strictly to these instructions, and ensure your output is detailed, accurate, and actionable. You will also be given objective descriptions of the market from a price movements, technical indicator, news, and sentiment perspective to provide more context for your analysis.
|
||||
"""
|
||||
|
||||
def _extract_current_situation(self, current_state: Dict[str, Any]) -> str:
|
||||
"""Extract the current market situation from the state."""
|
||||
curr_market_report = current_state["market_report"]
|
||||
curr_sentiment_report = current_state["sentiment_report"]
|
||||
curr_news_report = current_state["news_report"]
|
||||
curr_fundamentals_report = current_state["fundamentals_report"]
|
||||
|
||||
return f"{curr_market_report}\n\n{curr_sentiment_report}\n\n{curr_news_report}\n\n{curr_fundamentals_report}"
|
||||
|
||||
def _reflect_on_component(
|
||||
self, component_type: str, report: str, situation: str, returns_losses
|
||||
) -> str:
|
||||
"""Generate reflection for a component."""
|
||||
messages = [
|
||||
("system", self.reflection_system_prompt),
|
||||
(
|
||||
"human",
|
||||
f"Returns: {returns_losses}\n\nAnalysis/Decision: {report}\n\nObjective Market Reports for Reference: {situation}",
|
||||
),
|
||||
]
|
||||
|
||||
result = self.quick_thinking_llm.invoke(messages).content
|
||||
return result
|
||||
|
||||
def reflect_bull_researcher(self, current_state, returns_losses, bull_memory):
|
||||
"""Reflect on bull researcher's analysis and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
bull_debate_history = current_state["investment_debate_state"]["bull_history"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"BULL", bull_debate_history, situation, returns_losses
|
||||
)
|
||||
bull_memory.add_situations([(situation, result)])
|
||||
|
||||
def reflect_bear_researcher(self, current_state, returns_losses, bear_memory):
|
||||
"""Reflect on bear researcher's analysis and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
bear_debate_history = current_state["investment_debate_state"]["bear_history"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"BEAR", bear_debate_history, situation, returns_losses
|
||||
)
|
||||
bear_memory.add_situations([(situation, result)])
|
||||
|
||||
def reflect_trader(self, current_state, returns_losses, trader_memory):
|
||||
"""Reflect on trader's decision and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
trader_decision = current_state["trader_investment_plan"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"TRADER", trader_decision, situation, returns_losses
|
||||
)
|
||||
trader_memory.add_situations([(situation, result)])
|
||||
|
||||
def reflect_invest_judge(self, current_state, returns_losses, invest_judge_memory):
|
||||
"""Reflect on investment judge's decision and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
judge_decision = current_state["investment_debate_state"]["judge_decision"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"INVEST JUDGE", judge_decision, situation, returns_losses
|
||||
)
|
||||
invest_judge_memory.add_situations([(situation, result)])
|
||||
|
||||
def reflect_risk_manager(self, current_state, returns_losses, risk_manager_memory):
|
||||
"""Reflect on risk manager's decision and update memory."""
|
||||
situation = self._extract_current_situation(current_state)
|
||||
judge_decision = current_state["risk_debate_state"]["judge_decision"]
|
||||
|
||||
result = self._reflect_on_component(
|
||||
"RISK JUDGE", judge_decision, situation, returns_losses
|
||||
)
|
||||
risk_manager_memory.add_situations([(situation, result)])
|
||||
|
|
@ -1,205 +1,205 @@
|
|||
# TradingAgents/graph/setup.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
from tradingagents.agents.utils.agent_utils import Toolkit
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
|
||||
|
||||
class GraphSetup:
|
||||
"""Handles the setup and configuration of the agent graph."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quick_thinking_llm: ChatOpenAI,
|
||||
deep_thinking_llm: ChatOpenAI,
|
||||
toolkit: Toolkit,
|
||||
tool_nodes: Dict[str, ToolNode],
|
||||
bull_memory,
|
||||
bear_memory,
|
||||
trader_memory,
|
||||
invest_judge_memory,
|
||||
risk_manager_memory,
|
||||
conditional_logic: ConditionalLogic,
|
||||
):
|
||||
"""Initialize with required components."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
self.deep_thinking_llm = deep_thinking_llm
|
||||
self.toolkit = toolkit
|
||||
self.tool_nodes = tool_nodes
|
||||
self.bull_memory = bull_memory
|
||||
self.bear_memory = bear_memory
|
||||
self.trader_memory = trader_memory
|
||||
self.invest_judge_memory = invest_judge_memory
|
||||
self.risk_manager_memory = risk_manager_memory
|
||||
self.conditional_logic = conditional_logic
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
Args:
|
||||
selected_analysts (list): List of analyst types to include. Options are:
|
||||
- "market": Market analyst
|
||||
- "social": Social media analyst
|
||||
- "news": News analyst
|
||||
- "fundamentals": Fundamentals analyst
|
||||
"""
|
||||
if len(selected_analysts) == 0:
|
||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||
|
||||
# Create analyst nodes
|
||||
analyst_nodes = {}
|
||||
delete_nodes = {}
|
||||
tool_nodes = {}
|
||||
|
||||
if "market" in selected_analysts:
|
||||
analyst_nodes["market"] = create_market_analyst(
|
||||
self.quick_thinking_llm, self.toolkit
|
||||
)
|
||||
delete_nodes["market"] = create_msg_delete()
|
||||
tool_nodes["market"] = self.tool_nodes["market"]
|
||||
|
||||
if "social" in selected_analysts:
|
||||
analyst_nodes["social"] = create_social_media_analyst(
|
||||
self.quick_thinking_llm, self.toolkit
|
||||
)
|
||||
delete_nodes["social"] = create_msg_delete()
|
||||
tool_nodes["social"] = self.tool_nodes["social"]
|
||||
|
||||
if "news" in selected_analysts:
|
||||
analyst_nodes["news"] = create_news_analyst(
|
||||
self.quick_thinking_llm, self.toolkit
|
||||
)
|
||||
delete_nodes["news"] = create_msg_delete()
|
||||
tool_nodes["news"] = self.tool_nodes["news"]
|
||||
|
||||
if "fundamentals" in selected_analysts:
|
||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
||||
self.quick_thinking_llm, self.toolkit
|
||||
)
|
||||
delete_nodes["fundamentals"] = create_msg_delete()
|
||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||
|
||||
# Create researcher and manager nodes
|
||||
bull_researcher_node = create_bull_researcher(
|
||||
self.quick_thinking_llm, self.bull_memory
|
||||
)
|
||||
bear_researcher_node = create_bear_researcher(
|
||||
self.quick_thinking_llm, self.bear_memory
|
||||
)
|
||||
research_manager_node = create_research_manager(
|
||||
self.deep_thinking_llm, self.invest_judge_memory
|
||||
)
|
||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
||||
|
||||
# Create risk analysis nodes
|
||||
risky_analyst = create_risky_debator(self.quick_thinking_llm)
|
||||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||
safe_analyst = create_safe_debator(self.quick_thinking_llm)
|
||||
risk_manager_node = create_risk_manager(
|
||||
self.deep_thinking_llm, self.risk_manager_memory
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add analyst nodes to the graph
|
||||
for analyst_type, node in analyst_nodes.items():
|
||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||
workflow.add_node(
|
||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
|
||||
)
|
||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||
|
||||
# Add other nodes
|
||||
workflow.add_node("Bull Researcher", bull_researcher_node)
|
||||
workflow.add_node("Bear Researcher", bear_researcher_node)
|
||||
workflow.add_node("Research Manager", research_manager_node)
|
||||
workflow.add_node("Trader", trader_node)
|
||||
workflow.add_node("Risky Analyst", risky_analyst)
|
||||
workflow.add_node("Neutral Analyst", neutral_analyst)
|
||||
workflow.add_node("Safe Analyst", safe_analyst)
|
||||
workflow.add_node("Risk Judge", risk_manager_node)
|
||||
|
||||
# Define edges
|
||||
# Start with the first analyst
|
||||
first_analyst = selected_analysts[0]
|
||||
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
|
||||
|
||||
# Connect analysts in sequence
|
||||
for i, analyst_type in enumerate(selected_analysts):
|
||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
||||
current_tools = f"tools_{analyst_type}"
|
||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
||||
|
||||
# Add conditional edges for current analyst
|
||||
workflow.add_conditional_edges(
|
||||
current_analyst,
|
||||
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
|
||||
[current_tools, current_clear],
|
||||
)
|
||||
workflow.add_edge(current_tools, current_analyst)
|
||||
|
||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
||||
if i < len(selected_analysts) - 1:
|
||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
||||
workflow.add_edge(current_clear, next_analyst)
|
||||
else:
|
||||
workflow.add_edge(current_clear, "Bull Researcher")
|
||||
|
||||
# Add remaining edges
|
||||
workflow.add_conditional_edges(
|
||||
"Bull Researcher",
|
||||
self.conditional_logic.should_continue_debate,
|
||||
{
|
||||
"Bear Researcher": "Bear Researcher",
|
||||
"Research Manager": "Research Manager",
|
||||
},
|
||||
)
|
||||
workflow.add_conditional_edges(
|
||||
"Bear Researcher",
|
||||
self.conditional_logic.should_continue_debate,
|
||||
{
|
||||
"Bull Researcher": "Bull Researcher",
|
||||
"Research Manager": "Research Manager",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("Research Manager", "Trader")
|
||||
workflow.add_edge("Trader", "Risky Analyst")
|
||||
workflow.add_conditional_edges(
|
||||
"Risky Analyst",
|
||||
self.conditional_logic.should_continue_risk_analysis,
|
||||
{
|
||||
"Safe Analyst": "Safe Analyst",
|
||||
"Risk Judge": "Risk Judge",
|
||||
},
|
||||
)
|
||||
workflow.add_conditional_edges(
|
||||
"Safe Analyst",
|
||||
self.conditional_logic.should_continue_risk_analysis,
|
||||
{
|
||||
"Neutral Analyst": "Neutral Analyst",
|
||||
"Risk Judge": "Risk Judge",
|
||||
},
|
||||
)
|
||||
workflow.add_conditional_edges(
|
||||
"Neutral Analyst",
|
||||
self.conditional_logic.should_continue_risk_analysis,
|
||||
{
|
||||
"Risky Analyst": "Risky Analyst",
|
||||
"Risk Judge": "Risk Judge",
|
||||
},
|
||||
)
|
||||
|
||||
workflow.add_edge("Risk Judge", END)
|
||||
|
||||
# Compile and return
|
||||
return workflow.compile()
|
||||
# TradingAgents/graph/setup.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, StateGraph, START
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.agents.utils.agent_states import AgentState
|
||||
from tradingagents.agents.utils.agent_utils import Toolkit
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
|
||||
|
||||
class GraphSetup:
|
||||
"""Handles the setup and configuration of the agent graph."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quick_thinking_llm: ChatOpenAI,
|
||||
deep_thinking_llm: ChatOpenAI,
|
||||
toolkit: Toolkit,
|
||||
tool_nodes: Dict[str, ToolNode],
|
||||
bull_memory,
|
||||
bear_memory,
|
||||
trader_memory,
|
||||
invest_judge_memory,
|
||||
risk_manager_memory,
|
||||
conditional_logic: ConditionalLogic,
|
||||
):
|
||||
"""Initialize with required components."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
self.deep_thinking_llm = deep_thinking_llm
|
||||
self.toolkit = toolkit
|
||||
self.tool_nodes = tool_nodes
|
||||
self.bull_memory = bull_memory
|
||||
self.bear_memory = bear_memory
|
||||
self.trader_memory = trader_memory
|
||||
self.invest_judge_memory = invest_judge_memory
|
||||
self.risk_manager_memory = risk_manager_memory
|
||||
self.conditional_logic = conditional_logic
|
||||
|
||||
def setup_graph(
|
||||
self, selected_analysts=["market", "social", "news", "fundamentals"]
|
||||
):
|
||||
"""Set up and compile the agent workflow graph.
|
||||
|
||||
Args:
|
||||
selected_analysts (list): List of analyst types to include. Options are:
|
||||
- "market": Market analyst
|
||||
- "social": Social media analyst
|
||||
- "news": News analyst
|
||||
- "fundamentals": Fundamentals analyst
|
||||
"""
|
||||
if len(selected_analysts) == 0:
|
||||
raise ValueError("Trading Agents Graph Setup Error: no analysts selected!")
|
||||
|
||||
# Create analyst nodes
|
||||
analyst_nodes = {}
|
||||
delete_nodes = {}
|
||||
tool_nodes = {}
|
||||
|
||||
if "market" in selected_analysts:
|
||||
analyst_nodes["market"] = create_market_analyst(
|
||||
self.quick_thinking_llm, self.toolkit
|
||||
)
|
||||
delete_nodes["market"] = create_msg_delete()
|
||||
tool_nodes["market"] = self.tool_nodes["market"]
|
||||
|
||||
if "social" in selected_analysts:
|
||||
analyst_nodes["social"] = create_social_media_analyst(
|
||||
self.quick_thinking_llm, self.toolkit
|
||||
)
|
||||
delete_nodes["social"] = create_msg_delete()
|
||||
tool_nodes["social"] = self.tool_nodes["social"]
|
||||
|
||||
if "news" in selected_analysts:
|
||||
analyst_nodes["news"] = create_news_analyst(
|
||||
self.quick_thinking_llm, self.toolkit
|
||||
)
|
||||
delete_nodes["news"] = create_msg_delete()
|
||||
tool_nodes["news"] = self.tool_nodes["news"]
|
||||
|
||||
if "fundamentals" in selected_analysts:
|
||||
analyst_nodes["fundamentals"] = create_fundamentals_analyst(
|
||||
self.quick_thinking_llm, self.toolkit
|
||||
)
|
||||
delete_nodes["fundamentals"] = create_msg_delete()
|
||||
tool_nodes["fundamentals"] = self.tool_nodes["fundamentals"]
|
||||
|
||||
# Create researcher and manager nodes
|
||||
bull_researcher_node = create_bull_researcher(
|
||||
self.quick_thinking_llm, self.bull_memory
|
||||
)
|
||||
bear_researcher_node = create_bear_researcher(
|
||||
self.quick_thinking_llm, self.bear_memory
|
||||
)
|
||||
research_manager_node = create_research_manager(
|
||||
self.deep_thinking_llm, self.invest_judge_memory
|
||||
)
|
||||
trader_node = create_trader(self.quick_thinking_llm, self.trader_memory)
|
||||
|
||||
# Create risk analysis nodes
|
||||
risky_analyst = create_risky_debator(self.quick_thinking_llm)
|
||||
neutral_analyst = create_neutral_debator(self.quick_thinking_llm)
|
||||
safe_analyst = create_safe_debator(self.quick_thinking_llm)
|
||||
risk_manager_node = create_risk_manager(
|
||||
self.deep_thinking_llm, self.risk_manager_memory
|
||||
)
|
||||
|
||||
# Create workflow
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
# Add analyst nodes to the graph
|
||||
for analyst_type, node in analyst_nodes.items():
|
||||
workflow.add_node(f"{analyst_type.capitalize()} Analyst", node)
|
||||
workflow.add_node(
|
||||
f"Msg Clear {analyst_type.capitalize()}", delete_nodes[analyst_type]
|
||||
)
|
||||
workflow.add_node(f"tools_{analyst_type}", tool_nodes[analyst_type])
|
||||
|
||||
# Add other nodes
|
||||
workflow.add_node("Bull Researcher", bull_researcher_node)
|
||||
workflow.add_node("Bear Researcher", bear_researcher_node)
|
||||
workflow.add_node("Research Manager", research_manager_node)
|
||||
workflow.add_node("Trader", trader_node)
|
||||
workflow.add_node("Risky Analyst", risky_analyst)
|
||||
workflow.add_node("Neutral Analyst", neutral_analyst)
|
||||
workflow.add_node("Safe Analyst", safe_analyst)
|
||||
workflow.add_node("Risk Judge", risk_manager_node)
|
||||
|
||||
# Define edges
|
||||
# Start with the first analyst
|
||||
first_analyst = selected_analysts[0]
|
||||
workflow.add_edge(START, f"{first_analyst.capitalize()} Analyst")
|
||||
|
||||
# Connect analysts in sequence
|
||||
for i, analyst_type in enumerate(selected_analysts):
|
||||
current_analyst = f"{analyst_type.capitalize()} Analyst"
|
||||
current_tools = f"tools_{analyst_type}"
|
||||
current_clear = f"Msg Clear {analyst_type.capitalize()}"
|
||||
|
||||
# Add conditional edges for current analyst
|
||||
workflow.add_conditional_edges(
|
||||
current_analyst,
|
||||
getattr(self.conditional_logic, f"should_continue_{analyst_type}"),
|
||||
[current_tools, current_clear],
|
||||
)
|
||||
workflow.add_edge(current_tools, current_analyst)
|
||||
|
||||
# Connect to next analyst or to Bull Researcher if this is the last analyst
|
||||
if i < len(selected_analysts) - 1:
|
||||
next_analyst = f"{selected_analysts[i+1].capitalize()} Analyst"
|
||||
workflow.add_edge(current_clear, next_analyst)
|
||||
else:
|
||||
workflow.add_edge(current_clear, "Bull Researcher")
|
||||
|
||||
# Add remaining edges
|
||||
workflow.add_conditional_edges(
|
||||
"Bull Researcher",
|
||||
self.conditional_logic.should_continue_debate,
|
||||
{
|
||||
"Bear Researcher": "Bear Researcher",
|
||||
"Research Manager": "Research Manager",
|
||||
},
|
||||
)
|
||||
workflow.add_conditional_edges(
|
||||
"Bear Researcher",
|
||||
self.conditional_logic.should_continue_debate,
|
||||
{
|
||||
"Bull Researcher": "Bull Researcher",
|
||||
"Research Manager": "Research Manager",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("Research Manager", "Trader")
|
||||
workflow.add_edge("Trader", "Risky Analyst")
|
||||
workflow.add_conditional_edges(
|
||||
"Risky Analyst",
|
||||
self.conditional_logic.should_continue_risk_analysis,
|
||||
{
|
||||
"Safe Analyst": "Safe Analyst",
|
||||
"Risk Judge": "Risk Judge",
|
||||
},
|
||||
)
|
||||
workflow.add_conditional_edges(
|
||||
"Safe Analyst",
|
||||
self.conditional_logic.should_continue_risk_analysis,
|
||||
{
|
||||
"Neutral Analyst": "Neutral Analyst",
|
||||
"Risk Judge": "Risk Judge",
|
||||
},
|
||||
)
|
||||
workflow.add_conditional_edges(
|
||||
"Neutral Analyst",
|
||||
self.conditional_logic.should_continue_risk_analysis,
|
||||
{
|
||||
"Risky Analyst": "Risky Analyst",
|
||||
"Risk Judge": "Risk Judge",
|
||||
},
|
||||
)
|
||||
|
||||
workflow.add_edge("Risk Judge", END)
|
||||
|
||||
# Compile and return
|
||||
return workflow.compile()
|
||||
|
|
@ -1,31 +1,31 @@
|
|||
# TradingAgents/graph/signal_processing.py
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class SignalProcessor:
|
||||
"""Processes trading signals to extract actionable decisions."""
|
||||
|
||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||
"""Initialize with an LLM for processing."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
|
||||
def process_signal(self, full_signal: str) -> str:
|
||||
"""
|
||||
Process a full trading signal to extract the core decision.
|
||||
|
||||
Args:
|
||||
full_signal: Complete trading signal text
|
||||
|
||||
Returns:
|
||||
Extracted decision (BUY, SELL, or HOLD)
|
||||
"""
|
||||
messages = [
|
||||
(
|
||||
"system",
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
|
||||
),
|
||||
("human", full_signal),
|
||||
]
|
||||
|
||||
return self.quick_thinking_llm.invoke(messages).content
|
||||
# TradingAgents/graph/signal_processing.py
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
class SignalProcessor:
|
||||
"""Processes trading signals to extract actionable decisions."""
|
||||
|
||||
def __init__(self, quick_thinking_llm: ChatOpenAI):
|
||||
"""Initialize with an LLM for processing."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
|
||||
def process_signal(self, full_signal: str) -> str:
|
||||
"""
|
||||
Process a full trading signal to extract the core decision.
|
||||
|
||||
Args:
|
||||
full_signal: Complete trading signal text
|
||||
|
||||
Returns:
|
||||
Extracted decision (BUY, SELL, or HOLD)
|
||||
"""
|
||||
messages = [
|
||||
(
|
||||
"system",
|
||||
"**IMPORTANT THING** Respond in Korean(한국어로 대답해주세요)\n\nYou are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
|
||||
),
|
||||
("human", full_signal),
|
||||
]
|
||||
|
||||
return self.quick_thinking_llm.invoke(messages).content
|
||||
|
|
@ -1,264 +1,264 @@
|
|||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import date
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
)
|
||||
from tradingagents.dataflows.interface import set_config
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
from .reflection import Reflector
|
||||
from .signal_processing import SignalProcessor
|
||||
|
||||
|
||||
class TradingAgentsGraph:
|
||||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""Initialize the trading agents graph and components.
|
||||
|
||||
Args:
|
||||
selected_analysts: List of analyst types to include
|
||||
debug: Whether to run in debug mode
|
||||
config: Configuration dictionary. If None, uses default config
|
||||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs(
|
||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
# Initialize LLMs
|
||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "anthropic":
|
||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "google":
|
||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||
|
||||
self.toolkit = Toolkit(config=self.config)
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||||
|
||||
# Create tool nodes
|
||||
self.tool_nodes = self._create_tool_nodes()
|
||||
|
||||
# Initialize components
|
||||
self.conditional_logic = ConditionalLogic()
|
||||
self.graph_setup = GraphSetup(
|
||||
self.quick_thinking_llm,
|
||||
self.deep_thinking_llm,
|
||||
self.toolkit,
|
||||
self.tool_nodes,
|
||||
self.bull_memory,
|
||||
self.bear_memory,
|
||||
self.trader_memory,
|
||||
self.invest_judge_memory,
|
||||
self.risk_manager_memory,
|
||||
self.conditional_logic,
|
||||
)
|
||||
|
||||
self.propagator = Propagator()
|
||||
self.reflector = Reflector(self.quick_thinking_llm)
|
||||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||||
|
||||
# State tracking
|
||||
self.curr_state = None
|
||||
self.ticker = None
|
||||
self.log_states_dict = {} # date to full state dict
|
||||
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
|
||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||
"""Create tool nodes for different data sources."""
|
||||
return {
|
||||
"market": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_YFin_data_online,
|
||||
self.toolkit.get_stockstats_indicators_report_online,
|
||||
# offline tools
|
||||
self.toolkit.get_YFin_data,
|
||||
self.toolkit.get_stockstats_indicators_report,
|
||||
]
|
||||
),
|
||||
"social": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_stock_news,
|
||||
# offline tools
|
||||
# self.toolkit.get_reddit_stock_info,
|
||||
]
|
||||
),
|
||||
"news": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_global_news,
|
||||
self.toolkit.get_google_news,
|
||||
# offline tools
|
||||
# self.toolkit.get_finnhub_news,
|
||||
# self.toolkit.get_reddit_news,
|
||||
]
|
||||
),
|
||||
"fundamentals": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_fundamentals,
|
||||
# offline tools
|
||||
# self.toolkit.get_finnhub_company_insider_sentiment,
|
||||
# self.toolkit.get_finnhub_company_insider_transactions,
|
||||
# self.toolkit.get_simfin_balance_sheet,
|
||||
# self.toolkit.get_simfin_cashflow,
|
||||
# self.toolkit.get_simfin_income_stmt,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
"""Run the trading agents graph for a company on a specific date."""
|
||||
|
||||
self.ticker = company_name
|
||||
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
if self.debug:
|
||||
# Debug mode with tracing
|
||||
trace = []
|
||||
for chunk in self.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) == 0:
|
||||
continue
|
||||
|
||||
message = chunk["messages"][-1]
|
||||
|
||||
if message.content and message.content.strip():
|
||||
|
||||
if "FINAL TRANSACTION PROPOSAL:" in message.content:
|
||||
if not hasattr(self, '_final_printed'):
|
||||
message.pretty_print()
|
||||
self._final_printed = True
|
||||
else:
|
||||
message.pretty_print()
|
||||
|
||||
trace.append(chunk)
|
||||
|
||||
final_state = trace[-1]
|
||||
else:
|
||||
# Standard mode without tracing
|
||||
final_state = self.graph.invoke(init_agent_state, **args)
|
||||
|
||||
# Store current state for reflection
|
||||
self.curr_state = final_state
|
||||
|
||||
# Log state
|
||||
self._log_state(trade_date, final_state)
|
||||
|
||||
# Return decision and processed signal
|
||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
def _log_state(self, trade_date, final_state):
|
||||
"""Log the final state to a JSON file."""
|
||||
self.log_states_dict[str(trade_date)] = {
|
||||
"company_of_interest": final_state["company_of_interest"],
|
||||
"trade_date": final_state["trade_date"],
|
||||
"market_report": final_state["market_report"],
|
||||
"sentiment_report": final_state["sentiment_report"],
|
||||
"news_report": final_state["news_report"],
|
||||
"fundamentals_report": final_state["fundamentals_report"],
|
||||
"investment_debate_state": {
|
||||
"bull_history": final_state["investment_debate_state"]["bull_history"],
|
||||
"bear_history": final_state["investment_debate_state"]["bear_history"],
|
||||
"history": final_state["investment_debate_state"]["history"],
|
||||
"current_response": final_state["investment_debate_state"][
|
||||
"current_response"
|
||||
],
|
||||
"judge_decision": final_state["investment_debate_state"][
|
||||
"judge_decision"
|
||||
],
|
||||
},
|
||||
"trader_investment_decision": final_state["trader_investment_plan"],
|
||||
"risk_debate_state": {
|
||||
"risky_history": final_state["risk_debate_state"]["risky_history"],
|
||||
"safe_history": final_state["risk_debate_state"]["safe_history"],
|
||||
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
|
||||
"history": final_state["risk_debate_state"]["history"],
|
||||
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
|
||||
},
|
||||
"investment_plan": final_state["investment_plan"],
|
||||
"final_trade_decision": final_state["final_trade_decision"],
|
||||
}
|
||||
|
||||
# Save to file
|
||||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(
|
||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log.json",
|
||||
"w",
|
||||
) as f:
|
||||
json.dump(self.log_states_dict, f, indent=4)
|
||||
|
||||
def reflect_and_remember(self, returns_losses):
|
||||
"""Reflect on decisions and update memory based on returns."""
|
||||
self.reflector.reflect_bull_researcher(
|
||||
self.curr_state, returns_losses, self.bull_memory
|
||||
)
|
||||
self.reflector.reflect_bear_researcher(
|
||||
self.curr_state, returns_losses, self.bear_memory
|
||||
)
|
||||
self.reflector.reflect_trader(
|
||||
self.curr_state, returns_losses, self.trader_memory
|
||||
)
|
||||
self.reflector.reflect_invest_judge(
|
||||
self.curr_state, returns_losses, self.invest_judge_memory
|
||||
)
|
||||
self.reflector.reflect_risk_manager(
|
||||
self.curr_state, returns_losses, self.risk_manager_memory
|
||||
)
|
||||
|
||||
def process_signal(self, full_signal):
|
||||
"""Process a signal to extract the core decision."""
|
||||
return self.signal_processor.process_signal(full_signal)
|
||||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import date
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
RiskDebateState,
|
||||
)
|
||||
from tradingagents.dataflows.interface import set_config
|
||||
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
from .reflection import Reflector
|
||||
from .signal_processing import SignalProcessor
|
||||
|
||||
|
||||
class TradingAgentsGraph:
|
||||
"""Main class that orchestrates the trading agents framework."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""Initialize the trading agents graph and components.
|
||||
|
||||
Args:
|
||||
selected_analysts: List of analyst types to include
|
||||
debug: Whether to run in debug mode
|
||||
config: Configuration dictionary. If None, uses default config
|
||||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs(
|
||||
os.path.join(self.config["project_dir"], "dataflows/data_cache"),
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
# Initialize LLMs
|
||||
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
|
||||
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "anthropic":
|
||||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "google":
|
||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||
|
||||
self.toolkit = Toolkit(config=self.config)
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||
self.risk_manager_memory = FinancialSituationMemory("risk_manager_memory", self.config)
|
||||
|
||||
# Create tool nodes
|
||||
self.tool_nodes = self._create_tool_nodes()
|
||||
|
||||
# Initialize components
|
||||
self.conditional_logic = ConditionalLogic()
|
||||
self.graph_setup = GraphSetup(
|
||||
self.quick_thinking_llm,
|
||||
self.deep_thinking_llm,
|
||||
self.toolkit,
|
||||
self.tool_nodes,
|
||||
self.bull_memory,
|
||||
self.bear_memory,
|
||||
self.trader_memory,
|
||||
self.invest_judge_memory,
|
||||
self.risk_manager_memory,
|
||||
self.conditional_logic,
|
||||
)
|
||||
|
||||
self.propagator = Propagator()
|
||||
self.reflector = Reflector(self.quick_thinking_llm)
|
||||
self.signal_processor = SignalProcessor(self.quick_thinking_llm)
|
||||
|
||||
# State tracking
|
||||
self.curr_state = None
|
||||
self.ticker = None
|
||||
self.log_states_dict = {} # date to full state dict
|
||||
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
|
||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||
"""Create tool nodes for different data sources."""
|
||||
return {
|
||||
"market": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_YFin_data_online,
|
||||
self.toolkit.get_stockstats_indicators_report_online,
|
||||
# offline tools
|
||||
self.toolkit.get_YFin_data,
|
||||
self.toolkit.get_stockstats_indicators_report,
|
||||
]
|
||||
),
|
||||
"social": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_stock_news,
|
||||
# offline tools
|
||||
# self.toolkit.get_reddit_stock_info,
|
||||
]
|
||||
),
|
||||
"news": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_global_news,
|
||||
self.toolkit.get_google_news,
|
||||
# offline tools
|
||||
# self.toolkit.get_finnhub_news,
|
||||
# self.toolkit.get_reddit_news,
|
||||
]
|
||||
),
|
||||
"fundamentals": ToolNode(
|
||||
[
|
||||
# online tools
|
||||
self.toolkit.get_fundamentals,
|
||||
# offline tools
|
||||
# self.toolkit.get_finnhub_company_insider_sentiment,
|
||||
# self.toolkit.get_finnhub_company_insider_transactions,
|
||||
# self.toolkit.get_simfin_balance_sheet,
|
||||
# self.toolkit.get_simfin_cashflow,
|
||||
# self.toolkit.get_simfin_income_stmt,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
"""Run the trading agents graph for a company on a specific date."""
|
||||
|
||||
self.ticker = company_name
|
||||
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
if self.debug:
|
||||
# Debug mode with tracing
|
||||
trace = []
|
||||
for chunk in self.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) == 0:
|
||||
continue
|
||||
|
||||
message = chunk["messages"][-1]
|
||||
|
||||
if message.content and message.content.strip():
|
||||
|
||||
if "FINAL TRANSACTION PROPOSAL:" in message.content:
|
||||
if not hasattr(self, '_final_printed'):
|
||||
message.pretty_print()
|
||||
self._final_printed = True
|
||||
else:
|
||||
message.pretty_print()
|
||||
|
||||
trace.append(chunk)
|
||||
|
||||
final_state = trace[-1]
|
||||
else:
|
||||
# Standard mode without tracing
|
||||
final_state = self.graph.invoke(init_agent_state, **args)
|
||||
|
||||
# Store current state for reflection
|
||||
self.curr_state = final_state
|
||||
|
||||
# Log state
|
||||
self._log_state(trade_date, final_state)
|
||||
|
||||
# Return decision and processed signal
|
||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
def _log_state(self, trade_date, final_state):
|
||||
"""Log the final state to a JSON file."""
|
||||
self.log_states_dict[str(trade_date)] = {
|
||||
"company_of_interest": final_state["company_of_interest"],
|
||||
"trade_date": final_state["trade_date"],
|
||||
"market_report": final_state["market_report"],
|
||||
"sentiment_report": final_state["sentiment_report"],
|
||||
"news_report": final_state["news_report"],
|
||||
"fundamentals_report": final_state["fundamentals_report"],
|
||||
"investment_debate_state": {
|
||||
"bull_history": final_state["investment_debate_state"]["bull_history"],
|
||||
"bear_history": final_state["investment_debate_state"]["bear_history"],
|
||||
"history": final_state["investment_debate_state"]["history"],
|
||||
"current_response": final_state["investment_debate_state"][
|
||||
"current_response"
|
||||
],
|
||||
"judge_decision": final_state["investment_debate_state"][
|
||||
"judge_decision"
|
||||
],
|
||||
},
|
||||
"trader_investment_decision": final_state["trader_investment_plan"],
|
||||
"risk_debate_state": {
|
||||
"risky_history": final_state["risk_debate_state"]["risky_history"],
|
||||
"safe_history": final_state["risk_debate_state"]["safe_history"],
|
||||
"neutral_history": final_state["risk_debate_state"]["neutral_history"],
|
||||
"history": final_state["risk_debate_state"]["history"],
|
||||
"judge_decision": final_state["risk_debate_state"]["judge_decision"],
|
||||
},
|
||||
"investment_plan": final_state["investment_plan"],
|
||||
"final_trade_decision": final_state["final_trade_decision"],
|
||||
}
|
||||
|
||||
# Save to file
|
||||
directory = Path(f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/")
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(
|
||||
f"eval_results/{self.ticker}/TradingAgentsStrategy_logs/full_states_log.json",
|
||||
"w",
|
||||
) as f:
|
||||
json.dump(self.log_states_dict, f, indent=4)
|
||||
|
||||
def reflect_and_remember(self, returns_losses):
|
||||
"""Reflect on decisions and update memory based on returns."""
|
||||
self.reflector.reflect_bull_researcher(
|
||||
self.curr_state, returns_losses, self.bull_memory
|
||||
)
|
||||
self.reflector.reflect_bear_researcher(
|
||||
self.curr_state, returns_losses, self.bear_memory
|
||||
)
|
||||
self.reflector.reflect_trader(
|
||||
self.curr_state, returns_losses, self.trader_memory
|
||||
)
|
||||
self.reflector.reflect_invest_judge(
|
||||
self.curr_state, returns_losses, self.invest_judge_memory
|
||||
)
|
||||
self.reflector.reflect_risk_manager(
|
||||
self.curr_state, returns_losses, self.risk_manager_memory
|
||||
)
|
||||
|
||||
def process_signal(self, full_signal):
|
||||
"""Process a signal to extract the core decision."""
|
||||
return self.signal_processor.process_signal(full_signal)
|
||||
Loading…
Reference in New Issue