feat: add optional custom prompt support
This commit is contained in:
parent
10c136f49c
commit
4fed0d3ee9
14
README.md
14
README.md
|
|
@ -158,7 +158,7 @@ Launch the interactive CLI:
|
|||
tradingagents # installed command
|
||||
python -m cli.main # alternative: run directly from source
|
||||
```
|
||||
You will see a screen where you can select your desired tickers, analysis date, LLM provider, research depth, and more.
|
||||
You will see a screen where you can select your desired tickers, analysis date, an optional custom prompt for run-specific instructions, LLM provider, research depth, and more.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/cli/cli_init.png" width="100%" style="display: inline-block; margin: 0 2%;">
|
||||
|
|
@ -191,7 +191,11 @@ from tradingagents.default_config import DEFAULT_CONFIG
|
|||
ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy())
|
||||
|
||||
# forward propagate
|
||||
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||
_, decision = ta.propagate(
|
||||
"NVDA",
|
||||
"2026-01-15",
|
||||
custom_prompt="Long-term horizon. Focus on earnings quality, capex discipline, and new entries only.",
|
||||
)
|
||||
print(decision)
|
||||
```
|
||||
|
||||
|
|
@ -208,7 +212,11 @@ config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
|||
config["max_debate_rounds"] = 2
|
||||
|
||||
ta = TradingAgentsGraph(debug=True, config=config)
|
||||
_, decision = ta.propagate("NVDA", "2026-01-15")
|
||||
_, decision = ta.propagate(
|
||||
"NVDA",
|
||||
"2026-01-15",
|
||||
custom_prompt="Short-term trading setup. Prioritize momentum, catalyst timing, and downside risk.",
|
||||
)
|
||||
print(decision)
|
||||
```
|
||||
|
||||
|
|
|
|||
62
cli/main.py
62
cli/main.py
|
|
@ -24,6 +24,10 @@ from rich.align import Align
|
|||
from rich.rule import Rule
|
||||
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.custom_prompt import (
|
||||
CUSTOM_PROMPT_SECTION_TITLE,
|
||||
CUSTOM_PROMPT_STATUS_LABEL,
|
||||
)
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from cli.models import AnalystType
|
||||
from cli.utils import *
|
||||
|
|
@ -519,19 +523,28 @@ def get_user_selections():
|
|||
)
|
||||
analysis_date = get_analysis_date()
|
||||
|
||||
# Step 3: Output language
|
||||
# Step 3: Optional custom prompt
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 3: Output Language",
|
||||
"Step 3: Custom Prompt",
|
||||
"Optionally add run-specific instructions such as short-term/long-term horizon, only new positions, or risks to emphasize"
|
||||
)
|
||||
)
|
||||
custom_prompt = ask_custom_prompt()
|
||||
|
||||
# Step 4: Output language
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 4: Output Language",
|
||||
"Select the language for analyst reports and final decision"
|
||||
)
|
||||
)
|
||||
output_language = ask_output_language()
|
||||
|
||||
# Step 4: Select analysts
|
||||
# Step 5: Select analysts
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 4: Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||
"Step 5: Analysts Team", "Select your LLM analyst agents for the analysis"
|
||||
)
|
||||
)
|
||||
selected_analysts = select_analysts()
|
||||
|
|
@ -539,32 +552,32 @@ def get_user_selections():
|
|||
f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}"
|
||||
)
|
||||
|
||||
# Step 5: Research depth
|
||||
# Step 6: Research depth
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 5: Research Depth", "Select your research depth level"
|
||||
"Step 6: Research Depth", "Select your research depth level"
|
||||
)
|
||||
)
|
||||
selected_research_depth = select_research_depth()
|
||||
|
||||
# Step 6: LLM Provider
|
||||
# Step 7: LLM Provider
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 6: LLM Provider", "Select your LLM provider"
|
||||
"Step 7: LLM Provider", "Select your LLM provider"
|
||||
)
|
||||
)
|
||||
selected_llm_provider, backend_url = select_llm_provider()
|
||||
|
||||
# Step 7: Thinking agents
|
||||
# Step 8: Thinking agents
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 7: Thinking Agents", "Select your thinking agents for analysis"
|
||||
"Step 8: Thinking Agents", "Select your thinking agents for analysis"
|
||||
)
|
||||
)
|
||||
selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider)
|
||||
selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider)
|
||||
|
||||
# Step 8: Provider-specific thinking configuration
|
||||
# Step 9: Provider-specific thinking configuration
|
||||
thinking_level = None
|
||||
reasoning_effort = None
|
||||
anthropic_effort = None
|
||||
|
|
@ -573,7 +586,7 @@ def get_user_selections():
|
|||
if provider_lower == "google":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 8: Thinking Mode",
|
||||
"Step 9: Thinking Mode",
|
||||
"Configure Gemini thinking mode"
|
||||
)
|
||||
)
|
||||
|
|
@ -581,7 +594,7 @@ def get_user_selections():
|
|||
elif provider_lower == "openai":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 8: Reasoning Effort",
|
||||
"Step 9: Reasoning Effort",
|
||||
"Configure OpenAI reasoning effort level"
|
||||
)
|
||||
)
|
||||
|
|
@ -589,7 +602,7 @@ def get_user_selections():
|
|||
elif provider_lower == "anthropic":
|
||||
console.print(
|
||||
create_question_box(
|
||||
"Step 8: Effort Level",
|
||||
"Step 9: Effort Level",
|
||||
"Configure Claude effort level"
|
||||
)
|
||||
)
|
||||
|
|
@ -598,6 +611,7 @@ def get_user_selections():
|
|||
return {
|
||||
"ticker": selected_ticker,
|
||||
"analysis_date": analysis_date,
|
||||
"custom_prompt": custom_prompt,
|
||||
"analysts": selected_analysts,
|
||||
"research_depth": selected_research_depth,
|
||||
"llm_provider": selected_llm_provider.lower(),
|
||||
|
|
@ -721,6 +735,8 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path):
|
|||
|
||||
# Write consolidated report
|
||||
header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
||||
if final_state.get("custom_prompt"):
|
||||
header += f"## {CUSTOM_PROMPT_SECTION_TITLE}\n\n{final_state['custom_prompt']}\n\n"
|
||||
(save_path / "complete_report.md").write_text(header + "\n\n".join(sections))
|
||||
return save_path / "complete_report.md"
|
||||
|
||||
|
|
@ -730,6 +746,16 @@ def display_complete_report(final_state):
|
|||
console.print()
|
||||
console.print(Rule("Complete Analysis Report", style="bold green"))
|
||||
|
||||
if final_state.get("custom_prompt"):
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(final_state["custom_prompt"]),
|
||||
title=CUSTOM_PROMPT_SECTION_TITLE,
|
||||
border_style="magenta",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
# I. Analyst Team Reports
|
||||
analysts = []
|
||||
if final_state.get("market_report"):
|
||||
|
|
@ -1024,6 +1050,10 @@ def run_analysis():
|
|||
message_buffer.add_message(
|
||||
"System", f"Analysis date: {selections['analysis_date']}"
|
||||
)
|
||||
if selections.get("custom_prompt"):
|
||||
message_buffer.add_message(
|
||||
"System", f"{CUSTOM_PROMPT_STATUS_LABEL}: {selections['custom_prompt']}"
|
||||
)
|
||||
message_buffer.add_message(
|
||||
"System",
|
||||
f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}",
|
||||
|
|
@ -1043,7 +1073,9 @@ def run_analysis():
|
|||
|
||||
# Initialize state and get graph args with callbacks
|
||||
init_agent_state = graph.propagator.create_initial_state(
|
||||
selections["ticker"], selections["analysis_date"]
|
||||
selections["ticker"],
|
||||
selections["analysis_date"],
|
||||
selections.get("custom_prompt"),
|
||||
)
|
||||
# Pass callbacks to graph config for tool execution tracking
|
||||
# (LLM tracking is handled separately via LLM constructor)
|
||||
|
|
|
|||
20
cli/utils.py
20
cli/utils.py
|
|
@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Dict
|
|||
from rich.console import Console
|
||||
|
||||
from cli.models import AnalystType
|
||||
from tradingagents.custom_prompt import CUSTOM_PROMPT_INPUT_PROMPT
|
||||
from tradingagents.llm_clients.model_catalog import get_model_options
|
||||
|
||||
console = Console()
|
||||
|
|
@ -76,6 +77,25 @@ def get_analysis_date() -> str:
|
|||
return date.strip()
|
||||
|
||||
|
||||
def ask_custom_prompt() -> str:
|
||||
"""Prompt for optional run-specific instructions."""
|
||||
custom_prompt = questionary.text(
|
||||
CUSTOM_PROMPT_INPUT_PROMPT,
|
||||
default="",
|
||||
style=questionary.Style(
|
||||
[
|
||||
("text", "fg:green"),
|
||||
("highlighted", "noinherit"),
|
||||
]
|
||||
),
|
||||
).ask()
|
||||
|
||||
if custom_prompt is None:
|
||||
return ""
|
||||
|
||||
return custom_prompt.strip()
|
||||
|
||||
|
||||
def select_analysts() -> List[AnalystType]:
|
||||
"""Select analysts using an interactive checkbox."""
|
||||
choices = questionary.checkbox(
|
||||
|
|
|
|||
6
main.py
6
main.py
|
|
@ -24,7 +24,11 @@ config["data_vendors"] = {
|
|||
ta = TradingAgentsGraph(debug=True, config=config)
|
||||
|
||||
# forward propagate
|
||||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
_, decision = ta.propagate(
|
||||
"NVDA",
|
||||
"2024-05-10",
|
||||
custom_prompt="Long-term horizon. Focus on durable earnings power and capital allocation.",
|
||||
)
|
||||
print(decision)
|
||||
|
||||
# Memorize mistakes and reflect
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from cli.main import save_report_to_disk
|
||||
from tradingagents.agents.utils.agent_utils import build_custom_prompt_context
|
||||
from tradingagents.graph.propagation import Propagator
|
||||
|
||||
|
||||
class CustomPromptTests(unittest.TestCase):
|
||||
def test_build_custom_prompt_context_is_empty_when_missing(self):
|
||||
self.assertEqual(build_custom_prompt_context(" "), "")
|
||||
|
||||
def test_build_custom_prompt_context_formats_user_guidance(self):
|
||||
context = build_custom_prompt_context(
|
||||
"Long-term horizon; focus on earnings quality and capex discipline."
|
||||
)
|
||||
self.assertIn("Additional user instructions", context)
|
||||
self.assertIn("Long-term horizon", context)
|
||||
self.assertIn("explicit strategy constraints", context)
|
||||
|
||||
def test_create_initial_state_stores_custom_prompt(self):
|
||||
state = Propagator().create_initial_state(
|
||||
"META",
|
||||
"2026-04-05",
|
||||
custom_prompt="Short-term swing trade; new positions only.",
|
||||
)
|
||||
self.assertEqual(
|
||||
state["custom_prompt"],
|
||||
"Short-term swing trade; new positions only.",
|
||||
)
|
||||
|
||||
def test_save_report_to_disk_includes_custom_prompt_header(self):
|
||||
final_state = {
|
||||
"custom_prompt": "Long-term horizon; focus on capital allocation.",
|
||||
"market_report": "",
|
||||
"sentiment_report": "",
|
||||
"news_report": "",
|
||||
"fundamentals_report": "",
|
||||
"investment_debate_state": {},
|
||||
"risk_debate_state": {},
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
report_path = save_report_to_disk(
|
||||
final_state,
|
||||
"META",
|
||||
Path(tmpdir),
|
||||
)
|
||||
report_text = Path(report_path).read_text()
|
||||
|
||||
self.assertIn("## Custom Prompt", report_text)
|
||||
self.assertIn("Long-term horizon", report_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_custom_prompt_context,
|
||||
build_instrument_context,
|
||||
get_balance_sheet,
|
||||
get_cashflow,
|
||||
|
|
@ -15,6 +16,7 @@ def create_fundamentals_analyst(llm):
|
|||
def fundamentals_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
|
||||
tools = [
|
||||
get_fundamentals,
|
||||
|
|
@ -41,7 +43,7 @@ def create_fundamentals_analyst(llm):
|
|||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
"For your reference, the current date is {current_date}. {instrument_context}\n{custom_prompt_context}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
@ -51,6 +53,7 @@ def create_fundamentals_analyst(llm):
|
|||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
prompt = prompt.partial(custom_prompt_context=custom_prompt_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_custom_prompt_context,
|
||||
build_instrument_context,
|
||||
get_indicators,
|
||||
get_language_instruction,
|
||||
|
|
@ -13,6 +14,7 @@ def create_market_analyst(llm):
|
|||
def market_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
|
||||
tools = [
|
||||
get_stock_data,
|
||||
|
|
@ -60,7 +62,7 @@ Volume-Based Indicators:
|
|||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
"For your reference, the current date is {current_date}. {instrument_context}\n{custom_prompt_context}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
@ -70,6 +72,7 @@ Volume-Based Indicators:
|
|||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
prompt = prompt.partial(custom_prompt_context=custom_prompt_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_custom_prompt_context,
|
||||
build_instrument_context,
|
||||
get_global_news,
|
||||
get_language_instruction,
|
||||
|
|
@ -12,6 +13,7 @@ def create_news_analyst(llm):
|
|||
def news_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
|
||||
tools = [
|
||||
get_news,
|
||||
|
|
@ -35,7 +37,7 @@ def create_news_analyst(llm):
|
|||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
"For your reference, the current date is {current_date}. {instrument_context}\n{custom_prompt_context}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
@ -45,6 +47,7 @@ def create_news_analyst(llm):
|
|||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
prompt = prompt.partial(custom_prompt_context=custom_prompt_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_custom_prompt_context,
|
||||
build_instrument_context,
|
||||
get_language_instruction,
|
||||
get_news,
|
||||
)
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
||||
|
|
@ -7,6 +12,7 @@ def create_social_media_analyst(llm):
|
|||
def social_media_analyst_node(state):
|
||||
current_date = state["trade_date"]
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
|
||||
tools = [
|
||||
get_news,
|
||||
|
|
@ -29,7 +35,7 @@ def create_social_media_analyst(llm):
|
|||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. {instrument_context}",
|
||||
"For your reference, the current date is {current_date}. {instrument_context}\n{custom_prompt_context}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
|
|
@ -39,6 +45,7 @@ def create_social_media_analyst(llm):
|
|||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(instrument_context=instrument_context)
|
||||
prompt = prompt.partial(custom_prompt_context=custom_prompt_context)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,15 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_custom_prompt_context,
|
||||
build_instrument_context,
|
||||
get_language_instruction,
|
||||
)
|
||||
|
||||
|
||||
def create_portfolio_manager(llm, memory):
|
||||
def portfolio_manager_node(state) -> dict:
|
||||
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
|
||||
history = state["risk_debate_state"]["history"]
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
|
|
@ -15,7 +20,7 @@ def create_portfolio_manager(llm, memory):
|
|||
research_plan = state["investment_plan"]
|
||||
trader_plan = state["trader_investment_plan"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
|
|
@ -25,6 +30,7 @@ def create_portfolio_manager(llm, memory):
|
|||
prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision.
|
||||
|
||||
{instrument_context}
|
||||
{custom_prompt_context}
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_custom_prompt_context,
|
||||
build_instrument_context,
|
||||
)
|
||||
|
||||
|
||||
def create_research_manager(llm, memory):
|
||||
def research_manager_node(state) -> dict:
|
||||
instrument_context = build_instrument_context(state["company_of_interest"])
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
history = state["investment_debate_state"].get("history", "")
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
|
|
@ -13,7 +17,7 @@ def create_research_manager(llm, memory):
|
|||
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
|
|
@ -35,6 +39,7 @@ Here are your past reflections on mistakes:
|
|||
\"{past_memory_str}\"
|
||||
|
||||
{instrument_context}
|
||||
{custom_prompt_context}
|
||||
|
||||
Here is the debate:
|
||||
Debate History:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_custom_prompt_context
|
||||
|
||||
|
||||
def create_bear_researcher(llm, memory):
|
||||
def bear_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
history = investment_debate_state.get("history", "")
|
||||
bear_history = investment_debate_state.get("bear_history", "")
|
||||
|
||||
|
|
@ -12,7 +14,7 @@ def create_bear_researcher(llm, memory):
|
|||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
|
|
@ -38,6 +40,7 @@ Company fundamentals report: {fundamentals_report}
|
|||
Conversation history of the debate: {history}
|
||||
Last bull argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
{custom_prompt_context}
|
||||
Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_custom_prompt_context
|
||||
|
||||
|
||||
def create_bull_researcher(llm, memory):
|
||||
def bull_node(state) -> dict:
|
||||
investment_debate_state = state["investment_debate_state"]
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
history = investment_debate_state.get("history", "")
|
||||
bull_history = investment_debate_state.get("bull_history", "")
|
||||
|
||||
|
|
@ -12,7 +14,7 @@ def create_bull_researcher(llm, memory):
|
|||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
|
|
@ -36,6 +38,7 @@ Company fundamentals report: {fundamentals_report}
|
|||
Conversation history of the debate: {history}
|
||||
Last bear argument: {current_response}
|
||||
Reflections from similar situations and lessons learned: {past_memory_str}
|
||||
{custom_prompt_context}
|
||||
Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past.
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_custom_prompt_context
|
||||
|
||||
|
||||
def create_aggressive_debator(llm):
|
||||
def aggressive_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
history = risk_debate_state.get("history", "")
|
||||
aggressive_history = risk_debate_state.get("aggressive_history", "")
|
||||
|
||||
|
|
@ -26,6 +28,7 @@ Market Research Report: {market_research_report}
|
|||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
{custom_prompt_context}
|
||||
Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_custom_prompt_context
|
||||
|
||||
|
||||
def create_conservative_debator(llm):
|
||||
def conservative_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
history = risk_debate_state.get("history", "")
|
||||
conservative_history = risk_debate_state.get("conservative_history", "")
|
||||
|
||||
|
|
@ -26,6 +28,7 @@ Market Research Report: {market_research_report}
|
|||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
{custom_prompt_context}
|
||||
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from tradingagents.agents.utils.agent_utils import build_custom_prompt_context
|
||||
|
||||
|
||||
def create_neutral_debator(llm):
|
||||
def neutral_node(state) -> dict:
|
||||
risk_debate_state = state["risk_debate_state"]
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
history = risk_debate_state.get("history", "")
|
||||
neutral_history = risk_debate_state.get("neutral_history", "")
|
||||
|
||||
|
|
@ -26,6 +28,7 @@ Market Research Report: {market_research_report}
|
|||
Social Media Sentiment Report: {sentiment_report}
|
||||
Latest World Affairs Report: {news_report}
|
||||
Company Fundamentals Report: {fundamentals_report}
|
||||
{custom_prompt_context}
|
||||
Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data.
|
||||
|
||||
Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting."""
|
||||
|
|
|
|||
|
|
@ -1,19 +1,23 @@
|
|||
import functools
|
||||
|
||||
from tradingagents.agents.utils.agent_utils import build_instrument_context
|
||||
from tradingagents.agents.utils.agent_utils import (
|
||||
build_custom_prompt_context,
|
||||
build_instrument_context,
|
||||
)
|
||||
|
||||
|
||||
def create_trader(llm, memory):
|
||||
def trader_node(state, name):
|
||||
company_name = state["company_of_interest"]
|
||||
instrument_context = build_instrument_context(company_name)
|
||||
custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt"))
|
||||
investment_plan = state["investment_plan"]
|
||||
market_research_report = state["market_report"]
|
||||
sentiment_report = state["sentiment_report"]
|
||||
news_report = state["news_report"]
|
||||
fundamentals_report = state["fundamentals_report"]
|
||||
|
||||
curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}"
|
||||
past_memories = memory.get_memories(curr_situation, n_matches=2)
|
||||
|
||||
past_memory_str = ""
|
||||
|
|
@ -25,7 +29,7 @@ def create_trader(llm, memory):
|
|||
|
||||
context = {
|
||||
"role": "user",
|
||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. {instrument_context} This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
||||
"content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. {instrument_context}\n\n{custom_prompt_context}\n\nThis plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.",
|
||||
}
|
||||
|
||||
messages = [
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class RiskDebateState(TypedDict):
|
|||
class AgentState(MessagesState):
|
||||
company_of_interest: Annotated[str, "Company that we are interested in trading"]
|
||||
trade_date: Annotated[str, "What date we are trading at"]
|
||||
custom_prompt: Annotated[str, "Optional run-specific instructions from the user"]
|
||||
|
||||
sender: Annotated[str, "Agent that sent this message"]
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ from tradingagents.agents.utils.news_data_tools import (
|
|||
get_insider_transactions,
|
||||
get_global_news
|
||||
)
|
||||
from tradingagents.custom_prompt import (
|
||||
CUSTOM_PROMPT_CONTEXT_FOOTER,
|
||||
CUSTOM_PROMPT_CONTEXT_HEADER,
|
||||
)
|
||||
|
||||
|
||||
def get_language_instruction() -> str:
|
||||
|
|
@ -34,6 +38,14 @@ def get_language_instruction() -> str:
|
|||
return f" Write your entire response in {lang}."
|
||||
|
||||
|
||||
def build_custom_prompt_context(custom_prompt: str | None) -> str:
|
||||
"""Format optional run-specific instructions supplied by the user."""
|
||||
prompt = (custom_prompt or "").strip()
|
||||
if not prompt:
|
||||
return ""
|
||||
return f"{CUSTOM_PROMPT_CONTEXT_HEADER}\n{prompt}\n{CUSTOM_PROMPT_CONTEXT_FOOTER}"
|
||||
|
||||
|
||||
def build_instrument_context(ticker: str) -> str:
|
||||
"""Describe the exact instrument so agents preserve exchange-qualified tickers."""
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
"""Shared constants for optional run-specific custom prompt handling."""
|
||||
|
||||
CUSTOM_PROMPT_SECTION_TITLE = "Custom Prompt"
|
||||
CUSTOM_PROMPT_STATUS_LABEL = "Custom prompt"
|
||||
CUSTOM_PROMPT_INPUT_PROMPT = (
|
||||
"Optional custom prompt for this run (examples: short-term swing trade, "
|
||||
"long-term holding, only new positions, focus on earnings and capex):"
|
||||
)
|
||||
CUSTOM_PROMPT_CONTEXT_HEADER = "Additional user instructions for this run:"
|
||||
CUSTOM_PROMPT_CONTEXT_FOOTER = (
|
||||
"Treat these as explicit strategy constraints and address them directly in "
|
||||
"your analysis, trade plan, and final recommendation."
|
||||
)
|
||||
|
|
@ -16,13 +16,14 @@ class Propagator:
|
|||
self.max_recur_limit = max_recur_limit
|
||||
|
||||
def create_initial_state(
|
||||
self, company_name: str, trade_date: str
|
||||
self, company_name: str, trade_date: str, custom_prompt: str | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create the initial state for the agent graph."""
|
||||
return {
|
||||
"messages": [("human", company_name)],
|
||||
"company_of_interest": company_name,
|
||||
"trade_date": str(trade_date),
|
||||
"custom_prompt": (custom_prompt or "").strip(),
|
||||
"investment_debate_state": InvestDebateState(
|
||||
{
|
||||
"bull_history": "",
|
||||
|
|
|
|||
|
|
@ -191,14 +191,14 @@ class TradingAgentsGraph:
|
|||
),
|
||||
}
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
def propagate(self, company_name, trade_date, custom_prompt: str | None = None):
|
||||
"""Run the trading agents graph for a company on a specific date."""
|
||||
|
||||
self.ticker = company_name
|
||||
|
||||
# Initialize state
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
company_name, trade_date, custom_prompt=custom_prompt
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
|
|
@ -231,6 +231,7 @@ class TradingAgentsGraph:
|
|||
self.log_states_dict[str(trade_date)] = {
|
||||
"company_of_interest": final_state["company_of_interest"],
|
||||
"trade_date": final_state["trade_date"],
|
||||
"custom_prompt": final_state.get("custom_prompt", ""),
|
||||
"market_report": final_state["market_report"],
|
||||
"sentiment_report": final_state["sentiment_report"],
|
||||
"news_report": final_state["news_report"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue