176 lines
6.6 KiB
Python
176 lines
6.6 KiB
Python
import chromadb
|
||
from chromadb.config import Settings
|
||
import requests
|
||
import os
|
||
|
||
|
||
class FinancialSituationMemory:
|
||
def __init__(self, name, config):
|
||
# 根据不同的模型提供商设置embedding模型
|
||
if config["backend_url"] == "http://localhost:11434/v1":
|
||
self.embedding = "nomic-embed-text"
|
||
elif "dashscope.aliyuncs.com" in config["backend_url"]:
|
||
self.embedding = "text-embedding-v2" # 通义千问embedding模型
|
||
elif "baidu" in config["backend_url"]:
|
||
self.embedding = "bge-large-zh-v1.5" # 文心一言embedding模型
|
||
else:
|
||
self.embedding = "text-embedding-3-small" # 默认OpenAI模型
|
||
|
||
self.config = config
|
||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||
|
||
def get_embedding(self, text):
|
||
"""Get embedding for a text using the configured model"""
|
||
|
||
# 根据不同的模型提供商调用不同的API
|
||
if "dashscope.aliyuncs.com" in self.config["backend_url"]:
|
||
return self._get_qwen_embedding(text)
|
||
elif "baidu" in self.config["backend_url"]:
|
||
return self._get_ernie_embedding(text)
|
||
else:
|
||
# 对于其他模型,使用简化的embedding(返回固定向量)
|
||
return self._get_simple_embedding(text)
|
||
|
||
def _get_qwen_embedding(self, text):
|
||
"""获取通义千问embedding"""
|
||
try:
|
||
headers = {
|
||
"Authorization": f"Bearer {os.getenv('DASHSCOPE_API_KEY')}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
data = {
|
||
"model": self.embedding,
|
||
"input": text
|
||
}
|
||
|
||
response = requests.post(
|
||
f"{self.config['backend_url'].replace('/chat/completions', '/embeddings')}",
|
||
headers=headers,
|
||
json=data,
|
||
timeout=30
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
return result["data"][0]["embedding"]
|
||
except Exception as e:
|
||
print(f"⚠️ 通义千问embedding调用失败: {e}")
|
||
return self._get_simple_embedding(text)
|
||
|
||
def _get_ernie_embedding(self, text):
|
||
"""获取文心一言embedding"""
|
||
try:
|
||
# 文心一言的embedding API调用
|
||
# 这里使用简化的实现
|
||
return self._get_simple_embedding(text)
|
||
except Exception as e:
|
||
print(f"⚠️ 文心一言embedding调用失败: {e}")
|
||
return self._get_simple_embedding(text)
|
||
|
||
def _get_simple_embedding(self, text):
|
||
"""简化的embedding实现(返回固定长度的向量)"""
|
||
# 使用文本的hash值生成固定长度的向量
|
||
import hashlib
|
||
hash_obj = hashlib.md5(text.encode('utf-8'))
|
||
hash_bytes = hash_obj.digest()
|
||
|
||
# 生成1536维的向量(与OpenAI embedding维度相同)
|
||
embedding = []
|
||
for i in range(1536):
|
||
byte_index = i % len(hash_bytes)
|
||
embedding.append((hash_bytes[byte_index] - 128) / 128.0)
|
||
|
||
return embedding
|
||
|
||
def add_situations(self, situations_and_advice):
|
||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||
|
||
situations = []
|
||
advice = []
|
||
ids = []
|
||
embeddings = []
|
||
|
||
offset = self.situation_collection.count()
|
||
|
||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||
situations.append(situation)
|
||
advice.append(recommendation)
|
||
ids.append(str(offset + i))
|
||
embeddings.append(self.get_embedding(situation))
|
||
|
||
self.situation_collection.add(
|
||
documents=situations,
|
||
metadatas=[{"recommendation": rec} for rec in advice],
|
||
embeddings=embeddings,
|
||
ids=ids,
|
||
)
|
||
|
||
def get_memories(self, current_situation, n_matches=1):
|
||
"""Find matching recommendations using OpenAI embeddings"""
|
||
query_embedding = self.get_embedding(current_situation)
|
||
|
||
results = self.situation_collection.query(
|
||
query_embeddings=[query_embedding],
|
||
n_results=n_matches,
|
||
include=["metadatas", "documents", "distances"],
|
||
)
|
||
|
||
matched_results = []
|
||
for i in range(len(results["documents"][0])):
|
||
matched_results.append(
|
||
{
|
||
"matched_situation": results["documents"][0][i],
|
||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
||
"similarity_score": 1 - results["distances"][0][i],
|
||
}
|
||
)
|
||
|
||
return matched_results
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# Example usage
|
||
matcher = FinancialSituationMemory()
|
||
|
||
# Example data
|
||
example_data = [
|
||
(
|
||
"High inflation rate with rising interest rates and declining consumer spending",
|
||
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
||
),
|
||
(
|
||
"Tech sector showing high volatility with increasing institutional selling pressure",
|
||
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
||
),
|
||
(
|
||
"Strong dollar affecting emerging markets with increasing forex volatility",
|
||
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
||
),
|
||
(
|
||
"Market showing signs of sector rotation with rising yields",
|
||
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
||
),
|
||
]
|
||
|
||
# Add the example situations and recommendations
|
||
matcher.add_situations(example_data)
|
||
|
||
# Example query
|
||
current_situation = """
|
||
Market showing increased volatility in tech sector, with institutional investors
|
||
reducing positions and rising interest rates affecting growth stock valuations
|
||
"""
|
||
|
||
try:
|
||
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
||
|
||
for i, rec in enumerate(recommendations, 1):
|
||
print(f"\nMatch {i}:")
|
||
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
||
print(f"Matched Situation: {rec['matched_situation']}")
|
||
print(f"Recommendation: {rec['recommendation']}")
|
||
|
||
except Exception as e:
|
||
print(f"Error during recommendation: {str(e)}")
|