agent2/llm/ollama.py
CounterFire2023 530843d3b0 save tmp
2024-12-16 17:55:49 +08:00

87 lines
2.3 KiB
Python

from ollama import Client
# from langchain.llms.ollama import Ollama
# from langchain.llms import Ollama
from langchain_community.chat_models import ChatOllama
from .chat_history import ChatHistory
from .agent_util import config_agent
import os
ollama_model = 'llama3.2'
# OLLAMA_HOST = os.environ.get('OLLAMA_BOT_BASE_URL')
# if OLLAMA_HOST is None:
# raise ValueError("OLLAMA_BOT_BASE_URL env variable is not set")
default_system_prompt_path = './prompts/default.md'
# 传入system参数
def get_client(system=''):
llama3 = ChatOllama(model=ollama_model, temperature=0)
llama3_json = ChatOllama(
model=ollama_model,
format='json',
temperature=0
)
client = config_agent(llama3, llama3_json)
return client
def list_models_names():
# models = Client(OLLAMA_HOST).list()["models"]
# return [m['name'] for m in models]
return []
PROMPT_TEMPLATE = """
{context}
{history}
User: {question}
Assistant:
"""
class OllamaClient:
def __init__(self) -> None:
with open(default_system_prompt_path) as file:
self.system_prompt = file.read()
self.client = get_client(self.system_prompt)
self.history = ChatHistory()
def generate(self, question, uid) -> str:
context_content = self.history.get_context(question, uid)
context = f"CONTEXT: {context_content}" if len(context_content) > 0 else ''
last_few = self.history.get_last_few(uid)
history = f"LAST MESSAGES: {last_few}" if len(last_few) > 0 else ''
prompt = PROMPT_TEMPLATE.format(
context=context,
history=history,
question=question
)
answer = self.client.invoke({"question": question, "context": context, "history": history})['generation']
print("answer:", answer)
self.history.append(question, answer, uid)
return answer
def reset_history(self):
self.history.reset_history()
def set_model(self, model_name):
self.client.model = model_name
def set_host(self, host):
self.client.base_url = host
def set_system(self, system):
self.client.system = system
def get_status(self):
return (
self.client.base_url,
self.client.model,
self.client.system
)