Skip to content

Commit

Permalink
First stable release
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin-Tan committed Sep 13, 2020
1 parent 45bbbc2 commit 1b4e184
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 184 deletions.
7 changes: 5 additions & 2 deletions assets/HiFIC_torch_colab_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@
" clocktower=\"9cbf2594f339c0d3d0f0ea25c62af52b.png\",\n",
" fresco=\"8181526d9f238726d3e1d3ec3cc56fb7.png\",\n",
" islet=\"c6658d87c608b631f5cc3fb5a8d89731.png\",\n",
" mountain=\"d3688a7285d7b2b81febe1cd72e6e22c.png\",\n",
" pasta=\"f5be5054c01d8efc834d78a991356ad6.png\",\n",
" pines=\"e903c4f4684100a6dbac1f0b9b4de760.png\",\n",
" plaza=\"d78b363974ac79908b79012f48de715d.png\",\n",
Expand Down Expand Up @@ -812,7 +813,7 @@
"source": [
"# Choose default images from CLIC2020 dataset\n",
"# Skip if uploading custom images\n",
"default_image = \"street\" #@param [\"cafe\", \"cat\", \"city\", \"clocktower\", \"fresco\", \"islet\", \"pasta\", \"pines\", \"plaza\", \"portrait\", \"shoreline\", \"street\", \"tundra\"]"
"default_image = \"street\" #@param [\"cafe\", \"cat\", \"city\", \"clocktower\", \"fresco\", \"islet\", \"mountain\", \"pasta\", \"pines\", \"plaza\", \"portrait\", \"shoreline\", \"street\", \"tundra\"]"
],
"execution_count": null,
"outputs": []
Expand Down Expand Up @@ -1111,7 +1112,9 @@
"colab_type": "text"
},
"source": [
"You can compress new images by going back to the \"Prepare Images\" heading and selecting a different default image or upload your own for compression, then running the cells below in sequence. Please open an issue if you encounter an error when running this demo."
"You can compress new images by going back to the \"Prepare Images\" heading and selecting a different default image or upload your own for compression, then running the cells below in sequence. Note that each model cannot decompress the output generated by a different model, and you need to delete the contents of `/content/out` if you want to try a different model.\n",
"\n",
"Please open an issue if you encounter an error when running this demo."
]
},
{
Expand Down
1 change: 0 additions & 1 deletion src/compression/ans.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def push(x, starts, freqs, precisions):
tail = stack_extend(tail, np.uint32(head[idxs])) # Can also modulo with bitand
head = np.copy(head) # Ensure no side-effects
head[idxs] >>= 32

head_div_freqs, head_mod_freqs = np.divmod(head, freqs)
return (head_div_freqs << np.uint(precisions)) + head_mod_freqs + starts, tail

Expand Down
14 changes: 14 additions & 0 deletions src/compression/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import numpy as np
import functools
import os
import autograd.numpy as np

from autograd import make_vjp
from autograd.extend import vspace, VSpace
from collections import namedtuple

from src.helpers import utils
Expand Down Expand Up @@ -76,6 +79,17 @@ def estimate_tails(cdf, target, shape, dtype=torch.float32, extra_counts=24):

return tails

def view_update(data, view_fun):
view_vjp, item = make_vjp(view_fun)(data)
item_vs = vspace(item)
def update(new_item):
assert item_vs == vspace(new_item), \
"Please ensure new_item shape and dtype match the data view."
diff = view_vjp(item_vs.add(new_item,
item_vs.scalar_mul(item, -np.uint64(1))))
return vspace(data).add(data, diff)
return item, update

def decompose(x, n_channels, patch_size=PATCH_SIZE):
# Decompose input x into spatial patches
if isinstance(x, torch.Tensor) is False:
Expand Down
198 changes: 169 additions & 29 deletions src/compression/entropy_coding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Codec = namedtuple('Codec', ['push', 'pop'])
cast2u64 = lambda x: np.array(x, dtype=np.uint64)

def base_codec(enc_statfun, dec_statfun, precision):
def base_codec(enc_statfun, dec_statfun, precision, log=False):
if np.any(precision >= 24):
warn('Detected precision over 28. Codecs lose accuracy at high '
'precision.')
Expand All @@ -32,7 +32,7 @@ def push(message, symbol):
start, freq = enc_statfun(symbol)
return vrans.push(message, start, freq, precision)

def pop(message):
def pop(message, log=log):
cf, pop_fun = vrans.pop(message, precision)
symbol = dec_statfun(cf)
start, freq = enc_statfun(symbol)
Expand Down Expand Up @@ -93,7 +93,8 @@ def _dec_statfun(value):
# (coding_shape) = (C,H,W) by default but can be generalized
# cdf_i: [(coding_shape), pmf_length + 2]
# value: [(coding_shape)]
assert value.shape == coding_shape, "CDF-value shape mismatch!"
assert value.shape == coding_shape, (
f"CDF-value shape mismatch! {value.shape} v. {coding_shape}")
sym_flat = np.array(
[np.searchsorted(cb, v_i, 'right') - 1 for (cb, v_i) in
zip(cdf_i_flat_ragged, value.flatten())])
Expand Down Expand Up @@ -283,7 +284,17 @@ def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset
symbols = symbols.astype(np.int32)
indices = indices.astype(np.int32)
cdf_index = indices


max_overflow = (1 << overflow_width) - 1
overflow_cdf_size = (1 << overflow_width) + 1
overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64)[None, None, None, :]

