diff --git a/atomgpt/examples/inverse_model/config.json b/atomgpt/examples/inverse_model/config.json index e6008ea..4558a27 100644 --- a/atomgpt/examples/inverse_model/config.json +++ b/atomgpt/examples/inverse_model/config.json @@ -11,10 +11,10 @@ "per_device_train_batch_size": 2, "gradient_accumulation_steps": 4, "num_train": 2, - "num_val": 2, + "num_val": 0, "num_test": 2, "model_save_path": "lora_model_m", - "loss_type": "atomgpt_structure", + "loss_type": "default", "optim": "adamw_8bit", "lr_scheduler_type": "linear", "output_dir": "outputs", @@ -26,4 +26,4 @@ "instruction": "Below is a description of a superconductor material.", "alpaca_prompt": "### Instruction:\n{}\n### Input:\n{}\n### Output:\n{}", "output_prompt": " Generate atomic structure description with lattice lengths, angles, coordinates and atom types." -} \ No newline at end of file +} diff --git a/atomgpt/inverse_models/_utils.py b/atomgpt/inverse_models/_utils.py index 6a1d3cf..b807631 100644 --- a/atomgpt/inverse_models/_utils.py +++ b/atomgpt/inverse_models/_utils.py @@ -1,16 +1,28 @@ import torch from typing import Union, Optional, List, Any, Callable import warnings -warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") -warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub") -warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess") -warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers") -warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate") -warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub") + +warnings.filterwarnings(action="ignore", category=UserWarning, module="torch") +warnings.filterwarnings( + action="ignore", category=UserWarning, module="huggingface_hub" +) +warnings.filterwarnings( + action="ignore", category=RuntimeWarning, module="subprocess" +) +warnings.filterwarnings( + action="ignore", category=UserWarning, module="transformers" +) +warnings.filterwarnings( + action="ignore", category=FutureWarning, module="accelerate" +) +warnings.filterwarnings( + action="ignore", category=FutureWarning, module="huggingface_hub" +) import bitsandbytes as bnb from transformers.models.llama.modeling_llama import logger from transformers import AutoTokenizer from platform import system as platform_system + platform_system = platform_system() import math import numpy as np @@ -20,20 +32,27 @@ __version__ = "2024.5" # Get Flash Attention v2 if Ampere (RTX 30xx, A100) -major_version, minor_version = torch.cuda.get_device_capability() +try: + major_version, minor_version = torch.cuda.get_device_capability() + print("major_version major_version major_version", major_version) +except Exception as exp: + major_version = 7 + pass if major_version >= 8: try: from flash_attn import flash_attn_func + # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl" try: from flash_attn.flash_attn_interface import flash_attn_cuda + HAS_FLASH_ATTENTION = True except: logger.warning_once( - "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\ - "A possible explanation is you have a new CUDA version which isn't\n"\ - "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\ - "We shall now use Xformers instead, which gets a 0.01% performance hit.\n"\ + "Unsloth: Your Flash Attention 2 installation seems to be broken?\n" + "A possible explanation is you have a new CUDA version which isn't\n" + "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n" + "We shall now use Xformers instead, which gets a 0.01% performance hit.\n" "We found this negligible impact by benchmarking on 1x A100." ) HAS_FLASH_ATTENTION = False @@ -44,6 +63,7 @@ HAS_FLASH_ATTENTION = False pass import xformers.ops.fmha as xformers + xformers_attention = xformers.memory_efficient_attention from xformers import __version__ as xformers_version @@ -62,9 +82,9 @@ def prepare_model_for_kbit_training( - model : Any, - use_gradient_checkpointing : Optional = True, - use_reentrant : Optional[bool] = True, + model: Any, + use_gradient_checkpointing: Optional = True, + use_reentrant: Optional[bool] = True, ) -> Any: """ Calculates where to place the gradient checkpoints given n_layers. @@ -83,15 +103,22 @@ def prepare_model_for_kbit_training( # Freeze all parameters except LoRA import re + with torch.no_grad(): for name, param in model.named_parameters(): - if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name: + if ( + ".lora_A." in name + or ".lora_B." in name + or ".lora_magnitude_vector" in name + ): param.requires_grad_(True) # Also must be in float32! if param.dtype != torch.float32: name = name.replace("base_model", "model", 1) layer_number = re.search(r"\.[\d]{1,}\.", name).group(0) - name = name.replace(layer_number, f"[{layer_number[1:-1]}].") + name = name.replace( + layer_number, f"[{layer_number[1:-1]}]." + ) name = name.replace(".weight", "", 1) exec(f"{name}.to(torch.float32)") pass @@ -110,7 +137,7 @@ def prepare_model_for_kbit_training( original_model = original_model.model pass original_model._offloaded_gradient_checkpointing = True - + model.gradient_checkpointing_enable() elif use_gradient_checkpointing == True: @@ -122,26 +149,35 @@ def prepare_model_for_kbit_training( if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: + def make_inputs_require_grad(module, input, output): output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model.get_input_embeddings().register_forward_hook( + make_inputs_require_grad + ) return model + + pass def patch_tokenizer(model, tokenizer): """ - Phi3's pad_token isn't set. We set it to <|placeholder... - Llama-3 is <|reserved... - Llama-2 is - Check if pad_token is not the same as eos_token otherwise the loss will ignore it!! - Fixes https://github.com/unslothai/unsloth/issues/5 + Phi3's pad_token isn't set. We set it to <|placeholder... + Llama-3 is <|reserved... + Llama-2 is + Check if pad_token is not the same as eos_token otherwise the loss will ignore it!! + Fixes https://github.com/unslothai/unsloth/issues/5 """ - possible_reserved_tokens = ("<|reserved", "<|placeholder",) + possible_reserved_tokens = ( + "<|reserved", + "<|placeholder", + ) if model is not None: - model.config.update({"unsloth_version" : __version__}) + model.config.update({"unsloth_version": __version__}) bad_pad_token = False if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None: @@ -155,7 +191,9 @@ def patch_tokenizer(model, tokenizer): if bad_pad_token: # Find a better pad token - added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()] + added_tokens = [ + str(x) for x in tokenizer.added_tokens_decoder.values() + ] possible_pad_token = None for added_token in added_tokens[::-1]: if added_token.startswith(possible_reserved_tokens): @@ -181,14 +219,18 @@ def patch_tokenizer(model, tokenizer): logger.warning_once( f"{name} does not have a padding token! Will use pad_token = {possible_pad_token}." ) - + # Edit pad_token - tokenizer.add_special_tokens({"pad_token" : possible_pad_token}) + tokenizer.add_special_tokens({"pad_token": possible_pad_token}) tokenizer.pad_token = possible_pad_token if model is not None: - config = model.config.update({"pad_token_id" : tokenizer.pad_token_id}) + config = model.config.update( + {"pad_token_id": tokenizer.pad_token_id} + ) pass return model, tokenizer + + pass @@ -196,13 +238,14 @@ def patch_tokenizer(model, tokenizer): # For mixed precision, we need it to be in float32 not float16. from peft.tuners.lora.layer import LoraLayer import inspect, re + try: source = inspect.getsource(LoraLayer.update_layer) text = "if weight is not None:\n" start = source.find(text) + len(text) end = source.find("self.to(weight.device)", start) - spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0] - source = source.replace(source[start : end], spaces) + spaces = re.findall(r"^([ ]{1,})break", source, flags=re.MULTILINE)[0] + source = source.replace(source[start:end], spaces) spaces = len(re.match(r"[\s]{1,}", source).group(0)) lines = source.split("\n") source = "\n".join(x[spaces:] for x in lines) @@ -212,12 +255,14 @@ def patch_tokenizer(model, tokenizer): # Fix up incorrect downcasting of LoRA weights from peft.tuners.lora.layer import LoraLayer + LoraLayer.update_layer = LoraLayer_update_layer from peft.tuners.lora import LoraLayer + LoraLayer.update_layer = LoraLayer_update_layer except: logger.warning_once( - "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\ + "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n" "Luckily, your training run will still work in the meantime!" ) pass @@ -229,20 +274,33 @@ def get_statistics(): # This is simply so we can check if some envs are broken or not. try: from huggingface_hub import hf_hub_download - from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled + from huggingface_hub.utils import ( + disable_progress_bars, + enable_progress_bars, + are_progress_bars_disabled, + ) import psutil - n_cpus = psutil.cpu_count(logical = False) + + n_cpus = psutil.cpu_count(logical=False) keynames = "\n" + "\n".join(os.environ.keys()) statistics = None - if "\nCOLAB_" in keynames and n_cpus == 1: statistics = "colab" - elif "\nCOLAB_" in keynames: statistics = "colabpro" - elif "\nKAGGLE_" in keynames: statistics = "kaggle" - elif "\nRUNPOD_" in keynames: statistics = "runpod" - elif "\nAWS_" in keynames: statistics = "aws" - elif "\nAZURE_" in keynames: statistics = "azure" - elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp" - elif "\nINVOCATION_ID" in keynames: statistics = "lambda" + if "\nCOLAB_" in keynames and n_cpus == 1: + statistics = "colab" + elif "\nCOLAB_" in keynames: + statistics = "colabpro" + elif "\nKAGGLE_" in keynames: + statistics = "kaggle" + elif "\nRUNPOD_" in keynames: + statistics = "runpod" + elif "\nAWS_" in keynames: + statistics = "aws" + elif "\nAZURE_" in keynames: + statistics = "azure" + elif "\nK_" in keynames or "\nFUNCTION_" in keynames: + statistics = "gcp" + elif "\nINVOCATION_ID" in keynames: + statistics = "lambda" if statistics is not None: disabled = False @@ -250,23 +308,30 @@ def get_statistics(): disable_progress_bars() disabled = True pass - hf_hub_download(f"unslothai/statistics-{statistics}", "README.md", force_download = True) + hf_hub_download( + f"unslothai/statistics-{statistics}", + "README.md", + force_download=True, + ) if disabled: enable_progress_bars() pass pass except: pass + + pass def _calculate_n_gradient_checkpoints( - n_layers : int, - method : Optional[Union[str, int]] = "sqrt", + n_layers: int, + method: Optional[Union[str, int]] = "sqrt", ) -> List[int]: - assert(type(n_layers) is int and n_layers > 0) + assert type(n_layers) is int and n_layers > 0 - if method is None: method = "sqrt" + if method is None: + method = "sqrt" if method == "sqrt": n_checkpoints = int(n_layers**0.5) @@ -276,39 +341,45 @@ def _calculate_n_gradient_checkpoints( raise ValueError("method must be 'sqrt' or an int >0 and <= n_layers.") size = n_layers // n_checkpoints - sizes = np.full(n_checkpoints, size, dtype = int) + sizes = np.full(n_checkpoints, size, dtype=int) leftovers = n_layers % n_checkpoints # We append leftovers from the right for k in range(leftovers): - sizes[n_checkpoints-1-k] += 1 + sizes[n_checkpoints - 1 - k] += 1 boundaries = np.hstack((0, np.cumsum(sizes))) boundaries = boundaries.tolist() return boundaries + + pass def calculate_n_gradient_checkpoints( - n_layers : int, - layers_per_checkpoint : Optional[Union[str, int]] = "sqrt", + n_layers: int, + layers_per_checkpoint: Optional[Union[str, int]] = "sqrt", ) -> List[int]: - assert(type(n_layers) is int and n_layers > 0) + assert type(n_layers) is int and n_layers > 0 if layers_per_checkpoint is None or layers_per_checkpoint == 1: return None - boundaries = _calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint) + boundaries = _calculate_n_gradient_checkpoints( + n_layers, layers_per_checkpoint + ) - assert(boundaries[0] == 0 and boundaries[-1] == n_layers) - assert(min(boundaries) == 0 and max(boundaries) == n_layers) - assert(np.diff(boundaries).min() >= 0) + assert boundaries[0] == 0 and boundaries[-1] == n_layers + assert min(boundaries) == 0 and max(boundaries) == n_layers + assert np.diff(boundaries).min() >= 0 return boundaries + + pass def prepare_n_gradient_checkpoints( - model : Any, - layers_per_checkpoint : Optional[Union[str, int]] = "sqrt", - use_reentrant : Optional[bool] = True, + model: Any, + layers_per_checkpoint: Optional[Union[str, int]] = "sqrt", + use_reentrant: Optional[bool] = True, ) -> None: """ Calculates where to place the gradient checkpoints given n_layers. @@ -331,7 +402,9 @@ def prepare_n_gradient_checkpoints( if hasattr(model.model, "layers"): _model = model.model if _model is None: - raise TypeError("`model` or `model.model` does not have attribute `layers`. Are you sure this is a model?") + raise TypeError( + "`model` or `model.model` does not have attribute `layers`. Are you sure this is a model?" + ) pass if use_reentrant is False: @@ -339,9 +412,13 @@ def prepare_n_gradient_checkpoints( pass n_layers = len(_model.layers) - boundaries = calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint) - _model._gradient_checkpointing_boundaries = boundaries + boundaries = calculate_n_gradient_checkpoints( + n_layers, layers_per_checkpoint + ) + _model._gradient_checkpointing_boundaries = boundaries _model._gradient_checkpointing_use_reentrant = use_reentrant + + pass @@ -350,29 +427,39 @@ class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): Saves VRAM by smartly offloading to RAM. Tiny hit to performance, since we mask the movement via non blocking calls. """ + @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, forward_function, hidden_states, *args): - saved_hidden_states = hidden_states.to("cpu", non_blocking = True) + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) with torch.no_grad(): (output,) = forward_function(hidden_states, *args) ctx.save_for_backward(saved_hidden_states) ctx.forward_function = forward_function ctx.args = args return output + pass @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dY): (hidden_states,) = ctx.saved_tensors - hidden_states = hidden_states.to("cuda", non_blocking = True).detach() + hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states.requires_grad = True with torch.enable_grad(): (output,) = ctx.forward_function(hidden_states, *ctx.args) torch.autograd.backward(output, dY) - return (None, hidden_states.grad,) + (None,)*len(ctx.args) + return ( + None, + hidden_states.grad, + ) + ( + None, + ) * len(ctx.args) + pass + + pass @@ -380,30 +467,41 @@ def backward(ctx, dY): Remove warnings about missing kwargs """ try: - from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod + from transformers.utils.quantization_config import ( + BitsAndBytesConfig, + QuantizationMethod, + ) from inspect import getsource import re + BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__) BitsAndBytesConfig__init__ = re.sub( r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n", "", BitsAndBytesConfig__init__, - flags = re.MULTILINE, + flags=re.MULTILINE, ) BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n") - length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0)) - BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__) + length_spaces = len( + re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0) + ) + BitsAndBytesConfig__init__ = "\n".join( + x[length_spaces:] for x in BitsAndBytesConfig__init__ + ) BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace( "__init__", "_BitsAndBytesConfig__init__", ) exec(BitsAndBytesConfig__init__, globals()) - + import transformers.utils.quantization_config - transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__ + + transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = ( + _BitsAndBytesConfig__init__ + ) except: logger.warning_once( - "Unsloth unsuccessfully patched bitsandbytes. Please file a bug report.\n"\ + "Unsloth unsuccessfully patched bitsandbytes. Please file a bug report.\n" "Luckily, your training run will still work in the meantime!" ) pass diff --git a/atomgpt/inverse_models/inverse_models.py b/atomgpt/inverse_models/inverse_models.py index ad841d9..9b2111e 100644 --- a/atomgpt/inverse_models/inverse_models.py +++ b/atomgpt/inverse_models/inverse_models.py @@ -191,7 +191,8 @@ def run_atomgpt_inverse(config_file="config.json"): dat.append(info) train_ids = ids[0:num_train] - test_ids = ids[num_train:] + test_ids = ids[num_train : num_train + num_test] + # test_ids = ids[num_train:] m_train = make_alpaca_json( dataset=dat, @@ -302,7 +303,7 @@ def run_atomgpt_inverse(config_file="config.json"): load_in_4bit=config.load_in_4bit, ) FastLanguageModel.for_inference(model) # Enable native 2x faster inference - + print("Testing\n\n\n\n", len(m_test)) f = open(config.csv_out, "w") f.write("id,target,prediction\n") diff --git a/atomgpt/inverse_models/kernels/utils.py b/atomgpt/inverse_models/kernels/utils.py index 9f56d20..321d5f9 100644 --- a/atomgpt/inverse_models/kernels/utils.py +++ b/atomgpt/inverse_models/kernels/utils.py @@ -1,67 +1,102 @@ import triton + MAX_FUSED_SIZE = 65536 next_power_of_2 = triton.next_power_of_2 + def calculate_settings(n): BLOCK_SIZE = next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: - raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ - f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") + raise RuntimeError( + f"Cannot launch Triton kernel since n = {n} exceeds " + f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}." + ) num_warps = 4 - if BLOCK_SIZE >= 32768: num_warps = 32 - elif BLOCK_SIZE >= 8192: num_warps = 16 - elif BLOCK_SIZE >= 2048: num_warps = 8 + if BLOCK_SIZE >= 32768: + num_warps = 32 + elif BLOCK_SIZE >= 8192: + num_warps = 16 + elif BLOCK_SIZE >= 2048: + num_warps = 8 return BLOCK_SIZE, num_warps + + pass import bitsandbytes as bnb + get_ptr = bnb.functional.get_ptr import ctypes import torch -cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 -cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 -cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 -cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 -cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 + +try: + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + cdequantize_blockwise_fp16_nf4 = ( + bnb.functional.lib.cdequantize_blockwise_fp16_nf4 + ) + cdequantize_blockwise_bf16_nf4 = ( + bnb.functional.lib.cdequantize_blockwise_bf16_nf4 + ) + cgemm_4bit_inference_naive_fp16 = ( + bnb.functional.lib.cgemm_4bit_inference_naive_fp16 + ) + cgemm_4bit_inference_naive_bf16 = ( + bnb.functional.lib.cgemm_4bit_inference_naive_bf16 + ) +except Exception as exp: + print("Check if running on GPU") + pass def QUANT_STATE(W): return getattr(W, "quant_state", None) + + pass def get_lora_parameters(proj): # For DPO or disabled adapters - base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) + base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj W = base_layer.weight - if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + if ( + not hasattr(proj, "disable_adapters") + or proj.disable_adapters + or proj.merged + ): return W, QUANT_STATE(W), None, None, None pass - active_adapter = proj.active_adapters[0] if \ - hasattr(proj, "active_adapters") else proj.active_adapter - A = proj.lora_A [active_adapter].weight - B = proj.lora_B [active_adapter].weight + active_adapter = ( + proj.active_adapters[0] + if hasattr(proj, "active_adapters") + else proj.active_adapter + ) + A = proj.lora_A[active_adapter].weight + B = proj.lora_B[active_adapter].weight s = proj.scaling[active_adapter] return W, QUANT_STATE(W), A, B, s + + pass -def fast_dequantize(W, quant_state = None, out = None): - if quant_state is None: return W +def fast_dequantize(W, quant_state=None, out=None): + if quant_state is None: + return W if type(quant_state) is not list: # New quant_state as a class # https://github.com/TimDettmers/bitsandbytes/pull/763/files - absmax = quant_state.absmax - shape = quant_state.shape - dtype = quant_state.dtype - blocksize = quant_state.blocksize - offset = quant_state.offset - state2 = quant_state.state2 - absmax2 = state2.absmax - code2 = state2.code + absmax = quant_state.absmax + shape = quant_state.shape + dtype = quant_state.dtype + blocksize = quant_state.blocksize + offset = quant_state.offset + state2 = quant_state.state2 + absmax2 = state2.absmax + code2 = state2.code blocksize2 = state2.blocksize else: # Old quant_state as a list of lists @@ -72,36 +107,54 @@ def fast_dequantize(W, quant_state = None, out = None): # Create weight matrix if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda") + out = torch.empty(shape, dtype=dtype, device="cuda") else: - assert(out.shape == shape) - assert(out.dtype == dtype) + assert out.shape == shape + assert out.dtype == dtype # NF4 dequantization of statistics n_elements_absmax = absmax.numel() - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda") + out_absmax = torch.empty( + n_elements_absmax, dtype=torch.float32, device="cuda" + ) # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax) + get_ptr(code2), + get_ptr(absmax), + get_ptr(absmax2), + ptr_out_absmax, + ctypes.c_int(blocksize2), + ctypes.c_int(n_elements_absmax), ) out_absmax += offset - fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ - cdequantize_blockwise_bf16_nf4 - fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel())) + fx = ( + cdequantize_blockwise_fp16_nf4 + if dtype == torch.float16 + else cdequantize_blockwise_bf16_nf4 + ) + fx( + get_ptr(None), + get_ptr(W), + ptr_out_absmax, + get_ptr(out), + ctypes.c_int(blocksize), + ctypes.c_int(out.numel()), + ) # Careful returning transposed data - is_transposed = (True if W.shape[0] == 1 else False) + is_transposed = True if W.shape[0] == 1 else False return out.t() if is_transposed else out + + pass -def fast_gemv(X, W, quant_state, out = None): - if quant_state is None: return torch.matmul(X, W, out = out) +def fast_gemv(X, W, quant_state, out=None): + if quant_state is None: + return torch.matmul(X, W, out=out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 _, q_len, hd = X.shape @@ -109,18 +162,26 @@ def fast_gemv(X, W, quant_state, out = None): if type(quant_state) is not list: # https://github.com/TimDettmers/bitsandbytes/pull/763/files - absmax = quant_state.absmax - shape = quant_state.shape - dtype = quant_state.dtype - blocksize = quant_state.blocksize - stats = quant_state.code - offset = quant_state.offset - state2 = quant_state.state2 - absmax2 = state2.absmax - code2 = state2.code + absmax = quant_state.absmax + shape = quant_state.shape + dtype = quant_state.dtype + blocksize = quant_state.blocksize + stats = quant_state.code + offset = quant_state.offset + state2 = quant_state.state2 + absmax2 = state2.absmax + code2 = state2.code blocksize2 = state2.blocksize else: - absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state + ( + absmax, + shape, + dtype, + blocksize, + compressed_stats, + quant_type, + stats, + ) = quant_state offset, state2 = compressed_stats absmax2, code2, blocksize2, _, _, _, _ = state2 pass @@ -128,7 +189,15 @@ def fast_gemv(X, W, quant_state, out = None): bout = shape[0] if out is None: - out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda") + out = torch.empty( + ( + 1, + 1, + bout, + ), + dtype=dtype, + device="cuda", + ) # else: # assert(out.shape == (1, 1, bout,)) # pass @@ -138,7 +207,7 @@ def fast_gemv(X, W, quant_state, out = None): k = shape[1] lda = shape[0] ldc = shape[0] - ldb = (hd+1)//2 + ldb = (hd + 1) // 2 m = ctypes.c_int32(m) n = ctypes.c_int32(n) k = ctypes.c_int32(k) @@ -146,38 +215,60 @@ def fast_gemv(X, W, quant_state, out = None): ldb = ctypes.c_int32(ldb) ldc = ctypes.c_int32(ldc) - df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda") + df = torch.empty(absmax.shape, dtype=torch.float32, device="cuda") cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), + get_ptr(code2), + get_ptr(absmax), + get_ptr(absmax2), + get_ptr(df), + ctypes.c_int(blocksize2), + ctypes.c_int(df.numel()), ) df += offset absmax = df - fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ - cgemm_4bit_inference_naive_bf16 + fx = ( + cgemm_4bit_inference_naive_fp16 + if dtype == torch.float16 + else cgemm_4bit_inference_naive_bf16 + ) blocksize = ctypes.c_int32(blocksize) - fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), - lda, ldb, ldc, blocksize) + fx( + m, + n, + k, + get_ptr(X), + get_ptr(W), + get_ptr(absmax), + get_ptr(stats), + get_ptr(out), + lda, + ldb, + ldc, + blocksize, + ) return out + + pass -def fast_linear_forward(proj, X, temp_lora = None, out = None): +def fast_linear_forward(proj, X, temp_lora=None, out=None): W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj) bsz, q_len, in_dim = X.shape - if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) + if q_len != 1: + return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) if W_quant is None: - out = torch.matmul(X, W.t(), out = out) + out = torch.matmul(X, W.t(), out=out) elif bsz == 1 and q_len == 1: - out = fast_gemv(X, W, W_quant, out = out) + out = fast_gemv(X, W, W_quant, out=out) else: W = fast_dequantize(W.t(), W_quant) - out = torch.matmul(X, W, out = out) + out = torch.matmul(X, W, out=out) pass # Add in LoRA weights @@ -189,24 +280,28 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): lora_A._fast_lora = lora_A.to(dtype) lora_B._fast_lora = lora_B.to(dtype) pass - + if bsz == 1: out = out.view(out_dim) - temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora) - out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S) + temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out=temp_lora) + out.addmv_(lora_B._fast_lora, temp_lora, alpha=lora_S) else: out = out.view(bsz, out_dim) - temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) - out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S) + temp_lora = torch.mm( + X.view(bsz, in_dim), lora_A._fast_lora.t(), out=temp_lora + ) + out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha=lora_S) pass out = out.view(bsz, 1, out_dim) pass return out + + pass -def matmul_lora(X, W, W_quant, A, B, s, out = None): +def matmul_lora(X, W, W_quant, A, B, s, out=None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant) @@ -218,14 +313,17 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): reshape = False pass - out = torch.matmul(X, W, out = out) - if W_quant is not None: del W + out = torch.matmul(X, W, out=out) + if W_quant is not None: + del W if A is not None: # LoRA is enabled A, B = A.t(), B.t() out += (X @ A.to(dtype)) @ (s * B.to(dtype)) pass - + return out.view(batch, seq_len, -1) if reshape else out + + pass diff --git a/atomgpt/inverse_models/loader.py b/atomgpt/inverse_models/loader.py index 33c5b2f..400c6e4 100644 --- a/atomgpt/inverse_models/loader.py +++ b/atomgpt/inverse_models/loader.py @@ -4,67 +4,76 @@ from transformers import AutoConfig from transformers import __version__ as transformers_version from peft import PeftConfig, PeftModel -from atomgpt.inverse_models.mapper import INT_TO_FLOAT_MAPPER, FLOAT_TO_INT_MAPPER +from atomgpt.inverse_models.mapper import ( + INT_TO_FLOAT_MAPPER, + FLOAT_TO_INT_MAPPER, +) import os +from transformers import AutoModelForCausalLM # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! major, minor = transformers_version.split(".")[:2] major, minor = int(major), int(minor) SUPPORTS_FOURBIT = (major > 4) or (major == 4 and minor >= 37) -SUPPORTS_GEMMA = (major > 4) or (major == 4 and minor >= 38) +SUPPORTS_GEMMA = (major > 4) or (major == 4 and minor >= 38) if SUPPORTS_GEMMA: from atomgpt.inverse_models.gemma import FastGemmaModel del major, minor -def _get_model_name(model_name, load_in_4bit = True): +def _get_model_name(model_name, load_in_4bit=True): if not SUPPORTS_FOURBIT and model_name in INT_TO_FLOAT_MAPPER: model_name = INT_TO_FLOAT_MAPPER[model_name] logger.warning_once( - f"Unsloth: Your transformers version of {transformers_version} does not support native "\ - f"4bit loading.\nThe minimum required version is 4.37.\n"\ - f'Try `pip install --upgrade "transformers>=4.37"`\n'\ - f"to obtain the latest transformers build, then restart this session.\n"\ + f"AtomGPT: Your transformers version of {transformers_version} does not support native " + f"4bit loading.\nThe minimum required version is 4.37.\n" + f'Try `pip install --upgrade "transformers>=4.37"`\n' + f"to obtain the latest transformers build, then restart this session.\n" f"For now, we shall load `{model_name}` instead (still 4bit, just slower downloading)." ) - + elif not load_in_4bit and model_name in INT_TO_FLOAT_MAPPER: new_model_name = INT_TO_FLOAT_MAPPER[model_name] logger.warning_once( - f"Unsloth: You passed in `{model_name}` which is a 4bit model, yet you set\n"\ + f"AtomGPT: You passed in `{model_name}` which is a 4bit model, yet you set\n" f"`load_in_4bit = False`. We shall load `{new_model_name}` instead." ) model_name = new_model_name - elif load_in_4bit and SUPPORTS_FOURBIT and model_name in FLOAT_TO_INT_MAPPER: + elif ( + load_in_4bit and SUPPORTS_FOURBIT and model_name in FLOAT_TO_INT_MAPPER + ): new_model_name = FLOAT_TO_INT_MAPPER[model_name] logger.warning_once( - f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\ + f"AtomGPT: You passed in `{model_name}` and `load_in_4bit = True`.\n" f"We shall load `{new_model_name}` for 4x faster loading." ) model_name = new_model_name pass return model_name + + pass class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( - model_name = "unsloth/llama-3-8b-bnb-4bit", - max_seq_length = None, - dtype = None, - load_in_4bit = True, - token = None, - device_map = "sequential", - rope_scaling = None, - fix_tokenizer = True, - trust_remote_code = False, - use_gradient_checkpointing = True, - resize_model_vocab = None, - *args, **kwargs, + model_name="unsloth/llama-3-8b-bnb-4bit", + max_seq_length=None, + dtype=None, + load_in_4bit=True, + token=None, + device_map="sequential", + rope_scaling=None, + fix_tokenizer=True, + trust_remote_code=False, + use_gradient_checkpointing=True, + resize_model_vocab=None, + *args, + **kwargs, ): if token is None and "HF_TOKEN" in os.environ: token = os.environ["HF_TOKEN"] @@ -78,47 +87,63 @@ def from_pretrained( # First check if it's a normal model via AutoConfig is_peft = False try: - model_config = AutoConfig.from_pretrained(model_name, token = token) + model_config = AutoConfig.from_pretrained(model_name, token=token) is_peft = False except: try: # Most likely a PEFT model - peft_config = PeftConfig.from_pretrained(model_name, token = token) + peft_config = PeftConfig.from_pretrained( + model_name, token=token + ) except: - raise RuntimeError(f"Unsloth: `{model_name}` is not a full model or a PEFT model.") - + raise RuntimeError( + f"AtomGPT: `{model_name}` is not a full model or a PEFT model." + ) + # Check base model again for PEFT - model_name = _get_model_name(peft_config.base_model_name_or_path, load_in_4bit) - model_config = AutoConfig.from_pretrained(model_name, token = token) + model_name = _get_model_name( + peft_config.base_model_name_or_path, load_in_4bit + ) + model_config = AutoConfig.from_pretrained(model_name, token=token) is_peft = True pass model_type = model_config.model_type - if model_type == "llama": dispatch_model = FastLlamaModel - elif model_type == "mistral": dispatch_model = FastMistralModel + if model_type == "llama": + dispatch_model = FastLlamaModel + elif model_type == "mistral": + dispatch_model = FastMistralModel elif model_type == "gemma": if not SUPPORTS_GEMMA: raise RuntimeError( - f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\ - f"The minimum required version is 4.38.\n"\ - f'Try `pip install --upgrade "transformers>=4.38"`\n'\ - f"to obtain the latest transformers build, then restart this session."\ + f"AtomGPT: Your transformers version of {transformers_version} does not support Gemma.\n" + f"The minimum required version is 4.38.\n" + f'Try `pip install --upgrade "transformers>=4.38"`\n' + f"to obtain the latest transformers build, then restart this session." ) dispatch_model = FastGemmaModel elif model_type == "qwen2": dispatch_model = FastQwen2Model else: + print("Trying unsupported model model_type", model_type) + dispatch_model = AutoModelForCausalLM.from_pretrained(model_type) raise NotImplementedError( - f"Unsloth: {model_name} not supported yet!\n"\ - "Make an issue to https://github.com/unslothai/unsloth!", + f"AtomGPT: {model_name} not supported yet!\n" + "Make an issue to https://github.com/usnistgov/atomgpt!", ) pass - + print("dispatch_model", dispatch_model) # Check if this is local model since the tokenizer gets overwritten - if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ - os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \ - os.path.exists(os.path.join(old_model_name, "special_tokens_map.json")): + if ( + os.path.exists( + os.path.join(old_model_name, "tokenizer_config.json") + ) + and os.path.exists(os.path.join(old_model_name, "tokenizer.json")) + and os.path.exists( + os.path.join(old_model_name, "special_tokens_map.json") + ) + ): tokenizer_name = old_model_name else: @@ -126,56 +151,73 @@ def from_pretrained( pass model, tokenizer = dispatch_model.from_pretrained( - model_name = model_name, - max_seq_length = max_seq_length, - dtype = dtype, - load_in_4bit = load_in_4bit, - token = token, - device_map = device_map, - rope_scaling = rope_scaling, - fix_tokenizer = fix_tokenizer, - model_patcher = dispatch_model, - tokenizer_name = tokenizer_name, - trust_remote_code = trust_remote_code, - *args, **kwargs, + model_name=model_name, + max_seq_length=max_seq_length, + dtype=dtype, + load_in_4bit=load_in_4bit, + token=token, + device_map=device_map, + rope_scaling=rope_scaling, + fix_tokenizer=fix_tokenizer, + model_patcher=dispatch_model, + tokenizer_name=tokenizer_name, + trust_remote_code=trust_remote_code, + *args, + **kwargs, ) - + if resize_model_vocab is not None: model.resize_token_embeddings(resize_model_vocab) # In case the model supports tagging, add the unsloth tag. if hasattr(model, "add_model_tags"): - model.add_model_tags(["unsloth",]) + model.add_model_tags( + [ + "unsloth", + ] + ) pass if hasattr(tokenizer, "add_model_tags"): - tokenizer.add_model_tags(["unsloth",]) + tokenizer.add_model_tags( + [ + "unsloth", + ] + ) pass if load_in_4bit: # Fix up bitsandbytes config - quantization_config = \ - { + quantization_config = { # Sometimes torch_dtype is not a string!! - "bnb_4bit_compute_dtype" : model.config.to_dict()["torch_dtype"], - "bnb_4bit_quant_type" : "nf4", - "bnb_4bit_use_double_quant" : True, - "llm_int8_enable_fp32_cpu_offload" : False, - "llm_int8_has_fp16_weight" : False, - "llm_int8_skip_modules" : None, - "llm_int8_threshold" : 6.0, - "load_in_4bit" : True, - "load_in_8bit" : False, - "quant_method" : "bitsandbytes", + "bnb_4bit_compute_dtype": model.config.to_dict()[ + "torch_dtype" + ], + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_use_double_quant": True, + "llm_int8_enable_fp32_cpu_offload": False, + "llm_int8_has_fp16_weight": False, + "llm_int8_skip_modules": None, + "llm_int8_threshold": 6.0, + "load_in_4bit": True, + "load_in_8bit": False, + "quant_method": "bitsandbytes", } - model.config.update({"quantization_config" : quantization_config}) + model.config.update({"quantization_config": quantization_config}) pass if is_peft: # Now add PEFT adapters - model = PeftModel.from_pretrained(model, old_model_name, token = token) + model = PeftModel.from_pretrained( + model, old_model_name, token=token + ) # Patch it as well! - model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing) + model = dispatch_model.patch_peft_model( + model, use_gradient_checkpointing + ) pass return model, tokenizer + pass + + pass diff --git a/atomgpt/tests/test_inverse.py b/atomgpt/tests/test_inverse.py index 40ac02c..55abaf9 100644 --- a/atomgpt/tests/test_inverse.py +++ b/atomgpt/tests/test_inverse.py @@ -1,5 +1,25 @@ +from atomgpt.inverse_models.inverse_models import run_atomgpt_inverse +import os +import atomgpt + +config = os.path.join( + atomgpt.__path__[0], "examples", "inverse_model", "config.json" +) +print("config", config) + + +def run_inverse(): + try: + run_atomgpt_inverse(config) + except Exception as exp: + print("exp", exp) + pass + + +# atomgpt/examples/inverse_model/config.json +# run_inverse() # from atomgpt.inverse_models.mistral import FastMistralModel -#def test_inverse_model(): +# def test_inverse_model(): # f = FastMistralModel.from_pretrained() # print(f)