diff --git a/README.md b/README.md index 03a05480..410229f4 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,25 @@ print(decision) You can view the full list of configurations in `tradingagents/default_config.py`. +## Persistent Memory and Learning + +To allow the agents to learn from the success or failure of previous decisions, TradingAgents includes a persistent memory mechanism. + +Each agent's reflections and the "lessons learned" from past trading sessions are stored on disk. This allows the system to build a rich, searchable history of its actions and their consequences, enabling more informed decisions in the future. + +- **Storage**: The memory is managed by the `FinancialSituationMemory` class in `tradingagents/agents/utils/memory.py` and is persisted to the `./memory_store/` directory using a local ChromaDB database. +- **Learning Loop**: After a trade, a `Reflector` agent analyzes the outcome (profit or loss) and generates a "lesson." This lesson is stored in the memory, linked to the market conditions at the time. Before the next trade, agents query this memory for similar past situations to retrieve relevant lessons, which are then used to inform their decision-making process. + +### Inspecting the Memory + +You can inspect the contents of the persistent memory to see what the agents have learned. To do this, run the memory utility script from the root of the project: + +```bash +python -m tradingagents.agents.utils.memory +``` + +The first time you run this, it will populate the memory with example data. Subsequent runs will load and display the data from the `memory_store` directory, demonstrating that the memory persists across sessions. + ## Contributing We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/). diff --git a/cli/main.py b/cli/main.py index c148c522..03c1858b 100644 --- a/cli/main.py +++ b/cli/main.py @@ -73,6 +73,10 @@ class MessageBuffer: "final_trade_decision": None, } + def _format_report_content(self, content): + """Ensures content is a string.""" + return str(content) + def add_message(self, message_type, content): timestamp = datetime.datetime.now().strftime("%H:%M:%S") self.messages.append((timestamp, message_type, content)) @@ -100,7 +104,7 @@ class MessageBuffer: for section, content in self.report_sections.items(): if content is not None: latest_section = section - latest_content = content + latest_content = self._format_report_content(content) if latest_section and latest_content: # Format the current section for display @@ -136,35 +140,35 @@ class MessageBuffer: report_parts.append("## Analyst Team Reports") if self.report_sections["market_report"]: report_parts.append( - f"### Market Analysis\n{self.report_sections['market_report']}" + f"### Market Analysis\n{self._format_report_content(self.report_sections['market_report'])}" ) if self.report_sections["sentiment_report"]: report_parts.append( - f"### Social Sentiment\n{self.report_sections['sentiment_report']}" + f"### Social Sentiment\n{self._format_report_content(self.report_sections['sentiment_report'])}" ) if self.report_sections["news_report"]: report_parts.append( - f"### News Analysis\n{self.report_sections['news_report']}" + f"### News Analysis\n{self._format_report_content(self.report_sections['news_report'])}" ) if self.report_sections["fundamentals_report"]: report_parts.append( - f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}" + f"### Fundamentals Analysis\n{self._format_report_content(self.report_sections['fundamentals_report'])}" ) # Research Team Reports if self.report_sections["investment_plan"]: report_parts.append("## Research Team Decision") - report_parts.append(f"{self.report_sections['investment_plan']}") + report_parts.append(f"{self._format_report_content(self.report_sections['investment_plan'])}") # Trading Team Reports if self.report_sections["trader_investment_plan"]: report_parts.append("## Trading Team Plan") - report_parts.append(f"{self.report_sections['trader_investment_plan']}") + report_parts.append(f"{self._format_report_content(self.report_sections['trader_investment_plan'])}") # Portfolio Management Decision if self.report_sections["final_trade_decision"]: report_parts.append("## Portfolio Management Decision") - report_parts.append(f"{self.report_sections['final_trade_decision']}") + report_parts.append(f"{self._format_report_content(self.report_sections['final_trade_decision'])}") self.final_report = "\n\n".join(report_parts) if report_parts else None @@ -550,6 +554,10 @@ def display_complete_report(final_state): """Display the complete analysis report with team-based panels.""" console.print("\n[bold green]Complete Analysis Report[/bold green]\n") + def _format_content_for_markdown(content): + """Ensures content is a string.""" + return str(content) + # User Position user_position = final_state.get("user_position", "none") cost_per_trade = final_state.get("cost_per_trade", 0.0) @@ -562,7 +570,7 @@ def display_complete_report(final_state): if final_state.get("market_report"): analyst_reports.append( Panel( - Markdown(final_state["market_report"]), + Markdown(_format_content_for_markdown(final_state["market_report"])), title="Market Analyst", border_style="blue", padding=(1, 2), @@ -573,7 +581,7 @@ def display_complete_report(final_state): if final_state.get("sentiment_report"): analyst_reports.append( Panel( - Markdown(final_state["sentiment_report"]), + Markdown(_format_content_for_markdown(final_state["sentiment_report"])), title="Social Analyst", border_style="blue", padding=(1, 2), @@ -584,7 +592,7 @@ def display_complete_report(final_state): if final_state.get("news_report"): analyst_reports.append( Panel( - Markdown(final_state["news_report"]), + Markdown(_format_content_for_markdown(final_state["news_report"])), title="News Analyst", border_style="blue", padding=(1, 2), @@ -595,7 +603,7 @@ def display_complete_report(final_state): if final_state.get("fundamentals_report"): analyst_reports.append( Panel( - Markdown(final_state["fundamentals_report"]), + Markdown(_format_content_for_markdown(final_state["fundamentals_report"])), title="Fundamentals Analyst", border_style="blue", padding=(1, 2), @@ -621,7 +629,7 @@ def display_complete_report(final_state): if debate_state.get("bull_history"): research_reports.append( Panel( - Markdown(debate_state["bull_history"]), + Markdown(_format_content_for_markdown(debate_state["bull_history"])), title="Bull Researcher", border_style="blue", padding=(1, 2), @@ -632,7 +640,7 @@ def display_complete_report(final_state): if debate_state.get("bear_history"): research_reports.append( Panel( - Markdown(debate_state["bear_history"]), + Markdown(_format_content_for_markdown(debate_state["bear_history"])), title="Bear Researcher", border_style="blue", padding=(1, 2), @@ -643,7 +651,7 @@ def display_complete_report(final_state): if debate_state.get("judge_decision"): research_reports.append( Panel( - Markdown(debate_state["judge_decision"]), + Markdown(_format_content_for_markdown(debate_state["judge_decision"])), title="Research Manager", border_style="blue", padding=(1, 2), @@ -665,7 +673,7 @@ def display_complete_report(final_state): console.print( Panel( Panel( - Markdown(final_state["trader_investment_plan"]), + Markdown(_format_content_for_markdown(final_state["trader_investment_plan"])), title="Trader", border_style="blue", padding=(1, 2), @@ -685,7 +693,7 @@ def display_complete_report(final_state): if risk_state.get("risky_history"): risk_reports.append( Panel( - Markdown(risk_state["risky_history"]), + Markdown(_format_content_for_markdown(risk_state["risky_history"])), title="Aggressive Analyst", border_style="blue", padding=(1, 2), @@ -696,7 +704,7 @@ def display_complete_report(final_state): if risk_state.get("safe_history"): risk_reports.append( Panel( - Markdown(risk_state["safe_history"]), + Markdown(_format_content_for_markdown(risk_state["safe_history"])), title="Conservative Analyst", border_style="blue", padding=(1, 2), @@ -707,7 +715,7 @@ def display_complete_report(final_state): if risk_state.get("neutral_history"): risk_reports.append( Panel( - Markdown(risk_state["neutral_history"]), + Markdown(_format_content_for_markdown(risk_state["neutral_history"])), title="Neutral Analyst", border_style="blue", padding=(1, 2), @@ -729,7 +737,7 @@ def display_complete_report(final_state): console.print( Panel( Panel( - Markdown(risk_state["judge_decision"]), + Markdown(_format_content_for_markdown(risk_state["judge_decision"])), title="Portfolio Manager", border_style="blue", padding=(1, 2), @@ -826,7 +834,7 @@ def run_analysis(): if content: file_name = f"{section_name}.md" with open(report_dir / file_name, "w", encoding="utf-8") as f: - f.write(content) + f.write(str(content)) return wrapper message_buffer.add_message = save_message_decorator(message_buffer, "add_message") @@ -836,7 +844,7 @@ def run_analysis(): # Now start the display layout layout = create_layout() - with Live(layout, refresh_per_second=4) as live: + with Live(layout, refresh_per_second=1) as live: # Initial display update_display(layout) diff --git a/cli/utils.py b/cli/utils.py index a90b2fc2..bfa9beff 100644 --- a/cli/utils.py +++ b/cli/utils.py @@ -250,7 +250,7 @@ def select_llm_provider() -> tuple[str, str]: ("OpenAI", "https://api.openai.com/v1"), ("Anthropic", "https://api.anthropic.com/"), ("Google", "https://generativelanguage.googleapis.com/v1"), - ("Openrouter", "https://openrouter.ai/api/v1"), + ("OpenRouter", "https://openrouter.ai/api/v1"), ("Ollama", "http://localhost:11434/v1"), ] diff --git a/memory_store/chroma.sqlite3 b/memory_store/chroma.sqlite3 new file mode 100644 index 00000000..61f5629e Binary files /dev/null and b/memory_store/chroma.sqlite3 differ diff --git a/tradingagents/agents/managers/risk_manager.py b/tradingagents/agents/managers/risk_manager.py index 0a3b6d89..2afb1c1d 100644 --- a/tradingagents/agents/managers/risk_manager.py +++ b/tradingagents/agents/managers/risk_manager.py @@ -54,7 +54,8 @@ Deliverables: Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes.""" response = llm.invoke(prompt) - + + final_decision_content = response.content new_risk_debate_state = { "judge_decision": response.content, "history": risk_debate_state["history"], @@ -70,7 +71,7 @@ Focus on actionable insights and continuous improvement. Build on past lessons, return { "risk_debate_state": new_risk_debate_state, - "final_trade_decision": response.content, + "final_trade_decision": final_decision_content, } return risk_manager_node diff --git a/tradingagents/agents/trader/trader.py b/tradingagents/agents/trader/trader.py index ec8f96dd..c4d600e7 100644 --- a/tradingagents/agents/trader/trader.py +++ b/tradingagents/agents/trader/trader.py @@ -40,7 +40,8 @@ def create_trader(llm, memory): - If the user has an open short position, your recommendation can be to maintain the short position, close the short position, or close the short position and open a long position. - If the user has no open position, your recommendation can be to do nothing, open a long position, or open a short position. -Based on your analysis, provide a specific recommendation. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **YOUR_RECOMMENDATION**' to confirm your recommendation. Take into account that any transaction will incur a cost of {cost_per_trade}, so the potential profit of a transaction must be greater than this cost. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""", +Based on your analysis, provide a specific recommendation. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **YOUR_RECOMMENDATION**' to confirm your recommendation. Take into account that any transaction will incur a cost of {cost_per_trade}, so the potential profit of a transaction must be greater than this cost. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situations you traded in and the lessons learned: {past_memory_str} +Your output should always be in markdown format.""", }, context, ] diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 34b94dd8..546af58c 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -1,10 +1,8 @@ import chromadb -from chromadb.config import Settings import os - class FinancialSituationMemory: - def __init__(self, name, config): + def __init__(self, name, config, persist_directory="./memory_store"): # Use local embeddings for all providers - no external API dependency self.use_local_embeddings = config.get("use_local_embeddings", True) @@ -40,19 +38,20 @@ class FinancialSituationMemory: self.client = OpenAI(base_url=config["backend_url"], api_key=api_key) self.embedding_type = "api" - self.chroma_client = chromadb.Client(Settings(allow_reset=True)) + self.chroma_client = chromadb.PersistentClient(path=persist_directory) # Create collection with or without custom embedding function if self.embedding_type == "chromadb_default": # Let ChromaDB handle embeddings with its default function - self.situation_collection = self.chroma_client.create_collection(name=name) + self.situation_collection = self.chroma_client.get_or_create_collection(name=name) else: # We'll handle embeddings ourselves - self.situation_collection = self.chroma_client.create_collection( + self.situation_collection = self.chroma_client.get_or_create_collection( name=name, metadata={"hnsw:space": "cosine"} # Use cosine similarity ) + def get_embedding(self, text): """Get embedding for a text using local or API-based models""" try: @@ -146,46 +145,80 @@ class FinancialSituationMemory: if __name__ == "__main__": + # Define the directory where memory will be stored + PERSIST_DIRECTORY = "./memory_store" + print(f"Memory will be persisted to: {os.path.abspath(PERSIST_DIRECTORY)}\n") + # Example usage - matcher = FinancialSituationMemory() + config_example = {"use_local_embeddings": True, "backend_url": ""} + # Initialize memory with a name and the persistence directory + matcher = FinancialSituationMemory( + name="persistent_example_memory", + config=config_example, + persist_directory=PERSIST_DIRECTORY + ) - # Example data - example_data = [ - ( - "High inflation rate with rising interest rates and declining consumer spending", - "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", - ), - ( - "Tech sector showing high volatility with increasing institutional selling pressure", - "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", - ), - ( - "Strong dollar affecting emerging markets with increasing forex volatility", - "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", - ), - ( - "Market showing signs of sector rotation with rising yields", - "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", - ), - ] + # Check if memory is already populated + if matcher.situation_collection.count() == 0: + print("Memory is empty. Populating with example data...") + # Example data + example_data = [ + ( + "High inflation rate with rising interest rates and declining consumer spending", + "Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.", + ), + ( + "Tech sector showing high volatility with increasing institutional selling pressure", + "Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.", + ), + ( + "Strong dollar affecting emerging markets with increasing forex volatility", + "Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.", + ), + ( + "Market showing signs of sector rotation with rising yields", + "Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.", + ), + ] + # Add the example situations and recommendations + matcher.add_situations(example_data) + print("Example data added to persistent memory.\n") + else: + print("Memory already contains data from a previous run.\n") - # Add the example situations and recommendations - matcher.add_situations(example_data) + # --- Inspecting the entire memory store --- + print("--- Dumping all contents of the memory store ---") + all_items = matcher.situation_collection.get(include=["metadatas", "documents"]) + + if not all_items or not all_items.get("ids"): + print("Memory store is empty.") + else: + for i, item_id in enumerate(all_items["ids"]): + situation = all_items["documents"][i] + recommendation = all_items["metadatas"][i].get("recommendation", "N/A") + print(f"ID: {item_id}") + print(f" Situation: {situation}") + print(f" Recommendation/Lesson: {recommendation}\n") + print("--- End of memory dump ---") - # Example query + # Example query to show it still works + print("\n--- Running an example query ---") current_situation = """ Market showing increased volatility in tech sector, with institutional investors reducing positions and rising interest rates affecting growth stock valuations """ + print(f"Querying for situation: {current_situation.strip()}\n") try: - recommendations = matcher.get_memories(current_situation, n_matches=2) - - for i, rec in enumerate(recommendations, 1): - print(f"\nMatch {i}:") - print(f"Similarity Score: {rec['similarity_score']:.2f}") - print(f"Matched Situation: {rec['matched_situation']}") - print(f"Recommendation: {rec['recommendation']}") + recommendations = matcher.get_memories(current_situation, n_matches=1) + if recommendations: + rec = recommendations[0] + print(f"Most similar match found:") + print(f" Similarity Score: {rec['similarity_score']:.2f}") + print(f" Matched Situation: {rec['matched_situation']}") + print(f" Retrieved Recommendation: {rec['recommendation']}\n") + else: + print("No similar situations found in memory.") except Exception as e: print(f"Error during recommendation: {str(e)}") diff --git a/tradingagents/graph/signal_processing.py b/tradingagents/graph/signal_processing.py index 903e8529..48ce6d9e 100644 --- a/tradingagents/graph/signal_processing.py +++ b/tradingagents/graph/signal_processing.py @@ -1,5 +1,6 @@ # TradingAgents/graph/signal_processing.py +import json from langchain_openai import ChatOpenAI @@ -10,7 +11,7 @@ class SignalProcessor: """Initialize with an LLM for processing.""" self.quick_thinking_llm = quick_thinking_llm - def process_signal(self, full_signal: str) -> str: + def process_signal(self, full_signal: dict) -> str: """ Process a full trading signal to extract the core decision. @@ -25,7 +26,7 @@ class SignalProcessor: "system", "You are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.", ), - ("human", full_signal), + ("human", json.dumps(full_signal)), ] - return self.quick_thinking_llm.invoke(messages).content + return self.quick_thinking_llm.invoke(messages).content \ No newline at end of file