From 1b4e184c20fe4a3c9f346224ffd2b5191eb1cc8e Mon Sep 17 00:00:00 2001 From: Justin Tan Date: Sun, 13 Sep 2020 22:09:58 +1000 Subject: [PATCH] First stable release --- assets/HiFIC_torch_colab_demo.ipynb | 7 +- src/compression/ans.py | 1 - src/compression/compression_utils.py | 14 ++ src/compression/entropy_coding.py | 198 +++++++++++++++++++++++---- src/compression/prior_model.py | 4 +- src/compression/reversed_coders.py | 150 -------------------- 6 files changed, 190 insertions(+), 184 deletions(-) delete mode 100644 src/compression/reversed_coders.py diff --git a/assets/HiFIC_torch_colab_demo.ipynb b/assets/HiFIC_torch_colab_demo.ipynb index cad28da..79f5e74 100644 --- a/assets/HiFIC_torch_colab_demo.ipynb +++ b/assets/HiFIC_torch_colab_demo.ipynb @@ -592,6 +592,7 @@ " clocktower=\"9cbf2594f339c0d3d0f0ea25c62af52b.png\",\n", " fresco=\"8181526d9f238726d3e1d3ec3cc56fb7.png\",\n", " islet=\"c6658d87c608b631f5cc3fb5a8d89731.png\",\n", + " mountain=\"d3688a7285d7b2b81febe1cd72e6e22c.png\",\n", " pasta=\"f5be5054c01d8efc834d78a991356ad6.png\",\n", " pines=\"e903c4f4684100a6dbac1f0b9b4de760.png\",\n", " plaza=\"d78b363974ac79908b79012f48de715d.png\",\n", @@ -812,7 +813,7 @@ "source": [ "# Choose default images from CLIC2020 dataset\n", "# Skip if uploading custom images\n", - "default_image = \"street\" #@param [\"cafe\", \"cat\", \"city\", \"clocktower\", \"fresco\", \"islet\", \"pasta\", \"pines\", \"plaza\", \"portrait\", \"shoreline\", \"street\", \"tundra\"]" + "default_image = \"street\" #@param [\"cafe\", \"cat\", \"city\", \"clocktower\", \"fresco\", \"islet\", \"mountain\", \"pasta\", \"pines\", \"plaza\", \"portrait\", \"shoreline\", \"street\", \"tundra\"]" ], "execution_count": null, "outputs": [] @@ -1111,7 +1112,9 @@ "colab_type": "text" }, "source": [ - "You can compress new images by going back to the \"Prepare Images\" heading and selecting a different default image or upload your own for compression, then running the cells below in sequence. Please open an issue if you encounter an error when running this demo." + "You can compress new images by going back to the \"Prepare Images\" heading and selecting a different default image or upload your own for compression, then running the cells below in sequence. Note that each model cannot decompress the output generated by a different model, and you need to delete the contents of `/content/out` if you want to try a different model.\n", + "\n", + "Please open an issue if you encounter an error when running this demo." ] }, { diff --git a/src/compression/ans.py b/src/compression/ans.py index a37a974..632eea4 100644 --- a/src/compression/ans.py +++ b/src/compression/ans.py @@ -73,7 +73,6 @@ def push(x, starts, freqs, precisions): tail = stack_extend(tail, np.uint32(head[idxs])) # Can also modulo with bitand head = np.copy(head) # Ensure no side-effects head[idxs] >>= 32 - head_div_freqs, head_mod_freqs = np.divmod(head, freqs) return (head_div_freqs << np.uint(precisions)) + head_mod_freqs + starts, tail diff --git a/src/compression/compression_utils.py b/src/compression/compression_utils.py index 792141f..14c5edd 100644 --- a/src/compression/compression_utils.py +++ b/src/compression/compression_utils.py @@ -3,7 +3,10 @@ import numpy as np import functools import os +import autograd.numpy as np +from autograd import make_vjp +from autograd.extend import vspace, VSpace from collections import namedtuple from src.helpers import utils @@ -76,6 +79,17 @@ def estimate_tails(cdf, target, shape, dtype=torch.float32, extra_counts=24): return tails +def view_update(data, view_fun): + view_vjp, item = make_vjp(view_fun)(data) + item_vs = vspace(item) + def update(new_item): + assert item_vs == vspace(new_item), \ + "Please ensure new_item shape and dtype match the data view." + diff = view_vjp(item_vs.add(new_item, + item_vs.scalar_mul(item, -np.uint64(1)))) + return vspace(data).add(data, diff) + return item, update + def decompose(x, n_channels, patch_size=PATCH_SIZE): # Decompose input x into spatial patches if isinstance(x, torch.Tensor) is False: diff --git a/src/compression/entropy_coding.py b/src/compression/entropy_coding.py index e865fcd..b1dfe6a 100644 --- a/src/compression/entropy_coding.py +++ b/src/compression/entropy_coding.py @@ -23,7 +23,7 @@ Codec = namedtuple('Codec', ['push', 'pop']) cast2u64 = lambda x: np.array(x, dtype=np.uint64) -def base_codec(enc_statfun, dec_statfun, precision): +def base_codec(enc_statfun, dec_statfun, precision, log=False): if np.any(precision >= 24): warn('Detected precision over 28. Codecs lose accuracy at high ' 'precision.') @@ -32,7 +32,7 @@ def push(message, symbol): start, freq = enc_statfun(symbol) return vrans.push(message, start, freq, precision) - def pop(message): + def pop(message, log=log): cf, pop_fun = vrans.pop(message, precision) symbol = dec_statfun(cf) start, freq = enc_statfun(symbol) @@ -93,7 +93,8 @@ def _dec_statfun(value): # (coding_shape) = (C,H,W) by default but can be generalized # cdf_i: [(coding_shape), pmf_length + 2] # value: [(coding_shape)] - assert value.shape == coding_shape, "CDF-value shape mismatch!" + assert value.shape == coding_shape, ( + f"CDF-value shape mismatch! {value.shape} v. {coding_shape}") sym_flat = np.array( [np.searchsorted(cb, v_i, 'right') - 1 for (cb, v_i) in zip(cdf_i_flat_ragged, value.flatten())]) @@ -283,7 +284,17 @@ def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset symbols = symbols.astype(np.int32) indices = indices.astype(np.int32) cdf_index = indices - + + max_overflow = (1 << overflow_width) - 1 + overflow_cdf_size = (1 << overflow_width) + 1 + overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64)[None, None, None, :] + + enc_statfun_overflow = _vec_indexed_cdf_to_enc_statfun(overflow_cdf) + dec_statfun_overflow = _vec_indexed_cdf_to_dec_statfun(overflow_cdf, + np.ones_like(overflow_cdf) * len(overflow_cdf)) + overflow_push, overflow_pop = base_codec(enc_statfun_overflow, + dec_statfun_overflow, overflow_width) + assert bool(np.all(cdf_index >= 0)) and bool(np.all(cdf_index < cdf.shape[0])), ( "Invalid index.") @@ -297,14 +308,11 @@ def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset # If outside of this range, map value to non-negative integer overflow. overflow = np.zeros_like(values) - - of_mask = values < 0 - overflow = np.where(of_mask, -2 * values - 1, overflow) - values = np.where(of_mask, max_value, values) - - of_mask = values >= max_value - overflow = np.where(of_mask, 2 * (values - max_value), overflow) - values = np.where(of_mask, max_value, values) + of_mask_lower = values < 0 + overflow = np.where(of_mask_lower, -2 * values - 1, overflow) + of_mask_upper = values >= max_value + overflow = np.where(of_mask_upper, 2 * (values - max_value), overflow) + values = np.where(np.logical_or(of_mask_lower, of_mask_upper), max_value, values) assert bool(np.all(values >= 0)), ( "Invalid shifted value for current symbol - values must be non-negative.") @@ -320,13 +328,18 @@ def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset factor=PATCH_SIZE).cpu().numpy().astype(np.int32) indices = utils.pad_factor(torch.Tensor(indices), symbols_shape[2:], factor=PATCH_SIZE).cpu().numpy().astype(np.int32) + overflow = utils.pad_factor(torch.Tensor(overflow), symbols_shape[2:], + factor=PATCH_SIZE).cpu().numpy().astype(np.int32) assert (values.shape[2] % PATCH_SIZE[0] == 0) and (values.shape[3] % PATCH_SIZE[1] == 0) assert (indices.shape[2] % PATCH_SIZE[0] == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0) values, _ = compression_utils.decompose(values, n_channels) + overflow, _ = compression_utils.decompose(overflow, n_channels) cdf_index, unfolded_shape = compression_utils.decompose(indices, n_channels) coding_shape = values.shape[1:] + assert coding_shape == cdf_index.shape[1:] + # LIFO - last item in buffer is first item decompressed for i in range(len(cdf_index)): # loop over batch dimension @@ -334,46 +347,119 @@ def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset value_i = values[i] cdf_index_i = cdf_index[i] cdf_i = cdf[cdf_index_i] - cdf_i_length = cdf_length[cdf_index_i] + cdf_length_i = cdf_length[cdf_index_i] + max_value_i = cdf_length_i - 2 enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i) - dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_i_length) + dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i) symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision) start, freq = enc_statfun(value_i) - instructions.append((start, freq, False)) + instructions.append((start, freq, False, precision, 0)) """ Encode overflows here """ - # of_mask = values == max_value - # widths = np.zeros_like(values) - # widths[of_mask] = 1 + # No-op + empty_start = np.zeros_like(value_i).astype(np.uint) + empty_freq = np.ones_like(value_i).astype(np.uint) - # overflow = - - # cond_mask = overflow >> (widths * overflow_width) != 0 - # while np.all(cond_mask) is False: - # widths[cond_mask] += 1 + overflow_i = overflow[i] + of_mask = value_i == max_value_i + + if np.any(of_mask): + + widths = np.zeros_like(value_i) + cond_mask = (overflow_i >> (widths * overflow_width)) != 0 + + while np.any(cond_mask): + widths = np.where(cond_mask, widths+1, widths) + cond_mask = (overflow_i >> (widths * overflow_width)) != 0 + + val = widths + cond_mask = val >= max_overflow + while np.any(cond_mask): + print('Warning: Undefined behaviour.') + val_push = cast2u64(max_overflow) + overflow_start, overflow_freq = enc_statfun_overflow(val_push) + start = overflow_start[of_mask] + freq = overflow_start[of_mask] + instructions.append((start, freq, True, int(overflow_width), of_mask)) + # val[cond_mask] -= max_overflow + val = np.where(cond_mask, val-max_overflow, val) + cond_mask = val >= max_overflow + + val_push = cast2u64(val) + overflow_start, overflow_freq = enc_statfun_overflow(val_push) + start = overflow_start[of_mask] + freq = overflow_freq[of_mask] + instructions.append((start, freq, True, int(overflow_width), of_mask)) + + cond_mask = widths != 0 + while np.any(cond_mask): + counter = 0 + encoding = (overflow_i >> (counter * overflow_width)) & max_overflow + val = np.where(cond_mask, encoding, val) + val_push = cast2u64(val) + overflow_start, overflow_freq = enc_statfun_overflow(val_push) + start = overflow_start[of_mask] + freq = overflow_freq[of_mask] + instructions.append((start, freq, True, int(overflow_width), of_mask)) + widths = np.where(cond_mask, widths-1, widths) + cond_mask = widths != 0 + counter += 1 - # val = widths - return instructions, coding_shape +def overflow_view(value, mask): + return value[mask] + +def substack(codec, view_fun): + """ + Apply a codec on a subset of a message head. + view_fun should be a function: head -> subhead, for example + view_fun = lambda head: head[0] + to run the codec on only the first element of the head + """ + def push(message, start, freq, precision, mask): + head, tail = message + view_fun_ = lambda x: view_fun(x, mask) + subhead, update = compression_utils.view_update(head, view_fun_) + subhead, tail = vrans.push((subhead, tail), start, freq, precision) + return update(subhead), tail + + def pop(message, precision, mask, *args, **kwargs): + head, tail = message + view_fun_ = lambda x: view_fun(x, mask) + subhead, update = compression_utils.view_update(head, view_fun_) + + cf, pop_fun = vrans.pop((subhead, tail), precision) + symbol = cf + start, freq = symbol, 1 + + assert np.all(start <= cf) and np.all(cf < start + freq) + (subhead, tail), data = pop_fun(start, freq), symbol + updated_head = update(subhead) + return (updated_head, tail), data + + return Codec(push, pop) + def vec_ans_index_encoder_flush(instructions, precision, coding_shape, overflow_width=OVERFLOW_WIDTH, **kwargs): message = vrans.empty_message(coding_shape) - + overflow_push, _ = substack(codec=None, view_fun=overflow_view) # LIFO - last item compressed is first item decompressed for i in reversed(range(len(instructions))): - start, freq, flag = instructions[i] + start, freq, flag, precision_i, mask = instructions[i] if flag is False: message = vrans.push(message, start, freq, precision) else: - message = vrans.push(message, start, freq, overflow_width) + # Substack on overflow values + overflow_precision = precision_i + message = overflow_push(message, start, freq, overflow_precision, mask) encoded = vrans.flatten(message) message_length = len(encoded) @@ -487,7 +573,17 @@ def vec_ans_index_decoder(encoded, indices, cdf, cdf_length, cdf_offset, precisi message = vrans.unflatten(encoded, coding_shape) indices = indices.astype(np.int32) cdf_index = indices - + + max_overflow = (1 << overflow_width) - 1 + overflow_cdf_size = (1 << overflow_width) + 1 + overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64)[None, :] + + enc_statfun_overflow = _vec_indexed_cdf_to_enc_statfun(overflow_cdf) + dec_statfun_overflow = _vec_indexed_cdf_to_dec_statfun(overflow_cdf, + np.ones_like(overflow_cdf) * len(overflow_cdf)) + overflow_codec = base_codec(enc_statfun_overflow, + dec_statfun_overflow, overflow_width) + assert bool(np.all(cdf_index >= 0)) and bool(np.all(cdf_index < cdf.shape[0])), ( "Invalid index.") @@ -510,6 +606,8 @@ def vec_ans_index_decoder(encoded, indices, cdf, cdf_length, cdf_offset, precisi symbols = [] + _, overflow_pop = substack(codec=overflow_codec, view_fun=overflow_view) + for i in range(len(cdf_index)): cdf_index_i = cdf_index[i] cdf_i = cdf[cdf_index_i] @@ -520,9 +618,51 @@ def vec_ans_index_decoder(encoded, indices, cdf, cdf_length, cdf_offset, precisi symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision) message, value = symbol_pop(message) + + max_value_i = cdf_length_i - 2 + of_mask = value == max_value_i + + if np.any(of_mask): + + message, val = overflow_pop(message, overflow_width, of_mask) + val = cast2u64(val) + widths = val + + cond_mask = val == max_overflow + while np.any(cond_mask): + message, val = overflow_pop(message, overflow_width, of_mask) + val = cast2u64(val) + widths = np.where(cond_mask, widths + val, widths) + cond_mask = val == max_overflow + + overflow = np.zeros_like(val) + cond_mask = widths != 0 + + while np.any(cond_mask): + counter = 0 + message, val = overflow_pop(message, overflow_width, of_mask) + val = cast2u64(val) + assert np.all(val <= max_overflow) + + op = overflow | (val << (counter * overflow_width)) + overflow = np.where(cond_mask, op, overflow) + widths = np.where(cond_mask, widths-1, widths) + cond_mask = widths != 0 + counter += 1 + + overflow_broadcast = value + overflow_broadcast[of_mask] = overflow + overflow = overflow_broadcast + value = np.where(of_mask, overflow >> 1, value) + cond_mask = np.logical_and(of_mask, overflow & 1) + value = np.where(cond_mask, -value - 1, value) + cond_mask = np.logical_and(of_mask, np.logical_not(overflow & 1)) + value = np.where(cond_mask, value + max_value_i, value) + symbol = value + cdf_offset[cdf_index_i] symbols.append(symbol) + if B == 1: decoded = compression_utils.reconstitute(np.stack(symbols, axis=0), padded_shape, unfolded_shape) diff --git a/src/compression/prior_model.py b/src/compression/prior_model.py index 7a28328..ef5f20b 100644 --- a/src/compression/prior_model.py +++ b/src/compression/prior_model.py @@ -318,7 +318,7 @@ def forward(self, x, mean, scale, **kwargs): import time - n_channels = 4 + n_channels = 24 use_blocks = True vectorize = True prior_density = PriorDensity(n_channels) @@ -326,7 +326,7 @@ def forward(self, x, mean, scale, **kwargs): loc, scale = 2.401, 3.43 n_data = 1 - toy_shape = (n_data, n_channels, 32, 32) + toy_shape = (n_data, n_channels, 64, 64) bottleneck, means = torch.randn(toy_shape), torch.randn(toy_shape) scales = torch.randn(toy_shape) * np.sqrt(scale) + loc scales = torch.clamp(scales, min=MIN_SCALE) diff --git a/src/compression/reversed_coders.py b/src/compression/reversed_coders.py deleted file mode 100644 index 1f03729..0000000 --- a/src/compression/reversed_coders.py +++ /dev/null @@ -1,150 +0,0 @@ -def ans_index_encoder_reversed(symbols, indices, cdf, cdf_length, cdf_offset, precision, - overflow_width=OVERFLOW_WIDTH, **kwargs): - - message = vrans.empty_message(()) - coding_shape = symbols.shape[1:] - symbols = symbols.astype(np.int32).flatten() - indices = indices.astype(np.int32).flatten() - - max_overflow = (1 << overflow_width) - 1 - overflow_cdf_size = (1 << overflow_width) + 1 - overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64) - - enc_statfun_overflow = _indexed_cdf_to_enc_statfun(overflow_cdf) - dec_statfun_overflow = _indexed_cdf_to_dec_statfun(overflow_cdf, - len(overflow_cdf)) - overflow_push, overflow_pop = base_codec(enc_statfun_overflow, - dec_statfun_overflow, overflow_width) - - # LIFO - last item compressed is first item decompressed - for i in reversed(range(len(indices))): # loop over flattened axis - - cdf_index = indices[i] - cdf_i = cdf[cdf_index] - cdf_length_i = cdf_length[cdf_index] - - assert (cdf_index >= 0 and cdf_index < cdf.shape[0]), ( - f"Invalid index {cdf_index} for symbol {i}") - - max_value = cdf_length_i - 2 - - assert max_value >= 0 and max_value < cdf.shape[1] - 1, ( - f"Invalid max length {max_value} for symbol {i}") - - # Data in range [offset[cdf_index], offset[cdf_index] + m - 2] is ANS-encoded - # Map values with tracked probabilities to range [0, ..., max_value] - value = symbols[i] - value -= cdf_offset[cdf_index] - - # If outside of this range, map value to non-negative integer overflow. - overflow = 0 - if (value < 0): - overflow = -2 * value - 1 - value = max_value - elif (value >= max_value): - overflow = 2 * (value - max_value) - value = max_value - - assert value >= 0 and value < cdf_length_i - 1, ( - f"Invalid shifted value {value} for symbol {i} w/ " - f"cdf_length {cdf_length[cdf_index]}") - - # Bin of discrete CDF that value belongs to - enc_statfun = _indexed_cdf_to_enc_statfun(cdf_i) - dec_statfun = _indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i) - symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision) - - message = symbol_push(message, value) - - if value == max_value: - pass - - encoded = vrans.flatten(message) - message_length = len(encoded) - return encoded, coding_shape - -def vec_ans_index_encoder_reversed(symbols, indices, cdf, cdf_length, cdf_offset, precision, - coding_shape, overflow_width=OVERFLOW_WIDTH, **kwargs): - """ - Vectorized version of `ans_index_encoder`. Incurs constant bit overhead, - but is faster. - - ANS-encodes unbounded integer data using an indexed probability table. - """ - - symbols_shape = symbols.shape - B, n_channels = symbols_shape[:2] - symbols = symbols.astype(np.int32) - indices = indices.astype(np.int32) - cdf_index = indices - - assert bool(np.all(cdf_index >= 0)) and bool(np.all(cdf_index < cdf.shape[0])), ( - "Invalid index.") - - max_value = cdf_length[cdf_index] - 2 - - assert bool(np.all(max_value >= 0)) and bool(np.all(max_value < cdf.shape[1] - 1)), ( - "Invalid max length.") - - # Map values with tracked probabilities to range [0, ..., max_value] - values = symbols - cdf_offset[cdf_index] - - # If outside of this range, map value to non-negative integer overflow. - overflow = np.zeros_like(values) - - of_mask = values < 0 - overflow = np.where(of_mask, -2 * values - 1, overflow) - values = np.where(of_mask, max_value, values) - - of_mask = values >= max_value - overflow = np.where(of_mask, 2 * (values - max_value), overflow) - values = np.where(of_mask, max_value, values) - - assert bool(np.all(values >= 0)), ( - "Invalid shifted value for current symbol - values must be non-negative.") - - assert bool(np.all(values < cdf_length[cdf_index] - 1)), ( - "Invalid shifted value for current symbol - outside cdf index bounds.") - - if B == 1: - # Vectorize on patches - there's probably a way to interlace patches with - # batch elements for B > 1 ... - if ((symbols_shape[2] % PATCH_SIZE[0] == 0) and (symbols_shape[3] % PATCH_SIZE[1] == 0)) is False: - values = utils.pad_factor(torch.Tensor(values), symbols_shape[2:], - factor=PATCH_SIZE).cpu().numpy().astype(np.int32) - indices = utils.pad_factor(torch.Tensor(indices), symbols_shape[2:], - factor=PATCH_SIZE).cpu().numpy().astype(np.int32) - - assert (values.shape[2] % PATCH_SIZE[0] == 0) and (values.shape[3] % PATCH_SIZE[1] == 0) - assert (indices.shape[2] % PATCH_SIZE[0] == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0) - - values, _ = compression_utils.decompose(values, n_channels) - cdf_index, unfolded_shape = compression_utils.decompose(indices, n_channels) - coding_shape = values.shape[1:] - - message = vrans.empty_message(coding_shape) - - # LIFO - last item compressed is first item decompressed - for i in reversed(range(len(cdf_index))): # loop over batch dimension - # Bin of discrete CDF that value belongs to - value_i = values[i] - cdf_index_i = cdf_index[i] - cdf_i = cdf[cdf_index_i] - cdf_i_length = cdf_length[cdf_index_i] - - enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i) - dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_i_length) - symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision) - - message = symbol_push(message, value_i) - - """ - Encode overflows here - """ - - encoded = vrans.flatten(message) - message_length = len(encoded) - - # print('{} symbols compressed to {:.3f} bits.'.format(B, 32 * message_length)) - - return encoded, coding_shape