forked from lucataco/cog-hotshot-xl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
121 lines (105 loc) · 4.23 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input, Path
import os
import sys
sys.path.extend(['/Hotshot-XL'])
import torch
import tempfile
from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
from hotshot_xl.pipelines.hotshot_xl_controlnet_pipeline import HotshotXLControlNetPipeline
from hotshot_xl.models.unet import UNet3DConditionModel
import torchvision.transforms as transforms
from einops import rearrange
from hotshot_xl.utils import save_as_gif, extract_gif_frames_from_midpoint, scale_aspect_fill
from torch import autocast
from diffusers import ControlNetModel
from contextlib import contextmanager
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
SCHEDULERS = {
'EulerAncestralDiscreteScheduler': EulerAncestralDiscreteScheduler,
'EulerDiscreteScheduler': EulerDiscreteScheduler,
}
HOTSHOTXL_CACHE = "hotshot-xl"
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
self.pipe = HotshotXLPipeline.from_pretrained(
HOTSHOTXL_CACHE,
torch_dtype=torch.float16,
use_safetensors=True
).to('cuda')
def to_pil_images(self, video_frames: torch.Tensor, output_type='pil'):
to_pil = transforms.ToPILImage()
video_frames = rearrange(video_frames, "b c f w h -> b f c w h")
bsz = video_frames.shape[0]
images = []
for i in range(bsz):
video = video_frames[i]
for j in range(video.shape[0]):
if output_type == "pil":
images.append(to_pil(video[j]))
else:
images.append(video[j])
return images
def predict(
self,
prompt: str = Input(description="Input prompt", default="a camel smoking a cigarette, hd, high quality"),
negative_prompt: str = Input(description="Negative prompt", default="blurry"),
scheduler: str = Input(
default="EulerAncestralDiscreteScheduler",
choices=[
"EulerAncestralDiscreteScheduler",
"EulerDiscreteScheduler",
],
description="Select a Scheduler",
),
steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=30
),
mp4: bool = Input(
description="Save as mp4, False for GIF", default=False
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""
# Default text2Gif parameters
width = 672
height = 384
target_width = 512
target_height = 512
og_width = 1920
og_height = 1080
video_length = 8
video_duration = 1000
pipe = self.pipe
SchedulerClass = SCHEDULERS[scheduler]
if SchedulerClass is not None:
pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config)
pipe.enable_xformers_memory_efficient_attention()
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
generator = torch.Generator().manual_seed(seed)
kwargs = {}
images = pipe(prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
original_size=(og_width, og_height),
target_size=(target_width, target_height),
num_inference_steps=steps,
video_length=video_length,
generator=generator,
output_type="tensor", **kwargs
).videos
images = self.to_pil_images(images, output_type="pil")
out_path = "output.gif"
save_as_gif(images, out_path, duration=video_duration // video_length)
if mp4:
out_path = Path(tempfile.mkdtemp()) / "out.mp4"
os.system("ffmpeg -i output.gif -movflags faststart -pix_fmt yuv420p -qp 17 "+ str(out_path))
return Path(out_path)