reformat code

This commit is contained in:
CounterFire2023 2024-12-12 20:31:08 +08:00
parent 3d2e96c1fa
commit 2b9e07caa3
3 changed files with 149 additions and 131 deletions

View File

@ -3,18 +3,22 @@ from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langgraph.graph import END, StateGraph from langgraph.graph import END, StateGraph
from langchain_community.chat_models import ChatOllama from langchain_community.chat_models import ChatOllama
# For State Graph from langgraph.checkpoint.memory import MemorySaver
# For State Graph
from typing_extensions import TypedDict from typing_extensions import TypedDict
import json import json
def configure_llm(model, temperature): def configure_llm(model, temperature):
llama3 = ChatOllama(model=model, temperature=temperature) llama3 = ChatOllama(model=model, temperature=temperature)
llama3_json = ChatOllama(model=model, format='json', temperature=temperature) llama3_json = ChatOllama(
return llama3, llama3_json model=model, format='json', temperature=temperature)
return llama3, llama3_json
def config_agent(llama3, llama3_json): def config_agent(llama3, llama3_json):
generate_prompt = PromptTemplate( generate_prompt = PromptTemplate(
template=""" template="""
<|begin_of_text|> <|begin_of_text|>
@ -49,13 +53,13 @@ def config_agent(llama3, llama3_json):
<|eot_id|> <|eot_id|>
<|start_header_id|>assistant<|end_header_id|>""", <|start_header_id|>assistant<|end_header_id|>""",
input_variables=["question", "context"], input_variables=["question", "context"],
) )
# Chain # Chain
generate_chain = generate_prompt | llama3 | StrOutputParser() generate_chain = generate_prompt | llama3 | StrOutputParser()
router_prompt = PromptTemplate( router_prompt = PromptTemplate(
template=""" template="""
<|begin_of_text|> <|begin_of_text|>
@ -77,14 +81,14 @@ def config_agent(llama3, llama3_json):
<|start_header_id|>assistant<|end_header_id|> <|start_header_id|>assistant<|end_header_id|>
""", """,
input_variables=["question"], input_variables=["question"],
) )
# Chain # Chain
question_router = router_prompt | llama3_json | JsonOutputParser() question_router = router_prompt | llama3_json | JsonOutputParser()
query_prompt = PromptTemplate( query_prompt = PromptTemplate(
template=""" template="""
<|begin_of_text|> <|begin_of_text|>
@ -100,132 +104,137 @@ def config_agent(llama3, llama3_json):
<|start_header_id|>assistant<|end_header_id|> <|start_header_id|>assistant<|end_header_id|>
""", """,
input_variables=["question"], input_variables=["question"],
) )
# Chain # Chain
query_chain = query_prompt | llama3_json | JsonOutputParser() query_chain = query_prompt | llama3_json | JsonOutputParser()
# Graph State # Graph State
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes: class GraphState(TypedDict):
question: question """
generation: LLM generation Represents the state of our graph.
send_order: revised question for send order
context: send_order result
"""
question : str
generation : str
send_order : str
context : str
# Node - Generate Attributes:
question: question
generation: LLM generation
send_order: revised question for send order
context: send_order result
"""
question: str
generation: str
send_order: str
context: str
def generate(state): # Node - Generate
"""
Generate answer
Args: def generate(state):
state (dict): The current graph state """
Generate answer
Returns: Args:
state (dict): New key added to state, generation, that contains LLM generation state (dict): The current graph state
"""
print("Step: Generating Final Response")
question = state["question"]
context = state.get("context", None)
print(context)
# TODO:: 根据context特定的内容生产答案
if context is not None and context.index("orderinfo") != -1:
return {"generation": context.replace("orderinfo:", "")}
else:
generation = generate_chain.invoke({"context": context, "question": question})
return {"generation": generation}
# Node - Query Transformation Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
def transform_query(state): print("Step: Generating Final Response")
""" question = state["question"]
Transform user question to order info context = state.get("context", None)
print(context)
# TODO:: 根据context特定的内容生产答案
if context is not None and context.index("orderinfo") != -1:
return {"generation": context.replace("orderinfo:", "")}
else:
generation = generate_chain.invoke(
{"context": context, "question": question})
return {"generation": generation}
Args: # Node - Query Transformation
state (dict): The current graph state
Returns: def transform_query(state):
state (dict): Appended amount of CEC to context """
""" Transform user question to order info
print("Step: Optimizing Query for Send Order")
question = state['question']
gen_query = query_chain.invoke({"question": question})
search_query = gen_query["count"]
print("send_order", search_query)
return {"send_order": search_query}
# Node - Send Order Args:
state (dict): The current graph state
def send_order(state): Returns:
""" state (dict): Appended amount of CEC to context
Send order based on the question """
Args: print("Step: Optimizing Query for Send Order")
state (dict): The current graph state question = state['question']
gen_query = query_chain.invoke({"question": question})
search_query = gen_query["count"]
print("send_order", search_query)
return {"send_order": search_query}
Returns: # Node - Send Order
state (dict): Appended Order Info to context
"""
print("Step: before Send Order")
amount = state['send_order']
print(amount)
print(f'Step: build order info for : "{amount}" CEC')
order_info = {"amount": amount, "price": 0.1, "name": "CEC", "url": "https://www.example.com"}
search_result = f"orderinfo:{json.dumps(order_info)}"
return {"context": search_result}
# Conditional Edge, Routing def send_order(state):
"""
Send order based on the question
def route_question(state): Args:
""" state (dict): The current graph state
route question to send order or generation.
Args: Returns:
state (dict): The current graph state state (dict): Appended Order Info to context
"""
print("Step: before Send Order")
amount = state['send_order']
print(amount)
print(f'Step: build order info for : "{amount}" CEC')
order_info = {"amount": amount, "price": 0.1,
"name": "CEC", "url": "https://www.example.com"}
search_result = f"orderinfo:{json.dumps(order_info)}"
return {"context": search_result}
Returns: # Conditional Edge, Routing
str: Next node to call
"""
print("Step: Routing Query") def route_question(state):
question = state['question'] """
output = question_router.invoke({"question": question}) route question to send order or generation.
if output['choice'] == "send_order":
print("Step: Routing Query to Send Order")
return "sendorder"
elif output['choice'] == 'generate':
print("Step: Routing Query to Generation")
return "generate"
# Build the nodes Args:
workflow = StateGraph(GraphState) state (dict): The current graph state
workflow.add_node("sendorder", send_order)
workflow.add_node("transform_query", transform_query)
workflow.add_node("generate", generate)
# Build the edges Returns:
workflow.set_conditional_entry_point( str: Next node to call
route_question, """
{
"sendorder": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "sendorder")
workflow.add_edge("sendorder", "generate")
workflow.add_edge("generate", END)
# Compile the workflow print("Step: Routing Query")
local_agent = workflow.compile() question = state['question']
return local_agent output = question_router.invoke({"question": question})
if output['choice'] == "send_order":
print("Step: Routing Query to Send Order")
return "sendorder"
elif output['choice'] == 'generate':
print("Step: Routing Query to Generation")
return "generate"
# Build the nodes
workflow = StateGraph(GraphState)
workflow.add_node("sendorder", send_order)
workflow.add_node("transform_query", transform_query)
workflow.add_node("generate", generate)
# Build the edges
workflow.set_conditional_entry_point(
route_question,
{
"sendorder": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "sendorder")
workflow.add_edge("sendorder", "generate")
workflow.add_edge("generate", END)
# Compile the workflow
memory = MemorySaver()
# local_agent = workflow.compile(checkpointer=memory)
local_agent = workflow.compile()
return local_agent

4
app.py
View File

@ -7,12 +7,14 @@ llama3, llama3_json = configure_llm(local_llm, 0)
local_agent = config_agent(llama3, llama3_json) local_agent = config_agent(llama3, llama3_json)
def run_agent(query): def run_agent(query):
output = local_agent.invoke({"question": query}) output = local_agent.invoke({"question": query})
print("=======") print("=======")
print(output["generation"]) print(output["generation"])
# display(Markdown(output["generation"])) # display(Markdown(output["generation"]))
# run_agent("I want to buy 100 CEC") # run_agent("I want to buy 100 CEC")
run_agent("I'm Cz, I want to buy 100 CEC, how about the price?") run_agent("I'm Cz, I want to buy 100 CEC, how about the price?")
# run_agent("I'm Cz, How about CEC?") # run_agent("I'm Cz, How about CEC?")

View File

@ -7,22 +7,29 @@ st.sidebar.header("Configure LLM")
st.title("CEC Seller Assistant") st.title("CEC Seller Assistant")
# Model Selection # Model Selection
model_options = ["llama3.2"] model_options = ["llama3.2"]
selected_model = st.sidebar.selectbox("Choose the LLM Model", options=model_options, index=0) selected_model = st.sidebar.selectbox(
"Choose the LLM Model", options=model_options, index=0)
# Temperature Setting # Temperature Setting
temperature = st.sidebar.slider("Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1) temperature = st.sidebar.slider(
"Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
llama3, llama3_json=configure_llm(selected_model, temperature) llama3, llama3_json = configure_llm(selected_model, temperature)
local_agent = config_agent(llama3, llama3_json) local_agent = config_agent(llama3, llama3_json)
def run_agent(query): def run_agent(query):
# config = {"configurable": {"thread_id": "1"}}
# output = local_agent.invoke({"question": query}, config)
# print(list(local_agent.get_state_history(config)))
output = local_agent.invoke({"question": query}) output = local_agent.invoke({"question": query})
print("=======") print("=======")
return output["generation"] return output["generation"]
user_query = st.text_input("Enter your research question:", "") user_query = st.text_input("Enter your research question:", "")
if st.button("Run Query"): if st.button("Run Query"):
if user_query: if user_query:
st.write(run_agent(user_query)) st.write(run_agent(user_query))