Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
fix int8 skip module config (#1682)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Chang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
changwangss and pre-commit-ci[bot] authored Aug 9, 2024
1 parent 5df9c5f commit 6fadb18
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,10 @@ def __init__(
self.double_quant_bits = double_quant_bits
self.double_quant_use_sym = double_quant_use_sym
self.double_quant_group_size = double_quant_group_size
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
# "transformer.output_layer" for chatglm series model.
# "embed_out" for dolly v2 series model.
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules",
["lm_head", "transformer.output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -913,7 +916,8 @@ def __init__(
self.true_sequential = true_sequential
self.layer_wise = layer_wise
self.seq_len = seq_len
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules",
["lm_head", "transformer.output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -1011,7 +1015,8 @@ def __init__(
self.seq_len = seq_len
self.use_double_quant = use_double_quant
self.double_quant_scale_dtype = double_quant_scale_dtype
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules",
["lm_head", "transformer.output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_quant = use_quant
self.use_neural_speed = use_neural_speed
Expand Down Expand Up @@ -1080,7 +1085,8 @@ def __init__(
self.seq_len = seq_len
self.use_double_quant = use_double_quant
self.double_quant_scale_dtype = double_quant_scale_dtype
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules",
["lm_head", "transformer.output_layer", "embed_out"])
self.use_ggml = use_ggml
self.use_neural_speed = use_neural_speed
self.device = kwargs.get("device", "auto")
Expand Down Expand Up @@ -1156,7 +1162,8 @@ def __init__(
self.iters = iters
self.seq_len = seq_len
self.quant_lm_head = quant_lm_head
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules", ["lm_head", "output_layer", "embed_out"])
self.llm_int8_skip_modules = kwargs.get("llm_int8_skip_modules",
["lm_head", "transformer.output_layer", "embed_out"])
if self.quant_lm_head:
self.llm_int8_skip_modules = []
self.use_ggml = use_ggml
Expand Down

0 comments on commit 6fadb18

Please sign in to comment.