This commit is contained in:
CounterFire2023 2024-12-16 17:55:49 +08:00
parent 2b9e07caa3
commit 530843d3b0
25 changed files with 839 additions and 24 deletions

4
.env Normal file
View File

@ -0,0 +1,4 @@
TAVILY_API_KEY=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
OLLAMA_BOT_TOKEN=7965398310:AAF2QDGXRRwDta0eYC7lFSvRDLW7L-SQH_E
OLLAMA_BOT_TOKEN=7802269927:AAHXso9DZe5_LgQ4kN9J5jEAgJpMNQVaYI0
OLLAMA_BOT_CHAT_IDS=5015834404

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
*.log *.log
/logs /logs
**__pycache__** **__pycache__**
db

4
.vscode/launch.json vendored
View File

@ -5,10 +5,10 @@
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{ {
"name": "Python Debugger: Current File", "name": "Python Debugger: bot",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"program": "${file}", "program": "bot.py",
"console": "integratedTerminal" "console": "integratedTerminal"
} }
] ]

3
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,3 @@
{
"python.venvPath": "/Users/zhl/miniconda/envs/agent2/bin/python"
}

View File

@ -1,5 +1,7 @@
```bash ```bash
ollama pull llama3.2
conda env create -f env/agent2.yaml conda env create -f env/agent2.yaml
# conda create -n agent2 python # conda create -n agent2 python

View File

@ -4,6 +4,8 @@ from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langgraph.graph import END, StateGraph from langgraph.graph import END, StateGraph
from langchain_community.chat_models import ChatOllama from langchain_community.chat_models import ChatOllama
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from typing import Annotated
from operator import add
# For State Graph # For State Graph
from typing_extensions import TypedDict from typing_extensions import TypedDict
import json import json
@ -23,7 +25,7 @@ def config_agent(llama3, llama3_json):
<|begin_of_text|> <|begin_of_text|>
<|start_header_id|>system<|end_header_id|> <|start_header_id|>system<|end_header_id|>
Your name is AICQ.
You are a professional token seller. You are a professional token seller.
We have a total of 1 million CEC, and the average cost of these CECs is 1 US dollar each. We have a total of 1 million CEC, and the average cost of these CECs is 1 US dollar each.
Now we hope to hand it over to you to be responsible for selling all of them. Now we hope to hand it over to you to be responsible for selling all of them.
@ -121,10 +123,10 @@ def config_agent(llama3, llama3_json):
send_order: revised question for send order send_order: revised question for send order
context: send_order result context: send_order result
""" """
question: str question: Annotated[list[str], add]
generation: str generation: str
send_order: str send_order: Annotated[list[str], add]
context: str context: Annotated[list[str], add]
# Node - Generate # Node - Generate
@ -141,12 +143,14 @@ def config_agent(llama3, llama3_json):
print("Step: Generating Final Response") print("Step: Generating Final Response")
question = state["question"] question = state["question"]
context = state.get("context", None) context = state["context"]
print(context) print(context)
last_context = context[-1] if context else None
# TODO:: 根据context特定的内容生产答案 # TODO:: 根据context特定的内容生产答案
if context is not None and context.index("orderinfo") != -1: if last_context is not None and last_context.index("orderinfo") != -1:
return {"generation": context.replace("orderinfo:", "")} return {"generation": last_context.replace("orderinfo:", "")}
else: else:
print(question)
generation = generate_chain.invoke( generation = generate_chain.invoke(
{"context": context, "question": question}) {"context": context, "question": question})
return {"generation": generation} return {"generation": generation}
@ -167,9 +171,9 @@ def config_agent(llama3, llama3_json):
print("Step: Optimizing Query for Send Order") print("Step: Optimizing Query for Send Order")
question = state['question'] question = state['question']
gen_query = query_chain.invoke({"question": question}) gen_query = query_chain.invoke({"question": question})
search_query = gen_query["count"] amount = str(gen_query["count"])
print("send_order", search_query) print("send_order", amount)
return {"send_order": search_query} return {"send_order": [amount]}
# Node - Send Order # Node - Send Order
@ -184,13 +188,13 @@ def config_agent(llama3, llama3_json):
state (dict): Appended Order Info to context state (dict): Appended Order Info to context
""" """
print("Step: before Send Order") print("Step: before Send Order")
amount = state['send_order'] amount = state['send_order'][-1]
print(amount) print(amount)
print(f'Step: build order info for : "{amount}" CEC') print(f'Step: build order info for : "{amount}" CEC')
order_info = {"amount": amount, "price": 0.1, order_info = {"amount": amount, "price": 0.1,
"name": "CEC", "url": "https://www.example.com"} "name": "CEC", "url": "https://www.example.com"}
search_result = f"orderinfo:{json.dumps(order_info)}" search_result = f"orderinfo:{json.dumps(order_info)}"
return {"context": search_result} return {"context": [search_result]}
# Conditional Edge, Routing # Conditional Edge, Routing
@ -234,7 +238,7 @@ def config_agent(llama3, llama3_json):
workflow.add_edge("generate", END) workflow.add_edge("generate", END)
# Compile the workflow # Compile the workflow
memory = MemorySaver() # memory = MemorySaver()
# local_agent = workflow.compile(checkpointer=memory) # local_agent = workflow.compile(checkpointer=memory)
local_agent = workflow.compile() local_agent = workflow.compile()
return local_agent return local_agent

