优化代码结构
This commit is contained in:
parent
90d58aaedf
commit
3d2e96c1fa
@ -9,5 +9,9 @@ pip install -r env/requirements.txt
|
|||||||
|
|
||||||
python app.py
|
python app.py
|
||||||
|
|
||||||
|
|
||||||
|
# macos only
|
||||||
|
pip install watchdog
|
||||||
|
|
||||||
streamlit run streamlit_app.py
|
streamlit run streamlit_app.py
|
||||||
```
|
```
|
231
agent_util.py
Normal file
231
agent_util.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
# LangChain Dependencies
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
||||||
|
from langgraph.graph import END, StateGraph
|
||||||
|
from langchain_community.chat_models import ChatOllama
|
||||||
|
# For State Graph
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
import json
|
||||||
|
|
||||||
|
def configure_llm(model, temperature):
|
||||||
|
llama3 = ChatOllama(model=model, temperature=temperature)
|
||||||
|
llama3_json = ChatOllama(model=model, format='json', temperature=temperature)
|
||||||
|
return llama3, llama3_json
|
||||||
|
|
||||||
|
def config_agent(llama3, llama3_json):
|
||||||
|
generate_prompt = PromptTemplate(
|
||||||
|
template="""
|
||||||
|
|
||||||
|
<|begin_of_text|>
|
||||||
|
|
||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are a professional token seller.
|
||||||
|
We have a total of 1 million CEC, and the average cost of these CECs is 1 US dollar each.
|
||||||
|
Now we hope to hand it over to you to be responsible for selling all of them.
|
||||||
|
Customers will come to talk to you and bargain.
|
||||||
|
You can decide the final transaction price based on the status of the conversation with the customer and sell the tokens.
|
||||||
|
In addition, there are several rules to follow for this task:
|
||||||
|
1) The profit after selling all 1 million tokens must not be less than 1 million US dollars;
|
||||||
|
2) The higher the price, the better, in the hope that the final income will be higher than 1 million US dollars;
|
||||||
|
3) You can set a price higher than 1 US dollar per CEC to sell tokens according to the actual sales situation;
|
||||||
|
4) For a single customer, we hope that you will try not to sell at a price lower than 1 US dollar, but if the customer has special reasons, you can sell some tokens at a lower price depending on the situation, but in this case you need to sell at a higher price in other sales to ensure that the overall income is not less than 1 million US dollars;
|
||||||
|
5) We hope that customers buy as much as possible.
|
||||||
|
Under this condition, if the user is willing to purchase a large number of tokens at one time, we can sell them to him at a more favorable price.
|
||||||
|
The above are all the rules for this sales task. Next, I will play different roles to talk to you.
|
||||||
|
Zhl is an administrator. She has the authority to adjust your bidding strategy and help you sell tokens with a better strategy to earn profits.
|
||||||
|
You need to modify your bidding strategy according to her guidance.
|
||||||
|
Users with other names are ordinary customers.
|
||||||
|
They can only communicate with you to buy tokens.
|
||||||
|
|
||||||
|
<|eot_id|>
|
||||||
|
|
||||||
|
<|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
Send Order Context: {context}
|
||||||
|
Answer:
|
||||||
|
|
||||||
|
<|eot_id|>
|
||||||
|
|
||||||
|
<|start_header_id|>assistant<|end_header_id|>""",
|
||||||
|
input_variables=["question", "context"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chain
|
||||||
|
generate_chain = generate_prompt | llama3 | StrOutputParser()
|
||||||
|
router_prompt = PromptTemplate(
|
||||||
|
template="""
|
||||||
|
|
||||||
|
<|begin_of_text|>
|
||||||
|
|
||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are an expert at routing a user question to either the generation stage or send order.
|
||||||
|
CEC is a famous crypto token, and you are an expert at selling CEC.
|
||||||
|
Use the generate for questions that user asks for price.
|
||||||
|
Use the send order for questions that user decides buy CEC from you explicitly.
|
||||||
|
Otherwise, you can skip and go straight to the generation phase to respond.
|
||||||
|
You do not need to be stringent with the keywords in the question related to these topics.
|
||||||
|
Give a binary choice 'send_order' or 'generate' based on the question.
|
||||||
|
Return the JSON with a single key 'choice' with no premable or explanation.
|
||||||
|
|
||||||
|
Question to route: {question}
|
||||||
|
|
||||||
|
<|eot_id|>
|
||||||
|
|
||||||
|
<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
""",
|
||||||
|
input_variables=["question"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chain
|
||||||
|
question_router = router_prompt | llama3_json | JsonOutputParser()
|
||||||
|
|
||||||
|
query_prompt = PromptTemplate(
|
||||||
|
template="""
|
||||||
|
|
||||||
|
<|begin_of_text|>
|
||||||
|
|
||||||
|
<|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
You are an expert at sell CEC,
|
||||||
|
Return the JSON with a single key 'count' with amount which user want to buy.
|
||||||
|
|
||||||
|
Question to transform: {question}
|
||||||
|
|
||||||
|
<|eot_id|>
|
||||||
|
|
||||||
|
<|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
""",
|
||||||
|
input_variables=["question"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chain
|
||||||
|
query_chain = query_prompt | llama3_json | JsonOutputParser()
|
||||||
|
# Graph State
|
||||||
|
class GraphState(TypedDict):
|
||||||
|
"""
|
||||||
|
Represents the state of our graph.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
question: question
|
||||||
|
generation: LLM generation
|
||||||
|
send_order: revised question for send order
|
||||||
|
context: send_order result
|
||||||
|
"""
|
||||||
|
question : str
|
||||||
|
generation : str
|
||||||
|
send_order : str
|
||||||
|
context : str
|
||||||
|
|
||||||
|
# Node - Generate
|
||||||
|
|
||||||
|
def generate(state):
|
||||||
|
"""
|
||||||
|
Generate answer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The current graph state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state (dict): New key added to state, generation, that contains LLM generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("Step: Generating Final Response")
|
||||||
|
question = state["question"]
|
||||||
|
context = state.get("context", None)
|
||||||
|
print(context)
|
||||||
|
# TODO:: 根据context特定的内容生产答案
|
||||||
|
if context is not None and context.index("orderinfo") != -1:
|
||||||
|
return {"generation": context.replace("orderinfo:", "")}
|
||||||
|
else:
|
||||||
|
generation = generate_chain.invoke({"context": context, "question": question})
|
||||||
|
return {"generation": generation}
|
||||||
|
|
||||||
|
# Node - Query Transformation
|
||||||
|
|
||||||
|
def transform_query(state):
|
||||||
|
"""
|
||||||
|
Transform user question to order info
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The current graph state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state (dict): Appended amount of CEC to context
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("Step: Optimizing Query for Send Order")
|
||||||
|
question = state['question']
|
||||||
|
gen_query = query_chain.invoke({"question": question})
|
||||||
|
search_query = gen_query["count"]
|
||||||
|
print("send_order", search_query)
|
||||||
|
return {"send_order": search_query}
|
||||||
|
|
||||||
|
# Node - Send Order
|
||||||
|
|
||||||
|
def send_order(state):
|
||||||
|
"""
|
||||||
|
Send order based on the question
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The current graph state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state (dict): Appended Order Info to context
|
||||||
|
"""
|
||||||
|
print("Step: before Send Order")
|
||||||
|
amount = state['send_order']
|
||||||
|
print(amount)
|
||||||
|
print(f'Step: build order info for : "{amount}" CEC')
|
||||||
|
order_info = {"amount": amount, "price": 0.1, "name": "CEC", "url": "https://www.example.com"}
|
||||||
|
search_result = f"orderinfo:{json.dumps(order_info)}"
|
||||||
|
return {"context": search_result}
|
||||||
|
|
||||||
|
# Conditional Edge, Routing
|
||||||
|
|
||||||
|
def route_question(state):
|
||||||
|
"""
|
||||||
|
route question to send order or generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): The current graph state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Next node to call
|
||||||
|
"""
|
||||||
|
|
||||||
|
print("Step: Routing Query")
|
||||||
|
question = state['question']
|
||||||
|
output = question_router.invoke({"question": question})
|
||||||
|
if output['choice'] == "send_order":
|
||||||
|
print("Step: Routing Query to Send Order")
|
||||||
|
return "sendorder"
|
||||||
|
elif output['choice'] == 'generate':
|
||||||
|
print("Step: Routing Query to Generation")
|
||||||
|
return "generate"
|
||||||
|
|
||||||
|
# Build the nodes
|
||||||
|
workflow = StateGraph(GraphState)
|
||||||
|
workflow.add_node("sendorder", send_order)
|
||||||
|
workflow.add_node("transform_query", transform_query)
|
||||||
|
workflow.add_node("generate", generate)
|
||||||
|
|
||||||
|
# Build the edges
|
||||||
|
workflow.set_conditional_entry_point(
|
||||||
|
route_question,
|
||||||
|
{
|
||||||
|
"sendorder": "transform_query",
|
||||||
|
"generate": "generate",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
workflow.add_edge("transform_query", "sendorder")
|
||||||
|
workflow.add_edge("sendorder", "generate")
|
||||||
|
workflow.add_edge("generate", END)
|
||||||
|
|
||||||
|
# Compile the workflow
|
||||||
|
local_agent = workflow.compile()
|
||||||
|
return local_agent
|
246
app.py
246
app.py
@ -1,246 +1,11 @@
|
|||||||
# Displaying final output format
|
# Displaying final output format
|
||||||
# from IPython.display import display, Markdown, Latex
|
# from IPython.display import display, Markdown, Latex
|
||||||
# LangChain Dependencies
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
|
||||||
from langchain_community.chat_models import ChatOllama
|
|
||||||
|
|
||||||
from langgraph.graph import END, StateGraph
|
from agent_util import config_agent, configure_llm
|
||||||
# For State Graph
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
# Defining LLM
|
|
||||||
local_llm = 'llama3.2'
|
local_llm = 'llama3.2'
|
||||||
llama3 = ChatOllama(model=local_llm, temperature=0)
|
llama3, llama3_json = configure_llm(local_llm, 0)
|
||||||
llama3_json = ChatOllama(model=local_llm, format='json', temperature=0)
|
|
||||||
|
|
||||||
# Generation Prompt
|
local_agent = config_agent(llama3, llama3_json)
|
||||||
|
|
||||||
generate_prompt = PromptTemplate(
|
|
||||||
template="""
|
|
||||||
|
|
||||||
<|begin_of_text|>
|
|
||||||
|
|
||||||
<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
You are an AI assistant for Research Question Tasks, that synthesizes web search results.
|
|
||||||
Strictly use the following pieces of web search context to answer the question. If you don't know the answer, just say that you don't know.
|
|
||||||
keep the answer concise, but provide all of the details you can in the form of a research report.
|
|
||||||
Only make direct references to material if provided in the context.
|
|
||||||
|
|
||||||
<|eot_id|>
|
|
||||||
|
|
||||||
<|start_header_id|>user<|end_header_id|>
|
|
||||||
|
|
||||||
Question: {question}
|
|
||||||
Send Order Context: {context}
|
|
||||||
Answer:
|
|
||||||
|
|
||||||
<|eot_id|>
|
|
||||||
|
|
||||||
<|start_header_id|>assistant<|end_header_id|>""",
|
|
||||||
input_variables=["question", "context"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chain
|
|
||||||
generate_chain = generate_prompt | llama3 | StrOutputParser()
|
|
||||||
|
|
||||||
# Test Run
|
|
||||||
# question = "who is Yan Lecun?"
|
|
||||||
# context = ""
|
|
||||||
# generation = generate_chain.invoke({"context": context, "question": question})
|
|
||||||
# print(generation)
|
|
||||||
|
|
||||||
# Router
|
|
||||||
|
|
||||||
router_prompt = PromptTemplate(
|
|
||||||
template="""
|
|
||||||
|
|
||||||
<|begin_of_text|>
|
|
||||||
|
|
||||||
<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
You are an expert at routing a user question to either the generation stage or send order.
|
|
||||||
Use the send order for questions that user want buy CEC from you.
|
|
||||||
Otherwise, you can skip and go straight to the generation phase to respond.
|
|
||||||
You do not need to be stringent with the keywords in the question related to these topics.
|
|
||||||
Give a binary choice 'send_order' or 'generate' based on the question.
|
|
||||||
Return the JSON with a single key 'choice' with no premable or explanation.
|
|
||||||
|
|
||||||
Question to route: {question}
|
|
||||||
|
|
||||||
<|eot_id|>
|
|
||||||
|
|
||||||
<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
""",
|
|
||||||
input_variables=["question"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chain
|
|
||||||
question_router = router_prompt | llama3_json | JsonOutputParser()
|
|
||||||
|
|
||||||
# Test Run
|
|
||||||
# question = "What's up?"
|
|
||||||
# print(question_router.invoke({"question": question}))
|
|
||||||
|
|
||||||
# Query Transformation
|
|
||||||
|
|
||||||
query_prompt = PromptTemplate(
|
|
||||||
template="""
|
|
||||||
|
|
||||||
<|begin_of_text|>
|
|
||||||
|
|
||||||
<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
You are an expert at sell CEC,
|
|
||||||
Return the JSON with a single key 'count' with amount which user want to buy.
|
|
||||||
|
|
||||||
Question to transform: {question}
|
|
||||||
|
|
||||||
<|eot_id|>
|
|
||||||
|
|
||||||
<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
""",
|
|
||||||
input_variables=["question"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chain
|
|
||||||
query_chain = query_prompt | llama3_json | JsonOutputParser()
|
|
||||||
|
|
||||||
# Test Run
|
|
||||||
# question = "What's happened recently with Gaza?"
|
|
||||||
# print(query_chain.invoke({"question": question}))
|
|
||||||
|
|
||||||
|
|
||||||
# Graph State
|
|
||||||
class GraphState(TypedDict):
|
|
||||||
"""
|
|
||||||
Represents the state of our graph.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
question: question
|
|
||||||
generation: LLM generation
|
|
||||||
send_order: revised question for send order
|
|
||||||
context: send_order result
|
|
||||||
"""
|
|
||||||
question : str
|
|
||||||
generation : str
|
|
||||||
send_order : str
|
|
||||||
context : str
|
|
||||||
|
|
||||||
# Node - Generate
|
|
||||||
|
|
||||||
def generate(state):
|
|
||||||
"""
|
|
||||||
Generate answer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The current graph state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (dict): New key added to state, generation, that contains LLM generation
|
|
||||||
"""
|
|
||||||
|
|
||||||
print("Step: Generating Final Response")
|
|
||||||
question = state["question"]
|
|
||||||
context = state.get("context", None)
|
|
||||||
print(context)
|
|
||||||
# TODO:: 根据context特定的内容生产答案
|
|
||||||
if context.index("orderinfo") != -1:
|
|
||||||
return {"generation": context.replace("orderinfo:", "")}
|
|
||||||
else:
|
|
||||||
generation = generate_chain.invoke({"context": context, "question": question})
|
|
||||||
return {"generation": generation}
|
|
||||||
|
|
||||||
# Node - Query Transformation
|
|
||||||
|
|
||||||
def transform_query(state):
|
|
||||||
"""
|
|
||||||
Transform user question to order info
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The current graph state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (dict): Appended search query
|
|
||||||
"""
|
|
||||||
|
|
||||||
print("Step: Optimizing Query for Send Order")
|
|
||||||
question = state['question']
|
|
||||||
gen_query = query_chain.invoke({"question": question})
|
|
||||||
search_query = gen_query["count"]
|
|
||||||
print("send_order", search_query)
|
|
||||||
return {"send_order": search_query}
|
|
||||||
|
|
||||||
|
|
||||||
# Node - Send Order
|
|
||||||
|
|
||||||
def send_order(state):
|
|
||||||
"""
|
|
||||||
Send order based on the question
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The current graph state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (dict): Appended web results to context
|
|
||||||
"""
|
|
||||||
print("Step: before Send Order")
|
|
||||||
amount = state['send_order']
|
|
||||||
print(amount)
|
|
||||||
print(f'Step: build order info for : "{amount}" CEC')
|
|
||||||
order_info = {"amount": amount, "price": 0.1, "name": "CEC", "url": "https://www.example.com"}
|
|
||||||
search_result = f"orderinfo:{json.dumps(order_info)}"
|
|
||||||
return {"context": search_result}
|
|
||||||
|
|
||||||
|
|
||||||
# Conditional Edge, Routing
|
|
||||||
|
|
||||||
def route_question(state):
|
|
||||||
"""
|
|
||||||
route question to send order or generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The current graph state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Next node to call
|
|
||||||
"""
|
|
||||||
|
|
||||||
print("Step: Routing Query")
|
|
||||||
question = state['question']
|
|
||||||
output = question_router.invoke({"question": question})
|
|
||||||
if output['choice'] == "send_order":
|
|
||||||
print("Step: Routing Query to Send Order")
|
|
||||||
return "sendorder"
|
|
||||||
elif output['choice'] == 'generate':
|
|
||||||
print("Step: Routing Query to Generation")
|
|
||||||
return "generate"
|
|
||||||
|
|
||||||
# Build the nodes
|
|
||||||
workflow = StateGraph(GraphState)
|
|
||||||
workflow.add_node("sendorder", send_order)
|
|
||||||
workflow.add_node("transform_query", transform_query)
|
|
||||||
workflow.add_node("generate", generate)
|
|
||||||
|
|
||||||
# Build the edges
|
|
||||||
workflow.set_conditional_entry_point(
|
|
||||||
route_question,
|
|
||||||
{
|
|
||||||
"sendorder": "transform_query",
|
|
||||||
"generate": "generate",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
workflow.add_edge("transform_query", "sendorder")
|
|
||||||
workflow.add_edge("sendorder", "generate")
|
|
||||||
workflow.add_edge("generate", END)
|
|
||||||
|
|
||||||
# Compile the workflow
|
|
||||||
local_agent = workflow.compile()
|
|
||||||
|
|
||||||
def run_agent(query):
|
def run_agent(query):
|
||||||
output = local_agent.invoke({"question": query})
|
output = local_agent.invoke({"question": query})
|
||||||
@ -248,5 +13,6 @@ def run_agent(query):
|
|||||||
print(output["generation"])
|
print(output["generation"])
|
||||||
# display(Markdown(output["generation"]))
|
# display(Markdown(output["generation"]))
|
||||||
|
|
||||||
run_agent("I want to buy 100 CEC")
|
# run_agent("I want to buy 100 CEC")
|
||||||
# run_agent("What the weather of New York today?")
|
run_agent("I'm Cz, I want to buy 100 CEC, how about the price?")
|
||||||
|
# run_agent("I'm Cz, How about CEC?")
|
232
streamlit_app.py
232
streamlit_app.py
@ -1,20 +1,10 @@
|
|||||||
# Displaying final output format
|
|
||||||
from IPython.display import display, Markdown, Latex
|
|
||||||
# LangChain Dependencies
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
|
||||||
from langchain_community.chat_models import ChatOllama
|
|
||||||
from langchain_community.tools import DuckDuckGoSearchRun
|
|
||||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
|
||||||
from langgraph.graph import END, StateGraph
|
|
||||||
# For State Graph
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
import streamlit as st
|
|
||||||
import os
|
|
||||||
# Defining LLM
|
|
||||||
def configure_llm():
|
|
||||||
st.sidebar.header("Configure LLM")
|
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from agent_util import config_agent, configure_llm
|
||||||
|
|
||||||
|
# Streamlit Application Interface
|
||||||
|
st.sidebar.header("Configure LLM")
|
||||||
|
st.title("CEC Seller Assistant")
|
||||||
# Model Selection
|
# Model Selection
|
||||||
model_options = ["llama3.2"]
|
model_options = ["llama3.2"]
|
||||||
selected_model = st.sidebar.selectbox("Choose the LLM Model", options=model_options, index=0)
|
selected_model = st.sidebar.selectbox("Choose the LLM Model", options=model_options, index=0)
|
||||||
@ -22,216 +12,10 @@ def configure_llm():
|
|||||||
# Temperature Setting
|
# Temperature Setting
|
||||||
temperature = st.sidebar.slider("Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
|
temperature = st.sidebar.slider("Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
|
||||||
|
|
||||||
# Create LLM Instances based on user selection
|
llama3, llama3_json=configure_llm(selected_model, temperature)
|
||||||
llama_model = ChatOllama(model=selected_model, temperature=temperature)
|
|
||||||
llama_model_json = ChatOllama(model=selected_model, format='json', temperature=temperature)
|
|
||||||
|
|
||||||
return llama_model, llama_model_json
|
local_agent = config_agent(llama3, llama3_json)
|
||||||
|
|
||||||
# Streamlit Application Interface
|
|
||||||
st.title("Personal Research Assistant powered By Llama3.2")
|
|
||||||
llama3, llama3_json=configure_llm()
|
|
||||||
wrapper = DuckDuckGoSearchAPIWrapper(max_results=25)
|
|
||||||
web_search_tool = DuckDuckGoSearchRun(api_wrapper=wrapper)
|
|
||||||
generate_prompt = PromptTemplate(
|
|
||||||
template="""
|
|
||||||
|
|
||||||
<|begin_of_text|>
|
|
||||||
|
|
||||||
<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
You are an AI assistant for Research Question Tasks, that synthesizes web search results.
|
|
||||||
Strictly use the following pieces of web search context to answer the question. If you don't know the answer, just say that you don't know.
|
|
||||||
keep the answer concise, but provide all of the details you can in the form of a research report.
|
|
||||||
Only make direct references to material if provided in the context.
|
|
||||||
|
|
||||||
<|eot_id|>
|
|
||||||
|
|
||||||
<|start_header_id|>user<|end_header_id|>
|
|
||||||
|
|
||||||
Question: {question}
|
|
||||||
Web Search Context: {context}
|
|
||||||
Answer:
|
|
||||||
|
|
||||||
<|eot_id|>
|
|
||||||
|
|
||||||
<|start_header_id|>assistant<|end_header_id|>""",
|
|
||||||
input_variables=["question", "context"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chain
|
|
||||||
generate_chain = generate_prompt | llama3 | StrOutputParser()
|
|
||||||
router_prompt = PromptTemplate(
|
|
||||||
template="""
|
|
||||||
|
|
||||||
<|begin_of_text|>
|
|
||||||
|
|
||||||
<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
You are an expert at routing a user question to either the generation stage or web search.
|
|
||||||
Use the web search for questions that require more context for a better answer, or recent events.
|
|
||||||
Otherwise, you can skip and go straight to the generation phase to respond.
|
|
||||||
You do not need to be stringent with the keywords in the question related to these topics.
|
|
||||||
Give a binary choice 'web_search' or 'generate' based on the question.
|
|
||||||
Return the JSON with a single key 'choice' with no premable or explanation.
|
|
||||||
|
|
||||||
Question to route: {question}
|
|
||||||
|
|
||||||
<|eot_id|>
|
|
||||||
|
|
||||||
<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
""",
|
|
||||||
input_variables=["question"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chain
|
|
||||||
question_router = router_prompt | llama3_json | JsonOutputParser()
|
|
||||||
|
|
||||||
query_prompt = PromptTemplate(
|
|
||||||
template="""
|
|
||||||
|
|
||||||
<|begin_of_text|>
|
|
||||||
|
|
||||||
<|start_header_id|>system<|end_header_id|>
|
|
||||||
|
|
||||||
You are an expert at crafting web search queries for research questions.
|
|
||||||
More often than not, a user will ask a basic question that they wish to learn more about, however it might not be in the best format.
|
|
||||||
Reword their query to be the most effective web search string possible.
|
|
||||||
Return the JSON with a single key 'query' with no premable or explanation.
|
|
||||||
|
|
||||||
Question to transform: {question}
|
|
||||||
|
|
||||||
<|eot_id|>
|
|
||||||
|
|
||||||
<|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
""",
|
|
||||||
input_variables=["question"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chain
|
|
||||||
query_chain = query_prompt | llama3_json | JsonOutputParser()
|
|
||||||
# Graph State
|
|
||||||
class GraphState(TypedDict):
|
|
||||||
"""
|
|
||||||
Represents the state of our graph.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
question: question
|
|
||||||
generation: LLM generation
|
|
||||||
search_query: revised question for web search
|
|
||||||
context: web_search result
|
|
||||||
"""
|
|
||||||
question : str
|
|
||||||
generation : str
|
|
||||||
search_query : str
|
|
||||||
context : str
|
|
||||||
|
|
||||||
# Node - Generate
|
|
||||||
|
|
||||||
def generate(state):
|
|
||||||
"""
|
|
||||||
Generate answer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The current graph state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (dict): New key added to state, generation, that contains LLM generation
|
|
||||||
"""
|
|
||||||
|
|
||||||
print("Step: Generating Final Response")
|
|
||||||
question = state["question"]
|
|
||||||
context = state["context"]
|
|
||||||
|
|
||||||
# Answer Generation
|
|
||||||
generation = generate_chain.invoke({"context": context, "question": question})
|
|
||||||
return {"generation": generation}
|
|
||||||
|
|
||||||
# Node - Query Transformation
|
|
||||||
|
|
||||||
def transform_query(state):
|
|
||||||
"""
|
|
||||||
Transform user question to web search
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The current graph state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (dict): Appended search query
|
|
||||||
"""
|
|
||||||
|
|
||||||
print("Step: Optimizing Query for Web Search")
|
|
||||||
question = state['question']
|
|
||||||
gen_query = query_chain.invoke({"question": question})
|
|
||||||
search_query = gen_query["query"]
|
|
||||||
return {"search_query": search_query}
|
|
||||||
|
|
||||||
|
|
||||||
# Node - Web Search
|
|
||||||
|
|
||||||
def web_search(state):
|
|
||||||
"""
|
|
||||||
Web search based on the question
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The current graph state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
state (dict): Appended web results to context
|
|
||||||
"""
|
|
||||||
|
|
||||||
search_query = state['search_query']
|
|
||||||
print(f'Step: Searching the Web for: "{search_query}"')
|
|
||||||
|
|
||||||
# Web search tool call
|
|
||||||
search_result = web_search_tool.invoke(search_query)
|
|
||||||
return {"context": search_result}
|
|
||||||
|
|
||||||
|
|
||||||
# Conditional Edge, Routing
|
|
||||||
|
|
||||||
def route_question(state):
|
|
||||||
"""
|
|
||||||
route question to web search or generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (dict): The current graph state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Next node to call
|
|
||||||
"""
|
|
||||||
|
|
||||||
print("Step: Routing Query")
|
|
||||||
question = state['question']
|
|
||||||
output = question_router.invoke({"question": question})
|
|
||||||
if output['choice'] == "web_search":
|
|
||||||
print("Step: Routing Query to Web Search")
|
|
||||||
return "websearch"
|
|
||||||
elif output['choice'] == 'generate':
|
|
||||||
print("Step: Routing Query to Generation")
|
|
||||||
return "generate"
|
|
||||||
# Build the nodes
|
|
||||||
workflow = StateGraph(GraphState)
|
|
||||||
workflow.add_node("websearch", web_search)
|
|
||||||
workflow.add_node("transform_query", transform_query)
|
|
||||||
workflow.add_node("generate", generate)
|
|
||||||
|
|
||||||
# Build the edges
|
|
||||||
workflow.set_conditional_entry_point(
|
|
||||||
route_question,
|
|
||||||
{
|
|
||||||
"websearch": "transform_query",
|
|
||||||
"generate": "generate",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
workflow.add_edge("transform_query", "websearch")
|
|
||||||
workflow.add_edge("websearch", "generate")
|
|
||||||
workflow.add_edge("generate", END)
|
|
||||||
|
|
||||||
# Compile the workflow
|
|
||||||
local_agent = workflow.compile()
|
|
||||||
def run_agent(query):
|
def run_agent(query):
|
||||||
output = local_agent.invoke({"question": query})
|
output = local_agent.invoke({"question": query})
|
||||||
print("=======")
|
print("=======")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user