diff --git a/auto_round/data_type/fp8.py b/auto_round/data_type/fp8.py index d92c311a..6541384b 100644 --- a/auto_round/data_type/fp8.py +++ b/auto_round/data_type/fp8.py @@ -76,26 +76,28 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, **kwargs): - Placeholder for zp (None). """ orig_shape = tensor.shape - info = torch.finfo(torch.float8_e4m3fn) + info = torch.finfo(torch.float8_e5m2) orig_dtype = tensor.dtype - if tensor_max is None: ##dynamic per-token - tensor = tensor.reshape(-1, orig_shape[-1]) - max_tensor = torch.max(torch.abs(tensor), dim=-1)[ - 0] * max_scale - elif isinstance(tensor_max,torch.Tensor): - max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale - else: - max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale + # if tensor_max is None: ##dynamic per-token + # tensor = tensor.reshape(-1, orig_shape[-1]) + # max_tensor = torch.max(torch.abs(tensor), dim=-1)[ + # 0] * max_scale + # elif isinstance(tensor_max,torch.Tensor): + # max_tensor = tensor_max.clone().detach().to(tensor.device) * max_scale + # else: + # max_tensor = torch.tensor(tensor_max).to(tensor.device) * max_scale + max_tensor =torch.max(torch.abs(tensor)) scale = max_tensor.to(torch.float32) / info.max min_scaling_factor = float(1.0 / (info.max * 512.0)) ##copy from vllm scale = torch.clip(scale, min=min_scaling_factor) if tensor.dtype == torch.float16: ## Avoid NaN gradients with float16 tensor = tensor.to(torch.bfloat16) - scale = scale.unsqueeze(dim=-1) + # scale = scale.unsqueeze(dim=-1) + scale = torch.ones((1), device=tensor.device) fp8_res = (tensor / scale) fp8_res = torch.clip(fp8_res, info.min, info.max) - fp8_res = float8_e4m3fn_ste(fp8_res) + fp8_res = fp8_res.to(torch.float8_e5m2).to(torch.bfloat16) qdq_res = fp8_res * scale qdq_res = qdq_res.to(orig_dtype).reshape(orig_shape) return qdq_res, scale, None diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 899abf17..39605364 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -123,11 +123,11 @@ def quant_tensor_asym_dq(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_ scale = torch.clamp(scale, q_scale_thresh) wmin_m = wmin_m.view(-1, 1) - int_w = round_ste(tensor / scale + v) - q = torch.clamp(int_w + round_ste(wmin_m / scale), 0, maxq) + int_w = round_ste((tensor + wmin_m) / scale + v) + q = torch.clamp(int_w, 0, maxq) qdq_result = (scale * q - wmin_m).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - zp = round_ste(wmin_m / scale) # remove this later + # zp = round_ste(wmin_m / scale) # remove this later return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin_m": wmin_m, "d_wmin_m": d_wmin_m} diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index beffe62a..b3df745b 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -133,6 +133,117 @@ def pack_qact_layer(name, model): qlayer.to(device) +# def pack_layer(layer_name, model, backend): +# """ +# Packs a model layer for quantization based on its type and configuration. +# +# This function retrieves the specified layer from the model, checks its +# compatibility for quantization, and replaces it with a quantized version +# if applicable. The quantization process depends on the layer's bit-width, +# group size, symmetry, and activation bits. +# +# Args: +# layer_name (str): The name of the layer to be packed. +# model (torch.nn.Module): The model containing the layer. +# backend (str): The backend framework to be used for quantization. +# +# Returns: +# None: The function modifies the model in place. +# """ +# layer = get_module(model, layer_name) +# if hasattr(layer, "orig_layer"): +# layer = layer.orig_layer +# +# if not isinstance(layer, supported_layer_types): ##already packed +# return +# +# if int(layer.act_bits) <= 8: +# return pack_qact_layer(layer_name, model) +# +# if not check_to_quantized(layer): +# return +# +# device = layer.weight.device +# bits = layer.bits +# group_size = layer.group_size +# sym = layer.sym +# act_bits = layer.act_bits +# +# scale = layer.scale +# zp = layer.zp +# QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits) +# +# if isinstance(layer, nn.Linear): +# in_features = layer.in_features +# out_features = layer.out_features +# elif isinstance(layer, nn.Conv2d): +# in_features = layer.in_channels +# out_features = layer.out_channels +# elif isinstance(layer, transformers.pytorch_utils.Conv1D): +# in_features = layer.weight.shape[0] +# out_features = layer.weight.shape[1] +# bias = layer.bias is not None +# +# if "awq" not in backend: +# new_layer = QuantLinear( ##pylint: disable=E1123 +# bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype +# ) +# new_layer.device = device +# set_module(model, layer_name, new_layer) +# qlayer = new_layer +# import auto_round.export.export_to_autoround.qlinear_triton +# if sym and isinstance(QuantLinear, (auto_round.export.export_to_autoround.qlinear_triton.QuantLinear, +# auto_round_extension.cuda.qlinear_tritonv2.QuantLinear)): +# zp = int(zp.flatten()[0]) +# +# qlayer.to("cpu") +# ##force to float32 to be compatible with torch 2.0 +# sig = inspect.signature(qlayer.pack) +# param_count = len(sig.parameters) +# if param_count == 2: +# qlayer.pack(layer, scale) +# else: +# qlayer.pack(layer, scale, zp, None) +# qlayer.to(device) +# else: +# scale, zp = scale.to(torch.float32), zp.to(torch.float32) +# scale = scale.t().contiguous() +# zp = zp.t().contiguous() +# if sym: +# zp = int(zp.flatten()[0]) +# +# if bits != 4: +# logger.error("AutoAWQ format only supports 4-bits quantization.") +# qlayer = QuantLinear.from_linear( +# linear=layer, +# w_bit=bits, +# group_size=group_size, +# init_only=False, +# scales=scale, +# zeros=zp, +# ) +# qlayer.to(device) +# set_module(model, layer_name, qlayer) + + +class MyLinear(torch.nn.Module): + def __init__(self, in_features, out_features, bias=True, device=None, + dtype=None): + factory_kwargs = {"device": device} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter( + torch.empty((out_features, in_features), dtype=torch.float8_e5m2, **factory_kwargs) + ) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + self.register_buffer('weight_scale', torch.ones((1),dtype=torch.bfloat16)) + + + def pack_layer(layer_name, model, backend): """ Packs a model layer for quantization based on its type and configuration. @@ -171,7 +282,10 @@ def pack_layer(layer_name, model, backend): scale = layer.scale zp = layer.zp - QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits) + weight = layer.weight + q_weight = weight / scale + + # QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits) if isinstance(layer, nn.Linear): in_features = layer.in_features @@ -183,47 +297,53 @@ def pack_layer(layer_name, model, backend): in_features = layer.weight.shape[0] out_features = layer.weight.shape[1] bias = layer.bias is not None - - if "awq" not in backend: - new_layer = QuantLinear( ##pylint: disable=E1123 - bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype - ) - new_layer.device = device - set_module(model, layer_name, new_layer) - qlayer = new_layer - import auto_round.export.export_to_autoround.qlinear_triton - if sym and isinstance(QuantLinear, (auto_round.export.export_to_autoround.qlinear_triton.QuantLinear, - auto_round_extension.cuda.qlinear_tritonv2.QuantLinear)): - zp = int(zp.flatten()[0]) - - qlayer.to("cpu") - ##force to float32 to be compatible with torch 2.0 - sig = inspect.signature(qlayer.pack) - param_count = len(sig.parameters) - if param_count == 2: - qlayer.pack(layer, scale) - else: - qlayer.pack(layer, scale, zp, None) - qlayer.to(device) - else: - scale, zp = scale.to(torch.float32), zp.to(torch.float32) - scale = scale.t().contiguous() - zp = zp.t().contiguous() - if sym: - zp = int(zp.flatten()[0]) - - if bits != 4: - logger.error("AutoAWQ format only supports 4-bits quantization.") - qlayer = QuantLinear.from_linear( - linear=layer, - w_bit=bits, - group_size=group_size, - init_only=False, - scales=scale, - zeros=zp, - ) - qlayer.to(device) - set_module(model, layer_name, qlayer) + my_linear = MyLinear(in_features, out_features, bias) + my_linear.weight_scale.data.copy_(scale) + my_linear.weight.data.copy_(q_weight.to(torch.float8_e5m2)) + if bias: + my_linear.bias.data.copy_(layer.bias) + + # + # if "awq" not in backend: + # new_layer = QuantLinear( ##pylint: disable=E1123 + # bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype + # ) + # new_layer.device = device + # set_module(model, layer_name, new_layer) + # qlayer = new_layer + # import auto_round.export.export_to_autoround.qlinear_triton + # if sym and isinstance(QuantLinear, (auto_round.export.export_to_autoround.qlinear_triton.QuantLinear, + # auto_round_extension.cuda.qlinear_tritonv2.QuantLinear)): + # zp = int(zp.flatten()[0]) + # + # qlayer.to("cpu") + # ##force to float32 to be compatible with torch 2.0 + # sig = inspect.signature(qlayer.pack) + # param_count = len(sig.parameters) + # if param_count == 2: + # qlayer.pack(layer, scale) + # else: + # qlayer.pack(layer, scale, zp, None) + # qlayer.to(device) + # else: + # scale, zp = scale.to(torch.float32), zp.to(torch.float32) + # scale = scale.t().contiguous() + # zp = zp.t().contiguous() + # if sym: + # zp = int(zp.flatten()[0]) + # + # if bits != 4: + # logger.error("AutoAWQ format only supports 4-bits quantization.") + # qlayer = QuantLinear.from_linear( + # linear=layer, + # w_bit=bits, + # group_size=group_size, + # init_only=False, + # scales=scale, + # zeros=zp, + # ) + my_linear.to(device) + set_module(model, layer_name, my_linear) def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:exllamav2", **kwargs): @@ -261,6 +381,8 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex layer_config = kwargs["layer_config"] quantization_config = kwargs["serialization_dict"] quantization_config["quant_method"] = "auto-round" + quantization_config["fmt"] = "e5m2" + quantization_config["activation_scheme"] = "dynamic" if quantization_config["bits"] == 3: backend = "auto_round:auto_gptq" quantization_config["packing_format"] = backend diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index b3685d9f..aeec7808 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -525,7 +525,7 @@ def tune(args): for file in os.listdir(eval_folder): gguf_file = file user_model = AutoModelForCausalLM.from_pretrained( - eval_folder, gguf_file=gguf_file, device_map="auto" if use_auto_mapping else None) + eval_folder, gguf_file=gguf_file, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(eval_folder, gguf_file=gguf_file) else: if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index a624d9f2..39f0d188 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -262,6 +262,8 @@ def _set_dict_attr(attr_dict, attr_name): if isinstance(scale, dict): _set_dict_attr(scale, "scale") + elif scale.numel()==1: + self.orig_layer.scale = scale.to("cpu") else: self.orig_layer.scale = scale.reshape(shape[0], -1).to("cpu")