From 39b8b6db84a3261e46542c65ac51759ebe5e7e1b Mon Sep 17 00:00:00 2001 From: ElMoorish Date: Sun, 5 Apr 2026 04:14:05 +0100 Subject: [PATCH] feat: Add Trade Strategist node + SQLite checkpointing + .gitignore restore - Add Trade Strategist agent (trade_strategist_node.py): generates 5 actionable trade setups (Entry, SL, TP, R:R, Win%) after Portfolio Manager - Wire Trade Strategist into LangGraph pipeline (setup.py) - Add 'trade_possibilities' field to AgentState schema - Initialize trade_possibilities in Propagator initial state - Integrate SqliteSaver (langgraph-checkpoint-sqlite) for node-by-node state persistence to trading_agents_state.sqlite - Add unique thread_id (ticker_date) to graph config for checkpoint isolation - Update CLI (main.py) to display Trade Strategist progress + save report section - Restore full standard Python .gitignore (200+ rules) stripped in prior PR - Fix: convert trade_strategist_node to factory function (create_trade_strategist) to resolve LangGraph TypeError on missing 'llm' argument - Fix: use sqlite3.connect() directly instead of SqliteSaver.from_conn_string() to avoid _GeneratorContextManager TypeError --- .gitignore | 192 +++++++++++++++++- cli/main.py | 19 +- tradingagents/agents/trade_strategist_node.py | 55 +++++ tradingagents/agents/utils/agent_states.py | 3 + tradingagents/graph/propagation.py | 13 +- tradingagents/graph/setup.py | 15 +- tradingagents/graph/trading_graph.py | 3 +- 7 files changed, 287 insertions(+), 13 deletions(-) create mode 100644 tradingagents/agents/trade_strategist_node.py diff --git a/.gitignore b/.gitignore index 7d6db80c..311eda45 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,195 @@ -venv/ +# ============================ +# TradingAgents Custom Ignores +# ============================ + +# API Keys / Secrets .env -__pycache__/ -*.egg-info/ +.envrc + +# Virtual environments +venv/ +.venv +env/ +ENV/ +env.bak/ +venv.bak/ + +# Reports and results (generated output) reports/ results/ +eval_results/ + +# SQLite state database +*.sqlite + +# ============================ +# Standard Python Ignores +# ============================ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python build/ +develop-eggs/ dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# Pyenv +.python-version + +# pipenv +Pipfile.lock + +# poetry +poetry.lock + +# pdm +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582 +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# Ruff +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# ============================ +# IDE / Editor Ignores +# ============================ + +# PyCharm +.idea/ + +# Visual Studio Code +.vscode/ + +# ============================ +# OS Ignores +# ============================ + +# macOS .DS_Store +.AppleDouble +.LSOverride + +# Windows +Thumbs.db +ehthumbs.db +Desktop.ini +$RECYCLE.BIN/ + +# ============================ +# Cache / Data +# ============================ + +**/data_cache/ +.streamlit/secrets.toml diff --git a/cli/main.py b/cli/main.py index 5c59d23d..db87409d 100644 --- a/cli/main.py +++ b/cli/main.py @@ -46,7 +46,7 @@ class MessageBuffer: "Research Team": ["Bull Researcher", "Bear Researcher", "Research Manager"], "Trading Team": ["Trader"], "Risk Management": ["Aggressive Analyst", "Neutral Analyst", "Conservative Analyst"], - "Portfolio Management": ["Portfolio Manager"], + "Portfolio Management": ["Portfolio Manager", "Trade Strategist"], } # Analyst name mapping @@ -68,6 +68,7 @@ class MessageBuffer: "investment_plan": (None, "Research Manager"), "trader_investment_plan": (None, "Trader"), "final_trade_decision": (None, "Portfolio Manager"), + "trade_possibilities": (None, "Trade Strategist"), } def __init__(self, max_length=100): @@ -1046,8 +1047,11 @@ def run_analysis(): selections["ticker"], selections["analysis_date"] ) # Pass callbacks to graph config for tool execution tracking - # (LLM tracking is handled separately via LLM constructor) - args = graph.propagator.get_graph_args(callbacks=[stats_handler]) + args = graph.propagator.get_graph_args( + selections["ticker"], + selections["analysis_date"], + callbacks=[stats_handler] + ) # Stream the analysis trace = [] @@ -1148,6 +1152,15 @@ def run_analysis(): message_buffer.update_agent_status("Conservative Analyst", "completed") message_buffer.update_agent_status("Neutral Analyst", "completed") message_buffer.update_agent_status("Portfolio Manager", "completed") + message_buffer.update_agent_status("Trade Strategist", "in_progress") + + # Trade Strategist + if chunk.get("trade_possibilities"): + message_buffer.update_report_section( + "trade_possibilities", chunk["trade_possibilities"] + ) + if message_buffer.agent_status.get("Trade Strategist") != "completed": + message_buffer.update_agent_status("Trade Strategist", "completed") # Update the display update_display(layout, stats_handler=stats_handler, start_time=start_time) diff --git a/tradingagents/agents/trade_strategist_node.py b/tradingagents/agents/trade_strategist_node.py new file mode 100644 index 00000000..e936a982 --- /dev/null +++ b/tradingagents/agents/trade_strategist_node.py @@ -0,0 +1,55 @@ +from langchain_core.prompts import ChatPromptTemplate +from tradingagents.agents.utils.agent_states import AgentState + +def create_trade_strategist(llm): + def trade_strategist_node(state: AgentState): + """ + Agent that analyzes the final trade decision and outputs 5 distinct trade setups. + """ + + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + """You are an elite Trade Strategist at a premier quantitative hedge fund. +Your job is to take the final consensus decision from the Portfolio Manager and the Trader's investment plan, and synthesize them into exactly 5 specific, actionable trade setups. + +For the given asset, you must provide exactly 5 trade possibilities with the following parameters explicitly defined for each: +- Trade Direction (Long/Short, Options, etc.) +- Entry Price / Condition (e.g., Buy at market, Limit buy at $X, Wait for breakout above $X) +- Stop Loss (SL) (Specific price level) +- Take Profit (TP) (Specific price level) +- Risk/Reward Ratio +- Estimated Win Percentage (Probability of success based on current technicals/fundamentals, e.g., 65%) +- Brief Rationale (1-2 sentences explaining why this setup makes sense) + +Format your output as a clean, highly readable Markdown document. +Do not output anything besides the 5 trades and a brief introductory/concluding sentence. +Use bullet points and bold text for the parameters so they are easily scannable.""" + ), + ( + "human", + """Asset: {company} + +Portfolio Manager's Final Decision: +{final_decision} + +Trader's Investment Plan: +{trader_plan} + +Please formulate the 5 Trade Possibilities based on the above data.""" + ), + ] + ) + + chain = prompt | llm + + result = chain.invoke({ + "company": state.get("company_of_interest", ""), + "final_decision": state.get("final_trade_decision", ""), + "trader_plan": state.get("trader_investment_plan", "") + }) + + return {"trade_possibilities": result.content} + + return trade_strategist_node diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index 813b00ee..7644ed65 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -74,3 +74,6 @@ class AgentState(MessagesState): RiskDebateState, "Current state of the debate on evaluating risk" ] final_trade_decision: Annotated[str, "Final decision made by the Risk Analysts"] + + # final strategies + trade_possibilities: Annotated[str, "5 Trade Possibilities with percentages, SL/TP"] diff --git a/tradingagents/graph/propagation.py b/tradingagents/graph/propagation.py index 0fd10c0c..dafbe5f0 100644 --- a/tradingagents/graph/propagation.py +++ b/tradingagents/graph/propagation.py @@ -11,7 +11,7 @@ from tradingagents.agents.utils.agent_states import ( class Propagator: """Handles state initialization and propagation through the graph.""" - def __init__(self, max_recur_limit=100): + def __init__(self, max_recur_limit=1000): """Initialize with configuration parameters.""" self.max_recur_limit = max_recur_limit @@ -51,16 +51,23 @@ class Propagator: "fundamentals_report": "", "sentiment_report": "", "news_report": "", + "trade_possibilities": "", } - def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]: + def get_graph_args(self, company_name: str, trade_date: str, callbacks: Optional[List] = None) -> Dict[str, Any]: """Get arguments for the graph invocation. Args: + company_name: The ticker being analyzed + trade_date: The date of analysis callbacks: Optional list of callback handlers for tool execution tracking. Note: LLM callbacks are handled separately via LLM constructor. """ - config = {"recursion_limit": self.max_recur_limit} + thread_id = f"{company_name}_{trade_date}" + config = { + "recursion_limit": self.max_recur_limit, + "configurable": {"thread_id": thread_id} + } if callbacks: config["callbacks"] = callbacks return { diff --git a/tradingagents/graph/setup.py b/tradingagents/graph/setup.py index e0771c65..57a60ca1 100644 --- a/tradingagents/graph/setup.py +++ b/tradingagents/graph/setup.py @@ -2,11 +2,15 @@ from typing import Dict, Any from langchain_openai import ChatOpenAI + from langgraph.graph import END, StateGraph, START from langgraph.prebuilt import ToolNode +import sqlite3 +from langgraph.checkpoint.sqlite import SqliteSaver from tradingagents.agents import * from tradingagents.agents.utils.agent_states import AgentState +from tradingagents.agents.trade_strategist_node import create_trade_strategist from .conditional_logic import ConditionalLogic @@ -104,6 +108,7 @@ class GraphSetup: portfolio_manager_node = create_portfolio_manager( self.deep_thinking_llm, self.portfolio_manager_memory ) + trade_strategist_node = create_trade_strategist(self.quick_thinking_llm) # Create workflow workflow = StateGraph(AgentState) @@ -125,6 +130,7 @@ class GraphSetup: workflow.add_node("Neutral Analyst", neutral_analyst) workflow.add_node("Conservative Analyst", conservative_analyst) workflow.add_node("Portfolio Manager", portfolio_manager_node) + workflow.add_node("Trade Strategist", trade_strategist_node) # Define edges # Start with the first analyst @@ -196,7 +202,10 @@ class GraphSetup: }, ) - workflow.add_edge("Portfolio Manager", END) + workflow.add_edge("Portfolio Manager", "Trade Strategist") + workflow.add_edge("Trade Strategist", END) - # Compile and return - return workflow.compile() + # Compile and return with SQLite memory + conn = sqlite3.connect("trading_agents_state.sqlite", check_same_thread=False) + memory = SqliteSaver(conn) + return workflow.compile(checkpointer=memory) diff --git a/tradingagents/graph/trading_graph.py b/tradingagents/graph/trading_graph.py index c8cd7492..1103783b 100644 --- a/tradingagents/graph/trading_graph.py +++ b/tradingagents/graph/trading_graph.py @@ -200,7 +200,7 @@ class TradingAgentsGraph: init_agent_state = self.propagator.create_initial_state( company_name, trade_date ) - args = self.propagator.get_graph_args() + args = self.propagator.get_graph_args(company_name, trade_date) if self.debug: # Debug mode with tracing @@ -256,6 +256,7 @@ class TradingAgentsGraph: }, "investment_plan": final_state["investment_plan"], "final_trade_decision": final_state["final_trade_decision"], + "trade_possibilities": final_state.get("trade_possibilities", ""), } # Save to file