diff --git a/configs/rec/dptr/dptr_parseq_finetune.yml b/configs/rec/dptr/dptr_parseq_finetune.yml new file mode 100644 index 0000000..cc6e484 --- /dev/null +++ b/configs/rec/dptr/dptr_parseq_finetune.yml @@ -0,0 +1,104 @@ +Global: + device: gpu + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + output_dir: /share/ckpt/zhaoshuai/openocr/dptr_parseq/ + eval_epoch_step: [0, 1] + eval_batch_step: [0, 500] + cal_metric_during_train: True + pretrained_model: + checkpoints: + use_tensorboard: false + infer_img: + # for data or label process + character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt + max_text_length: &max_text_length 25 + use_space_char: &use_space_char False + use_amp: True + save_res_path: /share/ckpt/zhaoshuai/openocr/dptr_parseq/predicts_dptr_parseq.txt + grad_clip_val: 20 + +Optimizer: + name: AdamW + lr: 0.001485 # 2gpus 384bs/gpu + weight_decay: 0. + filter_bias_and_bn: False + +LRScheduler: + name: OneCycleLR + warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep + cycle_momentum: False + +Architecture: + model_type: rec + algorithm: DPTR + Transform: + Encoder: + name: ViT + embed_dim: 512 + num_heads: 8 + Decoder: + name: DptrParseq + decode_ar: True + refine_iters: 1 + is_pretrain: False + +Loss: + name: PARSeqLoss + +PostProcess: + name: ARLabelDecode + character_dict_path: *character_dict_path + use_space_char: *use_space_char + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: LMDBDataSet + data_dir: /share/test/zhaoshuai/parseq-data/data/train/real/ArT + transforms: + - DecodeImagePIL: # load image + img_mode: RGB + - PARSeqAugPIL: + - DPTRLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - RecTVResize: + image_shape: [ 32, 128 ] + padding: False + - KeepKeys: + keep_keys: ['image', 'label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 4 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSet + data_dir: /share/test/zhaoshuai/parseq-data/data/val + transforms: + - DecodeImagePIL: # load image + img_mode: RGB + - PARSeqAugPIL: + - DPTRLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - RecTVResize: + image_shape: [ 32, 128 ] + padding: False + - KeepKeys: + keep_keys: ['image', 'label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 4 + num_workers: 2 diff --git a/configs/rec/dptr/dptr_parseq_pretrain.yml b/configs/rec/dptr/dptr_parseq_pretrain.yml new file mode 100644 index 0000000..32b8adf --- /dev/null +++ b/configs/rec/dptr/dptr_parseq_pretrain.yml @@ -0,0 +1,88 @@ +Global: + device: gpu + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + output_dir: /share/ckpt/zhaoshuai/openocr/dptr_parseq/ + eval_epoch_step: [0, 1] + eval_batch_step: [0, 500] + cal_metric_during_train: True + pretrained_model: + checkpoints: + use_tensorboard: false + infer_img: + # for data or label process + character_dict_path: &character_dict_path ./tools/utils/EN_symbol_dict.txt + max_text_length: &max_text_length 25 + use_space_char: &use_space_char False + use_amp: True + save_res_path: /share/ckpt/zhaoshuai/openocr/dptr_parseq/predicts_dptr_parseq.txt + grad_clip_val: 20 + +Optimizer: + name: AdamW + lr: 0.001485 # 2gpus 384bs/gpu + weight_decay: 0. + filter_bias_and_bn: False + +LRScheduler: + name: OneCycleLR + warmup_epoch: 1.5 # pct_start 0.075*20 = 1.5ep + cycle_momentum: False + +Architecture: + model_type: rec + algorithm: DPTR + Decoder: + name: DptrParseq + decode_ar: True + refine_iters: 1 + is_pretrain: True + ORP_path: /share/ckpt/zhaoshuai/parseq/clip_background.pth + +Loss: + name: PARSeqLoss + +PostProcess: + name: ARLabelDecode + character_dict_path: *character_dict_path + use_space_char: *use_space_char + +Metric: + name: RecMetric + main_indicator: acc + is_filter: True + +Train: + dataset: + name: TextLMDBDataSet + data_dir: /share/test/zhaoshuai/parseq-data/data/train/real/ArT + transforms: + - DPTRLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['clip_label', 'label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: TextLMDBDataSet + data_dir: /share/test/zhaoshuai/parseq-data/data/val + transforms: + - DPTRLabelEncode: # Class handling label + character_dict_path: *character_dict_path + use_space_char: *use_space_char + max_text_length: *max_text_length + - KeepKeys: + keep_keys: ['clip_label', 'label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/openrec/modeling/clip/bpe_simple_vocab_16e6.txt.gz b/openrec/modeling/clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000..7b5088a Binary files /dev/null and b/openrec/modeling/clip/bpe_simple_vocab_16e6.txt.gz differ diff --git a/openrec/modeling/clip/clip.py b/openrec/modeling/clip/clip.py new file mode 100644 index 0000000..257511e --- /dev/null +++ b/openrec/modeling/clip/clip.py @@ -0,0 +1,237 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model +from .simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. + We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + else: + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/openrec/modeling/clip/model.py b/openrec/modeling/clip/model.py new file mode 100644 index 0000000..20b469d --- /dev/null +++ b/openrec/modeling/clip/model.py @@ -0,0 +1,445 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x) + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # take features from the eot embedding (eot_token is the highest number in each sequence) + output = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + output = torch.cat([output.unsqueeze(1), x], dim=1) + + return output + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + # print("embedding_dim", embed_dim) + # print("image_resolution", image_resolution) + # print("vision_layers", vision_layers) + # print("vision_width", vision_width) + # print("vision_patch_size", vision_patch_size) + # print("context_length", context_length) + # print("vocab_size", vocab_size) + # print("transformer_width", transformer_width) + # print("transformer_heads", transformer_heads) + # print("transformer_layers", transformer_layers) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/openrec/modeling/clip/simple_tokenizer.py b/openrec/modeling/clip/simple_tokenizer.py new file mode 100644 index 0000000..0a66286 --- /dev/null +++ b/openrec/modeling/clip/simple_tokenizer.py @@ -0,0 +1,132 @@ +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text diff --git a/openrec/modeling/decoders/__init__.py b/openrec/modeling/decoders/__init__.py index 7d30028..cca6675 100644 --- a/openrec/modeling/decoders/__init__.py +++ b/openrec/modeling/decoders/__init__.py @@ -28,6 +28,7 @@ def build_decoder(config): from .cam_decoder import CAMDecoder from .ote_decoder import OTEDecoder from .bus_decoder import BUSDecoder + from .dptr_parseq_clip_b_decoder import DptrParseq support_dict = [ 'CTCDecoder', 'NRTRDecoder', 'CPPDDecoder', 'ABINetDecoder', @@ -35,7 +36,7 @@ def build_decoder(config): 'SMTRDecoder', 'LPVDecoder', 'SARDecoder', 'RobustScannerDecoder', 'SRNDecoder', 'ASTERDecoder', 'RCTCDecoder', 'LISTERDecoder', 'GTCDecoder', 'SMTRDecoderNumAttn', 'MATRNDecoder', 'MGPDecoder', - 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder' + 'DANDecoder', 'CAMDecoder', 'OTEDecoder', 'BUSDecoder', 'DptrParseq' ] module_name = config.pop('name') diff --git a/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py b/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py new file mode 100644 index 0000000..80a3d46 --- /dev/null +++ b/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py @@ -0,0 +1,611 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from itertools import permutations +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules import transformer +# from typing import Optional, Tuple, List +from ..clip.clip import load, tokenize + +class FMU(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', + layer_norm_eps=1e-5): + super().__init__() + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward(self, query: Tensor, memory: Tensor): + """Forward pass for a single stream (i.e. content or query) + tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency. + Both tgt_kv and memory are expected to be LayerNorm'd too. + memory is LayerNorm'd by ViT. + """ + query1, ca_weights = self.cross_attn(query, memory, memory) + query = query + self.dropout1(query1) + + query2 = self.linear2(self.dropout2(self.activation(self.linear1(self.norm(query))))) + query = query + self.dropout3(query2) + + return query + + +class DecoderLayer(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) This + implements a pre-LN decoder, as opposed to the post-LN default in + PyTorch.""" + + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation='gelu', + layer_norm_eps=1e-5, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=True) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward_stream( + self, + tgt: Tensor, + tgt_norm: Tensor, + tgt_kv: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor], + tgt_key_padding_mask: Optional[Tensor], + ): + """Forward pass for a single stream (i.e. content or query) tgt_norm is + just a LayerNorm'd tgt. + + Added as a separate parameter for efficiency. Both tgt_kv and memory + are expected to be LayerNorm'd too. memory is LayerNorm'd by ViT. + """ + tgt2, sa_weights = self.self_attn( + tgt_norm, + tgt_kv, + tgt_kv, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + + tgt = tgt + self.dropout1(tgt2) + + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + self.attn_map = ca_weights + tgt = tgt + self.dropout2(tgt2) + + tgt2 = self.linear2( + self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + update_content: bool = True, + ): + query_norm = self.norm_q(query) + content_norm = self.norm_c(content) + query = self.forward_stream(query, query_norm, content_norm, memory, + query_mask, content_key_padding_mask)[0] + if update_content: + content = self.forward_stream(content, content_norm, content_norm, + memory, content_mask, + content_key_padding_mask)[0] + return query, content + + +class Decoder(nn.Module): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = transformer._get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + query, + content, + memory, + query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + ): + for i, mod in enumerate(self.layers): + last = i == len(self.layers) - 1 + query, content = mod( + query, + content, + memory, + query_mask, + content_mask, + content_key_padding_mask, + update_content=not last, + ) + query = self.norm(query) + return query + + +class TokenEmbedding(nn.Module): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = nn.Embedding(charset_size, embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: torch.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens) + + +class DptrParseq(nn.Module): + + def __init__(self, + in_channels, + out_channels, + max_label_length=25, + embed_dim=512, + dec_num_heads=8, + dec_mlp_ratio=4, + dec_depth=6, + perm_num=6, + perm_forward=True, + perm_mirrored=True, + decode_ar=True, + refine_iters=1, + dropout=0.1, + is_pretrain=True, + ORP_path=None, + **kwargs: Any) -> None: + super().__init__() + self.pad_id = out_channels - 1 + self.eos_id = 0 + self.bos_id = out_channels - 2 + self.max_label_length = max_label_length + self.decode_ar = decode_ar + self.refine_iters = refine_iters + self.is_pretrain = is_pretrain + if not is_pretrain: + self.token_query = nn.Parameter(torch.Tensor(1, 26, embed_dim)) + self.fmu = FMU(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + + decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, + num_layers=dec_depth, + norm=nn.LayerNorm(embed_dim)) + + # Perm/attn mask stuff + self.rng = np.random.default_rng() + self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num + self.perm_forward = perm_forward + self.perm_mirrored = perm_mirrored + + # We don't predict nor + self.head = nn.Linear(embed_dim, out_channels - 2) + self.text_embed = TokenEmbedding(out_channels, embed_dim) + + # +1 for + self.pos_queries = nn.Parameter( + torch.Tensor(1, max_label_length + 1, embed_dim)) + self.dropout = nn.Dropout(p=dropout) + # Encoder has its own init. + self.apply(self._init_weights) + nn.init.trunc_normal_(self.pos_queries, std=0.02) + + if is_pretrain: + self.clip_encoder, preprocess = load("ViT-B/16") + for p in self.clip_encoder.parameters(): + p.requires_grad = False + if ORP_path is None: + background_image_folder_path = 'background_mages_folder/path' + self.background_features = self.get_noise(background_image_folder_path, preprocess) + torch.save(self.background_features, 'save/noise/to/ORP_path') + else: + self.background_features = torch.load(ORP_path, map_location='cpu') + + def _init_weights(self, module: nn.Module): + """Initialize the weights using the typical initialization schemes used + in SOTA models.""" + + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, + mode='fan_out', + nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + @torch.jit.ignore + def no_weight_decay(self): + param_names = {'text_embed.embedding.weight', 'pos_queries'} + return param_names + + def get_noise(self, background_image_path, preprocess): + image_paths = [os.path.join(background_image_path, filename) for filename in os.listdir(image_folder_path) if + filename.endswith(('.png', '.jpg', '.jpeg'))] + features = [] + for image_path in image_paths: + image = Image.open(image_path) + input = preprocess(image).unsqueeze(0).to(self._device) + with torch.no_grad(): + feature = self.clip_encoder.encode_image(input) + features.append(feature) + image.close() + return torch.cat(features).cpu().numpy() + + def clip_encode(self, labels): + text_inputs = torch.cat([tokenize(f"a photo of a '{c}'") for c in labels]).to(self._device) + + return self.clip_encoder.encode_text(text_inputs) + + def decode( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[Tensor] = None, + tgt_padding_mask: Optional[Tensor] = None, + tgt_query: Optional[Tensor] = None, + tgt_query_mask: Optional[Tensor] = None, + pos_query: torch.Tensor = None, + ): + N, L = tgt.shape + # stands for the null context. We only supply position information for characters after . + null_ctx = self.text_embed(tgt[:, :1]) + + if tgt_query is None: + tgt_query = pos_query[:, :L] + tgt_emb = pos_query[:, :L - 1] + self.text_embed(tgt[:, 1:]) + tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1)) + + tgt_query = self.dropout(tgt_query) + return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, + tgt_mask, tgt_padding_mask) + + def forward(self, memory, data=None, pos_query=None): + # print(memory.shape, data[0].shape) + if self.training: + if self.is_pretrain: + return self.training_step(None, pos_query, data[0], memory) + return self.training_step(memory, pos_query, data[0], None) + else: + if self.is_pretrain: + return self.forward_test(None, memory, pos_query) + return self.forward_test(memory, None, pos_query) + + def forward_test(self, + memory: Tensor, clip_ids, + pos_query: Tensor = None, + max_length: Optional[int] = None) -> Tensor: + testing = max_length is None + max_length = (self.max_label_length if max_length is None else min( + max_length, self.max_label_length)) + + if self.is_pretrain: + memory = self.clip_encoder.encode_text(clip_ids) + else: + bs = memory.shape[0] + token_query = self.token_query.expand(bs, -1, -1) + memory = self.fmu(token_query, memory) + _device = memory.get_device() + bs = memory.shape[0] + # +1 for at end of sequence. + num_steps = max_length + 1 + # memory = self.encode(images) + + # Query positions up to `num_steps` + if pos_query is None: + pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1) + else: + pos_queries = pos_query + + # Special case for the forward permutation. Faster than using `generate_attn_masks()` + tgt_mask = query_mask = torch.triu( + torch.full((num_steps, num_steps), float('-inf'), device=_device), + 1) + self.attn_maps = [] + if self.decode_ar: + tgt_in = torch.full((bs, num_steps), + self.pad_id, + dtype=torch.long, + device=_device) + tgt_in[:, 0] = self.bos_id + + logits = [] + for i in range(num_steps): + j = i + 1 # next token index + # Efficient decoding: + # Input the context up to the ith token. We use only one query (at position = i) at a time. + # This works because of the lookahead masking effect of the canonical (forward) AR context. + # Past tokens have no access to future tokens, hence are fixed once computed. + tgt_out = self.decode( + tgt_in[:, :j], + memory, + tgt_mask[:j, :j], + tgt_query=pos_queries[:, i:j], + tgt_query_mask=query_mask[i:j, :j], + pos_query=pos_queries, + ) + self.attn_maps.append(self.decoder.layers[-1].attn_map) + # the next token probability is in the output's ith token position + p_i = self.head(tgt_out) + logits.append(p_i) + if j < num_steps: + # greedy decode. add the next token index to the target input + tgt_in[:, j] = p_i.squeeze().argmax(-1) + # Efficient batch decoding: If all output words have at least one EOS token, end decoding. + if testing and (tgt_in == self.eos_id).any(dim=-1).all(): + break + + logits = torch.cat(logits, dim=1) + else: + # No prior context, so input is just . We query all positions. + tgt_in = torch.full((bs, 1), + self.bos_id, + dtype=torch.long, + device=_device) + tgt_out = self.decode(tgt_in, + memory, + tgt_query=pos_queries, + pos_query=pos_queries) + logits = self.head(tgt_out) + + if self.refine_iters: + # For iterative refinement, we always use a 'cloze' mask. + # We can derive it from the AR forward mask by unmasking the token context to the right. + query_mask[torch.triu( + torch.ones(num_steps, + num_steps, + dtype=torch.bool, + device=_device), 2)] = 0 + bos = torch.full((bs, 1), + self.bos_id, + dtype=torch.long, + device=_device) + for i in range(self.refine_iters): + # Prior context is the previous output. + tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1) + tgt_len = tgt_in.shape[1] + tgt_padding_mask = (tgt_in == self.eos_id).int().cumsum( + -1) > 0 # mask tokens beyond the first EOS token. + tgt_out = self.decode( + tgt_in, + memory, + tgt_mask[:tgt_len, :tgt_len], + tgt_padding_mask, + tgt_query=pos_queries, + tgt_query_mask=query_mask[:, :tgt_len], + pos_query=pos_queries, + ) + logits = self.head(tgt_out) + + return F.softmax(logits, -1) + + def gen_tgt_perms(self, tgt, _device): + """Generate shared permutations for the whole batch. + + This works because the same attention mask can be used for the shorter + sequences because of the padding mask. + """ + # We don't permute the position of BOS, we permute EOS separately + max_num_chars = tgt.shape[1] - 2 + # Special handling for 1-character sequences + if max_num_chars == 1: + return torch.arange(3, device=_device).unsqueeze(0) + perms = [torch.arange(max_num_chars, device=_device) + ] if self.perm_forward else [] + # Additional permutations if needed + max_perms = math.factorial(max_num_chars) + if self.perm_mirrored: + max_perms //= 2 + num_gen_perms = min(self.max_gen_perms, max_perms) + # For 4-char sequences and shorter, we generate all permutations and sample from the pool to avoid collisions + # Note that this code path might NEVER get executed since the labels in a mini-batch typically exceed 4 chars. + if max_num_chars < 5: + # Pool of permutations to sample from. We only need the first half (if complementary option is selected) + # Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves + if max_num_chars == 4 and self.perm_mirrored: + selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21] + else: + selector = list(range(max_perms)) + perm_pool = torch.as_tensor(list( + permutations(range(max_num_chars), max_num_chars)), + device=_device)[selector] + # If the forward permutation is always selected, no need to add it to the pool for sampling + if self.perm_forward: + perm_pool = perm_pool[1:] + perms = torch.stack(perms) + if len(perm_pool): + i = self.rng.choice(len(perm_pool), + size=num_gen_perms - len(perms), + replace=False) + perms = torch.cat([perms, perm_pool[i]]) + else: + perms.extend([ + torch.randperm(max_num_chars, device=_device) + for _ in range(num_gen_perms - len(perms)) + ]) + perms = torch.stack(perms) + if self.perm_mirrored: + # Add complementary pairs + comp = perms.flip(-1) + # Stack in such a way that the pairs are next to each other. + perms = torch.stack([perms, comp + ]).transpose(0, 1).reshape(-1, max_num_chars) + # NOTE: + # The only meaningful way of permuting the EOS position is by moving it one character position at a time. + # However, since the number of permutations = T! and number of EOS positions = T + 1, the number of possible EOS + # positions will always be much less than the number of permutations (unless a low perm_num is set). + # Thus, it would be simpler to just train EOS using the full and null contexts rather than trying to evenly + # distribute it across the chosen number of permutations. + # Add position indices of BOS and EOS + bos_idx = perms.new_zeros((len(perms), 1)) + eos_idx = perms.new_full((len(perms), 1), max_num_chars + 1) + perms = torch.cat([bos_idx, perms + 1, eos_idx], dim=1) + # Special handling for the reverse direction. This does two things: + # 1. Reverse context for the characters + # 2. Null context for [EOS] (required for learning to predict [EOS] in NAR mode) + if len(perms) > 1: + perms[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, + device=_device) + return perms + + def generate_attn_masks(self, perm, _device): + """Generate attention masks given a sequence permutation (includes pos. + for bos and eos tokens) + + :param perm: the permutation sequence. i = 0 is always the BOS + :return: lookahead attention masks + """ + sz = perm.shape[0] + mask = torch.zeros((sz, sz), device=_device) + for i in range(sz): + query_idx = perm[i] + masked_keys = perm[i + 1:] + mask[query_idx, masked_keys] = float('-inf') + content_mask = mask[:-1, :-1].clone() + mask[torch.eye(sz, dtype=torch.bool, + device=_device)] = float('-inf') # mask "self" + query_mask = mask[1:, :-1] + return content_mask, query_mask + + def training_step(self, memory, pos_query, tgt_ids, clip_ids): + bs = tgt_ids.shape[0] + if self.is_pretrain: + memory = self.clip_encoder.encode_text(clip_ids) + n = memory.shape[1] + B, N, D = self.background_features.shape + random_B = np.random.choice(B, bs, replace=False) + random_N = np.random.choice(N, n, replace=False) + noise = self.background_features[random_B][:, random_N] + noise = torch.from_numpy(noise).to(memory.get_device()) + memory = memory + noise * 1e-1 + else: + token_query = self.token_query.expand(bs, -1, -1) + memory = self.fmu(token_query, memory) + + if pos_query is None: + pos_query = self.pos_queries.expand(bs, -1, -1) + # Prepare the target sequences (input and output) + tgt_perms = self.gen_tgt_perms(tgt_ids, memory.get_device()) + tgt_in = tgt_ids[:, :-1] + tgt_out = tgt_ids[:, 1:] + + # The [EOS] token is not depended upon by any other token in any permutation ordering + tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id) + + loss = 0 + loss_numel = 0 + n = (tgt_out != self.pad_id).sum().item() + for i, perm in enumerate(tgt_perms): + tgt_mask, query_mask = self.generate_attn_masks( + perm, memory.get_device()) + # print("tgt_in:", tgt_in, "tgt_out:", tgt_out, "tgt_padding_mask:", tgt_padding_mask) + # print('tgt_mask:', tgt_mask) + # print('query_mask:', query_mask) + # print(tgt_in.shape, memory.shape, tgt_mask.shape, tgt_padding_mask.shape, query_mask.shape, pos_query.shape) + out = self.decode( + tgt_in, + memory, + tgt_mask, + tgt_padding_mask, + tgt_query_mask=query_mask, + pos_query=pos_query, + ) + # print('out:', out) + logits = self.head(out) + # print('logits:', logits) + if i == 0: + final_out = logits + loss += n * F.cross_entropy(logits.flatten(end_dim=1), + tgt_out.flatten(), + ignore_index=self.pad_id) + loss_numel += n + # After the second iteration (i.e. done with canonical and reverse orderings), + # remove the [EOS] tokens for the succeeding perms + if i == 1: + tgt_out = torch.where(tgt_out == self.eos_id, self.pad_id, + tgt_out) + n = (tgt_out != self.pad_id).sum().item() + loss /= loss_numel + + # self.log('loss', loss) + return [loss, final_out] diff --git a/openrec/preprocess/__init__.py b/openrec/preprocess/__init__.py index b229447..dd2c26f 100644 --- a/openrec/preprocess/__init__.py +++ b/openrec/preprocess/__init__.py @@ -23,6 +23,7 @@ from .srn_label_encode import SRNLabelEncode from .visionlan_label_encode import VisionLANLabelEncode from .cam_label_encode import CAMLabelEncode +from .dptr_label_encode import DPTRLabelEncode class KeepKeys(object): diff --git a/openrec/preprocess/dptr_label_encode.py b/openrec/preprocess/dptr_label_encode.py new file mode 100644 index 0000000..24c7796 --- /dev/null +++ b/openrec/preprocess/dptr_label_encode.py @@ -0,0 +1,157 @@ +import re +from abc import ABC, abstractmethod +from itertools import groupby +from typing import List, Optional, Tuple +import numpy as np +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence +import unicodedata +from ..modeling.clip.clip import tokenize + +class CharsetAdapter: + """Transforms labels according to the target charset.""" + + def __init__(self, target_charset) -> None: + super().__init__() + self.lowercase_only = target_charset == target_charset.lower() + self.uppercase_only = target_charset == target_charset.upper() + self.unsupported = re.compile(f'[^{re.escape(target_charset)}]') + + def __call__(self, label): + if self.lowercase_only: + label = label.lower() + elif self.uppercase_only: + label = label.upper() + # Remove unsupported characters + label = self.unsupported.sub('', label) + return label + + +class BaseTokenizer(ABC): +# eos=0, a=1, bos=37, pad=38 + def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: + self._itos = specials_first + tuple(charset) + specials_last + self._stoi = {s: i for i, s in enumerate(self._itos)} + # print("stoi:", self._stoi) + + def __len__(self): + return len(self._itos) + + def _tok2ids(self, tokens: str) -> List[int]: + # print("tokens", tokens) + return [self._stoi[s] for s in tokens] + + def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: + tokens = [self._itos[i] for i in token_ids] + return ''.join(tokens) if join else tokens + + @abstractmethod + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + """Encode a batch of labels to a representation suitable for the model. + + Args: + labels: List of labels. Each can be of arbitrary length. + device: Create tensor on this device. + + Returns: + Batched tensor representation padded to the max label length. Shape: N, L + """ + raise NotImplementedError + + @abstractmethod + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + """Internal method which performs the necessary filtering prior to decoding.""" + raise NotImplementedError + + def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: + """Decode a batch of token distributions. + + Args: + token_dists: softmax probabilities over the token distribution. Shape: N, L, C + raw: return unprocessed labels (will return list of list of strings) + + Returns: + list of string labels (arbitrary length) and + their corresponding sequence probabilities as a list of Tensors + """ + batch_tokens = [] + batch_probs = [] + for dist in token_dists: + probs, ids = dist.max(-1) # greedy selection + if not raw: + probs, ids = self._filter(probs, ids) + tokens = self._ids2tok(ids, not raw) + batch_tokens.append(tokens) + batch_probs.append(probs) + return batch_tokens, batch_probs + + +class Tokenizer(BaseTokenizer): + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, charset: str) -> None: + specials_first = (self.EOS,) + specials_last = (self.BOS, self.PAD) + super().__init__(charset, specials_first, specials_last) + self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + batch = [self.bos_id] + self._tok2ids(labels) + [self.eos_id] + return batch + # return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + ids = ids.tolist() + try: + eos_idx = ids.index(self.eos_id) + except ValueError: + eos_idx = len(ids) # Nothing to truncate. + # Truncate after EOS + ids = ids[:eos_idx] + probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) + return probs, ids + +class DPTRLabelEncode(Tokenizer): + """Convert between text-label and text-index.""" + def __init__(self, max_text_length=25, character_dict_path=None, **kwargs): + self.max_length = max_text_length + charset = get_alpha(character_dict_path) + charset = ''.join(charset) + # print(charset) + super(DPTRLabelEncode, self).__init__(charset) + + def __call__(self, data, normalize_unicode=True): + text = data['label'] + + if normalize_unicode: + text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode() + text = ''.join(text.split()) + if len(text) == 0 or len(text) > self.max_length: + return None + + text_ids = self.encode(text) + clip_ids = tokenize(f"a photo of a '{text}'") + text_ids = text_ids + [self.pad_id] * (self.max_length + 2 - len(text_ids)) + # print(text, len(text_ids), len(clip_ids[0])) + data['clip_label'] = np.array(clip_ids[0]) + data['label'] = np.array(text_ids) + return data + + def add_special_char(self, dict_character): + dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD] + return dict_character + +def get_alpha(alpha_path): + character_str = [] + with open(alpha_path, 'rb') as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip('\n').strip('\r\n') + character_str.append(line) + dict_character = list(character_str) + if 'arabic' in alpha_path: + reverse = True + return dict_character \ No newline at end of file diff --git a/tools/data/__init__.py b/tools/data/__init__.py index ad1a203..dbb6719 100644 --- a/tools/data/__init__.py +++ b/tools/data/__init__.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader, DistributedSampler from tools.data.lmdb_dataset import LMDBDataSet +from tools.data.text_lmdb_dataset import TextLMDBDataSet from tools.data.lmdb_dataset_test import LMDBDataSetTest from tools.data.multi_scale_sampler import MultiScaleSampler from tools.data.ratio_dataset import RatioDataSet @@ -30,7 +31,7 @@ def build_dataloader(config, mode, logger, seed=None, epoch=3): config = copy.deepcopy(config) support_dict = [ - 'SimpleDataSet', 'LMDBDataSet', 'MultiScaleDataSet', 'STRLMDBDataSet', + 'SimpleDataSet', 'LMDBDataSet', 'TextLMDBDataSet', 'MultiScaleDataSet', 'STRLMDBDataSet', 'LMDBDataSetTest', 'RatioDataSet', 'RatioDataSetTest', 'RatioDataSetTVResize', 'RatioDataSetTVResizeTest' ] diff --git a/tools/data/text_lmdb_dataset.py b/tools/data/text_lmdb_dataset.py new file mode 100644 index 0000000..2bd8f37 --- /dev/null +++ b/tools/data/text_lmdb_dataset.py @@ -0,0 +1,127 @@ +import os +import cv2 +import lmdb +import numpy as np +from torch.utils.data import Dataset + +from openrec.preprocess import create_operators, transform + + +class TextLMDBDataSet(Dataset): + + def __init__(self, config, mode, logger, seed=None, epoch=1): + super(TextLMDBDataSet, self).__init__() + + global_config = config['Global'] + dataset_config = config[mode]['dataset'] + loader_config = config[mode]['loader'] + loader_config['batch_size_per_card'] + data_dir = dataset_config['data_dir'] + self.do_shuffle = loader_config['shuffle'] + + self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) + logger.info(f'Initialize indexs of datasets: {data_dir}') + self.data_idx_order_list = self.dataset_traversal() + if self.do_shuffle: + np.random.shuffle(self.data_idx_order_list) + self.ops = create_operators(dataset_config['transforms'], + global_config) + self.ext_op_transform_idx = dataset_config.get('ext_op_transform_idx', + 1) + + ratio_list = dataset_config.get('ratio_list', [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + + def load_hierarchical_lmdb_dataset(self, data_dir): + lmdb_sets = {} + dataset_idx = 0 + for dirpath, dirnames, filenames in os.walk(data_dir + '/'): + if not dirnames: + env = lmdb.open( + dirpath, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + txn = env.begin(write=False) + num_samples = int(txn.get('num-samples'.encode())) + lmdb_sets[dataset_idx] = { + 'dirpath': dirpath, + 'env': env, + 'txn': txn, + 'num_samples': num_samples, + } + dataset_idx += 1 + return lmdb_sets + + def dataset_traversal(self): + lmdb_num = len(self.lmdb_sets) + total_sample_num = 0 + for lno in range(lmdb_num): + total_sample_num += self.lmdb_sets[lno]['num_samples'] + data_idx_order_list = np.zeros((total_sample_num, 2)) + beg_idx = 0 + for lno in range(lmdb_num): + tmp_sample_num = self.lmdb_sets[lno]['num_samples'] + end_idx = beg_idx + tmp_sample_num + data_idx_order_list[beg_idx:end_idx, 0] = lno + data_idx_order_list[beg_idx:end_idx, + 1] = list(range(tmp_sample_num)) + data_idx_order_list[beg_idx:end_idx, 1] += 1 + beg_idx = beg_idx + tmp_sample_num + return data_idx_order_list + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint( + len(self))] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]['txn'], file_idx) + if sample_info is None: + continue + label = sample_info + data = {'label': label} + data = transform(data, load_data_ops) + if data is None: + continue + ext_data.append(data) + return ext_data + + def get_lmdb_sample_info(self, txn, index, normalize_unicode=True, remove_whitespace=True, max_length=True): + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key) + if label is None: + return None + label = label.decode('utf-8') + + return label + + def __getitem__(self, idx): + lmdb_idx, file_idx = self.data_idx_order_list[idx] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]['txn'], file_idx) + if sample_info is None: + return self.__getitem__(np.random.randint(self.__len__())) + label = sample_info + data = {'label': label} + outs = transform(data, self.ops) + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return self.data_idx_order_list.shape[0]