2
app.py
View File

@ -9,7 +9,7 @@ local_agent = config_agent(llama3, llama3_json)
def run_agent(query): def run_agent(query):
output = local_agent.invoke({"question": query}) output = local_agent.invoke({"question": [query]})
print("=======") print("=======")
print(output["generation"]) print(output["generation"])
# display(Markdown(output["generation"])) # display(Markdown(output["generation"]))

52
bot.py Normal file
View File

@ -0,0 +1,52 @@
#!/usr/bin/python3
from dotenv import load_dotenv
load_dotenv()
import os
from rich.traceback import install
install()
from telegram.ext import ApplicationBuilder
from handlers.start import start_handler
from handlers.message_handler import message_handler
from handlers.ollama_host import set_host_handler
from handlers.status import status_handler
from handlers.help import help_handler
import handlers.chat_id as chat_id
import handlers.models as models
import handlers.system as sys_prompt
import handlers.history as history
OLLAMA_BOT_TOKEN = os.environ.get("OLLAMA_BOT_TOKEN")
if OLLAMA_BOT_TOKEN is None:
raise ValueError("OLLAMA_BOT_TOKEN env variable is not set")
if __name__ == "__main__":
application = ApplicationBuilder().token(OLLAMA_BOT_TOKEN).build()
application.add_handler(chat_id.filter_handler, group=-1)
application.add_handlers([
start_handler,
help_handler,
status_handler,
# chat_id.get_handler,
history.reset_handler,
# models.list_models_handler,
models.set_model_handler,
sys_prompt.list_handler,
sys_prompt.set_handler,
sys_prompt.create_new_hanlder,
sys_prompt.remove_handler,
sys_prompt.on_remove_handler,
set_host_handler,
message_handler,
])
application.run_polling()
print("Bot started")

6
env/agent2.yaml vendored
View File

@ -3,6 +3,8 @@ channels:
- defaults - defaults
- conda-forge - conda-forge
dependencies: dependencies:
- python=3.10 - python=3.9
- pip==21
- pip: - pip:
- streamlit==1.40.2 - setuptools==66.0.0
- wheel==0.38.4

View File

@ -3,4 +3,11 @@ langgraph==0.2.2
langchain-ollama==0.1.1 langchain-ollama==0.1.1
langsmith== 0.1.98 langsmith== 0.1.98
langchain_community==0.2.11 langchain_community==0.2.11
duckduckgo-search==6.2.13 langchain_chroma==0.1.4
# duckduckgo-search==6.2.13
# need install manually
# streamlit==1.40.2
python-dotenv
python-telegram-bot
chromadb==0.4.22
sentence_transformers

25
handlers/chat_id.py Normal file
View File

