diff --git a/bfp/bfp_ops.py b/bfp/bfp_ops.py index 3fd356d..8e02f7c 100644 --- a/bfp/bfp_ops.py +++ b/bfp/bfp_ops.py @@ -101,7 +101,7 @@ def float_to_bfp_batched(t, mant_bits, epsilon, rounding_mode, device, bfp_tile_ assert num_format == 'bfp' orig_shape = t.size() - t = t.view(t.size()[0], -1) + t = t.contiguous().view(orig_shape[0], -1) o = _float_to_bfp(t, mant_bits, epsilon, rounding_mode, device) return o.view(orig_shape) @@ -112,7 +112,7 @@ def tensor_to_tiled(t, orig_shape, bfp_tile_size): Output: the tiled tensor, the number of tiles in each dimension, the dimensions before and after the tiling to help 'untiling' """ - t = t.view(orig_shape[0], -1) + t = t.contiguous().view(orig_shape[0], -1) matrix_h, matrix_w = t.size() numberOf_h_tiles = (matrix_h + bfp_tile_size - 1) // bfp_tile_size @@ -211,7 +211,7 @@ def _gen_bfp_op(op, name, bfp_args): class NewOpIn(torch.autograd.Function): @staticmethod def forward(ctx, x, w): - return (float_to_bfp_batched(x, **bfp_args), w) + return (float_to_bfp_batched(x, **bfp_args), float_to_bfp_batched(w, **bfp_args)) @staticmethod def backward(ctx, grad_x, grad_w):