Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for LTX-Video model in ImageToVideo Pipeline #394

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
43 changes: 37 additions & 6 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import inspect
import os
import time
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type

import PIL
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers import DiffusionPipeline, LTXImageToVideoPipeline, StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile

Expand All @@ -22,6 +23,8 @@

class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
self.pipeline_name = ""

self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

Expand All @@ -41,9 +44,17 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs)
logger.info("Loading DiffusionPipeline for model_id: %s", model_id)
self.ldm = DiffusionPipeline.from_pretrained(model_id, **kwargs)

if any(substring in model_id.lower() for substring in ("ltx-video", "ltx")):
logger.info("Adjusting to LTXImageToVideoPipeline for model_id: %s", model_id)
self.ldm = LTXImageToVideoPipeline.from_pipe(self.ldm)

self.ldm.to(get_torch_device())

self.pipeline_name = type(self.ldm).__name__

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
ad-astra-video marked this conversation as resolved.
Show resolved Hide resolved
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
if sfast_enabled and deepcache_enabled:
Expand All @@ -52,7 +63,9 @@ def __init__(self, model_id: str):
"as it may lead to suboptimal performance. Please disable one of them."
)

if sfast_enabled:
if sfast_enabled and self.pipeline_name == "LTXImageToVideoPipeline":
logger.warning("StableFast optimization is not compatible with LTXImageToVideoPipeline so,skipping.")
elif sfast_enabled:
logger.info(
"ImageToVideoPipeline will be dynamically compiled with stable-fast "
"for %s",
Expand Down Expand Up @@ -95,9 +108,11 @@ def __init__(self, model_id: str):
)
logger.info("Total warmup time: %s seconds", total_time)

if deepcache_enabled:
if deepcache_enabled and self.pipeline_name == "LTXImageToVideoPipeline":
logger.warning("DeepCache optimization is not compatible with LTXImageToVideoPipeline so,skipping.")
elif deepcache_enabled:
logger.info(
"TextToImagePipeline will be optimized with DeepCache for %s",
"ImageToVideoPipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache
Expand Down Expand Up @@ -132,6 +147,13 @@ def __call__(
):
del kwargs["num_inference_steps"]

if self.pipeline_name == "LTXImageToVideoPipeline":
pipeline_class = LTXImageToVideoPipeline
elif self.pipeline_name == "StableVideoDiffusionPipeline":
pipeline_class = StableVideoDiffusionPipeline

kwargs = self._filter_valid_kwargs(pipeline_class, kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images([image])
else:
Expand All @@ -146,5 +168,14 @@ def __call__(

return outputs.frames, has_nsfw_concept

@staticmethod
def _filter_valid_kwargs(pipeline_class: Type, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Filters the kwargs to just include keys that are necesssary for the pipeline_class.
"""

valid_kwargs = inspect.signature(pipeline_class.__call__).parameters.keys()
return {k: v for k, v in kwargs.items() if k in valid_kwargs}

def __str__(self) -> str:
return f"ImageToVideoPipeline model_id={self.model_id}"
19 changes: 19 additions & 0 deletions runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ async def image_to_video(
UploadFile,
File(description="Uploaded image to generate a video from."),
],
prompt: Annotated[
str,
Form(description="Text prompt(s) to guide video generation for prompt accepting models.")
] = "",
negative_prompt: Annotated[
str,
Form(
description=(
"Text prompt(s) to guide what to exclude from video generation for prompt accepting models. "
"Ignored if guidance_scale < 1."
)
),
] = "",
model_id: Annotated[
str, Form(description="Hugging Face model ID used for video generation.")
] = "",
Expand Down Expand Up @@ -123,6 +136,9 @@ async def image_to_video(
)
),
] = 25, # NOTE: Hardcoded due to varying pipeline values.
num_frames: Annotated[
int, Form(description="The number of video frames to generate.")
] = 25, # NOTE: Added `25` as default value to consider for `stable-video-diffusion-img2vid-xt` model having smaller default value than LTX-V in its pipeline.
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand Down Expand Up @@ -159,6 +175,9 @@ async def image_to_video(
try:
batch_frames, has_nsfw_concept = pipeline(
image=Image.open(image.file).convert("RGB"),
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
height=height,
width=width,
fps=fps,
Expand Down
1 change: 1 addition & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function download_all_models() {

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
huggingface-cli download Lightricks/LTX-Video --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models

# Download image-to-text models.
huggingface-cli download Salesforce/blip-image-captioning-large --include "*.safetensors" "*.json" --cache-dir models
Expand Down
16 changes: 16 additions & 0 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,17 @@ components:
format: binary
title: Image
description: Uploaded image to generate a video from.
prompt:
type: string
title: Prompt
description: Text prompt(s) to guide video generation for prompt accepting models.
default: ''
negative_prompt:
type: string
title: Negative Prompt
description: Text prompt(s) to guide what to exclude from video generation for prompt accepting models.
Ignored if guidance_scale < 1.
default: ''
model_id:
type: string
title: Model Id
Expand Down Expand Up @@ -709,6 +720,11 @@ components:
description: Number of denoising steps. More steps usually lead to higher
quality images but slower inference. Modulated by strength.
default: 25
num_frames:
type: integer
title: Num Frames
description: The number of video frames to generate.
default: 25
type: object
required:
- image
Expand Down
16 changes: 16 additions & 0 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,17 @@ components:
format: binary
title: Image
description: Uploaded image to generate a video from.
prompt:
type: string
title: Prompt
description: Text prompt(s) to guide video generation for prompt accepting models.
default: ''
negative_prompt:
type: string
title: Negative Prompt
description: Text prompt(s) to guide what to exclude from video generation for prompt accepting models.
Ignored if guidance_scale < 1.
default: ''
model_id:
type: string
title: Model Id
Expand Down Expand Up @@ -744,6 +755,11 @@ components:
description: Number of denoising steps. More steps usually lead to higher
quality images but slower inference. Modulated by strength.
default: 25
num_frames:
type: integer
title: Num Frames
description: The number of video frames to generate.
default: 25
type: object
required:
- image
Expand Down
2 changes: 1 addition & 1 deletion runner/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
diffusers==0.31.0
diffusers==0.32.1
accelerate==0.30.1
transformers==4.43.3
fastapi==0.111.0
Expand Down
15 changes: 15 additions & 0 deletions worker/multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ func NewImageToVideoMultipartWriter(w io.Writer, req GenImageToVideoMultipartReq
return nil, fmt.Errorf("failed to copy image to multipart request imageBytes=%v copiedBytes=%v", imageSize, copied)
}

if req.Prompt != nil {
if err := mw.WriteField("prompt", *req.Prompt); err != nil {
return nil, err
}
}
if req.NegativePrompt != nil {
if err := mw.WriteField("negative_prompt", *req.NegativePrompt); err != nil {
return nil, err
}
}
if req.ModelId != nil {
if err := mw.WriteField("model_id", *req.ModelId); err != nil {
return nil, err
Expand Down Expand Up @@ -157,6 +167,11 @@ func NewImageToVideoMultipartWriter(w io.Writer, req GenImageToVideoMultipartReq
return nil, err
}
}
if req.NumFrames != nil {
if err := mw.WriteField("num_frames", strconv.Itoa(*req.NumFrames)); err != nil {
return nil, err
}
}

if err := mw.Close(); err != nil {
return nil, err
Expand Down
Loading
Loading