diff --git a/README.md b/README.md index 97cbde48..52b4e3e8 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ Launch the interactive CLI: tradingagents # installed command python -m cli.main # alternative: run directly from source ``` -You will see a screen where you can select your desired tickers, analysis date, LLM provider, research depth, and more. +You will see a screen where you can select your desired tickers, analysis date, an optional custom prompt for run-specific instructions, LLM provider, research depth, and more.

@@ -196,7 +196,11 @@ from tradingagents.default_config import DEFAULT_CONFIG ta = TradingAgentsGraph(debug=True, config=DEFAULT_CONFIG.copy()) # forward propagate -_, decision = ta.propagate("NVDA", "2026-01-15") +_, decision = ta.propagate( + "NVDA", + "2026-01-15", + custom_prompt="Long-term horizon. Focus on earnings quality, capex discipline, and new entries only.", +) print(decision) ``` @@ -213,7 +217,11 @@ config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks config["max_debate_rounds"] = 2 ta = TradingAgentsGraph(debug=True, config=config) -_, decision = ta.propagate("NVDA", "2026-01-15") +_, decision = ta.propagate( + "NVDA", + "2026-01-15", + custom_prompt="Short-term trading setup. Prioritize momentum, catalyst timing, and downside risk.", +) print(decision) ``` diff --git a/cli/main.py b/cli/main.py index 33d110fb..d8347447 100644 --- a/cli/main.py +++ b/cli/main.py @@ -25,6 +25,10 @@ from rich.align import Align from rich.rule import Rule from tradingagents.graph.trading_graph import TradingAgentsGraph +from tradingagents.custom_prompt import ( + CUSTOM_PROMPT_SECTION_TITLE, + CUSTOM_PROMPT_STATUS_LABEL, +) from tradingagents.default_config import DEFAULT_CONFIG from cli.models import AnalystType from cli.utils import * @@ -520,19 +524,28 @@ def get_user_selections(): ) analysis_date = get_analysis_date() - # Step 3: Output language + # Step 3: Optional custom prompt console.print( create_question_box( - "Step 3: Output Language", + "Step 3: Custom Prompt", + "Optionally add run-specific instructions such as short-term/long-term horizon, only new positions, or risks to emphasize" + ) + ) + custom_prompt = ask_custom_prompt() + + # Step 4: Output language + console.print( + create_question_box( + "Step 4: Output Language", "Select the language for analyst reports and final decision" ) ) output_language = ask_output_language() - # Step 4: Select analysts + # Step 5: Select analysts console.print( create_question_box( - "Step 4: Analysts Team", "Select your LLM analyst agents for the analysis" + "Step 5: Analysts Team", "Select your LLM analyst agents for the analysis" ) ) selected_analysts = select_analysts() @@ -540,32 +553,32 @@ def get_user_selections(): f"[green]Selected analysts:[/green] {', '.join(analyst.value for analyst in selected_analysts)}" ) - # Step 5: Research depth + # Step 6: Research depth console.print( create_question_box( - "Step 5: Research Depth", "Select your research depth level" + "Step 6: Research Depth", "Select your research depth level" ) ) selected_research_depth = select_research_depth() - # Step 6: LLM Provider + # Step 7: LLM Provider console.print( create_question_box( - "Step 6: LLM Provider", "Select your LLM provider" + "Step 7: LLM Provider", "Select your LLM provider" ) ) selected_llm_provider, backend_url = select_llm_provider() - # Step 7: Thinking agents + # Step 8: Thinking agents console.print( create_question_box( - "Step 7: Thinking Agents", "Select your thinking agents for analysis" + "Step 8: Thinking Agents", "Select your thinking agents for analysis" ) ) selected_shallow_thinker = select_shallow_thinking_agent(selected_llm_provider) selected_deep_thinker = select_deep_thinking_agent(selected_llm_provider) - # Step 8: Provider-specific thinking configuration + # Step 9: Provider-specific thinking configuration thinking_level = None reasoning_effort = None anthropic_effort = None @@ -574,7 +587,7 @@ def get_user_selections(): if provider_lower == "google": console.print( create_question_box( - "Step 8: Thinking Mode", + "Step 9: Thinking Mode", "Configure Gemini thinking mode" ) ) @@ -582,7 +595,7 @@ def get_user_selections(): elif provider_lower == "openai": console.print( create_question_box( - "Step 8: Reasoning Effort", + "Step 9: Reasoning Effort", "Configure OpenAI reasoning effort level" ) ) @@ -590,7 +603,7 @@ def get_user_selections(): elif provider_lower == "anthropic": console.print( create_question_box( - "Step 8: Effort Level", + "Step 9: Effort Level", "Configure Claude effort level" ) ) @@ -599,6 +612,7 @@ def get_user_selections(): return { "ticker": selected_ticker, "analysis_date": analysis_date, + "custom_prompt": custom_prompt, "analysts": selected_analysts, "research_depth": selected_research_depth, "llm_provider": selected_llm_provider.lower(), @@ -722,6 +736,8 @@ def save_report_to_disk(final_state, ticker: str, save_path: Path): # Write consolidated report header = f"# Trading Analysis Report: {ticker}\n\nGenerated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + if final_state.get("custom_prompt"): + header += f"## {CUSTOM_PROMPT_SECTION_TITLE}\n\n{final_state['custom_prompt']}\n\n" (save_path / "complete_report.md").write_text(header + "\n\n".join(sections)) return save_path / "complete_report.md" @@ -731,6 +747,16 @@ def display_complete_report(final_state): console.print() console.print(Rule("Complete Analysis Report", style="bold green")) + if final_state.get("custom_prompt"): + console.print( + Panel( + Markdown(final_state["custom_prompt"]), + title=CUSTOM_PROMPT_SECTION_TITLE, + border_style="magenta", + padding=(1, 2), + ) + ) + # I. Analyst Team Reports analysts = [] if final_state.get("market_report"): @@ -1025,6 +1051,10 @@ def run_analysis(): message_buffer.add_message( "System", f"Analysis date: {selections['analysis_date']}" ) + if selections.get("custom_prompt"): + message_buffer.add_message( + "System", f"{CUSTOM_PROMPT_STATUS_LABEL}: {selections['custom_prompt']}" + ) message_buffer.add_message( "System", f"Selected analysts: {', '.join(analyst.value for analyst in selections['analysts'])}", @@ -1044,7 +1074,9 @@ def run_analysis(): # Initialize state and get graph args with callbacks init_agent_state = graph.propagator.create_initial_state( - selections["ticker"], selections["analysis_date"] + selections["ticker"], + selections["analysis_date"], + selections.get("custom_prompt"), ) # Pass callbacks to graph config for tool execution tracking # (LLM tracking is handled separately via LLM constructor) diff --git a/cli/utils.py b/cli/utils.py index 85c282ed..057ca17d 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Dict from rich.console import Console from cli.models import AnalystType +from tradingagents.custom_prompt import CUSTOM_PROMPT_INPUT_PROMPT from tradingagents.llm_clients.model_catalog import get_model_options console = Console() @@ -76,6 +77,25 @@ def get_analysis_date() -> str: return date.strip() +def ask_custom_prompt() -> str: + """Prompt for optional run-specific instructions.""" + custom_prompt = questionary.text( + CUSTOM_PROMPT_INPUT_PROMPT, + default="", + style=questionary.Style( + [ + ("text", "fg:green"), + ("highlighted", "noinherit"), + ] + ), + ).ask() + + if custom_prompt is None: + return "" + + return custom_prompt.strip() + + def select_analysts() -> List[AnalystType]: """Select analysts using an interactive checkbox.""" choices = questionary.checkbox( diff --git a/main.py b/main.py index c94fde32..bc9e3e7f 100644 --- a/main.py +++ b/main.py @@ -24,7 +24,11 @@ config["data_vendors"] = { ta = TradingAgentsGraph(debug=True, config=config) # forward propagate -_, decision = ta.propagate("NVDA", "2024-05-10") +_, decision = ta.propagate( + "NVDA", + "2024-05-10", + custom_prompt="Long-term horizon. Focus on durable earnings power and capital allocation.", +) print(decision) # Memorize mistakes and reflect diff --git a/tests/test_custom_prompt.py b/tests/test_custom_prompt.py new file mode 100644 index 00000000..f71168c9 --- /dev/null +++ b/tests/test_custom_prompt.py @@ -0,0 +1,57 @@ +import tempfile +import unittest +from pathlib import Path + +from cli.main import save_report_to_disk +from tradingagents.agents.utils.agent_utils import build_custom_prompt_context +from tradingagents.graph.propagation import Propagator + + +class CustomPromptTests(unittest.TestCase): + def test_build_custom_prompt_context_is_empty_when_missing(self): + self.assertEqual(build_custom_prompt_context(" "), "") + + def test_build_custom_prompt_context_formats_user_guidance(self): + context = build_custom_prompt_context( + "Long-term horizon; focus on earnings quality and capex discipline." + ) + self.assertIn("Additional user instructions", context) + self.assertIn("Long-term horizon", context) + self.assertIn("explicit strategy constraints", context) + + def test_create_initial_state_stores_custom_prompt(self): + state = Propagator().create_initial_state( + "META", + "2026-04-05", + custom_prompt="Short-term swing trade; new positions only.", + ) + self.assertEqual( + state["custom_prompt"], + "Short-term swing trade; new positions only.", + ) + + def test_save_report_to_disk_includes_custom_prompt_header(self): + final_state = { + "custom_prompt": "Long-term horizon; focus on capital allocation.", + "market_report": "", + "sentiment_report": "", + "news_report": "", + "fundamentals_report": "", + "investment_debate_state": {}, + "risk_debate_state": {}, + } + + with tempfile.TemporaryDirectory() as tmpdir: + report_path = save_report_to_disk( + final_state, + "META", + Path(tmpdir), + ) + report_text = Path(report_path).read_text() + + self.assertIn("## Custom Prompt", report_text) + self.assertIn("Long-term horizon", report_text) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/agents/analysts/fundamentals_analyst.py b/tradingagents/agents/analysts/fundamentals_analyst.py index 6aa49cf3..97b49b55 100644 --- a/tradingagents/agents/analysts/fundamentals_analyst.py +++ b/tradingagents/agents/analysts/fundamentals_analyst.py @@ -1,5 +1,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from tradingagents.agents.utils.agent_utils import ( + build_custom_prompt_context, build_instrument_context, get_balance_sheet, get_cashflow, @@ -15,6 +16,7 @@ def create_fundamentals_analyst(llm): def fundamentals_analyst_node(state): current_date = state["trade_date"] instrument_context = build_instrument_context(state["company_of_interest"]) + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) tools = [ get_fundamentals, @@ -41,7 +43,7 @@ def create_fundamentals_analyst(llm): " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. {instrument_context}", + "For your reference, the current date is {current_date}. {instrument_context}\n{custom_prompt_context}", ), MessagesPlaceholder(variable_name="messages"), ] @@ -51,6 +53,7 @@ def create_fundamentals_analyst(llm): prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) + prompt = prompt.partial(custom_prompt_context=custom_prompt_context) chain = prompt | llm.bind_tools(tools) diff --git a/tradingagents/agents/analysts/market_analyst.py b/tradingagents/agents/analysts/market_analyst.py index fef8f751..6d3f7636 100644 --- a/tradingagents/agents/analysts/market_analyst.py +++ b/tradingagents/agents/analysts/market_analyst.py @@ -1,5 +1,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from tradingagents.agents.utils.agent_utils import ( + build_custom_prompt_context, build_instrument_context, get_indicators, get_language_instruction, @@ -13,6 +14,7 @@ def create_market_analyst(llm): def market_analyst_node(state): current_date = state["trade_date"] instrument_context = build_instrument_context(state["company_of_interest"]) + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) tools = [ get_stock_data, @@ -60,7 +62,7 @@ Volume-Based Indicators: " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. {instrument_context}", + "For your reference, the current date is {current_date}. {instrument_context}\n{custom_prompt_context}", ), MessagesPlaceholder(variable_name="messages"), ] @@ -70,6 +72,7 @@ Volume-Based Indicators: prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) + prompt = prompt.partial(custom_prompt_context=custom_prompt_context) chain = prompt | llm.bind_tools(tools) diff --git a/tradingagents/agents/analysts/news_analyst.py b/tradingagents/agents/analysts/news_analyst.py index e0fe93c5..bb94f5a7 100644 --- a/tradingagents/agents/analysts/news_analyst.py +++ b/tradingagents/agents/analysts/news_analyst.py @@ -1,5 +1,6 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from tradingagents.agents.utils.agent_utils import ( + build_custom_prompt_context, build_instrument_context, get_global_news, get_language_instruction, @@ -12,6 +13,7 @@ def create_news_analyst(llm): def news_analyst_node(state): current_date = state["trade_date"] instrument_context = build_instrument_context(state["company_of_interest"]) + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) tools = [ get_news, @@ -35,7 +37,7 @@ def create_news_analyst(llm): " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. {instrument_context}", + "For your reference, the current date is {current_date}. {instrument_context}\n{custom_prompt_context}", ), MessagesPlaceholder(variable_name="messages"), ] @@ -45,6 +47,7 @@ def create_news_analyst(llm): prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) + prompt = prompt.partial(custom_prompt_context=custom_prompt_context) chain = prompt | llm.bind_tools(tools) result = chain.invoke(state["messages"]) diff --git a/tradingagents/agents/analysts/social_media_analyst.py b/tradingagents/agents/analysts/social_media_analyst.py index 34a53c46..5347a2f6 100644 --- a/tradingagents/agents/analysts/social_media_analyst.py +++ b/tradingagents/agents/analysts/social_media_analyst.py @@ -1,5 +1,10 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction, get_news +from tradingagents.agents.utils.agent_utils import ( + build_custom_prompt_context, + build_instrument_context, + get_language_instruction, + get_news, +) from tradingagents.dataflows.config import get_config @@ -7,6 +12,7 @@ def create_social_media_analyst(llm): def social_media_analyst_node(state): current_date = state["trade_date"] instrument_context = build_instrument_context(state["company_of_interest"]) + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) tools = [ get_news, @@ -29,7 +35,7 @@ def create_social_media_analyst(llm): " If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable," " prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop." " You have access to the following tools: {tool_names}.\n{system_message}" - "For your reference, the current date is {current_date}. {instrument_context}", + "For your reference, the current date is {current_date}. {instrument_context}\n{custom_prompt_context}", ), MessagesPlaceholder(variable_name="messages"), ] @@ -39,6 +45,7 @@ def create_social_media_analyst(llm): prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) prompt = prompt.partial(current_date=current_date) prompt = prompt.partial(instrument_context=instrument_context) + prompt = prompt.partial(custom_prompt_context=custom_prompt_context) chain = prompt | llm.bind_tools(tools) diff --git a/tradingagents/agents/managers/portfolio_manager.py b/tradingagents/agents/managers/portfolio_manager.py index 6c69ae9f..6a11f19e 100644 --- a/tradingagents/agents/managers/portfolio_manager.py +++ b/tradingagents/agents/managers/portfolio_manager.py @@ -1,10 +1,15 @@ -from tradingagents.agents.utils.agent_utils import build_instrument_context, get_language_instruction +from tradingagents.agents.utils.agent_utils import ( + build_custom_prompt_context, + build_instrument_context, + get_language_instruction, +) def create_portfolio_manager(llm, memory): def portfolio_manager_node(state) -> dict: instrument_context = build_instrument_context(state["company_of_interest"]) + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) history = state["risk_debate_state"]["history"] risk_debate_state = state["risk_debate_state"] @@ -15,7 +20,7 @@ def create_portfolio_manager(llm, memory): research_plan = state["investment_plan"] trader_plan = state["trader_investment_plan"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" @@ -25,6 +30,7 @@ def create_portfolio_manager(llm, memory): prompt = f"""As the Portfolio Manager, synthesize the risk analysts' debate and deliver the final trading decision. {instrument_context} +{custom_prompt_context} --- diff --git a/tradingagents/agents/managers/research_manager.py b/tradingagents/agents/managers/research_manager.py index 5b4b4fdc..167ff6c4 100644 --- a/tradingagents/agents/managers/research_manager.py +++ b/tradingagents/agents/managers/research_manager.py @@ -1,10 +1,14 @@ -from tradingagents.agents.utils.agent_utils import build_instrument_context +from tradingagents.agents.utils.agent_utils import ( + build_custom_prompt_context, + build_instrument_context, +) def create_research_manager(llm, memory): def research_manager_node(state) -> dict: instrument_context = build_instrument_context(state["company_of_interest"]) + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) history = state["investment_debate_state"].get("history", "") market_research_report = state["market_report"] sentiment_report = state["sentiment_report"] @@ -13,7 +17,7 @@ def create_research_manager(llm, memory): investment_debate_state = state["investment_debate_state"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" @@ -35,6 +39,7 @@ Here are your past reflections on mistakes: \"{past_memory_str}\" {instrument_context} +{custom_prompt_context} Here is the debate: Debate History: diff --git a/tradingagents/agents/researchers/bear_researcher.py b/tradingagents/agents/researchers/bear_researcher.py index a44212dc..d6d3db58 100644 --- a/tradingagents/agents/researchers/bear_researcher.py +++ b/tradingagents/agents/researchers/bear_researcher.py @@ -1,8 +1,10 @@ +from tradingagents.agents.utils.agent_utils import build_custom_prompt_context def create_bear_researcher(llm, memory): def bear_node(state) -> dict: investment_debate_state = state["investment_debate_state"] + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) history = investment_debate_state.get("history", "") bear_history = investment_debate_state.get("bear_history", "") @@ -12,7 +14,7 @@ def create_bear_researcher(llm, memory): news_report = state["news_report"] fundamentals_report = state["fundamentals_report"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" @@ -38,6 +40,7 @@ Company fundamentals report: {fundamentals_report} Conversation history of the debate: {history} Last bull argument: {current_response} Reflections from similar situations and lessons learned: {past_memory_str} +{custom_prompt_context} Use this information to deliver a compelling bear argument, refute the bull's claims, and engage in a dynamic debate that demonstrates the risks and weaknesses of investing in the stock. You must also address reflections and learn from lessons and mistakes you made in the past. """ diff --git a/tradingagents/agents/researchers/bull_researcher.py b/tradingagents/agents/researchers/bull_researcher.py index d23d4d76..dfafcd5d 100644 --- a/tradingagents/agents/researchers/bull_researcher.py +++ b/tradingagents/agents/researchers/bull_researcher.py @@ -1,8 +1,10 @@ +from tradingagents.agents.utils.agent_utils import build_custom_prompt_context def create_bull_researcher(llm, memory): def bull_node(state) -> dict: investment_debate_state = state["investment_debate_state"] + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) history = investment_debate_state.get("history", "") bull_history = investment_debate_state.get("bull_history", "") @@ -12,7 +14,7 @@ def create_bull_researcher(llm, memory): news_report = state["news_report"] fundamentals_report = state["fundamentals_report"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" @@ -36,6 +38,7 @@ Company fundamentals report: {fundamentals_report} Conversation history of the debate: {history} Last bear argument: {current_response} Reflections from similar situations and lessons learned: {past_memory_str} +{custom_prompt_context} Use this information to deliver a compelling bull argument, refute the bear's concerns, and engage in a dynamic debate that demonstrates the strengths of the bull position. You must also address reflections and learn from lessons and mistakes you made in the past. """ diff --git a/tradingagents/agents/risk_mgmt/aggressive_debator.py b/tradingagents/agents/risk_mgmt/aggressive_debator.py index 2dab1152..dca6aaab 100644 --- a/tradingagents/agents/risk_mgmt/aggressive_debator.py +++ b/tradingagents/agents/risk_mgmt/aggressive_debator.py @@ -1,8 +1,10 @@ +from tradingagents.agents.utils.agent_utils import build_custom_prompt_context def create_aggressive_debator(llm): def aggressive_node(state) -> dict: risk_debate_state = state["risk_debate_state"] + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) history = risk_debate_state.get("history", "") aggressive_history = risk_debate_state.get("aggressive_history", "") @@ -26,6 +28,7 @@ Market Research Report: {market_research_report} Social Media Sentiment Report: {sentiment_report} Latest World Affairs Report: {news_report} Company Fundamentals Report: {fundamentals_report} +{custom_prompt_context} Here is the current conversation history: {history} Here are the last arguments from the conservative analyst: {current_conservative_response} Here are the last arguments from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data. Engage actively by addressing any specific concerns raised, refuting the weaknesses in their logic, and asserting the benefits of risk-taking to outpace market norms. Maintain a focus on debating and persuading, not just presenting data. Challenge each counterpoint to underscore why a high-risk approach is optimal. Output conversationally as if you are speaking without any special formatting.""" diff --git a/tradingagents/agents/risk_mgmt/conservative_debator.py b/tradingagents/agents/risk_mgmt/conservative_debator.py index 99a8315e..c1139a14 100644 --- a/tradingagents/agents/risk_mgmt/conservative_debator.py +++ b/tradingagents/agents/risk_mgmt/conservative_debator.py @@ -1,8 +1,10 @@ +from tradingagents.agents.utils.agent_utils import build_custom_prompt_context def create_conservative_debator(llm): def conservative_node(state) -> dict: risk_debate_state = state["risk_debate_state"] + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) history = risk_debate_state.get("history", "") conservative_history = risk_debate_state.get("conservative_history", "") @@ -26,6 +28,7 @@ Market Research Report: {market_research_report} Social Media Sentiment Report: {sentiment_report} Latest World Affairs Report: {news_report} Company Fundamentals Report: {fundamentals_report} +{custom_prompt_context} Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the neutral analyst: {current_neutral_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data. Engage by questioning their optimism and emphasizing the potential downsides they may have overlooked. Address each of their counterpoints to showcase why a conservative stance is ultimately the safest path for the firm's assets. Focus on debating and critiquing their arguments to demonstrate the strength of a low-risk strategy over their approaches. Output conversationally as if you are speaking without any special formatting.""" diff --git a/tradingagents/agents/risk_mgmt/neutral_debator.py b/tradingagents/agents/risk_mgmt/neutral_debator.py index e99ff0af..c94a9b0b 100644 --- a/tradingagents/agents/risk_mgmt/neutral_debator.py +++ b/tradingagents/agents/risk_mgmt/neutral_debator.py @@ -1,8 +1,10 @@ +from tradingagents.agents.utils.agent_utils import build_custom_prompt_context def create_neutral_debator(llm): def neutral_node(state) -> dict: risk_debate_state = state["risk_debate_state"] + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) history = risk_debate_state.get("history", "") neutral_history = risk_debate_state.get("neutral_history", "") @@ -26,6 +28,7 @@ Market Research Report: {market_research_report} Social Media Sentiment Report: {sentiment_report} Latest World Affairs Report: {news_report} Company Fundamentals Report: {fundamentals_report} +{custom_prompt_context} Here is the current conversation history: {history} Here is the last response from the aggressive analyst: {current_aggressive_response} Here is the last response from the conservative analyst: {current_conservative_response}. If there are no responses from the other viewpoints yet, present your own argument based on the available data. Engage actively by analyzing both sides critically, addressing weaknesses in the aggressive and conservative arguments to advocate for a more balanced approach. Challenge each of their points to illustrate why a moderate risk strategy might offer the best of both worlds, providing growth potential while safeguarding against extreme volatility. Focus on debating rather than simply presenting data, aiming to show that a balanced view can lead to the most reliable outcomes. Output conversationally as if you are speaking without any special formatting.""" diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index 07e9f262..91bfe011 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -1,19 +1,23 @@ import functools -from tradingagents.agents.utils.agent_utils import build_instrument_context +from tradingagents.agents.utils.agent_utils import ( + build_custom_prompt_context, + build_instrument_context, +) def create_trader(llm, memory): def trader_node(state, name): company_name = state["company_of_interest"] instrument_context = build_instrument_context(company_name) + custom_prompt_context = build_custom_prompt_context(state.get("custom_prompt")) investment_plan = state["investment_plan"] market_research_report = state["market_report"] sentiment_report = state["sentiment_report"] news_report = state["news_report"] fundamentals_report = state["fundamentals_report"] - curr_situation = f"{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" + curr_situation = f"{custom_prompt_context}\n\n{market_research_report}\n\n{sentiment_report}\n\n{news_report}\n\n{fundamentals_report}" past_memories = memory.get_memories(curr_situation, n_matches=2) past_memory_str = "" @@ -25,7 +29,7 @@ def create_trader(llm, memory): context = { "role": "user", - "content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. {instrument_context} This plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.", + "content": f"Based on a comprehensive analysis by a team of analysts, here is an investment plan tailored for {company_name}. {instrument_context}\n\n{custom_prompt_context}\n\nThis plan incorporates insights from current technical market trends, macroeconomic indicators, and social media sentiment. Use this plan as a foundation for evaluating your next trading decision.\n\nProposed Investment Plan: {investment_plan}\n\nLeverage these insights to make an informed and strategic decision.", } messages = [ diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 6423b936..ed10b708 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -46,6 +46,7 @@ class RiskDebateState(TypedDict): class AgentState(MessagesState): company_of_interest: Annotated[str, "Company that we are interested in trading"] trade_date: Annotated[str, "What date we are trading at"] + custom_prompt: Annotated[str, "Optional run-specific instructions from the user"] sender: Annotated[str, "Agent that sent this message"] diff --git a/tradingagents/agents/utils/agent_utils.py b/tradingagents/agents/utils/agent_utils.py index 4ba40a80..87fd3c14 100644 --- a/tradingagents/agents/utils/agent_utils.py +++ b/tradingagents/agents/utils/agent_utils.py @@ -18,6 +18,10 @@ from tradingagents.agents.utils.news_data_tools import ( get_insider_transactions, get_global_news ) +from tradingagents.custom_prompt import ( + CUSTOM_PROMPT_CONTEXT_FOOTER, + CUSTOM_PROMPT_CONTEXT_HEADER, +) def get_language_instruction() -> str: @@ -34,6 +38,14 @@ def get_language_instruction() -> str: return f" Write your entire response in {lang}." +def build_custom_prompt_context(custom_prompt: str | None) -> str: + """Format optional run-specific instructions supplied by the user.""" + prompt = (custom_prompt or "").strip() + if not prompt: + return "" + return f"{CUSTOM_PROMPT_CONTEXT_HEADER}\n{prompt}\n{CUSTOM_PROMPT_CONTEXT_FOOTER}" + + def build_instrument_context(ticker: str) -> str: """Describe the exact instrument so agents preserve exchange-qualified tickers.""" return ( diff --git a/tradingagents/custom_prompt.py b/tradingagents/custom_prompt.py new file mode 100644 index 00000000..025d4751 --- /dev/null +++ b/tradingagents/custom_prompt.py @@ -0,0 +1,13 @@ +"""Shared constants for optional run-specific custom prompt handling.""" + +CUSTOM_PROMPT_SECTION_TITLE = "Custom Prompt" +CUSTOM_PROMPT_STATUS_LABEL = "Custom prompt" +CUSTOM_PROMPT_INPUT_PROMPT = ( + "Optional custom prompt for this run (examples: short-term swing trade, " + "long-term holding, only new positions, focus on earnings and capex):" +) +CUSTOM_PROMPT_CONTEXT_HEADER = "Additional user instructions for this run:" +CUSTOM_PROMPT_CONTEXT_FOOTER = ( + "Treat these as explicit strategy constraints and address them directly in " + "your analysis, trade plan, and final recommendation." +) diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index 0fd10c0c..312f69b4 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -16,13 +16,14 @@ class Propagator: self.max_recur_limit = max_recur_limit def create_initial_state( - self, company_name: str, trade_date: str + self, company_name: str, trade_date: str, custom_prompt: str | None = None ) -> Dict[str, Any]: """Create the initial state for the agent graph.""" return { "messages": [("human", company_name)], "company_of_interest": company_name, "trade_date": str(trade_date), + "custom_prompt": (custom_prompt or "").strip(), "investment_debate_state": InvestDebateState( { "bull_history": "", diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index 78bc13e5..5b0cd3b8 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -189,14 +189,14 @@ class TradingAgentsGraph: ), } - def propagate(self, company_name, trade_date): + def propagate(self, company_name, trade_date, custom_prompt: str | None = None): """Run the trading agents graph for a company on a specific date.""" self.ticker = company_name # Initialize state init_agent_state = self.propagator.create_initial_state( - company_name, trade_date + company_name, trade_date, custom_prompt=custom_prompt ) args = self.propagator.get_graph_args() @@ -229,6 +229,7 @@ class TradingAgentsGraph: self.log_states_dict[str(trade_date)] = { "company_of_interest": final_state["company_of_interest"], "trade_date": final_state["trade_date"], + "custom_prompt": final_state.get("custom_prompt", ""), "market_report": final_state["market_report"], "sentiment_report": final_state["sentiment_report"], "news_report": final_state["news_report"],