Skip to content

Commit

Permalink
Enhance layer-wise quant and fix bug (#27)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Mengni <[email protected]>
  • Loading branch information
mengniwang95 authored Jul 24, 2024
1 parent 892f93d commit 4d596e1
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,9 @@ def rewind(self):
if args.tune:
model_name = "model.onnx" # require optimum >= 1.14.0
model_path = os.path.join(args.model_path, model_name)

best_model = None
if args.algorithm.upper() == "RTN":
algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig()
algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig(layer_wise_quant=True)
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
model_path,
n_bits=4,
Expand Down Expand Up @@ -358,7 +357,7 @@ def rewind(self):
elif args.algorithm.upper() == "GPTQ":
calibration_data_reader = GPTQDataloader(model_path, seqlen=args.seqlen, batch_size=1)
algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig(
calibration_data_reader=calibration_data_reader,
calibration_data_reader=calibration_data_reader, layer_wise_quant=True
)
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
model_path,
Expand Down
82 changes: 43 additions & 39 deletions onnx_neural_compressor/algorithms/layer_wise/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,15 @@ def layer_wise_quant(
Returns:
_type_: _description_
"""
# check whether model shape is inferred
if not _check_model_with_infer_shapes(model):
logger.error(
"Before applying layer-wise quantization, please make sure to "
"run symbolic shape inference on your model like follows:\n"
"import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n"
"model = onnx.load(your_model_path)\n"
"out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n"
"onnx.save(out, infer_shape_model_path)\n"
)
raise ValueError("Fail to run layer-wise quantization.")
logger.warning(
"Layer-wise quantization requires data_type info for some tensors. "
"We will try to infer the data_type automatically if it doesn't exist."
"You can use model with symbolic shape inference before layer-wise quantization as well like follows:\n"
"import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer\n"
"model = onnx.load(your_model_path)\n"
"out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(model, auto_merge=True)\n"
"onnx.save_model(out, infer_shape_model_path, save_as_external_data=True)\n"
)

if not isinstance(model, onnx_model.ONNXModel):
model = onnx_model.ONNXModel(model, ignore_warning=True, load_external_data=False)
Expand Down Expand Up @@ -101,7 +99,7 @@ def layer_wise_quant(
split_model = model_to_split.pop(0)
split_node = split_nodes.pop(0)
if require_data_reader:
current_data_reader = lwq_data_reader.pop(0)
complete_data_reader = lwq_data_reader.pop(0)

# if no remaining split nodes, it means this is the last split, and the two split models will be saved.
save_both_split_models = True if len(split_nodes) == 0 else False
Expand All @@ -110,20 +108,28 @@ def layer_wise_quant(
split_model_part_1, split_model_part_2 = split_model.split_model_with_node(
split_node.name, model.model_path, save_both_split_models
)

if not save_both_split_models:
# append split_model_part_2 to do next split
model_to_split.append(split_model_part_2)

logger.info("Quantize split model {}".format(split_idx))

if require_data_reader:
# process data_reader for current split and next split
current_data_reader = _filter_data_reader_for_current_split_model(
split_model_part_1.model, current_data_reader
split_model_part_1.model, complete_data_reader
)
next_data_reader = _prepare_data_reader_for_next_split_model(
split_model_part_1.model_path, current_data_reader, providers

# complete_data_reader contains split_model_part_1 output data
complete_data_reader = _prepare_data_reader_for_next_split_model(
split_model_part_1.model_path,
[i.name for i in split_model_part_2.model.graph.input],
complete_data_reader,
providers,
)
lwq_data_reader.append(next_data_reader)

lwq_data_reader.append(complete_data_reader)

# perform quantization
split_model_part_1_quantized = quant_func(
Expand All @@ -141,7 +147,7 @@ def layer_wise_quant(

# check split model is valid
try:
ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers)
ort.InferenceSession(split_model_part_1_quantized.model_path, providers=providers)
except Exception as e:
logger.error(
"Layer-wise quantized model {} can't be inferred correctly. "
Expand All @@ -166,7 +172,7 @@ def layer_wise_quant(
# process data_reader for current split
current_data_reader = lwq_data_reader.pop(0)
current_data_reader = _filter_data_reader_for_current_split_model(
split_model_part_2.model, current_data_reader
split_model_part_2.model, complete_data_reader
)

# perform quantization
Expand All @@ -185,7 +191,7 @@ def layer_wise_quant(

# check split model is valid
try:
ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers)
ort.InferenceSession(split_model_part_2_quantized.model_path, providers=providers)
except Exception as e:
logger.error(
"Layer-wise quantized model {} can't be inferred correctly. "
Expand All @@ -204,7 +210,6 @@ def layer_wise_quant(
onnx.external_data_helper.load_external_data_for_model(
quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path)
)

return quantized_model_merged


Expand All @@ -222,31 +227,38 @@ def rewind(self):
self.iter_next = iter(self.data_list)


def _filter_data_reader_for_current_split_model(model: onnx.ModelProto, data_reader: data_reader.CalibrationDataReader):
def _filter_data_reader_for_current_split_model(
model: onnx.ModelProto,
current_data_reader: data_reader.CalibrationDataReader,
):
"""Filter data reader to remove data that is not in model input.
Args:
model (onnx.ModelProto): onnx model.
data_reader (data_reader.CalibrationDataReader): data reader.
current_data_reader (data_reader.CalibrationDataReader): data reader of current split model.
Returns:
data_reader.CalibrationDataReader: filtered data reader.
"""
filter_inputs = []
input_names = [input.name for input in model.graph.input]
current_data_reader.rewind()

while True:
inputs = data_reader.get_next()
inputs = current_data_reader.get_next()
if not inputs:
break
filter_input = {
input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names
}
filter_inputs.append(filter_input)

return DataReader(filter_inputs)


def _prepare_data_reader_for_next_split_model(
model_path: str,
next_model_input_names: list,
data_reader: data_reader.CalibrationDataReader,
providers: List[str] = ["CPUExecutionProvider"],
):
Expand All @@ -262,27 +274,19 @@ def _prepare_data_reader_for_next_split_model(
Returns:
data_reader.CalibrationDataReader: data reader for next split model.
"""
data_reader = copy.deepcopy(data_reader)

data_reader.rewind()
data_reader_for_next_split_model = []
session = ort.InferenceSession(model_path, providers=providers)
output_names = [output.name for output in session.get_outputs()]
input_names = [input.name for input in session.get_inputs()]
while True:
inputs = data_reader.get_next()
if not inputs:
break
out = session.run(None, inputs)
inputs.update({name: value for name, value in zip(output_names, out)})
data_reader_for_next_split_model.append(inputs)
out = session.run(None, {name: inputs[name] for name in input_names})
filter_input = {name: value for name, value in zip(output_names, out)}
for name, value in inputs.items():
if name in next_model_input_names and name not in filter_input:
filter_input[name] = value
data_reader_for_next_split_model.append(filter_input)
return DataReader(data_reader_for_next_split_model)


def _check_model_with_infer_shapes(model):
"""Check if the model has been shape inferred."""
if isinstance(model, (pathlib.Path, str)):
model = onnx.load(model, load_external_data=False)
elif isinstance(model, onnx_model.ONNXModel):
model = model.model
if len(model.graph.value_info) > 0:
return True
return False
13 changes: 8 additions & 5 deletions onnx_neural_compressor/algorithms/weight_only/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,10 @@ def _apply_awq_scale(model, weight_config, absorb_pairs, output_dicts):
if init_share_num == 1:
model.remove_initializer(weight_tensor)

if parent is None:
continue
parent = model.get_node(parent)
if parent.name in updated_nodes:
if parent is None or parent.name in updated_nodes:
continue

if parent.op_type in ["LayerNormalization", "BatchNormalization", "InstanceNormalization"] and len(
Expand Down Expand Up @@ -363,8 +365,9 @@ def awq_quantize(
output_name_to_node = model.output_name_to_node()
input_name_to_nodes = model.input_name_to_nodes()
for input_name in output_names:
parent = output_name_to_node[input_name]
dump_pairs = {parent.name: []}
# input_name maybe the input of graph and there is no parent node
parent = output_name_to_node[input_name].name if input_name in output_name_to_node else None
dump_pairs = {parent: []}

for node in input_name_to_nodes[input_name]:
# check op_type of node is MatMul
Expand All @@ -375,9 +378,9 @@ def awq_quantize(
and model.get_initializer(node.input[1]) is not None
and weight_config.get(node.name, {}).get("weight_dtype", "fp32") != "fp32"
):
dump_pairs[parent.name].append(model.get_node(node.name))
dump_pairs[parent].append(model.get_node(node.name))

if len(dump_pairs[parent.name]) == 0: # pragma: no cover
if len(dump_pairs[parent]) == 0: # pragma: no cover
continue

output_dicts = {}
Expand Down
5 changes: 3 additions & 2 deletions onnx_neural_compressor/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def gptq_quantize(
k_blocks = (org_shape[0] + group_size - 1) // group_size
q_weight = quant_utils.pad_tensor(q_weight, group_size, k_blocks)
_, _, zp, scale, q_weight = quant_utils.quantize_data(
q_weight.T,
q_weight.T.reshape((-1, group_size)),
"uint" + str(num_bits),
sym,
axis=1,
Expand All @@ -345,7 +345,7 @@ def gptq_quantize(
num_bits=num_bits,
group_size=group_size,
k_blocks=k_blocks,
q_weight=q_weight.astype("uint8"),
q_weight=q_weight,
scale=scale.astype(dtype),
zero_point=zp if not sym else None,
accuracy_level=accuracy_level,
Expand Down Expand Up @@ -380,6 +380,7 @@ def gptq_quantize(
if return_modelproto:
return model.model
else:
model.save(model.model_path + "_quant.onnx")
return model


Expand Down
3 changes: 2 additions & 1 deletion onnx_neural_compressor/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def rtn_quantize(
num_bits=num_bits,
group_size=group_size,
k_blocks=k_blocks,
q_weight=q_weight.astype("uint8"),
q_weight=q_weight,
scale=scale.astype(dtype),
zero_point=zp if not sym else None,
accuracy_level=accuracy_level,
Expand Down Expand Up @@ -167,6 +167,7 @@ def rtn_quantize(
if return_modelproto:
return model.model
else:
model.save(model.model_path + "_quant.onnx")
return model


Expand Down
Loading

0 comments on commit 4d596e1

Please sign in to comment.