Skip to content

Commit

Permalink
Gradient Accumulation Fix (#1134)
Browse files Browse the repository at this point in the history
* Unsloth Zoo

* Update trainer.py

* Update trainer.py

* Update cross_entropy_loss.py

* n_items

* Update llama.py

* kwargs

* Remove extraneous f prefixes (#1133)

Co-authored-by: Emil Sadek <[email protected]>

* Update __init__.py

---------

Co-authored-by: Emil Sadek <[email protected]>
Co-authored-by: Emil Sadek <[email protected]>
  • Loading branch information
3 people authored Oct 15, 2024
1 parent a2f4c97 commit 38663b0
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 544 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ exclude = ["images*"]

[project.optional-dependencies]
huggingface = [
"unsloth_zoo",
"packaging",
"tyro",
"transformers>=4.44.2",
Expand Down Expand Up @@ -210,6 +211,7 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3",
]
colab-new = [
"unsloth_zoo",
"packaging",
"tyro",
"transformers>=4.44.2",
Expand Down
9 changes: 8 additions & 1 deletion unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
# pass
# pass

# Check for unsloth_zoo
try:
import unsloth_zoo
except:
raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`")
pass

# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
# enabling it will require much more work, so we have to prioritize. Please understand!
# We do have a beta version, which you can contact us about!
Expand Down Expand Up @@ -124,7 +131,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16

# Try linking cuda folder, or everything in local
if len(possible_cudas) == 0:
os.system(f"ldconfig /usr/local/")
os.system("ldconfig /usr/local/")
else:
find_number = re.compile(r"([\d\.]{2,})")
latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
Expand Down
198 changes: 6 additions & 192 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from .tokenizer_utils import *
from .models._utils import patch_tokenizer
import re

from unsloth_zoo.dataset_utils import (
train_on_responses_only,
)
CHAT_TEMPLATES = {}

# =========================================== Unsloth
Expand Down Expand Up @@ -910,7 +912,7 @@ def get_chat_template(
# Check fast tokenizer
if not is_fast_tokenizer:
print(
f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
"Please log a Github issue if you want this as a new feature!\n"\
"Your chat template will still work, but it won't add or edit tokens."
)
Expand Down Expand Up @@ -1236,7 +1238,7 @@ def __convert_to_sharegpt__(examples):
n_extensions = max(conversation_extension-1, 0)
if n_extensions == 0: return dataset

dataset = dataset.rename_columns({"conversations" : f"conversations0"})
dataset = dataset.rename_columns({"conversations" : "conversations0"})
all_shuffled = [dataset]
for j in range(1, n_extensions+1):
shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
Expand All @@ -1254,7 +1256,7 @@ def __convert_to_sharegpt__(examples):
f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
function += f"{' '*8}convos.append("\
f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
function += f"{' '*4}return " + "{ " + f"'conversations' : convos" + " }"
function += f"{' '*4}return " + "{ " + "'conversations' : convos" + " }"

# Map function
exec(function, globals())
Expand Down Expand Up @@ -1812,194 +1814,6 @@ def formatting_prompts_func(examples):
pass


# From https://www.geeksforgeeks.org/longest-common-substring-array-strings/
# Longest Common Substring in an Array of Strings
def _longest_common_substring(arr):
n = len(arr)
s = arr[0]
l = len(s)
res = ""
for i in range(l):
for j in range(i + 1, l + 1):
stem = s[i:j]
k = 1
for k in range(1, n):
if stem not in arr[k]:
break
if (k + 1 == n and len(res) < len(stem)):
res = stem
return res
pass


def _find_common_token_ids(component, tokenizer):
"""
\n### User:\n\n
\n\n### User:\n\n
etc
we need to find the middle most repeatted part.
Tokenizers can tokenize newlines or spaces as 1 token!
"""
right_text = ""
if component.endswith (" "): right_text = " "
elif component.endswith("\n"): right_text = "\n"
left_text = ""
if component.startswith (" "): left_text = " "
elif component.startswith("\n"): left_text = "\n"
stripped = component.strip()

# Add current pieces and also newlines
all_input_ids = []
for left in range(3):
for right in range(3):
x = left*left_text + stripped + right*right_text
x = tokenizer(x, add_special_tokens = False).input_ids
all_input_ids.append(x)

x = left*"\n" + stripped + right*"\n"
x = tokenizer(x, add_special_tokens = False).input_ids
all_input_ids.append(x)
pass
pass
substring = _longest_common_substring([str(x + [0]) for x in all_input_ids])
substring = substring.split(", ")[:-1]
substring = [int(x) for x in substring]

# Also get rest of tokenized string
original = tokenizer(component, add_special_tokens = False).input_ids
# Get optional left and right
for j in range(len(original)):
if original[j : j + len(substring)] == substring: break
optional_left = original[:j]
optional_right = original[j+len(substring):]
return substring, optional_left, optional_right
pass


def train_on_responses_only(
trainer,
instruction_part = None,
response_part = None,
):
"""
Trains only on responses and not on the instruction by masking out
the labels with -100 for the instruction part.
"""
tokenizer = trainer.tokenizer

if not hasattr(tokenizer, "_unsloth_input_part") or \
not hasattr(tokenizer, "_unsloth_output_part"):

if instruction_part is None or response_part is None:
raise ValueError("Unsloth: instruction_part and response_part must be given!")
pass
elif (instruction_part is not None or response_part is not None) and \
(hasattr(tokenizer, "_unsloth_input_part") or hasattr(tokenizer, "_unsloth_output_part")):

raise ValueError("Unsloth: Your tokenizer already has instruction and response parts set - do not give custom ones!")
else:
instruction_part = tokenizer._unsloth_input_part
response_part = tokenizer._unsloth_output_part
pass

# Get most common tokens since tokenizers can tokenize stuff differently!
Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, tokenizer)
A_must, A_left, A_right = _find_common_token_ids(response_part, tokenizer)

# Store some temporary stuff
A_first = A_must[0]
len_A_must = len(A_must)
A_left_reversed = A_left[::-1]
A_right_forward = A_right

Q_first = Q_must[0]
len_Q_must = len(Q_must)
Q_left_reversed = Q_left[::-1]
Q_right_forward = Q_right

def _train_on_responses_only(examples):
input_ids_ = examples["input_ids"]
all_labels = []

for input_ids in input_ids_:
n = len(input_ids)
labels = [-100] * n
n_minus_1 = n - 1
j = 0
while j < n:
# Find <assistant>
if (input_ids[j] == A_first) and \
(input_ids[j : (k := j + len_A_must)] == A_must):

# Now backtrack to get previous optional tokens
for optional_left in A_left_reversed:
if j < 1: break
if optional_left == input_ids[j-1]: j -= 1
else: break
pass
# And forwards look as well
for optional_right in A_right_forward:
if k >= n_minus_1: break
if optional_right == input_ids[k+1]: k += 1
else: break
pass
# assistant_j = j
assistant_k = k

j = assistant_k
# Given <assistant>, now find next user
while j < n:
# Find <user>
# Also accept last final item if assistant is the last turn
if (j == n_minus_1) or \
((input_ids[j] == Q_first) and \
(input_ids[j : (k := j + len_Q_must)] == Q_must)):

# Now backtrack to get previous optional tokens
for optional_left in Q_left_reversed:
if j < 1: break
if optional_left == input_ids[j-1]: j -= 1
else: break
pass
# And forwards look as well
for optional_right in Q_right_forward:
if k >= n_minus_1: break
if optional_right == input_ids[k+1]: k += 1
else: break
pass
user_j = j
# Account for last item
if user_j != n_minus_1:
# user_k = k
# j = user_k
j = k
else:
user_j = n
k = n
pass
# Now copy input_ids to labels
labels[assistant_k : user_j] = input_ids[assistant_k : user_j]
# print(assistant_j, assistant_k, user_j, user_k)
break
pass
j += 1
pass
pass
j += 1
pass
all_labels.append(labels)
pass
return { "labels" : all_labels }
pass

if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None:
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True)
if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None:
trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True)
return trainer
pass


def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
class StoppingCriteriaSub(StoppingCriteria):
__slots__ = "stop_token", "single_match", "length",
Expand Down
5 changes: 4 additions & 1 deletion unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def fast_cross_entropy_loss(
labels,
logit_softcapping = 0,
logit_scaling = 0,
n_items = None,
):
"""
Arguments:
Expand All @@ -372,7 +373,8 @@ def fast_cross_entropy_loss(
logit_softcapping,
logit_scaling,
)
n_items = torch.count_nonzero(labels != -100)
if n_items is None:
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
pass

Expand Down Expand Up @@ -409,6 +411,7 @@ def fast_cross_entropy_loss(
labels = shift_labels,
logit_softcapping = logit_softcapping,
logit_scaling = logit_scaling,
n_items = kwargs.get("n_items", None),
)
else:
if logit_scaling != 0:
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2024.9.post4"
__version__ = "2024.10.0"

__all__ = [
"prepare_model_for_kbit_training",
Expand Down
7 changes: 4 additions & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,13 +975,14 @@ def _CausalLM_fast_forward(
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
pass

shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
logit_softcapping = logit_softcapping,
logit_scaling = logit_scaling,
n_items = kwargs.get("n_items", None),
)
else:
if logit_scaling != 0:
Expand Down Expand Up @@ -2019,8 +2020,8 @@ def get_peft_model(
if loftq_config == {}:
from peft import LoftQConfig
logger.warning_once(
f"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
f"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
)
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)
pass
Expand Down
8 changes: 4 additions & 4 deletions unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def unsloth_save_model(
# max_ram = max(max_ram - W.nbytes, 0)
else:
# Save to Disk
logger.warning_once(f"We will save to Disk and not RAM now.")
logger.warning_once("We will save to Disk and not RAM now.")
filename = os.path.join(temporary_location, f"{name}.pt")
torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
# weights_only = True weirdly fails?
Expand Down Expand Up @@ -1460,7 +1460,7 @@ def fix_tokenizer_bos_token(tokenizer):

fix_bos_token = True
logger.warning(
f"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### Your chat template has a BOS token. We shall remove it temporarily."
)

Expand Down Expand Up @@ -1671,7 +1671,7 @@ def unsloth_save_pretrained_gguf(

if fix_bos_token:
logger.warning(
f"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### We removed it in GGUF's chat template for you."
)
pass
Expand Down Expand Up @@ -1867,7 +1867,7 @@ def unsloth_push_to_hub_gguf(

if fix_bos_token:
logger.warning(
f"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### The current model auto adds a BOS token.\n"\
"Unsloth: ##### We removed it in GGUF's chat template for you."
)
pass
Expand Down
Loading

0 comments on commit 38663b0

Please sign in to comment.