enc_statfun_overflow = _vec_indexed_cdf_to_enc_statfun(overflow_cdf)
dec_statfun_overflow = _vec_indexed_cdf_to_dec_statfun(overflow_cdf,
np.ones_like(overflow_cdf) * len(overflow_cdf))
overflow_push, overflow_pop = base_codec(enc_statfun_overflow,
dec_statfun_overflow, overflow_width)

assert bool(np.all(cdf_index >= 0)) and bool(np.all(cdf_index < cdf.shape[0])), (
"Invalid index.")

Expand All @@ -297,14 +308,11 @@ def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset

# If outside of this range, map value to non-negative integer overflow.
overflow = np.zeros_like(values)

of_mask = values < 0
overflow = np.where(of_mask, -2 * values - 1, overflow)
values = np.where(of_mask, max_value, values)

of_mask = values >= max_value
overflow = np.where(of_mask, 2 * (values - max_value), overflow)
values = np.where(of_mask, max_value, values)
of_mask_lower = values < 0
overflow = np.where(of_mask_lower, -2 * values - 1, overflow)
of_mask_upper = values >= max_value
overflow = np.where(of_mask_upper, 2 * (values - max_value), overflow)
values = np.where(np.logical_or(of_mask_lower, of_mask_upper), max_value, values)

assert bool(np.all(values >= 0)), (
"Invalid shifted value for current symbol - values must be non-negative.")
Expand All @@ -320,60 +328,138 @@ def vec_ans_index_buffered_encoder(symbols, indices, cdf, cdf_length, cdf_offset
factor=PATCH_SIZE).cpu().numpy().astype(np.int32)
indices = utils.pad_factor(torch.Tensor(indices), symbols_shape[2:],
factor=PATCH_SIZE).cpu().numpy().astype(np.int32)
overflow = utils.pad_factor(torch.Tensor(overflow), symbols_shape[2:],
factor=PATCH_SIZE).cpu().numpy().astype(np.int32)

assert (values.shape[2] % PATCH_SIZE[0] == 0) and (values.shape[3] % PATCH_SIZE[1] == 0)
assert (indices.shape[2] % PATCH_SIZE[0] == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0)

values, _ = compression_utils.decompose(values, n_channels)
overflow, _ = compression_utils.decompose(overflow, n_channels)
cdf_index, unfolded_shape = compression_utils.decompose(indices, n_channels)
coding_shape = values.shape[1:]
assert coding_shape == cdf_index.shape[1:]


# LIFO - last item in buffer is first item decompressed
for i in range(len(cdf_index)): # loop over batch dimension
# Bin of discrete CDF that value belongs to
value_i = values[i]
cdf_index_i = cdf_index[i]
cdf_i = cdf[cdf_index_i]
cdf_i_length = cdf_length[cdf_index_i]
cdf_length_i = cdf_length[cdf_index_i]
max_value_i = cdf_length_i - 2

