agent2/llm/chat_history.py
CounterFire2023 5ee0dd9a19 优化prompt
2024-12-16 21:49:38 +08:00

84 lines
2.4 KiB
Python

# from langchain.vectorstores.chroma import Chroma
from langchain_chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from chromadb.config import Settings
import os
message_template = """
User: {question}
Assistant: {answer}
"""
model_norm = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-base-en",
model_kwargs={ 'device': 'cpu' },
encode_kwargs={ 'normalize_embeddings': True },
)
COLLECTION_NAME = 'latest_chat'
def get_chroma():
settings = Settings()
settings.allow_reset = True
settings.is_persistent = True
return Chroma(
persist_directory=DB_FOLDER,
embedding_function=model_norm,
client_settings=settings,
collection_name=COLLECTION_NAME
)
DB_FOLDER = 'db'
class ChatHistory:
def __init__(self) -> None:
self.history = {}
def append(self, question: str, answer: str, uid: str):
new_message = message_template.format(
question=question,
answer=answer
)
if uid not in self.history:
self.history[uid] = []
self.history[uid].append(new_message)
self.embed(new_message, uid)
def delete(self, uid: str):
if uid in self.history:
del self.history[uid]
chroma = get_chroma()
collection = chroma._client.get_collection(name=COLLECTION_NAME)
collection.delete(where={'user': uid})
def reset_history(self):
self.history = {}
chroma = get_chroma()
chroma._client.reset()
def embed(self, text: str, uid: str):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_text(text)
chroma = get_chroma()
chroma.add_texts(all_splits, metadatas=[{"user": uid}])
def get_context(self, question: str, uid: str):
if not os.path.exists(DB_FOLDER):
return ''
chroma = get_chroma()
filter = {
'user': uid
}
documents = chroma.similarity_search(query=question, k=4, filter=filter)
context = '\n'.join([d.page_content for d in documents])
return context
def get_last_few(self, uid: str):
if uid not in self.history:
return ''
return '\n'.join(self.history[uid][-3:])