88 lines
2.3 KiB
Python
88 lines
2.3 KiB
Python
from ollama import Client
|
|
# from langchain.llms.ollama import Ollama
|
|
# from langchain.llms import Ollama
|
|
from langchain_community.chat_models import ChatOllama
|
|
from .db_history import ChatHistory
|
|
from .agent_util import config_agent
|
|
import os
|
|
|
|
ollama_model = 'llama3.2'
|
|
# OLLAMA_HOST = os.environ.get('OLLAMA_BOT_BASE_URL')
|
|
# if OLLAMA_HOST is None:
|
|
# raise ValueError("OLLAMA_BOT_BASE_URL env variable is not set")
|
|
|
|
default_system_prompt_path = './prompts/default.md'
|
|
|
|
# 传入system参数
|
|
def get_client(system=''):
|
|
llama3 = ChatOllama(model=ollama_model)
|
|
llama3_json = ChatOllama(
|
|
model=ollama_model,
|
|
format='json',
|
|
temperature=0
|
|
)
|
|
client = config_agent(llama3, llama3_json)
|
|
return client
|
|
|
|
|
|
def list_models_names():
|
|
# models = Client(OLLAMA_HOST).list()["models"]
|
|
# return [m['name'] for m in models]
|
|
return []
|
|
|
|
|
|
PROMPT_TEMPLATE = """
|
|
{context}
|
|
|
|
{history}
|
|
|
|
User: {question}
|
|
|
|
Assistant:
|
|
"""
|
|
|
|
class OllamaClient:
|
|
def __init__(self) -> None:
|
|
with open(default_system_prompt_path) as file:
|
|
self.system_prompt = file.read()
|
|
|
|
self.client = get_client(self.system_prompt)
|
|
self.history = ChatHistory()
|
|
|
|
def generate(self, question, uid) -> str:
|
|
context_content = self.history.get_context(question, uid)
|
|
context = f"CONTEXT: {context_content}" if len(context_content) > 0 else ''
|
|
|
|
last_few = self.history.get_last_few(uid)
|
|
history = f"LAST MESSAGES: {last_few}" if len(last_few) > 0 else ''
|
|
|
|
prompt = PROMPT_TEMPLATE.format(
|
|
context=context,
|
|
history=history,
|
|
question=question
|
|
)
|
|
answer = self.client.invoke({"question": question, "context": context, "history": history})['generation']
|
|
print("answer:", answer)
|
|
self.history.append(question, answer, uid)
|
|
return answer
|
|
|
|
def reset_history(self, uid):
|
|
# self.history.reset_history()
|
|
self.history.delete(uid)
|
|
|
|
def set_model(self, model_name):
|
|
self.client.model = model_name
|
|
|
|
def set_host(self, host):
|
|
self.client.base_url = host
|
|
|
|
def set_system(self, system):
|
|
self.client.system = system
|
|
|
|
def get_status(self):
|
|
return (
|
|
self.client.base_url,
|
|
self.client.model,
|
|
self.client.system
|
|
)
|