diff --git a/LICENSE b/LICENSE index b4d15e39c..6f4236370 100644 --- a/LICENSE +++ b/LICENSE @@ -703,3 +703,28 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--------------------------------------------------------------------------------- +The implementations of qlora in federatedscope/llm/model/model_builder.py and +federatedscope/llm/model/adapter_builder.py are adapted from +https://github.com/artidoro/qlora (MIT License) + +Copyright (c) 2023 Artidoro Pagnoni, Tim Dettmers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/federatedscope/core/auxiliaries/optimizer_builder.py b/federatedscope/core/auxiliaries/optimizer_builder.py index 75a43b031..c88687df7 100644 --- a/federatedscope/core/auxiliaries/optimizer_builder.py +++ b/federatedscope/core/auxiliaries/optimizer_builder.py @@ -54,6 +54,15 @@ def get_optimizer(model, type, lr, **kwargs): **tmp_kwargs) else: return getattr(torch.optim, type)(model, lr, **tmp_kwargs) + elif 'bit.' in type: + type = type.split('.')[-1] + import bitsandbytes + if isinstance(model, torch.nn.Module): + return getattr(bitsandbytes.optim, type)(model.parameters(), + lr, **tmp_kwargs) + else: + return getattr(bitsandbytes.optim, type)(model, lr, + **tmp_kwargs) else: raise NotImplementedError( 'Optimizer {} not implement'.format(type)) diff --git a/federatedscope/core/configs/cfg_computation_quantization.py b/federatedscope/core/configs/cfg_computation_quantization.py new file mode 100644 index 000000000..73110faa9 --- /dev/null +++ b/federatedscope/core/configs/cfg_computation_quantization.py @@ -0,0 +1,41 @@ +import logging + +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + +logger = logging.getLogger(__name__) + + +def extend_computation_quantization_cfg(cfg): + # ---------------------------------------------------------------------- # + # quantization (for memory/computation efficiency) related options + # ---------------------------------------------------------------------- # + cfg.computation_quantization = CN() + + # Params + # ['qlora', 'uniform'] + cfg.computation_quantization.method = 'none' + cfg.computation_quantization.nbits = 4 # [4,8,16] + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_quant_cfg) + + +def assert_quant_cfg(cfg): + + if cfg.quantization.method.lower() not in ['none', 'qlora']: + logger.warning( + 'Quantization for Communication method is expected ' + 'to be one of ["none","qlora"]', + f'but got "{cfg.quantization.method}". So we', + 'change it to "none"') + + if cfg.quantization.method.lower( + ) != 'none' and cfg.quantization.nbits not in [4, 8, 16]: + raise ValueError(f'The value of cfg.quantization.nbits is invalid, ' + f'which is expected to be one on [4, 8, 16] but got ' + f'{cfg.quantization.nbits}.') + + +register_config("computation_quantization", + extend_computation_quantization_cfg) diff --git a/federatedscope/core/trainers/torch_trainer.py b/federatedscope/core/trainers/torch_trainer.py index 86c66e8f0..da637aa40 100644 --- a/federatedscope/core/trainers/torch_trainer.py +++ b/federatedscope/core/trainers/torch_trainer.py @@ -30,7 +30,13 @@ class GeneralTorchTrainer(Trainer): def get_model_para(self): if self.cfg.federate.process_num > 1 or \ self.cfg.federate.share_local_model or \ - self.cfg.llm.deepspeed.use: + self.cfg.llm.deepspeed.use or \ + self.cfg.computation_quantization.method == 'qlora': + # bitsandbytes quantization does not support model discharge + # provided by weiruikuang@gmail.com + # https://github.com/huggingface/transformers/blob/ + # fb7d246951d5f60aa36a7958841dfea72f51fc6b/src/ + # transformers/trainer.py#L506C9-L512C1 return self._param_filter(self.ctx.model.state_dict()) else: return self._param_filter(self.ctx.model.cpu().state_dict()) @@ -467,5 +473,6 @@ def discharge_model(self): return if not self.cfg.federate.share_local_model and \ - not self.cfg.llm.deepspeed.use: + not self.cfg.llm.deepspeed.use and \ + not self.cfg.computation_quantization.method == 'qlora': self.ctx.model.to(torch.device("cpu")) diff --git a/federatedscope/llm/baseline/qlora.yaml b/federatedscope/llm/baseline/qlora.yaml new file mode 100644 index 000000000..6e7ee6062 --- /dev/null +++ b/federatedscope/llm/baseline/qlora.yaml @@ -0,0 +1,47 @@ +use_gpu: True +device: 0 +early_stop: + patience: 0 +federate: + mode: standalone + client_num: 10 + total_round_num: 500 + save_to: "llama.ckpt" + share_local_model: True +data: + root: data/ + type: 'alpaca@llm' + splits: [0.98,0.01,0.01] + splitter: 'iid' +llm: + tok_len: 1000 + chat: + max_len: 2000 + adapter: + use: True + args: [ { 'adapter_package': 'peft', 'adapter_method': 'qlora', 'r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05 } ] +dataloader: + batch_size: 1 +model: + # type: 'decapoda-research/llama-7b-hf@huggingface_llm' + type: 'openlm-research/open_llama_3b@huggingface_llm' + # type: 'gpt2@huggingface_llm' +train: + local_update_steps: 30 + batch_or_epoch: batch + optimizer: + type: bit.SGD + lr: 0.0003 + weight_decay: 0.0005 + momentum: 0.9 + is_enable_half: True +criterion: + type: CrossEntropyLoss +trainer: + type: llmtrainer +eval: + freq: 2 + metrics: ['loss'] + count_flops: False +computation_quantization: + method: qlora \ No newline at end of file diff --git a/federatedscope/llm/model/adapter_builder.py b/federatedscope/llm/model/adapter_builder.py index d2b46ed63..61cace208 100644 --- a/federatedscope/llm/model/adapter_builder.py +++ b/federatedscope/llm/model/adapter_builder.py @@ -28,6 +28,7 @@ def enable_adapter(model, package, adapter, **kwargs): PEFT: https://github.com/huggingface/peft Support methods: LoRA + QLoRA Prefix Tuning P-Tuning Prompt Tuning @@ -38,6 +39,50 @@ def enable_adapter(model, package, adapter, **kwargs): from peft import LoraConfig peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, **kwargs) model = get_peft_model(model, peft_config) + elif adapter == 'qlora': + # The implementation of QLoRA is adapted from + # https://github.com/artidoro/qlora + import bitsandbytes as bnb + from peft import LoraConfig + from peft.tuners.lora import LoraLayer + + def find_all_linear_names(bits, model): + cls = bnb.nn.Linear4bit if bits == 4 else \ + (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear) + lora_module_names = set() + for name, module in model.named_modules(): + if isinstance(module, cls): + names = name.split('.') + lora_module_names.add(names[0] if len(names) == + 1 else names[-1]) + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + peft_config = LoraConfig( + r=kwargs['r'], + lora_alpha=kwargs['lora_alpha'], + target_modules=find_all_linear_names(bits=4, model=model), + lora_dropout=kwargs['lora_dropout'], + bias="none", + task_type=TaskType.CAUSAL_LM, + ) + # without the following line, an error with + # `element 0 of tensors does not require grad + # and does not have a grad_fn` + # would be caused + # @https://github.com/huggingface/peft/issues/137 + model.enable_input_require_grads() + model = get_peft_model(model, peft_config) + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + module = module.to(torch.float16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if module.weight.dtype == torch.float32: + module = module.to(torch.float16) elif adapter == 'prefix': from peft import PrefixTuningConfig peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, diff --git a/federatedscope/llm/model/model_builder.py b/federatedscope/llm/model/model_builder.py index 7bacb55d3..b082b3fdb 100644 --- a/federatedscope/llm/model/model_builder.py +++ b/federatedscope/llm/model/model_builder.py @@ -18,7 +18,30 @@ def get_model_from_huggingface(model_name, config): kwargs = {} if len(config.llm.cache.model): kwargs['cache_dir'] = config.llm.cache.model - + if config.computation_quantization.method == 'qlora': + from transformers import BitsAndBytesConfig + import torch + from peft import prepare_model_for_kbit_training + model = AutoModelForCausalLM.from_pretrained( + model_name, + load_in_4bit=True, + load_in_8bit=False, + device_map=config.device, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + # bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4'), + torch_dtype=torch.bfloat16, + trust_remote_code=False, + # use_auth_token=False + ) + return prepare_model_for_kbit_training(model, + use_gradient_checkpointing=True) return AutoModelForCausalLM.from_pretrained(model_name, **kwargs) diff --git a/setup.py b/setup.py index bf7a93db1..6cee76323 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,10 @@ 'tokenizers==0.13.3', 'transformers==4.29.2', 'accelerate==0.20.3', - 'peft==0.3.0', + # required by QLoRA: prepare_model_for_kbit_training + 'peft==0.4.0', + # required by QLoRA + 'bitsandbytes==0.41.1', 'sentencepiece==0.1.99', ]