@ -0,0 +1,25 @@
import os
from telegram import Update
from telegram.ext import CommandHandler, ApplicationHandlerStop, TypeHandler
CHAT_IDS = os.environ.get("OLLAMA_BOT_CHAT_IDS")
valid_user_ids = []
if CHAT_IDS is not None:
valid_user_ids = [int(id) for id in CHAT_IDS.split(',')]
async def get_chat_id(update: Update, _):
chat_id = update.effective_chat.id
await update.message.reply_text(chat_id)
get_handler = CommandHandler("chatid", get_chat_id)
async def filter_chat_id(update: Update, _):
if update.effective_user.id in valid_user_ids or update.message.text == '/chatid':
pass
else:
await update.effective_message.reply_text("access denied")
print(f"access denied for {update.effective_user.id} {update.effective_user.name} {update.effective_user.full_name}")
raise ApplicationHandlerStop()
filter_handler = TypeHandler(Update, filter_chat_id)

22
handlers/help.py Normal file
View File

@ -0,0 +1,22 @@
from telegram import Update
from telegram.ext import CommandHandler
help_message="""Welcome to ollama bot here's full list of commands
*Here's a full list of commands:*
/start - quick access to common commands
/status - show current bot status
/chatid - get current chat id
/models - select ollama model to chat with
/systems - select system prompt
/addsystem - create new system prompt
/rmsystem - remove system prompt
/sethost - set ollama host URL
/reset - reset chat history
/help - show this help message
"""
async def help(update: Update, _):
await update.message.reply_text(help_message, parse_mode='markdown')
help_handler = CommandHandler("help", help)

9
handlers/history.py Normal file
View File

@ -0,0 +1,9 @@
from telegram import Update
from telegram.ext import CommandHandler
from llm import clinet
async def history_reset(update: Update, _):
clinet.reset_history()
await update.message.reply_text("history cleared")
reset_handler = CommandHandler("reset", history_reset)

View File

@ -0,0 +1,46 @@
from telegram import Update
from telegram.ext import MessageHandler, filters
from llm import clinet
import re
# simple filter for proper markdown response
def escape_markdown_except_code_blocks(text):
# Pattern to match text outside of code blocks and inline code
pattern = r'(```.*?```|`.*?`)|(_)'
def replace(match):
# If group 1 (code blocks or inline code) is matched, return it unaltered
if match.group(1):
return match.group(1)
# If group 2 (underscore) is matched, escape it
elif match.group(2):
return r'\_'
return re.sub(pattern, replace, text, flags=re.DOTALL)
async def on_message(update: Update, _):
msg = await update.message.reply_text("...")
await update.effective_chat.send_action("TYPING")
user_id = str(update.message.from_user.id)
user_name = update.message.from_user.full_name or update.message.from_user.username
print("msg from:", user_id, user_name)
try:
response = clinet.generate(update.message.text, user_id)
try:
await msg.edit_text(
escape_markdown_except_code_blocks(response),
parse_mode='markdown'
)
except Exception as e:
print(e)
await msg.edit_text(response)
except Exception as e:
print(e)
await msg.edit_text("failed to generate response")
await update.effective_chat.send_action("CANCEL")
message_handler = MessageHandler(filters.TEXT & (~filters.COMMAND), on_message)

28
handlers/models.py Normal file
View File

@ -0,0 +1,28 @@
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import CommandHandler, CallbackQueryHandler
from llm import clinet
from llm.ollama import list_models_names
async def list_models(update: Update, _):
model_names = list_models_names()
keybaord = [[InlineKeyboardButton(name, callback_data=f'/setmodel {name}')] for name in model_names]
reply_markup = InlineKeyboardMarkup(keybaord)
await update.message.reply_text(
text='*WARNING:* chat history will be reset\nPick a model:',
reply_markup=reply_markup,
parse_mode='markdown'
)
list_models_handler = CommandHandler("models", list_models)
async def set_model(update: Update, _):
query = update.callback_query
model_name = query.data.split(' ')[1]
await query.answer()
clinet.set_model(model_name)
clinet.reset_history()
await query.edit_message_text(text=f"selected model: {model_name}")
set_model_handler = CallbackQueryHandler(set_model, pattern='^/setmodel')