enc_statfun = _vec_indexed_cdf_to_enc_statfun(cdf_i)
dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_i_length)
dec_statfun = _vec_indexed_cdf_to_dec_statfun(cdf_i, cdf_length_i)
symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision)

start, freq = enc_statfun(value_i)
instructions.append((start, freq, False))
instructions.append((start, freq, False, precision, 0))

"""
Encode overflows here
"""
# of_mask = values == max_value
# widths = np.zeros_like(values)
# widths[of_mask] = 1
# No-op
empty_start = np.zeros_like(value_i).astype(np.uint)
empty_freq = np.ones_like(value_i).astype(np.uint)

# overflow =

# cond_mask = overflow >> (widths * overflow_width) != 0
# while np.all(cond_mask) is False:
# widths[cond_mask] += 1
overflow_i = overflow[i]
of_mask = value_i == max_value_i

if np.any(of_mask):

widths = np.zeros_like(value_i)
cond_mask = (overflow_i >> (widths * overflow_width)) != 0

while np.any(cond_mask):
widths = np.where(cond_mask, widths+1, widths)
cond_mask = (overflow_i >> (widths * overflow_width)) != 0

val = widths
cond_mask = val >= max_overflow
while np.any(cond_mask):
print('Warning: Undefined behaviour.')
val_push = cast2u64(max_overflow)
overflow_start, overflow_freq = enc_statfun_overflow(val_push)
start = overflow_start[of_mask]
freq = overflow_start[of_mask]
instructions.append((start, freq, True, int(overflow_width), of_mask))
# val[cond_mask] -= max_overflow
val = np.where(cond_mask, val-max_overflow, val)
cond_mask = val >= max_overflow

val_push = cast2u64(val)
overflow_start, overflow_freq = enc_statfun_overflow(val_push)
start = overflow_start[of_mask]
freq = overflow_freq[of_mask]
instructions.append((start, freq, True, int(overflow_width), of_mask))

cond_mask = widths != 0
while np.any(cond_mask):
counter = 0
encoding = (overflow_i >> (counter * overflow_width)) & max_overflow
val = np.where(cond_mask, encoding, val)
val_push = cast2u64(val)
overflow_start, overflow_freq = enc_statfun_overflow(val_push)
start = overflow_start[of_mask]
freq = overflow_freq[of_mask]
instructions.append((start, freq, True, int(overflow_width), of_mask))
widths = np.where(cond_mask, widths-1, widths)
cond_mask = widths != 0
counter += 1

# val = widths

return instructions, coding_shape


def overflow_view(value, mask):
return value[mask]

def substack(codec, view_fun):
"""
Apply a codec on a subset of a message head.
view_fun should be a function: head -> subhead, for example
view_fun = lambda head: head[0]
to run the codec on only the first element of the head
"""
def push(message, start, freq, precision, mask):
head, tail = message
view_fun_ = lambda x: view_fun(x, mask)
subhead, update = compression_utils.view_update(head, view_fun_)
subhead, tail = vrans.push((subhead, tail), start, freq, precision)
return update(subhead), tail

def pop(message, precision, mask, *args, **kwargs):
head, tail = message
view_fun_ = lambda x: view_fun(x, mask)
subhead, update = compression_utils.view_update(head, view_fun_)

cf, pop_fun = vrans.pop((subhead, tail), precision)
symbol = cf
start, freq = symbol, 1

assert np.all(start <= cf) and np.all(cf < start + freq)
(subhead, tail), data = pop_fun(start, freq), symbol
updated_head = update(subhead)
return (updated_head, tail), data

return Codec(push, pop)

def vec_ans_index_encoder_flush(instructions, precision, coding_shape, overflow_width=OVERFLOW_WIDTH, **kwargs):

message = vrans.empty_message(coding_shape)

overflow_push, _ = substack(codec=None, view_fun=overflow_view)
# LIFO - last item compressed is first item decompressed
for i in reversed(range(len(instructions))):

start, freq, flag = instructions[i]
start, freq, flag, precision_i, mask = instructions[i]

if flag is False:
message = vrans.push(message, start, freq, precision)
else:
message = vrans.push(message, start, freq, overflow_width)
# Substack on overflow values
overflow_precision = precision_i
message = overflow_push(message, start, freq, overflow_precision, mask)

