TradingAgents/tradingagents/agents/utils/memory.py

176 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import chromadb
from chromadb.config import Settings
import requests
import os
class FinancialSituationMemory:
def __init__(self, name, config):
# 根据不同的模型提供商设置embedding模型
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
elif "dashscope.aliyuncs.com" in config["backend_url"]:
self.embedding = "text-embedding-v2" # 通义千问embedding模型
elif "baidu" in config["backend_url"]:
self.embedding = "bge-large-zh-v1.5" # 文心一言embedding模型
else:
self.embedding = "text-embedding-3-small" # 默认OpenAI模型
self.config = config
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
def get_embedding(self, text):
"""Get embedding for a text using the configured model"""
# 根据不同的模型提供商调用不同的API
if "dashscope.aliyuncs.com" in self.config["backend_url"]:
return self._get_qwen_embedding(text)
elif "baidu" in self.config["backend_url"]:
return self._get_ernie_embedding(text)
else:
# 对于其他模型使用简化的embedding返回固定向量
return self._get_simple_embedding(text)
def _get_qwen_embedding(self, text):
"""获取通义千问embedding"""
try:
headers = {
"Authorization": f"Bearer {os.getenv('DASHSCOPE_API_KEY')}",
"Content-Type": "application/json"
}
data = {
"model": self.embedding,
"input": text
}
response = requests.post(
f"{self.config['backend_url'].replace('/chat/completions', '/embeddings')}",
headers=headers,
json=data,
timeout=30
)
response.raise_for_status()
result = response.json()
return result["data"][0]["embedding"]
except Exception as e:
print(f"⚠️ 通义千问embedding调用失败: {e}")
return self._get_simple_embedding(text)
def _get_ernie_embedding(self, text):
"""获取文心一言embedding"""
try:
# 文心一言的embedding API调用
# 这里使用简化的实现
return self._get_simple_embedding(text)
except Exception as e:
print(f"⚠️ 文心一言embedding调用失败: {e}")
return self._get_simple_embedding(text)
def _get_simple_embedding(self, text):
"""简化的embedding实现返回固定长度的向量"""
# 使用文本的hash值生成固定长度的向量
import hashlib
hash_obj = hashlib.md5(text.encode('utf-8'))
hash_bytes = hash_obj.digest()
# 生成1536维的向量与OpenAI embedding维度相同
embedding = []
for i in range(1536):
byte_index = i % len(hash_bytes)
embedding.append((hash_bytes[byte_index] - 128) / 128.0)
return embedding
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
situations = []
advice = []
ids = []
embeddings = []
offset = self.situation_collection.count()
for i, (situation, recommendation) in enumerate(situations_and_advice):
situations.append(situation)
advice.append(recommendation)
ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation))
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using OpenAI embeddings"""
query_embedding = self.get_embedding(current_situation)
results = self.situation_collection.query(
query_embeddings=[query_embedding],
n_results=n_matches,
include=["metadatas", "documents", "distances"],
)
matched_results = []
for i in range(len(results["documents"][0])):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
)
return matched_results
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory()
# Example data
example_data = [
(
"High inflation rate with rising interest rates and declining consumer spending",
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
),
(
"Tech sector showing high volatility with increasing institutional selling pressure",
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
),
(
"Strong dollar affecting emerging markets with increasing forex volatility",
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
),
(
"Market showing signs of sector rotation with rising yields",
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
),
]
# Add the example situations and recommendations
matcher.add_situations(example_data)
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
try:
recommendations = matcher.get_memories(current_situation, n_matches=2)
for i, rec in enumerate(recommendations, 1):
print(f"\nMatch {i}:")
print(f"Similarity Score: {rec['similarity_score']:.2f}")
print(f"Matched Situation: {rec['matched_situation']}")
print(f"Recommendation: {rec['recommendation']}")
except Exception as e:
print(f"Error during recommendation: {str(e)}")