From b7aaadfe5620f77a3cb6fbec8080feb2892b984d Mon Sep 17 00:00:00 2001 From: Iman Hosseini Date: Mon, 5 Dec 2022 22:42:14 -0500 Subject: [PATCH 1/7] adding LinearGELU kernel --- torch_int/kernels/include/linear.h | 8 +++ torch_int/kernels/linear.cu | 106 +++++++++++++++++++++++++++++ torch_int/nn/linear.py | 43 ++++++++++++ 3 files changed, 157 insertions(+) diff --git a/torch_int/kernels/include/linear.h b/torch_int/kernels/include/linear.h index 5df6ac6..ddccd97 100644 --- a/torch_int/kernels/include/linear.h +++ b/torch_int/kernels/include/linear.h @@ -32,6 +32,14 @@ torch::Tensor linear_relu_a8_w8_b8_o8(torch::Tensor input, // INT8 float beta // FP32 ); +// used by fc1, return INT8 +torch::Tensor linear_gelu_a8_w8_b8_o8(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT8 + float alpha, // FP32 + float beta // FP32 +); + // used by q_proj, k_proj, v_proj, return INT8 torch::Tensor linear_a8_w8_b8_o8(torch::Tensor input, // INT8 torch::Tensor weight, // INT8 diff --git a/torch_int/kernels/linear.cu b/torch_int/kernels/linear.cu index 0e11d7b..d87159c 100644 --- a/torch_int/kernels/linear.cu +++ b/torch_int/kernels/linear.cu @@ -487,5 +487,111 @@ torch::Tensor linear_relu_a8_w8_b8_o8(torch::Tensor input, // INT8 std::to_string((int)status)); } + return out; +} + +// used by fc1 +torch::Tensor linear_gelu_a8_w8_b8_o8(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // INT8 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3>; + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement, status: " + + std::to_string((int)status)); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize, status: " + + std::to_string((int)status)); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run, status: " + + std::to_string((int)status)); + } + return out; } \ No newline at end of file diff --git a/torch_int/nn/linear.py b/torch_int/nn/linear.py index 1a6e7b7..bde41c2 100644 --- a/torch_int/nn/linear.py +++ b/torch_int/nn/linear.py @@ -56,6 +56,49 @@ def from_float(module: torch.nn.Linear, input_scale, output_scale): int8_module.b = beta return int8_module +class W8A8B8O8LinearGELU(torch.nn.Module): + # For fc1 + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer('weight', torch.randint(-127, 127, (self.out_features, + self.in_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('bias', torch.zeros( + (1, self.out_features), dtype=torch.int8, requires_grad=False)) + self.register_buffer('a', torch.tensor(alpha)) + self.register_buffer('b', torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_gelu_a8_w8_b8_o8(x, self.weight, self.bias, + self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + # TODO: add zero-point to prevent the bit waste + int8_module = W8A8B8O8LinearGELU( + module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + alpha = input_scale * weight_scale / output_scale + beta = bias_scale / output_scale + int8_module.weight = int8_weight + int8_module.bias = int8_bias + int8_module.a = alpha + int8_module.b = beta + return int8_module class W8A8B8O8LinearReLU(torch.nn.Module): # For fc1 From 49b6ec766b275e874b2f9f6f8008aaee8501e765 Mon Sep 17 00:00:00 2001 From: Iman Hosseini Date: Tue, 6 Dec 2022 01:29:20 -0500 Subject: [PATCH 2/7] adding a test for linear_gelu --- tests/test_linear_kernels.py | 21 ++++++++++++++++++++- torch_int/kernels/bindings.cpp | 2 ++ torch_int/nn/linear.py | 1 + 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/test_linear_kernels.py b/tests/test_linear_kernels.py index b134579..c649d05 100644 --- a/tests/test_linear_kernels.py +++ b/tests/test_linear_kernels.py @@ -1,5 +1,5 @@ import torch -from torch_int._CUDA import linear_a8_w8_b32_o32, linear_relu_a8_w8_b8_o8, linear_a8_w8_b8_o8, linear_a8_w8_b32_o32_with_scaling, linear_a8_w8_bfp32_ofp32 +from torch_int._CUDA import linear_a8_w8_b32_o32, linear_relu_a8_w8_b8_o8, linear_a8_w8_b8_o8, linear_a8_w8_b32_o32_with_scaling, linear_a8_w8_bfp32_ofp32, linear_gelu_a8_w8_b8_o8 from icecream import ic @@ -85,6 +85,23 @@ def test_quant_linear_relu_a8_w8_b8_o8(): ic(torch.allclose(y_gt.float(), y.float().cpu(), atol=1)) +@torch.no_grad() +def test_quant_linear_gelu_a8_w8_b8_o8(): + B, M, N = 128, 512, 1024 + weight = torch.randint(-128, 127, (N, M), dtype=torch.int8) + bias = torch.randint(-128, 127, (N,), dtype=torch.int8) + x = torch.randint(-128, 127, (B, M), dtype=torch.int8) + alpha, beta = 0.001, 0.01 + linear = torch.nn.Linear(M, N, bias=True) + linear.weight.data = weight.float() * alpha + linear.bias.data = bias.float() * beta + y_gt = linear(x.float()) + y_gt = y_gt.clamp(0, 127).round().long() + y = linear_gelu_a8_w8_b8_o8(x.cuda(), weight.cuda(), + bias.cuda(), alpha, beta).cpu().long() + ic(torch.allclose(y_gt.float(), y.float().cpu(), atol=1)) + + if __name__ == '__main__': print('test_quant_linear_a8_w8_b32_o32') test_quant_linear_a8_w8_b32_o32() @@ -96,3 +113,5 @@ def test_quant_linear_relu_a8_w8_b8_o8(): test_quant_linear_a8_w8_b8_o8() print('test_quant_linear_relu_a8_w8_b8_o8') test_quant_linear_relu_a8_w8_b8_o8() + print('test_quant_linear_gelu_a8_w8_b8_o8') + test_quant_linear_gelu_a8_w8_b8_o8() diff --git a/torch_int/kernels/bindings.cpp b/torch_int/kernels/bindings.cpp index 4eaf7bc..bbe3398 100644 --- a/torch_int/kernels/bindings.cpp +++ b/torch_int/kernels/bindings.cpp @@ -5,6 +5,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("linear_relu_a8_w8_b8_o8", &linear_relu_a8_w8_b8_o8, "Linear ReLU (INT8)"); + m.def("linear_gelu_a8_w8_b8_o8", &linear_relu_a8_w8_b8_o8, + "Linear ReLU (INT8)"); m.def("linear_a8_w8_b32_o32", &linear_a8_w8_b32_o32, "Linear (INT32)"); m.def("linear_a8_w8_bfp32_ofp32", &linear_a8_w8_bfp32_ofp32, "Linear (I8-OFP32)"); diff --git a/torch_int/nn/linear.py b/torch_int/nn/linear.py index bde41c2..dfa70de 100644 --- a/torch_int/nn/linear.py +++ b/torch_int/nn/linear.py @@ -1,6 +1,7 @@ import torch from .._CUDA import (linear_a8_w8_b32_o32, linear_relu_a8_w8_b8_o8, + linear_gelu_a8_w8_b8_o8, linear_a8_w8_b8_o8, linear_a8_w8_b32_o32_with_scaling, linear_a8_w8_bfp32_ofp32 From 60a57adee1bc915c2acffc869f455d08b7db9dd6 Mon Sep 17 00:00:00 2001 From: Iman Hosseini Date: Tue, 6 Dec 2022 04:19:07 -0500 Subject: [PATCH 3/7] adding support for GPT-J + a test for it --- tests/test_gptj_attention.py | 57 +++++ torch_int/models/gptj.py | 422 +++++++++++++++++++++++++++++++++++ 2 files changed, 479 insertions(+) create mode 100644 tests/test_gptj_attention.py create mode 100644 torch_int/models/gptj.py diff --git a/tests/test_gptj_attention.py b/tests/test_gptj_attention.py new file mode 100644 index 0000000..576c30e --- /dev/null +++ b/tests/test_gptj_attention.py @@ -0,0 +1,57 @@ +import torch +from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJConfig +from torch_int.models.gptj import Int8GPTJAttention +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +from typing import Tuple +from icecream import ic +from functools import partial + + +def store_act(module, x, y, act_dict, name): + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + act_dict[name] = (x, y) + + + +@torch.no_grad() +def test_gptj_attention(): + B, L, D, H = 1, 16, 16, 1 + x = torch.randn(B, L, D) + x_scale = x.abs().max() / 127 + config = GPTJConfig() + config.n_embd = D + config.n_head = H + attn = GPTJAttention(config) + attn.eval() + act_dict = {} + for name, module in attn.named_modules(): + if isinstance(module, torch.nn.Linear): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + y = attn(x)[0] + + q_output_scale = act_dict['q_proj'][1].abs().max() / 127 + k_output_scale = act_dict['k_proj'][1].abs().max() / 127 + v_output_scale = act_dict['v_proj'][1].abs().max() / 127 + out_input_scale = act_dict['out_proj'][0].abs().max() / 127 + int8_attn = Int8GPTJAttention.from_float( + attn, x_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale).cuda() + int8_attn.eval() + q_act_dict = {} + for name, module in int8_attn.named_modules(): + if isinstance(module, (W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU)): + module.register_forward_hook( + partial(store_act, act_dict=q_act_dict, name=name)) + q_x = (x / x_scale).round().to(torch.int8) + y_hat = int8_attn(q_x.cuda())[0].cpu() + + # ic(y_hat) + r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + test_gptj_attention() diff --git a/torch_int/models/gptj.py b/torch_int/models/gptj.py new file mode 100644 index 0000000..8da1f77 --- /dev/null +++ b/torch_int/models/gptj.py @@ -0,0 +1,422 @@ +import torch +from torch import nn +from transformers.models.gptj.modeling_gptj import ( + GPTJConfig, + GPTJForCausalLM, + GPTJModel, + GPTJPreTrainedModel, + GPTJAttention, + GPTJMLP, + GPTJBlock, + BaseModelOutputWithPast +) + +from typing import Optional, Tuple, List +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU +from torch_int.nn.fused import LayerNormQ +from transformers.utils import logging +from torch_int.nn.bmm import BMM_S8T_S8N_S8T, BMM_S8T_S8N_F32T + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(seq_len, dtype=torch.float), inv_freq).to(x.device).float() + ) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map(lambda t: duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :], sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class Int8GPTJAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, n_embd, n_head, max_position_embeddings, rotary_dim = None): + super().__init__() + + max_positions = max_position_embeddings + self.embed_dim = n_embd + self.num_attention_heads = n_head + self.head_dim = n_embd // n_head + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + + if (self.head_dim * self.num_attention_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_attention_heads})." + ) + + self.attention_weight_scale = 1.0 + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + self.k_proj = W8A8B8O8Linear(n_embd, n_embd) + self.v_proj = W8A8B8O8Linear(n_embd, n_embd) + self.q_proj = W8A8B8O8Linear(n_embd, n_embd) + self.out_proj = W8A8BFP32OFP32Linear(n_embd, n_embd) + self.rotary_dim = None + if rotary_dim is not None: + self.rotary_dim = rotary_dim + + @staticmethod + @torch.no_grad() + def from_float(module: GPTJAttention, + input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float): + int8_module = Int8GPTJAttention(module.embed_dim, module.num_attention_heads, module.bias.shape[3], module.rotary_dim) + # Fuse the scaling into the q_proj output scale + q_output_scale = q_output_scale * module.scale_attn + # TODO: GPTJ has no bias, find a way to elide these later + module.q_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) + module.v_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) + module.k_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) + module.out_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) + module.q_proj.weight *= module.scale_attn + int8_module.q_proj = W8A8B8O8Linear.from_float( + module.q_proj, input_scale, q_output_scale) + int8_module.k_proj = W8A8B8O8Linear.from_float( + module.k_proj, input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float( + module.v_proj, input_scale, v_output_scale) + int8_module.out_proj = W8A8BFP32OFP32Linear.from_float( + module.out_proj, out_input_scale) + int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale( + q_output_scale, k_output_scale) + + # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 + int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale( + 1.0 / 127, v_output_scale, out_input_scale) + return int8_module + + def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): + """ + Splits hidden dim into attn_head_size and num_attention_heads + """ + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(new_shape) + if rotary: + return tensor + if len(tensor.shape) == 5: + # (batch, blocks, head, block_length, head_features) + return tensor.permute(0, 1, 3, 2, 4) + elif len(tensor.shape) == 4: + # (batch, head, seq_length, head_features) + return tensor.permute(0, 2, 1, 3) + else: + raise ValueError( + f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError( + f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + \ + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - + query_length: key_length, :key_length].to(torch.bool) + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.int8) + key = key.to(torch.int8) + + # attn_weights = torch.matmul(query, key.transpose(-1, -2)) + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + key = key.view(*proj_shape) + query = self._shape( + query_states, tgt_len, bsz).view(*proj_shape) + attn_weights = self.qk_bmm(query, key) + + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor( + mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + attn_weights = attn_weights / self.scale_attn + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + # attn_weights = attn_weights.to(value.dtype) + attn_weights.mul_(127).round_() + attn_weights = attn_weights.to(torch.int8) + # attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + # attn_output = torch.matmul(attn_weights, value) + attn_output = self.pv_bmm(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads( + query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads( + key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads( + value, self.num_attention_heads, self.head_dim, False) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim:] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim:] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads( + attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + # attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class Int8GPTJMLP(nn.Module): + # in MLP: intermediate_size= 4 * embed_dim + def __init__(self, intermediate_size, embed_dim): + super().__init__() + + self.fc1 = W8A8B8O8LinearGELU(embed_dim, intermediate_size) + self.fc2 = W8A8BFP32OFP32Linear(intermediate_size, embed_dim) + + def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + @staticmethod + def from_float(module: GPTJMLP, fc1_input_scale: float, fc2_input_scale: float): + int8_module = Int8GPTJMLP( + module.fc_in.out_features, module.fc_in.in_features) + int8_module.fc1 = W8A8B8O8LinearGELU.from_float( + module.fc_in, fc1_input_scale) + int8_module.fc2 = W8A8BFP32OFP32Linear.from_float( + module.fc_out, fc2_input_scale) + return int8_module + + +class Int8GPTJBlock(nn.Module): + def __init__(self, config): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + self.ln_1 = LayerNormQ(config.n_embd) + self.attn = Int8GPTJAttention(config) + self.mlp = Int8GPTJMLP(inner_dim, config.n_embd) + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + @staticmethod + def from_float(module, attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float, + fc1_input_scale: float, + fc2_input_scale: float): + int8_module = Int8GPTJBlock(config) + int8_module.mlp = Int8GPTJMLP.from_float( + module.mlp, fc1_input_scale, fc2_input_scale) + int8_module.ln_1 = LayerNormQ.from_float(module.ln_1, attn_input_scale) + int8_module.attn = Int8GPTJAttention.from_float( + module.attn, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale) + + +class Int8GPTJModel(GPTJPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.h = nn.ModuleList([Int8PTJBlock(config) + for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + get_input_embeddings = GPTJModel.get_input_embeddings + set_input_embeddings = GPTJModel.set_input_embeddings + forward = GPTJModel.forward + + @staticmethod + def from_float(module, decoder_layer_scales): + int8_module = Int8GPTJModel(module.config) + int8_module.h = nn.ModuleList( + [Int8GPTJBlock.from_float(mm, decoder_layer_scales) for mm in module.h]) + return int8_module + + +class Int8GPTJForCausalLM(GPTJPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = Int8GPTJModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @staticmethod + def from_float(module, decoder_layer_scales): + int8_module = Int8GPTJForCausalLM(module.config) + int8_module.transformer = Int8GPTJModel(config, decoder_layer_scales) + int8_module.lm_head = module.lm_head + return int8_module + + get_input_embeddings = GPTJForCausalLM.get_input_embeddings + set_input_embeddings = GPTJForCausalLM.set_input_embeddings + get_output_embeddings = GPTJForCausalLM.get_output_embeddings + set_output_embeddings = GPTJForCausalLM.set_output_embeddings + forward = GPTJForCausalLM.forward + prepare_inputs_for_generation = GPTJForCausalLM.prepare_inputs_for_generation + _reorder_cache = GPTJForCausalLM._reorder_cache + parallelize = GPTJForCausalLM.parallelize + deparallelize = GPTJForCausalLM.deparallelize From fc015ec9ed1405ccc8e7e02c64172719cd4f5098 Mon Sep 17 00:00:00 2001 From: Iman Hosseini Date: Tue, 6 Dec 2022 04:43:21 -0500 Subject: [PATCH 4/7] trying to fix bugs --- torch_int/models/gptj.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/torch_int/models/gptj.py b/torch_int/models/gptj.py index 8da1f77..8c09629 100644 --- a/torch_int/models/gptj.py +++ b/torch_int/models/gptj.py @@ -85,7 +85,11 @@ def __init__(self, n_embd, n_head, max_position_embeddings, rotary_dim = None): self.rotary_dim = None if rotary_dim is not None: self.rotary_dim = rotary_dim + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2).contiguous() + @staticmethod @torch.no_grad() def from_float(module: GPTJAttention, @@ -133,6 +137,8 @@ def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): elif len(tensor.shape) == 4: # (batch, head, seq_length, head_features) return tensor.permute(0, 2, 1, 3) + elif len(tensor.shape) == 3: + return tensor.permute(1, 0, 2) else: raise ValueError( f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") @@ -145,6 +151,8 @@ def _merge_heads(self, tensor, num_attention_heads, attn_head_size): tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() elif len(tensor.shape) == 4: tensor = tensor.permute(0, 2, 1, 3).contiguous() + elif len(tensor.shape) == 3: + tensor = tensor.permute(1, 0, 2).contiguous() else: raise ValueError( f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") @@ -171,10 +179,10 @@ def _attn( key = key.to(torch.int8) # attn_weights = torch.matmul(query, key.transpose(-1, -2)) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) + proj_shape = (self.bsz * self.num_attention_heads, -1, self.head_dim) key = key.view(*proj_shape) query = self._shape( - query_states, tgt_len, bsz).view(*proj_shape) + query, self.tgt_len, 1).view(*proj_shape) attn_weights = self.qk_bmm(query, key) mask_value = torch.finfo(attn_weights.dtype).min @@ -201,8 +209,13 @@ def _attn( attn_weights = attn_weights * head_mask # attn_output = torch.matmul(attn_weights, value) + value = value[0] + attn_weights = attn_weights[0] + # print(attn_weights.shape) + # print(value.shape) attn_output = self.pv_bmm(attn_weights, value) - + value = value.view(1, *value.shape) + attn_weights = attn_weights.view(1, *attn_weights.shape) return attn_output, attn_weights def forward( @@ -214,7 +227,7 @@ def forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, ): - + self.bsz, self.tgt_len, _ = hidden_states.size() query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) From 88e17b2e62917fe3c285d94e3272782736006a4b Mon Sep 17 00:00:00 2001 From: Iman Hosseini Date: Wed, 7 Dec 2022 12:21:08 -0500 Subject: [PATCH 5/7] -> updates + test, mlp has high r2 :( --- tests/test_gptj.py | 64 ++++++++++++++ tests/test_gptj_attention.py | 10 +-- tests/test_gptj_block.py | 82 ++++++++++++++++++ tests/test_gptj_mlp.py | 54 ++++++++++++ tests/test_gptj_model.py | 61 +++++++++++++ torch_int/models/gptj.py | 162 ++++++++++++++++++++++------------- 6 files changed, 367 insertions(+), 66 deletions(-) create mode 100644 tests/test_gptj.py create mode 100644 tests/test_gptj_block.py create mode 100644 tests/test_gptj_mlp.py create mode 100644 tests/test_gptj_model.py diff --git a/tests/test_gptj.py b/tests/test_gptj.py new file mode 100644 index 0000000..dcaecef --- /dev/null +++ b/tests/test_gptj.py @@ -0,0 +1,64 @@ +import torch +from torch_int.models.opt import Int8OPTForCausalLM +from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM +from icecream import ic +from transformers import GPT2Tokenizer +from datasets import load_dataset +from tqdm import tqdm + + +class Evaluator: + def __init__(self, dataset, tokenizer, device): + self.dataset = dataset + self.tokenizer = tokenizer + self.device = device + + # tokenize the dataset + def tokenize_function(examples): + example = self.tokenizer(examples['text']) + return example + + self.dataset = self.dataset.map(tokenize_function, batched=True) + self.dataset.set_format(type='torch', columns=['input_ids']) + + @torch.no_grad() + def evaluate(self, model): + model.eval() + # The task is to predict the last token of the input. + total, hit = 0, 0 + pbar = tqdm(self.dataset, desc='Evaluating') + for batch in pbar: + input_ids = batch['input_ids'].to(self.device).unsqueeze(0) + # label is the last token which is not the padding token + label = input_ids[:, -1] + outputs = model(input_ids) + last_token_logits = outputs.logits[:, -2, :] + pred = last_token_logits.argmax(dim=-1) + total += label.size(0) + hit += (pred == label).sum().item() + pbar.set_postfix({'acc': hit / total}) + acc = hit / total + return acc + + +@torch.no_grad() +def test_opt(): + dataset = load_dataset('lambada', split='validation[:1000]') + tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-13b') + evaluator = Evaluator(dataset, tokenizer, 'cuda') + int8_model_path = '/dataset/opt/opt-13b-smoothquant' + # precision = 'fp16' + precision = 'int8' + if precision == 'int8': + model = Int8OPTForCausalLM.from_pretrained(int8_model_path, + device_map='auto', torch_dtype=torch.float16) + else: + model = OPTForCausalLM.from_pretrained('facebook/opt-13b', + device_map='auto', + torch_dtype=torch.float16) + acc = evaluator.evaluate(model) + ic(acc) + + +if __name__ == '__main__': + test_opt() diff --git a/tests/test_gptj_attention.py b/tests/test_gptj_attention.py index 576c30e..dcaaff2 100644 --- a/tests/test_gptj_attention.py +++ b/tests/test_gptj_attention.py @@ -6,8 +6,8 @@ from icecream import ic from functools import partial - def store_act(module, x, y, act_dict, name): + # print(f"{name}: {y.mean()}") if isinstance(x, tuple): x = x[0] if isinstance(y, tuple): @@ -15,15 +15,15 @@ def store_act(module, x, y, act_dict, name): act_dict[name] = (x, y) - @torch.no_grad() def test_gptj_attention(): - B, L, D, H = 1, 16, 16, 1 + B, L, D, H = 1, 32, 128, 1 x = torch.randn(B, L, D) x_scale = x.abs().max() / 127 config = GPTJConfig() config.n_embd = D config.n_head = H + config.rotary_dim = None attn = GPTJAttention(config) attn.eval() act_dict = {} @@ -31,7 +31,8 @@ def test_gptj_attention(): if isinstance(module, torch.nn.Linear): module.register_forward_hook( partial(store_act, act_dict=act_dict, name=name)) - y = attn(x)[0] + y = attn(x) + y = y[0] q_output_scale = act_dict['q_proj'][1].abs().max() / 127 k_output_scale = act_dict['k_proj'][1].abs().max() / 127 @@ -48,7 +49,6 @@ def test_gptj_attention(): q_x = (x / x_scale).round().to(torch.int8) y_hat = int8_attn(q_x.cuda())[0].cpu() - # ic(y_hat) r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() ic(r2) diff --git a/tests/test_gptj_block.py b/tests/test_gptj_block.py new file mode 100644 index 0000000..49348dc --- /dev/null +++ b/tests/test_gptj_block.py @@ -0,0 +1,82 @@ +import torch +from transformers.models.gptj.modeling_gptj import GPTJBlock, GPTJConfig +from torch_int.models.gptj import Int8GPTJBlock +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +from typing import Tuple +from icecream import ic +from functools import partial +import matplotlib.pyplot as plt + +def store_act(module, x, y, act_dict, name): + # print(f"{name}: {y.mean()}") + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + act_dict[name] = (x, y) + + +@torch.no_grad() +def test_gptj_block(): + config : GPTJConfig = GPTJConfig.from_pretrained('Salesforce/codegen-350M-mono') + B, L, D, H = 1, 256, config.n_embd, config.n_head + x = torch.randn(B, L, D)*20 + blk = GPTJBlock(config) + blk.eval() + act_dict = {} + for name, module in blk.named_modules(): + if isinstance(module, torch.nn.Linear): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + if isinstance(module, torch.nn.LayerNorm): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + + y = blk(x) + y = y[0].cpu() + print(act_dict.keys()) + # exit(0) + ln1_input_scale = act_dict['ln_1'][1].abs().max() / 127 + attn_input_scale = act_dict['attn.q_proj'][0].abs().max() / 127 + q_output_scale = act_dict['attn.q_proj'][1].abs().max() / 127 + k_output_scale = act_dict['attn.k_proj'][1].abs().max() / 127 + v_output_scale = act_dict['attn.v_proj'][1].abs().max() / 127 + out_input_scale = act_dict['attn.out_proj'][0].abs().max() / 127 + fc1_input_scale = act_dict['mlp.fc_in'][0].abs().max() / 127 + fc2_input_scale = act_dict['mlp.fc_out'][0].abs().max() / 127 + int8_blk = Int8GPTJBlock.from_float( + blk, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale, fc1_input_scale, fc2_input_scale, + ln1_input_scale).cuda() + int8_blk.eval() + q_act_dict = {} + + y_hat = int8_blk(x.cuda())[0].cpu() + rd = blk.dbgi + md = int8_blk.dbgi + RN = 256 + ra = rd['atto'].cpu().flatten()[:RN] + ma = md['attoX'].cpu().flatten()[:RN] + rf = rd['ffn'].cpu().flatten()[:RN] + mf = md['ffnX'].cpu().flatten()[:RN] + rr = rd['resi'].cpu().flatten()[:RN] + mr = md['resiX'].cpu().flatten()[:RN] + # + # plt.plot(ra.flatten()) + print(f"MAX: a:{ra.abs().max()} f:{rf.abs().max()} r:{rr.abs().max()+0.0000001}") + plt.plot(ma - ra, color='r') + plt.savefig("Xa.jpg", dpi=300) + plt.cla() + # plt.plot(rf) + plt.plot(mf - rf, color='r') + plt.savefig("Xf.jpg", dpi=300) + plt.cla() + # plt.plot(rr.flatten()) + plt.plot(mr - rr, color='r') + plt.savefig("Xr.jpg", dpi=300) + + r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + test_gptj_block() diff --git a/tests/test_gptj_mlp.py b/tests/test_gptj_mlp.py new file mode 100644 index 0000000..1e21a00 --- /dev/null +++ b/tests/test_gptj_mlp.py @@ -0,0 +1,54 @@ +import torch +from transformers.models.gptj.modeling_gptj import GPTJMLP, GPTJConfig +from torch_int.models.gptj import Int8GPTJMLP +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +from typing import Tuple +from icecream import ic +from functools import partial +from torch_int.nn.fused import LayerNormQ +from torch.nn import LayerNorm + +def store_act(module, x, y, act_dict, name): + # print(f"{name}: {y.mean()}") + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + act_dict[name] = (x, y) + + +@torch.no_grad() +def test_gptj_mlp(): + B, L, D, H = 1, 16, 32, 1 + x = torch.randn(B, L, D)*40 + x = torch.clamp(x, -127, 127) + x_scale = x.abs().max() / 127 + config = GPTJConfig() + config.n_embd = D + config.n_head = H + intermediate_size = 4*D + config.rotary_dim = None + mlp = GPTJMLP(intermediate_size, config) + mlp.eval() + act_dict = {} + for name, module in mlp.named_modules(): + if isinstance(module, torch.nn.Linear): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + y = mlp(x) + y = y[0] + + fc_in_scale = act_dict['fc_in'][0].abs().max() / 127 + fc_out_scale = act_dict['fc_out'][0].abs().max() / 127 + int8_mlp = Int8GPTJMLP.from_float( + mlp, fc_in_scale, fc_out_scale).cuda() + int8_mlp.eval() + q_x = x.round().to(torch.int8) + y_hat = int8_mlp(q_x.cuda()).cpu() + print(y_hat.shape) + r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + test_gptj_mlp() diff --git a/tests/test_gptj_model.py b/tests/test_gptj_model.py new file mode 100644 index 0000000..6a27cb0 --- /dev/null +++ b/tests/test_gptj_model.py @@ -0,0 +1,61 @@ +import torch +from transformers.models.gptj.modeling_gptj import GPTJModel, GPTJConfig +from torch_int.models.gptj import Int8GPTJModel +from icecream import ic +from functools import partial + + +def store_act(module, x, y, act_dict, name): + if isinstance(x, tuple): + x = x[0] + if isinstance(y, tuple): + y = y[0] + act_dict[name] = (x, y) + + +@torch.no_grad() +def test_gptj_model_layer(): + config = GPTJConfig.from_pretrained('Salesforce/codegen-350M-mono') + + B, L, D, H = 1, 256, config.n_embd, config.n_head + + x = torch.randint(0, config.vocab_size, (B, L)) + model = GPTJModel(config) + model.eval() + act_dict = {} + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + module.register_forward_hook( + partial(store_act, act_dict=act_dict, name=name)) + y = model(x)[0].cuda() + decoder_layer_scales = [] + for idx in range(config.n_layer): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"h.{idx}.attn.q_proj"][0].abs( + ).max() / 127 + scale_dict["q_output_scale"] = act_dict[f"h.{idx}.attn.q_proj"][1].abs( + ).max() / 127 + scale_dict["k_output_scale"] = act_dict[f"h.{idx}.attn.k_proj"][1].abs( + ).max() / 127 + scale_dict["v_output_scale"] = act_dict[f"h.{idx}.attn.v_proj"][1].abs( + ).max() / 127 + scale_dict["out_input_scale"] = act_dict[f"h.{idx}.attn.out_proj"][0].abs( + ).max() / 127 + scale_dict["fc1_input_scale"] = act_dict[f"h.{idx}.mlp.fc_in"][0].abs( + ).max() / 127 + scale_dict["fc2_input_scale"] = act_dict[f"h.{idx}.mlp.fc_out"][0].abs( + ).max() / 127 + decoder_layer_scales.append(scale_dict) + + int8_model = Int8GPTJModel.from_float(model, decoder_layer_scales).cuda() + int8_model.eval() + + y_hat = int8_model(x.cuda())[0] + + # # ic(y_hat) + r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() + ic(r2) + + +if __name__ == '__main__': + test_gptj_model_layer() diff --git a/torch_int/models/gptj.py b/torch_int/models/gptj.py index 8c09629..63c2154 100644 --- a/torch_int/models/gptj.py +++ b/torch_int/models/gptj.py @@ -12,7 +12,7 @@ ) from typing import Optional, Tuple, List -from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearReLU +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU from torch_int.nn.fused import LayerNormQ from transformers.utils import logging from torch_int.nn.bmm import BMM_S8T_S8N_S8T, BMM_S8T_S8N_F32T @@ -49,15 +49,16 @@ def duplicate_interleave(m): def apply_rotary_pos_emb(x, sincos, offset=0): sin, cos = map(lambda t: duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :], sincos) # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) - return (x * cos) + (rotate_every_two(x) * sin) + r = (x.to(torch.float) * cos) + (rotate_every_two(x) * sin) + return r class Int8GPTJAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, n_embd, n_head, max_position_embeddings, rotary_dim = None): - super().__init__() - + super().__init__() + self.dbgi = {} max_positions = max_position_embeddings self.embed_dim = n_embd self.num_attention_heads = n_head @@ -75,7 +76,6 @@ def __init__(self, n_embd, n_head, max_position_embeddings, rotary_dim = None): f" and `num_heads`: {self.num_attention_heads})." ) - self.attention_weight_scale = 1.0 self.qk_bmm = BMM_S8T_S8N_F32T(1.0) self.pv_bmm = BMM_S8T_S8N_S8T(1.0) self.k_proj = W8A8B8O8Linear(n_embd, n_embd) @@ -100,13 +100,16 @@ def from_float(module: GPTJAttention, out_input_scale: float): int8_module = Int8GPTJAttention(module.embed_dim, module.num_attention_heads, module.bias.shape[3], module.rotary_dim) # Fuse the scaling into the q_proj output scale - q_output_scale = q_output_scale * module.scale_attn + scale_h = module.head_dim**-0.5 + q_output_scale = q_output_scale * scale_h + module.q_proj.weight *= scale_h + # k_output_scale = k_output_scale * scale_h + # module.k_proj.weight *= scale_h # TODO: GPTJ has no bias, find a way to elide these later module.q_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) module.v_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) module.k_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) module.out_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) - module.q_proj.weight *= module.scale_attn int8_module.q_proj = W8A8B8O8Linear.from_float( module.q_proj, input_scale, q_output_scale) int8_module.k_proj = W8A8B8O8Linear.from_float( @@ -119,6 +122,7 @@ def from_float(module: GPTJAttention, q_output_scale, k_output_scale) # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 + print(f"{v_output_scale}/{out_input_scale}") int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale( 1.0 / 127, v_output_scale, out_input_scale) return int8_module @@ -132,16 +136,11 @@ def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): if rotary: return tensor if len(tensor.shape) == 5: - # (batch, blocks, head, block_length, head_features) - return tensor.permute(0, 1, 3, 2, 4) + return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) elif len(tensor.shape) == 4: - # (batch, head, seq_length, head_features) - return tensor.permute(0, 2, 1, 3) - elif len(tensor.shape) == 3: - return tensor.permute(1, 0, 2) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) else: - raise ValueError( - f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") def _merge_heads(self, tensor, num_attention_heads, attn_head_size): """ @@ -151,13 +150,9 @@ def _merge_heads(self, tensor, num_attention_heads, attn_head_size): tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() elif len(tensor.shape) == 4: tensor = tensor.permute(0, 2, 1, 3).contiguous() - elif len(tensor.shape) == 3: - tensor = tensor.permute(1, 0, 2).contiguous() else: - raise ValueError( - f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") - new_shape = tensor.size()[:-2] + \ - (num_attention_heads * attn_head_size,) + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) return tensor.view(new_shape) def _attn( @@ -175,16 +170,27 @@ def _attn( query_length: key_length, :key_length].to(torch.bool) # Keep the attention weights computation in fp32 to avoid overflow issues - query = query.to(torch.int8) - key = key.to(torch.int8) + # query = query.to(torch.int8) + # key = key.to(torch.int8) # attn_weights = torch.matmul(query, key.transpose(-1, -2)) + # proj_shape = (self.bsz * self.num_attention_heads, -1, self.head_dim) + # key = key.view(*proj_shape) + # query = self._shape( + # query, self.tgt_len, 1).view(*proj_shape) + + # key = key.transpose(-1, -2) proj_shape = (self.bsz * self.num_attention_heads, -1, self.head_dim) - key = key.view(*proj_shape) - query = self._shape( - query, self.tgt_len, 1).view(*proj_shape) + key = key.reshape(*proj_shape) + query = query.view(*proj_shape) + query = query.contiguous() + key = key.contiguous() + print(f"I8key:{key.shape}, query:{query.shape}") attn_weights = self.qk_bmm(query, key) - + self.dbgi["qk"] = attn_weights.clone() + print(f"I8OUT: {attn_weights.shape}") + attn_weights = attn_weights.view(self.bsz, self.num_attention_heads, self.tgt_len, key_length) + print(f"I8OUTpost: {attn_weights.shape}") mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` @@ -192,7 +198,7 @@ def _attn( mask_value, dtype=attn_weights.dtype).to(attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) - attn_weights = attn_weights / self.scale_attn + # attn_weights = attn_weights / self.scale_attn if attention_mask is not None: # Apply the attention mask @@ -202,25 +208,35 @@ def _attn( # attn_weights = attn_weights.to(value.dtype) attn_weights.mul_(127).round_() attn_weights = attn_weights.to(torch.int8) - # attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: attn_weights = attn_weights * head_mask - - # attn_output = torch.matmul(attn_weights, value) - value = value[0] - attn_weights = attn_weights[0] - # print(attn_weights.shape) - # print(value.shape) + self.dbgi["Am1"] = value.clone() + attn_weights = attn_weights.view(self.bsz * self.num_attention_heads, -1, self.tgt_len).contiguous() + print(f"VAL:{value.shape}") + value = value.transpose(2,3) + print(f"VAL:{value.shape}") + value = value.reshape(self.num_attention_heads * self.bsz, self.head_dim, self.tgt_len).contiguous() + # value = value.reshape(self.num_attention_heads * self.bsz, self.head_dim, self.tgt_len).contiguous() + print(f"I8: att:{attn_weights.shape}, v: {value.shape}") + self.dbgi["pv_a"] = attn_weights.clone() + self.dbgi["pv_v"] = value.clone() + print(f"ATTNPROBS:{attn_weights.to(torch.float).abs().mean()}|VAL:{value.to(torch.float).abs().mean()}") + print(f"att___:{attn_weights.shape}, value__:{value.shape}") attn_output = self.pv_bmm(attn_weights, value) - value = value.view(1, *value.shape) - attn_weights = attn_weights.view(1, *attn_weights.shape) + # attn_output = torch.matmul(attn_weights, value) + # print(f"===F:{attn_output[:16]}") + self.dbgi["pv"] = attn_output.clone() + print(f"ASIZE_I8: {torch.numel(attn_output)}") + attn_weights = attn_weights.view(self.bsz, self.num_attention_heads, self.tgt_len, key_length) + attn_output = attn_output.view(self.bsz, self.num_attention_heads, self.tgt_len, self.head_dim) + print(f"MOUT: W:{attn_weights.shape}, O: {attn_output.shape}") return attn_output, attn_weights def forward( self, - hidden_states: Optional[torch.FloatTensor], + hidden_states: Optional[torch.Tensor], attention_mask: Optional[torch.FloatTensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -228,9 +244,11 @@ def forward( output_attentions: Optional[bool] = False, ): self.bsz, self.tgt_len, _ = hidden_states.size() + print(f"HS: {hidden_states.shape}") query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) + self.dbgi["vO"] = value.clone() query = self._split_heads( query, self.num_attention_heads, self.head_dim, True) @@ -257,12 +275,12 @@ def forward( k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) - key = torch.cat([k_rot, k_pass], dim=-1) - query = torch.cat([q_rot, q_pass], dim=-1) + key = torch.cat([k_rot, k_pass], dim=-1).to(torch.int8) + query = torch.cat([q_rot, q_pass], dim=-1).to(torch.int8) else: sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) - key = apply_rotary_pos_emb(key, sincos, offset=offset) - query = apply_rotary_pos_emb(query, sincos, offset=offset) + key = apply_rotary_pos_emb(key, sincos, offset=offset).to(torch.int8) + query = apply_rotary_pos_emb(query, sincos, offset=offset).to(torch.int8) key = key.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3) @@ -277,13 +295,17 @@ def forward( present = (key, value) else: present = None - + # tvals = self.dbgi[0] + # r2q = (tvals[0] - query).pow(2).mean() / tvals[0].pow(2).mean() + # r2k = (tvals[0] - query).pow(2).mean() / tvals[0].pow(2).mean() + # r2v = (tvals[0] - query).pow(2).mean() / tvals[0].pow(2).mean() # compute self-attention: V x Softmax(QK^T) attn_output, attn_weights = self._attn( query, key, value, attention_mask, head_mask) - + print(f"I8-attO: {attn_output.shape}") attn_output = self._merge_heads( attn_output, self.num_attention_heads, self.head_dim) + attn_output = attn_output.contiguous() attn_output = self.out_proj(attn_output) # attn_output = self.resid_dropout(attn_output) @@ -312,19 +334,19 @@ def from_float(module: GPTJMLP, fc1_input_scale: float, fc2_input_scale: float): int8_module = Int8GPTJMLP( module.fc_in.out_features, module.fc_in.in_features) int8_module.fc1 = W8A8B8O8LinearGELU.from_float( - module.fc_in, fc1_input_scale) + module.fc_in, fc1_input_scale, fc2_input_scale) int8_module.fc2 = W8A8BFP32OFP32Linear.from_float( module.fc_out, fc2_input_scale) return int8_module class Int8GPTJBlock(nn.Module): - def __init__(self, config): + def __init__(self, inner_dim, n_embd): super().__init__() - inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd - self.ln_1 = LayerNormQ(config.n_embd) - self.attn = Int8GPTJAttention(config) - self.mlp = Int8GPTJMLP(inner_dim, config.n_embd) + self.ln_1 = LayerNormQ(n_embd) + # self.attn = Int8GPTJAttention(config) + # self.mlp = Int8GPTJMLP(inner_dim, n_embd) + self.dbgi = {} def forward( self, @@ -347,15 +369,20 @@ def forward( ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] - + # print(f"MLPIN MEAN: {hidden_states.to(torch.float).abs().mean()}") + # mxx = hidden_states.to(torch.float).abs().max() + # scc = 127.0/mxx + # hidden_states = hidden_states*scc.round().to(torch.int8) feed_forward_hidden_states = self.mlp(hidden_states) + self.dbgi['attoX'] = attn_output.clone() + self.dbgi['ffnX'] = feed_forward_hidden_states.clone() + self.dbgi['resiX'] = residual.clone() hidden_states = attn_output + feed_forward_hidden_states + residual if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] - return outputs # hidden_states, present, (attentions) @staticmethod @@ -366,24 +393,32 @@ def from_float(module, attn_input_scale: float, out_input_scale: float, fc1_input_scale: float, fc2_input_scale: float): - int8_module = Int8GPTJBlock(config) + inner_dim = module.mlp.fc_out.in_features + n_embd = module.ln_1.normalized_shape + # eps = module.ln_1.eps + int8_module = Int8GPTJBlock(inner_dim, n_embd) int8_module.mlp = Int8GPTJMLP.from_float( module.mlp, fc1_input_scale, fc2_input_scale) int8_module.ln_1 = LayerNormQ.from_float(module.ln_1, attn_input_scale) + int8_module.ln_1.eps = module.ln_1.eps int8_module.attn = Int8GPTJAttention.from_float( module.attn, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale) + return int8_module class Int8GPTJModel(GPTJPreTrainedModel): def __init__(self, config): super().__init__(config) - + n_layer = config.n_layer + inner_dim = 4 * config.n_embd self.embed_dim = config.n_embd self.vocab_size = config.vocab_size + print(f"EMBEDDING: {config.vocab_size}x{self.embed_dim}") self.wte = nn.Embedding(config.vocab_size, self.embed_dim) - self.h = nn.ModuleList([Int8PTJBlock(config) - for _ in range(config.n_layer)]) - self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.drop = nn.Identity() + # self.h = nn.ModuleList([Int8GPTJBlock(inner_dim, self.embed_dim) + # for _ in range(config.n_layer)]) + # self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Model parallel self.model_parallel = False @@ -395,10 +430,15 @@ def __init__(self, config): forward = GPTJModel.forward @staticmethod - def from_float(module, decoder_layer_scales): - int8_module = Int8GPTJModel(module.config) - int8_module.h = nn.ModuleList( - [Int8GPTJBlock.from_float(mm, decoder_layer_scales) for mm in module.h]) + def from_float(module : GPTJModel, decoder_layer_scales): + config = GPTJConfig(vocab_size=module.vocab_size, n_embd=module.embed_dim, n_layer=len(module.h), rotary_dim=module.h[0].attn.rotary_dim + , n_inner=4*module.embed_dim) + int8_module = Int8GPTJModel(config) + int8_module.h = nn.ModuleList() + for i, layer in enumerate(module.h): + int8_module.h.insert(i, Int8GPTJBlock.from_float( + layer, **decoder_layer_scales[i])) + int8_module.ln_f = module.ln_f return int8_module From e7ad7ad8357bd33a7d41fb47bfcdcb25874fe4d0 Mon Sep 17 00:00:00 2001 From: Iman Hosseini Date: Wed, 7 Dec 2022 12:24:15 -0500 Subject: [PATCH 6/7] fix --- tests/test_gptj_block.py | 45 ++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/tests/test_gptj_block.py b/tests/test_gptj_block.py index 49348dc..e1f55b0 100644 --- a/tests/test_gptj_block.py +++ b/tests/test_gptj_block.py @@ -20,7 +20,7 @@ def store_act(module, x, y, act_dict, name): def test_gptj_block(): config : GPTJConfig = GPTJConfig.from_pretrained('Salesforce/codegen-350M-mono') B, L, D, H = 1, 256, config.n_embd, config.n_head - x = torch.randn(B, L, D)*20 + x = torch.randn(B, L, D)*10 blk = GPTJBlock(config) blk.eval() act_dict = {} @@ -45,34 +45,33 @@ def test_gptj_block(): fc1_input_scale = act_dict['mlp.fc_in'][0].abs().max() / 127 fc2_input_scale = act_dict['mlp.fc_out'][0].abs().max() / 127 int8_blk = Int8GPTJBlock.from_float( - blk, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale, fc1_input_scale, fc2_input_scale, - ln1_input_scale).cuda() + blk, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale, fc1_input_scale, fc2_input_scale).cuda() int8_blk.eval() q_act_dict = {} y_hat = int8_blk(x.cuda())[0].cpu() - rd = blk.dbgi - md = int8_blk.dbgi - RN = 256 - ra = rd['atto'].cpu().flatten()[:RN] - ma = md['attoX'].cpu().flatten()[:RN] - rf = rd['ffn'].cpu().flatten()[:RN] - mf = md['ffnX'].cpu().flatten()[:RN] - rr = rd['resi'].cpu().flatten()[:RN] - mr = md['resiX'].cpu().flatten()[:RN] + # rd = blk.dbgi + # md = int8_blk.dbgi + # RN = 256 + # ra = rd['atto'].cpu().flatten()[:RN] + # ma = md['attoX'].cpu().flatten()[:RN] + # rf = rd['ffn'].cpu().flatten()[:RN] + # mf = md['ffnX'].cpu().flatten()[:RN] + # rr = rd['resi'].cpu().flatten()[:RN] + # mr = md['resiX'].cpu().flatten()[:RN] # # plt.plot(ra.flatten()) - print(f"MAX: a:{ra.abs().max()} f:{rf.abs().max()} r:{rr.abs().max()+0.0000001}") - plt.plot(ma - ra, color='r') - plt.savefig("Xa.jpg", dpi=300) - plt.cla() - # plt.plot(rf) - plt.plot(mf - rf, color='r') - plt.savefig("Xf.jpg", dpi=300) - plt.cla() - # plt.plot(rr.flatten()) - plt.plot(mr - rr, color='r') - plt.savefig("Xr.jpg", dpi=300) + # print(f"MAX: a:{ra.abs().max()} f:{rf.abs().max()} r:{rr.abs().max()+0.0000001}") + # plt.plot(ma - ra, color='r') + # plt.savefig("Xa.jpg", dpi=300) + # plt.cla() + # # plt.plot(rf) + # plt.plot(mf - rf, color='r') + # plt.savefig("Xf.jpg", dpi=300) + # plt.cla() + # # plt.plot(rr.flatten()) + # plt.plot(mr - rr, color='r') + # plt.savefig("Xr.jpg", dpi=300) r2 = (y - y_hat).pow(2).mean() / y.pow(2).mean() ic(r2) From 2163a169748edff67586c2bf0158f4c7f0718fc6 Mon Sep 17 00:00:00 2001 From: Iman Hosseini Date: Thu, 22 Dec 2022 23:15:10 -0500 Subject: [PATCH 7/7] changes --- tests/model_dec_scales.json | 1 + tests/test_gptj.py | 71 ++++++++++---- tests/test_gptj_block.py | 2 +- torch_int/models/gptj.py | 179 ++++++++++++++++++------------------ torch_int/nn/linear.py | 4 +- 5 files changed, 148 insertions(+), 109 deletions(-) create mode 100644 tests/model_dec_scales.json diff --git a/tests/model_dec_scales.json b/tests/model_dec_scales.json new file mode 100644 index 0000000..841008a --- /dev/null +++ b/tests/model_dec_scales.json @@ -0,0 +1 @@ +[{"attn_input_scale": 0.031619094488188976, "q_output_scale": 0.1687992125984252, "k_output_scale": 0.1347194881889764, "v_output_scale": 0.02297613188976378, "out_input_scale": 0.01796259842519685, "fc1_input_scale": 0.031619094488188976, "fc2_input_scale": 0.007831723671259843}, {"attn_input_scale": 0.011095903051181102, "q_output_scale": 0.14013287401574803, "k_output_scale": 0.14160925196850394, "v_output_scale": 0.046475147637795276, "out_input_scale": 0.03595595472440945, "fc1_input_scale": 0.011095903051181102, "fc2_input_scale": 0.00655142716535433}, {"attn_input_scale": 0.00863527312992126, "q_output_scale": 0.18553149606299213, "k_output_scale": 0.156373031496063, "v_output_scale": 0.03021961122047244, "out_input_scale": 0.030096579724409447, "fc1_input_scale": 0.00863527312992126, "fc2_input_scale": 0.029789000984251968}, {"attn_input_scale": 0.019192913385826772, "q_output_scale": 0.2233021653543307, "k_output_scale": 0.15563484251968504, "v_output_scale": 0.03804749015748032, "out_input_scale": 0.03168061023622047, "fc1_input_scale": 0.019192913385826772, "fc2_input_scale": 0.03337229330708662}, {"attn_input_scale": 0.01287217027559055, "q_output_scale": 0.13041338582677164, "k_output_scale": 0.1392716535433071, "v_output_scale": 0.062100147637795276, "out_input_scale": 0.05361097440944882, "fc1_input_scale": 0.01287217027559055, "fc2_input_scale": 0.002772053395669291}, {"attn_input_scale": 0.016901451771653545, "q_output_scale": 0.17691929133858267, "k_output_scale": 0.17704232283464566, "v_output_scale": 0.025298351377952756, "out_input_scale": 0.024913877952755906, "fc1_input_scale": 0.016901451771653545, "fc2_input_scale": 0.00285279281496063}, {"attn_input_scale": 0.016378567913385825, "q_output_scale": 0.13188976377952755, "k_output_scale": 0.15243602362204725, "v_output_scale": 0.02449864665354331, "out_input_scale": 0.020100270669291338, "fc1_input_scale": 0.016378567913385825, "fc2_input_scale": 0.0020415538877952754}, {"attn_input_scale": 0.014563853346456693, "q_output_scale": 0.15526574803149606, "k_output_scale": 0.1625246062992126, "v_output_scale": 0.02995816929133858, "out_input_scale": 0.02109990157480315, "fc1_input_scale": 0.014563853346456693, "fc2_input_scale": 0.002793199434055118}, {"attn_input_scale": 0.016701525590551183, "q_output_scale": 0.15255905511811024, "k_output_scale": 0.18061023622047245, "v_output_scale": 0.021345964566929134, "out_input_scale": 0.01842396653543307, "fc1_input_scale": 0.016701525590551183, "fc2_input_scale": 0.00299312561515748}, {"attn_input_scale": 0.017685777559055118, "q_output_scale": 0.16289370078740156, "k_output_scale": 0.18393208661417323, "v_output_scale": 0.02875861220472441, "out_input_scale": 0.026113435039370077, "fc1_input_scale": 0.017685777559055118, "fc2_input_scale": 0.0021876537893700788}, {"attn_input_scale": 0.01819328248031496, "q_output_scale": 0.1875, "k_output_scale": 0.17285925196850394, "v_output_scale": 0.03186515748031496, "out_input_scale": 0.0296505905511811, "fc1_input_scale": 0.01819328248031496, "fc2_input_scale": 0.001685915969488189}, {"attn_input_scale": 0.014271653543307087, "q_output_scale": 0.14480807086614172, "k_output_scale": 0.16510826771653545, "v_output_scale": 0.023622047244094488, "out_input_scale": 0.01714751476377953, "fc1_input_scale": 0.014271653543307087, "fc2_input_scale": 0.0016195943036417322}, {"attn_input_scale": 0.01624015748031496, "q_output_scale": 0.1733513779527559, "k_output_scale": 0.18713090551181102, "v_output_scale": 0.04856668307086614, "out_input_scale": 0.029389148622047244, "fc1_input_scale": 0.01624015748031496, "fc2_input_scale": 0.0015542338213582678}, {"attn_input_scale": 0.016670767716535435, "q_output_scale": 0.1546505905511811, "k_output_scale": 0.18639271653543307, "v_output_scale": 0.03380290354330709, "out_input_scale": 0.03257258858267716, "fc1_input_scale": 0.016670767716535435, "fc2_input_scale": 0.002921998031496063}, {"attn_input_scale": 0.014686884842519685, "q_output_scale": 0.16203248031496062, "k_output_scale": 0.1969734251968504, "v_output_scale": 0.03071173720472441, "out_input_scale": 0.02066929133858268, "fc1_input_scale": 0.014686884842519685, "fc2_input_scale": 0.0026105745570866143}, {"attn_input_scale": 0.016670767716535435, "q_output_scale": 0.1592027559055118, "k_output_scale": 0.18011811023622049, "v_output_scale": 0.028420275590551183, "out_input_scale": 0.014148622047244094, "fc1_input_scale": 0.016670767716535435, "fc2_input_scale": 0.005417230561023622}, {"attn_input_scale": 0.017854945866141732, "q_output_scale": 0.17568897637795275, "k_output_scale": 0.19672736220472442, "v_output_scale": 0.023452878937007874, "out_input_scale": 0.02251476377952756, "fc1_input_scale": 0.017854945866141732, "fc2_input_scale": 0.0013398898868110236}, {"attn_input_scale": 0.015286663385826772, "q_output_scale": 0.1671998031496063, "k_output_scale": 0.14271653543307086, "v_output_scale": 0.019239050196850394, "out_input_scale": 0.017593503937007874, "fc1_input_scale": 0.015286663385826772, "fc2_input_scale": 0.0022145669291338582}, {"attn_input_scale": 0.016070989173228346, "q_output_scale": 0.15514271653543307, "k_output_scale": 0.15231299212598426, "v_output_scale": 0.019408218503937008, "out_input_scale": 0.016424704724409447, "fc1_input_scale": 0.016070989173228346, "fc2_input_scale": 0.006243848425196851}, {"attn_input_scale": 0.017009104330708662, "q_output_scale": 0.1422244094488189, "k_output_scale": 0.16117125984251968, "v_output_scale": 0.025221456692913386, "out_input_scale": 0.019500492125984252, "fc1_input_scale": 0.017009104330708662, "fc2_input_scale": 0.01803949311023622}] \ No newline at end of file diff --git a/tests/test_gptj.py b/tests/test_gptj.py index dcaecef..28b838c 100644 --- a/tests/test_gptj.py +++ b/tests/test_gptj.py @@ -1,11 +1,15 @@ import torch -from torch_int.models.opt import Int8OPTForCausalLM +from torch_int.models.gptj import Int8GPTJForCausalLM, Int8GPTJBlock, Int8GPTJMLP, Int8GPTJAttention, Int8GPTJModel from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoderLayer, OPTForCausalLM +from transformers.models.gptj.modeling_gptj import GPTJModel, GPTJConfig, GPTJForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer from icecream import ic -from transformers import GPT2Tokenizer +from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU +# from transformers import GPTJTok from datasets import load_dataset from tqdm import tqdm - +import json +import copy class Evaluator: def __init__(self, dataset, tokenizer, device): @@ -17,21 +21,21 @@ def __init__(self, dataset, tokenizer, device): def tokenize_function(examples): example = self.tokenizer(examples['text']) return example - self.dataset = self.dataset.map(tokenize_function, batched=True) self.dataset.set_format(type='torch', columns=['input_ids']) @torch.no_grad() - def evaluate(self, model): + def evaluate2(self, model): model.eval() # The task is to predict the last token of the input. total, hit = 0, 0 + idx = 0 pbar = tqdm(self.dataset, desc='Evaluating') for batch in pbar: input_ids = batch['input_ids'].to(self.device).unsqueeze(0) - # label is the last token which is not the padding token label = input_ids[:, -1] - outputs = model(input_ids) + outputs = model(input_ids.cuda()) + idx += 1 last_token_logits = outputs.logits[:, -2, :] pred = last_token_logits.argmax(dim=-1) total += label.size(0) @@ -40,23 +44,52 @@ def evaluate(self, model): acc = hit / total return acc + @torch.no_grad() + def evaluate(self, modelX, model): + model.eval() + # The task is to predict the last token of the input. + idx = 0 + total, hit = 0, 0 + hit2 = 0 + pbar = tqdm(self.dataset, desc='Evaluating') + for batch in pbar: + input_ids = batch['input_ids'].to(self.device).unsqueeze(0) + label = input_ids[:, -1] + outputs = model(input_ids.to('cuda')) + outputs2 = modelX(input_ids.to('cuda')) + model.transformer.d.clear() + modelX.transformer.d.clear() + idx += 1 + last_token_logits = outputs.logits[:, -2, :] + last_token_logits = outputs2.logits[:, -2, :] + pred = last_token_logits.argmax(dim=-1) + pred2 = last_token_logits.argmax(dim=-1) + total += label.size(0) + hit += (pred == label).sum().item() + hit2 += (pred == label).sum().item() + pbar.set_postfix({'acc': hit / total, 'accX': hit2 / total}) + acc = hit / total + return acc +MP = "/home/iman/fgg/smoothquant/SF/codegen-350M-multiX.pt" @torch.no_grad() def test_opt(): dataset = load_dataset('lambada', split='validation[:1000]') - tokenizer = GPT2Tokenizer.from_pretrained('facebook/opt-13b') + dataset = dataset.shuffle(seed=42) + checkpoint = "moyix/codegen-350M-multi-gptj" + # checkpoint = "Salesforce/codegen-350M-multi" + config = GPTJConfig.from_pretrained('moyix/codegen-350M-multi-gptj') + model = GPTJForCausalLM.from_pretrained(checkpoint, device_map = 'auto', torch_dtype = 'auto').cuda() + tokenizer = AutoTokenizer.from_pretrained('Salesforce/codegen-350M-multi') evaluator = Evaluator(dataset, tokenizer, 'cuda') - int8_model_path = '/dataset/opt/opt-13b-smoothquant' - # precision = 'fp16' - precision = 'int8' - if precision == 'int8': - model = Int8OPTForCausalLM.from_pretrained(int8_model_path, - device_map='auto', torch_dtype=torch.float16) - else: - model = OPTForCausalLM.from_pretrained('facebook/opt-13b', - device_map='auto', - torch_dtype=torch.float16) - acc = evaluator.evaluate(model) + dlsj = "./tests/model_dec_scales.json" + decoder_layer_scales = [] + with open(dlsj, 'r') as fp: + decoder_layer_scales = json.load(fp) + # these layers will not be quantized + layers_to_keep = list(range(13)) + int8_model = Int8GPTJForCausalLM.from_float(model, decoder_layer_scales, k = layers_to_keep) + acc = evaluator.evaluate2(int8_model.to('cuda')) ic(acc) diff --git a/tests/test_gptj_block.py b/tests/test_gptj_block.py index e1f55b0..00f8cf8 100644 --- a/tests/test_gptj_block.py +++ b/tests/test_gptj_block.py @@ -20,7 +20,7 @@ def store_act(module, x, y, act_dict, name): def test_gptj_block(): config : GPTJConfig = GPTJConfig.from_pretrained('Salesforce/codegen-350M-mono') B, L, D, H = 1, 256, config.n_embd, config.n_head - x = torch.randn(B, L, D)*10 + x = torch.randn(B, L, D) blk = GPTJBlock(config) blk.eval() act_dict = {} diff --git a/torch_int/models/gptj.py b/torch_int/models/gptj.py index 63c2154..ba03bf3 100644 --- a/torch_int/models/gptj.py +++ b/torch_int/models/gptj.py @@ -11,13 +11,26 @@ BaseModelOutputWithPast ) +@torch.no_grad() +def quantize_per_tensor_absmax(t): + scale = t.abs().max() / 127 + if not t.is_cuda: + # half rounding is not supported on CPU + t = t.float() + # use inplace operation to save memory + t.div_(scale).round_() + t_q = t.to(torch.int8) + return t_q, scale + from typing import Optional, Tuple, List from torch_int.nn.linear import W8A8BFP32OFP32Linear, W8A8B8O8Linear, W8A8B8O8LinearGELU from torch_int.nn.fused import LayerNormQ from transformers.utils import logging from torch_int.nn.bmm import BMM_S8T_S8N_S8T, BMM_S8T_S8N_F32T +from transformers.activations import ACT2FN def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] if seq_len is None: seq_len = x.shape[seq_dim] @@ -47,9 +60,11 @@ def duplicate_interleave(m): def apply_rotary_pos_emb(x, sincos, offset=0): + x_ = x.to(torch.float32) sin, cos = map(lambda t: duplicate_interleave(t)[None, offset : x.shape[1] + offset, None, :], sincos) # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) - r = (x.to(torch.float) * cos) + (rotate_every_two(x) * sin) + r = ((x_.to(torch.float) * cos) + (rotate_every_two(x_.to(torch.float)) * sin)) + r = r.clamp(-128, 127).to(torch.int8) return r @@ -58,8 +73,8 @@ class Int8GPTJAttention(nn.Module): def __init__(self, n_embd, n_head, max_position_embeddings, rotary_dim = None): super().__init__() - self.dbgi = {} max_positions = max_position_embeddings + self.max_position = max_positions self.embed_dim = n_embd self.num_attention_heads = n_head self.head_dim = n_embd // n_head @@ -100,29 +115,34 @@ def from_float(module: GPTJAttention, out_input_scale: float): int8_module = Int8GPTJAttention(module.embed_dim, module.num_attention_heads, module.bias.shape[3], module.rotary_dim) # Fuse the scaling into the q_proj output scale - scale_h = module.head_dim**-0.5 - q_output_scale = q_output_scale * scale_h - module.q_proj.weight *= scale_h - # k_output_scale = k_output_scale * scale_h - # module.k_proj.weight *= scale_h + # scale_h = module.head_dim**-0.5 + ## scaling + # qoo = q_output_scale + # q_output_scale = q_output_scale * scale_h + # module.q_proj.weight *= scale_h + # qs2 = q_output_scale * scale_h + ## scaling # TODO: GPTJ has no bias, find a way to elide these later - module.q_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) - module.v_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) - module.k_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) - module.out_proj.bias = torch.nn.Parameter(torch.zeros(module.embed_dim, dtype=float)) + module.q_proj.bias = torch.nn.Parameter(torch.zeros((1,module.embed_dim), dtype=module.q_proj.weight.dtype)) + module.v_proj.bias = torch.nn.Parameter(torch.zeros((1,module.embed_dim), dtype=module.v_proj.weight.dtype)) + module.k_proj.bias = torch.nn.Parameter(torch.zeros((1,module.embed_dim), dtype=module.k_proj.weight.dtype)) + module.out_proj.bias = torch.nn.Parameter(torch.zeros((1,module.embed_dim), dtype=module.out_proj.weight.dtype)) + module.cuda() int8_module.q_proj = W8A8B8O8Linear.from_float( module.q_proj, input_scale, q_output_scale) + wc = module.k_proj.weight.clone() int8_module.k_proj = W8A8B8O8Linear.from_float( module.k_proj, input_scale, k_output_scale) + int8_weight, weight_scale = quantize_per_tensor_absmax(wc) int8_module.v_proj = W8A8B8O8Linear.from_float( module.v_proj, input_scale, v_output_scale) + int8_module.v_proj.requires_grad = False int8_module.out_proj = W8A8BFP32OFP32Linear.from_float( module.out_proj, out_input_scale) int8_module.qk_bmm = BMM_S8T_S8N_F32T.from_scale( q_output_scale, k_output_scale) - # alpha = s_prob * s_v / s_out, where s_prob = 1 / 127 - print(f"{v_output_scale}/{out_input_scale}") + # print(f"{v_output_scale}/{out_input_scale}") int8_module.pv_bmm = BMM_S8T_S8N_S8T.from_scale( 1.0 / 127, v_output_scale, out_input_scale) return int8_module @@ -167,17 +187,7 @@ def _attn( # compute causal mask from causal mask buffer query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - - query_length: key_length, :key_length].to(torch.bool) - - # Keep the attention weights computation in fp32 to avoid overflow issues - # query = query.to(torch.int8) - # key = key.to(torch.int8) - - # attn_weights = torch.matmul(query, key.transpose(-1, -2)) - # proj_shape = (self.bsz * self.num_attention_heads, -1, self.head_dim) - # key = key.view(*proj_shape) - # query = self._shape( - # query, self.tgt_len, 1).view(*proj_shape) + query_length: key_length, :key_length].to(torch.bool).cuda() # key = key.transpose(-1, -2) proj_shape = (self.bsz * self.num_attention_heads, -1, self.head_dim) @@ -185,12 +195,8 @@ def _attn( query = query.view(*proj_shape) query = query.contiguous() key = key.contiguous() - print(f"I8key:{key.shape}, query:{query.shape}") attn_weights = self.qk_bmm(query, key) - self.dbgi["qk"] = attn_weights.clone() - print(f"I8OUT: {attn_weights.shape}") attn_weights = attn_weights.view(self.bsz, self.num_attention_heads, self.tgt_len, key_length) - print(f"I8OUTpost: {attn_weights.shape}") mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` @@ -198,40 +204,25 @@ def _attn( mask_value, dtype=attn_weights.dtype).to(attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) - # attn_weights = attn_weights / self.scale_attn + attn_weights = attn_weights / self.scale_attn if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - # attn_weights = attn_weights.to(value.dtype) attn_weights.mul_(127).round_() attn_weights = attn_weights.to(torch.int8) # Mask heads if we want to if head_mask is not None: attn_weights = attn_weights * head_mask - self.dbgi["Am1"] = value.clone() attn_weights = attn_weights.view(self.bsz * self.num_attention_heads, -1, self.tgt_len).contiguous() - print(f"VAL:{value.shape}") value = value.transpose(2,3) - print(f"VAL:{value.shape}") value = value.reshape(self.num_attention_heads * self.bsz, self.head_dim, self.tgt_len).contiguous() - # value = value.reshape(self.num_attention_heads * self.bsz, self.head_dim, self.tgt_len).contiguous() - print(f"I8: att:{attn_weights.shape}, v: {value.shape}") - self.dbgi["pv_a"] = attn_weights.clone() - self.dbgi["pv_v"] = value.clone() - print(f"ATTNPROBS:{attn_weights.to(torch.float).abs().mean()}|VAL:{value.to(torch.float).abs().mean()}") - print(f"att___:{attn_weights.shape}, value__:{value.shape}") attn_output = self.pv_bmm(attn_weights, value) - # attn_output = torch.matmul(attn_weights, value) - # print(f"===F:{attn_output[:16]}") - self.dbgi["pv"] = attn_output.clone() - print(f"ASIZE_I8: {torch.numel(attn_output)}") attn_weights = attn_weights.view(self.bsz, self.num_attention_heads, self.tgt_len, key_length) attn_output = attn_output.view(self.bsz, self.num_attention_heads, self.tgt_len, self.head_dim) - print(f"MOUT: W:{attn_weights.shape}, O: {attn_output.shape}") return attn_output, attn_weights def forward( @@ -244,11 +235,10 @@ def forward( output_attentions: Optional[bool] = False, ): self.bsz, self.tgt_len, _ = hidden_states.size() - print(f"HS: {hidden_states.shape}") + # self.out_proj.cuda() query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) - self.dbgi["vO"] = value.clone() query = self._split_heads( query, self.num_attention_heads, self.head_dim, True) @@ -275,12 +265,12 @@ def forward( k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) - key = torch.cat([k_rot, k_pass], dim=-1).to(torch.int8) - query = torch.cat([q_rot, q_pass], dim=-1).to(torch.int8) + key = torch.cat([k_rot, k_pass.to(torch.int8)], dim=-1) + query = torch.cat([q_rot, q_pass.to(torch.int8)], dim=-1) else: sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) - key = apply_rotary_pos_emb(key, sincos, offset=offset).to(torch.int8) - query = apply_rotary_pos_emb(query, sincos, offset=offset).to(torch.int8) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) key = key.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3) @@ -295,19 +285,13 @@ def forward( present = (key, value) else: present = None - # tvals = self.dbgi[0] - # r2q = (tvals[0] - query).pow(2).mean() / tvals[0].pow(2).mean() - # r2k = (tvals[0] - query).pow(2).mean() / tvals[0].pow(2).mean() - # r2v = (tvals[0] - query).pow(2).mean() / tvals[0].pow(2).mean() # compute self-attention: V x Softmax(QK^T) attn_output, attn_weights = self._attn( query, key, value, attention_mask, head_mask) - print(f"I8-attO: {attn_output.shape}") attn_output = self._merge_heads( attn_output, self.num_attention_heads, self.head_dim) attn_output = attn_output.contiguous() attn_output = self.out_proj(attn_output) - # attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: @@ -325,9 +309,10 @@ def __init__(self, intermediate_size, embed_dim): self.fc2 = W8A8BFP32OFP32Linear(intermediate_size, embed_dim) def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor: + # hidden_states = hidden_states.to(torch.float) hidden_states = self.fc1(hidden_states) hidden_states = self.fc2(hidden_states) - return hidden_states + return hidden_states @staticmethod def from_float(module: GPTJMLP, fc1_input_scale: float, fc2_input_scale: float): @@ -341,12 +326,11 @@ def from_float(module: GPTJMLP, fc1_input_scale: float, fc2_input_scale: float): class Int8GPTJBlock(nn.Module): - def __init__(self, inner_dim, n_embd): + def __init__(self, inner_dim, n_embd, n_head, max_position_embeddings, rotary_dim = None): super().__init__() self.ln_1 = LayerNormQ(n_embd) - # self.attn = Int8GPTJAttention(config) - # self.mlp = Int8GPTJMLP(inner_dim, n_embd) - self.dbgi = {} + self.attn = Int8GPTJAttention(n_embd, n_head, max_position_embeddings, rotary_dim) + self.mlp = Int8GPTJMLP(inner_dim, n_embd) def forward( self, @@ -369,16 +353,8 @@ def forward( ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] - # print(f"MLPIN MEAN: {hidden_states.to(torch.float).abs().mean()}") - # mxx = hidden_states.to(torch.float).abs().max() - # scc = 127.0/mxx - # hidden_states = hidden_states*scc.round().to(torch.int8) - feed_forward_hidden_states = self.mlp(hidden_states) - self.dbgi['attoX'] = attn_output.clone() - self.dbgi['ffnX'] = feed_forward_hidden_states.clone() - self.dbgi['resiX'] = residual.clone() + feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = attn_output + feed_forward_hidden_states + residual - if use_cache: outputs = (hidden_states,) + outputs else: @@ -394,31 +370,32 @@ def from_float(module, attn_input_scale: float, fc1_input_scale: float, fc2_input_scale: float): inner_dim = module.mlp.fc_out.in_features - n_embd = module.ln_1.normalized_shape - # eps = module.ln_1.eps - int8_module = Int8GPTJBlock(inner_dim, n_embd) + n_embd = module.ln_1.normalized_shape[0] + int8_module = Int8GPTJBlock(inner_dim, n_embd, module.attn.num_attention_heads, module.attn.bias.shape[0], module.attn.rotary_dim) int8_module.mlp = Int8GPTJMLP.from_float( module.mlp, fc1_input_scale, fc2_input_scale) int8_module.ln_1 = LayerNormQ.from_float(module.ln_1, attn_input_scale) - int8_module.ln_1.eps = module.ln_1.eps int8_module.attn = Int8GPTJAttention.from_float( module.attn, attn_input_scale, q_output_scale, k_output_scale, v_output_scale, out_input_scale) return int8_module class Int8GPTJModel(GPTJPreTrainedModel): + # TODO: have to add padding! def __init__(self, config): + self.d = {} super().__init__(config) n_layer = config.n_layer inner_dim = 4 * config.n_embd self.embed_dim = config.n_embd self.vocab_size = config.vocab_size - print(f"EMBEDDING: {config.vocab_size}x{self.embed_dim}") self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.drop = nn.Identity() - # self.h = nn.ModuleList([Int8GPTJBlock(inner_dim, self.embed_dim) - # for _ in range(config.n_layer)]) - # self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.padding_idx = config.pad_token_id + # self.h = nn.ModuleList() + self.h = nn.ModuleList([Int8GPTJBlock(inner_dim, self.embed_dim, config.n_head, config.n_positions, config.rotary_dim) + for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Model parallel self.model_parallel = False @@ -427,18 +404,46 @@ def __init__(self, config): get_input_embeddings = GPTJModel.get_input_embeddings set_input_embeddings = GPTJModel.set_input_embeddings - forward = GPTJModel.forward + old_forward = GPTJModel.forward + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + from torch.nn.functional import pad + input_len = input_ids.shape[1] + if input_len % 16 != 0: + padding_len = 16 - input_len % 16 + input_ids = pad(input_ids, (0, padding_len), value=self.padding_idx) + if attention_mask is not None: + attention_mask = pad(attention_mask, (0, padding_len), value=0) + output = self.old_forward(input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) + if input_len % 16 != 0: + output.last_hidden_state = output.last_hidden_state[:,:input_len, :] + return output @staticmethod - def from_float(module : GPTJModel, decoder_layer_scales): + def from_float(module : GPTJModel, decoder_layer_scales, k = None): config = GPTJConfig(vocab_size=module.vocab_size, n_embd=module.embed_dim, n_layer=len(module.h), rotary_dim=module.h[0].attn.rotary_dim , n_inner=4*module.embed_dim) int8_module = Int8GPTJModel(config) - int8_module.h = nn.ModuleList() for i, layer in enumerate(module.h): - int8_module.h.insert(i, Int8GPTJBlock.from_float( - layer, **decoder_layer_scales[i])) - int8_module.ln_f = module.ln_f + if k is not None and i in k: + int8_module.h[i] = layer + else: + int8_module.h[i] = Int8GPTJBlock.from_float(layer, **decoder_layer_scales[i]) + int8_module.ln_f = module.ln_f.to(torch.float) + int8_module.wte = module.wte return int8_module @@ -458,10 +463,10 @@ def __init__(self, config): self.post_init() @staticmethod - def from_float(module, decoder_layer_scales): + def from_float(module, decoder_layer_scales, k = None): int8_module = Int8GPTJForCausalLM(module.config) - int8_module.transformer = Int8GPTJModel(config, decoder_layer_scales) - int8_module.lm_head = module.lm_head + int8_module.transformer = Int8GPTJModel.from_float(module.transformer, decoder_layer_scales, k) + int8_module.lm_head = module.lm_head.to(torch.float) return int8_module get_input_embeddings = GPTJForCausalLM.get_input_embeddings diff --git a/torch_int/nn/linear.py b/torch_int/nn/linear.py index dfa70de..2efb247 100644 --- a/torch_int/nn/linear.py +++ b/torch_int/nn/linear.py @@ -96,7 +96,7 @@ def from_float(module: torch.nn.Linear, input_scale, output_scale): alpha = input_scale * weight_scale / output_scale beta = bias_scale / output_scale int8_module.weight = int8_weight - int8_module.bias = int8_bias + int8_module.bias = int8_bias.reshape(int8_module.bias.shape) int8_module.a = alpha int8_module.b = beta return int8_module @@ -266,7 +266,7 @@ def from_float(module: torch.nn.Linear, input_scale): int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) alpha = input_scale * weight_scale int8_module.weight = int8_weight - int8_module.bias = module.bias.to(torch.float32) + int8_module.bias = module.bias.to(torch.float32).reshape(int8_module.bias.shape) int8_module.a = alpha int8_module.input_scale = input_scale int8_module.weight_scale = weight_scale