Skip to content

Commit

Permalink
update prompt template (pytorch#3372)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi9 authored Dec 5, 2024
1 parent 3182443 commit 0985386
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ echo "ROOT_DIR: $ROOT_DIR"

# Build docker image for the application
docker_build_cmd="DOCKER_BUILDKIT=1 \
docker buildx build \
docker buildx build --load \
--platform=linux/amd64 \
--file ${EXAMPLE_DIR}/Dockerfile \
--build-arg BASE_IMAGE=\"${BASE_IMAGE}\" \
Expand Down
62 changes: 46 additions & 16 deletions examples/usecases/llm_diffusion_serving_app/docker/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@
st.session_state.gen_captions = st.session_state.get("gen_captions", [])
st.session_state.llm_prompts = st.session_state.get("llm_prompts", None)
st.session_state.llm_time = st.session_state.get("llm_time", 0)
st.session_state.num_images = st.session_state.get("num_images", 2)
st.session_state.max_new_tokens = st.session_state.get("max_new_tokens", 100)


def update_max_tokens():
# Update the max_new_tokens input value in session state and UI
# The prompts generated are description which are around 50 tokens per prompt
st.session_state.max_new_tokens = 50 * st.session_state.num_images


with st.sidebar:
st.title("Image Generation with Llama, SDXL, torch.compile and OpenVINO")
Expand Down Expand Up @@ -76,13 +85,23 @@ def get_model_status(model_name):
)

# Client App Parameters
num_images = st.sidebar.number_input(
"Number of images to generate", min_value=1, max_value=8, value=2, step=1
# Default value is set via session_state variables for num_images and max_new_tokens
st.sidebar.number_input(
"Number of images to generate",
min_value=1,
max_value=8,
step=1,
key="num_images",
on_change=update_max_tokens,
)

st.subheader("LLM Model parameters")
max_new_tokens = st.sidebar.number_input(
"max_new_tokens", min_value=30, max_value=250, value=40, step=5
st.sidebar.number_input(
"max_new_tokens",
min_value=100,
max_value=1250,
step=10,
key="max_new_tokens",
)

temperature = st.sidebar.number_input(
Expand Down Expand Up @@ -159,11 +178,19 @@ def sd_response_postprocess(response):


def preprocess_llm_input(user_prompt, num_images=2):
template = """ Below is an instruction that describes a task. Write a response that appropriately completes the request.
Generate {} unique prompts similar to '{}' by changing the context, keeping the core theme intact.
Give the output in square brackets seperated by semicolon.
template = """ Generate expanded and descriptive prompts for a image generation model based on the user input.
Each prompt should build upon the original concept, adding layers of detail and context to create a more vivid and engaging scene for image generation.
Format each prompt distinctly within square brackets.
Ensure that each prompt is a standalone description that significantly elaborates on the original input as shown in the example below:
Example: For the input 'A futuristic cityscape with flying cars,' generate:
[A futuristic cityscape with sleek, silver flying cars zipping through the sky, set against a backdrop of towering skyscrapers and neon-lit streets.]
[A futuristic cityscape at dusk, with flying cars of various colors and shapes flying in formation.]
[A futuristic cityscape at night, with flying cars illuminated by the city's vibrant nightlife.]
Aim for a tone that is rich in imagination and visual appeal, capturing the essence of the scene with depth and creativity.
Do not generate text beyond the specified output format. Do not explain your response.
### Response:
Generate {} similar detailed prompts for the user's input: {}.
Organize the output such that each prompt is within square brackets. Refer to example above.
"""

prompt_template_with_user_input = template.format(num_images, user_prompt)
Expand Down Expand Up @@ -206,7 +233,7 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt):
{
"prompt_template": prompt_template_with_user_input,
"user_prompt": user_prompt,
"max_new_tokens": max_new_tokens,
"max_new_tokens": st.session_state.max_new_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
Expand Down Expand Up @@ -260,7 +287,7 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt):
)

user_prompt = st.text_input("Enter a prompt for image generation:")
include_user_prompt = st.checkbox("Include orginal prompt", value=False)
include_user_prompt = st.checkbox("Include original prompt", value=False)

prompt_container = st.container()
status_container = st.container()
Expand All @@ -287,15 +314,18 @@ def display_prompts():
llm_start_time = time.time()

st.session_state.llm_prompts = [user_prompt]
if num_images > 1:
if st.session_state.num_images > 1:
prompt_template_with_user_input = preprocess_llm_input(
user_prompt, num_images
user_prompt, st.session_state.num_images
)
llm_prompts = generate_llm_model_response(
prompt_template_with_user_input, user_prompt
)
st.session_state.llm_prompts = postprocess_llm_response(
llm_prompts, user_prompt, num_images, include_user_prompt
llm_prompts,
user_prompt,
st.session_state.num_images,
include_user_prompt,
)

st.session_state.llm_time = time.time() - llm_start_time
Expand All @@ -306,11 +336,11 @@ def display_prompts():
prompt_container.write(
"Enter Image Generation Prompt and Click Generate Prompts !"
)
elif len(st.session_state.llm_prompts) < num_images:
elif len(st.session_state.llm_prompts) < st.session_state.num_images:
prompt_container.warning(
f"""Insufficient prompts. Regenerate prompts !
Num Images Requested: {num_images}, Prompts Generated: {len(st.session_state.llm_prompts)}
{f"Consider increasing the max_new_tokens parameter !" if num_images > 4 else ""}""",
Num Images Requested: {st.session_state.num_images}, Prompts Generated: {len(st.session_state.llm_prompts)}
{f"Consider increasing the max_new_tokens parameter !" if st.session_state.num_images > 4 else ""}""",
icon="⚠️",
)
else:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self):
self.user_prompt = []
self.prompt_template = ""

@timed
def initialize(self, ctx):
self.context = ctx
self.manifest = ctx.manifest
Expand All @@ -48,7 +49,7 @@ def initialize(self, ctx):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(model_dir)

# Get backend for model-confil.yaml. Defaults to "openvino"
# Get backend for model-config.yaml. Defaults to "openvino"
compile_options = {}
pt2_config = ctx.model_yaml_config.get("pt2", {})
compile_options = {
Expand Down Expand Up @@ -115,21 +116,22 @@ def inference(self, input_data):

return generated_text

@timed
def postprocess(self, generated_text):
logger.info(f"LLM Generated Output: {generated_text}")
# Initialize with user prompt
# Remove input prompt from generated_text
generated_text = generated_text.replace(self.prompt_template, "", 1)
# Clean up LLM output
generated_text = generated_text.replace("\n", " ").replace(" ", " ").strip()

logger.info(f"LLM Generated Output without input prompt: {generated_text}")
prompt_list = [self.user_prompt]
try:
logger.info("Parsing LLM Generated Output to extract prompts within []...")
response_match = re.search(r"\[(.*?)\]", generated_text)
# Extract the result if match is found
if response_match:
# Split the extracted string by semicolon and strip any leading/trailing spaces
response_list = response_match.group(1)
extracted_prompts = [item.strip() for item in response_list.split(";")]
prompt_list.extend(extracted_prompts)
else:
logger.warning("No match found in the generated output text !!!")
# Use regular expressions to find strings within square brackets
pattern = re.compile(r"\[.*?\]")
matches = pattern.findall(generated_text)
# Clean up the matches and remove square brackets
extracted_prompts = [match.strip("[]").strip() for match in matches]
prompt_list.extend(extracted_prompts)
except Exception as e:
logger.error(f"An error occurred while parsing the generated text: {e}")

Expand Down

0 comments on commit 0985386

Please sign in to comment.