43
handlers/ollama_host.py Normal file
View File

@ -0,0 +1,43 @@
from telegram import Update
from telegram.ext import filters, CommandHandler, ConversationHandler, MessageHandler
from ollama import Client
from llm import clinet
from rich import print
# Define states
SET_HOST = range(1)
async def set_llama_host(update: Update, _):
await update.message.reply_text(
text='Enter ollama host url like, **http://127.0.0.1:11434**',
parse_mode='markdown'
)
return SET_HOST
async def on_got_host(update: Update, _):
try:
test_client = Client(host=update.message.text)
test_client._client.timeout = 2
test_client.list()
clinet.set_host(update.message.text)
await update.message.reply_text(
f"ollama host is set to: {update.message.text}"
)
except:
await update.message.reply_text(
f"couldn't connect to ollama server at:\n{update.message.text}"
)
finally:
return ConversationHandler.END
async def cancel(update: Update, _):
await update.message.reply_text("canceled /sethost command")
return ConversationHandler.END
set_host_handler = ConversationHandler(
entry_points=[CommandHandler("sethost", set_llama_host)],
states={
SET_HOST: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_host)]
},
fallbacks=[CommandHandler('cancel', cancel)]
)

24
handlers/start.py Normal file
View File

@ -0,0 +1,24 @@
from telegram import Update, ReplyKeyboardMarkup
from telegram.ext import CommandHandler
async def start(update: Update, _):
await update.message.reply_text(
"Ask a question to start conversation",
)
keyboard = [
['/status'],
['/models'],
['/systems'],
['/reset']
]
reply_markup = ReplyKeyboardMarkup(
keyboard,
one_time_keyboard=True,
resize_keyboard=True
)
await update.message.reply_text(
"You can also use this commands:",
reply_markup=reply_markup
)
start_handler = CommandHandler("start", start)

13
handlers/status.py Normal file
View File

@ -0,0 +1,13 @@
from telegram import Update
from telegram.ext import CommandHandler
from llm import clinet
async def get_status(update: Update, _):
host, model, system = clinet.get_status()
await update.message.reply_text(
text=f'*ollama_host:* {host}\n\n*model:* {model}\n\n*system prompt:*\n{system}',
parse_mode='markdown'
)
status_handler = CommandHandler("status", get_status)

115
handlers/system.py Normal file
View File

