Skip to content

Commit

Permalink
skip bad rows
Browse files Browse the repository at this point in the history
  • Loading branch information
huseinzol05 committed Oct 13, 2024
1 parent 21b2f36 commit fe1c754
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 11 deletions.
23 changes: 13 additions & 10 deletions session/translation/end-to-end/nanot5-small.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
WANDB_PROJECT="nanot5-small-malaysian-cased-translation-v3" \
WANDB_PROJECT="nanot5-small-malaysian-cased-translation-v4" \
torchrun \
--nproc_per_node 4 \
--nproc_per_node 1 \
-m run_t5_v2 \
--model_name_or_path mesolitica/nanot5-small-malaysian-cased \
--num_train_epochs 2 \
--eval_steps 1000000000 \
--logging_steps 2 \
--save_steps 1500 \
--save_steps 200 \
--save_total_limit 3 \
--do_train \
--train_file malaysian-translation \
--output_dir nanot5-small-malaysian-cased-translation-v3 \
--per_device_train_batch_size=12 \
--train_file mosaic \
--output_dir nanot5-small-malaysian-cased-translation-v4-v2 \
--dataloader_num_workers=10 \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=3 \
--gradient_accumulation_steps=2 \
--max_source_length 4096 \
--max_target_length 4096 \
--gradient_accumulation_steps=16 \
--max_source_length 2048 \
--max_target_length 2048 \
--learning_rate 2e-4 \
--gradient_checkpointing true \
--bf16
--weight_decay 0.01 \
--bf16 \
--run_name nanot5-small-malaysian-cased-translation-v4-1
150 changes: 149 additions & 1 deletion session/translation/end-to-end/run_t5_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Any, Union

import torch
import datasets
Expand Down Expand Up @@ -51,6 +51,9 @@
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from tokenizers import AddedToken
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from streaming import LocalDataset

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Expand All @@ -70,6 +73,140 @@
MBart50TokenizerFast,
M2M100Tokenizer]

@dataclass
class DataCollatorForSeq2Seq:
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`], *optional*):
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
prepare the *decoder_input_ids*
This is useful when using *label_smoothing* to avoid calculating loss twice.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
sequence is provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (`int`, *optional*, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`, *optional*, defaults to `"pt"`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""

tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
return_tensors: str = "pt"

def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors

label_name = 'labels'
labels = [feature[label_name] for feature in features if feature is not None]
# reconvert list[None] to None if necessary
# this might occur when we pass {..., "labels": None}
if labels is not None and all(label is None for label in labels):
labels = None
non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features if feature is not None]

# run through tokenizer without labels to ensure no side effects
batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
non_labels_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)

# we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
if labels is not None:
if no_padding:
if isinstance(features[0][label_name], list):
batch["labels"] = list(labels)
else:
batch["labels"] = [np.concatenate([label, []]) for label in labels]
else:
max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
if self.pad_to_multiple_of is not None:
max_label_length = (
(max_label_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
)

padding_side = self.tokenizer.padding_side
if isinstance(features[0][label_name], list):
batch["labels"] = [
label + [self.label_pad_token_id] * (max_label_length - len(label))
if padding_side == "right"
else [self.label_pad_token_id] * (max_label_length - len(label)) + label
for label in labels
]
else:
batch["labels"] = [
np.concatenate(
[
label,
np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
]
)
if padding_side == "right"
else np.concatenate(
[
np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
label,
]
)
for label in labels
]

# reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
if batch.get("labels", None) is not None:
if return_tensors == "pt":
import torch

batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
elif return_tensors == "tf":
import tensorflow as tf

batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64)
else:
batch["labels"] = np.array(batch["labels"], dtype=np.int64)
else:
batch["labels"] = None

# prepare decoder_input_ids
if (
labels is not None
and self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
batch["decoder_input_ids"] = decoder_input_ids

return batch

@dataclass
class ModelArguments:
Expand Down Expand Up @@ -325,6 +462,17 @@ def __getitem__(self, idx):
max_length=max_target_length,
truncation=True,
)
left_eos = outputs["input_ids"][-1] == tokenizer.eos_token_id
right_eos = labels["input_ids"][-1] == tokenizer.eos_token_id

if left_eos and not right_eos:
print(left_eos, right_eos, 'skip')
return None

if not left_eos and right_eos:
print(left_eos, right_eos, 'skip')
return None

return {
"input_ids": outputs["input_ids"],
"attention_mask": outputs["attention_mask"],
Expand Down

0 comments on commit fe1c754

Please sign in to comment.