save tmp
This commit is contained in:
parent
2b9e07caa3
commit
530843d3b0
4
.env
Normal file
4
.env
Normal file
@ -0,0 +1,4 @@
|
||||
TAVILY_API_KEY=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
|
||||
OLLAMA_BOT_TOKEN=7965398310:AAF2QDGXRRwDta0eYC7lFSvRDLW7L-SQH_E
|
||||
OLLAMA_BOT_TOKEN=7802269927:AAHXso9DZe5_LgQ4kN9J5jEAgJpMNQVaYI0
|
||||
OLLAMA_BOT_CHAT_IDS=5015834404
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
*.log
|
||||
/logs
|
||||
**__pycache__**
|
||||
**__pycache__**
|
||||
db
|
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@ -5,10 +5,10 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: Current File",
|
||||
"name": "Python Debugger: bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"program": "bot.py",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
|
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"python.venvPath": "/Users/zhl/miniconda/envs/agent2/bin/python"
|
||||
}
|
@ -1,5 +1,7 @@
|
||||
|
||||
```bash
|
||||
ollama pull llama3.2
|
||||
|
||||
conda env create -f env/agent2.yaml
|
||||
# conda create -n agent2 python
|
||||
|
||||
|
@ -4,6 +4,8 @@ from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from typing import Annotated
|
||||
from operator import add
|
||||
# For State Graph
|
||||
from typing_extensions import TypedDict
|
||||
import json
|
||||
@ -23,7 +25,7 @@ def config_agent(llama3, llama3_json):
|
||||
<|begin_of_text|>
|
||||
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Your name is AICQ.
|
||||
You are a professional token seller.
|
||||
We have a total of 1 million CEC, and the average cost of these CECs is 1 US dollar each.
|
||||
Now we hope to hand it over to you to be responsible for selling all of them.
|
||||
@ -121,10 +123,10 @@ def config_agent(llama3, llama3_json):
|
||||
send_order: revised question for send order
|
||||
context: send_order result
|
||||
"""
|
||||
question: str
|
||||
question: Annotated[list[str], add]
|
||||
generation: str
|
||||
send_order: str
|
||||
context: str
|
||||
send_order: Annotated[list[str], add]
|
||||
context: Annotated[list[str], add]
|
||||
|
||||
# Node - Generate
|
||||
|
||||
@ -141,12 +143,14 @@ def config_agent(llama3, llama3_json):
|
||||
|
||||
print("Step: Generating Final Response")
|
||||
question = state["question"]
|
||||
context = state.get("context", None)
|
||||
context = state["context"]
|
||||
print(context)
|
||||
last_context = context[-1] if context else None
|
||||
# TODO:: 根据context特定的内容生产答案
|
||||
if context is not None and context.index("orderinfo") != -1:
|
||||
return {"generation": context.replace("orderinfo:", "")}
|
||||
if last_context is not None and last_context.index("orderinfo") != -1:
|
||||
return {"generation": last_context.replace("orderinfo:", "")}
|
||||
else:
|
||||
print(question)
|
||||
generation = generate_chain.invoke(
|
||||
{"context": context, "question": question})
|
||||
return {"generation": generation}
|
||||
@ -167,9 +171,9 @@ def config_agent(llama3, llama3_json):
|
||||
print("Step: Optimizing Query for Send Order")
|
||||
question = state['question']
|
||||
gen_query = query_chain.invoke({"question": question})
|
||||
search_query = gen_query["count"]
|
||||
print("send_order", search_query)
|
||||
return {"send_order": search_query}
|
||||
amount = str(gen_query["count"])
|
||||
print("send_order", amount)
|
||||
return {"send_order": [amount]}
|
||||
|
||||
# Node - Send Order
|
||||
|
||||
@ -184,13 +188,13 @@ def config_agent(llama3, llama3_json):
|
||||
state (dict): Appended Order Info to context
|
||||
"""
|
||||
print("Step: before Send Order")
|
||||
amount = state['send_order']
|
||||
amount = state['send_order'][-1]
|
||||
print(amount)
|
||||
print(f'Step: build order info for : "{amount}" CEC')
|
||||
order_info = {"amount": amount, "price": 0.1,
|
||||
"name": "CEC", "url": "https://www.example.com"}
|
||||
search_result = f"orderinfo:{json.dumps(order_info)}"
|
||||
return {"context": search_result}
|
||||
return {"context": [search_result]}
|
||||
|
||||
# Conditional Edge, Routing
|
||||
|
||||
@ -234,7 +238,7 @@ def config_agent(llama3, llama3_json):
|
||||
workflow.add_edge("generate", END)
|
||||
|
||||
# Compile the workflow
|
||||
memory = MemorySaver()
|
||||
# memory = MemorySaver()
|
||||
# local_agent = workflow.compile(checkpointer=memory)
|
||||
local_agent = workflow.compile()
|
||||
return local_agent
|
||||
|
2
app.py
2
app.py
@ -9,7 +9,7 @@ local_agent = config_agent(llama3, llama3_json)
|
||||
|
||||
|
||||
def run_agent(query):
|
||||
output = local_agent.invoke({"question": query})
|
||||
output = local_agent.invoke({"question": [query]})
|
||||
print("=======")
|
||||
print(output["generation"])
|
||||
# display(Markdown(output["generation"]))
|
||||
|
52
bot.py
Normal file
52
bot.py
Normal file
@ -0,0 +1,52 @@
|
||||
#!/usr/bin/python3
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
import os
|
||||
from rich.traceback import install
|
||||
install()
|
||||
|
||||
from telegram.ext import ApplicationBuilder
|
||||
|
||||
from handlers.start import start_handler
|
||||
from handlers.message_handler import message_handler
|
||||
from handlers.ollama_host import set_host_handler
|
||||
from handlers.status import status_handler
|
||||
from handlers.help import help_handler
|
||||
|
||||
import handlers.chat_id as chat_id
|
||||
import handlers.models as models
|
||||
import handlers.system as sys_prompt
|
||||
import handlers.history as history
|
||||
|
||||
|
||||
OLLAMA_BOT_TOKEN = os.environ.get("OLLAMA_BOT_TOKEN")
|
||||
if OLLAMA_BOT_TOKEN is None:
|
||||
raise ValueError("OLLAMA_BOT_TOKEN env variable is not set")
|
||||
|
||||
if __name__ == "__main__":
|
||||
application = ApplicationBuilder().token(OLLAMA_BOT_TOKEN).build()
|
||||
|
||||
application.add_handler(chat_id.filter_handler, group=-1)
|
||||
|
||||
application.add_handlers([
|
||||
start_handler,
|
||||
help_handler,
|
||||
status_handler,
|
||||
|
||||
# chat_id.get_handler,
|
||||
history.reset_handler,
|
||||
|
||||
# models.list_models_handler,
|
||||
models.set_model_handler,
|
||||
|
||||
sys_prompt.list_handler,
|
||||
sys_prompt.set_handler,
|
||||
sys_prompt.create_new_hanlder,
|
||||
sys_prompt.remove_handler,
|
||||
sys_prompt.on_remove_handler,
|
||||
set_host_handler,
|
||||
message_handler,
|
||||
])
|
||||
|
||||
application.run_polling()
|
||||
print("Bot started")
|
6
env/agent2.yaml
vendored
6
env/agent2.yaml
vendored
@ -3,6 +3,8 @@ channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.10
|
||||
- python=3.9
|
||||
- pip==21
|
||||
- pip:
|
||||
- streamlit==1.40.2
|
||||
- setuptools==66.0.0
|
||||
- wheel==0.38.4
|
||||
|
9
env/requirements.txt
vendored
9
env/requirements.txt
vendored
@ -3,4 +3,11 @@ langgraph==0.2.2
|
||||
langchain-ollama==0.1.1
|
||||
langsmith== 0.1.98
|
||||
langchain_community==0.2.11
|
||||
duckduckgo-search==6.2.13
|
||||
langchain_chroma==0.1.4
|
||||
# duckduckgo-search==6.2.13
|
||||
# need install manually
|
||||
# streamlit==1.40.2
|
||||
python-dotenv
|
||||
python-telegram-bot
|
||||
chromadb==0.4.22
|
||||
sentence_transformers
|
25
handlers/chat_id.py
Normal file
25
handlers/chat_id.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
from telegram import Update
|
||||
from telegram.ext import CommandHandler, ApplicationHandlerStop, TypeHandler
|
||||
|
||||
CHAT_IDS = os.environ.get("OLLAMA_BOT_CHAT_IDS")
|
||||
valid_user_ids = []
|
||||
if CHAT_IDS is not None:
|
||||
valid_user_ids = [int(id) for id in CHAT_IDS.split(',')]
|
||||
|
||||
|
||||
async def get_chat_id(update: Update, _):
|
||||
chat_id = update.effective_chat.id
|
||||
await update.message.reply_text(chat_id)
|
||||
|
||||
get_handler = CommandHandler("chatid", get_chat_id)
|
||||
|
||||
async def filter_chat_id(update: Update, _):
|
||||
if update.effective_user.id in valid_user_ids or update.message.text == '/chatid':
|
||||
pass
|
||||
else:
|
||||
await update.effective_message.reply_text("access denied")
|
||||
print(f"access denied for {update.effective_user.id} {update.effective_user.name} {update.effective_user.full_name}")
|
||||
raise ApplicationHandlerStop()
|
||||
|
||||
filter_handler = TypeHandler(Update, filter_chat_id)
|
22
handlers/help.py
Normal file
22
handlers/help.py
Normal file
@ -0,0 +1,22 @@
|
||||
from telegram import Update
|
||||
from telegram.ext import CommandHandler
|
||||
|
||||
help_message="""Welcome to ollama bot here's full list of commands
|
||||
|
||||
*Here's a full list of commands:*
|
||||
/start - quick access to common commands
|
||||
/status - show current bot status
|
||||
/chatid - get current chat id
|
||||
/models - select ollama model to chat with
|
||||
/systems - select system prompt
|
||||
/addsystem - create new system prompt
|
||||
/rmsystem - remove system prompt
|
||||
/sethost - set ollama host URL
|
||||
/reset - reset chat history
|
||||
/help - show this help message
|
||||
"""
|
||||
|
||||
async def help(update: Update, _):
|
||||
await update.message.reply_text(help_message, parse_mode='markdown')
|
||||
|
||||
help_handler = CommandHandler("help", help)
|
9
handlers/history.py
Normal file
9
handlers/history.py
Normal file
@ -0,0 +1,9 @@
|
||||
from telegram import Update
|
||||
from telegram.ext import CommandHandler
|
||||
from llm import clinet
|
||||
|
||||
async def history_reset(update: Update, _):
|
||||
clinet.reset_history()
|
||||
await update.message.reply_text("history cleared")
|
||||
|
||||
reset_handler = CommandHandler("reset", history_reset)
|
46
handlers/message_handler.py
Normal file
46
handlers/message_handler.py
Normal file
@ -0,0 +1,46 @@
|
||||
from telegram import Update
|
||||
from telegram.ext import MessageHandler, filters
|
||||
from llm import clinet
|
||||
import re
|
||||
|
||||
# simple filter for proper markdown response
|
||||
def escape_markdown_except_code_blocks(text):
|
||||
# Pattern to match text outside of code blocks and inline code
|
||||
pattern = r'(```.*?```|`.*?`)|(_)'
|
||||
|
||||
def replace(match):
|
||||
# If group 1 (code blocks or inline code) is matched, return it unaltered
|
||||
if match.group(1):
|
||||
return match.group(1)
|
||||
# If group 2 (underscore) is matched, escape it
|
||||
elif match.group(2):
|
||||
return r'\_'
|
||||
|
||||
return re.sub(pattern, replace, text, flags=re.DOTALL)
|
||||
|
||||
|
||||
|
||||
async def on_message(update: Update, _):
|
||||
msg = await update.message.reply_text("...")
|
||||
await update.effective_chat.send_action("TYPING")
|
||||
user_id = str(update.message.from_user.id)
|
||||
user_name = update.message.from_user.full_name or update.message.from_user.username
|
||||
print("msg from:", user_id, user_name)
|
||||
try:
|
||||
response = clinet.generate(update.message.text, user_id)
|
||||
try:
|
||||
await msg.edit_text(
|
||||
escape_markdown_except_code_blocks(response),
|
||||
parse_mode='markdown'
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
await msg.edit_text(response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
await msg.edit_text("failed to generate response")
|
||||
|
||||
await update.effective_chat.send_action("CANCEL")
|
||||
|
||||
message_handler = MessageHandler(filters.TEXT & (~filters.COMMAND), on_message)
|
||||
|
28
handlers/models.py
Normal file
28
handlers/models.py
Normal file
@ -0,0 +1,28 @@
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import CommandHandler, CallbackQueryHandler
|
||||
from llm import clinet
|
||||
from llm.ollama import list_models_names
|
||||
|
||||
async def list_models(update: Update, _):
|
||||
model_names = list_models_names()
|
||||
keybaord = [[InlineKeyboardButton(name, callback_data=f'/setmodel {name}')] for name in model_names]
|
||||
reply_markup = InlineKeyboardMarkup(keybaord)
|
||||
|
||||
await update.message.reply_text(
|
||||
text='*WARNING:* chat history will be reset\nPick a model:',
|
||||
reply_markup=reply_markup,
|
||||
parse_mode='markdown'
|
||||
)
|
||||
|
||||
list_models_handler = CommandHandler("models", list_models)
|
||||
|
||||
async def set_model(update: Update, _):
|
||||
query = update.callback_query
|
||||
model_name = query.data.split(' ')[1]
|
||||
await query.answer()
|
||||
|
||||
clinet.set_model(model_name)
|
||||
clinet.reset_history()
|
||||
await query.edit_message_text(text=f"selected model: {model_name}")
|
||||
|
||||
set_model_handler = CallbackQueryHandler(set_model, pattern='^/setmodel')
|
43
handlers/ollama_host.py
Normal file
43
handlers/ollama_host.py
Normal file
@ -0,0 +1,43 @@
|
||||
from telegram import Update
|
||||
from telegram.ext import filters, CommandHandler, ConversationHandler, MessageHandler
|
||||
from ollama import Client
|
||||
from llm import clinet
|
||||
from rich import print
|
||||
# Define states
|
||||
SET_HOST = range(1)
|
||||
|
||||
async def set_llama_host(update: Update, _):
|
||||
await update.message.reply_text(
|
||||
text='Enter ollama host url like, **http://127.0.0.1:11434**',
|
||||
parse_mode='markdown'
|
||||
)
|
||||
return SET_HOST
|
||||
|
||||
|
||||
async def on_got_host(update: Update, _):
|
||||
try:
|
||||
test_client = Client(host=update.message.text)
|
||||
test_client._client.timeout = 2
|
||||
test_client.list()
|
||||
clinet.set_host(update.message.text)
|
||||
await update.message.reply_text(
|
||||
f"ollama host is set to: {update.message.text}"
|
||||
)
|
||||
except:
|
||||
await update.message.reply_text(
|
||||
f"couldn't connect to ollama server at:\n{update.message.text}"
|
||||
)
|
||||
finally:
|
||||
return ConversationHandler.END
|
||||
|
||||
async def cancel(update: Update, _):
|
||||
await update.message.reply_text("canceled /sethost command")
|
||||
return ConversationHandler.END
|
||||
|
||||
set_host_handler = ConversationHandler(
|
||||
entry_points=[CommandHandler("sethost", set_llama_host)],
|
||||
states={
|
||||
SET_HOST: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_host)]
|
||||
},
|
||||
fallbacks=[CommandHandler('cancel', cancel)]
|
||||
)
|
24
handlers/start.py
Normal file
24
handlers/start.py
Normal file
@ -0,0 +1,24 @@
|
||||
from telegram import Update, ReplyKeyboardMarkup
|
||||
from telegram.ext import CommandHandler
|
||||
|
||||
async def start(update: Update, _):
|
||||
await update.message.reply_text(
|
||||
"Ask a question to start conversation",
|
||||
)
|
||||
keyboard = [
|
||||
['/status'],
|
||||
['/models'],
|
||||
['/systems'],
|
||||
['/reset']
|
||||
]
|
||||
reply_markup = ReplyKeyboardMarkup(
|
||||
keyboard,
|
||||
one_time_keyboard=True,
|
||||
resize_keyboard=True
|
||||
)
|
||||
await update.message.reply_text(
|
||||
"You can also use this commands:",
|
||||
reply_markup=reply_markup
|
||||
)
|
||||
|
||||
start_handler = CommandHandler("start", start)
|
13
handlers/status.py
Normal file
13
handlers/status.py
Normal file
@ -0,0 +1,13 @@
|
||||
from telegram import Update
|
||||
from telegram.ext import CommandHandler
|
||||
from llm import clinet
|
||||
|
||||
async def get_status(update: Update, _):
|
||||
host, model, system = clinet.get_status()
|
||||
|
||||
await update.message.reply_text(
|
||||
text=f'*ollama_host:* {host}\n\n*model:* {model}\n\n*system prompt:*\n{system}',
|
||||
parse_mode='markdown'
|
||||
)
|
||||
|
||||
status_handler = CommandHandler("status", get_status)
|
115
handlers/system.py
Normal file
115
handlers/system.py
Normal file
@ -0,0 +1,115 @@
|
||||
from llm import clinet
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram.ext import filters, ContextTypes, CommandHandler, MessageHandler, ConversationHandler, CallbackQueryHandler
|
||||
|
||||
import os
|
||||
|
||||
SYSTEM_DIR = 'system'
|
||||
|
||||
# list system prompts as inline keybard
|
||||
# clicking a button sets system prompt for current model
|
||||
async def list_system_prompts(update: Update, _):
|
||||
files = os.listdir(SYSTEM_DIR)
|
||||
|
||||
keybaord = [[InlineKeyboardButton(name, callback_data=f'/setsystem {name}')] for name in files]
|
||||
reply_markup = InlineKeyboardMarkup(keybaord)
|
||||
|
||||
await update.message.reply_text('select system prompt:', reply_markup=reply_markup)
|
||||
|
||||
|
||||
list_handler = CommandHandler("systems", list_system_prompts)
|
||||
|
||||
async def on_set_system(update: Update, _):
|
||||
query = update.callback_query
|
||||
file_name = query.data.split(' ')[1]
|
||||
await query.answer()
|
||||
|
||||
file_path = os.path.join(SYSTEM_DIR, file_name)
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path) as file:
|
||||
system_prompt = file.read()
|
||||
clinet.set_system(system_prompt)
|
||||
await query.edit_message_text(text=f"the system prompt has been set to: {file_name}\n\n{system_prompt}")
|
||||
|
||||
set_handler = CallbackQueryHandler(on_set_system, pattern='^/setsystem')
|
||||
|
||||
# list system prompts as buttons
|
||||
# clicking on a button removes selected system prompt
|
||||
async def remove_system_prompt(update: Update, _):
|
||||
files = os.listdir(SYSTEM_DIR)
|
||||
|
||||
keybaord = [[InlineKeyboardButton(name, callback_data=f'/rmsystem {name}')] for name in files]
|
||||
reply_markup = InlineKeyboardMarkup(keybaord)
|
||||
|
||||
await update.message.reply_text(
|
||||
'*WARNING:* selected prompt will be removed:',
|
||||
reply_markup=reply_markup,
|
||||
parse_mode='markdown'
|
||||
)
|
||||
|
||||
remove_handler = CommandHandler("rmsystem", remove_system_prompt)
|
||||
|
||||
async def on_remove_system(update: Update, _):
|
||||
query = update.callback_query
|
||||
file_name = query.data.split(' ')[1]
|
||||
await query.answer()
|
||||
|
||||
file_path = os.path.join(SYSTEM_DIR, file_name)
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path) as file:
|
||||
system_prompt = file.read()
|
||||
clinet.set_system(system_prompt)
|
||||
os.remove(file_path)
|
||||
await query.edit_message_text(text=f"{file_name} system prompt has been removed\n\n{system_prompt}")
|
||||
|
||||
on_remove_handler = CallbackQueryHandler(on_remove_system, pattern='^/rmsystem')
|
||||
|
||||
# Define states
|
||||
SET_NAME, SET_CONTENT = range(2)
|
||||
|
||||
# add new system prompt
|
||||
# 1. enter name
|
||||
# 2. enter content
|
||||
# new prompt will be saved to SYSTEM_DIR
|
||||
async def add_system(update: Update, _):
|
||||
await update.message.reply_text(
|
||||
text="Enter name for a new system prompt:",
|
||||
)
|
||||
return SET_NAME
|
||||
|
||||
async def on_got_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
file_name = ''.join(x for x in update.message.text if x.isalnum() or x in '._ ')
|
||||
context.user_data['new_name'] = file_name
|
||||
|
||||
if os.path.exists(os.path.join(SYSTEM_DIR, file_name)):
|
||||
await update.message.reply_text(
|
||||
text=f'*Warning*, existing **{file_name}** will be rewritten\nenter prompt content to continue or /cancel',
|
||||
parse_mode='markdown'
|
||||
)
|
||||
else:
|
||||
await update.message.reply_text(
|
||||
f"New system prompt will be saved to {file_name}, enter prompt:"
|
||||
)
|
||||
|
||||
return SET_CONTENT
|
||||
|
||||
async def on_got_content(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
file_name = context.user_data['new_name']
|
||||
with open(os.path.join(SYSTEM_DIR, file_name), 'w') as file:
|
||||
file.write(update.message.text)
|
||||
|
||||
await update.message.reply_text(f"{file_name} content saved")
|
||||
return ConversationHandler.END
|
||||
|
||||
async def cancel(update: Update, _):
|
||||
await update.message.reply_text("canceled /addsystem command")
|
||||
return ConversationHandler.END
|
||||
|
||||
create_new_hanlder = ConversationHandler(
|
||||
entry_points=[CommandHandler('addsystem', add_system)],
|
||||
states={
|
||||
SET_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_name)],
|
||||
SET_CONTENT: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_content)]
|
||||
},
|
||||
fallbacks=[CommandHandler('cancel', cancel)]
|
||||
)
|
3
llm/__init__.py
Normal file
3
llm/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .ollama import OllamaClient
|
||||
|
||||
clinet = OllamaClient()
|
245
llm/agent_util.py
Normal file
245
llm/agent_util.py
Normal file
@ -0,0 +1,245 @@
|
||||
# LangChain Dependencies
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from typing import Annotated
|
||||
from operator import add
|
||||
# For State Graph
|
||||
from typing_extensions import TypedDict
|
||||
import json
|
||||
|
||||
|
||||
def configure_llm(model, temperature):
|
||||
llama3 = ChatOllama(model=model, temperature=temperature)
|
||||
llama3_json = ChatOllama(
|
||||
model=model, format='json', temperature=temperature)
|
||||
return llama3, llama3_json
|
||||
|
||||
|
||||
def config_agent(llama3, llama3_json):
|
||||
# Graph State
|
||||
|
||||
class GraphState(TypedDict):
|
||||
"""
|
||||
Represents the state of our graph.
|
||||
|
||||
Attributes:
|
||||
question: question
|
||||
generation: LLM generation
|
||||
send_order: revised question for send order
|
||||
context: send_order result
|
||||
"""
|
||||
question: str
|
||||
generation: str
|
||||
order_info: str
|
||||
context: str
|
||||
history: str
|
||||
|
||||
generate_prompt = PromptTemplate(
|
||||
template="""
|
||||
|
||||
<|begin_of_text|>
|
||||
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
Your name is AICQ.
|
||||
You are a professional token seller.
|
||||
We have a total of 1 million CEC, and the average cost of these CECs is $1.1 each, The min price is $0.9 per CEC, There is a 1% chance that the price will be lower than $1 per CEC.
|
||||
Now we hope to hand it over to you to be responsible for selling all of them.
|
||||
Customers will come to talk to you and bargain.
|
||||
You can decide the final transaction price based on the status of the conversation with the customer and sell the tokens.
|
||||
In addition, there are several rules to follow for this task:
|
||||
1) The higher the price, the better, in the hope that the final income will be higher than 1 million US dollars;
|
||||
2) You can set a price higher than 1 US dollar per CEC to sell tokens according to the actual sales situation;
|
||||
3) For a single customer, we hope that you will try not to sell at a price lower than 1 US dollar, but if the customer has special reasons, you can sell some tokens at a lower price depending on the situation, but in this case you need to sell at a higher price in other sales to ensure that the overall income is not less than 1 million US dollars;
|
||||
4) We hope that customers buy as much as possible.
|
||||
Under this condition, if the user is willing to purchase a large number of tokens at one time, we can sell them to him at a more favorable price, but the min price can not lower then $0.9.
|
||||
The above are all the rules for this sales task.
|
||||
|
||||
<|eot_id|>
|
||||
|
||||
<|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{context}
|
||||
{history}
|
||||
Question: {question}
|
||||
Answer:
|
||||
|
||||
<|eot_id|>
|
||||
|
||||
<|start_header_id|>assistant<|end_header_id|>""",
|
||||
input_variables=["question", "context", "history"],
|
||||
)
|
||||
|
||||
# Chain
|
||||
generate_chain = generate_prompt | llama3 | StrOutputParser()
|
||||
router_prompt = PromptTemplate(
|
||||
template="""
|
||||
|
||||
<|begin_of_text|>
|
||||
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are an expert at routing a user question to either the generation stage or send order.
|
||||
CEC is a famous crypto token, and you are an expert at selling CEC.
|
||||
Use the generate for questions that user asks for price.
|
||||
Use the send order for questions that user accepts the price you give and decides to buy a certain amount of CEC.
|
||||
Otherwise, you can skip and go straight to the generation phase to respond.
|
||||
You do not need to be stringent with the keywords in the question related to these topics.
|
||||
Give a binary choice 'send_order' or 'generate' based on the question.
|
||||
Return the JSON with a single key 'choice' with no premable or explanation.
|
||||
|
||||
Question to route: {question}
|
||||
Context to route: {context}
|
||||
History to route: {history}
|
||||
<|eot_id|>
|
||||
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
""",
|
||||
input_variables=["question", "history", "context"],
|
||||
)
|
||||
|
||||
# Chain
|
||||
question_router = router_prompt | llama3_json | JsonOutputParser()
|
||||
|
||||
order_prompt = PromptTemplate(
|
||||
template="""
|
||||
|
||||
<|begin_of_text|>
|
||||
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
You are an expert at sell CEC,
|
||||
Return the JSON with a single key 'count' with amount which user want to buy.
|
||||
|
||||
Question to transform: {question}
|
||||
Context to transform: {context}
|
||||
History to transform: {history}
|
||||
|
||||
<|eot_id|>
|
||||
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
""",
|
||||
input_variables=["question", "context", "history"],
|
||||
)
|
||||
|
||||
# Chain
|
||||
order_chain = order_prompt | llama3_json | JsonOutputParser()
|
||||
|
||||
|
||||
# Node - Generate
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, generation, that contains LLM generation
|
||||
"""
|
||||
|
||||
print("Step: Generating Final Response")
|
||||
question = state["question"]
|
||||
context = state["context"]
|
||||
order_info = state.get("order_info", None)
|
||||
# TODO:: 根据context特定的内容生产答案
|
||||
# check if context is not None and not empty
|
||||
if order_info is not None:
|
||||
return {"generation": order_info}
|
||||
else:
|
||||
generation = generate_chain.invoke(
|
||||
{"context": context, "question": question, "history": state["history"]})
|
||||
return {"generation": generation}
|
||||
|
||||
# Node - Query Transformation
|
||||
|
||||
def transform_query(state):
|
||||
"""
|
||||
Transform user question to order info
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Appended amount of CEC to context
|
||||
"""
|
||||
|
||||
print("Step: Optimizing Query for Send Order")
|
||||
question = state['question']
|
||||
gen_query = order_chain.invoke({"question": question, "history": state["history"], "context": state["context"]})
|
||||
amount = str(gen_query["count"])
|
||||
print("order_info", amount)
|
||||
return {"order_info": [amount] }
|
||||
|
||||
# Node - Send Order
|
||||
|
||||
def send_order(state):
|
||||
"""
|
||||
Send order based on the question
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): Appended Order Info to context
|
||||
"""
|
||||
print("Step: before Send Order")
|
||||
amount = state['order_info']
|
||||
print(amount)
|
||||
print(f'Step: build order info for : "{amount}" CEC')
|
||||
order_info = {"amount": amount, "price": 0.1,
|
||||
"name": "CEC", "url": "https://www.example.com"}
|
||||
order_result = json.dumps(order_info)
|
||||
return {"order_info": order_result}
|
||||
|
||||
# Conditional Edge, Routing
|
||||
|
||||
def route_question(state):
|
||||
"""
|
||||
route question to send order or generation.
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
str: Next node to call
|
||||
"""
|
||||
|
||||
print("Step: Routing Query")
|
||||
question = state['question']
|
||||
output = question_router.invoke({"question": question, "history": state["history"], "context": state["context"]})
|
||||
if output['choice'] == "send_order":
|
||||
print("Step: Routing Query to Send Order")
|
||||
return "sendorder"
|
||||
elif output['choice'] == 'generate':
|
||||
print("Step: Routing Query to Generation")
|
||||
return "generate"
|
||||
|
||||
# Build the nodes
|
||||
workflow = StateGraph(GraphState)
|
||||
workflow.add_node("sendorder", send_order)
|
||||
workflow.add_node("transform_query", transform_query)
|
||||
workflow.add_node("generate", generate)
|
||||
|
||||
# Build the edges
|
||||
workflow.set_conditional_entry_point(
|
||||
route_question,
|
||||
{
|
||||
"sendorder": "transform_query",
|
||||
"generate": "generate",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("transform_query", "sendorder")
|
||||
workflow.add_edge("sendorder", "generate")
|
||||
workflow.add_edge("generate", END)
|
||||
|
||||
# Compile the workflow
|
||||
# memory = MemorySaver()
|
||||
# local_agent = workflow.compile(checkpointer=memory)
|
||||
local_agent = workflow.compile()
|
||||
return local_agent
|
75
llm/chat_history.py
Normal file
75
llm/chat_history.py
Normal file
@ -0,0 +1,75 @@
|
||||
# from langchain.vectorstores.chroma import Chroma
|
||||
from langchain_chroma import Chroma
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
# from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||||
from chromadb.config import Settings
|
||||
import os
|
||||
|
||||
message_template = """
|
||||
User: {question}
|
||||
|
||||
Assistant: {answer}
|
||||
"""
|
||||
|
||||
model_norm = HuggingFaceBgeEmbeddings(
|
||||
model_name="BAAI/bge-base-en",
|
||||
model_kwargs={ 'device': 'cpu' },
|
||||
encode_kwargs={ 'normalize_embeddings': True },
|
||||
)
|
||||
|
||||
def get_chroma():
|
||||
settings = Settings()
|
||||
settings.allow_reset = True
|
||||
settings.is_persistent = True
|
||||
|
||||
return Chroma(
|
||||
persist_directory=DB_FOLDER,
|
||||
embedding_function=model_norm,
|
||||
client_settings=settings,
|
||||
collection_name='latest_chat'
|
||||
)
|
||||
|
||||
DB_FOLDER = 'db'
|
||||
|
||||
class ChatHistory:
|
||||
def __init__(self) -> None:
|
||||
self.history = {}
|
||||
|
||||
def append(self, question: str, answer: str, uid: str):
|
||||
new_message = message_template.format(
|
||||
question=question,
|
||||
answer=answer
|
||||
)
|
||||
if uid not in self.history:
|
||||
self.history[uid] = []
|
||||
self.history[uid].append(new_message)
|
||||
self.embed(new_message, uid)
|
||||
|
||||
def reset_history(self):
|
||||
self.history = {}
|
||||
chroma = get_chroma()
|
||||
chroma._client.reset()
|
||||
|
||||
def embed(self, text: str, uid: str):
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
|
||||
all_splits = text_splitter.split_text(text)
|
||||
|
||||
chroma = get_chroma()
|
||||
chroma.add_texts(all_splits, metadatas=[{"user": uid}])
|
||||
|
||||
def get_context(self, question: str, uid: str):
|
||||
if not os.path.exists(DB_FOLDER):
|
||||
return ''
|
||||
chroma = get_chroma()
|
||||
filter = {
|
||||
'user': uid
|
||||
}
|
||||
documents = chroma.similarity_search(query=question, k=4, filter=filter)
|
||||
context = '\n'.join([d.page_content for d in documents])
|
||||
return context
|
||||
|
||||
def get_last_few(self, uid: str):
|
||||
if uid not in self.history:
|
||||
return ''
|
||||
return '\n'.join(self.history[uid][-3:])
|
86
llm/ollama.py
Normal file
86
llm/ollama.py
Normal file
@ -0,0 +1,86 @@
|
||||
from ollama import Client
|
||||
# from langchain.llms.ollama import Ollama
|
||||
# from langchain.llms import Ollama
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from .chat_history import ChatHistory
|
||||
from .agent_util import config_agent
|
||||
import os
|
||||
|
||||
ollama_model = 'llama3.2'
|
||||
# OLLAMA_HOST = os.environ.get('OLLAMA_BOT_BASE_URL')
|
||||
# if OLLAMA_HOST is None:
|
||||
# raise ValueError("OLLAMA_BOT_BASE_URL env variable is not set")
|
||||
|
||||
default_system_prompt_path = './prompts/default.md'
|
||||
|
||||
# 传入system参数
|
||||
def get_client(system=''):
|
||||
llama3 = ChatOllama(model=ollama_model, temperature=0)
|
||||
llama3_json = ChatOllama(
|
||||
model=ollama_model,
|
||||
format='json',
|
||||
temperature=0
|
||||
)
|
||||
client = config_agent(llama3, llama3_json)
|
||||
return client
|
||||
|
||||
|
||||
def list_models_names():
|
||||
# models = Client(OLLAMA_HOST).list()["models"]
|
||||
# return [m['name'] for m in models]
|
||||
return []
|
||||
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
{context}
|
||||
|
||||
{history}
|
||||
|
||||
User: {question}
|
||||
|
||||
Assistant:
|
||||
"""
|
||||
|
||||
class OllamaClient:
|
||||
def __init__(self) -> None:
|
||||
with open(default_system_prompt_path) as file:
|
||||
self.system_prompt = file.read()
|
||||
|
||||
self.client = get_client(self.system_prompt)
|
||||
self.history = ChatHistory()
|
||||
|
||||
def generate(self, question, uid) -> str:
|
||||
context_content = self.history.get_context(question, uid)
|
||||
context = f"CONTEXT: {context_content}" if len(context_content) > 0 else ''
|
||||
|
||||
last_few = self.history.get_last_few(uid)
|
||||
history = f"LAST MESSAGES: {last_few}" if len(last_few) > 0 else ''
|
||||
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
context=context,
|
||||
history=history,
|
||||
question=question
|
||||
)
|
||||
answer = self.client.invoke({"question": question, "context": context, "history": history})['generation']
|
||||
print("answer:", answer)
|
||||
self.history.append(question, answer, uid)
|
||||
return answer
|
||||
|
||||
def reset_history(self):
|
||||
self.history.reset_history()
|
||||
|
||||
def set_model(self, model_name):
|
||||
self.client.model = model_name
|
||||
|
||||
def set_host(self, host):
|
||||
self.client.base_url = host
|
||||
|
||||
def set_system(self, system):
|
||||
self.client.system = system
|
||||
|
||||
def get_status(self):
|
||||
return (
|
||||
self.client.base_url,
|
||||
self.client.model,
|
||||
self.client.system
|
||||
)
|
6
prompts/default.md
Normal file
6
prompts/default.md
Normal file
@ -0,0 +1,6 @@
|
||||
You're a helpful assistant.
|
||||
Your goal is to help the user with their questions
|
||||
If there's a previous conversation you'll be provided a context
|
||||
|
||||
If you don't know an answer to a given USER question don't imagine anything
|
||||
just say you don't know
|
@ -20,10 +20,10 @@ local_agent = config_agent(llama3, llama3_json)
|
||||
|
||||
|
||||
def run_agent(query):
|
||||
# config = {"configurable": {"thread_id": "1"}}
|
||||
# output = local_agent.invoke({"question": query}, config)
|
||||
# print(list(local_agent.get_state_history(config)))
|
||||
output = local_agent.invoke({"question": query})
|
||||
config = {"configurable": {"thread_id": "1", "user_id": "1"}}
|
||||
output = local_agent.invoke({"question": [query]}, config)
|
||||
print(list(local_agent.get_state_history(config)))
|
||||
# output = local_agent.invoke({"question": ['hi, my name is cz', query]})
|
||||
print("=======")
|
||||
return output["generation"]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user