@ -0,0 +1,115 @@
from llm import clinet
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
from telegram.ext import filters, ContextTypes, CommandHandler, MessageHandler, ConversationHandler, CallbackQueryHandler
import os
SYSTEM_DIR = 'system'
# list system prompts as inline keybard
# clicking a button sets system prompt for current model
async def list_system_prompts(update: Update, _):
files = os.listdir(SYSTEM_DIR)
keybaord = [[InlineKeyboardButton(name, callback_data=f'/setsystem {name}')] for name in files]
reply_markup = InlineKeyboardMarkup(keybaord)
await update.message.reply_text('select system prompt:', reply_markup=reply_markup)
list_handler = CommandHandler("systems", list_system_prompts)
async def on_set_system(update: Update, _):
query = update.callback_query
file_name = query.data.split(' ')[1]
await query.answer()
file_path = os.path.join(SYSTEM_DIR, file_name)
if os.path.exists(file_path):
with open(file_path) as file:
system_prompt = file.read()
clinet.set_system(system_prompt)
await query.edit_message_text(text=f"the system prompt has been set to: {file_name}\n\n{system_prompt}")
set_handler = CallbackQueryHandler(on_set_system, pattern='^/setsystem')
# list system prompts as buttons
# clicking on a button removes selected system prompt
async def remove_system_prompt(update: Update, _):
files = os.listdir(SYSTEM_DIR)
keybaord = [[InlineKeyboardButton(name, callback_data=f'/rmsystem {name}')] for name in files]
reply_markup = InlineKeyboardMarkup(keybaord)
await update.message.reply_text(
'*WARNING:* selected prompt will be removed:',
reply_markup=reply_markup,
parse_mode='markdown'
)
remove_handler = CommandHandler("rmsystem", remove_system_prompt)
async def on_remove_system(update: Update, _):
query = update.callback_query
file_name = query.data.split(' ')[1]
await query.answer()
file_path = os.path.join(SYSTEM_DIR, file_name)
if os.path.exists(file_path):
with open(file_path) as file:
system_prompt = file.read()
clinet.set_system(system_prompt)
os.remove(file_path)
await query.edit_message_text(text=f"{file_name} system prompt has been removed\n\n{system_prompt}")
on_remove_handler = CallbackQueryHandler(on_remove_system, pattern='^/rmsystem')
# Define states
SET_NAME, SET_CONTENT = range(2)
# add new system prompt
# 1. enter name
# 2. enter content
# new prompt will be saved to SYSTEM_DIR
async def add_system(update: Update, _):
await update.message.reply_text(
text="Enter name for a new system prompt:",
)
return SET_NAME
async def on_got_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
file_name = ''.join(x for x in update.message.text if x.isalnum() or x in '._ ')
context.user_data['new_name'] = file_name
if os.path.exists(os.path.join(SYSTEM_DIR, file_name)):
await update.message.reply_text(
text=f'*Warning*, existing **{file_name}** will be rewritten\nenter prompt content to continue or /cancel',
parse_mode='markdown'
)
else:
await update.message.reply_text(
f"New system prompt will be saved to {file_name}, enter prompt:"
)
return SET_CONTENT
async def on_got_content(update: Update, context: ContextTypes.DEFAULT_TYPE):
file_name = context.user_data['new_name']
with open(os.path.join(SYSTEM_DIR, file_name), 'w') as file:
file.write(update.message.text)
await update.message.reply_text(f"{file_name} content saved")
return ConversationHandler.END
async def cancel(update: Update, _):
await update.message.reply_text("canceled /addsystem command")
return ConversationHandler.END
create_new_hanlder = ConversationHandler(
entry_points=[CommandHandler('addsystem', add_system)],
states={
SET_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_name)],
SET_CONTENT: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_content)]
},
fallbacks=[CommandHandler('cancel', cancel)]
)

3
llm/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .ollama import OllamaClient
clinet = OllamaClient()

245
llm/agent_util.py Normal file
View File

