agent2/llm/db_history.py
2024-12-18 11:58:04 +08:00

59 lines
1.7 KiB
Python

from pymongo import MongoClient
import os
from datetime import datetime
message_template = """
User: {question}
Answer: {answer}
"""
COLLECTION_NAME = 'chat_history'
class ChatHistory:
def __init__(self) -> None:
self.history = {}
db_url = os.environ.get('DB_MAIN')
db_name = os.environ.get('DB_NAME')
self.db_client = MongoClient(db_url)
self.db = self.db_client[db_name]
self.collection = self.db[COLLECTION_NAME]
self.collection.create_index('user')
def append(self, question: str, answer: str, uid: str):
new_message = message_template.format(
question=question,
answer=answer
)
if uid not in self.history:
self.history[uid] = []
self.history[uid].append(new_message)
self.collection.insert_one({
'user': uid,
'question': question,
'answer': answer,
'created_at': int(datetime.now().timestamp())
})
def delete(self, uid: str):
if uid in self.history:
del self.history[uid]
self.collection.delete_many({'user': uid})
def get_context(self, question: str, uid: str):
records = self.collection.find({'user': uid}).sort("_id", -1).limit(10)
# transform records into list of message_template format
documents = []
for record in records:
documents.append(message_template.format(
question=record['question'],
answer=record['answer']
))
context = '\n'.join([d for d in documents])
return context
def get_last_few(self, uid: str):
if uid not in self.history:
return ''
return '\n'.join(self.history[uid][-3:])