feat: add permanent memory
This commit is contained in:
parent
309a105465
commit
356a563813
19
README.md
19
README.md
|
|
@ -197,6 +197,25 @@ print(decision)
|
|||
|
||||
You can view the full list of configurations in `tradingagents/default_config.py`.
|
||||
|
||||
## Persistent Memory and Learning
|
||||
|
||||
To allow the agents to learn from the success or failure of previous decisions, TradingAgents includes a persistent memory mechanism.
|
||||
|
||||
Each agent's reflections and the "lessons learned" from past trading sessions are stored on disk. This allows the system to build a rich, searchable history of its actions and their consequences, enabling more informed decisions in the future.
|
||||
|
||||
- **Storage**: The memory is managed by the `FinancialSituationMemory` class in `tradingagents/agents/utils/memory.py` and is persisted to the `./memory_store/` directory using a local ChromaDB database.
|
||||
- **Learning Loop**: After a trade, a `Reflector` agent analyzes the outcome (profit or loss) and generates a "lesson." This lesson is stored in the memory, linked to the market conditions at the time. Before the next trade, agents query this memory for similar past situations to retrieve relevant lessons, which are then used to inform their decision-making process.
|
||||
|
||||
### Inspecting the Memory
|
||||
|
||||
You can inspect the contents of the persistent memory to see what the agents have learned. To do this, run the memory utility script from the root of the project:
|
||||
|
||||
```bash
|
||||
python -m tradingagents.agents.utils.memory
|
||||
```
|
||||
|
||||
The first time you run this, it will populate the memory with example data. Subsequent runs will load and display the data from the `memory_store` directory, demonstrating that the memory persists across sessions.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions from the community! Whether it's fixing a bug, improving documentation, or suggesting a new feature, your input helps make this project better. If you are interested in this line of research, please consider joining our open-source financial AI research community [Tauric Research](https://tauric.ai/).
|
||||
|
|
|
|||
52
cli/main.py
52
cli/main.py
|
|
@ -73,6 +73,10 @@ class MessageBuffer:
|
|||
"final_trade_decision": None,
|
||||
}
|
||||
|
||||
def _format_report_content(self, content):
|
||||
"""Ensures content is a string."""
|
||||
return str(content)
|
||||
|
||||
def add_message(self, message_type, content):
|
||||
timestamp = datetime.datetime.now().strftime("%H:%M:%S")
|
||||
self.messages.append((timestamp, message_type, content))
|
||||
|
|
@ -100,7 +104,7 @@ class MessageBuffer:
|
|||
for section, content in self.report_sections.items():
|
||||
if content is not None:
|
||||
latest_section = section
|
||||
latest_content = content
|
||||
latest_content = self._format_report_content(content)
|
||||
|
||||
if latest_section and latest_content:
|
||||
# Format the current section for display
|
||||
|
|
@ -136,35 +140,35 @@ class MessageBuffer:
|
|||
report_parts.append("## Analyst Team Reports")
|
||||
if self.report_sections["market_report"]:
|
||||
report_parts.append(
|
||||
f"### Market Analysis\n{self.report_sections['market_report']}"
|
||||
f"### Market Analysis\n{self._format_report_content(self.report_sections['market_report'])}"
|
||||
)
|
||||
if self.report_sections["sentiment_report"]:
|
||||
report_parts.append(
|
||||
f"### Social Sentiment\n{self.report_sections['sentiment_report']}"
|
||||
f"### Social Sentiment\n{self._format_report_content(self.report_sections['sentiment_report'])}"
|
||||
)
|
||||
if self.report_sections["news_report"]:
|
||||
report_parts.append(
|
||||
f"### News Analysis\n{self.report_sections['news_report']}"
|
||||
f"### News Analysis\n{self._format_report_content(self.report_sections['news_report'])}"
|
||||
)
|
||||
if self.report_sections["fundamentals_report"]:
|
||||
report_parts.append(
|
||||
f"### Fundamentals Analysis\n{self.report_sections['fundamentals_report']}"
|
||||
f"### Fundamentals Analysis\n{self._format_report_content(self.report_sections['fundamentals_report'])}"
|
||||
)
|
||||
|
||||
# Research Team Reports
|
||||
if self.report_sections["investment_plan"]:
|
||||
report_parts.append("## Research Team Decision")
|
||||
report_parts.append(f"{self.report_sections['investment_plan']}")
|
||||
report_parts.append(f"{self._format_report_content(self.report_sections['investment_plan'])}")
|
||||
|
||||
# Trading Team Reports
|
||||
if self.report_sections["trader_investment_plan"]:
|
||||
report_parts.append("## Trading Team Plan")
|
||||
report_parts.append(f"{self.report_sections['trader_investment_plan']}")
|
||||
report_parts.append(f"{self._format_report_content(self.report_sections['trader_investment_plan'])}")
|
||||
|
||||
# Portfolio Management Decision
|
||||
if self.report_sections["final_trade_decision"]:
|
||||
report_parts.append("## Portfolio Management Decision")
|
||||
report_parts.append(f"{self.report_sections['final_trade_decision']}")
|
||||
report_parts.append(f"{self._format_report_content(self.report_sections['final_trade_decision'])}")
|
||||
|
||||
self.final_report = "\n\n".join(report_parts) if report_parts else None
|
||||
|
||||
|
|
@ -550,6 +554,10 @@ def display_complete_report(final_state):
|
|||
"""Display the complete analysis report with team-based panels."""
|
||||
console.print("\n[bold green]Complete Analysis Report[/bold green]\n")
|
||||
|
||||
def _format_content_for_markdown(content):
|
||||
"""Ensures content is a string."""
|
||||
return str(content)
|
||||
|
||||
# User Position
|
||||
user_position = final_state.get("user_position", "none")
|
||||
cost_per_trade = final_state.get("cost_per_trade", 0.0)
|
||||
|
|
@ -562,7 +570,7 @@ def display_complete_report(final_state):
|
|||
if final_state.get("market_report"):
|
||||
analyst_reports.append(
|
||||
Panel(
|
||||
Markdown(final_state["market_report"]),
|
||||
Markdown(_format_content_for_markdown(final_state["market_report"])),
|
||||
title="Market Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -573,7 +581,7 @@ def display_complete_report(final_state):
|
|||
if final_state.get("sentiment_report"):
|
||||
analyst_reports.append(
|
||||
Panel(
|
||||
Markdown(final_state["sentiment_report"]),
|
||||
Markdown(_format_content_for_markdown(final_state["sentiment_report"])),
|
||||
title="Social Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -584,7 +592,7 @@ def display_complete_report(final_state):
|
|||
if final_state.get("news_report"):
|
||||
analyst_reports.append(
|
||||
Panel(
|
||||
Markdown(final_state["news_report"]),
|
||||
Markdown(_format_content_for_markdown(final_state["news_report"])),
|
||||
title="News Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -595,7 +603,7 @@ def display_complete_report(final_state):
|
|||
if final_state.get("fundamentals_report"):
|
||||
analyst_reports.append(
|
||||
Panel(
|
||||
Markdown(final_state["fundamentals_report"]),
|
||||
Markdown(_format_content_for_markdown(final_state["fundamentals_report"])),
|
||||
title="Fundamentals Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -621,7 +629,7 @@ def display_complete_report(final_state):
|
|||
if debate_state.get("bull_history"):
|
||||
research_reports.append(
|
||||
Panel(
|
||||
Markdown(debate_state["bull_history"]),
|
||||
Markdown(_format_content_for_markdown(debate_state["bull_history"])),
|
||||
title="Bull Researcher",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -632,7 +640,7 @@ def display_complete_report(final_state):
|
|||
if debate_state.get("bear_history"):
|
||||
research_reports.append(
|
||||
Panel(
|
||||
Markdown(debate_state["bear_history"]),
|
||||
Markdown(_format_content_for_markdown(debate_state["bear_history"])),
|
||||
title="Bear Researcher",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -643,7 +651,7 @@ def display_complete_report(final_state):
|
|||
if debate_state.get("judge_decision"):
|
||||
research_reports.append(
|
||||
Panel(
|
||||
Markdown(debate_state["judge_decision"]),
|
||||
Markdown(_format_content_for_markdown(debate_state["judge_decision"])),
|
||||
title="Research Manager",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -665,7 +673,7 @@ def display_complete_report(final_state):
|
|||
console.print(
|
||||
Panel(
|
||||
Panel(
|
||||
Markdown(final_state["trader_investment_plan"]),
|
||||
Markdown(_format_content_for_markdown(final_state["trader_investment_plan"])),
|
||||
title="Trader",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -685,7 +693,7 @@ def display_complete_report(final_state):
|
|||
if risk_state.get("risky_history"):
|
||||
risk_reports.append(
|
||||
Panel(
|
||||
Markdown(risk_state["risky_history"]),
|
||||
Markdown(_format_content_for_markdown(risk_state["risky_history"])),
|
||||
title="Aggressive Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -696,7 +704,7 @@ def display_complete_report(final_state):
|
|||
if risk_state.get("safe_history"):
|
||||
risk_reports.append(
|
||||
Panel(
|
||||
Markdown(risk_state["safe_history"]),
|
||||
Markdown(_format_content_for_markdown(risk_state["safe_history"])),
|
||||
title="Conservative Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -707,7 +715,7 @@ def display_complete_report(final_state):
|
|||
if risk_state.get("neutral_history"):
|
||||
risk_reports.append(
|
||||
Panel(
|
||||
Markdown(risk_state["neutral_history"]),
|
||||
Markdown(_format_content_for_markdown(risk_state["neutral_history"])),
|
||||
title="Neutral Analyst",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -729,7 +737,7 @@ def display_complete_report(final_state):
|
|||
console.print(
|
||||
Panel(
|
||||
Panel(
|
||||
Markdown(risk_state["judge_decision"]),
|
||||
Markdown(_format_content_for_markdown(risk_state["judge_decision"])),
|
||||
title="Portfolio Manager",
|
||||
border_style="blue",
|
||||
padding=(1, 2),
|
||||
|
|
@ -826,7 +834,7 @@ def run_analysis():
|
|||
if content:
|
||||
file_name = f"{section_name}.md"
|
||||
with open(report_dir / file_name, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
f.write(str(content))
|
||||
return wrapper
|
||||
|
||||
message_buffer.add_message = save_message_decorator(message_buffer, "add_message")
|
||||
|
|
@ -836,7 +844,7 @@ def run_analysis():
|
|||
# Now start the display layout
|
||||
layout = create_layout()
|
||||
|
||||
with Live(layout, refresh_per_second=4) as live:
|
||||
with Live(layout, refresh_per_second=1) as live:
|
||||
# Initial display
|
||||
update_display(layout)
|
||||
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ def select_llm_provider() -> tuple[str, str]:
|
|||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("OpenRouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -54,7 +54,8 @@ Deliverables:
|
|||
Focus on actionable insights and continuous improvement. Build on past lessons, critically evaluate all perspectives, and ensure each decision advances better outcomes."""
|
||||
|
||||
response = llm.invoke(prompt)
|
||||
|
||||
|
||||
final_decision_content = response.content
|
||||
new_risk_debate_state = {
|
||||
"judge_decision": response.content,
|
||||
"history": risk_debate_state["history"],
|
||||
|
|
@ -70,7 +71,7 @@ Focus on actionable insights and continuous improvement. Build on past lessons,
|
|||
|
||||
return {
|
||||
"risk_debate_state": new_risk_debate_state,
|
||||
"final_trade_decision": response.content,
|
||||
"final_trade_decision": final_decision_content,
|
||||
}
|
||||
|
||||
return risk_manager_node
|
||||
|
|
|
|||
|
|
@ -40,7 +40,8 @@ def create_trader(llm, memory):
|
|||
- If the user has an open short position, your recommendation can be to maintain the short position, close the short position, or close the short position and open a long position.
|
||||
- If the user has no open position, your recommendation can be to do nothing, open a long position, or open a short position.
|
||||
|
||||
Based on your analysis, provide a specific recommendation. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **YOUR_RECOMMENDATION**' to confirm your recommendation. Take into account that any transaction will incur a cost of {cost_per_trade}, so the potential profit of a transaction must be greater than this cost. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situatiosn you traded in and the lessons learned: {past_memory_str}""",
|
||||
Based on your analysis, provide a specific recommendation. End with a firm decision and always conclude your response with 'FINAL TRANSACTION PROPOSAL: **YOUR_RECOMMENDATION**' to confirm your recommendation. Take into account that any transaction will incur a cost of {cost_per_trade}, so the potential profit of a transaction must be greater than this cost. Do not forget to utilize lessons from past decisions to learn from your mistakes. Here is some reflections from similar situations you traded in and the lessons learned: {past_memory_str}
|
||||
Your output should always be in markdown format.""",
|
||||
},
|
||||
context,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
import os
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
def __init__(self, name, config, persist_directory="./memory_store"):
|
||||
# Use local embeddings for all providers - no external API dependency
|
||||
self.use_local_embeddings = config.get("use_local_embeddings", True)
|
||||
|
||||
|
|
@ -40,19 +38,20 @@ class FinancialSituationMemory:
|
|||
self.client = OpenAI(base_url=config["backend_url"], api_key=api_key)
|
||||
self.embedding_type = "api"
|
||||
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.chroma_client = chromadb.PersistentClient(path=persist_directory)
|
||||
|
||||
# Create collection with or without custom embedding function
|
||||
if self.embedding_type == "chromadb_default":
|
||||
# Let ChromaDB handle embeddings with its default function
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
self.situation_collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
else:
|
||||
# We'll handle embeddings ourselves
|
||||
self.situation_collection = self.chroma_client.create_collection(
|
||||
self.situation_collection = self.chroma_client.get_or_create_collection(
|
||||
name=name,
|
||||
metadata={"hnsw:space": "cosine"} # Use cosine similarity
|
||||
)
|
||||
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get embedding for a text using local or API-based models"""
|
||||
try:
|
||||
|
|
@ -146,46 +145,80 @@ class FinancialSituationMemory:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define the directory where memory will be stored
|
||||
PERSIST_DIRECTORY = "./memory_store"
|
||||
print(f"Memory will be persisted to: {os.path.abspath(PERSIST_DIRECTORY)}\n")
|
||||
|
||||
# Example usage
|
||||
matcher = FinancialSituationMemory()
|
||||
config_example = {"use_local_embeddings": True, "backend_url": ""}
|
||||
# Initialize memory with a name and the persistence directory
|
||||
matcher = FinancialSituationMemory(
|
||||
name="persistent_example_memory",
|
||||
config=config_example,
|
||||
persist_directory=PERSIST_DIRECTORY
|
||||
)
|
||||
|
||||
# Example data
|
||||
example_data = [
|
||||
(
|
||||
"High inflation rate with rising interest rates and declining consumer spending",
|
||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||
),
|
||||
(
|
||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||
),
|
||||
(
|
||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||
),
|
||||
(
|
||||
"Market showing signs of sector rotation with rising yields",
|
||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||
),
|
||||
]
|
||||
# Check if memory is already populated
|
||||
if matcher.situation_collection.count() == 0:
|
||||
print("Memory is empty. Populating with example data...")
|
||||
# Example data
|
||||
example_data = [
|
||||
(
|
||||
"High inflation rate with rising interest rates and declining consumer spending",
|
||||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||||
),
|
||||
(
|
||||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||||
),
|
||||
(
|
||||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||||
),
|
||||
(
|
||||
"Market showing signs of sector rotation with rising yields",
|
||||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||||
),
|
||||
]
|
||||
# Add the example situations and recommendations
|
||||
matcher.add_situations(example_data)
|
||||
print("Example data added to persistent memory.\n")
|
||||
else:
|
||||
print("Memory already contains data from a previous run.\n")
|
||||
|
||||
# Add the example situations and recommendations
|
||||
matcher.add_situations(example_data)
|
||||
# --- Inspecting the entire memory store ---
|
||||
print("--- Dumping all contents of the memory store ---")
|
||||
all_items = matcher.situation_collection.get(include=["metadatas", "documents"])
|
||||
|
||||
if not all_items or not all_items.get("ids"):
|
||||
print("Memory store is empty.")
|
||||
else:
|
||||
for i, item_id in enumerate(all_items["ids"]):
|
||||
situation = all_items["documents"][i]
|
||||
recommendation = all_items["metadatas"][i].get("recommendation", "N/A")
|
||||
print(f"ID: {item_id}")
|
||||
print(f" Situation: {situation}")
|
||||
print(f" Recommendation/Lesson: {recommendation}\n")
|
||||
print("--- End of memory dump ---")
|
||||
|
||||
# Example query
|
||||
# Example query to show it still works
|
||||
print("\n--- Running an example query ---")
|
||||
current_situation = """
|
||||
Market showing increased volatility in tech sector, with institutional investors
|
||||
reducing positions and rising interest rates affecting growth stock valuations
|
||||
"""
|
||||
print(f"Querying for situation: {current_situation.strip()}\n")
|
||||
|
||||
try:
|
||||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||||
|
||||
for i, rec in enumerate(recommendations, 1):
|
||||
print(f"\nMatch {i}:")
|
||||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||||
print(f"Matched Situation: {rec['matched_situation']}")
|
||||
print(f"Recommendation: {rec['recommendation']}")
|
||||
recommendations = matcher.get_memories(current_situation, n_matches=1)
|
||||
if recommendations:
|
||||
rec = recommendations[0]
|
||||
print(f"Most similar match found:")
|
||||
print(f" Similarity Score: {rec['similarity_score']:.2f}")
|
||||
print(f" Matched Situation: {rec['matched_situation']}")
|
||||
print(f" Retrieved Recommendation: {rec['recommendation']}\n")
|
||||
else:
|
||||
print("No similar situations found in memory.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during recommendation: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# TradingAgents/graph/signal_processing.py
|
||||
|
||||
import json
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
||||
|
|
@ -10,7 +11,7 @@ class SignalProcessor:
|
|||
"""Initialize with an LLM for processing."""
|
||||
self.quick_thinking_llm = quick_thinking_llm
|
||||
|
||||
def process_signal(self, full_signal: str) -> str:
|
||||
def process_signal(self, full_signal: dict) -> str:
|
||||
"""
|
||||
Process a full trading signal to extract the core decision.
|
||||
|
||||
|
|
@ -25,7 +26,7 @@ class SignalProcessor:
|
|||
"system",
|
||||
"You are an efficient assistant designed to analyze paragraphs or financial reports provided by a group of analysts. Your task is to extract the investment decision: SELL, BUY, or HOLD. Provide only the extracted decision (SELL, BUY, or HOLD) as your output, without adding any additional text or information.",
|
||||
),
|
||||
("human", full_signal),
|
||||
("human", json.dumps(full_signal)),
|
||||
]
|
||||
|
||||
return self.quick_thinking_llm.invoke(messages).content
|
||||
return self.quick_thinking_llm.invoke(messages).content
|
||||
Loading…
Reference in New Issue