75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
# from langchain.vectorstores.chroma import Chroma
|
|
from langchain_chroma import Chroma
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
# from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings
|
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
|
from chromadb.config import Settings
|
|
import os
|
|
|
|
message_template = """
|
|
User: {question}
|
|
|
|
Assistant: {answer}
|
|
"""
|
|
|
|
model_norm = HuggingFaceBgeEmbeddings(
|
|
model_name="BAAI/bge-base-en",
|
|
model_kwargs={ 'device': 'cpu' },
|
|
encode_kwargs={ 'normalize_embeddings': True },
|
|
)
|
|
|
|
def get_chroma():
|
|
settings = Settings()
|
|
settings.allow_reset = True
|
|
settings.is_persistent = True
|
|
|
|
return Chroma(
|
|
persist_directory=DB_FOLDER,
|
|
embedding_function=model_norm,
|
|
client_settings=settings,
|
|
collection_name='latest_chat'
|
|
)
|
|
|
|
DB_FOLDER = 'db'
|
|
|
|
class ChatHistory:
|
|
def __init__(self) -> None:
|
|
self.history = {}
|
|
|
|
def append(self, question: str, answer: str, uid: str):
|
|
new_message = message_template.format(
|
|
question=question,
|
|
answer=answer
|
|
)
|
|
if uid not in self.history:
|
|
self.history[uid] = []
|
|
self.history[uid].append(new_message)
|
|
self.embed(new_message, uid)
|
|
|
|
def reset_history(self):
|
|
self.history = {}
|
|
chroma = get_chroma()
|
|
chroma._client.reset()
|
|
|
|
def embed(self, text: str, uid: str):
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
|
|
all_splits = text_splitter.split_text(text)
|
|
|
|
chroma = get_chroma()
|
|
chroma.add_texts(all_splits, metadatas=[{"user": uid}])
|
|
|
|
def get_context(self, question: str, uid: str):
|
|
if not os.path.exists(DB_FOLDER):
|
|
return ''
|
|
chroma = get_chroma()
|
|
filter = {
|
|
'user': uid
|
|
}
|
|
documents = chroma.similarity_search(query=question, k=4, filter=filter)
|
|
context = '\n'.join([d.page_content for d in documents])
|
|
return context
|
|
|
|
def get_last_few(self, uid: str):
|
|
if uid not in self.history:
|
|
return ''
|
|
return '\n'.join(self.history[uid][-3:]) |