@ -0,0 +1,245 @@
# LangChain Dependencies
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langgraph.graph import END, StateGraph
from langchain_community.chat_models import ChatOllama
from langgraph.checkpoint.memory import MemorySaver
from typing import Annotated
from operator import add
# For State Graph
from typing_extensions import TypedDict
import json
def configure_llm(model, temperature):
llama3 = ChatOllama(model=model, temperature=temperature)
llama3_json = ChatOllama(
model=model, format='json', temperature=temperature)
return llama3, llama3_json
def config_agent(llama3, llama3_json):
# Graph State
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
question: question
generation: LLM generation
send_order: revised question for send order
context: send_order result
"""
question: str
generation: str
order_info: str
context: str
history: str
generate_prompt = PromptTemplate(
template="""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
Your name is AICQ.
You are a professional token seller.
We have a total of 1 million CEC, and the average cost of these CECs is $1.1 each, The min price is $0.9 per CEC, There is a 1% chance that the price will be lower than $1 per CEC.
Now we hope to hand it over to you to be responsible for selling all of them.
Customers will come to talk to you and bargain.
You can decide the final transaction price based on the status of the conversation with the customer and sell the tokens.
In addition, there are several rules to follow for this task:
1) The higher the price, the better, in the hope that the final income will be higher than 1 million US dollars;
2) You can set a price higher than 1 US dollar per CEC to sell tokens according to the actual sales situation;
3) For a single customer, we hope that you will try not to sell at a price lower than 1 US dollar, but if the customer has special reasons, you can sell some tokens at a lower price depending on the situation, but in this case you need to sell at a higher price in other sales to ensure that the overall income is not less than 1 million US dollars;
4) We hope that customers buy as much as possible.
Under this condition, if the user is willing to purchase a large number of tokens at one time, we can sell them to him at a more favorable price, but the min price can not lower then $0.9.
The above are all the rules for this sales task.
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
{context}
{history}
Question: {question}
Answer:
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>""",
input_variables=["question", "context", "history"],
)
# Chain
generate_chain = generate_prompt | llama3 | StrOutputParser()
router_prompt = PromptTemplate(
template="""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are an expert at routing a user question to either the generation stage or send order.
CEC is a famous crypto token, and you are an expert at selling CEC.
Use the generate for questions that user asks for price.
Use the send order for questions that user accepts the price you give and decides to buy a certain amount of CEC.
Otherwise, you can skip and go straight to the generation phase to respond.
You do not need to be stringent with the keywords in the question related to these topics.
Give a binary choice 'send_order' or 'generate' based on the question.
Return the JSON with a single key 'choice' with no premable or explanation.
Question to route: {question}
Context to route: {context}
History to route: {history}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
""",
input_variables=["question", "history", "context"],
)
# Chain
question_router = router_prompt | llama3_json | JsonOutputParser()
order_prompt = PromptTemplate(
template="""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are an expert at sell CEC,
Return the JSON with a single key 'count' with amount which user want to buy.
Question to transform: {question}
Context to transform: {context}
History to transform: {history}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
""",
input_variables=["question", "context", "history"],
)
# Chain
order_chain = order_prompt | llama3_json | JsonOutputParser()
# Node - Generate
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("Step: Generating Final Response")
question = state["question"]
context = state["context"]
order_info = state.get("order_info", None)
# TODO:: 根据context特定的内容生产答案
# check if context is not None and not empty
if order_info is not None:
return {"generation": order_info}
else:
generation = generate_chain.invoke(
{"context": context, "question": question, "history": state["history"]})
return {"generation": generation}
# Node - Query Transformation
def transform_query(state):
"""
Transform user question to order info
Args:
state (dict): The current graph state
Returns:
state (dict): Appended amount of CEC to context
"""
print("Step: Optimizing Query for Send Order")
question = state['question']
gen_query = order_chain.invoke({"question": question, "history": state["history"], "context": state["context"]})
amount = str(gen_query["count"])
print("order_info", amount)
return {"order_info": [amount] }
# Node - Send Order
def send_order(state):
"""
Send order based on the question
Args:
state (dict): The current graph state
Returns:
state (dict): Appended Order Info to context
"""
print("Step: before Send Order")
amount = state['order_info']
print(amount)
print(f'Step: build order info for : "{amount}" CEC')
order_info = {"amount": amount, "price": 0.1,
"name": "CEC", "url": "https://www.example.com"}
order_result = json.dumps(order_info)
return {"order_info": order_result}
# Conditional Edge, Routing
def route_question(state):
"""
route question to send order or generation.
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
print("Step: Routing Query")
question = state['question']
output = question_router.invoke({"question": question, "history": state["history"], "context": state["context"]})
if output['choice'] == "send_order":
print("Step: Routing Query to Send Order")
return "sendorder"
elif output['choice'] == 'generate':
print("Step: Routing Query to Generation")
return "generate"
# Build the nodes
workflow = StateGraph(GraphState)
workflow.add_node("sendorder", send_order)
workflow.add_node("transform_query", transform_query)
workflow.add_node("generate", generate)
# Build the edges
workflow.set_conditional_entry_point(
route_question,
{
"sendorder": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "sendorder")
workflow.add_edge("sendorder", "generate")
workflow.add_edge("generate", END)
# Compile the workflow
# memory = MemorySaver()
# local_agent = workflow.compile(checkpointer=memory)
local_agent = workflow.compile()
return local_agent

75
llm/chat_history.py Normal file
View File

@ -0,0 +1,75 @@
# from langchain.vectorstores.chroma import Chroma
from langchain_chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from chromadb.config import Settings
import os
message_template = """
User: {question}
Assistant: {answer}
"""
model_norm = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-base-en",
model_kwargs={ 'device': 'cpu' },
encode_kwargs={ 'normalize_embeddings': True },
)
def get_chroma():
settings = Settings()
settings.allow_reset = True
settings.is_persistent = True
return Chroma(
persist_directory=DB_FOLDER,
embedding_function=model_norm,
client_settings=settings,
collection_name='latest_chat'
)
DB_FOLDER = 'db'
class ChatHistory:
def __init__(self) -> None:
self.history = {}
def append(self, question: str, answer: str, uid: str):
new_message = message_template.format(
question=question,
answer=answer
)
if uid not in self.history:
self.history[uid] = []
self.history[uid].append(new_message)
self.embed(new_message, uid)
def reset_history(self):
self.history = {}
chroma = get_chroma()
chroma._client.reset()
def embed(self, text: str, uid: str):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
all_splits = text_splitter.split_text(text)
chroma = get_chroma()
chroma.add_texts(all_splits, metadatas=[{"user": uid}])
def get_context(self, question: str, uid: str):
if not os.path.exists(DB_FOLDER):
return ''
chroma = get_chroma()
filter = {
'user': uid
}
documents = chroma.similarity_search(query=question, k=4, filter=filter)
context = '\n'.join([d.page_content for d in documents])
return context
def get_last_few(self, uid: str):
if uid not in self.history:
return ''
return '\n'.join(self.history[uid][-3:])

