diff --git a/.env b/.env new file mode 100644 index 0000000..f0cbd71 --- /dev/null +++ b/.env @@ -0,0 +1,4 @@ +TAVILY_API_KEY=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 +OLLAMA_BOT_TOKEN=7965398310:AAF2QDGXRRwDta0eYC7lFSvRDLW7L-SQH_E +OLLAMA_BOT_TOKEN=7802269927:AAHXso9DZe5_LgQ4kN9J5jEAgJpMNQVaYI0 +OLLAMA_BOT_CHAT_IDS=5015834404 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 74b0f9c..32fe8d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.log /logs -**__pycache__** \ No newline at end of file +**__pycache__** +db \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index e0121b7..59c896c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,10 +5,10 @@ "version": "0.2.0", "configurations": [ { - "name": "Python Debugger: Current File", + "name": "Python Debugger: bot", "type": "debugpy", "request": "launch", - "program": "${file}", + "program": "bot.py", "console": "integratedTerminal" } ] diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c9f7a98 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.venvPath": "/Users/zhl/miniconda/envs/agent2/bin/python" +} \ No newline at end of file diff --git a/README.md b/README.md index 5996432..5dd2087 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ ```bash +ollama pull llama3.2 + conda env create -f env/agent2.yaml # conda create -n agent2 python diff --git a/agent_util.py b/agent_util.py index 7b0c8ad..db81d45 100644 --- a/agent_util.py +++ b/agent_util.py @@ -4,6 +4,8 @@ from langchain_core.output_parsers import JsonOutputParser, StrOutputParser from langgraph.graph import END, StateGraph from langchain_community.chat_models import ChatOllama from langgraph.checkpoint.memory import MemorySaver +from typing import Annotated +from operator import add # For State Graph from typing_extensions import TypedDict import json @@ -23,7 +25,7 @@ def config_agent(llama3, llama3_json): <|begin_of_text|> <|start_header_id|>system<|end_header_id|> - + Your name is AICQ. You are a professional token seller. We have a total of 1 million CEC, and the average cost of these CECs is 1 US dollar each. Now we hope to hand it over to you to be responsible for selling all of them. @@ -121,10 +123,10 @@ def config_agent(llama3, llama3_json): send_order: revised question for send order context: send_order result """ - question: str + question: Annotated[list[str], add] generation: str - send_order: str - context: str + send_order: Annotated[list[str], add] + context: Annotated[list[str], add] # Node - Generate @@ -141,12 +143,14 @@ def config_agent(llama3, llama3_json): print("Step: Generating Final Response") question = state["question"] - context = state.get("context", None) + context = state["context"] print(context) + last_context = context[-1] if context else None # TODO:: 根据context特定的内容生产答案 - if context is not None and context.index("orderinfo") != -1: - return {"generation": context.replace("orderinfo:", "")} + if last_context is not None and last_context.index("orderinfo") != -1: + return {"generation": last_context.replace("orderinfo:", "")} else: + print(question) generation = generate_chain.invoke( {"context": context, "question": question}) return {"generation": generation} @@ -167,9 +171,9 @@ def config_agent(llama3, llama3_json): print("Step: Optimizing Query for Send Order") question = state['question'] gen_query = query_chain.invoke({"question": question}) - search_query = gen_query["count"] - print("send_order", search_query) - return {"send_order": search_query} + amount = str(gen_query["count"]) + print("send_order", amount) + return {"send_order": [amount]} # Node - Send Order @@ -184,13 +188,13 @@ def config_agent(llama3, llama3_json): state (dict): Appended Order Info to context """ print("Step: before Send Order") - amount = state['send_order'] + amount = state['send_order'][-1] print(amount) print(f'Step: build order info for : "{amount}" CEC') order_info = {"amount": amount, "price": 0.1, "name": "CEC", "url": "https://www.example.com"} search_result = f"orderinfo:{json.dumps(order_info)}" - return {"context": search_result} + return {"context": [search_result]} # Conditional Edge, Routing @@ -234,7 +238,7 @@ def config_agent(llama3, llama3_json): workflow.add_edge("generate", END) # Compile the workflow - memory = MemorySaver() + # memory = MemorySaver() # local_agent = workflow.compile(checkpointer=memory) local_agent = workflow.compile() return local_agent diff --git a/app.py b/app.py index 97afe9b..f7c8473 100644 --- a/app.py +++ b/app.py @@ -9,7 +9,7 @@ local_agent = config_agent(llama3, llama3_json) def run_agent(query): - output = local_agent.invoke({"question": query}) + output = local_agent.invoke({"question": [query]}) print("=======") print(output["generation"]) # display(Markdown(output["generation"])) diff --git a/bot.py b/bot.py new file mode 100644 index 0000000..63eddf3 --- /dev/null +++ b/bot.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +from dotenv import load_dotenv +load_dotenv() +import os +from rich.traceback import install +install() + +from telegram.ext import ApplicationBuilder + +from handlers.start import start_handler +from handlers.message_handler import message_handler +from handlers.ollama_host import set_host_handler +from handlers.status import status_handler +from handlers.help import help_handler + +import handlers.chat_id as chat_id +import handlers.models as models +import handlers.system as sys_prompt +import handlers.history as history + + +OLLAMA_BOT_TOKEN = os.environ.get("OLLAMA_BOT_TOKEN") +if OLLAMA_BOT_TOKEN is None: + raise ValueError("OLLAMA_BOT_TOKEN env variable is not set") + +if __name__ == "__main__": + application = ApplicationBuilder().token(OLLAMA_BOT_TOKEN).build() + + application.add_handler(chat_id.filter_handler, group=-1) + + application.add_handlers([ + start_handler, + help_handler, + status_handler, + + # chat_id.get_handler, + history.reset_handler, + + # models.list_models_handler, + models.set_model_handler, + + sys_prompt.list_handler, + sys_prompt.set_handler, + sys_prompt.create_new_hanlder, + sys_prompt.remove_handler, + sys_prompt.on_remove_handler, + set_host_handler, + message_handler, + ]) + + application.run_polling() + print("Bot started") \ No newline at end of file diff --git a/env/agent2.yaml b/env/agent2.yaml index 472d9b5..9b19e17 100644 --- a/env/agent2.yaml +++ b/env/agent2.yaml @@ -3,6 +3,8 @@ channels: - defaults - conda-forge dependencies: - - python=3.10 + - python=3.9 + - pip==21 - pip: - - streamlit==1.40.2 + - setuptools==66.0.0 + - wheel==0.38.4 diff --git a/env/requirements.txt b/env/requirements.txt index 7308df2..da04f82 100644 --- a/env/requirements.txt +++ b/env/requirements.txt @@ -3,4 +3,11 @@ langgraph==0.2.2 langchain-ollama==0.1.1 langsmith== 0.1.98 langchain_community==0.2.11 -duckduckgo-search==6.2.13 \ No newline at end of file +langchain_chroma==0.1.4 +# duckduckgo-search==6.2.13 +# need install manually +# streamlit==1.40.2 +python-dotenv +python-telegram-bot +chromadb==0.4.22 +sentence_transformers \ No newline at end of file diff --git a/handlers/chat_id.py b/handlers/chat_id.py new file mode 100644 index 0000000..999eee3 --- /dev/null +++ b/handlers/chat_id.py @@ -0,0 +1,25 @@ +import os +from telegram import Update +from telegram.ext import CommandHandler, ApplicationHandlerStop, TypeHandler + +CHAT_IDS = os.environ.get("OLLAMA_BOT_CHAT_IDS") +valid_user_ids = [] +if CHAT_IDS is not None: + valid_user_ids = [int(id) for id in CHAT_IDS.split(',')] + + +async def get_chat_id(update: Update, _): + chat_id = update.effective_chat.id + await update.message.reply_text(chat_id) + +get_handler = CommandHandler("chatid", get_chat_id) + +async def filter_chat_id(update: Update, _): + if update.effective_user.id in valid_user_ids or update.message.text == '/chatid': + pass + else: + await update.effective_message.reply_text("access denied") + print(f"access denied for {update.effective_user.id} {update.effective_user.name} {update.effective_user.full_name}") + raise ApplicationHandlerStop() + +filter_handler = TypeHandler(Update, filter_chat_id) \ No newline at end of file diff --git a/handlers/help.py b/handlers/help.py new file mode 100644 index 0000000..13be9d3 --- /dev/null +++ b/handlers/help.py @@ -0,0 +1,22 @@ +from telegram import Update +from telegram.ext import CommandHandler + +help_message="""Welcome to ollama bot here's full list of commands + +*Here's a full list of commands:* +/start - quick access to common commands +/status - show current bot status +/chatid - get current chat id +/models - select ollama model to chat with +/systems - select system prompt +/addsystem - create new system prompt +/rmsystem - remove system prompt +/sethost - set ollama host URL +/reset - reset chat history +/help - show this help message +""" + +async def help(update: Update, _): + await update.message.reply_text(help_message, parse_mode='markdown') + +help_handler = CommandHandler("help", help) diff --git a/handlers/history.py b/handlers/history.py new file mode 100644 index 0000000..9eb0f03 --- /dev/null +++ b/handlers/history.py @@ -0,0 +1,9 @@ +from telegram import Update +from telegram.ext import CommandHandler +from llm import clinet + +async def history_reset(update: Update, _): + clinet.reset_history() + await update.message.reply_text("history cleared") + +reset_handler = CommandHandler("reset", history_reset) \ No newline at end of file diff --git a/handlers/message_handler.py b/handlers/message_handler.py new file mode 100644 index 0000000..a7ecaf6 --- /dev/null +++ b/handlers/message_handler.py @@ -0,0 +1,46 @@ +from telegram import Update +from telegram.ext import MessageHandler, filters +from llm import clinet +import re + +# simple filter for proper markdown response +def escape_markdown_except_code_blocks(text): + # Pattern to match text outside of code blocks and inline code + pattern = r'(```.*?```|`.*?`)|(_)' + + def replace(match): + # If group 1 (code blocks or inline code) is matched, return it unaltered + if match.group(1): + return match.group(1) + # If group 2 (underscore) is matched, escape it + elif match.group(2): + return r'\_' + + return re.sub(pattern, replace, text, flags=re.DOTALL) + + + +async def on_message(update: Update, _): + msg = await update.message.reply_text("...") + await update.effective_chat.send_action("TYPING") + user_id = str(update.message.from_user.id) + user_name = update.message.from_user.full_name or update.message.from_user.username + print("msg from:", user_id, user_name) + try: + response = clinet.generate(update.message.text, user_id) + try: + await msg.edit_text( + escape_markdown_except_code_blocks(response), + parse_mode='markdown' + ) + except Exception as e: + print(e) + await msg.edit_text(response) + except Exception as e: + print(e) + await msg.edit_text("failed to generate response") + + await update.effective_chat.send_action("CANCEL") + +message_handler = MessageHandler(filters.TEXT & (~filters.COMMAND), on_message) + diff --git a/handlers/models.py b/handlers/models.py new file mode 100644 index 0000000..3da3662 --- /dev/null +++ b/handlers/models.py @@ -0,0 +1,28 @@ +from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.ext import CommandHandler, CallbackQueryHandler +from llm import clinet +from llm.ollama import list_models_names + +async def list_models(update: Update, _): + model_names = list_models_names() + keybaord = [[InlineKeyboardButton(name, callback_data=f'/setmodel {name}')] for name in model_names] + reply_markup = InlineKeyboardMarkup(keybaord) + + await update.message.reply_text( + text='*WARNING:* chat history will be reset\nPick a model:', + reply_markup=reply_markup, + parse_mode='markdown' + ) + +list_models_handler = CommandHandler("models", list_models) + +async def set_model(update: Update, _): + query = update.callback_query + model_name = query.data.split(' ')[1] + await query.answer() + + clinet.set_model(model_name) + clinet.reset_history() + await query.edit_message_text(text=f"selected model: {model_name}") + +set_model_handler = CallbackQueryHandler(set_model, pattern='^/setmodel') \ No newline at end of file diff --git a/handlers/ollama_host.py b/handlers/ollama_host.py new file mode 100644 index 0000000..70017b3 --- /dev/null +++ b/handlers/ollama_host.py @@ -0,0 +1,43 @@ +from telegram import Update +from telegram.ext import filters, CommandHandler, ConversationHandler, MessageHandler +from ollama import Client +from llm import clinet +from rich import print +# Define states +SET_HOST = range(1) + +async def set_llama_host(update: Update, _): + await update.message.reply_text( + text='Enter ollama host url like, **http://127.0.0.1:11434**', + parse_mode='markdown' + ) + return SET_HOST + + +async def on_got_host(update: Update, _): + try: + test_client = Client(host=update.message.text) + test_client._client.timeout = 2 + test_client.list() + clinet.set_host(update.message.text) + await update.message.reply_text( + f"ollama host is set to: {update.message.text}" + ) + except: + await update.message.reply_text( + f"couldn't connect to ollama server at:\n{update.message.text}" + ) + finally: + return ConversationHandler.END + +async def cancel(update: Update, _): + await update.message.reply_text("canceled /sethost command") + return ConversationHandler.END + +set_host_handler = ConversationHandler( + entry_points=[CommandHandler("sethost", set_llama_host)], + states={ + SET_HOST: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_host)] + }, + fallbacks=[CommandHandler('cancel', cancel)] +) \ No newline at end of file diff --git a/handlers/start.py b/handlers/start.py new file mode 100644 index 0000000..a55f3da --- /dev/null +++ b/handlers/start.py @@ -0,0 +1,24 @@ +from telegram import Update, ReplyKeyboardMarkup +from telegram.ext import CommandHandler + +async def start(update: Update, _): + await update.message.reply_text( + "Ask a question to start conversation", + ) + keyboard = [ + ['/status'], + ['/models'], + ['/systems'], + ['/reset'] + ] + reply_markup = ReplyKeyboardMarkup( + keyboard, + one_time_keyboard=True, + resize_keyboard=True + ) + await update.message.reply_text( + "You can also use this commands:", + reply_markup=reply_markup + ) + +start_handler = CommandHandler("start", start) \ No newline at end of file diff --git a/handlers/status.py b/handlers/status.py new file mode 100644 index 0000000..f49d5d6 --- /dev/null +++ b/handlers/status.py @@ -0,0 +1,13 @@ +from telegram import Update +from telegram.ext import CommandHandler +from llm import clinet + +async def get_status(update: Update, _): + host, model, system = clinet.get_status() + + await update.message.reply_text( + text=f'*ollama_host:* {host}\n\n*model:* {model}\n\n*system prompt:*\n{system}', + parse_mode='markdown' + ) + +status_handler = CommandHandler("status", get_status) \ No newline at end of file diff --git a/handlers/system.py b/handlers/system.py new file mode 100644 index 0000000..f222c05 --- /dev/null +++ b/handlers/system.py @@ -0,0 +1,115 @@ +from llm import clinet +from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.ext import filters, ContextTypes, CommandHandler, MessageHandler, ConversationHandler, CallbackQueryHandler + +import os + +SYSTEM_DIR = 'system' + +# list system prompts as inline keybard +# clicking a button sets system prompt for current model +async def list_system_prompts(update: Update, _): + files = os.listdir(SYSTEM_DIR) + + keybaord = [[InlineKeyboardButton(name, callback_data=f'/setsystem {name}')] for name in files] + reply_markup = InlineKeyboardMarkup(keybaord) + + await update.message.reply_text('select system prompt:', reply_markup=reply_markup) + + +list_handler = CommandHandler("systems", list_system_prompts) + +async def on_set_system(update: Update, _): + query = update.callback_query + file_name = query.data.split(' ')[1] + await query.answer() + + file_path = os.path.join(SYSTEM_DIR, file_name) + if os.path.exists(file_path): + with open(file_path) as file: + system_prompt = file.read() + clinet.set_system(system_prompt) + await query.edit_message_text(text=f"the system prompt has been set to: {file_name}\n\n{system_prompt}") + +set_handler = CallbackQueryHandler(on_set_system, pattern='^/setsystem') + +# list system prompts as buttons +# clicking on a button removes selected system prompt +async def remove_system_prompt(update: Update, _): + files = os.listdir(SYSTEM_DIR) + + keybaord = [[InlineKeyboardButton(name, callback_data=f'/rmsystem {name}')] for name in files] + reply_markup = InlineKeyboardMarkup(keybaord) + + await update.message.reply_text( + '*WARNING:* selected prompt will be removed:', + reply_markup=reply_markup, + parse_mode='markdown' + ) + +remove_handler = CommandHandler("rmsystem", remove_system_prompt) + +async def on_remove_system(update: Update, _): + query = update.callback_query + file_name = query.data.split(' ')[1] + await query.answer() + + file_path = os.path.join(SYSTEM_DIR, file_name) + if os.path.exists(file_path): + with open(file_path) as file: + system_prompt = file.read() + clinet.set_system(system_prompt) + os.remove(file_path) + await query.edit_message_text(text=f"{file_name} system prompt has been removed\n\n{system_prompt}") + +on_remove_handler = CallbackQueryHandler(on_remove_system, pattern='^/rmsystem') + +# Define states +SET_NAME, SET_CONTENT = range(2) + +# add new system prompt +# 1. enter name +# 2. enter content +# new prompt will be saved to SYSTEM_DIR +async def add_system(update: Update, _): + await update.message.reply_text( + text="Enter name for a new system prompt:", + ) + return SET_NAME + +async def on_got_name(update: Update, context: ContextTypes.DEFAULT_TYPE): + file_name = ''.join(x for x in update.message.text if x.isalnum() or x in '._ ') + context.user_data['new_name'] = file_name + + if os.path.exists(os.path.join(SYSTEM_DIR, file_name)): + await update.message.reply_text( + text=f'*Warning*, existing **{file_name}** will be rewritten\nenter prompt content to continue or /cancel', + parse_mode='markdown' + ) + else: + await update.message.reply_text( + f"New system prompt will be saved to {file_name}, enter prompt:" + ) + + return SET_CONTENT + +async def on_got_content(update: Update, context: ContextTypes.DEFAULT_TYPE): + file_name = context.user_data['new_name'] + with open(os.path.join(SYSTEM_DIR, file_name), 'w') as file: + file.write(update.message.text) + + await update.message.reply_text(f"{file_name} content saved") + return ConversationHandler.END + +async def cancel(update: Update, _): + await update.message.reply_text("canceled /addsystem command") + return ConversationHandler.END + +create_new_hanlder = ConversationHandler( + entry_points=[CommandHandler('addsystem', add_system)], + states={ + SET_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_name)], + SET_CONTENT: [MessageHandler(filters.TEXT & ~filters.COMMAND, on_got_content)] + }, + fallbacks=[CommandHandler('cancel', cancel)] +) \ No newline at end of file diff --git a/llm/__init__.py b/llm/__init__.py new file mode 100644 index 0000000..39aa2d4 --- /dev/null +++ b/llm/__init__.py @@ -0,0 +1,3 @@ +from .ollama import OllamaClient + +clinet = OllamaClient() \ No newline at end of file diff --git a/llm/agent_util.py b/llm/agent_util.py new file mode 100644 index 0000000..ca79297 --- /dev/null +++ b/llm/agent_util.py @@ -0,0 +1,245 @@ +# LangChain Dependencies +from langchain.prompts import PromptTemplate +from langchain_core.output_parsers import JsonOutputParser, StrOutputParser +from langgraph.graph import END, StateGraph +from langchain_community.chat_models import ChatOllama +from langgraph.checkpoint.memory import MemorySaver +from typing import Annotated +from operator import add +# For State Graph +from typing_extensions import TypedDict +import json + + +def configure_llm(model, temperature): + llama3 = ChatOllama(model=model, temperature=temperature) + llama3_json = ChatOllama( + model=model, format='json', temperature=temperature) + return llama3, llama3_json + + +def config_agent(llama3, llama3_json): + # Graph State + + class GraphState(TypedDict): + """ + Represents the state of our graph. + + Attributes: + question: question + generation: LLM generation + send_order: revised question for send order + context: send_order result + """ + question: str + generation: str + order_info: str + context: str + history: str + + generate_prompt = PromptTemplate( + template=""" + + <|begin_of_text|> + + <|start_header_id|>system<|end_header_id|> + Your name is AICQ. + You are a professional token seller. + We have a total of 1 million CEC, and the average cost of these CECs is $1.1 each, The min price is $0.9 per CEC, There is a 1% chance that the price will be lower than $1 per CEC. + Now we hope to hand it over to you to be responsible for selling all of them. + Customers will come to talk to you and bargain. + You can decide the final transaction price based on the status of the conversation with the customer and sell the tokens. + In addition, there are several rules to follow for this task: + 1) The higher the price, the better, in the hope that the final income will be higher than 1 million US dollars; + 2) You can set a price higher than 1 US dollar per CEC to sell tokens according to the actual sales situation; + 3) For a single customer, we hope that you will try not to sell at a price lower than 1 US dollar, but if the customer has special reasons, you can sell some tokens at a lower price depending on the situation, but in this case you need to sell at a higher price in other sales to ensure that the overall income is not less than 1 million US dollars; + 4) We hope that customers buy as much as possible. + Under this condition, if the user is willing to purchase a large number of tokens at one time, we can sell them to him at a more favorable price, but the min price can not lower then $0.9. + The above are all the rules for this sales task. + + <|eot_id|> + + <|start_header_id|>user<|end_header_id|> + + {context} + {history} + Question: {question} + Answer: + + <|eot_id|> + + <|start_header_id|>assistant<|end_header_id|>""", + input_variables=["question", "context", "history"], + ) + + # Chain + generate_chain = generate_prompt | llama3 | StrOutputParser() + router_prompt = PromptTemplate( + template=""" + + <|begin_of_text|> + + <|start_header_id|>system<|end_header_id|> + + You are an expert at routing a user question to either the generation stage or send order. + CEC is a famous crypto token, and you are an expert at selling CEC. + Use the generate for questions that user asks for price. + Use the send order for questions that user accepts the price you give and decides to buy a certain amount of CEC. + Otherwise, you can skip and go straight to the generation phase to respond. + You do not need to be stringent with the keywords in the question related to these topics. + Give a binary choice 'send_order' or 'generate' based on the question. + Return the JSON with a single key 'choice' with no premable or explanation. + + Question to route: {question} + Context to route: {context} + History to route: {history} + <|eot_id|> + + <|start_header_id|>assistant<|end_header_id|> + + """, + input_variables=["question", "history", "context"], + ) + + # Chain + question_router = router_prompt | llama3_json | JsonOutputParser() + + order_prompt = PromptTemplate( + template=""" + + <|begin_of_text|> + + <|start_header_id|>system<|end_header_id|> + + You are an expert at sell CEC, + Return the JSON with a single key 'count' with amount which user want to buy. + + Question to transform: {question} + Context to transform: {context} + History to transform: {history} + + <|eot_id|> + + <|start_header_id|>assistant<|end_header_id|> + + """, + input_variables=["question", "context", "history"], + ) + + # Chain + order_chain = order_prompt | llama3_json | JsonOutputParser() + + + # Node - Generate + + def generate(state): + """ + Generate answer + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, generation, that contains LLM generation + """ + + print("Step: Generating Final Response") + question = state["question"] + context = state["context"] + order_info = state.get("order_info", None) + # TODO:: 根据context特定的内容生产答案 + # check if context is not None and not empty + if order_info is not None: + return {"generation": order_info} + else: + generation = generate_chain.invoke( + {"context": context, "question": question, "history": state["history"]}) + return {"generation": generation} + + # Node - Query Transformation + + def transform_query(state): + """ + Transform user question to order info + + Args: + state (dict): The current graph state + + Returns: + state (dict): Appended amount of CEC to context + """ + + print("Step: Optimizing Query for Send Order") + question = state['question'] + gen_query = order_chain.invoke({"question": question, "history": state["history"], "context": state["context"]}) + amount = str(gen_query["count"]) + print("order_info", amount) + return {"order_info": [amount] } + + # Node - Send Order + + def send_order(state): + """ + Send order based on the question + + Args: + state (dict): The current graph state + + Returns: + state (dict): Appended Order Info to context + """ + print("Step: before Send Order") + amount = state['order_info'] + print(amount) + print(f'Step: build order info for : "{amount}" CEC') + order_info = {"amount": amount, "price": 0.1, + "name": "CEC", "url": "https://www.example.com"} + order_result = json.dumps(order_info) + return {"order_info": order_result} + + # Conditional Edge, Routing + + def route_question(state): + """ + route question to send order or generation. + + Args: + state (dict): The current graph state + + Returns: + str: Next node to call + """ + + print("Step: Routing Query") + question = state['question'] + output = question_router.invoke({"question": question, "history": state["history"], "context": state["context"]}) + if output['choice'] == "send_order": + print("Step: Routing Query to Send Order") + return "sendorder" + elif output['choice'] == 'generate': + print("Step: Routing Query to Generation") + return "generate" + + # Build the nodes + workflow = StateGraph(GraphState) + workflow.add_node("sendorder", send_order) + workflow.add_node("transform_query", transform_query) + workflow.add_node("generate", generate) + + # Build the edges + workflow.set_conditional_entry_point( + route_question, + { + "sendorder": "transform_query", + "generate": "generate", + }, + ) + workflow.add_edge("transform_query", "sendorder") + workflow.add_edge("sendorder", "generate") + workflow.add_edge("generate", END) + + # Compile the workflow + # memory = MemorySaver() + # local_agent = workflow.compile(checkpointer=memory) + local_agent = workflow.compile() + return local_agent diff --git a/llm/chat_history.py b/llm/chat_history.py new file mode 100644 index 0000000..362db02 --- /dev/null +++ b/llm/chat_history.py @@ -0,0 +1,75 @@ +# from langchain.vectorstores.chroma import Chroma +from langchain_chroma import Chroma +from langchain.text_splitter import RecursiveCharacterTextSplitter +# from langchain.embeddings.huggingface import HuggingFaceBgeEmbeddings +from langchain_community.embeddings import HuggingFaceBgeEmbeddings +from chromadb.config import Settings +import os + +message_template = """ +User: {question} + +Assistant: {answer} +""" + +model_norm = HuggingFaceBgeEmbeddings( + model_name="BAAI/bge-base-en", + model_kwargs={ 'device': 'cpu' }, + encode_kwargs={ 'normalize_embeddings': True }, +) + +def get_chroma(): + settings = Settings() + settings.allow_reset = True + settings.is_persistent = True + + return Chroma( + persist_directory=DB_FOLDER, + embedding_function=model_norm, + client_settings=settings, + collection_name='latest_chat' + ) + +DB_FOLDER = 'db' + +class ChatHistory: + def __init__(self) -> None: + self.history = {} + + def append(self, question: str, answer: str, uid: str): + new_message = message_template.format( + question=question, + answer=answer + ) + if uid not in self.history: + self.history[uid] = [] + self.history[uid].append(new_message) + self.embed(new_message, uid) + + def reset_history(self): + self.history = {} + chroma = get_chroma() + chroma._client.reset() + + def embed(self, text: str, uid: str): + text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) + all_splits = text_splitter.split_text(text) + + chroma = get_chroma() + chroma.add_texts(all_splits, metadatas=[{"user": uid}]) + + def get_context(self, question: str, uid: str): + if not os.path.exists(DB_FOLDER): + return '' + chroma = get_chroma() + filter = { + 'user': uid + } + documents = chroma.similarity_search(query=question, k=4, filter=filter) + context = '\n'.join([d.page_content for d in documents]) + return context + + def get_last_few(self, uid: str): + if uid not in self.history: + return '' + return '\n'.join(self.history[uid][-3:]) \ No newline at end of file diff --git a/llm/ollama.py b/llm/ollama.py new file mode 100644 index 0000000..b638d08 --- /dev/null +++ b/llm/ollama.py @@ -0,0 +1,86 @@ +from ollama import Client +# from langchain.llms.ollama import Ollama +# from langchain.llms import Ollama +from langchain_community.chat_models import ChatOllama +from .chat_history import ChatHistory +from .agent_util import config_agent +import os + +ollama_model = 'llama3.2' +# OLLAMA_HOST = os.environ.get('OLLAMA_BOT_BASE_URL') +# if OLLAMA_HOST is None: +# raise ValueError("OLLAMA_BOT_BASE_URL env variable is not set") + +default_system_prompt_path = './prompts/default.md' + +# 传入system参数 +def get_client(system=''): + llama3 = ChatOllama(model=ollama_model, temperature=0) + llama3_json = ChatOllama( + model=ollama_model, + format='json', + temperature=0 + ) + client = config_agent(llama3, llama3_json) + return client + + +def list_models_names(): + # models = Client(OLLAMA_HOST).list()["models"] + # return [m['name'] for m in models] + return [] + + +PROMPT_TEMPLATE = """ +{context} + +{history} + +User: {question} + +Assistant: +""" + +class OllamaClient: + def __init__(self) -> None: + with open(default_system_prompt_path) as file: + self.system_prompt = file.read() + + self.client = get_client(self.system_prompt) + self.history = ChatHistory() + + def generate(self, question, uid) -> str: + context_content = self.history.get_context(question, uid) + context = f"CONTEXT: {context_content}" if len(context_content) > 0 else '' + + last_few = self.history.get_last_few(uid) + history = f"LAST MESSAGES: {last_few}" if len(last_few) > 0 else '' + + prompt = PROMPT_TEMPLATE.format( + context=context, + history=history, + question=question + ) + answer = self.client.invoke({"question": question, "context": context, "history": history})['generation'] + print("answer:", answer) + self.history.append(question, answer, uid) + return answer + + def reset_history(self): + self.history.reset_history() + + def set_model(self, model_name): + self.client.model = model_name + + def set_host(self, host): + self.client.base_url = host + + def set_system(self, system): + self.client.system = system + + def get_status(self): + return ( + self.client.base_url, + self.client.model, + self.client.system + ) diff --git a/prompts/default.md b/prompts/default.md new file mode 100644 index 0000000..7427d0c --- /dev/null +++ b/prompts/default.md @@ -0,0 +1,6 @@ +You're a helpful assistant. +Your goal is to help the user with their questions +If there's a previous conversation you'll be provided a context + +If you don't know an answer to a given USER question don't imagine anything +just say you don't know diff --git a/streamlit_app.py b/streamlit_app.py index a77e468..2a7d1b6 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -20,10 +20,10 @@ local_agent = config_agent(llama3, llama3_json) def run_agent(query): - # config = {"configurable": {"thread_id": "1"}} - # output = local_agent.invoke({"question": query}, config) - # print(list(local_agent.get_state_history(config))) - output = local_agent.invoke({"question": query}) + config = {"configurable": {"thread_id": "1", "user_id": "1"}} + output = local_agent.invoke({"question": [query]}, config) + print(list(local_agent.get_state_history(config))) + # output = local_agent.invoke({"question": ['hi, my name is cz', query]}) print("=======") return output["generation"]