From 2b9e07caa3ddc1305fa3598749b165bd66d634c5 Mon Sep 17 00:00:00 2001 From: CounterFire2023 <136581895+CounterFire2023@users.noreply.github.com> Date: Thu, 12 Dec 2024 20:31:08 +0800 Subject: [PATCH] reformat code --- agent_util.py | 259 ++++++++++++++++++++++++----------------------- app.py | 4 +- streamlit_app.py | 17 +++- 3 files changed, 149 insertions(+), 131 deletions(-) diff --git a/agent_util.py b/agent_util.py index 3118023..7b0c8ad 100644 --- a/agent_util.py +++ b/agent_util.py @@ -3,18 +3,22 @@ from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser, StrOutputParser from langgraph.graph import END, StateGraph from langchain_community.chat_models import ChatOllama -# For State Graph +from langgraph.checkpoint.memory import MemorySaver +# For State Graph from typing_extensions import TypedDict import json + def configure_llm(model, temperature): - llama3 = ChatOllama(model=model, temperature=temperature) - llama3_json = ChatOllama(model=model, format='json', temperature=temperature) - return llama3, llama3_json + llama3 = ChatOllama(model=model, temperature=temperature) + llama3_json = ChatOllama( + model=model, format='json', temperature=temperature) + return llama3, llama3_json + def config_agent(llama3, llama3_json): - generate_prompt = PromptTemplate( - template=""" + generate_prompt = PromptTemplate( + template=""" <|begin_of_text|> @@ -49,13 +53,13 @@ def config_agent(llama3, llama3_json): <|eot_id|> <|start_header_id|>assistant<|end_header_id|>""", - input_variables=["question", "context"], - ) + input_variables=["question", "context"], + ) - # Chain - generate_chain = generate_prompt | llama3 | StrOutputParser() - router_prompt = PromptTemplate( - template=""" + # Chain + generate_chain = generate_prompt | llama3 | StrOutputParser() + router_prompt = PromptTemplate( + template=""" <|begin_of_text|> @@ -77,14 +81,14 @@ def config_agent(llama3, llama3_json): <|start_header_id|>assistant<|end_header_id|> """, - input_variables=["question"], - ) + input_variables=["question"], + ) - # Chain - question_router = router_prompt | llama3_json | JsonOutputParser() + # Chain + question_router = router_prompt | llama3_json | JsonOutputParser() - query_prompt = PromptTemplate( - template=""" + query_prompt = PromptTemplate( + template=""" <|begin_of_text|> @@ -100,132 +104,137 @@ def config_agent(llama3, llama3_json): <|start_header_id|>assistant<|end_header_id|> """, - input_variables=["question"], - ) + input_variables=["question"], + ) - # Chain - query_chain = query_prompt | llama3_json | JsonOutputParser() - # Graph State - class GraphState(TypedDict): - """ - Represents the state of our graph. + # Chain + query_chain = query_prompt | llama3_json | JsonOutputParser() + # Graph State - Attributes: - question: question - generation: LLM generation - send_order: revised question for send order - context: send_order result - """ - question : str - generation : str - send_order : str - context : str + class GraphState(TypedDict): + """ + Represents the state of our graph. - # Node - Generate + Attributes: + question: question + generation: LLM generation + send_order: revised question for send order + context: send_order result + """ + question: str + generation: str + send_order: str + context: str - def generate(state): - """ - Generate answer + # Node - Generate - Args: - state (dict): The current graph state + def generate(state): + """ + Generate answer - Returns: - state (dict): New key added to state, generation, that contains LLM generation - """ - - print("Step: Generating Final Response") - question = state["question"] - context = state.get("context", None) - print(context) - # TODO:: 根据context特定的内容生产答案 - if context is not None and context.index("orderinfo") != -1: - return {"generation": context.replace("orderinfo:", "")} - else: - generation = generate_chain.invoke({"context": context, "question": question}) - return {"generation": generation} + Args: + state (dict): The current graph state - # Node - Query Transformation + Returns: + state (dict): New key added to state, generation, that contains LLM generation + """ - def transform_query(state): - """ - Transform user question to order info + print("Step: Generating Final Response") + question = state["question"] + context = state.get("context", None) + print(context) + # TODO:: 根据context特定的内容生产答案 + if context is not None and context.index("orderinfo") != -1: + return {"generation": context.replace("orderinfo:", "")} + else: + generation = generate_chain.invoke( + {"context": context, "question": question}) + return {"generation": generation} - Args: - state (dict): The current graph state + # Node - Query Transformation - Returns: - state (dict): Appended amount of CEC to context - """ - - print("Step: Optimizing Query for Send Order") - question = state['question'] - gen_query = query_chain.invoke({"question": question}) - search_query = gen_query["count"] - print("send_order", search_query) - return {"send_order": search_query} + def transform_query(state): + """ + Transform user question to order info - # Node - Send Order + Args: + state (dict): The current graph state - def send_order(state): - """ - Send order based on the question + Returns: + state (dict): Appended amount of CEC to context + """ - Args: - state (dict): The current graph state + print("Step: Optimizing Query for Send Order") + question = state['question'] + gen_query = query_chain.invoke({"question": question}) + search_query = gen_query["count"] + print("send_order", search_query) + return {"send_order": search_query} - Returns: - state (dict): Appended Order Info to context - """ - print("Step: before Send Order") - amount = state['send_order'] - print(amount) - print(f'Step: build order info for : "{amount}" CEC') - order_info = {"amount": amount, "price": 0.1, "name": "CEC", "url": "https://www.example.com"} - search_result = f"orderinfo:{json.dumps(order_info)}" - return {"context": search_result} + # Node - Send Order - # Conditional Edge, Routing + def send_order(state): + """ + Send order based on the question - def route_question(state): - """ - route question to send order or generation. + Args: + state (dict): The current graph state - Args: - state (dict): The current graph state + Returns: + state (dict): Appended Order Info to context + """ + print("Step: before Send Order") + amount = state['send_order'] + print(amount) + print(f'Step: build order info for : "{amount}" CEC') + order_info = {"amount": amount, "price": 0.1, + "name": "CEC", "url": "https://www.example.com"} + search_result = f"orderinfo:{json.dumps(order_info)}" + return {"context": search_result} - Returns: - str: Next node to call - """ + # Conditional Edge, Routing - print("Step: Routing Query") - question = state['question'] - output = question_router.invoke({"question": question}) - if output['choice'] == "send_order": - print("Step: Routing Query to Send Order") - return "sendorder" - elif output['choice'] == 'generate': - print("Step: Routing Query to Generation") - return "generate" + def route_question(state): + """ + route question to send order or generation. - # Build the nodes - workflow = StateGraph(GraphState) - workflow.add_node("sendorder", send_order) - workflow.add_node("transform_query", transform_query) - workflow.add_node("generate", generate) + Args: + state (dict): The current graph state - # Build the edges - workflow.set_conditional_entry_point( - route_question, - { - "sendorder": "transform_query", - "generate": "generate", - }, - ) - workflow.add_edge("transform_query", "sendorder") - workflow.add_edge("sendorder", "generate") - workflow.add_edge("generate", END) + Returns: + str: Next node to call + """ - # Compile the workflow - local_agent = workflow.compile() - return local_agent \ No newline at end of file + print("Step: Routing Query") + question = state['question'] + output = question_router.invoke({"question": question}) + if output['choice'] == "send_order": + print("Step: Routing Query to Send Order") + return "sendorder" + elif output['choice'] == 'generate': + print("Step: Routing Query to Generation") + return "generate" + + # Build the nodes + workflow = StateGraph(GraphState) + workflow.add_node("sendorder", send_order) + workflow.add_node("transform_query", transform_query) + workflow.add_node("generate", generate) + + # Build the edges + workflow.set_conditional_entry_point( + route_question, + { + "sendorder": "transform_query", + "generate": "generate", + }, + ) + workflow.add_edge("transform_query", "sendorder") + workflow.add_edge("sendorder", "generate") + workflow.add_edge("generate", END) + + # Compile the workflow + memory = MemorySaver() + # local_agent = workflow.compile(checkpointer=memory) + local_agent = workflow.compile() + return local_agent diff --git a/app.py b/app.py index 9270ff0..97afe9b 100644 --- a/app.py +++ b/app.py @@ -7,12 +7,14 @@ llama3, llama3_json = configure_llm(local_llm, 0) local_agent = config_agent(llama3, llama3_json) + def run_agent(query): output = local_agent.invoke({"question": query}) print("=======") print(output["generation"]) # display(Markdown(output["generation"])) + # run_agent("I want to buy 100 CEC") run_agent("I'm Cz, I want to buy 100 CEC, how about the price?") -# run_agent("I'm Cz, How about CEC?") \ No newline at end of file +# run_agent("I'm Cz, How about CEC?") diff --git a/streamlit_app.py b/streamlit_app.py index 1e927aa..a77e468 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -7,22 +7,29 @@ st.sidebar.header("Configure LLM") st.title("CEC Seller Assistant") # Model Selection model_options = ["llama3.2"] -selected_model = st.sidebar.selectbox("Choose the LLM Model", options=model_options, index=0) +selected_model = st.sidebar.selectbox( + "Choose the LLM Model", options=model_options, index=0) # Temperature Setting -temperature = st.sidebar.slider("Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1) +temperature = st.sidebar.slider( + "Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1) -llama3, llama3_json=configure_llm(selected_model, temperature) +llama3, llama3_json = configure_llm(selected_model, temperature) local_agent = config_agent(llama3, llama3_json) + def run_agent(query): + # config = {"configurable": {"thread_id": "1"}} + # output = local_agent.invoke({"question": query}, config) + # print(list(local_agent.get_state_history(config))) output = local_agent.invoke({"question": query}) print("=======") - return output["generation"] + + user_query = st.text_input("Enter your research question:", "") if st.button("Run Query"): if user_query: - st.write(run_agent(user_query)) \ No newline at end of file + st.write(run_agent(user_query))