From 2198f6f8539848bd11ad44e136b4234cec9e308d Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sun, 22 Sep 2024 08:38:26 +0000 Subject: [PATCH 1/5] feat: add `fp16_inference` option to support fp16 infer when `fp16_inference` is enabled, the model will be loaded as fp16 paramters when inference. --- libai/inference/basic.py | 3 +++ libai/models/utils/model_loader/base_loader.py | 13 +++++++++++-- projects/Llama/configs/llama_config.py | 1 + 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/libai/inference/basic.py b/libai/inference/basic.py index b869e56cc..16dd5b8b6 100644 --- a/libai/inference/basic.py +++ b/libai/inference/basic.py @@ -82,6 +82,9 @@ def __init__( self.model._apply(dist.convert_to_distributed_default_setting) self.model = self.model.eval() + # Release unused memory from the device cache after loading the model + flow.cuda.empty_cache() + # initial tokenizer if dist.is_main_process(): self.tokenizer = self.build_tokenizer(self.cfg) diff --git a/libai/models/utils/model_loader/base_loader.py b/libai/models/utils/model_loader/base_loader.py index e5a58a22a..31eb09e9d 100644 --- a/libai/models/utils/model_loader/base_loader.py +++ b/libai/models/utils/model_loader/base_loader.py @@ -24,7 +24,7 @@ from termcolor import colored import libai.utils.distributed as dist -from libai.config import LazyCall +from libai.config import LazyCall, try_get_key from libai.models.build import build_model logger = logging.getLogger(__name__) @@ -374,6 +374,8 @@ def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs): self.base_model_prefix_2 = None # prefix in LiBai self.origin_libai_cfg = copy.deepcopy(self.libai_cfg) self.changed_keys = set() # Store the changed configuration + self.load_fp16_model = try_get_key(self.origin_libai_cfg, "fp16_inference", default=False) + self.load_fp16_model = bool(self.load_fp16_model) def _convert_tensor(self, tensor): """Convert PyTorch tensor to OneFlow tensor. @@ -384,7 +386,10 @@ def _convert_tensor(self, tensor): Returns: flow.Tensor: The target tensor. """ - return flow.Tensor(tensor.detach().cpu().numpy()) + if self.load_fp16_model: + return flow.tensor(tensor.detach().cpu().numpy(), dtype=flow.float16) + else: + return flow.tensor(tensor.detach().cpu().numpy()) def _convert_tensors(self, torch_state_dict): @@ -610,6 +615,10 @@ def load(self): self.model = build_model(self.model) else: self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg)) + + # If fp16 is specified, convert the model to half-precision before loading weights + if self.load_fp16_model: + self.model = self.model.half() # State_dict to global logger.info("transfering state_dict local to global...") diff --git a/projects/Llama/configs/llama_config.py b/projects/Llama/configs/llama_config.py index 36f95d126..8400b44a0 100644 --- a/projects/Llama/configs/llama_config.py +++ b/projects/Llama/configs/llama_config.py @@ -24,6 +24,7 @@ scale_mask_softmax_fusion=False, amp_enabled=True, # Inference + fp16_inference=True, is_encoder_decoder=False, max_length=256, min_length=0, From 49fc21eb25dadc1d67e5576c476eb159b6fca9ae Mon Sep 17 00:00:00 2001 From: Jianhua Zheng Date: Thu, 19 Sep 2024 10:12:42 +0800 Subject: [PATCH 2/5] update --- projects/ChatGLM/chatglm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/projects/ChatGLM/chatglm.py b/projects/ChatGLM/chatglm.py index cea239878..fda48c103 100644 --- a/projects/ChatGLM/chatglm.py +++ b/projects/ChatGLM/chatglm.py @@ -180,7 +180,8 @@ def scaled_dot_product_attention( def forward(self, query_layer, key_layer, value_layer, attention_mask=None): # query_layer: [sq, b, np, hn] -[premute]-> [batch_size, head_num, seq_len, hidden_size] query_layer, key_layer, value_layer = [ - k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] + # k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] + k.transpose(1, 2).transpose(0, 2) for k in [query_layer, key_layer, value_layer] ] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: context_layer = self.scaled_dot_product_attention( @@ -194,7 +195,8 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask=None): query_layer, key_layer, value_layer, attention_mask ) - context_layer = context_layer.permute(2, 0, 1, 3) + # context_layer = context_layer.permute(2, 0, 1, 3) + context_layer = context_layer.transpose(0, 1).transpose(0, 2) context_layer = context_layer.flatten(2) return context_layer @@ -709,7 +711,8 @@ def get_prompt(self, batch_size): ) # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + past_key_values = past_key_values.transpose(0, 2).split(2) return past_key_values def forward( From 4edb33acd6dc00d375618269e017706f365cd7da Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sun, 22 Sep 2024 17:03:40 +0800 Subject: [PATCH 3/5] support chatglm with fp16 inference --- projects/ChatGLM/configs/chatglm_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/ChatGLM/configs/chatglm_config.py b/projects/ChatGLM/configs/chatglm_config.py index 9fec6f3b9..de4dac399 100644 --- a/projects/ChatGLM/configs/chatglm_config.py +++ b/projects/ChatGLM/configs/chatglm_config.py @@ -40,6 +40,7 @@ use_return_dict=True, amp_enabled=True, # Inference + fp16_inference=False, is_encoder_decoder=False, max_length=1350, min_length=0, From ec1a81a983f322fa6e51f6916b6f18a5163ba689 Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sun, 22 Sep 2024 17:05:22 +0800 Subject: [PATCH 4/5] Revert "update" This reverts commit 49fc21eb25dadc1d67e5576c476eb159b6fca9ae. --- projects/ChatGLM/chatglm.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/projects/ChatGLM/chatglm.py b/projects/ChatGLM/chatglm.py index fda48c103..cea239878 100644 --- a/projects/ChatGLM/chatglm.py +++ b/projects/ChatGLM/chatglm.py @@ -180,8 +180,7 @@ def scaled_dot_product_attention( def forward(self, query_layer, key_layer, value_layer, attention_mask=None): # query_layer: [sq, b, np, hn] -[premute]-> [batch_size, head_num, seq_len, hidden_size] query_layer, key_layer, value_layer = [ - # k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] - k.transpose(1, 2).transpose(0, 2) for k in [query_layer, key_layer, value_layer] + k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer] ] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: context_layer = self.scaled_dot_product_attention( @@ -195,8 +194,7 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask=None): query_layer, key_layer, value_layer, attention_mask ) - # context_layer = context_layer.permute(2, 0, 1, 3) - context_layer = context_layer.transpose(0, 1).transpose(0, 2) + context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.flatten(2) return context_layer @@ -711,8 +709,7 @@ def get_prompt(self, batch_size): ) # seq_len, b, nh, hidden_size past_key_values = self.dropout(past_key_values) - # past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - past_key_values = past_key_values.transpose(0, 2).split(2) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) return past_key_values def forward( From c2a8ef6b182b39f1ec3960c1cf0c7a55872f7062 Mon Sep 17 00:00:00 2001 From: 0x404 <871206929@qq.com> Date: Sun, 22 Sep 2024 17:10:55 +0800 Subject: [PATCH 5/5] set defaults to False --- projects/Llama/configs/llama_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/Llama/configs/llama_config.py b/projects/Llama/configs/llama_config.py index 8400b44a0..98ef145e1 100644 --- a/projects/Llama/configs/llama_config.py +++ b/projects/Llama/configs/llama_config.py @@ -24,7 +24,7 @@ scale_mask_softmax_fusion=False, amp_enabled=True, # Inference - fp16_inference=True, + fp16_inference=False, is_encoder_decoder=False, max_length=256, min_length=0,