# from langchain.vectorstores.chroma import Chroma from langchain_chroma import Chroma from langchain.text_splitter import RecursiveCharacterTextSplitter # from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings from langchain_community.embeddings import HuggingFaceBgeEmbeddings from chromadb.config import Settings import os message_template = """ User: {question} Assistant: {answer} """ model_norm = HuggingFaceBgeEmbeddings( model_name="BAAI/bge-base-en", model_kwargs={ 'device': 'cpu' }, encode_kwargs={ 'normalize_embeddings': True }, ) COLLECTION_NAME = 'latest_chat' def get_chroma(): settings = Settings() settings.allow_reset = True settings.is_persistent = True return Chroma( persist_directory=DB_FOLDER, embedding_function=model_norm, client_settings=settings, collection_name=COLLECTION_NAME ) DB_FOLDER = 'db' class ChatHistory: def __init__(self) -> None: self.history = {} def append(self, question: str, answer: str, uid: str): new_message = message_template.format( question=question, answer=answer ) if uid not in self.history: self.history[uid] = [] self.history[uid].append(new_message) self.embed(new_message, uid) def delete(self, uid: str): if uid in self.history: del self.history[uid] chroma = get_chroma() collection = chroma._client.get_collection(name=COLLECTION_NAME) collection.delete(where={'user': uid}) def reset_history(self): self.history = {} chroma = get_chroma() chroma._client.reset() def embed(self, text: str, uid: str): text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) all_splits = text_splitter.split_text(text) chroma = get_chroma() chroma.add_texts(all_splits, metadatas=[{"user": uid}]) def get_context(self, question: str, uid: str): if not os.path.exists(DB_FOLDER): return '' chroma = get_chroma() filter = { 'user': uid } documents = chroma.similarity_search(query=question, k=4, filter=filter) context = '\n'.join([d.page_content for d in documents]) return context def get_last_few(self, uid: str): if uid not in self.history: return '' return '\n'.join(self.history[uid][-3:])