reformat code
This commit is contained in:
parent
3d2e96c1fa
commit
2b9e07caa3
@ -3,15 +3,19 @@ from langchain.prompts import PromptTemplate
|
|||||||
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
|
||||||
from langgraph.graph import END, StateGraph
|
from langgraph.graph import END, StateGraph
|
||||||
from langchain_community.chat_models import ChatOllama
|
from langchain_community.chat_models import ChatOllama
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
# For State Graph
|
# For State Graph
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
def configure_llm(model, temperature):
|
def configure_llm(model, temperature):
|
||||||
llama3 = ChatOllama(model=model, temperature=temperature)
|
llama3 = ChatOllama(model=model, temperature=temperature)
|
||||||
llama3_json = ChatOllama(model=model, format='json', temperature=temperature)
|
llama3_json = ChatOllama(
|
||||||
|
model=model, format='json', temperature=temperature)
|
||||||
return llama3, llama3_json
|
return llama3, llama3_json
|
||||||
|
|
||||||
|
|
||||||
def config_agent(llama3, llama3_json):
|
def config_agent(llama3, llama3_json):
|
||||||
generate_prompt = PromptTemplate(
|
generate_prompt = PromptTemplate(
|
||||||
template="""
|
template="""
|
||||||
@ -106,6 +110,7 @@ def config_agent(llama3, llama3_json):
|
|||||||
# Chain
|
# Chain
|
||||||
query_chain = query_prompt | llama3_json | JsonOutputParser()
|
query_chain = query_prompt | llama3_json | JsonOutputParser()
|
||||||
# Graph State
|
# Graph State
|
||||||
|
|
||||||
class GraphState(TypedDict):
|
class GraphState(TypedDict):
|
||||||
"""
|
"""
|
||||||
Represents the state of our graph.
|
Represents the state of our graph.
|
||||||
@ -142,7 +147,8 @@ def config_agent(llama3, llama3_json):
|
|||||||
if context is not None and context.index("orderinfo") != -1:
|
if context is not None and context.index("orderinfo") != -1:
|
||||||
return {"generation": context.replace("orderinfo:", "")}
|
return {"generation": context.replace("orderinfo:", "")}
|
||||||
else:
|
else:
|
||||||
generation = generate_chain.invoke({"context": context, "question": question})
|
generation = generate_chain.invoke(
|
||||||
|
{"context": context, "question": question})
|
||||||
return {"generation": generation}
|
return {"generation": generation}
|
||||||
|
|
||||||
# Node - Query Transformation
|
# Node - Query Transformation
|
||||||
@ -181,7 +187,8 @@ def config_agent(llama3, llama3_json):
|
|||||||
amount = state['send_order']
|
amount = state['send_order']
|
||||||
print(amount)
|
print(amount)
|
||||||
print(f'Step: build order info for : "{amount}" CEC')
|
print(f'Step: build order info for : "{amount}" CEC')
|
||||||
order_info = {"amount": amount, "price": 0.1, "name": "CEC", "url": "https://www.example.com"}
|
order_info = {"amount": amount, "price": 0.1,
|
||||||
|
"name": "CEC", "url": "https://www.example.com"}
|
||||||
search_result = f"orderinfo:{json.dumps(order_info)}"
|
search_result = f"orderinfo:{json.dumps(order_info)}"
|
||||||
return {"context": search_result}
|
return {"context": search_result}
|
||||||
|
|
||||||
@ -227,5 +234,7 @@ def config_agent(llama3, llama3_json):
|
|||||||
workflow.add_edge("generate", END)
|
workflow.add_edge("generate", END)
|
||||||
|
|
||||||
# Compile the workflow
|
# Compile the workflow
|
||||||
|
memory = MemorySaver()
|
||||||
|
# local_agent = workflow.compile(checkpointer=memory)
|
||||||
local_agent = workflow.compile()
|
local_agent = workflow.compile()
|
||||||
return local_agent
|
return local_agent
|
2
app.py
2
app.py
@ -7,12 +7,14 @@ llama3, llama3_json = configure_llm(local_llm, 0)
|
|||||||
|
|
||||||
local_agent = config_agent(llama3, llama3_json)
|
local_agent = config_agent(llama3, llama3_json)
|
||||||
|
|
||||||
|
|
||||||
def run_agent(query):
|
def run_agent(query):
|
||||||
output = local_agent.invoke({"question": query})
|
output = local_agent.invoke({"question": query})
|
||||||
print("=======")
|
print("=======")
|
||||||
print(output["generation"])
|
print(output["generation"])
|
||||||
# display(Markdown(output["generation"]))
|
# display(Markdown(output["generation"]))
|
||||||
|
|
||||||
|
|
||||||
# run_agent("I want to buy 100 CEC")
|
# run_agent("I want to buy 100 CEC")
|
||||||
run_agent("I'm Cz, I want to buy 100 CEC, how about the price?")
|
run_agent("I'm Cz, I want to buy 100 CEC, how about the price?")
|
||||||
# run_agent("I'm Cz, How about CEC?")
|
# run_agent("I'm Cz, How about CEC?")
|
@ -7,20 +7,27 @@ st.sidebar.header("Configure LLM")
|
|||||||
st.title("CEC Seller Assistant")
|
st.title("CEC Seller Assistant")
|
||||||
# Model Selection
|
# Model Selection
|
||||||
model_options = ["llama3.2"]
|
model_options = ["llama3.2"]
|
||||||
selected_model = st.sidebar.selectbox("Choose the LLM Model", options=model_options, index=0)
|
selected_model = st.sidebar.selectbox(
|
||||||
|
"Choose the LLM Model", options=model_options, index=0)
|
||||||
|
|
||||||
# Temperature Setting
|
# Temperature Setting
|
||||||
temperature = st.sidebar.slider("Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
|
temperature = st.sidebar.slider(
|
||||||
|
"Set the Temperature", min_value=0.0, max_value=1.0, value=0.5, step=0.1)
|
||||||
|
|
||||||
llama3, llama3_json = configure_llm(selected_model, temperature)
|
llama3, llama3_json = configure_llm(selected_model, temperature)
|
||||||
|
|
||||||
local_agent = config_agent(llama3, llama3_json)
|
local_agent = config_agent(llama3, llama3_json)
|
||||||
|
|
||||||
|
|
||||||
def run_agent(query):
|
def run_agent(query):
|
||||||
|
# config = {"configurable": {"thread_id": "1"}}
|
||||||
|
# output = local_agent.invoke({"question": query}, config)
|
||||||
|
# print(list(local_agent.get_state_history(config)))
|
||||||
output = local_agent.invoke({"question": query})
|
output = local_agent.invoke({"question": query})
|
||||||
print("=======")
|
print("=======")
|
||||||
|
|
||||||
return output["generation"]
|
return output["generation"]
|
||||||
|
|
||||||
|
|
||||||
user_query = st.text_input("Enter your research question:", "")
|
user_query = st.text_input("Enter your research question:", "")
|
||||||
|
|
||||||
if st.button("Run Query"):
|
if st.button("Run Query"):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user