Skip to content

Commit

Permalink
Prelim Feb release (#173)
Browse files Browse the repository at this point in the history
* Works?

* Update pyproject.toml

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Swiglu

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update swiglu.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* Update fast_lora.py

* attention_mask

* Update llama.py

* Update llama.py

* labels

* Update mistral.py

* Update llama.py

* attention mask

* Update save.py

* Update save.py

* Update mistral.py

* attention mask

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update dpo.py

* Patch saving

* Update save.py

* Update save.py

* patch_saving_functions

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* print

* Mistral patch

* Update mistral.py

* Update save.py

* saving

* Update llama.py

* Update llama.py

* Fast inference repatch

* Update llama.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update mistral.py

* Update __init__.py

* Fix inference

* Update mistral.py

* fast lm_head

* Remove fast path

* Update rope_embedding.py

* Update loader.py

* LlamaAttention_fast_forward_inference

* if past_key_value is not None and q_len == 1:

* revert inference

* Update loader.py

* past_key_value

* Update llama.py

* Update llama.py

* Fix SDPA

* Update llama.py

* padding

* Inference

* Update llama.py

* Revert

* Update mistral.py

* faster inference

* inference

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* inference

* Update llama.py

* Update utils.py

* faster inference

* Update llama.py

* revert

* lm_head

* Update llama.py

* inference

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* faster inference

* Update llama.py

* fast inference

* Update llama.py

* Update llama.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* torch compile

* past_key_values

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* Update llama.py

* fast inference + saving config.json

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update mistral.py

* fast inference again

* more temp matrices

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* fast inference

* Update mistral.py

* Update llama.py

* SDPA

* attention_mask

* New version

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update utils.py

* Update utils.py

* Update save.py

* Update save.py

* Torch 2.2.0

* Update save.py

* mistral swa

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Fix SWA inference

* Fix llm_int8_skip_modules

* SWA inference

* Update save.py

* Update save.py

* Update pyproject.toml

* __version__

* __version__

* Update save.py

* Update save.py

* Update mistral.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Chat Templates

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* Update chat_templates.py

* patch tokenizer

* Update chat_templates.py

* Saving, LlamaRotaryEmbedding issues

* Update llama.py

* Update mistral.py
  • Loading branch information
danielhanchen authored Feb 14, 2024
1 parent 7c46209 commit a030e80
Show file tree
Hide file tree
Showing 8 changed files with 494 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and
| **CodeLlama 34b** A100 | [▶️ Start on Colab](https://colab.research.google.com/drive/1y7A0AxE3y8gdj4AVkl2aZX47Xu3P1wJT?usp=sharing) | 1.9x faster | 27% less |
| **Mistral 7b** 1xT4 | [▶️ Start on Kaggle](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook) | 5x faster\* | 62% less |

- This [conversational notebook](https://colab.research.google.com/drive/1bMOKOBzxQWUIGZBs_B0zm8pimuEnZdfM?usp=sharing) is useful for ShareGPT ChatML datatsets.
- This [conversational notebook](https://colab.research.google.com/drive/1Aau3lgPzeZKQ-98h69CCu1UJcvIBLmy2?usp=sharing) is useful for ShareGPT ChatML / Vicuna templates.
- Our [raw text notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is useful for text completion.
- Colab provides a free GPU sometimes. Kaggle has 30 hrs free per week on a 12 hr running cap.
- \* Kaggle has 2x T4s, but we use 1. Due to overhead, 1x T4 is 5x faster. Use Colab as Kaggle takes 10 mins to install.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ huggingface = [
"peft>=0.7.1",
"tqdm",
"psutil",
"wheel>=0.42.0",
]
cu118only = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
Expand Down
1 change: 1 addition & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,4 @@

from .models import *
from .save import *
from .chat_templates import *
384 changes: 384 additions & 0 deletions unsloth/chat_templates.py

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Union, Optional, List, Any, Callable
import warnings
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
from transformers import AutoTokenizer
Expand Down Expand Up @@ -116,21 +117,24 @@ def make_inputs_require_grad(module, input, output):


def patch_tokenizer(model, tokenizer):
model.config.update({"unsloth_version" : __version__})
if model is not None:
model.config.update({"unsloth_version" : __version__})
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
# Fixes https://github.com/unslothai/unsloth/issues/5
if hasattr(tokenizer, "unk_token"):
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token})
tokenizer.pad_token = tokenizer.unk_token
else:
name = model.config._name_or_path if model is not None else "Model"
logger.warning_one(
f"{model.config._name_or_path} does not have a padding or unknown token!\n"\
f"{name} does not have a padding or unknown token!\n"\
f"Will use the EOS token of id {tokenizer.eos_token_id} as padding."
)
assert(hasattr(tokenizer, "eos_token"))
tokenizer.add_special_tokens({"pad_token" : tokenizer.eos_token})
tokenizer.pad_token = tokenizer.eos_token
config = model.config.update({"pad_token_id" : tokenizer.eos_token_id})
if model is not None:
config = model.config.update({"pad_token_id" : tokenizer.eos_token_id})
pass
return model, tokenizer
pass
Expand Down
78 changes: 77 additions & 1 deletion unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def LlamaModel_fast_forward(

hidden_states = inputs_embeds

if past_key_values is None and self.gradient_checkpointing and self.training:
if past_key_values is None and self.training:
use_cache = False
# if use_cache:
# logger.warning_once(
Expand Down Expand Up @@ -776,6 +776,73 @@ def PeftModelForCausalLM_fast_forward(
pass


# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
pass

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
pass
pass


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
pass

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
pass
pass


class FastLlamaModel:

@staticmethod
Expand All @@ -787,6 +854,15 @@ def pre_patch():
LlamaModel .forward = LlamaModel_fast_forward
LlamaForCausalLM .forward = LlamaForCausalLM_fast_forward
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward

# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding
return
pass

Expand Down
8 changes: 8 additions & 0 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ def pre_patch():
MistralModel .forward = LlamaModel_fast_forward
MistralForCausalLM .forward = MistralForCausalLM_fast_forward
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward

# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.mistral.modeling_mistral
transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding = LlamaRotaryEmbedding
return
pass

Expand Down
16 changes: 15 additions & 1 deletion unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"input_layernorm", "post_attention_layernorm",
)

# https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19
# From https://mlabonne.github.io/blog/posts/Quantize_Llama_2_models_using_ggml.html
ALLOWED_QUANTS = \
{
Expand All @@ -59,10 +60,16 @@
"q4_0" : "Original quant method, 4-bit.",
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
"q4_k_s" : "Uses Q4_K for all tensors",
"q4_k" : "alias for q4_k_m",
"q5_k" : "alias for q5_k_m",
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
"q5_k_s" : "Uses Q5_K for all tensors",
"q6_k" : "Uses Q8_K for all tensors",
"iq2_xxs" : "2.06 bpw quantization",
"iq2_xs" : "2.31 bpw quantization",
"iq3_xxs" : "3.06 bpw quantization",
"q3_k_xs" : "3-bit extra small quantization",
}

def print_quantization_methods():
Expand Down Expand Up @@ -246,7 +253,8 @@ def unsloth_save_model(
# If push_to_hub, we must remove the .../ part of a repo
if push_to_hub and "/" in save_directory:

new_save_directory = save_directory[save_directory.find("/"):]
# +1 solves absolute path issues
new_save_directory = save_directory[save_directory.find("/")+1:]

logger.warning_once(
f"Unsloth: You are pushing to hub, but you passed your HF username.\n"\
Expand Down Expand Up @@ -861,10 +869,16 @@ def unsloth_save_pretrained_gguf(
"q4_0" : "Original quant method, 4-bit.",
"q4_1" : "Higher accuracy than q4_0 but not as high as q5_0. However has quicker inference than q5 models.",
"q4_k_s" : "Uses Q4_K for all tensors",
"q4_k" : "alias for q4_k_m",
"q5_k" : "alias for q5_k_m",
"q5_0" : "Higher accuracy, higher resource usage and slower inference.",
"q5_1" : "Even higher accuracy, resource usage and slower inference.",
"q5_k_s" : "Uses Q5_K for all tensors",
"q6_k" : "Uses Q8_K for all tensors",
"iq2_xxs" : "2.06 bpw quantization",
"iq2_xs" : "2.31 bpw quantization",
"iq3_xxs" : "3.06 bpw quantization",
"q3_k_xs" : "3-bit extra small quantization",
"""
if tokenizer is None:
raise ValueError("Unsloth: Saving to GGUF must have a tokenizer.")
Expand Down

0 comments on commit a030e80

Please sign in to comment.