reformat code
This commit is contained in:
parent
3d2e96c1fa
commit
2b9e07caa3
259
agent_util.py
259
agent_util.py
@ -3,18 +3,22 @@ from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
# For State Graph
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
# For State Graph
|
||||
from typing_extensions import TypedDict
|
||||
import json
|
||||
|
||||
|
||||
def configure_llm(model, temperature):
|
||||
llama3 = ChatOllama(model=model, temperature=temperature)
|
||||
llama3_json = ChatOllama(model=model, format='json', temperature=temperature)
|
||||
return llama3, llama3_json
|
||||
llama3 = ChatOllama(model=model, temperature=temperature)
|
||||
llama3_json = ChatOllama(
|
||||
model=model, format='json', temperature=temperature)
|
||||
return llama3, llama3_json
|
||||
|
||||
|
||||
def config_agent(llama3, llama3_json):
|
||||
generate_prompt = PromptTemplate(
|
||||
template="""
|
||||
generate_prompt = PromptTemplate(
|
||||
template="""
|
||||
|
||||
<|begin_of_text|>
|
||||
|
||||
@ -49,13 +53,13 @@ def config_agent(llama3, llama3_json):
|
||||
<|eot_id|>
|
||||
|
||||
<|start_header_id|>assistant<|end_header_id|>""",
|
||||
input_variables=["question", "context"],
|
||||
)
|
||||
input_variables=["question", "context"],
|
||||
)
|
||||
|
||||
# Chain
|
||||
generate_chain = generate_prompt | llama3 | StrOutputParser()
|
||||
router_prompt = PromptTemplate(
|
||||
template="""
|
||||
# Chain
|
||||
generate_chain = generate_prompt | llama3 | StrOutputParser()
|
||||
router_prompt = PromptTemplate(
|
||||
template="""
|
||||
|
||||
<|begin_of_text|>
|
||||
|
||||
@ -77,14 +81,14 @@ def config_agent(llama3, llama3_json):
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
""",
|
||||
input_variables=["question"],
|
||||
)
|
||||
input_variables=["question"],
|
||||
)
|
||||
|
||||
# Chain
|
||||
question_router = router_prompt | llama3_json | JsonOutputParser()
|
||||
# Chain
|
||||
question_router = router_prompt | llama3_json | JsonOutputParser()
|
||||
|
||||
query_prompt = PromptTemplate(
|
||||
template="""
|
||||
query_prompt = PromptTemplate(
|
||||
template="""
|
||||
|
||||
<|begin_of_text|>
|
||||
|
||||
@ -100,132 +104,137 @@ def config_agent(llama3, llama3_json):
|
||||
<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
""",
|
||||
input_variables=["question"],
|
||||
)
|
||||
input_variables=["question"],
|
||||
)
|
||||
|
||||
# Chain
|
||||
query_chain = query_prompt | llama3_json | JsonOutputParser()
|
||||
# Graph State
|
||||
class GraphState(TypedDict):
|
||||
"""
|
||||
Represents the state of our graph.
|
||||
# Chain
|
||||
query_chain = query_prompt | llama3_json | JsonOutputParser()
|
||||
# Graph State
|
||||
|
||||
Attributes:
|
||||
question: question
|
||||
generation: LLM generation
|
||||
send_order: revised question for send order
|
||||
context: send_order result
|
||||
"""
|
||||
question : str
|
||||
generation : str
|
||||
send_order : str
|
||||
context : str
|
||||
class GraphState(TypedDict):
|
||||
"""
|
||||
Represents the state of our graph.
|
||||
|
||||
# Node - Generate
|
||||
Attributes:
|
||||
question: question
|
||||
generation: LLM generation
|
||||
send_order: revised question for send order
|
||||
context: send_order result
|
||||
"""
|
||||
question: str
|
||||
generation: str
|
||||
send_order: str
|
||||
context: str
|
||||
|
||||
def generate(state):
|
||||
"""
|
||||
Generate answer
|
||||
# Node - Generate
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
def generate(state):
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, generation, that contains LLM generation
|
||||
"""
|
||||
|
||||
print("Step: Generating Final Response")
|
||||
question = state["question"]
|
||||
context = state.get("context", None)
|
||||
print(context)
|
||||
# TODO:: 根据context特定的内容生产答案
|
||||
if context is not None and context.index("orderinfo") != -1:
|
||||
return {"generation": context.replace("orderinfo:", "")}
|
||||
else:
|
||||
generation = generate_chain.invoke({"context": context, "question": question})
|
||||
return {"generation": generation}
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
# Node - Query Transformation
|
||||
Returns:
|
||||
state (dict): New key added to state, generation, that contains LLM generation
|
||||
"""
|
||||
|
||||
def transform_query(state):
|
||||
"""
|
||||
Transform user question to order info
|
||||
print("Step: Generating Final Response")
|
||||
question = state["question"]
|
||||
context = state.get("context", None)
|
||||
print(context)
|
||||
# TODO:: 根据context特定的内容生产答案
|
||||
if context is not None and context.index("orderinfo") != -1:
|
||||
return {"generation": context.replace("orderinfo:", "")}
|
||||
else:
|
||||
generation = generate_chain.invoke(
|
||||
{"context": context, "question": question})
|
||||
return {"generation": generation}
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
# Node - Query Transformation
|
||||
|
||||
Returns:
|
||||
state (dict): Appended amount of CEC to context
|
||||
"""
|
||||
|
||||
print("Step: Optimizing Query for Send Order")
|
||||
question = state['question']
|
||||
gen_query = query_chain.invoke({"question": question})
|
||||
search_query = gen_query["count"]
|
||||
print("send_order", search_query)
|
||||
return {"send_order": search_query}
|
||||
def transform_query(state):
|
||||
"""
|
||||
Transform user question to order info
|
||||
|
||||
# Node - Send Order
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
def send_order(state):
|
||||
"""
|
||||
Send order based on the question
|
||||
Returns:
|
||||
state (dict): Appended amount of CEC to context
|
||||
"""
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
print("Step: Optimizing Query for Send Order")
|
||||
question = state['question']
|
||||
gen_query = query_chain.invoke({"question": question})
|
||||
search_query = gen_query["count"]
|
||||
print("send_order", search_query)
|
||||
return {"send_order": search_query}
|
||||
|
||||
Returns:
|
||||
state (dict): Appended Order Info to context
|
||||
"""
|
||||
print("Step: before Send Order")
|
||||
amount = state['send_order']
|
||||
print(amount)
|
||||
print(f'Step: build order info for : "{amount}" CEC')
|
||||
order_info = {"amount": amount, "price": 0.1, "name": "CEC", "url": "https://www.example.com"}
|
||||
search_result = f"orderinfo:{json.dumps(order_info)}"
|
||||
return {"context": search_result}
|
||||
# Node - Send Order
|
||||
|
||||
# Conditional Edge, Routing
|
||||
def send_order(state):
|
||||
"""
|
||||
Send order based on the question
|
||||
|
||||
def route_question(state):
|
||||
"""
|
||||
route question to send order or generation.
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
Returns:
|
||||
state (dict): Appended Order Info to context
|
||||
"""
|
||||
print("Step: before Send Order")
|
||||
amount = state['send_order']
|
||||
print(amount)
|
||||
print(f'Step: build order info for : "{amount}" CEC')
|
||||
order_info = {"amount": amount, "price": 0.1,
|
||||
"name": "CEC", "url": "https://www.example.com"}
|
||||
search_result = f"orderinfo:{json.dumps(order_info)}"
|
||||
return {"context": search_result}
|
||||
|
||||
Returns:
|
||||
str: Next node to call
|
||||
"""
|
||||
# Conditional Edge, Routing
|
||||
|
||||
print("Step: Routing Query")
|
||||
question = state['question']
|
||||
output = question_router.invoke({"question": question})
|
||||
if output['choice'] == "send_order":
|
||||
print("Step: Routing Query to Send Order")
|
||||
return "sendorder"
|
||||
elif output['choice'] == 'generate':
|
||||
print("Step: Routing Query to Generation")
|
||||
return "generate"
|
||||
def route_question(state):
|
||||
"""
|
||||
route question to send order or generation.
|
||||
|
||||
# Build the nodes
|
||||
workflow = StateGraph(GraphState)
|
||||
workflow.add_node("sendorder", send_order)
|
||||
workflow.add_node("transform_query", transform_query)
|
||||
workflow.add_node("generate", generate)
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
# Build the edges
|
||||
workflow.set_conditional_entry_point(
|
||||
route_question,
|
||||
{
|
||||
"sendorder": "transform_query",
|
||||
"generate": "generate",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("transform_query", "sendorder")
|
||||
workflow.add_edge("sendorder", "generate")
|
||||
workflow.add_edge("generate", END)
|
||||
Returns:
|
||||
str: Next node to call
|
||||
"""
|
||||
|
||||
# Compile the workflow
|
||||
local_agent = workflow.compile()
|
||||
return local_agent
|
||||
print("Step: Routing Query")
|
||||
question = state['question']
|
||||
output = question_router.invoke({"question": question})
|
||||
if output['choice'] == "send_order":
|
||||
print("Step: Routing Query to Send Order")
|
||||
return "sendorder"
|
||||
elif output['choice'] == 'generate':
|
||||
print("Step: Routing Query to Generation")
|
||||
return "generate"
|
||||
|
||||
# Build the nodes
|
||||
workflow = StateGraph(GraphState)
|
||||
workflow.add_node("sendorder", send_order)
|
||||
workflow.add_node("transform_query", transform_query)
|
||||
workflow.add_node("generate", generate)
|
||||
|
||||
# Build the edges
|
||||
workflow.set_conditional_entry_point(
|
||||
route_question,
|
||||
{
|
||||
"sendorder": "transform_query",
|
||||
"generate": "generate",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("transform_query", "sendorder")
|
||||
workflow.add_edge("sendorder", "generate")
|
||||
workflow.add_edge("generate", END)
|
||||
|
||||
# Compile the workflow
|
||||
memory = MemorySaver()
|
||||
# local_agent = workflow.compile(checkpointer=memory)
|
||||
local_agent = workflow.compile()
|
||||
return local_agent
|
||||
|
4
app.py
4
app.py
@ -7,12 +7,14 @@ llama3, llama3_json = configure_llm(local_llm, 0)
|
||||
|
||||
local_agent = config_agent(llama3, llama3_json)
|
||||
|
||||
|
||||
def run_agent(query):
|
||||
output = local_agent.invoke({"question": query})
|
||||
print("=======")
|
||||
print(output["generation"])
|
||||
# display(Markdown(output["generation"]))
|
||||
|
||||
|
||||
# run_agent("I want to buy 100 CEC")
|
||||
run_agent("I'm Cz, I want to buy 100 CEC, how about the price?")
|
||||
# run_agent("I'm Cz, How about CEC?")
|
||||
# run_agent("I'm Cz, How about CEC?")
|
||||
|
@ -7,22 +7,29 @@ st.sidebar.header("Configure LLM")
|
||||
st.title("CEC Seller Assistant")
|
||||
# Model Selection
|
||||
model_options = ["llama3.2"]
|
||||
selected_model = st.sidebar.selectbox("Choose the LLM Model", options=model_options, index=0)
|
||||
selected_model = st.sidebar.selectbox(
|
||||
"Choose the LLM Model", options=model_options, index=0)
|
||||
|
||||
# Temperature Setting
|
||||
temperature = st.sidebar.slider("Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
|
||||
temperature = st.sidebar.slider(
|
||||
"Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
|
||||
|
||||
llama3, llama3_json=configure_llm(selected_model, temperature)
|
||||
llama3, llama3_json = configure_llm(selected_model, temperature)
|
||||
|
||||
local_agent = config_agent(llama3, llama3_json)
|
||||
|
||||
|
||||
def run_agent(query):
|
||||
# config = {"configurable": {"thread_id": "1"}}
|
||||
# output = local_agent.invoke({"question": query}, config)
|
||||
# print(list(local_agent.get_state_history(config)))
|
||||
output = local_agent.invoke({"question": query})
|
||||
print("=======")
|
||||
|
||||
return output["generation"]
|
||||
|
||||
|
||||
user_query = st.text_input("Enter your research question:", "")
|
||||
|
||||
if st.button("Run Query"):
|
||||
if user_query:
|
||||
st.write(run_agent(user_query))
|
||||
st.write(run_agent(user_query))
|
||||
|
Loading…
x
Reference in New Issue
Block a user