diff --git a/libai/inference/basic.py b/libai/inference/basic.py index b869e56cc..16dd5b8b6 100644 --- a/libai/inference/basic.py +++ b/libai/inference/basic.py @@ -82,6 +82,9 @@ def __init__( self.model._apply(dist.convert_to_distributed_default_setting) self.model = self.model.eval() + # Release unused memory from the device cache after loading the model + flow.cuda.empty_cache() + # initial tokenizer if dist.is_main_process(): self.tokenizer = self.build_tokenizer(self.cfg) diff --git a/libai/models/utils/model_loader/base_loader.py b/libai/models/utils/model_loader/base_loader.py index e12294cd3..746176875 100644 --- a/libai/models/utils/model_loader/base_loader.py +++ b/libai/models/utils/model_loader/base_loader.py @@ -24,7 +24,7 @@ from termcolor import colored import libai.utils.distributed as dist -from libai.config import LazyCall +from libai.config import LazyCall, try_get_key from libai.models.build import build_model logger = logging.getLogger(__name__) @@ -374,6 +374,8 @@ def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs): self.base_model_prefix_2 = None # prefix in LiBai self.origin_libai_cfg = copy.deepcopy(self.libai_cfg) self.changed_keys = set() # Store the changed configuration + self.load_fp16_model = try_get_key(self.origin_libai_cfg, "fp16_inference", default=False) + self.load_fp16_model = bool(self.load_fp16_model) def _convert_tensor(self, tensor): """Convert PyTorch tensor to OneFlow tensor. @@ -388,8 +390,13 @@ def _convert_tensor(self, tensor): if tensor.dtype == torch.bfloat16: data = tensor.detach().half().cpu().numpy() - return flow.Tensor(data) - return flow.Tensor(tensor.detach().cpu().numpy()) + else: + data = tensor.detach().cpu().numpy() + + if self.load_fp16_model: + return flow.tensor(data, dtype=flow.float16) + else: + return flow.tensor(data) def _convert_tensors(self, torch_state_dict): @@ -615,6 +622,10 @@ def load(self): self.model = build_model(self.model) else: self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg)) + + # If fp16 is specified, convert the model to half-precision before loading weights + if self.load_fp16_model: + self.model = self.model.half() # State_dict to global logger.info("transfering state_dict local to global...") diff --git a/projects/ChatGLM/configs/chatglm_config.py b/projects/ChatGLM/configs/chatglm_config.py index 9fec6f3b9..de4dac399 100644 --- a/projects/ChatGLM/configs/chatglm_config.py +++ b/projects/ChatGLM/configs/chatglm_config.py @@ -40,6 +40,7 @@ use_return_dict=True, amp_enabled=True, # Inference + fp16_inference=False, is_encoder_decoder=False, max_length=1350, min_length=0, diff --git a/projects/Llama/configs/llama_config.py b/projects/Llama/configs/llama_config.py index 36f95d126..98ef145e1 100644 --- a/projects/Llama/configs/llama_config.py +++ b/projects/Llama/configs/llama_config.py @@ -24,6 +24,7 @@ scale_mask_softmax_fusion=False, amp_enabled=True, # Inference + fp16_inference=False, is_encoder_decoder=False, max_length=256, min_length=0,