Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] [interpolators] RIFE #4

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
port implmentation of rife v4.22; add modeling/hub utils
a-r-r-o-w committed Sep 11, 2024
commit 0049c7b997e0d20436a7120cac662dc3089e3323
654 changes: 654 additions & 0 deletions src/image_gen_aux/interpolators/model_mixin.py

Large diffs are not rendered by default.

Empty file.
208 changes: 208 additions & 0 deletions src/image_gen_aux/interpolators/rife/modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import List, Optional, Tuple, Type

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..model_mixin import ModelMixin, register_to_config


class Encoder(nn.Module):
def __init__(
self,
in_channels: int = 3,
hidden_channels: int = 32,
out_channels: int = 8,
nonlinearity: Type[nn.Module] = nn.LeakyReLU,
) -> None:
super().__init__()

blocks = []
blocks.append(nn.Conv2d(in_channels, hidden_channels, 3, 2, 1))
blocks.append(nonlinearity())
blocks.append(nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1))
blocks.append(nonlinearity())
blocks.append(nn.Conv2d(hidden_channels, hidden_channels, 3, 1, 1))
blocks.append(nonlinearity())
blocks.append(nn.ConvTranspose2d(hidden_channels, out_channels, 4, 2, 1))
blocks.append(nonlinearity())
self.blocks = nn.ModuleList(blocks)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for block in self.blocks:
hidden_states = block(hidden_states)
return hidden_states


class ResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
nonlinearity: Type[nn.Module] = nn.LeakyReLU,
) -> None:
super().__init__()

self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.beta = nn.Parameter(torch.ones((1, out_channels, 1, 1)), requires_grad=True)
self.nonlinearity = nonlinearity()

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states * self.beta
hidden_states = hidden_states + residual
hidden_states = self.nonlinearity(hidden_states)

return hidden_states


class IntermediateFlowBlock(nn.Module):
def __init__(
self,
in_channels: int,
hidden_channels: int = 64,
out_channels: Optional[int] = None,
num_hidden_blocks: int = 8,
nonlinearity: Type[nn.Module] = nn.LeakyReLU,
) -> None:
super().__init__()

out_channels = out_channels or in_channels