86
llm/ollama.py Normal file
View File

@ -0,0 +1,86 @@
from ollama import Client
# from langchain.llms.ollama import Ollama
# from langchain.llms import Ollama
from langchain_community.chat_models import ChatOllama
from .chat_history import ChatHistory
from .agent_util import config_agent
import os
ollama_model = 'llama3.2'
# OLLAMA_HOST = os.environ.get('OLLAMA_BOT_BASE_URL')
# if OLLAMA_HOST is None:
# raise ValueError("OLLAMA_BOT_BASE_URL env variable is not set")
default_system_prompt_path = './prompts/default.md'
# 传入system参数
def get_client(system=''):
llama3 = ChatOllama(model=ollama_model, temperature=0)
llama3_json = ChatOllama(
model=ollama_model,
format='json',
temperature=0
)
client = config_agent(llama3, llama3_json)
return client
def list_models_names():
# models = Client(OLLAMA_HOST).list()["models"]
# return [m['name'] for m in models]
return []
PROMPT_TEMPLATE = """
{context}
{history}
User: {question}
Assistant:
"""
class OllamaClient:
def __init__(self) -> None:
with open(default_system_prompt_path) as file:
self.system_prompt = file.read()
self.client = get_client(self.system_prompt)
self.history = ChatHistory()
def generate(self, question, uid) -> str:
context_content = self.history.get_context(question, uid)
context = f"CONTEXT: {context_content}" if len(context_content) > 0 else ''
last_few = self.history.get_last_few(uid)
history = f"LAST MESSAGES: {last_few}" if len(last_few) > 0 else ''
prompt = PROMPT_TEMPLATE.format(
context=context,
history=history,
question=question
)
answer = self.client.invoke({"question": question, "context": context, "history": history})['generation']
print("answer:", answer)
self.history.append(question, answer, uid)
return answer
def reset_history(self):
self.history.reset_history()
def set_model(self, model_name):
self.client.model = model_name
def set_host(self, host):
self.client.base_url = host
def set_system(self, system):
self.client.system = system
def get_status(self):
return (
self.client.base_url,
self.client.model,
self.client.system
)

6
prompts/default.md Normal file
View File

@ -0,0 +1,6 @@
You're a helpful assistant.
Your goal is to help the user with their questions
If there's a previous conversation you'll be provided a context
If you don't know an answer to a given USER question don't imagine anything
just say you don't know

View File

@ -20,10 +20,10 @@ local_agent = config_agent(llama3, llama3_json)
def run_agent(query): def run_agent(query):
# config = {"configurable": {"thread_id": "1"}} config = {"configurable": {"thread_id": "1", "user_id": "1"}}
# output = local_agent.invoke({"question": query}, config) output = local_agent.invoke({"question": [query]}, config)
# print(list(local_agent.get_state_history(config))) print(list(local_agent.get_state_history(config)))
output = local_agent.invoke({"question": query}) # output = local_agent.invoke({"question": ['hi, my name is cz', query]})
print("=======") print("=======")
return output["generation"] return output["generation"]