diff --git a/gradio/app.py b/gradio/app.py index 905c587e..181c059d 100644 --- a/gradio/app.py +++ b/gradio/app.py @@ -19,9 +19,12 @@ import torch import gradio as gr +from tempfile import NamedTemporaryFile +import datetime -MODEL_TYPES = ["v1.1"] + +MODEL_TYPES = ["v1.1-stage2", "v1.1-stage3"] CONFIG_MAP = { "v1.1-stage2": "configs/opensora-v1-1/inference/sample-ref.py", "v1.1-stage3": "configs/opensora-v1-1/inference/sample-ref.py", @@ -31,12 +34,41 @@ "v1.1-stage3": "hpcai-tech/OpenSora-STDiT-v2-stage3", } RESOLUTION_MAP = { - "144p": (144, 256), - "240p": (240, 426), - "360p": (360, 480), - "480p": (480, 858), - "720p": (720, 1280), - "1080p": (1080, 1920) + "144p": { + "16:9": (256, 144), + "9:16": (144, 256), + "4:3": (221, 165), + "3:4": (165, 221), + "1:1": (192, 192), + }, + "240p": { + "16:9": (426, 240), + "9:16": (240, 426), + "4:3": (370, 278), + "3:4": (278, 370), + "1:1": (320, 320), + }, + "360p": { + "16:9": (640, 360), + "9:16": (360, 640), + "4:3": (554, 416), + "3:4": (416, 554), + "1:1": (480, 480), + }, + "480p": { + "16:9": (854, 480), + "9:16": (480, 854), + "4:3": (740, 555), + "3:4": (555, 740), + "1:1": (640, 640), + }, + "720p": { + "16:9": (1280, 720), + "9:16": (720, 1280), + "4:3": (1108, 832), + "3:4": (832, 1110), + "1:1": (960, 960), + }, } @@ -302,37 +334,53 @@ def parse_args(): vae, text_encoder, stdit, scheduler = build_models(args.model_type, config, enable_optimization=args.enable_optimization) -@spaces.GPU(duration=200) -def run_inference(mode, prompt_text, resolution, length, reference_image): +def run_inference(mode, prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale): + torch.manual_seed(seed) with torch.inference_mode(): # ====================== # 1. Preparation # ====================== # parse the inputs - resolution = RESOLUTION_MAP[resolution] - + resolution = RESOLUTION_MAP[resolution][aspect_ratio] + + # gather args from config + num_frames = config.num_frames + frame_interval = config.frame_interval + fps = config.fps + condition_frame_length = config.condition_frame_length + # compute number of loops - num_seconds = int(length.rstrip('s')) - total_number_of_frames = num_seconds * config.fps / config.frame_interval - num_loop = math.ceil(total_number_of_frames / config.num_frames) + if mode == "Text2Image": + num_frames = 1 + num_loop = 1 + else: + num_seconds = int(length.rstrip('s')) + if num_seconds <= 16: + num_frames = num_seconds * fps // frame_interval + num_loop = 1 + else: + config.num_frames = 16 + total_number_of_frames = num_seconds * fps / frame_interval + num_loop = math.ceil((total_number_of_frames - condition_frame_length) / (num_frames - condition_frame_length)) # prepare model args - model_args = dict() - height = torch.tensor([resolution[0]], device=device, dtype=dtype) - width = torch.tensor([resolution[1]], device=device, dtype=dtype) - num_frames = torch.tensor([config.num_frames], device=device, dtype=dtype) - ar = torch.tensor([resolution[0] / resolution[1]], device=device, dtype=dtype) if config.num_frames == 1: - config.fps = IMG_FPS - fps = torch.tensor([config.fps], device=device, dtype=dtype) - model_args["height"] = height - model_args["width"] = width - model_args["num_frames"] = num_frames - model_args["ar"] = ar - model_args["fps"] = fps + fps = IMG_FPS + + model_args = dict() + height_tensor = torch.tensor([resolution[0]], device=device, dtype=dtype) + width_tensor = torch.tensor([resolution[1]], device=device, dtype=dtype) + num_frames_tensor = torch.tensor([num_frames], device=device, dtype=dtype) + ar_tensor = torch.tensor([resolution[0] / resolution[1]], device=device, dtype=dtype) + fps_tensor = torch.tensor([fps], device=device, dtype=dtype) + model_args["height"] = height_tensor + model_args["width"] = width_tensor + model_args["num_frames"] = num_frames_tensor + model_args["ar"] = ar_tensor + model_args["fps"] = fps_tensor # compute latent size - input_size = (config.num_frames, *resolution) + input_size = (num_frames, *resolution) latent_size = vae.get_latent_size(input_size) # process prompt @@ -342,24 +390,33 @@ def run_inference(mode, prompt_text, resolution, length, reference_image): video_clips = [] # prepare mask strategy - if mode == "Text2Video": + if mode == "Text2Image": mask_strategy = [None] - elif mode == "Image2Video": - mask_strategy = ['0'] + elif mode == "Text2Video": + if reference_image is not None: + mask_strategy = ['0'] + else: + mask_strategy = [None] else: raise ValueError(f"Invalid mode: {mode}") # ========================= # 2. Load reference images # ========================= - if mode == "Text2Video": + if mode == "Text2Image": refs_x = collect_references_batch([None], vae, resolution) - elif mode == "Image2Video": - # save image to disk - from PIL import Image - im = Image.fromarray(reference_image) - im.save("test.jpg") - refs_x = collect_references_batch(["test.jpg"], vae, resolution) + elif mode == "Text2Video": + if reference_image is not None: + # save image to disk + from PIL import Image + im = Image.fromarray(reference_image) + idx = os.environ['CUDA_VISIBLE_DEVICES'] + + with NamedTemporaryFile(suffix=".jpg") as temp_file: + im.save(temp_file.name) + refs_x = collect_references_batch([temp_file.name], vae, resolution) + else: + refs_x = collect_references_batch([None], vae, resolution) else: raise ValueError(f"Invalid mode: {mode}") @@ -386,11 +443,20 @@ def run_inference(mode, prompt_text, resolution, length, reference_image): mask_strategy[j] += ";" mask_strategy[ j - ] += f"{loop_i},{len(refs)-1},-{config.condition_frame_length},0,{config.condition_frame_length}" + ] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length}" masks = apply_mask_strategy(z, refs_x, mask_strategy, loop_i) # 4.6. diffusion sampling + # hack to update num_sampling_steps and cfg_scale + scheduler_kwargs = config.scheduler.copy() + scheduler_kwargs.pop('type') + scheduler_kwargs['num_sampling_steps'] = sampling_steps + scheduler_kwargs['cfg_scale'] = cfg_scale + + scheduler.__init__( + **scheduler_kwargs + ) samples = scheduler.sample( stdit, text_encoder, @@ -410,10 +476,20 @@ def run_inference(mode, prompt_text, resolution, length, reference_image): for i in range(1, num_loop) ] video = torch.cat(video_clips_list, dim=1) - save_path = f"{args.output}/sample" - saved_path = save_sample(video, fps=config.fps // config.frame_interval, save_path=save_path, force_video=True) + current_datetime = datetime.datetime.now() + timestamp = current_datetime.timestamp() + save_path = os.path.join(args.output, f"output_{timestamp}") + saved_path = save_sample(video, save_path=save_path, fps=config.fps // config.frame_interval) return saved_path +@spaces.GPU(duration=200) +def run_image_inference(prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale): + return run_inference("Text2Image", prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale) + +@spaces.GPU(duration=200) +def run_video_inference(prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale): + return run_inference("Text2Video", prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale) + def main(): # create demo @@ -442,31 +518,54 @@ def main(): with gr.Row(): with gr.Column(): - mode = gr.Radio( - choices=["Text2Video", "Image2Video"], - value="Text2Video", - label="Usage", - info="Choose your usage scenario", - ) prompt_text = gr.Textbox( label="Prompt", placeholder="Describe your video here", lines=4, ) resolution = gr.Radio( - choices=["144p", "240p", "360p", "480p", "720p", "1080p"], - value="144p", + choices=["144p", "240p", "360p", "480p", "720p"], + value="240p", label="Resolution", ) + aspect_ratio = gr.Radio( + choices=["9:16", "16:9", "3:4", "4:3", "1:1"], + value="9:16", + label="Aspect Ratio (H:W)", + ) length = gr.Radio( - choices=["2s", "4s", "8s"], + choices=["2s", "4s", "8s", "16s"], value="2s", - label="Video Length", + label="Video Length (only effective for video generation)", info="8s may fail as Hugging Face ZeroGPU has the limitation of max 200 seconds inference time." ) + with gr.Row(): + seed = gr.Slider( + value=1024, + minimum=1, + maximum=2048, + step=1, + label="Seed" + ) + + sampling_steps = gr.Slider( + value=100, + minimum=1, + maximum=200, + step=1, + label="Sampling steps" + ) + cfg_scale = gr.Slider( + value=7.0, + minimum=0.0, + maximum=10.0, + step=0.1, + label="CFG Scale" + ) + reference_image = gr.Image( - label="Reference Image (only used for Image2Video)", + label="Reference Image (Optional)", ) with gr.Column(): @@ -476,12 +575,18 @@ def main(): ) with gr.Row(): - submit_button = gr.Button("Generate video") + image_gen_button = gr.Button("Generate image") + video_gen_button = gr.Button("Generate video") - submit_button.click( - fn=run_inference, - inputs=[mode, prompt_text, resolution, length, reference_image], + image_gen_button.click( + fn=run_image_inference, + inputs=[prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale], + outputs=reference_image + ) + video_gen_button.click( + fn=run_video_inference, + inputs=[prompt_text, resolution, aspect_ratio, length, reference_image, seed, sampling_steps, cfg_scale], outputs=output_video )