self.conv_input = nn.Sequential(
nn.Conv2d(in_channels, hidden_channels // 2, 3, 2, 1),
nonlinearity(),
nn.Conv2d(hidden_channels // 2, hidden_channels, 3, 2, 1),
nonlinearity(),
)

blocks = []
for _ in range(num_hidden_blocks):
blocks.append(ResnetBlock(hidden_channels, hidden_channels, 3, 1, 1, nonlinearity))
self.blocks = nn.ModuleList(blocks)

self.conv_output = nn.Sequential(
nn.ConvTranspose2d(hidden_channels, out_channels, 4, 2, 1), nn.PixelShuffle(2)
)

def forward(
self, hidden_states: torch.Tensor, flow: Optional[torch.Tensor] = None, scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = F.interpolate(hidden_states, scale_factor=1 / scale, mode="bilinear", align_corners=False)

if flow is not None:
flow = F.interpolate(hidden_states, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale
hidden_states = torch.cat([hidden_states, flow], dim=1)

hidden_states = self.conv_input(hidden_states)

for block in self.blocks:
hidden_states = block(hidden_states)

hidden_states = self.conv_output(hidden_states)

hidden_states = F.interpolate(hidden_states, scale_factor=scale, mode="bilinear", align_corners=False)
flow, mask, features = hidden_states.split_with_sizes([4, 1, hidden_states.size(1) - 5], dim=1)
flow = flow * scale

return flow, mask, features


class IntermediateFlowNet(ModelMixin):
@register_to_config
def __init__(
self,
in_channels: List[int] = [7 + 16, 8 + 4 + 16 + 8, 8 + 4 + 16 + 8, 8 + 4 + 16 + 8],
hidden_channels: List[int] = [256, 192, 96, 48],
out_channels: List[int] = [52, 52, 52, 52],
encoder_in_channels: int = 3,
encoder_out_channels: int = 32,
num_hidden_blocks: int = 8,
nonlinearity: Type[nn.Module] = functools.partial(nn.LeakyReLU, negative_slope=0.2),
) -> None:
super().__init__()

self.encoder = Encoder(encoder_in_channels, encoder_out_channels, 8, nonlinearity)

blocks = []
for in_channel, hidden_channel, out_channel in zip(in_channels, hidden_channels, out_channels):
blocks.append(
IntermediateFlowBlock(in_channel, hidden_channel, out_channel, num_hidden_blocks, nonlinearity)
)
self.blocks = nn.ModuleList(blocks)

def forward(
self, image_1: torch.Tensor, image_2: torch.Tensor, timestep: float = 0.5, scale: List[float] = [8, 4, 2, 1]
) -> torch.Tensor:
batch_size, channels, height, width = image_1.shape
assert image_1.shape == image_2.shape

timestep = image_1.new_full((1, 1, height, width), timestep)
warped_image_1 = image_1
warped_image_2 = image_2

hidden_states_1 = self.encoder(image_1)
hidden_states_2 = self.encoder(image_2)

flow, mask, features = None, None, None

for i, block in enumerate(self.blocks):
if flow is None:
hidden_states = torch.cat([image_1, image_2, hidden_states_1, hidden_states_2, timestep], dim=1)
flow, mask, features = block(hidden_states, None, scale[i])
else:
flow_1, flow_2 = flow.split(2, dim=1)
warped_flow_1 = _warp(hidden_states_1, flow_1)
warped_flow_2 = _warp(hidden_states_2, flow_2)
hidden_states = torch.cat(
[warped_image_1, warped_image_2, warped_flow_1, warped_flow_2, timestep, mask, features], dim=1
)
flow_, mask, features = block(hidden_states, flow, scale[i])
flow = flow + flow_

flow_1, flow_2 = flow.split(2, dim=1)

warped_image_1 = _warp(image_1, flow_1)
warped_image_2 = _warp(image_2, flow_2)

mask = mask.sigmoid()
output = warped_image_1 * mask + warped_image_2 * (1 - mask)
return output


def _warp(hidden_states: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
batch_size, channels, height, width = flow.shape

flow_1 = flow[:, 0:1, :, :] / ((width - 1) / 2)
flow_2 = flow[:, 1:2, :, :] / ((height - 1) / 2)
flow = torch.cat([flow_1, flow_2], dim=1)

horizontal = torch.linspace(-1, 1, width, device=flow.device, dtype=flow.dtype).view(1, 1, 1, width)
horizontal = horizontal.expand(batch_size, -1, height, -1)
vertical = torch.linspace(-1, 1, height, device=flow.device, dtype=flow.dtype).view(1, 1, height, 1)
vertical = vertical.expand(batch_size, -1, -1, width)
grid = torch.cat([horizontal, vertical], dim=1)
grid = grid + flow
grid = grid.permute(0, 2, 3, 1)

warped_flow = F.grid_sample(hidden_states, grid, mode="bilinear", padding_mode="border", align_corners=False)
return warped_flow
Empty file.
109 changes: 109 additions & 0 deletions src/image_gen_aux/interpolators/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
import os
from typing import Optional, Union

import numpy as np

from ...utils.logging import get_logger
from . import __version__


logger = get_logger(__name__)


class ModelConfig:
def __init__(self):
super().__setattr__("_internal_dict", {})

def __getattr__(self, name):
if name not in self._internal_dict:
raise AttributeError(f"Config does not have attribute: {name}")
return self._internal_dict.get(name)

def __setattr__(self, name, value):
self._internal_dict[name] = value

def __getitem__(self, name):
if name not in self._internal_dict:
raise KeyError(f"Config does not have key: {name}")
return self._internal_dict.get(name)

def __setitem__(self, name, value):
self._internal_dict[name] = value

def __delitem__(self, name):
del self._internal_dict[name]

def __iter__(self):
return iter(self._internal_dict)

def __len__(self):
return len(self._internal_dict)

def __repr__(self):
return repr(self._internal_dict)

def __str__(self):
return str(self._internal_dict)

def __contains__(self, name):
return name in self._internal_dict

@staticmethod
def _to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
return value

def to_json_string(self, class_name: Optional[str] = None) -> str:
r"""
Serialize the configuration to a JSON formatted string.

Returns:
str:
String containing all the attributes that make up the configuration instance in JSON format.
"""
config_dict: dict = copy.deepcopy(self._internal_dict)
config_dict["_class_name"] = class_name or self.__class__.__name__
config_dict["_backbones_version"] = __version__

config_dict = {k: self._to_json_saveable(v) for k, v in config_dict.items()}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

def to_json_file(self, path: Union[str, os.PathLike], class_name: Optional[str] = None) -> None:
r"""
Serialize the configuration to a JSON formatted file.

Args:
path (str or os.PathLike):
The path to the file where the configuration will be saved.
"""
with open(path, "w", encoding="utf-8") as f:
f.write(self.to_json_string(class_name))


def is_primitive_type(value) -> bool:
return isinstance(value, (int, float, str, bool))


def is_primitive_type_dict(value) -> bool:
return isinstance(value, dict) and all(is_primitive_type(v) for v in value.values())


def is_primitive_type_list(value) -> bool:
return isinstance(value, (list, tuple)) and all(is_primitive_type(v) for v in value)
24 changes: 24 additions & 0 deletions src/image_gen_aux/interpolators/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os


CONFIG_NAME = "config.json"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
SAFETENSORS_FILE_EXTENSION = "safetensors"
SAFETENSORS_WEIGHTS_NAME = "interpolators_pytorch_model.safetensors"
SAFETENSORS_WEIGHTS_INDEX_NAME = "interpolators_pytorch_model.safetensors.index.json"
WEIGHTS_NAME = "interpolators_pytorch_model.bin"
WEIGHTS_INDEX_NAME = "interpolators_pytorch_model.bin.index.json"
414 changes: 414 additions & 0 deletions src/image_gen_aux/interpolators/utils/hub_utils.py

Large diffs are not rendered by default.

126 changes: 126 additions & 0 deletions src/image_gen_aux/interpolators/utils/model_loading_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Union

import safetensors
import torch
from huggingface_hub.utils import EntryNotFoundError

from .constants import SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
from .hub_utils import _add_variant, _get_model_file


def load_state_dict(checkpoint_file: Union[str, os.PathLike]) -> Dict[str, Any]:
r"""Reads a checkpoint file, returning properly formatted errors if they arise."""
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
else:
return torch.load(
checkpoint_file,
map_location="cpu",
)
except Exception as e:
try:
with open(checkpoint_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError(
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
"model. Make sure you have saved the model properly."
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
)


def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
error_msgs = []

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix: str = ""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
module._load_from_state_dict(*args)

for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")

load(model_to_load)

return error_msgs


def _fetch_index_file(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
resume_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
_add_variant(SAFETENSORS_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
)
else:
index_file_in_repo = Path(
subfolder or "",
_add_variant(SAFETENSORS_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
).as_posix()
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
index_file = None

return index_file