Skip to content

Commit

Permalink
update gradio (#310)
Browse files Browse the repository at this point in the history
* update gradio

* update gradio
  • Loading branch information
FrankLeeeee authored Apr 25, 2024
1 parent 68b8f60 commit 74b6453
Showing 1 changed file with 160 additions and 55 deletions.
215 changes: 160 additions & 55 deletions gradio/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import torch

import gradio as gr
from tempfile import NamedTemporaryFile
import datetime


MODEL_TYPES = ["v1.1"]

MODEL_TYPES = ["v1.1-stage2", "v1.1-stage3"]
CONFIG_MAP = {
"v1.1-stage2": "configs/opensora-v1-1/inference/sample-ref.py",
"v1.1-stage3": "configs/opensora-v1-1/inference/sample-ref.py",
Expand All @@ -31,12 +34,41 @@
"v1.1-stage3": "hpcai-tech/OpenSora-STDiT-v2-stage3",
}
RESOLUTION_MAP = {
"144p": (144, 256),
"240p": (240, 426),
"360p": (360, 480),
"480p": (480, 858),
"720p": (720, 1280),
"1080p": (1080, 1920)
"144p": {
"16:9": (256, 144),
"9:16": (144, 256),
"4:3": (221, 165),
"3:4": (165, 221),
"1:1": (192, 192),
},
"240p": {
"16:9": (426, 240),
"9:16": (240, 426),
"4:3": (370, 278),
"3:4": (278, 370),
"1:1": (320, 320),
},
"360p": {
"16:9": (640, 360),
"9:16": (360, 640),
"4:3": (554, 416),
"3:4": (416, 554),
"1:1": (480, 480),
},
"480p": {
"16:9": (854, 480),
"9:16": (480, 854),
"4:3": (740, 555),
"3:4": (555, 740),
"1:1": (640, 640),
},
"720p": {
"16:9": (1280, 720),
"9:16": (720, 1280),
"4:3": (1108, 832),
"3:4": (832, 1110),
"1:1": (960, 960),
},
}


Expand Down Expand Up @@ -302,37 +334,53 @@ def parse_args():
vae, text_encoder, stdit, scheduler = build_models(args.model_type, config, enable_optimization=args.enable_optimization)


@spaces.GPU(duration=200)
def run_inference(mode, prompt_text, resolution, length, reference_image):
def run_inference(mode, prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale):
torch.manual_seed(seed)
with torch.inference_mode():
# ======================
# 1. Preparation
# ======================
# parse the inputs
resolution = RESOLUTION_MAP[resolution]

resolution = RESOLUTION_MAP[resolution][aspect_ratio]

# gather args from config
num_frames = config.num_frames
frame_interval = config.frame_interval
fps = config.fps
condition_frame_length = config.condition_frame_length

# compute number of loops
num_seconds = int(length.rstrip('s'))
total_number_of_frames = num_seconds * config.fps / config.frame_interval
num_loop = math.ceil(total_number_of_frames / config.num_frames)
if mode == "Text2Image":
num_frames = 1
num_loop = 1
else:
num_seconds = int(length.rstrip('s'))
if num_seconds <= 16:
num_frames = num_seconds * fps // frame_interval
num_loop = 1
else:
config.num_frames = 16
total_number_of_frames = num_seconds * fps / frame_interval
num_loop = math.ceil((total_number_of_frames - condition_frame_length) / (num_frames - condition_frame_length))

# prepare model args
model_args = dict()
height = torch.tensor([resolution[0]], device=device, dtype=dtype)
width = torch.tensor([resolution[1]], device=device, dtype=dtype)
num_frames = torch.tensor([config.num_frames], device=device, dtype=dtype)
ar = torch.tensor([resolution[0] / resolution[1]], device=device, dtype=dtype)
if config.num_frames == 1:
config.fps = IMG_FPS
fps = torch.tensor([config.fps], device=device, dtype=dtype)
model_args["height"] = height
model_args["width"] = width
model_args["num_frames"] = num_frames
model_args["ar"] = ar
model_args["fps"] = fps
fps = IMG_FPS

model_args = dict()
height_tensor = torch.tensor([resolution[0]], device=device, dtype=dtype)
width_tensor = torch.tensor([resolution[1]], device=device, dtype=dtype)
num_frames_tensor = torch.tensor([num_frames], device=device, dtype=dtype)
ar_tensor = torch.tensor([resolution[0] / resolution[1]], device=device, dtype=dtype)
fps_tensor = torch.tensor([fps], device=device, dtype=dtype)
model_args["height"] = height_tensor
model_args["width"] = width_tensor
model_args["num_frames"] = num_frames_tensor
model_args["ar"] = ar_tensor
model_args["fps"] = fps_tensor

# compute latent size
input_size = (config.num_frames, *resolution)
input_size = (num_frames, *resolution)
latent_size = vae.get_latent_size(input_size)

# process prompt
Expand All @@ -342,24 +390,33 @@ def run_inference(mode, prompt_text, resolution, length, reference_image):
video_clips = []

# prepare mask strategy
if mode == "Text2Video":
if mode == "Text2Image":
mask_strategy = [None]
elif mode == "Image2Video":
mask_strategy = ['0']
elif mode == "Text2Video":
if reference_image is not None:
mask_strategy = ['0']
else:
mask_strategy = [None]
else:
raise ValueError(f"Invalid mode: {mode}")

# =========================
# 2. Load reference images
# =========================
if mode == "Text2Video":
if mode == "Text2Image":
refs_x = collect_references_batch([None], vae, resolution)
elif mode == "Image2Video":
# save image to disk
from PIL import Image
im = Image.fromarray(reference_image)
im.save("test.jpg")
refs_x = collect_references_batch(["test.jpg"], vae, resolution)
elif mode == "Text2Video":
if reference_image is not None:
# save image to disk
from PIL import Image
im = Image.fromarray(reference_image)
idx = os.environ['CUDA_VISIBLE_DEVICES']

with NamedTemporaryFile(suffix=".jpg") as temp_file:
im.save(temp_file.name)
refs_x = collect_references_batch([temp_file.name], vae, resolution)
else:
refs_x = collect_references_batch([None], vae, resolution)
else:
raise ValueError(f"Invalid mode: {mode}")

Expand All @@ -386,11 +443,20 @@ def run_inference(mode, prompt_text, resolution, length, reference_image):
mask_strategy[j] += ";"
mask_strategy[
j
] += f"{loop_i},{len(refs)-1},-{config.condition_frame_length},0,{config.condition_frame_length}"
] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length}"

masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i)

# 4.6. diffusion sampling
# hack to update num_sampling_steps and cfg_scale
scheduler_kwargs = config.scheduler.copy()
scheduler_kwargs.pop('type')
scheduler_kwargs['num_sampling_steps'] = sampling_steps
scheduler_kwargs['cfg_scale'] = cfg_scale

scheduler.__init__(
**scheduler_kwargs
)
samples = scheduler.sample(
stdit,
text_encoder,
Expand All @@ -410,10 +476,20 @@ def run_inference(mode, prompt_text, resolution, length, reference_image):
for i in range(1, num_loop)
]
video = torch.cat(video_clips_list, dim=1)
save_path = f"{args.output}/sample"
saved_path = save_sample(video, fps=config.fps // config.frame_interval, save_path=save_path, force_video=True)
current_datetime = datetime.datetime.now()
timestamp = current_datetime.timestamp()
save_path = os.path.join(args.output, f"output_{timestamp}")
saved_path = save_sample(video, save_path=save_path, fps=config.fps // config.frame_interval)
return saved_path

@spaces.GPU(duration=200)
def run_image_inference(prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale):
return run_inference("Text2Image", prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale)

@spaces.GPU(duration=200)
def run_video_inference(prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale):
return run_inference("Text2Video", prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale)


def main():
# create demo
Expand Down Expand Up @@ -442,31 +518,54 @@ def main():

with gr.Row():
with gr.Column():
mode = gr.Radio(
choices=["Text2Video", "Image2Video"],
value="Text2Video",
label="Usage",
info="Choose your usage scenario",
)
prompt_text = gr.Textbox(
label="Prompt",
placeholder="Describe your video here",
lines=4,
)
resolution = gr.Radio(
choices=["144p", "240p", "360p", "480p", "720p", "1080p"],
value="144p",
choices=["144p", "240p", "360p", "480p", "720p"],
value="240p",
label="Resolution",
)
aspect_ratio = gr.Radio(
choices=["9:16", "16:9", "3:4", "4:3", "1:1"],
value="9:16",
label="Aspect Ratio (H:W)",
)
length = gr.Radio(
choices=["2s", "4s", "8s"],
choices=["2s", "4s", "8s", "16s"],
value="2s",
label="Video Length",
label="Video Length (only effective for video generation)",
info="8s may fail as Hugging Face ZeroGPU has the limitation of max 200 seconds inference time."
)

with gr.Row():
seed = gr.Slider(
value=1024,
minimum=1,
maximum=2048,
step=1,
label="Seed"
)

sampling_steps = gr.Slider(
value=100,
minimum=1,
maximum=200,
step=1,
label="Sampling steps"
)
cfg_scale = gr.Slider(
value=7.0,
minimum=0.0,
maximum=10.0,
step=0.1,
label="CFG Scale"
)

reference_image = gr.Image(
label="Reference Image (only used for Image2Video)",
label="Reference Image (Optional)",
)

with gr.Column():
Expand All @@ -476,12 +575,18 @@ def main():
)

with gr.Row():
submit_button = gr.Button("Generate video")
image_gen_button = gr.Button("Generate image")
video_gen_button = gr.Button("Generate video")


submit_button.click(
fn=run_inference,
inputs=[mode, prompt_text, resolution, length, reference_image],
image_gen_button.click(
fn=run_image_inference,
inputs=[prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale],
outputs=reference_image
)
video_gen_button.click(
fn=run_video_inference,
inputs=[prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale],
outputs=output_video
)

Expand Down

0 comments on commit 74b6453

Please sign in to comment.