encoded = vrans.flatten(message)
message_length = len(encoded)
Expand Down Expand Up @@ -487,7 +573,17 @@ def vec_ans_index_decoder(encoded, indices, cdf, cdf_length, cdf_offset, precisi
message = vrans.unflatten(encoded, coding_shape)
indices = indices.astype(np.int32)
cdf_index = indices


max_overflow = (1 << overflow_width) - 1
overflow_cdf_size = (1 << overflow_width) + 1
overflow_cdf = np.arange(overflow_cdf_size, dtype=np.uint64)[None, :]

enc_statfun_overflow = _vec_indexed_cdf_to_enc_statfun(overflow_cdf)
dec_statfun_overflow = _vec_indexed_cdf_to_dec_statfun(overflow_cdf,
np.ones_like(overflow_cdf) * len(overflow_cdf))
overflow_codec = base_codec(enc_statfun_overflow,
dec_statfun_overflow, overflow_width)

assert bool(np.all(cdf_index >= 0)) and bool(np.all(cdf_index < cdf.shape[0])), (
"Invalid index.")

Expand All @@ -510,6 +606,8 @@ def vec_ans_index_decoder(encoded, indices, cdf, cdf_length, cdf_offset, precisi


symbols = []
_, overflow_pop = substack(codec=overflow_codec, view_fun=overflow_view)

for i in range(len(cdf_index)):
cdf_index_i = cdf_index[i]
cdf_i = cdf[cdf_index_i]
Expand All @@ -520,9 +618,51 @@ def vec_ans_index_decoder(encoded, indices, cdf, cdf_length, cdf_offset, precisi
symbol_push, symbol_pop = base_codec(enc_statfun, dec_statfun, precision)

message, value = symbol_pop(message)

max_value_i = cdf_length_i - 2
of_mask = value == max_value_i

if np.any(of_mask):

message, val = overflow_pop(message, overflow_width, of_mask)
val = cast2u64(val)
widths = val

cond_mask = val == max_overflow
while np.any(cond_mask):
message, val = overflow_pop(message, overflow_width, of_mask)
val = cast2u64(val)
widths = np.where(cond_mask, widths + val, widths)
cond_mask = val == max_overflow

overflow = np.zeros_like(val)
cond_mask = widths != 0

while np.any(cond_mask):
counter = 0
message, val = overflow_pop(message, overflow_width, of_mask)
val = cast2u64(val)
assert np.all(val <= max_overflow)

op = overflow | (val << (counter * overflow_width))
overflow = np.where(cond_mask, op, overflow)
widths = np.where(cond_mask, widths-1, widths)
cond_mask = widths != 0
counter += 1

overflow_broadcast = value
overflow_broadcast[of_mask] = overflow
overflow = overflow_broadcast
value = np.where(of_mask, overflow >> 1, value)
cond_mask = np.logical_and(of_mask, overflow & 1)
value = np.where(cond_mask, -value - 1, value)
cond_mask = np.logical_and(of_mask, np.logical_not(overflow & 1))
value = np.where(cond_mask, value + max_value_i, value)

symbol = value + cdf_offset[cdf_index_i]
symbols.append(symbol)


if B == 1:
decoded = compression_utils.reconstitute(np.stack(symbols, axis=0), padded_shape, unfolded_shape)

Expand Down
4 changes: 2 additions & 2 deletions src/compression/prior_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,15 @@ def forward(self, x, mean, scale, **kwargs):

import time

n_channels = 4
n_channels = 24
use_blocks = True
vectorize = True
prior_density = PriorDensity(n_channels)
prior_entropy_model = PriorEntropyModel(distribution=prior_density)

loc, scale = 2.401, 3.43
n_data = 1
toy_shape = (n_data, n_channels, 32, 32)
toy_shape = (n_data, n_channels, 64, 64)
bottleneck, means = torch.randn(toy_shape), torch.randn(toy_shape)
scales = torch.randn(toy_shape) * np.sqrt(scale) + loc
scales = torch.clamp(scales, min=MIN_SCALE)
Expand Down
Loading

0 comments on commit 1b4e184

Please sign in to comment.