203 lines
8.2 KiB
Python
203 lines
8.2 KiB
Python
import chromadb
|
|
from chromadb.config import Settings
|
|
from openai import OpenAI
|
|
import os
|
|
|
|
# Import DashScope if available
|
|
try:
|
|
import dashscope
|
|
from dashscope import TextEmbedding
|
|
DASHSCOPE_AVAILABLE = True
|
|
except ImportError:
|
|
DASHSCOPE_AVAILABLE = False
|
|
dashscope = None
|
|
TextEmbedding = None
|
|
|
|
|
|
class FinancialSituationMemory:
|
|
def __init__(self, name, config):
|
|
self.config = config
|
|
self.llm_provider = config.get("llm_provider", "openai").lower()
|
|
|
|
# Configure embedding model and client based on LLM provider
|
|
if (self.llm_provider == "dashscope" or
|
|
"dashscope" in self.llm_provider or
|
|
"alibaba" in self.llm_provider):
|
|
|
|
# Check if DashScope is available and configured
|
|
dashscope_key = os.getenv('DASHSCOPE_API_KEY')
|
|
openai_key = os.getenv('OPENAI_API_KEY')
|
|
|
|
if DASHSCOPE_AVAILABLE and dashscope_key:
|
|
# Use DashScope embeddings
|
|
self.embedding = "text-embedding-v3"
|
|
self.client = None # DashScope doesn't need OpenAI client
|
|
dashscope.api_key = dashscope_key
|
|
print("✅ Using DashScope embeddings")
|
|
elif openai_key:
|
|
# Fallback to OpenAI embeddings
|
|
print("⚠️ DashScope not available or not configured, falling back to OpenAI embeddings")
|
|
self.embedding = "text-embedding-3-small"
|
|
self.client = OpenAI(base_url=config.get("backend_url", "https://api.openai.com/v1"))
|
|
else:
|
|
# No valid API keys available
|
|
raise ValueError(
|
|
"No valid API keys found. For DashScope provider, please set either:\n"
|
|
"1. DASHSCOPE_API_KEY (preferred for DashScope embeddings)\n"
|
|
"2. OPENAI_API_KEY (fallback for OpenAI embeddings)\n"
|
|
"Install dashscope package: pip install dashscope"
|
|
)
|
|
elif self.llm_provider == "google":
|
|
# Google AI uses DashScope embedding if available, otherwise OpenAI
|
|
dashscope_key = os.getenv('DASHSCOPE_API_KEY')
|
|
openai_key = os.getenv('OPENAI_API_KEY')
|
|
|
|
if dashscope_key and DASHSCOPE_AVAILABLE:
|
|
self.embedding = "text-embedding-v3"
|
|
self.client = None
|
|
dashscope.api_key = dashscope_key
|
|
print("💡 Google AI using DashScope embedding service")
|
|
elif openai_key:
|
|
self.embedding = "text-embedding-3-small"
|
|
self.client = OpenAI(base_url=config.get("backend_url", "https://api.openai.com/v1"))
|
|
print("⚠️ Google AI falling back to OpenAI embedding service")
|
|
else:
|
|
raise ValueError(
|
|
"No valid API keys found for Google AI embeddings. Please set either:\n"
|
|
"1. DASHSCOPE_API_KEY (preferred)\n"
|
|
"2. OPENAI_API_KEY (fallback)"
|
|
)
|
|
elif config["backend_url"] == "http://localhost:11434/v1":
|
|
self.embedding = "nomic-embed-text"
|
|
self.client = OpenAI(base_url=config["backend_url"])
|
|
else:
|
|
self.embedding = "text-embedding-3-small"
|
|
self.client = OpenAI(base_url=config["backend_url"])
|
|
|
|
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
|
|
|
# Try to get existing collection, create new one if it doesn't exist
|
|
try:
|
|
self.situation_collection = self.chroma_client.get_collection(name=name)
|
|
except Exception:
|
|
# Collection doesn't exist, create new one
|
|
self.situation_collection = self.chroma_client.create_collection(name=name)
|
|
|
|
def get_embedding(self, text):
|
|
"""Get embedding for a text using the configured provider"""
|
|
|
|
if ((self.llm_provider == "dashscope" or
|
|
"dashscope" in self.llm_provider or
|
|
"alibaba" in self.llm_provider or
|
|
(self.llm_provider == "google" and self.client is None)) and
|
|
DASHSCOPE_AVAILABLE and self.client is None):
|
|
# Use DashScope embedding model
|
|
try:
|
|
response = TextEmbedding.call(
|
|
model=self.embedding,
|
|
input=text
|
|
)
|
|
if response.status_code == 200:
|
|
return response.output['embeddings'][0]['embedding']
|
|
else:
|
|
raise Exception(f"DashScope embedding error: {response.code} - {response.message}")
|
|
except Exception as e:
|
|
raise Exception(f"Error getting DashScope embedding: {str(e)}")
|
|
else:
|
|
# Use OpenAI-compatible embedding model
|
|
response = self.client.embeddings.create(
|
|
model=self.embedding, input=text
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
def add_situations(self, situations_and_advice):
|
|
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
|
|
|
situations = []
|
|
advice = []
|
|
ids = []
|
|
embeddings = []
|
|
|
|
offset = self.situation_collection.count()
|
|
|
|
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
|
situations.append(situation)
|
|
advice.append(recommendation)
|
|
ids.append(str(offset + i))
|
|
embeddings.append(self.get_embedding(situation))
|
|
|
|
self.situation_collection.add(
|
|
documents=situations,
|
|
metadatas=[{"recommendation": rec} for rec in advice],
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
)
|
|
|
|
def get_memories(self, current_situation, n_matches=1):
|
|
"""Find matching recommendations using embeddings"""
|
|
query_embedding = self.get_embedding(current_situation)
|
|
|
|
results = self.situation_collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_matches,
|
|
include=["metadatas", "documents", "distances"],
|
|
)
|
|
|
|
matched_results = []
|
|
for i in range(len(results["documents"][0])):
|
|
matched_results.append(
|
|
{
|
|
"matched_situation": results["documents"][0][i],
|
|
"recommendation": results["metadatas"][0][i]["recommendation"],
|
|
"similarity_score": 1 - results["distances"][0][i],
|
|
}
|
|
)
|
|
|
|
return matched_results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
matcher = FinancialSituationMemory()
|
|
|
|
# Example data
|
|
example_data = [
|
|
(
|
|
"High inflation rate with rising interest rates and declining consumer spending",
|
|
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
|
),
|
|
(
|
|
"Tech sector showing high volatility with increasing institutional selling pressure",
|
|
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
|
),
|
|
(
|
|
"Strong dollar affecting emerging markets with increasing forex volatility",
|
|
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
|
),
|
|
(
|
|
"Market showing signs of sector rotation with rising yields",
|
|
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
|
),
|
|
]
|
|
|
|
# Add the example situations and recommendations
|
|
matcher.add_situations(example_data)
|
|
|
|
# Example query
|
|
current_situation = """
|
|
Market showing increased volatility in tech sector, with institutional investors
|
|
reducing positions and rising interest rates affecting growth stock valuations
|
|
"""
|
|
|
|
try:
|
|
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
|
|
|
for i, rec in enumerate(recommendations, 1):
|
|
print(f"\nMatch {i}:")
|
|
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
|
print(f"Matched Situation: {rec['matched_situation']}")
|
|
print(f"Recommendation: {rec['recommendation']}")
|
|
|
|
except Exception as e:
|
|
print(f"Error during recommendation: {str(e)}")
|