forked from facebookresearch/chameleon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_tokenizer.py
124 lines (98 loc) · 4.27 KB
/
image_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Meta Platforms, Inc. and affiliates
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import PIL
import torch
import yaml
from PIL import Image
from chameleon.inference.vqgan import VQModel
class ImageTokenizer:
def __init__(
self,
cfg_path: str,
ckpt_path: str,
device: str | torch.device | None = None,
):
with open(cfg_path) as f:
config = yaml.safe_load(f)
params = config["model"]["params"]
if "lossconfig" in params:
del params["lossconfig"]
params["ckpt_path"] = ckpt_path
self._vq_model = VQModel(**params)
self._vq_model.eval()
if device is None:
devices = {p.device for p in self._vq_model.parameters()}
assert len(devices) == 1
device = devices.pop()
else:
self._vq_model.to(device)
self._device = device
dtypes = {p.dtype for p in self._vq_model.parameters()}
assert len(dtypes) == 1
self._dtype = dtypes.pop()
def _whiten_transparency(self, img: PIL.Image) -> PIL.Image:
# Check if it's already in RGB format.
if img.mode == "RGB":
return img
vals_rgba = np.array(img.convert("RGBA"))
# If there is no transparency layer, simple convert and return.
if not (vals_rgba[:, :, 3] < 255).any():
return img.convert("RGB")
# There is a transparency layer, blend it with a white background.
# Calculate the alpha proportion for blending.
alpha = vals_rgba[:, :, 3] / 255.0
# Blend with white background.
vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[
:, :, np.newaxis
] * vals_rgba[:, :, :3]
return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB")
def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor:
# Resize with aspect ratio preservation.
s = min(img.size)
scale = target_image_size / s
new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
img = img.resize(new_size, PIL.Image.LANCZOS)
# Center crop.
x0 = (img.width - target_image_size) // 2
y0 = (img.height - target_image_size) // 2
img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))
# Convert to tensor.
np_img = np.array(img) / 255.0 # Normalize to [0, 1]
np_img = np_img * 2 - 1 # Scale to [-1, 1]
tensor_img = (
torch.from_numpy(np_img).permute(2, 0, 1).float()
) # (Channels, Height, Width) format.
# Add batch dimension.
return tensor_img.unsqueeze(0)
def img_tokens_from_pil(self, image: PIL.Image) -> list[int]:
image = self._whiten_transparency(image)
vqgan_input = self._vqgan_input_from(image).to(self._device).to(self._dtype)
_, _, [_, _, img_toks] = self._vq_model.encode(vqgan_input)
return img_toks
def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image:
# Ensure detachment and move tensor to CPU.
detached_chw_tensor = chw_tensor.detach().cpu()
# Normalize tensor to [0, 1] range from [-1, 1] range.
normalized_chw_tensor = (
torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
) / 2.0
# Permute CHW tensor to HWC format and convert to NumPy array.
hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
# Convert to an 8-bit unsigned integer format.
image_array_uint8 = (hwc_array * 255).astype(np.uint8)
# Convert NumPy array to PIL Image.
pil_image = Image.fromarray(image_array_uint8)
# Convert image to RGB if it is not already.
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
return pil_image
def pil_from_img_toks(self, img_tensor: torch.Tensor) -> PIL.Image:
emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
codebook_entry = self._vq_model.quantize.get_codebook_entry(
img_tensor, (1, 32, 32, emb_dim)
)
pixels = self._vq_model.decode(codebook_entry)
return self._pil_from_chw_tensor(pixels[0])