diff --git a/docs/transformers/_toctree.yml b/docs/transformers/_toctree.yml
index 04a38d69c9..b4a0342e68 100644
--- a/docs/transformers/_toctree.yml
+++ b/docs/transformers/_toctree.yml
@@ -2,3 +2,11 @@
- local: index
title: 🤗 Transformers
title: Get started
+- sections:
+ - local: tutorials/finetune
+ title: Fine-tune a pretrained model
+ - local: tutorials/finetune_distribute
+ title: Distributed training and mixed precision
+ - local: tutorials/generation
+ title: Generation with LLMs
+ title: Tutorials
diff --git a/docs/transformers/tutorials/finetune.md b/docs/transformers/tutorials/finetune.md
new file mode 100644
index 0000000000..60dbdf9171
--- /dev/null
+++ b/docs/transformers/tutorials/finetune.md
@@ -0,0 +1,243 @@
+
+
+# Fine-tune a pretrained model
+
+There are significant benefits to using a pretrained model. It reduces computation costs, your carbon footprint, and allows you to use state-of-the-art models without having to train one from scratch. 🤗 Transformers provides access to thousands of pretrained models for a wide range of tasks. When you use a pretrained model, you train it on a dataset specific to your task. This is known as fine-tuning, an incredibly powerful training technique. In this tutorial, you will fine-tune a pretrained model with a deep learning framework of your choice:
+
+- Fine-tune a pretrained model with 🤗 Transformers Trainer.
+- Fine-tune a pretrained model in native MindSpore.
+
+## Prepare a dataset
+
+Before you can fine-tune a pretrained model, download a dataset and prepare it for training. The previous tutorial showed you how to process data for training, and now you get an opportunity to put those skills to the test!
+
+Begin by loading the Yelp Reviews dataset:
+
+```pycon
+>>> from datasets import load_dataset
+
+>>> dataset = load_dataset("yelp_review_full")
+>>> dataset["train"][100]
+{'label': 0,
+ 'text': 'My expectations for McDonalds are t rarely high. But for one to still fail so spectacularly...that takes something special!\\nThe cashier took my friends\'s order, then promptly ignored me. I had to force myself in front of a cashier who opened his register to wait on the person BEHIND me. I waited over five minutes for a gigantic order that included precisely one kid\'s meal. After watching two people who ordered after me be handed their food, I asked where mine was. The manager started yelling at the cashiers for \\"serving off their orders\\" when they didn\'t have their food. But neither cashier was anywhere near those controls, and the manager was the one serving food to customers and clearing the boards.\\nThe manager was rude when giving me my order. She didn\'t make sure that I had everything ON MY RECEIPT, and never even had the decency to apologize that I felt I was getting poor service.\\nI\'ve eaten at various McDonalds restaurants for over 30 years. I\'ve worked at more than one location. I expect bad days, bad moods, and the occasional mistake. But I have yet to have a decent experience at this store. It will remain a place I avoid unless someone in my party needs to avoid illness from low blood sugar. Perhaps I should go back to the racially biased service of Steak n Shake instead!'}
+```
+
+As you now know, you need a tokenizer to process the text and include a padding and truncation strategy to handle any variable sequence lengths. To process your dataset in one step, use 🤗 Datasets map method to apply a preprocessing function over the entire dataset:
+
+```pycon
+>>> from transformers import AutoTokenizer
+
+>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+
+>>> def tokenize_function(examples):
+... return tokenizer(examples["text"], padding="max_length", truncation=True)
+
+
+>>> tokenized_datasets = dataset.map(tokenize_function, batched=True)
+```
+
+If you like, you can create a smaller subset of the full dataset to fine-tune on to reduce the time it takes:
+
+```pycon
+small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
+small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
+```
+
+## Train
+
+At this point, you should follow the section corresponding to the framework you want to use. You can use the links in the right sidebar to jump to the one you want - and if you want to hide all of the content for a given framework, just use the button at the top-right of that framework’s block!
+
+### Train with MindSpore Trainer
+
+
+
+!!! Note
+
+ Taking bert as an example, you can find the complete code in `examples/transformers/bert/finetune_with_mindspore_trainer.py`
+
+🤗 Transformers provides a Trainer class optimized for training 🤗 Transformers models, making it easier to start training without manually writing your own training loop. The Trainer API supports a wide range of training options and features such as logging, gradient accumulation, and mixed precision.
+
+Start by loading your model and specify the number of expected labels. From the Yelp Review dataset card, you know there are five labels:
+
+```pycon
+>>> from mindone.transformers.models.bert import BertForSequenceClassification
+
+>>> model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
+```
+
+!!! Note
+
+ You will see a warning about some of the pretrained weights not being used and some weights being randomly initialized. Don’t worry, this is completely normal! The pretrained head of the BERT model is discarded, and replaced with a randomly initialized classification head. You will fine-tune this new model head on your sequence classification task, transferring the knowledge of the pretrained model to it.
+
+#### Training hyperparameters
+
+Next, create a TrainingArguments class which contains all the hyperparameters you can tune as well as flags for activating different training options. For this tutorial you can start with the default training hyperparameters, but feel free to experiment with these to find your optimal settings.
+
+Specify where to save the checkpoints from your training:
+
+```pycon
+>>> from mindone.transformers.training_args import TrainingArguments
+
+>>> training_args = TrainingArguments(output_dir="test_trainer")
+```
+
+(optional but recommended) Init environment:
+
+```pycon
+>>> import mindspore as ms
+>>> from mindone.transformers.mindspore_adapter import MindSporeArguments, init_environment
+
+>>> env_args = MindSporeArguments(mode=ms.GRAPH_MODE, device_target="Ascend")
+>>> init_environment(env_args)
+```
+
+#### Trainer
+
+Create a Trainer object with your model, training arguments, training and test datasets, and evaluation function:
+
+```pycon
+>>> trainer = Trainer(
+... model=model,
+... args=training_args,
+... train_dataset=small_train_dataset,
+... eval_dataset=small_eval_dataset,
+... compute_metrics=compute_metrics,
+... )
+```
+
+Then fine-tune your model by calling train():
+
+```pycon
+>>> trainer.train()
+```
+
+
+
+### Train in native MindSpore
+
+
+
+!!! Note
+
+ Taking bert as an example, you can find the complete code in `examples/transformers/bert/finetune_in_native_mindspore.py`
+
+Trainer takes care of the training loop and allows you to fine-tune a model in a single line of code. For users who prefer to write their own training loop, you can also fine-tune a 🤗 Transformers model in native MindSpore.
+
+At this point, you may need to restart your notebook to free memory.
+
+Next, manually postprocess `tokenized_dataset` to prepare it for training.
+
+1. Remove the text column because the model does not accept raw text as an input:
+
+```pycon
+>>> tokenized_datasets = tokenized_datasets.remove_columns(["text"])
+```
+
+2. Rename the `label` column to `labels` because the model expects the argument to be named `labels`:
+
+```pycon
+>>> tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
+```
+
+#### DataLoader
+
+Create a MindSpore DataLoader for your training datasets so you can iterate over batches of data:
+
+```pycon
+>>> import mindspore as ms
+>>> from mindone.transformers.mindspore_adapter import HF2MSDataset
+
+>>> def ms_data_collator(features, batch_info):
+... batch = {}
+... for k, v in features[0]:
+... batch[k] = np.stack([f[k] for f in features]) if isinstance(v, np.ndarray) else np.array([f[k] for f in features])
+... return batch
+
+>>> batch_size, num_epochs = 1, 3
+>>> train_dataloader = ms.dataset.GeneratorDataset(HF2MSDataset(small_train_dataset), column_names="item")
+>>> train_dataloader = train_dataloader.batch(batch_size=batch_size, per_batch_map=ms_data_collator)
+>>> train_dataloader = train_dataloader.repeat(1)
+>>> train_dataloader = train_dataloader.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
+```
+
+Load your model with the number of expected labels:
+
+```pycon
+>>> from mindone.transformers.models.bert import BertForSequenceClassification
+
+>>> model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)
+```
+
+#### Optimizer
+
+Create an optimizer to fine-tune the model. Let’s use the AdamWeightDecay optimizer from MindSpore:
+
+```pycon
+>>> from mindspore import nn
+
+>>> optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=5e-6)
+```
+
+#### Train Network
+
+Create an MindSpore train network
+
+```pycon
+>>> from mindone.transformers.mindspore_adapter import TrainOneStepWrapper
+
+>>> class ReturnLoss(nn.Cell):
+... def __init__(self, model):
+... super(ReturnLoss, self).__init__(auto_prefix=False)
+... self.model = model
+...
+... def construct(self, *args, **kwargs):
+... outputs = self.model(*args, **kwargs)
+... loss = outputs[0]
+... return loss
+
+>>> train_model = TrainOneStepWrapper(ReturnLoss(model), optimizer)
+```
+
+Great, now you are ready to train! 🥳
+
+#### Training loop
+
+To keep track of your training progress, use the tqdm library to add a progress bar over the number of training steps:
+
+```pycon
+>>> from tqdm.auto import tqdm
+
+>>> num_training_steps = len(small_train_dataset) * num_epochs // batch_size
+>>> progress_bar = tqdm(range(num_training_steps))
+
+>>> train_model.train()
+>>> for step, batch in enumerate(train_dataloader):
+... batch = batch["item"]
+...
+... tuple_inputs = (
+... ms.Tensor(batch["input_ids"], ms.int32),
+... ms.Tensor(batch["attention_mask"], ms.bool_),
+... None,
+... None,
+... None,
+... None,
+... ms.tensor(batch["labels"], ms.int32)
+... )
+...
+... loss, _, overflow = train_model(*tuple_inputs)
+...
+... progress_bar.update(1)
+```
+
+
diff --git a/docs/transformers/tutorials/finetune_distribute.md b/docs/transformers/tutorials/finetune_distribute.md
new file mode 100644
index 0000000000..6aa5ff26ca
--- /dev/null
+++ b/docs/transformers/tutorials/finetune_distribute.md
@@ -0,0 +1,37 @@
+# Distributed training with mixed precision and ZeRO parallelism
+
+The Trainer supports distributed training and mixed precision, which means you can also use it in a script. To enable both of these features:
+
+See `examples/transformers/llama/finetune_with_mindspore_trainer.py` for more detail.
+
+- Add the `is_distribute` argument to enable distribute training.
+- Add the `fp16` or `bf16` argument to enable mixed precision.
+- Add the `zero_stage` argument to enable optimizer parallelism with `ZeRO` algorithm.
+- Set the number of global/local NPUs to use with the `worker_num`/`local_worker_num` argument.
+
+```shell
+msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=outputs/parallel_logs \
+python finetune_with_mindspore_trainer.py \
+ --model_path $local_path/meta-llama/Meta-Llama-3-8B \
+ --dataset_path $local_path/yelp_review_full \
+ --output_dir ./outputs \
+ --bf16 \
+ --zero_stage 2 \
+ --is_distribute True
+```
+
+Another example implemented through native MindSpore, see `examples/transformers/llama/finetune_in_native_mindspore.py` for more detail.
+
+
+
+```shell
+msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir=outputs/parallel_logs \
+python finetune_in_native_mindspore.py \
+ --model_path meta-llama/Meta-Llama-3-8B \
+ --dataset_path Yelp/yelp_review_full \
+ --bf16 \
+ --zero_stage 2 \
+ --is_distribute True
+```
+
+
diff --git a/docs/transformers/tutorials/generation.md b/docs/transformers/tutorials/generation.md
new file mode 100644
index 0000000000..15a46fe65e
--- /dev/null
+++ b/docs/transformers/tutorials/generation.md
@@ -0,0 +1,102 @@
+
+
+# Generation with LLMs
+
+LLMs, or Large Language Models, are the key component behind text generation. In a nutshell, they consist of large pretrained transformer models trained to predict the next word (or, more precisely, token) given some input text. Since they predict one token at a time, you need to do something more elaborate to generate new sentences other than just calling the model — you need to do autoregressive generation.
+
+Autoregressive generation is the inference-time procedure of iteratively calling a model with its own generated outputs, given a few initial inputs. In 🤗 Transformers, this is handled by the generate() method, which is available to all models with generative capabilities.
+
+This tutorial will show you how to:
+
+- Generate text with an LLM
+
+Before you begin, make sure you have all the necessary libraries installed:
+
+```shell
+pip install transformers==4.42.4
+```
+
+## Generate text
+
+!!! Note
+
+ Taking llama as an example, you can find the complete code in `examples/transformers/llama/generate.py`
+ And you can compare the results of script `examples/transformers/llama/generate_pt.py` with PyTorch.
+
+A language model trained for causal language modeling takes a sequence of text tokens as input and returns the probability distribution for the next token.
+
+A critical aspect of autoregressive generation with LLMs is how to select the next token from this probability distribution. Anything goes in this step as long as you end up with a token for the next iteration. This means it can be as simple as selecting the most likely token from the probability distribution or as complex as applying a dozen transformations before sampling from the resulting distribution.
+
+The process depicted above is repeated iteratively until some stopping condition is reached. Ideally, the stopping condition is dictated by the model, which should learn when to output an end-of-sequence (EOS) token. If this is not the case, generation stops when some predefined maximum length is reached.
+
+Properly setting up the token selection step and the stopping condition is essential to make your model behave as you’d expect on your task. That is why we have a GenerationConfig file associated with each model, which contains a good default generative parameterization and is loaded alongside your model.
+
+Let’s talk code!
+
+!!! Note
+
+ If you’re interested in basic LLM usage, our high-level Pipeline interface is a great starting point. However, LLMs often require advanced features like quantization and fine control of the token selection step, which is best done through generate(). Autoregressive generation with LLMs is also resource-intensive and should be executed on a Ascend NPU for adequate throughput.
+
+First, you need to load the model.
+
+```pycon
+>>> from mindone.transformers.models.llama import LlamaForCausalLM
+
+>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
+```
+
+There are other ways to initialize a model, but this is a good baseline to begin with an LLM.
+
+Next, you need to preprocess your text input with a tokenizer.
+
+```pycon
+>>> from transformers import AutoTokenizer
+
+>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
+>>> input_ids = ms.Tensor(tokenizer(["A list of colors: red, blue"]).input_ids, ms.int32)
+```
+
+The model_inputs variable holds the tokenized text input, as well as the attention mask. While generate() does its best effort to infer the attention mask when it is not passed, we recommend passing it whenever possible for optimal results.
+
+After tokenizing the inputs, you can call the generate() method to returns the generated tokens. The generated tokens then should be converted to text before printing.
+
+```pycon
+>>> generated_ids = model.generate(
+... input_ids=input_ids,
+... max_new_tokens=30,
+... use_cache=True,
+... do_sample=False,
+... )
+
+>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
+```
+
+Finally, you don’t need to do it one sequence at a time! You can batch your inputs, which will greatly improve the throughput at a small latency and memory cost. All you need to do is to make sure you pad your inputs properly (more on that below).
+
+```pycon
+>>> tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
+>>> input_ids = ms.Tensor(tokenizer(
+... ["A list of colors: red, blue", "Portugal is"], padding=True
+... ).input_ids, ms.int32)
+
+>>> generated_ids = model.generate(
+... input_ids=input_ids,
+... max_new_tokens=30,
+... use_cache=True,
+... do_sample=False,
+... )
+
+>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+```
+
+And that’s it! In a few lines of code, you can harness the power of an LLM.
diff --git a/examples/transformers/bert/finetune_in_native_mindspore.py b/examples/transformers/bert/finetune_in_native_mindspore.py
new file mode 100644
index 0000000000..fb2610605c
--- /dev/null
+++ b/examples/transformers/bert/finetune_in_native_mindspore.py
@@ -0,0 +1,95 @@
+import argparse
+
+import numpy as np
+from datasets import load_dataset
+from transformers import AutoTokenizer
+
+import mindspore as ms
+from mindspore import nn
+
+from mindone.transformers.mindspore_adapter import HF2MSDataset, TrainOneStepWrapper
+from mindone.transformers.models.bert import BertForSequenceClassification
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_path", type=str, default="google-bert/bert-base-cased", help="pretrained model name")
+ parser.add_argument("--dataset_path", type=str, default="Yelp/yelp_review_full", help="dataset path.")
+ args = parser.parse_args()
+ print(args)
+
+ ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": "O0"})
+
+ # 1. create dataset
+ dataset = load_dataset(args.dataset_path)
+ dataset["train"] = dataset["train"].shuffle(seed=42).select(range(1000))
+ dataset["test"] = dataset["test"].shuffle(seed=42).select(range(1000))
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+
+ def tokenize_function(examples):
+ return tokenizer(
+ examples["text"],
+ padding="max_length",
+ truncation=True,
+ max_length=512, # Note: pad is need for training batch size is gather than 1.
+ )
+
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
+ tokenized_datasets = tokenized_datasets.remove_columns(["text"])
+ tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
+ small_train_dataset = tokenized_datasets["train"]
+
+ def ms_data_collator(features, batch_info):
+ batch = {}
+ for k, v in features[0].items():
+ batch[k] = (
+ np.stack([f[k] for f in features]) if isinstance(v, np.ndarray) else np.array([f[k] for f in features])
+ )
+ return batch
+
+ batch_size, num_epochs = 8, 3
+ train_dataloader = ms.dataset.GeneratorDataset(HF2MSDataset(small_train_dataset), column_names="item")
+ train_dataloader = train_dataloader.batch(batch_size=batch_size, per_batch_map=ms_data_collator)
+ train_dataloader = train_dataloader.repeat(1)
+ train_dataloader = train_dataloader.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
+
+ # 2. create train network
+ model = BertForSequenceClassification.from_pretrained(args.model_path, num_labels=5)
+ optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=5e-6)
+
+ class ReturnLoss(nn.Cell):
+ def __init__(self, model):
+ super(ReturnLoss, self).__init__(auto_prefix=False)
+ self.model = model
+
+ def construct(self, *args, **kwargs):
+ outputs = self.model(*args, **kwargs)
+ loss = outputs[0]
+ return loss
+
+ train_model = TrainOneStepWrapper(ReturnLoss(model), optimizer)
+
+ # 3. training
+ train_model.set_train()
+ for step, batch in enumerate(train_dataloader):
+ batch = batch["item"]
+
+ # inputs dict to tuple
+ tuple_inputs = (
+ ms.Tensor(batch["input_ids"], ms.int32),
+ ms.Tensor(batch["attention_mask"], ms.bool_),
+ None,
+ None,
+ None,
+ None,
+ ms.tensor(batch["labels"], ms.int32),
+ )
+
+ loss, _, overflow = train_model(*tuple_inputs)
+
+ print(f"step: {step}, loss: {loss}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/transformers/bert/finetune_with_mindspore_trainer.py b/examples/transformers/bert/finetune_with_mindspore_trainer.py
new file mode 100644
index 0000000000..63b49c2400
--- /dev/null
+++ b/examples/transformers/bert/finetune_with_mindspore_trainer.py
@@ -0,0 +1,69 @@
+from dataclasses import dataclass, field
+
+import evaluate
+import numpy as np
+from datasets import load_dataset
+from transformers import AutoTokenizer, HfArgumentParser
+
+from mindone.transformers.mindspore_adapter import MindSporeArguments, init_environment
+from mindone.transformers.models.bert import BertForSequenceClassification
+from mindone.transformers.trainer import Trainer
+from mindone.transformers.training_args import TrainingArguments
+
+
+@dataclass
+class MyArguments(MindSporeArguments, TrainingArguments):
+ model_path: str = field(default="google-bert/bert-base-cased")
+ dataset_path: str = field(default="Yelp/yelp_review_full")
+
+
+def main():
+ parser = HfArgumentParser(MyArguments)
+ args = parser.parse_args_into_dataclasses()[0]
+
+ init_environment(args)
+
+ dataset = load_dataset(args.dataset_path)
+ dataset["train"] = dataset["train"].shuffle(seed=42).select(range(1000))
+ dataset["test"] = dataset["test"].shuffle(seed=42).select(range(1000))
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+
+ def tokenize_function(examples):
+ return tokenizer(
+ examples["text"],
+ padding="max_length",
+ truncation=True,
+ max_length=512, # Note: pad is need for training batch size is gather than 1.
+ )
+
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
+ small_train_dataset = tokenized_datasets["train"]
+ small_eval_dataset = tokenized_datasets["test"]
+
+ model = BertForSequenceClassification.from_pretrained(args.model_path, num_labels=5)
+
+ if args.do_eval:
+ metric = evaluate.load("accuracy")
+
+ def compute_metrics(eval_pred):
+ logits, labels = eval_pred
+ predictions = np.argmax(logits, axis=-1)
+ return metric.compute(predictions=predictions, references=labels)
+
+ else:
+ compute_metrics = None
+
+ trainer = Trainer(
+ model=model,
+ args=args,
+ train_dataset=small_train_dataset,
+ eval_dataset=small_eval_dataset,
+ compute_metrics=compute_metrics,
+ )
+
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/transformers/llama/finetune_in_native_mindspore.py b/examples/transformers/llama/finetune_in_native_mindspore.py
new file mode 100644
index 0000000000..765b9d20bf
--- /dev/null
+++ b/examples/transformers/llama/finetune_in_native_mindspore.py
@@ -0,0 +1,159 @@
+"""
+Llama 3 model fine-tuning script.
+This script with default values fine-tunes a pretrained Meta Llama3 on the `Yelp/yelp_review_full` dataset,
+"""
+
+
+import argparse
+import ast
+from typing import Dict
+
+import numpy as np
+from datasets import load_dataset
+from transformers import AutoTokenizer
+
+import mindspore as ms
+from mindspore import nn
+
+from mindone.transformers.mindspore_adapter import HF2MSDataset, TrainOneStepWrapper, auto_mixed_precision
+from mindone.transformers.models.llama import LlamaForSequenceClassification
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B", help="pretrained model name")
+ parser.add_argument("--dataset_path", type=str, default="Yelp/yelp_review_full", help="dataset path.")
+ parser.add_argument(
+ "--zero_stage", type=int, default=0, choices=[0, 1, 2], help="stage of ZeRO optimizer parallelism"
+ )
+ parser.add_argument(
+ "--fp16", action="store_true", default=False, help="whether or not to enable mix precision with float16"
+ )
+ parser.add_argument(
+ "--bf16", action="store_true", default=False, help="whether or not to enable mix precision with bfloat16"
+ )
+ parser.add_argument(
+ "--is_distribute", type=ast.literal_eval, default=False, help="whether or not to run distribute"
+ )
+ parser.add_argument("--rank", type=int, default=0, help="id of card")
+ parser.add_argument("--rank_size", type=int, default=1, help="num of cards")
+ args = parser.parse_args()
+ print(args)
+
+ # 0. set mindspore context
+ ms.set_context(mode=ms.GRAPH_MODE, jit_config={"jit_level": "O0"})
+ if args.is_distribute:
+ from mindspore.communication import get_group_size, get_rank, init
+
+ init()
+ args.rank = get_rank()
+ args.rank_size = get_group_size()
+ ms.reset_auto_parallel_context()
+ ms.set_auto_parallel_context(
+ parallel_mode=ms.ParallelMode.DATA_PARALLEL,
+ gradients_mean=True,
+ device_num=get_group_size(),
+ )
+
+ # 1. create dataset
+ dataset = load_dataset(args.dataset_path)
+ dataset["train"] = dataset["train"].shuffle(seed=42).select(range(1000))
+ dataset["test"] = dataset["test"].shuffle(seed=42).select(range(1000))
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ def tokenize_function(examples):
+ return tokenizer(
+ examples["text"],
+ padding="max_length",
+ truncation=True,
+ max_length=512, # Note: pad is need for training batch size is gather than 1.
+ )
+
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
+ small_train_dataset = tokenized_datasets["train"]
+
+ def ms_data_collator(features, batch_info):
+ first = features[0]
+ assert isinstance(first, Dict)
+ batch = {}
+ batch["labels"] = np.array([f["label"] for f in features], dtype=np.int32)
+ for k, v in first.items():
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
+ if isinstance(v, np.ndarray):
+ batch[k] = np.stack([f[k] for f in features])
+ else:
+ batch[k] = np.array([f[k] for f in features])
+ return batch
+
+ batch_size, num_epochs = 1, 3
+ train_dataloader = ms.dataset.GeneratorDataset(
+ HF2MSDataset(small_train_dataset), column_names="item", shard_id=args.rank, num_shards=args.rank_size
+ )
+ train_dataloader = train_dataloader.batch(batch_size=batch_size, per_batch_map=ms_data_collator)
+ train_dataloader = train_dataloader.repeat(1)
+ train_dataloader = train_dataloader.create_dict_iterator(num_epochs=num_epochs, output_numpy=True)
+
+ # 2. create train network and mix precision
+ model = LlamaForSequenceClassification.from_pretrained(
+ args.model_path,
+ num_labels=5,
+ use_flash_attention_2=True,
+ mindspore_dtype=ms.bfloat16 if args.bf16 else (ms.float16 if args.fp16 else None),
+ )
+ model.gradient_checkpointing_enable()
+
+ assert not (args.fp16 and args.bf16)
+ if args.fp16:
+ model = auto_mixed_precision(model, "O2", ms.float16)
+ if args.bf16:
+ model = auto_mixed_precision(model, "O2", ms.bfloat16)
+
+ if args.zero_stage == 0:
+ optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=5e-6)
+ elif args.zero_stage == 1:
+ from mindone.transformers.mindspore_adapter import AdamWeightDecayZeRO1
+
+ optimizer = AdamWeightDecayZeRO1(model.trainable_params(), learning_rate=5e-6)
+ elif args.zero_stage == 2:
+ from mindone.transformers.mindspore_adapter import AdamWeightDecayZeRO2
+
+ optimizer = AdamWeightDecayZeRO2(model.trainable_params(), learning_rate=5e-6)
+ else:
+ raise ValueError
+
+ class ReturnLoss(nn.Cell):
+ def __init__(self, model):
+ super(ReturnLoss, self).__init__(auto_prefix=False)
+ self.model = model
+
+ def construct(self, *args, **kwargs):
+ outputs = self.model(*args, **kwargs)
+ loss = outputs[0]
+ return loss
+
+ train_model = TrainOneStepWrapper(ReturnLoss(model), optimizer)
+
+ # 3. training
+ train_model.set_train()
+ for step, batch in enumerate(train_dataloader):
+ batch = batch["item"]
+
+ # inputs dict to tuple
+ tuple_inputs = (
+ ms.Tensor(batch["input_ids"], ms.int32),
+ ms.Tensor(batch["attention_mask"], ms.bool_),
+ None,
+ None,
+ None,
+ ms.tensor(batch["labels"], ms.int32),
+ )
+
+ loss, _, overflow = train_model(*tuple_inputs)
+
+ print(f"step: {step}, loss: {loss}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/transformers/llama/finetune_with_mindspore_trainer.py b/examples/transformers/llama/finetune_with_mindspore_trainer.py
new file mode 100644
index 0000000000..4bdb21a964
--- /dev/null
+++ b/examples/transformers/llama/finetune_with_mindspore_trainer.py
@@ -0,0 +1,87 @@
+"""
+Llama 3 model fine-tuning script.
+This script with default values fine-tunes a pretrained Meta Llama3 on the `Yelp/yelp_review_full` dataset,
+"""
+
+
+from dataclasses import dataclass, field
+
+import evaluate
+import numpy as np
+from datasets import load_dataset
+from transformers import AutoTokenizer, HfArgumentParser
+
+import mindspore as ms
+
+from mindone.transformers.mindspore_adapter import MindSporeArguments, init_environment
+from mindone.transformers.models.llama import LlamaForSequenceClassification
+from mindone.transformers.trainer import Trainer
+from mindone.transformers.training_args import TrainingArguments
+
+
+@dataclass
+class MyArguments(MindSporeArguments, TrainingArguments):
+ model_path: str = field(default="meta-llama/Meta-Llama-3-8B")
+ dataset_path: str = field(default="Yelp/yelp_review_full")
+ output_dir: str = field(default="./outputs")
+ enable_flash_attention: bool = field(default=True)
+ gradient_checkpointing: bool = field(default=True)
+ is_distribute: bool = field(default=False)
+
+
+def main():
+ parser = HfArgumentParser(MyArguments)
+ args = parser.parse_args_into_dataclasses()[0]
+
+ init_environment(args)
+
+ dataset = load_dataset(args.dataset_path)
+ dataset["train"] = dataset["train"].shuffle(seed=42).select(range(1000))
+ dataset["test"] = dataset["test"].shuffle(seed=42).select(range(1000))
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ def tokenize_function(examples):
+ return tokenizer(
+ examples["text"],
+ padding="max_length",
+ truncation=True,
+ max_length=512, # Note: pad is need for training batch size is gather than 1.
+ )
+
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
+ small_train_dataset = tokenized_datasets["train"]
+ small_eval_dataset = tokenized_datasets["test"]
+
+ model = LlamaForSequenceClassification.from_pretrained(
+ args.model_path,
+ num_labels=5,
+ use_flash_attention_2=args.enable_flash_attention,
+ mindspore_dtype=ms.bfloat16 if args.bf16 else (ms.float16 if args.fp16 else None),
+ )
+
+ if args.do_eval:
+ metric = evaluate.load("accuracy")
+
+ def compute_metrics(eval_pred):
+ logits, labels = eval_pred
+ predictions = np.argmax(logits, axis=-1)
+ return metric.compute(predictions=predictions, references=labels)
+
+ else:
+ compute_metrics = None
+
+ trainer = Trainer(
+ model=model,
+ args=args,
+ train_dataset=small_train_dataset,
+ eval_dataset=small_eval_dataset,
+ compute_metrics=compute_metrics,
+ )
+
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/transformers/llama/generate.py b/examples/transformers/llama/generate.py
new file mode 100644
index 0000000000..13f2da85f7
--- /dev/null
+++ b/examples/transformers/llama/generate.py
@@ -0,0 +1,100 @@
+import argparse
+import ast
+import os
+import time
+
+from transformers import AutoTokenizer
+
+import mindspore as ms
+
+from mindone.transformers.mindspore_adapter import auto_mixed_precision
+from mindone.transformers.models.llama import LlamaForCausalLM
+
+
+def run_llama3_generate(args):
+ print("=====> test_llama3_generate:")
+ print("=====> Building model...")
+
+ s_time = time.time()
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+ model = LlamaForCausalLM.from_pretrained(args.model_path, use_flash_attention_2=args.use_fa)
+
+ model = auto_mixed_precision(model, amp_level="O2", dtype=ms.float16)
+
+ print("=====> Building model done.")
+
+ while True:
+ prompt = input("Enter your prompt [e.g. `What's your name?`] or enter [`q`] to exit: ")
+
+ if prompt == "q":
+ print("Generate task done, see you next time!")
+ break
+
+ prompt = [
+ prompt,
+ ]
+ input_ids = ms.Tensor(tokenizer(prompt).input_ids, ms.int32)
+
+ input_kwargs = {}
+ if args.use_embed_input:
+ input_kwargs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
+ else:
+ input_kwargs["input_ids"] = input_ids
+
+ output_ids = model.generate(**input_kwargs, use_cache=args.use_cache, max_new_tokens=30, do_sample=False)
+ output_ids = output_ids.asnumpy()
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+
+ print(f"=====> input prompt: {prompt}, time cost: {time.time() - s_time:.2f}s")
+ print("=" * 46 + " Result " + "=" * 46)
+ print(outputs)
+ print("=" * 100)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="test")
+ parser.add_argument("--ms_mode", type=int, default=0, help="0 is Graph, 1 is Pynative")
+ parser.add_argument("--jit_level", type=str, default="O0")
+ parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B")
+ parser.add_argument("--use_fa", type=ast.literal_eval, default=True)
+ parser.add_argument("--use_cache", type=ast.literal_eval, default=True)
+ parser.add_argument("--use_embed_input", type=ast.literal_eval, default=True)
+ args, _ = parser.parse_known_args()
+
+ if args.ms_mode == ms.GRAPH_MODE:
+ if os.environ.get("MS_DEV_RUNTIME_CONF") is None:
+ os.environ["MS_DEV_RUNTIME_CONF"] = "synchronize:True"
+ print("WARNING: os environment MS_DEV_RUNTIME_CONF synchronize has not been set, force setting it now.")
+ else:
+ if "synchronize:True" not in os.environ.get("MS_DEV_RUNTIME_CONF"):
+ _old = os.environ.get("MS_DEV_RUNTIME_CONF")
+ _old.replace("synchronize:False,", "")
+ _old.replace(",synchronize:False", "")
+ _old.replace("synchronize:False", "")
+ _new = "synchronize:True," + _old if len(_old) > 0 else "synchronize:True"
+ os.environ["MS_DEV_RUNTIME_CONF"] = _new
+ print("WARNING: os environment MS_DEV_RUNTIME_CONF synchronize has not been set, force setting it now.")
+
+ ms.set_context(
+ mode=ms.GRAPH_MODE,
+ device_target="Ascend",
+ jit_config={"jit_level": args.jit_level},
+ max_device_memory="59GB",
+ deterministic="ON",
+ )
+
+ elif args.ms_mode == ms.PYNATIVE_MODE:
+ ms.set_context(
+ mode=ms.PYNATIVE_MODE,
+ device_target="Ascend",
+ pynative_synchronize=True,
+ max_device_memory="59GB",
+ deterministic="ON",
+ )
+
+ else:
+ raise ValueError
+
+ run_llama3_generate(args)
diff --git a/examples/transformers/llama/generate_pt.py b/examples/transformers/llama/generate_pt.py
new file mode 100644
index 0000000000..9ea393340d
--- /dev/null
+++ b/examples/transformers/llama/generate_pt.py
@@ -0,0 +1,61 @@
+import argparse
+import ast
+import time
+
+import torch
+from transformers import AutoTokenizer
+from transformers.models.llama import LlamaForCausalLM
+
+
+def run_llama3_generate_pt(args):
+ print("=====> test_llama3_generate:")
+ print("=====> Building model...")
+
+ s_time = time.time()
+
+ device = "cuda:0"
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+ model = LlamaForCausalLM.from_pretrained(args.model_path, attn_implementation="eager")
+ model.to(device)
+
+ print("=====> Building model done.")
+
+ while True:
+ prompt = input("Enter your prompt [e.g. `What's your name?`] or enter [`q`] to exit: ")
+
+ if prompt == "q":
+ print("Generate task done, see you next time!")
+ break
+
+ prompt = [
+ prompt,
+ ]
+ input_ids = torch.tensor(tokenizer(prompt).input_ids).to(device)
+
+ input_kwargs = {}
+ if args.use_embed_input:
+ input_kwargs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
+ else:
+ input_kwargs["input_ids"] = input_ids
+
+ output_ids = model.generate(**input_kwargs, use_cache=args.use_cache, max_new_tokens=24, do_sample=False)
+ output_ids = output_ids.detach().cpu().numpy()
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+
+ print(f"=====> input prompt: {prompt}, time cost: {time.time() - s_time:.2f}s")
+ print("=" * 46 + " Result " + "=" * 46)
+ print(outputs)
+ print("=" * 100)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="test")
+ parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B")
+ parser.add_argument("--use_fa", type=ast.literal_eval, default=False) # unavailable
+ parser.add_argument("--use_cache", type=ast.literal_eval, default=True)
+ parser.add_argument("--use_embed_input", type=ast.literal_eval, default=True)
+ args, _ = parser.parse_known_args()
+
+ run_llama3_generate_pt(args)
diff --git a/mindone/transformers/README.md b/mindone/transformers/README.md
index 7c410cc17a..45e31e82ff 100644
--- a/mindone/transformers/README.md
+++ b/mindone/transformers/README.md
@@ -1,16 +1,24 @@
-# Get Pretrained Txt/Img Encoder from 🤗 Transformers
+# Make 🤗 Transformers run on MindSpore
-This MindSpore patch for [🤗 Transformers](https://github.com/huggingface/transformers) enables researchers or developers
-in the field of text-to-image (t2i) and text-to-video (t2v) generation to utilize pretrained text and image models from 🤗 Transformers on MindSpore.
-The pretrained models from 🤗 Transformers can be employed either as frozen encoders or fine-tuned with denoising networks for generative tasks.
-This approach **_aligns with the practices_** of PyTorch users[[1]](https://github.com/huggingface/diffusers)[[2]](https://github.com/Stability-AI/generative-models).
-Now, MindSpore users can benefit from the same functionality!
+
-## Philosophy
+> State-of-the-art transformers models to perform tasks on different modalities such as text, vision,
+> and audio in MindSpore. We've tried to provide a similar interface and usage with the
+> [huggingface/transformers](https://github.com/huggingface/transformers). Only necessary changes are made to
+> the [huggingface/transformers](https://github.com/huggingface/transformers) to make it seamless for users from torch.
+
+🤗 **Development Principles**
+
+- Only necessary changes are made to the [huggingface/transformers](https://github.com/huggingface/transformers)
+- Configuration, Tokenizer, etc. will utilize the original Transformers.
+
+🤗 **Currently**,
+we provides pretrained models, generation api, trainer, etc.
+to be enables researchers or developers in the field of AIGC and MLLMs to utilize Transformers on MindSpore.
+
+🤗 **Comming Soon**,
+latest state-of-the-art models, auto class, pipeline, agent, distributed and so on.
-- Only the MindSpore model definition will be implemented, which will be identical to the PyTorch model.
-- Configuration, Tokenizer, etc. will utilize the original 🤗 Transformers.
-- Models here will be limited to the scope of generative tasks.
## Quick Tour
@@ -18,18 +26,21 @@ The following lines of code are an example that shows you how to download and us
Remember that the models are from `mindone.transformers`, and anything else is from 🤗 Transformers.
```diff
-from mindspore import Tensor
-# use tokenizer from 🤗 Transformers
++from mindspore import Tensor
+
+# use tokenizer from 🤗 transformers
from transformers import AutoTokenizer
-# use model from mindone.transformers
--from transformers import CLIPTextModel
-+from mindone.transformers import CLIPTextModel
-model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
-tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+# replace model from 🤗 transformers to mindone.transformers
+-from transformers import LlamaForCausalLM
++from mindone.transformers import LlamaForCausalLM
+
+model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
+tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B)
+tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
inputs = tokenizer(
- ["a photo of a cat", "a photo of a dog"],
+ ["A list of colors: red, blue", "Portugal is"],
padding=True,
- return_tensors="pt",
+ return_tensors="np"
@@ -38,8 +49,36 @@ inputs = tokenizer(
+outputs = model(Tensor(inputs.input_ids))
```
+Then run text generation.
+
+```diff
+generated_ids = model.generate(
+- **inputs,
++ input_ids=Tensor(inputs.input_ids),
+ max_new_tokens=30,
+ use_cache=True,
+ do_sample=False
+)
+
+tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+```
+
+
+## Tutorials
+
+| Section | Description |
+|------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------|
+| [Generation with LLMs](../../docs/transformers/tutorials/generation.md) | Generate text with an LLM |
+| [Training and fine-tuning](../../docs/transformers/tutorials/finetune.md) | Using the models provided by 🤗 Transformers in a native MindSpore training loop and the `Trainer` API |
+| [Distributed training and mixed precision](../../docs/transformers/tutorials/finetune_distribute.md) | Example scripts for fine-tuning models using distribute and mix precision |
+
+
## Model Zoo
+We introduced some of the provided models and basic usage, as detailed below:
+
+
+
### CLIP
The CLIP model was proposed in [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) by
@@ -191,8 +230,15 @@ logits = outputs[0]
encoder_outputs = outputs[1]
```
+
+
+
## Numerical Parity
+We compare the numerical parity with the [huggingface/transformer](https://github.com/huggingface/transformers), as detailed below:
+
+
+
MindSpore 2.2/2.3 @ Ascend **_vs._** Pytorch 2.2 @ CPU(aarch64)
Error Formula: `max(abs(ms-pt)) / mean(abs(pt))`
@@ -296,3 +342,5 @@ Error Formula: `max(abs(ms-pt)) / mean(abs(pt))`
| google-t5/t5-small | 4.88E-05 |
| DeepFloyd/t5-v1_1-xxl | 9.84E-05 |
| google/flan-t5-large | 4.55E-06 |
+
+
diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py
index bc54dcb729..c6222a059b 100644
--- a/mindone/transformers/__init__.py
+++ b/mindone/transformers/__init__.py
@@ -1,3 +1,6 @@
+__version__ = "4.42.4"
+
+
from .modeling_utils import MSPreTrainedModel
from .models.bert import (
BertForMaskedLM,
@@ -37,6 +40,7 @@
GemmaModel,
GemmaPreTrainedModel,
)
+from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
from .models.gemma2 import Gemma2Model, Gemma2PreTrainedModel
from .models.mt5 import (
MT5_PRETRAINED_MODEL_ARCHIVE_LIST,
diff --git a/mindone/transformers/activations.py b/mindone/transformers/activations.py
index 7309b5adbe..9ab47b67c9 100644
--- a/mindone/transformers/activations.py
+++ b/mindone/transformers/activations.py
@@ -14,7 +14,9 @@
import math
from collections import OrderedDict
+from functools import partial
+import mindspore as ms
from mindspore import Tensor, nn, ops
@@ -47,7 +49,7 @@ class GELUActivation(nn.Cell):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
- torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
+ ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ops.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
@@ -79,8 +81,12 @@ class QuickGELUActivation(nn.Cell):
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
- def construct(self, input: Tensor) -> Tensor:
- return input * ops.sigmoid(1.702 * input)
+ def __init__(self):
+ super(QuickGELUActivation, self).__init__()
+ self.sigmoid = nn.Sigmoid()
+
+ def construct(self, input):
+ return input * self.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Cell):
@@ -93,7 +99,7 @@ class ClippedGELUActivation(nn.Cell):
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
- torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
+ ops.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * ops.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
def __init__(self, min: float, max: float):
@@ -125,17 +131,17 @@ def construct(self, input: Tensor) -> Tensor:
return 0.5 * input * (1 + ops.tanh(self.precomputed_constant * (input + 0.044715 * ops.pow(input, 3))))
-class SiLUActivation(nn.Cell):
- """
- See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
- Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
- Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
- Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
- later.
- """
+class SiLUActivationFP32(nn.Cell):
+ def __init__(self):
+ super(SiLUActivationFP32, self).__init__()
+ self.sigmoid = nn.Sigmoid()
- def construct(self, input: Tensor) -> Tensor:
- return ops.silu(input)
+ def construct(self, x):
+ _dtype = x.dtype
+ x = x.to(ms.float32)
+ out = x * self.sigmoid(x)
+ out = out.to(_dtype)
+ return out
class MishActivation(nn.Cell):
@@ -189,7 +195,7 @@ def __getitem__(self, key):
ACT2CLS = {
- "gelu": GELUActivation,
+ "gelu": partial(nn.GELU, approximate=False),
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
@@ -204,8 +210,8 @@ def __getitem__(self, key):
"relu2": ReLUSquaredActivation,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
- "silu": SiLUActivation,
- "swish": SiLUActivation,
+ "silu": SiLUActivationFP32,
+ "swish": SiLUActivationFP32,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)
diff --git a/mindone/transformers/cache_utils.py b/mindone/transformers/cache_utils.py
new file mode 100644
index 0000000000..af48673bfd
--- /dev/null
+++ b/mindone/transformers/cache_utils.py
@@ -0,0 +1,281 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+import mindspore as ms
+from mindspore import ops
+
+logger = logging.get_logger(__name__)
+
+
+def init_static_cache(config: PretrainedConfig, max_batch_size: int, max_cache_len: int, dtype=None):
+ max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+
+ dtype = dtype if dtype is not None else ms.float32
+ num_key_value_heads = (
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
+ )
+
+ key_value_cache: Tuple[Tuple[ms.Tensor, ms.Tensor]] = ()
+ cache_shape = (max_batch_size, num_key_value_heads, max_cache_len, head_dim)
+ for _layer_index in range(config.num_hidden_layers):
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
+ # breaks when updating the cache.
+ new_layer_key_cache = ms.Tensor(np.zeros(cache_shape), dtype=dtype)
+ new_layer_value_cache = ms.Tensor(np.zeros(cache_shape), dtype=dtype)
+ key_value_cache += ((new_layer_key_cache, new_layer_value_cache),)
+
+ return key_value_cache
+
+
+# Notes: Only return the updated value, do not modifying the original `past_key_value` in-place !
+def update(
+ past_key_value: Tuple[ms.Tensor, ms.Tensor],
+ key_states: ms.Tensor,
+ value_states: ms.Tensor,
+ cache_position: Optional[ms.Tensor] = None,
+) -> Tuple[ms.Tensor, ms.Tensor]:
+ """
+ Notes: Only return the updated value, do not modifying the original `past_key_value` in-place !
+
+ Get the cache with the new `key_states` and `value_states` for cur layer.
+
+ Parameters:
+ past_key_value (`Tuple[ms.Tensor, ms.Tensor]`):
+ Past key/value states cache.
+ key_states (`ms.Tensor`):
+ The new key states to cache.
+ value_states (`ms.Tensor`):
+ The new value states to cache.
+ cache_position (`ms.Tensor`, `optional`):
+ Additional arguments for the cache subclass, needs the `cache_position` input
+ to know how where to write in the cache.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ k_out, v_out = past_key_value[0], past_key_value[1]
+
+ if cache_position.shape[0] == 1:
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+ else:
+ # assert cache_position.shape[0] == k_out.shape[2]
+
+ k_out = ops.select(
+ (ops.arange(k_out.shape[2]) == cache_position)[None, None, :, None],
+ key_states,
+ k_out,
+ )
+ v_out = ops.select(
+ (ops.arange(v_out.shape[2]) == cache_position)[None, None, :, None],
+ value_states,
+ v_out,
+ )
+
+ return k_out, v_out
+
+
+def get_seq_length(past_key_values, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states that were seen by the model."""
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ # TODO: deprecate this function in favor of `cache_position`
+ return (past_key_values[layer_idx][0][0, 0].any(axis=-1)).sum()
+
+
+def get_max_length(past_key_values) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states."""
+ return past_key_values[0][0].shape[2]
+
+
+def reset(past_key_values):
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(past_key_values)):
+ # In-place ops prevent breaking the static address
+ past_key_values[layer_idx][0] = ops.zeros_like(past_key_values[layer_idx][0]) # key
+ past_key_values[layer_idx][1] = ops.zeros_like(past_key_values[layer_idx][1]) # value
+
+ return past_key_values
+
+
+@dataclass
+class Cache:
+ """
+ Base, abstract class for all caches. The actual data structure is specific to each subclass.
+ """
+
+ def update(
+ self,
+ key_states: ms.Tensor,
+ value_states: ms.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[ms.Tensor, ms.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`ms.Tensor`):
+ The new key states to cache.
+ value_states (`ms.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+ cache to be created.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ raise NotImplementedError("Make sure to implement `update` in a subclass.")
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states, if there is any."""
+ raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
+
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
+ # Cache without size limit -> all cache is usable
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
+ max_length = self.get_max_length()
+ previous_seq_length = self.get_seq_length(layer_idx)
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
+ return max_length - new_seq_length
+ return previous_seq_length
+
+ def reorder_cache(self, beam_idx: ms.Tensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ for layer_idx in range(len(self.key_cache)):
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].gather(input_indices=beam_idx, axis=0)
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].gather(input_indices=beam_idx, axis=0)
+
+ @property
+ def seen_tokens(self):
+ logger.warning_once(
+ "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
+ "model input instead."
+ )
+ if hasattr(self, "_seen_tokens"):
+ return self._seen_tokens
+ else:
+ return None
+
+
+class StaticCache(Cache):
+ """
+ Static Cache class to be used with `static shape`.
+
+ Parameters:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used.
+ max_cache_len (`int`):
+ The maximum sequence length with which the model will be used.
+ dtype (*optional*, defaults to `ms.float32`):
+ The default `dtype` to use when initializing the layer.
+ """
+
+ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, dtype=None) -> None:
+ super().__init__()
+ self.max_batch_size = max_batch_size
+ self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ self.head_dim = (
+ config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+ )
+
+ self.dtype = dtype if dtype is not None else ms.float32
+ self.num_key_value_heads = (
+ config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
+ )
+
+ key_cache: List[ms.Parameter] = []
+ value_cache: List[ms.Parameter] = []
+ cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
+ for _layer_index in range(config.num_hidden_layers):
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
+ # breaks when updating the cache.
+ new_layer_key_cache = ms.Parameter(
+ ms.Tensor(np.zeros(cache_shape), dtype=self.dtype),
+ name=f"key_cache_{_layer_index}",
+ requires_grad=False,
+ )
+ new_layer_value_cache = ms.Parameter(
+ ms.Tensor(np.zeros(cache_shape), dtype=self.dtype),
+ name=f"value_cache_{_layer_index}",
+ requires_grad=False,
+ )
+ key_cache.append(new_layer_key_cache)
+ value_cache.append(new_layer_value_cache)
+
+ self.key_cache = ms.ParameterTuple(key_cache)
+ self.value_cache = ms.ParameterTuple(value_cache)
+
+ def update(
+ self,
+ key_states: ms.Tensor,
+ value_states: ms.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[ms.Tensor, ms.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+ Parameters:
+ key_states (`ms.Tensor`):
+ The new key states to cache.
+ value_states (`ms.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
+ to know how where to write in the cache.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ cache_position = cache_kwargs.get("cache_position")
+ k_out = self.key_cache[layer_idx]
+ v_out = self.value_cache[layer_idx]
+
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ # update to self.key_cache?
+ self.key_cache[layer_idx] = k_out
+ self.value_cache[layer_idx] = v_out
+
+ return k_out, v_out
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states that were seen by the model."""
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ # TODO: deprecate this function in favor of `cache_position`
+ return (self.key_cache[layer_idx][0, 0].any(axis=-1)).sum()
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states."""
+ return self.max_cache_len
+
+ def reset(self):
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ ops.assign(self.key_cache[layer_idx], ms.Tensor(0.0))
+ ops.assign(self.value_cache[layer_idx], ms.Tensor(0.0))
diff --git a/mindone/transformers/data/data_collator.py b/mindone/transformers/data/data_collator.py
new file mode 100644
index 0000000000..9de9640d46
--- /dev/null
+++ b/mindone/transformers/data/data_collator.py
@@ -0,0 +1,166 @@
+from collections.abc import Mapping
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, NewType, Optional, Union
+
+import numpy as np
+from transformers import PreTrainedTokenizerBase
+from transformers.utils.generic import PaddingStrategy
+
+InputDataClass = NewType("InputDataClass", Any)
+
+"""
+A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
+of MindSpore/PyTorch/TensorFlow tensors or NumPy arrays.
+"""
+DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
+
+
+class DataCollatorMixin:
+ def __call__(self, features, return_tensors=None):
+ if return_tensors is None:
+ return_tensors = self.return_tensors
+ if return_tensors == "ms":
+ return self.mindspore_call(features)
+ elif return_tensors == "tf":
+ return self.tf_call(features)
+ elif return_tensors == "pt":
+ return self.torch_call(features)
+ elif return_tensors == "np":
+ return self.numpy_call(features)
+ else:
+ raise ValueError(f"Framework '{return_tensors}' not recognized!")
+
+
+def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
+ """
+ Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
+ """
+
+ # To avoid errors when using Feature extractors
+ if not hasattr(tokenizer, "deprecation_warnings"):
+ return tokenizer.pad(*pad_args, **pad_kwargs)
+
+ # Save the state of the warning, then disable it
+ warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
+
+ try:
+ padded = tokenizer.pad(*pad_args, **pad_kwargs)
+ finally:
+ # Restore the state of the warning.
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
+
+ return padded
+
+
+def default_data_collator(features: List[InputDataClass], return_tensors="np") -> Dict[str, Any]:
+ """
+ Very simple data collator that simply collates batches of dict-like objects and performs special handling for
+ potential keys named:
+
+ - `label`: handles a single value (int or float) per object
+ - `label_ids`: handles a list of values per object
+
+ Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
+ to the model. See glue and ner for example of how it's useful.
+ """
+
+ # In this function we'll make the assumption that all `features` in the batch
+ # have the same attributes.
+ # So we will look at the first element as a proxy for what attributes exist
+ # on the whole batch.
+
+ if return_tensors == "ms":
+ raise NotImplementedError
+ elif return_tensors == "pt":
+ raise NotImplementedError
+ elif return_tensors == "tf":
+ raise NotImplementedError
+ elif return_tensors == "np":
+ return numpy_default_data_collator(features)
+ else:
+ raise ValueError
+
+
+def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
+ if not isinstance(features[0], Mapping):
+ features = [vars(f) for f in features]
+ first = features[0]
+ batch = {}
+
+ # Special handling for labels.
+ # Ensure that tensor is created with the correct type
+ # (it should be automatically the case, but let's make sure of it.)
+ if "label" in first and first["label"] is not None:
+ label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
+ dtype = np.int64 if isinstance(label, int) else np.float32
+ batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
+ elif "label_ids" in first and first["label_ids"] is not None:
+ if isinstance(first["label_ids"], np.ndarray):
+ batch["labels"] = np.stack([f["label_ids"] for f in features])
+ else:
+ dtype = np.int64 if isinstance(first["label_ids"][0], int) else np.float32
+ batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
+
+ # Handling of all other possible keys.
+ # Again, we will use the first element to figure out which key/values are not None for this model.
+ for k, v in first.items():
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
+ if isinstance(v, np.ndarray):
+ batch[k] = np.stack([f[k] for f in features])
+ else:
+ batch[k] = np.array([f[k] for f in features])
+
+ return batch
+
+
+@dataclass
+class DataCollatorWithPadding:
+ """
+ Data collator that will dynamically pad the inputs received.
+
+ Args:
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
+ The tokenizer used for encoding the data.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
+ among:
+
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
+ sequence is provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
+ 7.5 (Volta).
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
+ """
+
+ tokenizer: PreTrainedTokenizerBase
+ padding: Union[bool, str, PaddingStrategy] = True
+ max_length: Optional[int] = None
+ pad_to_multiple_of: Optional[int] = None
+ return_tensors: str = "np"
+
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
+ batch = pad_without_fast_tokenizer_warning(
+ self.tokenizer,
+ features,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ return_tensors=self.return_tensors,
+ )
+ if "label" in batch:
+ batch["labels"] = batch["label"]
+ del batch["label"]
+ if "label_ids" in batch:
+ batch["labels"] = batch["label_ids"]
+ del batch["label_ids"]
+ return batch
diff --git a/mindone/transformers/debug_utils.py b/mindone/transformers/debug_utils.py
new file mode 100644
index 0000000000..f63c41e770
--- /dev/null
+++ b/mindone/transformers/debug_utils.py
@@ -0,0 +1,7 @@
+from transformers.utils.generic import ExplicitEnum
+
+
+class DebugOption(ExplicitEnum):
+ UNDERFLOW_OVERFLOW = "underflow_overflow"
+ NPU_METRICS_DEBUG = "npu_metrics_debug"
+ TPU_METRICS_DEBUG = "tpu_metrics_debug"
diff --git a/mindone/transformers/generation/__init__.py b/mindone/transformers/generation/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/mindone/transformers/generation/logits_process.py b/mindone/transformers/generation/logits_process.py
new file mode 100644
index 0000000000..d8404031af
--- /dev/null
+++ b/mindone/transformers/generation/logits_process.py
@@ -0,0 +1,446 @@
+import inspect
+from typing import Callable, List, Union
+
+import numpy as np
+from transformers.utils import add_start_docstrings
+from transformers.utils.logging import get_logger
+
+import mindspore as ms
+import mindspore.numpy as mnp
+from mindspore import ops
+
+from mindone.transformers.mindspore_adapter.utils import dtype_to_min
+
+INF = 1e5
+
+
+logger = get_logger(__name__)
+
+
+LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`ms.Tensor or numpy.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
+ scores (`ms.Tensor or numpy.ndarray` of shape `(batch_size, config.vocab_size)`):
+ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
+ search or log softmax for each vocabulary token when using beam search
+
+ Return:
+ `ms.Tensor or numpy.ndarray` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
+
+"""
+
+
+class LogitsProcessor:
+ """Abstract base class for all logit processors that can be applied during generation."""
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ raise NotImplementedError(
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+ )
+
+
+class LogitsWarper:
+ """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ raise NotImplementedError(
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+ )
+
+
+class LogitsProcessorList(list):
+ """
+ This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor.
+ This class inherits from list and adds a specific *__call__* method to apply each [`LogitsProcessor`] to the
+ inputs.
+ """
+
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray], **kwargs
+ ) -> Union[ms.Tensor, np.ndarray]:
+ r"""
+ Args:
+ input_ids (`Union[ms.Tensor, np.ndarray]` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
+ scores (`Union[ms.Tensor, np.ndarray]` of shape `(batch_size, config.vocab_size)`):
+ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
+ beam search or log softmax for each vocabulary token when using beam search
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional kwargs that are specific to a logits processor.
+
+ Return:
+ `Union[ms.Tensor, np.ndarray]` of shape `(batch_size, config.vocab_size)`:
+ The processed prediction scores.
+
+ """
+ for processor in self:
+ function_args = inspect.signature(processor.__call__).parameters
+ if len(function_args) > 2:
+ if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
+ raise ValueError(
+ f"Make sure that all the required parameters: {list(function_args.keys())} for "
+ f"{processor.__class__} are passed to the logits processor."
+ )
+ scores = processor(input_ids, scores, **kwargs)
+ else:
+ scores = processor(input_ids, scores)
+
+ return scores
+
+
+class MinLengthLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. Note that, for decoder-only models
+ like most LLMs, the length includes the prompt.
+
+ Args:
+ min_length (`int`):
+ The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
+ eos_token_id (`Union[int, List[int], ms.Tensor, np.ndarray]`):
+ The id(s) of the *end-of-sequence* token.
+ """
+
+ def __init__(self, min_length: int, eos_token_id: Union[int, List[int], ms.Tensor, np.ndarray], **ignore):
+ if not isinstance(min_length, int) or min_length < 0:
+ raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")
+
+ # to list
+ if not isinstance(eos_token_id, ms.Tensor):
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ elif isinstance(eos_token_id, np.ndarray):
+ eos_token_id = eos_token_id.tolist()
+ else:
+ eos_token_id = eos_token_id.asnumpy().tolist()
+
+ self.min_length = min_length
+ self.eos_token_id = eos_token_id
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ if isinstance(scores, ms.Tensor):
+ vocab_tensor = ops.arange(0, scores.shape[-1])
+ eos_token_mask = mnp.isin(vocab_tensor, self.eos_token_id)
+ scores_processed = scores[:]
+ if input_ids.shape[-1] < self.min_length:
+ scores_processed = ops.where(eos_token_mask, -INF, scores)
+ elif isinstance(scores, np.ndarray):
+ vocab_tensor = np.arange(0, scores.shape[-1])
+ eos_token_mask = np.isin(vocab_tensor, self.eos_token_id)
+ scores_processed = scores[:]
+ if input_ids.shape[-1] < self.min_length:
+ scores_processed = ops.where(eos_token_mask, -INF, scores)
+ else:
+ raise NotImplementedError
+
+ return scores_processed
+
+
+class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
+ Contrarily to [`MinLengthLogitsProcessor`], this processor ignores the prompt.
+
+ Args:
+ prompt_length_to_skip (`int`):
+ The input tokens length. Not a valid argument when used with `generate` as it will automatically assign the
+ input length.
+ min_new_tokens (`int`):
+ The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`.
+ eos_token_id (`Union[int, List[int], ms.Tensor]`):
+ The id(s) of the *end-of-sequence* token.
+ """
+
+ def __init__(
+ self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int], ms.Tensor], **ignore
+ ):
+ for arg_name, arg_value in [
+ ("prompt_length_to_skip", prompt_length_to_skip),
+ ("min_new_tokens", min_new_tokens),
+ ]:
+ if not isinstance(arg_value, int) or arg_value < 0:
+ raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
+
+ # to list
+ if not isinstance(eos_token_id, ms.Tensor):
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ elif isinstance(eos_token_id, np.ndarray):
+ eos_token_id = eos_token_id.tolist()
+ else:
+ eos_token_id = eos_token_id.asnumpy().tolist()
+
+ self.prompt_length_to_skip = prompt_length_to_skip
+ self.min_new_tokens = min_new_tokens
+ self.eos_token_id = eos_token_id
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ if isinstance(scores, ms.Tensor):
+ new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
+ scores_processed = scores[:]
+ vocab_tensor = ops.arange(0, scores.shape[-1])
+ eos_token_mask = mnp.isin(vocab_tensor, self.eos_token_id)
+ if new_tokens_length < self.min_new_tokens:
+ scores_processed = ops.where(eos_token_mask, -INF, scores)
+ elif isinstance(scores, np.ndarray):
+ new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
+ scores_processed = scores[:]
+ vocab_tensor = np.arange(0, scores.shape[-1])
+ eos_token_mask = np.isin(vocab_tensor, self.eos_token_id)
+ if new_tokens_length < self.min_new_tokens:
+ scores_processed = np.where(eos_token_mask, -INF, scores)
+ else:
+ raise NotImplementedError
+
+ return scores_processed
+
+
+class TemperatureLogitsWarper(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] for temperature (exponential scaling output probability distribution), which effectively means
+ that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
+ [`TopKLogitsWarper`].
+
+
+
+ Make sure that `do_sample=True` is included in the `generate` arguments otherwise the temperature value won't have
+ any effect.
+
+
+
+ Args:
+ temperature (`float`):
+ Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases
+ randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely
+ token.
+ """
+
+ def __init__(self, temperature: float):
+ if not isinstance(temperature, float) or not (temperature > 0):
+ except_msg = (
+ f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
+ "scores will be invalid."
+ )
+ if isinstance(temperature, float) and temperature == 0.0:
+ except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
+ raise ValueError(except_msg)
+
+ self.temperature = temperature
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ scores_processed = scores / self.temperature
+ return scores_processed
+
+
+class TopPLogitsWarper(LogitsWarper):
+ """
+ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
+ used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
+
+ Args:
+ top_p (`float`):
+ If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
+ higher are kept for generation.
+ filter_value (`float`, *optional*, defaults to -inf):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, set_seed
+ >>> from mindone.transformers.models.llama import LlamaForCausalLM
+
+ >>> set_seed(1)
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
+
+ >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="np")
+
+ >>> # With sampling, the output is unexpected -- sometimes too unexpected.
+ >>> outputs = model.generate(**inputs, do_sample=True)
+ >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
+ A sequence: 1, 2, 3 | < 4 (left-hand pointer) ;
+
+
+
+ >>> # With `top_p` sampling, the output gets restricted to high-probability tokens.
+ >>> # Pro tip: In practice, LLMs use `top_p` in the 0.9-0.95 range.
+ >>> outputs = model.generate(**inputs, do_sample=True, top_p=0.1)
+ >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
+ A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
+ ```
+ """
+
+ def __init__(self, top_p: float, filter_value: float = None, min_tokens_to_keep: int = 1):
+ top_p = float(top_p)
+ if top_p < 0 or top_p > 1.0:
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
+ if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
+ raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
+
+ self.top_p = top_p
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ if isinstance(scores, ms.Tensor):
+ filter_value = self.filter_value if self.filter_value is not None else dtype_to_min(scores.dtype)
+
+ sorted_logits, sorted_indices = ops.sort(scores, descending=False)
+ cumulative_probs = sorted_logits.softmax(axis=-1).cumsum(axis=-1)
+
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
+ # Keep at least min_tokens_to_keep
+ sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
+
+ # scatter sorted tensors to original indexing
+ # indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ sorted_indices_to_remove = sorted_indices_to_remove.astype(ms.int32)
+ indices_to_remove = ops.tensor_scatter_elements(
+ sorted_indices_to_remove, indices=sorted_indices, updates=sorted_indices_to_remove, axis=1
+ )
+
+ scores_processed = scores.masked_fill(indices_to_remove.astype(ms.bool_), filter_value)
+ elif isinstance(scores, np.ndarray):
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ return scores_processed
+
+
+class TopKLogitsWarper(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. Often used
+ together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
+
+ Args:
+ top_k (`int`):
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
+ filter_value (`float`, *optional*, defaults to -inf):
+ All filtered values will be set to this float value.
+ min_tokens_to_keep (`int`, *optional*, defaults to 1):
+ Minimum number of tokens that cannot be filtered.
+ """
+
+ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
+ if not isinstance(top_k, int) or top_k <= 0:
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
+
+ self.top_k = max(top_k, min_tokens_to_keep)
+ self.filter_value = filter_value
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ if isinstance(scores, ms.Tensor):
+ top_k = min(self.top_k, scores.shape[-1]) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = scores < ops.topk(scores, top_k)[0][..., -1, None]
+ scores_processed = scores.masked_fill(indices_to_remove, self.filter_value)
+ elif isinstance(scores, np.ndarray):
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+ return scores_processed
+
+
+class PrefixConstrainedLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
+ generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.
+
+ Args:
+ prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], List[int]]`):
+ This function constraints the beam search to allowed tokens only at each step. This function takes 2
+ arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
+ next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
+ `batch_id`.
+ """
+
+ def __init__(self, prefix_allowed_tokens_fn: Callable[[int, ms.Tensor], List[int]], num_beams: int):
+ self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
+ self._num_beams = num_beams
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ if isinstance(input_ids, ms.Tensor):
+ assert isinstance(scores, ms.Tensor)
+ mask = ops.full_like(scores, -INF)
+ for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
+ for beam_id, sent in enumerate(beam_sent):
+ prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
+ if len(prefix_allowed_tokens) == 0:
+ raise ValueError(
+ f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
+ f"This means that the constraint is unsatisfiable. Please check your implementation"
+ f"of `prefix_allowed_tokens_fn` "
+ )
+ mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
+ elif isinstance(input_ids, np.ndarray):
+ assert isinstance(scores, np.ndarray)
+ mask = np.full_like(scores, -INF)
+ for batch_id, beam_sent in enumerate(input_ids.reshape((-1, self._num_beams, input_ids.shape[-1]))):
+ for beam_id, sent in enumerate(beam_sent):
+ prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
+ if len(prefix_allowed_tokens) == 0:
+ raise ValueError(
+ f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
+ f"This means that the constraint is unsatisfiable. Please check your implementation"
+ f"of `prefix_allowed_tokens_fn` "
+ )
+ mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0
+
+ else:
+ raise NotImplementedError
+
+ scores_processed = scores + mask
+ return scores_processed
+
+
+class LogitNormalization(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
+ the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
+ this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
+ the scores are normalized when comparing the hypotheses.
+
+ """
+
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray]
+ ) -> Union[ms.Tensor, np.ndarray]:
+ if isinstance(scores, ms.Tensor):
+ scores_processed = ops.log_softmax(scores.to(ms.float32), axis=-1).to(scores.dtype)
+ elif isinstance(scores, np.ndarray):
+ exp_scores = np.exp(scores)
+ scores_processed = np.log(exp_scores / exp_scores.sum(-1))
+ else:
+ raise NotImplementedError
+
+ return scores_processed
diff --git a/mindone/transformers/generation/stopping_criteria.py b/mindone/transformers/generation/stopping_criteria.py
new file mode 100644
index 0000000000..a8b5e402c1
--- /dev/null
+++ b/mindone/transformers/generation/stopping_criteria.py
@@ -0,0 +1,181 @@
+import time
+from abc import ABC
+from collections import OrderedDict
+from typing import List, Optional, Union
+
+import numpy as np
+from transformers.utils import add_start_docstrings, logging
+
+import mindspore as ms
+import mindspore.numpy as mnp
+from mindspore import ops
+
+logger = logging.get_logger(__name__)
+# We maintain a module-level cache of the embedding vectors for the stop string criterion
+# because they are slow to compute
+STOP_STRING_EMBEDDING_CACHE = OrderedDict()
+
+
+STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ scores (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, config.vocab_size)`):
+ Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
+ or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
+ make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional stopping criteria specific kwargs.
+
+ Return:
+ `Union[ms.Tensor, numpy.ndarray]`. (`Union[ms.Tensor, numpy.ndarray]` of shape `(batch_size, 1)`), where `True` indicates we stop generation
+ for a particular row, `True` indicates we should continue.
+
+"""
+
+
+class StoppingCriteria(ABC):
+ """Abstract base class for all stopping criteria that can be applied during generation.
+
+ If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
+ output_scores=True` to `generate`.
+ """
+
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
+ def __call__(self, input_ids: ms.Tensor, scores: ms.Tensor, **kwargs) -> ms.Tensor:
+ raise NotImplementedError("StoppingCriteria needs to be subclassed")
+
+
+class MaxLengthCriteria(StoppingCriteria):
+ """
+ This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
+ in mind for decoder-only type of transformers, this will include the initial prompted tokens.
+
+ Args:
+ max_length (`int`):
+ The maximum length that the output sequence can have in number of tokens.
+ max_position_embeddings (`int`, *optional*):
+ The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
+ """
+
+ def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
+ self.max_length = max_length
+ self.max_position_embeddings = max_position_embeddings
+
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray], **kwargs
+ ) -> Union[ms.Tensor, np.ndarray]:
+ cur_len = input_ids.shape[-1]
+ is_done = cur_len >= self.max_length
+ if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
+ logger.warning_once(
+ "This is a friendly reminder - the current text generation call will exceed the model's predefined "
+ f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
+ "exceptions, performance degradation, or nothing at all."
+ )
+
+ if isinstance(input_ids, ms.Tensor):
+ return ops.full((input_ids.shape[0],), is_done, dtype=ms.bool_)
+ elif isinstance(input_ids, np.ndarray):
+ return np.full((input_ids.shape[0],), is_done, dtype=np.bool_)
+ else:
+ raise NotImplementedError
+
+
+class MaxTimeCriteria(StoppingCriteria):
+ """
+ This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
+ time will start being counted when you initialize this function. You can override this by passing an
+ `initial_time`.
+
+ Args:
+ max_time (`float`):
+ The maximum allowed time in seconds for the generation.
+ initial_time (`float`, *optional*, defaults to `time.time()`):
+ The start of the generation allowed time.
+ """
+
+ def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
+ self.max_time = max_time
+ self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
+
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray], **kwargs
+ ) -> Union[ms.Tensor, np.ndarray]:
+ is_done = time.time() - self.initial_timestamp > self.max_time
+
+ if isinstance(input_ids, ms.Tensor):
+ return ops.full((input_ids.shape[0],), is_done, dtype=ms.bool_)
+ elif isinstance(input_ids, np.ndarray):
+ return np.full((input_ids.shape[0],), is_done, dtype=np.bool_)
+ else:
+ raise NotImplementedError
+
+
+class EosTokenCriteria(StoppingCriteria):
+ """
+ This class can be used to stop generation whenever the "end-of-sequence" token is generated.
+ By default, it uses the `model.generation_config.eos_token_id`.
+
+ Args:
+ eos_token_id (`Union[int, List[int], ms.Tensor]`):
+ The id(s) of the *end-of-sequence* token.
+ """
+
+ def __init__(self, eos_token_id: Union[int, List[int], ms.Tensor]):
+ # to list
+ if not isinstance(eos_token_id, ms.Tensor):
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ elif isinstance(eos_token_id, np.ndarray):
+ eos_token_id = eos_token_id.tolist()
+ else:
+ eos_token_id = eos_token_id.asnumpy().tolist()
+
+ self.eos_token_id = eos_token_id
+
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray], **kwargs
+ ) -> Union[ms.Tensor, np.ndarray]:
+ if isinstance(input_ids, ms.Tensor):
+ is_done = mnp.isin(input_ids[:, -1], self.eos_token_id)
+ elif isinstance(input_ids, np.ndarray):
+ is_done = np.isin(input_ids[:, -1], self.eos_token_id)
+ else:
+ raise NotImplementedError
+
+ return is_done
+
+
+class StoppingCriteriaList(list):
+ @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
+ def __call__(
+ self, input_ids: Union[ms.Tensor, np.ndarray], scores: Union[ms.Tensor, np.ndarray], **kwargs
+ ) -> Union[ms.Tensor, np.ndarray]:
+ if isinstance(input_ids, ms.Tensor):
+ is_done = ops.full((input_ids.shape[0],), False, dtype=ms.bool_)
+ for criteria in self:
+ is_done = ops.logical_or(is_done, criteria(input_ids, scores, **kwargs))
+ elif isinstance(input_ids, np.ndarray):
+ is_done = np.full((input_ids.shape[0],), False, dtype=np.bool_)
+ for criteria in self:
+ is_done = np.logical_or(is_done, criteria(input_ids, scores, **kwargs))
+ else:
+ raise NotImplementedError
+
+ return is_done
+
+ @property
+ def max_length(self) -> Optional[int]:
+ for stopping_criterium in self:
+ if isinstance(stopping_criterium, MaxLengthCriteria):
+ return stopping_criterium.max_length
+ return None
diff --git a/mindone/transformers/generation/utils.py b/mindone/transformers/generation/utils.py
new file mode 100644
index 0000000000..1cb756aa92
--- /dev/null
+++ b/mindone/transformers/generation/utils.py
@@ -0,0 +1,1783 @@
+import copy
+import inspect
+import time
+import warnings
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+
+from transformers import logging
+from transformers.generation.configuration_utils import GenerationConfig, GenerationMode
+from transformers.generation.utils import GenerateNonBeamOutput
+from transformers.tokenization_utils import ExtensionsTrie
+from transformers.utils.generic import ModelOutput
+
+import mindspore as ms
+import mindspore.numpy as mnp
+from mindspore import ops
+
+from mindone.transformers.cache_utils import Cache, get_seq_length, init_static_cache, reset
+from mindone.transformers.generation.logits_process import (
+ LogitNormalization,
+ LogitsProcessorList,
+ MinLengthLogitsProcessor,
+ MinNewTokensLengthLogitsProcessor,
+ PrefixConstrainedLogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+)
+from mindone.transformers.generation.stopping_criteria import (
+ EosTokenCriteria,
+ MaxLengthCriteria,
+ MaxTimeCriteria,
+ StoppingCriteria,
+ StoppingCriteriaList,
+)
+from mindone.transformers.modeling_outputs import CausalLMOutputWithPast
+
+if TYPE_CHECKING:
+ from transformers.generation.streamers import BaseStreamer
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+
+ from mindone.transformers.modeling_utils import MSPreTrainedModel as PreTrainedModel
+
+logger = logging.get_logger(__name__)
+
+
+NEED_SETUP_CACHE_CLASSES_MAPPING = {}
+QUANT_BACKEND_CLASSES_MAPPING = {}
+
+
+@dataclass
+class GenerateDecoderOnlyOutput(ModelOutput):
+ """
+ Outputs of decoder-only generation models, when using non-beam methods.
+
+ Args:
+ sequences (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
+ if all batches finished early due to the `eos_token_id`.
+ scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`):
+ Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`):
+ Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
+ hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `ms.Tensor` of shape `(batch_size, generated_length, hidden_size)`.
+ past_key_values (`tuple(tuple(mindspore.Tensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
+ Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
+ tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+ """
+
+ sequences: ms.Tensor = None
+ scores: Optional[Tuple[ms.Tensor]] = None
+ logits: Optional[Tuple[ms.Tensor]] = None
+ attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None
+
+
+@dataclass
+class GenerateEncoderDecoderOutput(ModelOutput):
+ """
+ Outputs of encoder-decoder generation models, when using non-beam methods.
+
+ Args:
+ sequences (`ms.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
+ The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
+ if all batches finished early due to the `eos_token_id`.
+ scores (`tuple(ms.Tensor)` *optional*, returned when `output_scores=True`):
+ Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ logits (`tuple(ms.Tensor)` *optional*, returned when `output_logits=True`):
+ Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
+ at each generation step. Tuple of `ms.Tensor` with up to `max_new_tokens` elements (one element for
+ each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
+ encoder_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True`):
+ Tuple of `ms.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
+ sequence_length, sequence_length)`.
+ encoder_hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True`):
+ Tuple of `ms.Tensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+ decoder_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
+ cross_attentions (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_attentions=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `ms.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
+ decoder_hidden_states (`tuple(tuple(ms.Tensor))`, *optional*, returned when `output_hidden_states=True`):
+ Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
+ `ms.Tensor` of shape `(batch_size, generated_length, hidden_size)`.
+ past_key_values (`tuple(tuple(mindspore.Tensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
+ Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
+ tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+ """
+
+ sequences: ms.Tensor = None
+ scores: Optional[Tuple[ms.Tensor]] = None
+ logits: Optional[Tuple[ms.Tensor]] = None
+ encoder_attentions: Optional[Tuple[ms.Tensor]] = None
+ encoder_hidden_states: Optional[Tuple[ms.Tensor]] = None
+ decoder_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ cross_attentions: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ decoder_hidden_states: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ past_key_values: Optional[Tuple[Tuple[Tuple[ms.Tensor]]]] = None
+
+
+class GenerationMixin:
+ def prepare_inputs_for_generation(self, *args, **kwargs):
+ raise NotImplementedError(
+ "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
+ )
+
+ def _prepare_generation_config(
+ self, generation_config: Optional[GenerationConfig], **kwargs: Dict
+ ) -> Tuple[GenerationConfig, Dict]:
+ """
+ Prepares the base generation config, then applies any generation configuration options from kwargs.
+ """
+
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
+ if generation_config is None:
+ # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
+ # three conditions must be met
+ # 1) the generation config must have been created from the model config (`_from_model_config` field);
+ # 2) the generation config must have seen no modification since its creation (the hash is the same);
+ # 3) the user must have set generation parameters in the model config.
+ if (
+ self.generation_config._from_model_config
+ and self.generation_config._original_object_hash == hash(self.generation_config)
+ and self.config._has_non_default_generation_parameters()
+ ):
+ new_generation_config = GenerationConfig.from_model_config(self.config)
+ if new_generation_config != self.generation_config:
+ warnings.warn(
+ "You have modified the pretrained model configuration to control generation. This is a"
+ " deprecated strategy to control generation and will be removed soon, in a future version."
+ " Please use and modify the model generation configuration (see"
+ " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
+ )
+ self.generation_config = new_generation_config
+ generation_config = self.generation_config
+
+ # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled.
+ strict_kwargs = kwargs.pop("strict_kwargs", False)
+ if strict_kwargs:
+ model_kwargs = kwargs
+ generate_attributes_in_kwargs = [
+ key for key, value in kwargs.items() if getattr(generation_config, key, None) != value
+ ]
+ if len(generate_attributes_in_kwargs) > 0:
+ raise ValueError(
+ "exception: all generation configuration attributes must be passed within a "
+ f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})."
+ )
+ else:
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(**kwargs)
+
+ return generation_config, model_kwargs
+
+ def _prepare_model_inputs(
+ self,
+ inputs: Optional[ms.Tensor] = None,
+ bos_token_id: Optional[ms.Tensor] = None,
+ model_kwargs: Optional[Dict[str, ms.Tensor]] = None,
+ ) -> Tuple[ms.Tensor, Optional[str], Dict[str, ms.Tensor]]:
+ """
+ This function extracts the model-specific `inputs` for generation.
+ """
+ # 1. retrieve all kwargs that are non-None or non-model input related.
+ # some encoder-decoder models have different names for model and encoder
+ if (
+ self.config.is_encoder_decoder
+ and hasattr(self, "encoder")
+ and self.encoder.main_input_name != self.main_input_name
+ ):
+ input_name = self.encoder.main_input_name
+ else:
+ input_name = self.main_input_name
+
+ model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
+
+ # 2. check whether model_input_name is passed as kwarg
+ # if yes and `inputs` is None use kwarg inputs
+ inputs_kwarg = model_kwargs.pop(input_name, None)
+ if inputs_kwarg is not None and inputs is not None:
+ raise ValueError(
+ f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
+ f"Make sure to either pass {inputs} or {input_name}=..."
+ )
+ elif inputs_kwarg is not None:
+ inputs = inputs_kwarg
+
+ # 3. In the presence of `inputs_embeds` for text models:
+ # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
+ # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
+ # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
+ # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
+ # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
+ if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
+ if not self.config.is_encoder_decoder:
+ has_inputs_embeds_forwarding = "inputs_embeds" in set(
+ inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
+ )
+ if not has_inputs_embeds_forwarding:
+ raise ValueError(
+ f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
+ "doesn't have its forwarding implemented. See the GPT2 implementation for an example "
+ "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
+ )
+ # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
+ # the attention mask) can rely on the actual model input.
+ model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
+ inputs, bos_token_id, model_kwargs=model_kwargs
+ )
+ else:
+ if inputs is not None:
+ raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
+ inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
+
+ # 4. if `inputs` is still None, try to create `input_ids` from BOS token
+ inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
+ return inputs, input_name, model_kwargs
+
+ def _maybe_initialize_input_ids_for_generation(
+ self,
+ inputs: Optional[ms.Tensor] = None,
+ bos_token_id: Optional[ms.Tensor] = None,
+ model_kwargs: Optional[Dict[str, ms.Tensor]] = None,
+ ) -> ms.Tensor:
+ """Initializes input ids for generation, if necessary."""
+ if inputs is not None:
+ return inputs
+
+ encoder_outputs = model_kwargs.get("encoder_outputs")
+ if self.config.is_encoder_decoder and encoder_outputs is not None:
+ # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
+ shape = encoder_outputs.last_hidden_state.shape[:-1]
+ return ops.ones(shape, dtype=ms.int32) * -100
+
+ # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
+ # soft-prompting or in multimodal implementations built on top of decoder-only language models.
+ batch_size = 1
+ for value in model_kwargs.values():
+ if isinstance(value, ms.Tensor):
+ batch_size = value.shape[0]
+ break
+
+ if "inputs_embeds" in model_kwargs:
+ return ops.ones((batch_size, 0), dtype=ms.int32)
+
+ if bos_token_id is None:
+ raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
+
+ return ops.ones((batch_size, 1), dtype=ms.int32) * bos_token_id
+
+ def _prepare_attention_mask_for_generation(
+ self,
+ inputs: ms.Tensor,
+ pad_token_id: Optional[ms.Tensor],
+ eos_token_id: Optional[ms.Tensor],
+ ) -> ms.Tensor:
+ # No information for attention mask inference -> return default attention mask
+ default_attention_mask = ops.ones(inputs.shape[:2], dtype=ms.int32)
+ if pad_token_id is None:
+ return default_attention_mask
+
+ is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [ms.int32, ms.int64]
+ if not is_input_ids:
+ return default_attention_mask
+
+ is_pad_token_in_inputs = (pad_token_id is not None) and (
+ mnp.isin(element=inputs, test_elements=pad_token_id).any()
+ )
+ is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
+ mnp.isin(element=eos_token_id, test_elements=pad_token_id).any()
+ )
+ can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
+ attention_mask_from_padding = inputs.ne(pad_token_id).to(ms.int32)
+
+ attention_mask = (
+ attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask
+ )
+ return attention_mask
+
+ def _prepare_encoder_decoder_kwargs_for_generation(
+ self,
+ inputs_tensor: ms.Tensor,
+ model_kwargs,
+ model_input_name: Optional[str],
+ generation_config: GenerationConfig,
+ ) -> Dict[str, Any]:
+ # 1. get encoder
+ encoder = self.get_encoder()
+
+ # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
+ }
+ encoder_kwargs["output_attentions"] = generation_config.output_attentions
+ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+
+ # 3. make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
+ encoder_kwargs["return_dict"] = True
+ encoder_kwargs[model_input_name] = inputs_tensor
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs)
+
+ return model_kwargs
+
+ def _prepare_decoder_input_ids_for_generation(
+ self,
+ batch_size: int,
+ model_input_name: str,
+ model_kwargs: Dict[str, ms.Tensor],
+ decoder_start_token_id: ms.Tensor,
+ **ignore_kwargs,
+ ) -> Tuple[ms.Tensor, Dict[str, ms.Tensor]]:
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+ # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
+ # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+ elif "input_ids" in model_kwargs and model_input_name != "input_ids":
+ decoder_input_ids = model_kwargs.pop("input_ids")
+ else:
+ decoder_input_ids = None
+
+ # 2. `decoder_start_token_id` must have shape (batch_size, 1)
+ if decoder_start_token_id.ndim == 1:
+ if decoder_start_token_id.shape[0] != batch_size:
+ raise ValueError(
+ f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}"
+ )
+ decoder_start_token_id = decoder_start_token_id.view(-1, 1)
+ else:
+ decoder_start_token_id = ops.ones((batch_size, 1), dtype=ms.int32) * decoder_start_token_id
+
+ # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
+ # no user input -> use decoder_start_token_id as decoder_input_ids
+ if decoder_input_ids is None:
+ decoder_input_ids = decoder_start_token_id
+ # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
+ # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
+ # See: https://github.com/huggingface/transformers/pull/31470
+ elif "donut" in self.__class__.__name__.lower() or (
+ self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()
+ ):
+ pass
+ elif self.config.model_type in ["whisper"]:
+ pass
+ # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
+ # decoder_attention_mask if provided)
+ elif (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item():
+ decoder_input_ids = ops.cat([decoder_start_token_id, decoder_input_ids], axis=-1)
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ decoder_attention_mask = ops.cat(
+ (ops.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
+ axis=-1,
+ )
+ model_kwargs["decoder_attention_mask"] = decoder_attention_mask
+
+ return decoder_input_ids, model_kwargs
+
+ def _get_initial_cache_position(self, input_ids, model_kwargs):
+ """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
+ if not model_kwargs.get("use_cache", True):
+ model_kwargs["cache_position"] = None
+ return model_kwargs
+
+ past_length = 0
+ if model_kwargs.get("past_key_values") is not None:
+ cache = model_kwargs["past_key_values"]
+ if isinstance(cache, Tuple):
+ past_length = get_seq_length(cache)
+
+ if model_kwargs.get("attention_mask", None) is not None:
+ attention_mask = model_kwargs["attention_mask"]
+ if "inputs_embeds" in model_kwargs:
+ max_len = model_kwargs["inputs_embeds"].shape[1]
+ else:
+ max_len = input_ids.shape[-1]
+ cur_len = int(attention_mask.sum(-1).max())
+
+ cache_position = ops.arange(past_length, cur_len, dtype=ms.int32)
+ if (cur_len - past_length) < max_len:
+ cache_position = ops.cat([cache_position, ops.zeros(max_len - (cur_len - past_length), ms.int32)])
+ else:
+ if "inputs_embeds" in model_kwargs:
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
+ else:
+ cur_len = input_ids.shape[-1]
+
+ cache_position = ops.arange(past_length, cur_len, ms.int32)
+
+ model_kwargs["cache_position"] = cache_position
+ return model_kwargs
+
+ def _get_cache(
+ self, cache_implementation: str, max_batch_size: int, max_cache_len: int
+ ) -> Tuple[Tuple[ms.Tensor, ms.Tensor]]:
+ """
+ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
+ new `generate` call requires a larger cache.
+
+ Returns the resulting cache object.
+ """
+ if cache_implementation == "sliding_window":
+ max_cache_len = min(self.config.sliding_window, max_cache_len)
+
+ need_new_cache = (
+ not hasattr(self, "_cache")
+ or (not isinstance(self._cache, tuple))
+ or (not isinstance(self._cache[0][0], ms.Tensor))
+ or self._cache[0][0].shape[0] != max_batch_size
+ or self._cache[0][0].shape[2] < max_cache_len
+ )
+
+ if need_new_cache:
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ cache_dtype = self.config._pre_quantization_dtype
+ else:
+ cache_dtype = self.dtype
+
+ self._cache = init_static_cache(
+ config=self.config,
+ max_batch_size=max_batch_size,
+ max_cache_len=max_cache_len,
+ dtype=cache_dtype,
+ )
+ else:
+ self._cache = reset(self._cache)
+
+ return self._cache
+
+ def _supports_default_dynamic_cache(self) -> bool:
+ """
+ Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
+ This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which
+ uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in
+ order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed
+ for `HybridMambaAttentionDynamicCache`).
+ """
+ return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower()
+
+ def _prepare_special_tokens(
+ self, generation_config: GenerationConfig, kwargs_has_attention_mask: Optional[bool] = None, **ignore_kwargs
+ ):
+ """
+ Prepares the special tokens for generation, overwriting the generation config with their processed versions
+ converted to tensor.
+
+ Note that `generation_config` is changed in place and stops being serializable after this method is called.
+ That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
+ function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
+ """
+
+ # Convert special tokens to tensors (if they exist either in kwargs or in self.config)
+ def _tensor_or_none(token_kwargs, token_self):
+ token = token_kwargs if token_kwargs is not None else token_self
+ if token is None or isinstance(token, ms.Tensor):
+ return token
+ return ms.Tensor(token, dtype=ms.int32)
+
+ bos_token_id = _tensor_or_none(generation_config.bos_token_id, self.generation_config.bos_token_id)
+ eos_token_id = _tensor_or_none(generation_config.eos_token_id, self.generation_config.eos_token_id)
+ pad_token_id = _tensor_or_none(generation_config.pad_token_id, self.generation_config.pad_token_id)
+ decoder_start_token_id = _tensor_or_none(
+ generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id
+ )
+
+ # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
+ if self.config.is_encoder_decoder:
+ decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
+
+ # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
+ if eos_token_id is not None and eos_token_id.ndim == 0:
+ eos_token_id = eos_token_id.unsqueeze(0)
+
+ # Set pad token if unset (and there are conditions to do so)
+ if pad_token_id is None and eos_token_id is not None:
+ if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
+ logger.warning(
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+ )
+ pad_token_id = eos_token_id[0]
+ logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
+
+ # we can't infer attn mask if pad token is set to be eos token in model's generation config
+ if eos_token_id is not None and mnp.isin(element=eos_token_id, test_elements=pad_token_id).any():
+ if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
+ logger.warning_once(
+ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
+ "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
+ "to obtain reliable results."
+ )
+
+ # Sanity checks/warnings
+ if self.config.is_encoder_decoder and decoder_start_token_id is None:
+ raise ValueError(
+ "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
+ )
+ if eos_token_id is not None and (ops.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
+ logger.warning(
+ f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
+ "stop until the maximum length is reached. Depending on other flags, it may even crash."
+ )
+
+ # Update generation config with the updated special tokens tensors
+ generation_config.bos_token_id = bos_token_id
+ generation_config.eos_token_id = eos_token_id
+ generation_config.pad_token_id = pad_token_id
+ generation_config.decoder_start_token_id = decoder_start_token_id
+
+ def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
+ past_key_values = None
+ cache_name = "past_key_values"
+ if "past_key_values" in outputs:
+ past_key_values = outputs.past_key_values
+ elif "mems" in outputs:
+ past_key_values = outputs.mems
+ elif "past_buckets_states" in outputs:
+ past_key_values = outputs.past_buckets_states
+ elif "cache_params" in outputs:
+ past_key_values = outputs.cache_params
+ cache_name = "cache_params"
+
+ # Bloom fix: standardizes the cache format when requested
+ if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
+ batch_size = outputs.logits.shape[0]
+ past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
+ return cache_name, past_key_values
+
+ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
+ """Validates model kwargs for generation. Generate argument typos will also be caught here."""
+ # If a `Cache` instance is passed, checks whether the model is compatible with it
+ if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
+ raise ValueError(
+ f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
+ "check the model documentation for supported cache formats."
+ )
+
+ # Excludes arguments that are handled before calling any model function
+ if self.config.is_encoder_decoder:
+ for key in ["decoder_input_ids"]:
+ model_kwargs.pop(key, None)
+
+ unused_model_args = []
+ model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
+ # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
+ # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
+ if "kwargs" in model_args or "model_kwargs" in model_args:
+ model_args |= set(inspect.signature(self.construct).parameters)
+
+ # Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
+ if self.config.is_encoder_decoder:
+ base_model = getattr(self, self.base_model_prefix, None)
+
+ # allow encoder kwargs
+ encoder = getattr(self, "encoder", None)
+ # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
+ # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
+ # TODO: A better way to handle this.
+ if encoder is None and base_model is not None:
+ encoder = getattr(base_model, "encoder", None)
+
+ if encoder is not None:
+ encoder_model_args = set(inspect.signature(encoder.forward).parameters)
+ model_args |= encoder_model_args
+
+ # allow decoder kwargs
+ decoder = getattr(self, "decoder", None)
+ if decoder is None and base_model is not None:
+ decoder = getattr(base_model, "decoder", None)
+
+ if decoder is not None:
+ decoder_model_args = set(inspect.signature(decoder.forward).parameters)
+ model_args |= {f"decoder_{x}" for x in decoder_model_args}
+
+ # allow assistant_encoder_outputs to be passed if we're doing assisted generating
+ if "assistant_encoder_outputs" in model_kwargs:
+ model_args |= {"assistant_encoder_outputs"}
+
+ for key, value in model_kwargs.items():
+ if value is not None and key not in model_args:
+ unused_model_args.append(key)
+
+ if unused_model_args:
+ raise ValueError(
+ f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
+ " generate arguments will also show up in this list)"
+ )
+
+ def _validate_assistant(self, assistant_model):
+ if assistant_model is None:
+ return
+
+ if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
+ attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
+ attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check]
+ are_equal = all(
+ getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check
+ )
+ if not are_equal:
+ raise ValueError(
+ "The main model and the assistant don't have compatible encoder-dependent input shapes. "
+ "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
+ )
+
+ if not self.config.vocab_size == assistant_model.config.vocab_size:
+ raise ValueError("Make sure the main and assistant model use the same tokenizer")
+
+ def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
+ """Performs validation related to the resulting generated length"""
+
+ # 1. Max length warnings related to poor parameterization
+ if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
+ # 20 is the default max_length of the generation config
+ warnings.warn(
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
+ "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
+ "generation.",
+ UserWarning,
+ )
+ if input_ids_length >= generation_config.max_length:
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ raise ValueError(
+ f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_length` or, better yet, setting `max_new_tokens`."
+ )
+
+ # 2. Min length warnings due to unfeasible parameter combinations
+ min_length_error_suffix = (
+ " Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
+ "increase the maximum length."
+ )
+ if has_default_max_length:
+ min_length_error_suffix += (
+ f" Note that `max_length` is set to {generation_config.max_length}, its default value."
+ )
+ if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
+ warnings.warn(
+ f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
+ f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
+ UserWarning,
+ )
+ if generation_config.min_new_tokens is not None:
+ min_length = generation_config.min_new_tokens + input_ids_length
+ if min_length > generation_config.max_length:
+ warnings.warn(
+ f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
+ f"added to the prompt length ({input_ids_length}), is larger than"
+ f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
+ UserWarning,
+ )
+
+ def _prepare_generated_length(
+ self,
+ generation_config,
+ has_default_max_length,
+ has_default_min_length,
+ model_input_name,
+ input_ids_length,
+ inputs_tensor,
+ ):
+ """Prepared max and min length in generaion configs to avoid clashes between similar attributes"""
+
+ if generation_config.max_new_tokens is not None:
+ if not has_default_max_length and generation_config.max_length is not None:
+ logger.warning(
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+ )
+
+ logger.warning(
+ "Unlike the original transformers, `input_ids` will pad with inputs prompt token "
+ "and `max_new_tokens` contains the length of the input, "
+ "please set bigger `max_new_tokens` while considering the input length !"
+ )
+
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_length
+
+ if generation_config.max_length < inputs_tensor.shape[1]:
+ raise ValueError(
+ f"max_new_tokens `{generation_config.max_new_tokens}` is smaller than "
+ f"input length `{inputs_tensor.shape[1]}`."
+ )
+
+ # if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
+ # otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length``
+ elif (
+ model_input_name == "inputs_embeds"
+ and input_ids_length != inputs_tensor.shape[1]
+ and not self.config.is_encoder_decoder
+ ):
+ generation_config.max_length -= inputs_tensor.shape[1]
+
+ # same for min length
+ if generation_config.min_new_tokens is not None:
+ if not has_default_min_length:
+ logger.warning(
+ f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
+ f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
+ "Please refer to the documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+ )
+ generation_config.min_length = generation_config.min_new_tokens + input_ids_length
+
+ elif (
+ model_input_name == "inputs_embeds"
+ and input_ids_length != inputs_tensor.shape[1]
+ and not self.config.is_encoder_decoder
+ ):
+ generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
+
+ return generation_config
+
+ def heal_tokens(self, input_ids: ms.Tensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None) -> ms.Tensor:
+ r"""
+ Generates sequences of token ids for models with a language modeling head.
+ Parameters:
+ input_ids (`ms.Tensor`): The sequence used as a prompt for the generation.
+ tokenizer (`PreTrainedTokenizerBase`, *optional*): The tokenizer used to decode the input ids.
+ Return:
+ `ms.Tensor` where each sequence has its tail token replaced with its appropriate extension.
+ """
+ if tokenizer is None:
+ raise ValueError(
+ " When generating with token healing, you must pass the model's tokenizer to the `tokenizer` "
+ "argument of `generate`."
+ )
+
+ bos_token_id, pad_token_id = tokenizer.bos_token_id, tokenizer.pad_token_id
+ vocab_trie = ExtensionsTrie(tokenizer.get_vocab())
+ generation_config = GenerationConfig(max_new_tokens=1, pad_token_id=pad_token_id)
+
+ # assumption: leading/trailing whitespace is not meaningful, so the prompts are
+ # stripped before re-tokenizing to desensitize generation to whitespace artefacts
+ prompts = [p.strip() for p in tokenizer.batch_decode(input_ids, skip_special_tokens=True)]
+ input_ids = ms.Tensor(
+ tokenizer(
+ prompts,
+ return_tensors="np",
+ padding=True,
+ ).input_ids
+ )
+
+ # replace bos with pad to not condition healing on it
+ input_ids = ops.where(input_ids == bos_token_id, pad_token_id, input_ids)
+
+ tail_ids = input_ids[:, -1].tolist()
+ space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
+ # tail tokens are used for a prefix search, thus, whitespaces are replaced with
+ # their tokenization (e.g. 'Ä ') to enable search for tokens prefixed with a whitespace
+ tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
+
+ for batch_idx, (tail_id, tail_tok) in enumerate(zip(tail_ids, tail_toks)):
+ batch_ids = input_ids[batch_idx]
+ if ops.all(batch_ids == pad_token_id).item():
+ continue # skip empty sequences (all pad ids)
+
+ # apply bias for alternatives (extensions) to the tail token
+ seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)}
+ if len(seq_bias) == 1:
+ continue # skip if there are no token alternatives to heal with
+
+ # slightly favor original token to limit aggressive healing e.g. 'http' -> 'https'
+ seq_bias[(tail_id,)] += 1.0
+ generation_config.update(sequence_bias=seq_bias)
+
+ trimmed_ids = batch_ids[:-1]
+ # if the prompt is a single (non-pad) token, regenerate from bos
+ if len(batch_ids[batch_ids != pad_token_id]) == 1:
+ trimmed_ids[-1] = bos_token_id
+
+ input_ids[batch_idx] = self.generate(trimmed_ids.unsqueeze(0), generation_config=generation_config)
+
+ return input_ids
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ num_new_tokens: int = 1,
+ ) -> Dict[str, Any]:
+ # update past_key_values keeping its naming used in model code
+ cache_name, cache = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format
+ )
+ model_kwargs[cache_name] = cache
+ if getattr(outputs, "state", None) is not None:
+ model_kwargs["state"] = outputs.state
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = ops.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1)
+
+ if not is_encoder_decoder:
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+
+ cur_lens = attention_mask.sum(-1)
+ for batch_idx in range(attention_mask.shape[0]):
+ cur_len = int(cur_lens[batch_idx])
+ if cur_len < attention_mask.shape[-1]:
+ attention_mask[batch_idx, cur_len] = 1
+ else:
+ attention_mask[batch_idx, :-1] = attention_mask[batch_idx, 1:]
+ attention_mask[batch_idx, -1:] = 1
+ model_kwargs["attention_mask"] = attention_mask
+
+ # model_kwargs["attention_mask"] = ops.cat(
+ # [attention_mask, ops.ones((attention_mask.shape[0], 1), dtype=attention_mask.dtype)], axis=-1
+ # )
+ else:
+ # update decoder attention mask
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ model_kwargs["decoder_attention_mask"] = ops.cat(
+ [
+ decoder_attention_mask,
+ ops.ones((decoder_attention_mask.shape[0], 1), dtype=decoder_attention_mask.dtype),
+ ],
+ axis=-1,
+ )
+
+ if (
+ model_kwargs.get("use_cache", True)
+ and "cache_position" in model_kwargs
+ and model_kwargs["cache_position"] is not None
+ ):
+ if (
+ model_kwargs.get("attention_mask", None) is not None
+ and model_kwargs["attention_mask"].shape[-1] == model_kwargs["cache_position"].shape[0]
+ ):
+ # `cache_position` obtain effective length after 1st step
+ cur_idx = int(model_kwargs["attention_mask"].sum(-1).max()) - 1
+ past_idx = cur_idx - 1
+ model_kwargs["cache_position"] = (
+ model_kwargs["cache_position"][past_idx : past_idx + 1] + num_new_tokens
+ )
+ else:
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
+
+ return model_kwargs
+
+ def _get_logits_processor(
+ self,
+ generation_config: GenerationConfig,
+ input_ids_seq_length: int,
+ encoder_input_ids: ms.Tensor,
+ prefix_allowed_tokens_fn: Callable[[int, ms.Tensor], List[int]],
+ logits_processor: Optional[LogitsProcessorList],
+ device: str = None,
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ negative_prompt_ids: Optional[ms.Tensor] = None,
+ negative_prompt_attention_mask: Optional[ms.Tensor] = None,
+ ) -> LogitsProcessorList:
+ """
+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
+ instances used to modify the scores of the language model head.
+ """
+ # instantiate processors list
+ processors = LogitsProcessorList()
+
+ if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
+ raise NotImplementedError
+ if generation_config.sequence_bias is not None:
+ raise NotImplementedError
+
+ if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
+ raise NotImplementedError
+ if (
+ generation_config.encoder_repetition_penalty is not None
+ and generation_config.encoder_repetition_penalty != 1.0
+ ):
+ raise NotImplementedError
+ if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
+ raise NotImplementedError
+ if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
+ raise NotImplementedError
+ if (
+ generation_config.encoder_no_repeat_ngram_size is not None
+ and generation_config.encoder_no_repeat_ngram_size > 0
+ ):
+ raise NotImplementedError
+ if generation_config.bad_words_ids is not None:
+ raise NotImplementedError
+ if (
+ generation_config.min_length is not None
+ and generation_config.eos_token_id is not None
+ and generation_config.min_length > 0
+ ):
+ processors.append(
+ MinLengthLogitsProcessor(
+ generation_config.min_length,
+ generation_config.eos_token_id,
+ )
+ )
+ if (
+ generation_config.min_new_tokens is not None
+ and generation_config.eos_token_id is not None
+ and generation_config.min_new_tokens > 0
+ ):
+ processors.append(
+ MinNewTokensLengthLogitsProcessor(
+ input_ids_seq_length,
+ generation_config.min_new_tokens,
+ generation_config.eos_token_id,
+ )
+ )
+ if prefix_allowed_tokens_fn is not None:
+ processors.append(
+ PrefixConstrainedLogitsProcessor(
+ prefix_allowed_tokens_fn,
+ generation_config.num_beams // generation_config.num_beam_groups,
+ )
+ )
+ if generation_config.forced_bos_token_id is not None:
+ raise NotImplementedError
+ if generation_config.forced_eos_token_id is not None:
+ raise NotImplementedError
+ if generation_config.remove_invalid_values is True:
+ raise NotImplementedError
+ if generation_config.exponential_decay_length_penalty is not None:
+ raise NotImplementedError
+ if generation_config.suppress_tokens is not None:
+ raise NotImplementedError
+ if generation_config.begin_suppress_tokens is not None:
+ raise NotImplementedError
+ if generation_config.forced_decoder_ids is not None:
+ raise NotImplementedError
+ if generation_config.watermarking_config is not None:
+ raise NotImplementedError
+
+ processors = self._merge_criteria_processor_list(processors, logits_processor)
+ # `LogitNormalization` should always be the last logit processor, when present
+ if generation_config.renormalize_logits is True:
+ processors.append(LogitNormalization())
+ return processors
+
+ def _get_stopping_criteria(
+ self,
+ generation_config: GenerationConfig,
+ stopping_criteria: Optional[StoppingCriteriaList],
+ tokenizer: Optional["PreTrainedTokenizerBase"] = None,
+ **kwargs,
+ ) -> StoppingCriteriaList:
+ criteria = StoppingCriteriaList()
+ if generation_config.max_length is not None:
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
+ criteria.append(
+ MaxLengthCriteria(
+ max_length=generation_config.max_length,
+ max_position_embeddings=max_position_embeddings,
+ )
+ )
+ if generation_config.max_time is not None:
+ criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
+ if generation_config.stop_strings is not None:
+ if tokenizer is None:
+ raise ValueError(
+ "There are one or more stop strings, either in the arguments to `generate` or in the "
+ "model's generation config, but we could not locate a tokenizer. When generating with "
+ "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
+ )
+ # criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
+ raise NotImplementedError
+ if generation_config.eos_token_id is not None:
+ criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
+ criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
+ return criteria
+
+ def _get_logits_warper(
+ self,
+ generation_config: GenerationConfig,
+ ) -> LogitsProcessorList:
+ """
+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
+ used for multinomial sampling.
+ """
+
+ # instantiate warpers list
+ warpers = LogitsProcessorList()
+
+ # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
+ # better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
+ if generation_config.num_beams > 1:
+ if isinstance(generation_config.eos_token_id, list):
+ min_tokens_to_keep = len(generation_config.eos_token_id) + 1
+ elif isinstance(generation_config.eos_token_id, ms.Tensor):
+ min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1
+ else:
+ min_tokens_to_keep = 2
+ else:
+ min_tokens_to_keep = 1
+
+ # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
+ # all samplers can be found in `generation_utils_samplers.py`
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
+ warpers.append(TemperatureLogitsWarper(generation_config.temperature))
+ if generation_config.top_k is not None and generation_config.top_k != 0:
+ warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.top_p is not None and generation_config.top_p < 1.0:
+ warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.min_p is not None:
+ # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
+ # warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
+ raise NotImplementedError
+ if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
+ # warpers.append(
+ # TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
+ # )
+ raise NotImplementedError
+ if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
+ # warpers.append(
+ # EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
+ # )
+ raise NotImplementedError
+ if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
+ # warpers.append(
+ # EtaLogitsWarper(
+ # epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep
+ # )
+ # )
+ raise NotImplementedError
+ # `LogitNormalization` should always be the last logit processor, when present
+ if generation_config.renormalize_logits is True:
+ warpers.append(LogitNormalization())
+ return warpers
+
+ def _merge_criteria_processor_list(
+ self,
+ default_list: Union[LogitsProcessorList, StoppingCriteriaList],
+ custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
+ ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
+ if len(custom_list) == 0:
+ return default_list
+ for default in default_list:
+ for custom in custom_list:
+ if type(custom) is type(default):
+ object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
+ raise ValueError(
+ f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
+ f" `.generate()`, but it has already been created with the values {default}. {default} has been"
+ " created by passing the corresponding arguments to generate or by the model's config default"
+ f" values. If you just want to change the default values of {object_type} consider passing"
+ f" them as arguments to `.generate()` instead of using a custom {object_type}."
+ )
+ default_list.extend(custom_list)
+ return default_list
+
+ @staticmethod
+ def _expand_inputs_for_generation(
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[ms.Tensor] = None,
+ **model_kwargs,
+ ) -> Tuple[ms.Tensor, Dict[str, Any]]:
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], ms.Tensor)
+ ):
+ if dict_to_expand[key].dtype == ms.bool_:
+ dict_to_expand[key] = (
+ dict_to_expand[key].to(ms.int32).repeat_interleave(expand_size, dim=0).to(ms.bool_)
+ )
+ else:
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
+
+ def _padding_inputs(
+ self,
+ generation_config,
+ input_ids: ms.Tensor,
+ inputs_embeds: ms.Tensor = None,
+ labels: ms.Tensor = None,
+ position_ids: ms.Tensor = None,
+ attention_mask: ms.Tensor = None,
+ ):
+ # init empty array
+ bs, max_length = len(input_ids), generation_config.max_length
+ emb_length = inputs_embeds.shape[-1] if inputs_embeds is not None else 0
+ ignore_label_index = 0
+
+ padded_input_ids = ops.zeros((bs, max_length), ms.int32)
+ padded_labels = ops.full((bs, max_length), ignore_label_index, dtype=ms.int32)
+ padded_position_ids = ops.zeros((bs, max_length), ms.int32)
+ padded_attention_mask = ops.zeros((bs, max_length), ms.bool_)
+
+ padded_inputs_embeds = (
+ ops.zeros((bs, max_length, emb_length), inputs_embeds.dtype) if inputs_embeds is not None else None
+ )
+
+ _labels = labels
+ _position_ids = position_ids
+
+ if attention_mask is None:
+ if inputs_embeds is not None:
+ attention_mask = ops.ones(inputs_embeds.shape[:2], dtype=ms.bool_)
+ else:
+ attention_mask = ops.ones(input_ids.shape[:], dtype=ms.bool_)
+ else:
+ attention_mask = attention_mask.astype(ms.bool_)
+ cur_len = int(attention_mask.sum(-1).max())
+
+ if position_ids is None:
+ position_ids = ops.arange(0, cur_len, dtype=ms.int32)
+ if labels is None:
+ labels = ops.full(
+ (
+ bs,
+ cur_len,
+ ),
+ ignore_label_index,
+ dtype=ms.int32,
+ )
+
+ for batch_idx, cur_attention_mask in enumerate(attention_mask):
+ cur_len = cur_attention_mask.sum()
+
+ padded_attention_mask[batch_idx, :cur_len] = attention_mask[batch_idx][:]
+ padded_input_ids[batch_idx, : min(cur_len, input_ids[batch_idx].shape[0])] = input_ids[batch_idx][:]
+ padded_labels[batch_idx, :cur_len] = labels[batch_idx][:]
+ padded_position_ids[batch_idx, :cur_len] = ops.arange(0, cur_len, dtype=position_ids.dtype)
+
+ if inputs_embeds is not None:
+ padded_inputs_embeds[batch_idx, :cur_len] = inputs_embeds[batch_idx][:]
+
+ new_input_ids = padded_input_ids
+ new_attention_mask = padded_attention_mask
+ new_labels = None if _labels is None else padded_labels
+ new_position_ids = None if _position_ids is None else padded_position_ids
+
+ new_inputs_embeds = None if inputs_embeds is None else padded_inputs_embeds
+
+ return new_input_ids, new_inputs_embeds, new_labels, new_position_ids, new_attention_mask
+
+ def generate(
+ self,
+ inputs: Optional[ms.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, ms.Tensor], List[int]]] = None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ negative_prompt_ids: Optional[ms.Tensor] = None,
+ negative_prompt_attention_mask: Optional[ms.Tensor] = None,
+ **kwargs,
+ ) -> Union[Tuple, ms.Tensor]:
+ r"""
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](../generation_strategies).
+
+
+
+ Parameters:
+ inputs (`ms.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config ([`~generation.GenerationConfig`], *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which has the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complements the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
+ sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
+ intended for advanced users.
+ prefix_allowed_tokens_fn (`Callable[[int, ms.Tensor], List[int]]`, *optional*):
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
+ Retrieval](https://arxiv.org/abs/2010.00904).
+ synced_gpus (`bool`, *optional*):
+ Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
+ `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
+ generating before other GPUs. Otherwise it'll be set to `False`.
+ assistant_model (`PreTrainedModel`, *optional*):
+ An assistant model that can be used to accelerate generation. The assistant model must have the exact
+ same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
+ is much faster than running generation with the model you're calling generate from. As such, the
+ assistant model should be much smaller.
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ negative_prompt_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ The negative prompt needed for some processors such as CFG. The batch size must match the input batch
+ size. This is an experimental feature, subject to breaking API changes in future versions.
+ negative_prompt_attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Attention_mask for `negative_prompt_ids`.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `ms.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `ms.Tensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateDecoderOnlyOutput`],
+ - [`~generation.GenerateBeamDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GenerateEncoderDecoderOutput`],
+ - [`~generation.GenerateBeamEncoderDecoderOutput`]
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+ tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
+ generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
+ self._validate_model_kwargs(model_kwargs.copy())
+ self._validate_assistant(assistant_model)
+
+ # 2. Set generation parameters if not already defined
+ synced_gpus = False # Set to `True` when zero3
+
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ accepts_attention_mask = "attention_mask" in set(inspect.signature(self.construct).parameters.keys())
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+ kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
+
+ # 3. Define model inputs
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+
+ self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
+
+ # decoder-only models must use left-padding for batched generation.
+ if not self.config.is_encoder_decoder:
+ # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
+ # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
+ if (
+ generation_config.pad_token_id is not None
+ and batch_size > 1
+ and len(inputs_tensor.shape) == 2
+ and ops.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
+ ):
+ logger.warning(
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
+ )
+
+ # 4. Define other model kwargs
+ # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
+ # generating the first new token or not, and we only want to use the embeddings for the first new token)
+ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
+ model_kwargs["use_cache"] = True
+ if not generation_config.use_cache:
+ logger.warning("force `use_cache=True` when decoder-only and model_input_name is `inputs_embeds`.")
+ else:
+ model_kwargs["use_cache"] = generation_config.use_cache
+
+ if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
+ model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
+ inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
+ )
+
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
+ # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name, generation_config
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ if self.config.is_encoder_decoder:
+ input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
+ batch_size=batch_size,
+ model_input_name=model_input_name,
+ model_kwargs=model_kwargs,
+ decoder_start_token_id=generation_config.decoder_start_token_id,
+ )
+ else:
+ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
+
+ if generation_config.token_healing:
+ input_ids = self.heal_tokens(input_ids, tokenizer)
+
+ if streamer is not None:
+ streamer.put(input_ids.asnumpy())
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
+ generation_config = self._prepare_generated_length(
+ generation_config=generation_config,
+ has_default_max_length=has_default_max_length,
+ has_default_min_length=has_default_min_length,
+ model_input_name=model_input_name,
+ inputs_tensor=inputs_tensor,
+ input_ids_length=input_ids_length,
+ )
+
+ use_dynamic_cache_by_default = False
+ if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
+ raise ValueError(
+ "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
+ "Cache object) is unsupported. Please use only one of the two."
+ )
+ elif generation_config.cache_implementation is not None:
+ if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
+ if generation_config.cache_implementation == "static" and not self._supports_static_cache:
+ raise ValueError(
+ "This model does not support `cache_implementation='static'`. Please check the following "
+ "issue: https://github.com/huggingface/transformers/issues/28981"
+ )
+ model_kwargs["past_key_values"] = self._get_cache(
+ generation_config.cache_implementation,
+ getattr(generation_config, "num_beams", 1) * batch_size,
+ generation_config.max_length,
+ )
+ elif generation_config.cache_implementation == "quantized":
+ if not self._supports_quantized_cache:
+ raise ValueError(
+ "This model does not support the quantized cache. If you want your model to support quantized "
+ "cache, please open an issue."
+ )
+
+ raise NotImplementedError("Not support Quantized Cache")
+ # Use DynamicCache instance by default. This will avoid back and forth from legacy format that
+ # keeps copying the cache thus using much more memory
+ elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
+ raise NotImplementedError
+
+ # Use static tuple cache by default.
+ elif (
+ generation_config.cache_implementation is None
+ and not self._supports_default_dynamic_cache()
+ and model_kwargs.get("use_cache", False)
+ ):
+ past = model_kwargs.get("past_key_values", None)
+ max_batch_size, max_cache_len, cache_dtype = (
+ getattr(generation_config, "num_beams", 1) * batch_size,
+ generation_config.max_length,
+ self.dtype,
+ )
+ need_new_cache = (
+ past is None
+ or (not isinstance(past, tuple))
+ or (not isinstance(past[0][0], ms.Tensor))
+ or past[0][0].shape[0] != max_batch_size
+ or past[0][0].shape[2] < max_cache_len
+ )
+
+ if need_new_cache:
+ model_kwargs["past_key_values"] = init_static_cache(
+ config=self.config,
+ max_batch_size=max_batch_size,
+ max_cache_len=max_cache_len,
+ dtype=cache_dtype,
+ )
+ else:
+ model_kwargs["past_key_values"] = reset(past)
+
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
+
+ # 7. determine generation mode
+ generation_mode = generation_config.get_generation_mode(assistant_model)
+
+ if streamer is not None and (generation_config.num_beams > 1):
+ raise ValueError(
+ "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
+ )
+
+ # 8. prepare distribution pre_processing samplers
+ prepared_logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ model_kwargs=model_kwargs,
+ negative_prompt_ids=negative_prompt_ids,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+
+ # 9. prepare stopping criteria
+ prepared_stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
+ )
+
+ # 10. go into different generation modes
+ if generation_mode == GenerationMode.ASSISTED_GENERATION:
+ raise NotImplementedError
+ elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
+ raise NotImplementedError
+ elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+ # 11. prepare logits warper
+ prepared_logits_warper = self._get_logits_warper(generation_config) if generation_config.do_sample else None
+
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
+ result = self._sample(
+ input_ids,
+ logits_processor=prepared_logits_processor,
+ logits_warper=prepared_logits_warper,
+ stopping_criteria=prepared_stopping_criteria,
+ generation_config=generation_config,
+ synced_gpus=synced_gpus,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ # 14. unlike the original transformers, need delete the length of the input
+ result = result[:, inputs_tensor.shape[1] :]
+
+ elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
+ raise NotImplementedError
+ elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
+ raise NotImplementedError
+ elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+ # Convert to legacy cache if needed
+ if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
+ # if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
+ # if isinstance(result.past_key_values, DynamicCache):
+ # result.past_key_values = result.past_key_values.to_legacy_cache()
+ raise NotImplementedError
+
+ return result
+
+ def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool) -> bool:
+ """
+ Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
+ fed through `this_peer_finished`. ZeRO stage 3-friendly.
+ """
+ if synced_gpus:
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
+ # The following logic allows an early break if all peers finished generating their sequence
+ this_peer_finished_flag = ms.Tensor(0.0 if this_peer_finished else 1.0)
+ # send 0.0 if we finished, 1.0 otherwise
+ this_peer_finished_flag = ops.AllReduce()(this_peer_finished_flag)
+ # did all peers finish? the reduced sum will be 0.0 then
+ if this_peer_finished_flag.item() == 0.0:
+ return False
+ elif this_peer_finished:
+ return False
+ return True
+
+ def _sample(
+ self,
+ input_ids: ms.Tensor,
+ logits_processor: LogitsProcessorList,
+ stopping_criteria: StoppingCriteriaList,
+ generation_config: GenerationConfig,
+ synced_gpus: bool,
+ streamer: Optional["BaseStreamer"],
+ logits_warper: Optional[LogitsProcessorList] = None,
+ **model_kwargs,
+ ) -> Union[GenerateNonBeamOutput, ms.Tensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+ Parameters:
+ input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ logits_processor (`LogitsProcessorList`):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ generation_config ([`~generation.GenerationConfig`]):
+ The generation configuration to be used as parametrization of the decoding method.
+ synced_gpus (`bool`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ streamer (`BaseStreamer`, *optional*):
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ logits_warper (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+ to warp the prediction score distribution of the language modeling head applied before multinomial
+ sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
+ `generation_config`)
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `ms.Tensor`:
+ A `ms.Tensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+ """
+
+ # init values
+ pad_token_id = generation_config.pad_token_id
+ output_attentions = generation_config.output_attentions
+ output_hidden_states = generation_config.output_hidden_states
+ output_scores = generation_config.output_scores
+ output_logits = generation_config.output_logits
+ return_dict_in_generate = generation_config.return_dict_in_generate
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
+ do_sample = generation_config.do_sample
+ if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
+ raise ValueError(
+ "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
+ f"{logits_warper})."
+ )
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
+
+ # Padding inputs to avoid dynamic shape on MindSpore 2.3.1
+ (
+ padded_input_ids,
+ padded_inputs_embeds,
+ padded_labels,
+ padded_position_ids,
+ padded_attention_mask,
+ ) = self._padding_inputs(
+ generation_config,
+ input_ids,
+ model_kwargs.get("inputs_embeds", None),
+ model_kwargs.get("labels", None),
+ model_kwargs.get("position_ids", None),
+ model_kwargs.get("attention_mask", None),
+ )
+ input_ids = padded_input_ids
+ model_kwargs["attention_mask"] = padded_attention_mask
+ if model_kwargs.get("inputs_embeds", None) is not None:
+ model_kwargs["inputs_embeds"] = padded_inputs_embeds
+ if model_kwargs.get("labels", None) is not None:
+ model_kwargs["labels"] = padded_labels
+ if model_kwargs.get("position_ids", None) is not None:
+ model_kwargs["position_ids"] = padded_position_ids
+
+ # keep track of which sequences are already finished
+ batch_size = input_ids.shape[0]
+ this_peer_finished = False
+ unfinished_sequences = ops.ones(batch_size, dtype=ms.int32)
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+
+ step = 0
+ s_time = time.time()
+
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus):
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=False if ms.get_context("mode") == ms.GRAPH_MODE else True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't need
+
+ print(
+ f"======> sampling, step: {step}, sample outputs shape: "
+ f"{[o.shape for o in outputs if isinstance(o, ms.Tensor)]}, time cost: {time.time() - s_time:.3f}s"
+ )
+ s_time = time.time()
+ step += 1
+
+ if not isinstance(outputs, CausalLMOutputWithPast):
+ outputs = CausalLMOutputWithPast(
+ loss=None,
+ logits=outputs[0],
+ past_key_values=outputs[1] if model_inputs.get("use_cache", False) else None,
+ )
+
+ if model_kwargs.get("attention_mask", None) is not None:
+ attention_mask = model_kwargs["attention_mask"]
+ cur_idx = int(attention_mask.sum(-1).max()) - 1
+
+ if outputs.logits.shape[1] == attention_mask.shape[-1]:
+ next_token_logits = outputs.logits[:, cur_idx, :] # (bs, seq, dim)
+ else:
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # `input_ids` obtain effective length after 1st step
+ if input_ids.shape[1] == attention_mask.shape[1]:
+ input_ids = input_ids[:, : cur_idx + 1]
+
+ else:
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ if do_sample:
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_logits:
+ raw_logits += (next_token_logits,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,)
+ )
+
+ # token selection
+ if do_sample:
+ probs = ops.softmax(next_token_scores, axis=-1, dtype=ms.float32).to(next_token_scores.dtype)
+ next_tokens = ops.multinomial(probs, num_samples=1).squeeze(1)
+ else:
+ next_tokens = ops.argmax(next_token_scores, dim=-1)
+
+ # finished sentences should have their next token be a padding token
+ if has_eos_stopping_criteria:
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+ next_tokens = next_tokens.to(ms.int32)
+
+ # update generated ids, model inputs, and length for next step
+ input_ids = ops.cat([input_ids, next_tokens[:, None]], axis=-1)
+ if streamer is not None:
+ streamer.put(next_tokens.asnumpy())
+
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ )
+
+ unfinished_sequences = unfinished_sequences & ~ms.Tensor(stopping_criteria(input_ids, scores), ms.bool_)
+ this_peer_finished = unfinished_sequences.max() == 0
+
+ if streamer is not None:
+ streamer.end()
+
+ if return_dict_in_generate:
+ if self.config.is_encoder_decoder:
+ return GenerateEncoderDecoderOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return input_ids
diff --git a/mindone/transformers/integrations/peft.py b/mindone/transformers/integrations/peft.py
index ed82a91d70..50d329e41c 100644
--- a/mindone/transformers/integrations/peft.py
+++ b/mindone/transformers/integrations/peft.py
@@ -34,7 +34,7 @@ class PeftAdapterMixin:
- AdaLora: https://arxiv.org/abs/2303.10512
Other PEFT models such as prompt tuning, prompt learning are out of scope as these adapters are not "injectable"
- into a torch module. For using these methods, please refer to the usage guide of PEFT library.
+ into a mindspore cell. For using these methods, please refer to the usage guide of PEFT library.
With this mixin, if the correct PEFT version is installed, it is possible to:
@@ -88,7 +88,7 @@ def load_adapter(
Whether to use authentication token to load the remote folder. Userful to load private repositories
that are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to
cache it.
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
+ device_map (`str` or `Dict[str, Union[int, str]]` or `int`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
diff --git a/mindone/transformers/mindspore_adapter/__init__.py b/mindone/transformers/mindspore_adapter/__init__.py
new file mode 100644
index 0000000000..f06708a822
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/__init__.py
@@ -0,0 +1,7 @@
+from .amp import *
+from .attention import *
+from .data import *
+from .recompute import *
+from .train_onestep_wrapper import *
+from .training_args import *
+from .utils import *
diff --git a/mindone/transformers/mindspore_adapter/adamw.py b/mindone/transformers/mindspore_adapter/adamw.py
new file mode 100644
index 0000000000..88f6122393
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/adamw.py
@@ -0,0 +1,202 @@
+import numpy as np
+
+import mindspore as ms
+from mindspore import Parameter, ParameterTuple, Tensor, nn, ops
+from mindspore.ops import functional as F
+from mindspore.ops import operations as P
+
+update_params = ops.MultitypeFuncGraph("update_params")
+adamw_opt = ops.MultitypeFuncGraph("adamw_opt")
+fused_adam_weight_decay = ops.MultitypeFuncGraph("fused_adam_weight_decay")
+
+
+@adamw_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
+def _adamw_opt(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag):
+ op_mul = P.Mul()
+ op_square = P.Square()
+ op_sqrt = P.Sqrt()
+ op_cast = P.Cast()
+ op_reshape = P.Reshape()
+ op_shape = P.Shape()
+ param_fp32 = op_cast(param, ms.float32)
+ m_fp32 = op_cast(m, ms.float32)
+ v_fp32 = op_cast(v, ms.float32)
+ gradient_fp32 = op_cast(gradient, ms.float32)
+
+ next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), ms.float32) - beta1, gradient_fp32)
+
+ next_v = op_mul(beta2, v_fp32) + op_mul(
+ op_cast(F.tuple_to_array((1.0,)), ms.float32) - beta2, op_square(gradient_fp32)
+ )
+
+ update = next_m / (eps + op_sqrt(next_v))
+ if decay_flag:
+ update = op_mul(weight_decay, param_fp32) + update
+
+ update_with_lr = op_mul(lr, update)
+ next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
+
+ # next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
+ next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
+ next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
+
+ return op_cast(next_param, F.dtype(param))
+
+
+@fused_adam_weight_decay.register(
+ "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool"
+)
+def _run_fused_adam_weight_decay_opt(
+ opt, beta1, beta2, eps, lr, weight_decay, param, moment1, moment2, gradient, decay_flags, optim_filter
+):
+ """Apply FusedAdamWeightDecay optimizer to the weight parameter using Tensor."""
+
+ beta1 = ops.cast(beta1, ms.float32)
+ beta2 = ops.cast(beta2, ms.float32)
+ eps = ops.cast(eps, ms.float32)
+ lr = ops.cast(lr, ms.float32)
+ weight_decay = ops.cast(weight_decay, ms.float32)
+
+ if optim_filter:
+ if decay_flags:
+ opt(param, moment1, moment2, lr, beta1, beta2, eps, weight_decay, P.Cast()(gradient, F.dtype(param)))
+ else:
+ opt(param, moment1, moment2, lr, beta1, beta2, eps, 0.0, P.Cast()(gradient, F.dtype(param)))
+
+ return True
+
+
+@update_params.register("Tensor", "Tensor")
+def update_params(param, update):
+ update = ops.cast(update, param.dtype)
+ success = ops.logical_not(ops.isnan(update))
+ success = ops.depend(success, ops.assign(param, update))
+ return success
+
+
+class AdamWeightDecay(nn.Optimizer):
+ def __init__(
+ self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, enable_fuse=False
+ ):
+ super(AdamWeightDecay, self).__init__(learning_rate, params, weight_decay)
+
+ print(
+ f"WARNING: {self.__class__.__name__} \n"
+ f" beta1/beta2/eps : {beta1}/{beta2}/{eps} \n"
+ f" weight_decay : {weight_decay} \n"
+ f" enable_fuse : {enable_fuse} \n"
+ )
+
+ self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
+ self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
+ self.eps = Tensor(np.array([eps]).astype(np.float32))
+ self.moments1 = self._param_init_op(self._parameters, prefix="adam_m", init="zeros")
+ self.moments2 = self._param_init_op(self._parameters, prefix="adam_v", init="zeros")
+
+ self.enable_fuse = enable_fuse
+ if self.enable_fuse:
+ self.fused_opt = ops.AdamWeightDecay()
+
+ # print
+ param_dtype = None
+ if isinstance(params[0], Parameter):
+ param_dtype = params[0].dtype
+ elif isinstance(params[0], dict):
+ if isinstance(params[0]["params"], list) and len(params[0]["params"]) > 0:
+ param_dtype = params[0]["params"][0].dtype
+ if param_dtype == ms.float16:
+ print(f"[ERROR] {self.__class__.__name__}, param dtype fp16, may cause `sdma error` on MindSpore 2.3.0")
+ else:
+ print(
+ f"[WARNING] {self.__class__.__name__}, custom optimizer, may cause `memory leakage` on MindSpore 2.3.0"
+ )
+
+ def _param_init_op(self, params, prefix, init="zeros"):
+ news = []
+ for p in params:
+ new = p.clone(init)
+ new.name = prefix + "." + p.name
+ news.append(new)
+ return ParameterTuple(news)
+
+ @ms.jit
+ def construct(self, gradients):
+ gradients = self.flatten_gradients(gradients)
+ weight_decay = self.get_weight_decay()
+ lr = self.get_lr()
+ self.assignadd(self.global_step, self.global_step_increase_tensor)
+
+ if self.enable_fuse:
+ if self.is_group:
+ if self.is_group_lr:
+ success = self.hyper_map(
+ F.partial(fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps),
+ lr,
+ weight_decay,
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+ else:
+ success = self.hyper_map(
+ F.partial(fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps, lr),
+ weight_decay,
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+ else:
+ success = self.hyper_map(
+ F.partial(
+ fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps, lr, weight_decay
+ ),
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+
+ else:
+ if self.is_group:
+ if self.is_group_lr:
+ optim_result = self.hyper_map_reverse(
+ F.partial(adamw_opt, self.beta1, self.beta2, self.eps),
+ lr,
+ weight_decay,
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ )
+ else:
+ optim_result = self.hyper_map_reverse(
+ F.partial(adamw_opt, self.beta1, self.beta2, self.eps, lr),
+ weight_decay,
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ )
+ else:
+ optim_result = self.hyper_map_reverse(
+ F.partial(adamw_opt, self.beta1, self.beta2, self.eps, lr, weight_decay),
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ )
+
+ success = self.hyper_map(update_params, self._parameters, optim_result)
+
+ return success
diff --git a/mindone/transformers/mindspore_adapter/adamw_zero.py b/mindone/transformers/mindspore_adapter/adamw_zero.py
new file mode 100644
index 0000000000..8fb214cba3
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/adamw_zero.py
@@ -0,0 +1,435 @@
+import numpy as np
+
+import mindspore as ms
+from mindspore import Parameter, ParameterTuple, Tensor, context, nn, ops
+from mindspore.common.initializer import initializer
+from mindspore.communication.management import GlobalComm, get_group_size, get_rank
+from mindspore.ops import functional as F
+
+from .adamw import adamw_opt, fused_adam_weight_decay
+from .utils import _is_parallel
+
+split_params = ops.MultitypeFuncGraph("split_params")
+update_params_with_all_gather = ops.MultitypeFuncGraph("update_params_with_all_gather")
+allreduce_op = ops.MultitypeFuncGraph("reduce_op")
+allreduce_and_split_op = ops.MultitypeFuncGraph("reduce_and_split_op")
+reducescatter_and_split_op = ops.MultitypeFuncGraph("reducescatter_and_split_op")
+
+
+@update_params_with_all_gather.register("Tensor", "Tensor", "Function")
+def _update_params_with_all_gather(param, update, all_gather):
+ update = all_gather(update)
+ update = update.to(param.dtype)
+ # Note: ops.isnan not support bfloat16 on MindSpore 2.3.1
+ success = ops.logical_not(ops.isnan(update.float() if update.dtype == ms.bfloat16 else update).any())
+ success = ops.depend(success, ops.assign(param, update))
+ return success
+
+
+@split_params.register("Number", "Number", "Tensor")
+def split_params(shard_id, shard_size, param):
+ if param.shape[0] % shard_size == 0:
+ # param = ops.Split(0, shard_size)(param)[shard_id]
+ param = ops.chunk(param, shard_size, axis=0)[shard_id]
+ return param
+
+
+@allreduce_op.register("Number", "Bool", "Function", "Tensor")
+def _tensors_allreduce(degree, mean, all_reduce_op, grad):
+ # allreduce
+ grad = all_reduce_op(grad)
+ if mean:
+ grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
+
+ return grad
+
+
+@allreduce_and_split_op.register("Number", "Bool", "Function", "Number", "Number", "Tensor")
+def _tensors_allreduce_and_split(degree, mean, all_reduce_op, shard_id, shard_size, grad):
+ # allreduce
+ grad = all_reduce_op(grad)
+ if mean:
+ grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
+
+ # split
+ if grad.shape[0] % shard_size == 0:
+ grad = ops.Split(0, shard_size)(grad)[shard_id]
+
+ return grad
+
+
+@reducescatter_and_split_op.register("Number", "Bool", "Function", "Function", "Number", "Tensor")
+def _tensors_reducescatter_and_split(degree, mean, reduce_scatter_op, all_reduce_op, shard_size, grad):
+ if grad.shape[0] % shard_size == 0:
+ # allreduce and split on world size
+ grad = reduce_scatter_op(grad)
+ else:
+ # allreduce
+ grad = all_reduce_op(grad)
+
+ if mean:
+ grad = F.tensor_mul(grad, F.cast(degree, F.dtype(grad)))
+
+ return grad
+
+
+class AdamWeightDecayZeRO1(nn.Optimizer):
+ def __init__(
+ self,
+ params,
+ learning_rate=1e-3,
+ beta1=0.9,
+ beta2=0.999,
+ eps=1e-6,
+ weight_decay=0.0,
+ shard_size=None,
+ enable_fuse=True,
+ momentum_dtype=ms.float32,
+ ):
+ super(AdamWeightDecayZeRO1, self).__init__(learning_rate, params, weight_decay)
+
+ self.map = ops.Map()
+ self.rank = get_rank() if _is_parallel() else 0
+ self.group_size = get_group_size() if _is_parallel() else 1
+ self.is_parallel = _is_parallel()
+
+ # group for split
+ if shard_size == 1 or not _is_parallel():
+ comm_group = None
+ g_id = 0
+ self.shard_id = self.rank
+ self.shard_size = shard_size if _is_parallel() else 1
+ print(
+ f"[WARNING] {self.__class__.__name__} shard_size is 1, will not shard optimizer parameter, "
+ f"recommended to use the `mindspore.nn.AdamWeightDecay`"
+ )
+
+ elif shard_size is None:
+ comm_group = GlobalComm.WORLD_COMM_GROUP
+ g_id = 0
+ self.shard_id = self.rank
+ self.shard_size = self.group_size
+
+ else:
+ assert (1 < shard_size <= self.group_size) and (self.group_size % shard_size == 0)
+ from mindspore.communication import create_group
+
+ g_id = self.rank // shard_size
+ s_id, e_id = g_id * shard_size, (g_id + 1) * shard_size
+ comm_group = f"sub_group_{g_id}"
+ create_group(comm_group, [_i for _i in range(s_id, e_id)])
+ self.shard_id = self.rank % shard_size
+ self.shard_size = shard_size
+
+ print(
+ f"[WARNING] {self.__class__.__name__} \n"
+ f" beta1/beta2/eps : {beta1}/{beta2}/{eps} \n"
+ f" weight_decay : {weight_decay} \n"
+ f" shard size : {self.shard_size} \n"
+ f" shard_id : {self.shard_id} \n"
+ f" comm group : {comm_group} \n"
+ f" enable_fuse : {enable_fuse} \n"
+ f" momentum_dtype : {momentum_dtype} \n"
+ )
+
+ self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
+ self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
+ self.eps = Tensor(np.array([eps]).astype(np.float32))
+
+ self.moments1 = self._param_init_op(self._parameters, prefix="adam_m", init="zeros", dtype=momentum_dtype)
+ self.moments2 = self._param_init_op(self._parameters, prefix="adam_v", init="zeros", dtype=momentum_dtype)
+ self.all_gather_ops = self._init_all_gather_ops(self._parameters, group=comm_group)
+ self.comm_group = comm_group
+
+ if _is_parallel():
+ self.all_reduce_op = ops.AllReduce()
+ self.mean = context.get_auto_parallel_context("gradients_mean")
+ self.degree = context.get_auto_parallel_context("device_num")
+ self.degree = 1.0 / self.degree
+
+ total_num = len(self.all_gather_ops)
+ split_num = sum([1 for _op in self.all_gather_ops if isinstance(_op, ops.AllGather)])
+ unsplit_num = total_num - split_num
+ print(
+ f"{self.__class__.__name__}, total param num: {total_num}, "
+ f"split num: {split_num}, unsplit num: {unsplit_num}"
+ )
+
+ self.enable_fuse = enable_fuse
+ if self.enable_fuse:
+ self.fused_opt = ops.AdamWeightDecay()
+
+ if self.shard_size > 1:
+ self._split_parameters = self._param_init_op(
+ self._parameters, prefix="adam_split_p", init="same", dtype=momentum_dtype
+ )
+
+ if momentum_dtype == ms.float16:
+ print(
+ f"[ERROR] {self.__class__.__name__}, momentum dtype fp16, may cause `sdma error` on MindSpore 2.3.0"
+ )
+ else:
+ print(
+ f"[WARNING] {self.__class__.__name__}, custom optimizer, may cause `memory leakage` on MindSpore 2.3.0"
+ )
+
+ def _init_all_gather_ops(self, params, group):
+ op_list = []
+ for x in params:
+ if x.split_op:
+ op_list.append(ops.AllGather(group=group))
+ else:
+ op_list.append(ops.identity)
+ return tuple(op_list)
+
+ def _param_init_op(self, params, prefix, init="zeros", dtype=None):
+ news = []
+ for p in params:
+ s = p.shape
+ dtype = dtype if dtype is not None else p.dtype
+ if self.shard_size == 1:
+ if init == "same":
+ new = Parameter(Tensor(p.asnumpy(), dtype=dtype), name=prefix + "." + p.name)
+ else:
+ new = Parameter(initializer(init, shape=s, dtype=dtype), name=prefix + "." + p.name)
+ setattr(p, "split_op", False)
+ elif s[0] % self.shard_size == 0:
+ s = list(s)
+ s[0] = s[0] // self.shard_size
+ s = tuple(s)
+ if init == "same":
+ new_np = p.asnumpy()
+ split_shape = (
+ self.shard_size,
+ -1,
+ *new_np.shape[1:],
+ ) # e.g. (6, 1000) -> (2, 3, 1000) -> (3, 1000)
+ new_np = np.reshape(new_np, split_shape)[self.shard_id]
+ new = Parameter(Tensor(new_np, dtype=dtype), name=prefix + "." + p.name)
+ else:
+ new = Parameter(initializer(init, shape=s, dtype=dtype), name=prefix + "." + p.name)
+ setattr(p, "split_op", True)
+ else:
+ if init == "same":
+ new_np = p.asnumpy()
+ new = Parameter(Tensor(new_np, dtype=dtype), name=prefix + "." + p.name)
+ else:
+ new = Parameter(initializer(init, shape=p.shape, dtype=dtype), name=prefix + "." + p.name)
+ setattr(p, "split_op", False)
+ print(f"[WARNING] {self.__class__.__name__} split {new.name} fail, keep original shape.")
+
+ if not isinstance(new, ms.Parameter):
+ print(f"p.name: {p.name}, type(p): {type(p)}, p.shape: {p.shape}, type(new): {type(new)}")
+
+ news.append(new)
+
+ return ParameterTuple(news)
+
+ def convert_momentum_dtype(self, momentum_list, dtype=ms.float32):
+ for p in momentum_list:
+ p.set_dtype(dtype)
+
+ @ms.jit
+ def grad_reduce(self, grads):
+ if self.is_parallel:
+ mean, degree, shard_id, shard_size = self.mean, self.degree, self.shard_id, self.shard_size
+
+ if self.shard_size == 1:
+ return self.grad_allreduce_(mean, degree, grads)
+ else:
+ return self.grad_allreduce_and_split(mean, degree, shard_id, shard_size, grads)
+ else:
+ return grads
+
+ @ms.jit
+ def grad_allreduce_(self, mean, degree, gradients):
+ gradients = ops.HyperMap()(F.partial(allreduce_op, degree, mean, self.all_reduce_op), gradients)
+ return gradients
+
+ @ms.jit
+ def grad_allreduce_and_split(self, mean, degree, shard_id, shard_size, gradients):
+ part_gradients = ops.HyperMap()(
+ F.partial(allreduce_and_split_op, degree, mean, self.all_reduce_op, shard_id, shard_size), gradients
+ )
+ return part_gradients
+
+ @ms.jit
+ def construct(self, split_gradients):
+ if self.enable_fuse:
+ if self.shard_size == 1:
+ self._optim_fuse_no_shard(split_gradients)
+ else:
+ self._optim_fuse(split_gradients)
+ else:
+ self._optim_custom(split_gradients)
+
+ def _optim_custom(self, split_gradients):
+ gradients = split_gradients
+ params = self.hyper_map(F.partial(split_params, self.shard_id, self.shard_size), self._parameters)
+ # gradients = self.hyper_map(F.partial(split_params, self.shard_id, self.shard_size), gradients)
+ # params = self.hyper_map(F.partial(split_params, self.shard_id, self.shard_size), self._parameters)
+
+ gradients = self.flatten_gradients(gradients)
+ weight_decay = self.get_weight_decay()
+ lr = self.get_lr()
+ self.assignadd(self.global_step, self.global_step_increase_tensor)
+
+ if self.is_group:
+ if self.is_group_lr:
+ optim_result = self.hyper_map(
+ F.partial(adamw_opt, self.beta1, self.beta2, self.eps),
+ lr,
+ weight_decay,
+ params,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ )
+ else:
+ optim_result = self.hyper_map(
+ F.partial(adamw_opt, self.beta1, self.beta2, self.eps, lr),
+ weight_decay,
+ params,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ )
+ else:
+ optim_result = self.hyper_map(
+ F.partial(adamw_opt, self.beta1, self.beta2, self.eps, lr, weight_decay),
+ params,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ )
+
+ success = self.hyper_map(update_params_with_all_gather, self._parameters, optim_result, self.all_gather_ops)
+
+ return success
+
+ def _optim_fuse(self, split_gradients):
+ gradients = split_gradients
+
+ gradients = self.flatten_gradients(gradients)
+ weight_decay = self.get_weight_decay()
+ lr = self.get_lr()
+ self.assignadd(self.global_step, self.global_step_increase_tensor)
+
+ if self.is_group:
+ if self.is_group_lr:
+ success = self.hyper_map(
+ F.partial(fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps),
+ lr,
+ weight_decay,
+ self._split_parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+ else:
+ success = self.hyper_map(
+ F.partial(fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps, lr),
+ weight_decay,
+ self._split_parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+ else:
+ success = self.hyper_map(
+ F.partial(fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps, lr, weight_decay),
+ self._split_parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+
+ success = ops.depend(
+ self.hyper_map(
+ update_params_with_all_gather, self._parameters, self._split_parameters, self.all_gather_ops
+ ),
+ success,
+ )
+
+ return success
+
+ def _optim_fuse_no_shard(self, split_gradients):
+ gradients = split_gradients
+
+ gradients = self.flatten_gradients(gradients)
+ weight_decay = self.get_weight_decay()
+ lr = self.get_lr()
+ self.assignadd(self.global_step, self.global_step_increase_tensor)
+
+ if self.is_group:
+ if self.is_group_lr:
+ success = self.hyper_map(
+ F.partial(fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps),
+ lr,
+ weight_decay,
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+ else:
+ success = self.hyper_map(
+ F.partial(fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps, lr),
+ weight_decay,
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+ else:
+ success = self.hyper_map(
+ F.partial(fused_adam_weight_decay, self.fused_opt, self.beta1, self.beta2, self.eps, lr, weight_decay),
+ self._parameters,
+ self.moments1,
+ self.moments2,
+ gradients,
+ self.decay_flags,
+ self.optim_filter,
+ )
+
+ return success
+
+
+class AdamWeightDecayZeRO2(AdamWeightDecayZeRO1):
+ def __init__(self, *args, **kwargs):
+ super(AdamWeightDecayZeRO2, self).__init__(*args, **kwargs)
+ self.reduce_scatter_op = ops.ReduceScatter() if _is_parallel() else nn.Identity()
+
+ def grad_reduce(self, grads):
+ if self.is_parallel:
+ mean, degree, shard_id, shard_size = self.mean, self.degree, self.shard_id, self.shard_size
+
+ if self.shard_size == 1:
+ return self.grad_allreduce_(mean, degree, grads)
+ else:
+ if self.group_size == self.shard_size:
+ return self.grad_reducescatter_and_split(mean, degree, shard_id, shard_size, grads)
+ else:
+ return self.grad_allreduce_and_split(mean, degree, shard_id, shard_size, grads)
+ else:
+ return grads
+
+ def grad_reducescatter_and_split(self, mean, degree, shard_id, shard_size, gradients):
+ part_gradients = ops.HyperMap()(
+ F.partial(reducescatter_and_split_op, degree, mean, self.reduce_scatter_op, self.all_reduce_op, shard_size),
+ gradients,
+ )
+ return part_gradients
diff --git a/mindone/transformers/mindspore_adapter/amp.py b/mindone/transformers/mindspore_adapter/amp.py
new file mode 100644
index 0000000000..ea7eb0323c
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/amp.py
@@ -0,0 +1,88 @@
+import mindspore as ms
+from mindspore import nn
+from mindspore.train.amp import _auto_black_list
+
+HALF_UNFRIENDLY_LAYERS = [
+ nn.BatchNorm1d,
+ nn.BatchNorm2d,
+ nn.BatchNorm3d,
+ nn.LayerNorm,
+ nn.GroupNorm,
+ nn.SiLU,
+ nn.GELU,
+ nn.Softmax,
+ nn.Sigmoid,
+ nn.MaxPool1d,
+ nn.MaxPool2d,
+ nn.MaxPool3d,
+ nn.AvgPool1d,
+ nn.AvgPool2d,
+ nn.AvgPool3d,
+ nn.CrossEntropyLoss,
+]
+
+
+def auto_mixed_precision(network, amp_level="O0", dtype=ms.float16):
+ """
+ auto mixed precision function.
+
+ Args:
+ network (Cell): Definition of the network.
+ amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: "O0".
+
+ - "O0": Do not change.
+ - "O1": Cast the operators in white_list to float16, the remaining operators are kept in float32.
+ - "O2": Cast network to float16, keep operators in black_list run in float32,
+ - "O3": Cast network to float16.
+
+ Raises:
+ ValueError: If amp level is not supported.
+
+ Examples:
+ >>> network = LeNet5()
+ >>> amp_level = "O2"
+ >>> net = auto_mixed_precision(network, amp_level, dtype=ms.float16)
+ """
+
+ if not isinstance(network, nn.Cell):
+ raise TypeError("The network type should be Cell.")
+
+ if amp_level == "O0":
+ pass
+ elif amp_level == "O1":
+ raise NotImplementedError
+ elif amp_level == "O2":
+ _auto_black_list(network, HALF_UNFRIENDLY_LAYERS, dtype)
+ elif amp_level == "O3":
+ network.to_float(dtype)
+ else:
+ raise ValueError("The amp level {} is not supported".format(amp_level))
+ return network
+
+
+def auto_convert_module_dtype(model: nn.Cell, dtype=ms.float16, keep_norm_fp32=True):
+ dtype2str_map = {ms.float16: "fp16", ms.bfloat16: "bf16", ms.float32: "fp32"}
+
+ if dtype not in (ms.float16, ms.bfloat16, ms.float32):
+ raise ValueError(f"convert_module_dtype, not support dtype: {dtype}")
+
+ if model is not None:
+ assert isinstance(model, nn.Cell)
+
+ k_num, c_num = 0, 0
+ for _, p in model.parameters_and_names():
+ # filter norm parameters
+ if keep_norm_fp32 and ("norm" in p.name):
+ k_num += 1
+ # filter bool/int parameters
+ elif p.dtype in (ms.bool_, ms.int32, ms.int64, ms.uint8):
+ k_num += 1
+ elif p.dtype == dtype:
+ c_num += 1
+ else:
+ c_num += 1
+ p.set_dtype(dtype)
+
+ print(f"Convert `{type(model).__name__}` param to {dtype2str_map[dtype]}, keep/modify num {k_num}/{c_num}.")
+
+ return model
diff --git a/mindone/transformers/mindspore_adapter/attention.py b/mindone/transformers/mindspore_adapter/attention.py
new file mode 100644
index 0000000000..182e50968a
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/attention.py
@@ -0,0 +1,105 @@
+import numpy as np
+
+import mindspore as ms
+from mindspore import nn, ops
+from mindspore.ops.operations.nn_ops import FlashAttentionScore as _FlashAttention
+
+DTYPE_FP16_MIN = float(np.finfo(np.float16).min)
+
+
+def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None):
+ # force dtype(fp16 or bf16) precision calculation
+ ori_dtype = query.dtype
+ if dtype is not None:
+ query, key, value = query.astype(dtype), key.astype(dtype), value.astype(dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == ms.bool_:
+ attn_mask = attn_mask.to(ms.float32)
+ attn_mask = attn_mask.masked_fill((1 - attn_mask).to(ms.bool_), DTYPE_FP16_MIN)
+ attn_mask = attn_mask.to(query.dtype)
+
+ attn_weight = ops.softmax(
+ ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5) + attn_mask, axis=-1, dtype=ms.float32
+ ).astype(query.dtype)
+ else:
+ attn_weight = ops.softmax(
+ ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5), axis=-1, dtype=ms.float32
+ ).astype(query.dtype)
+
+ out = ops.matmul(attn_weight, value)
+ out = out.astype(ori_dtype)
+
+ return out
+
+
+class FlashAttention2(nn.Cell):
+ def __init__(
+ self,
+ head_dim: int,
+ head_num: int,
+ attention_dropout: float = 0.0,
+ input_layout: str = "BNSD",
+ dtype: ms.dtype = ms.float16,
+ ):
+ super().__init__()
+ self.input_layout = input_layout
+ if input_layout not in ["BSH", "BNSD"]:
+ raise ValueError(f"input_layout must be in ['BSH', 'BNSD'], but get {input_layout}.")
+ self.head_dim = head_dim
+
+ self.flash_attention = _FlashAttention(
+ scale_value=head_dim**-0.5,
+ head_num=head_num,
+ input_layout=input_layout,
+ keep_prob=1 - attention_dropout,
+ )
+
+ self.dtype = dtype
+ cand_d_list = [64, 80, 96, 120, 128, 256]
+ self.d_pad = 0
+ for d in cand_d_list:
+ if head_dim == d:
+ self.d_pad = 0
+ break
+ elif head_dim < d:
+ self.d_pad = d - head_dim
+ break
+ if head_dim > 256:
+ raise ValueError("head_dim must <= 256!")
+ self.need_pad = self.d_pad != 0
+
+ def _rearange_input(self, x):
+ x = x.to(self.dtype)
+ if self.need_pad:
+ if self.input_layout == "BNSD":
+ B, N, S, D = x.shape
+ pad = ops.zeros((B, N, S, self.d_pad), x.dtype)
+ else:
+ B, S = x.shape[:2]
+ x = x.reshape(B, S, -1, self.head_dim)
+ pad = ops.zeros((B, S, x.shape[2], self.d_pad), x.dtype)
+ x = ops.concat((x, pad), axis=-1)
+ if self.input_layout == "BSH":
+ B, S = x.shape[:2]
+ x = x.reshape(B, S, -1)
+ return x
+
+ def _rearange_output(self, x, dtype):
+ if self.input_layout == "BSH":
+ B, S = x.shape[:2]
+ x = x.reshape(B, S, -1, self.head_dim + self.d_pad)
+ if self.need_pad:
+ x = x[:, :, :, : self.head_dim]
+ return x.to(dtype)
+
+ def construct(self, q, k, v, mask=None):
+ q_dtype = q.dtype
+ q = self._rearange_input(q)
+ k = self._rearange_input(k)
+ v = self._rearange_input(v)
+ if mask is not None:
+ mask = mask.to(ms.uint8)
+ out = self.flash_attention(q, k, v, None, None, None, mask)[3]
+ out = self._rearange_output(out, q_dtype)
+ return out
diff --git a/mindone/transformers/mindspore_adapter/clip_grad.py b/mindone/transformers/mindspore_adapter/clip_grad.py
new file mode 100644
index 0000000000..635cf01f11
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/clip_grad.py
@@ -0,0 +1,101 @@
+import mindspore as ms
+from mindspore import ops
+from mindspore.ops import composite as C
+from mindspore.ops import functional as F
+
+_clip_grad_value = ops.MultitypeFuncGraph("_clip_grad_value")
+
+
+@_clip_grad_value.register("Number", "Number", "Tensor")
+def __clip_grad_value(max_value, grad):
+ """
+ Clip gradients.
+
+ Inputs:
+ max_value (float): Specifies how much to clip.
+ grad (tuple[Tensor]): Gradients.
+
+ Outputs:
+ tuple[Tensor]: clipped gradients.
+ """
+ new_grad = C.clip_by_value(grad, -max_value, max_value)
+ return new_grad
+
+
+_apply_global_norm = ops.MultitypeFuncGraph("_apply_global_norm")
+
+
+@_apply_global_norm.register("Number", "Tensor", "Tensor")
+def __apply_global_norm(clip_coef, x):
+ x_dtype = F.dtype(x)
+ x = x * clip_coef
+ x = F.cast(x, x_dtype)
+ return x
+
+
+_square = ops.MultitypeFuncGraph("_square")
+
+
+@_square.register("Tensor")
+def __square(x):
+ return ops.square(x)
+
+
+_square_sum = ops.MultitypeFuncGraph("_square_sum")
+
+
+@_square_sum.register("Tensor")
+def __square_sum(x):
+ return ops.square(x.astype(ms.float32)).sum()
+
+
+_square_sum_and_all_reduce = ops.MultitypeFuncGraph("_square_sum_and_all_reduce")
+
+
+@_square_sum_and_all_reduce.register("Tensor")
+def __square_sum_and_all_reduce(all_reduce_op, x):
+ square_x_sum = ops.square(x.astype(ms.float32)).sum()
+ square_x_sum = all_reduce_op(square_x_sum)
+ return square_x_sum
+
+
+hyper_map_op = ops.HyperMap()
+
+
+def _clip_grad_l2norm(max_norm, grads):
+ grads_square_sum = hyper_map_op(_square_sum, grads)
+ total_norm = ops.sqrt(ops.addn(grads_square_sum))
+
+ clip_coef = max_norm / (total_norm + 1e-6)
+ clip_coef_clamped = ops.clamp(clip_coef, None, 1.0)
+
+ clipped_grads = hyper_map_op(F.partial(_apply_global_norm, clip_coef_clamped), grads)
+ return clipped_grads
+
+
+def _clip_grad_l2norm_for_zero(max_norm, all_reduce_op, part_grads):
+ grads_square_sum = hyper_map_op(F.partial(_square_sum_and_all_reduce, all_reduce_op), part_grads)
+ total_norm = ops.sqrt(ops.addn(grads_square_sum))
+
+ clip_coef = max_norm / (total_norm + 1e-6)
+ clip_coef = (
+ ops.ones((), dtype=ms.float32) * clip_coef
+ ) # necessary on MindSpore 2.3.1 to enable `clip_coef` as a Tensor
+
+ clip_coef_clamped = ops.clamp(clip_coef, None, 1.0)
+
+ clipped_part_grads = hyper_map_op(F.partial(_apply_global_norm, clip_coef_clamped), part_grads)
+ return clipped_part_grads
+
+
+def clip_grad_value(grads, max_value):
+ clipped_grads = hyper_map_op(F.partial(_clip_grad_value, max_value), grads)
+ return clipped_grads
+
+
+def clip_grad_norm(grads, max_norm):
+ return _clip_grad_l2norm(max_norm, grads)
+
+
+def clip_grad_norm_for_zero(part_grads, max_norm, all_reduce_op):
+ return _clip_grad_l2norm_for_zero(max_norm, all_reduce_op, part_grads)
diff --git a/mindone/transformers/mindspore_adapter/data.py b/mindone/transformers/mindspore_adapter/data.py
new file mode 100644
index 0000000000..d57ebe0c5e
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/data.py
@@ -0,0 +1,180 @@
+from typing import Generic, Iterator, Optional, Sized, TypeVar, Union
+
+import numpy as np
+
+import mindspore as ms
+
+T_co = TypeVar("T_co", covariant=True)
+
+
+class Sampler(Generic[T_co]):
+ r"""Base class for all Samplers.
+
+ reference to torch.utils.data.Sampler
+
+ Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
+ way to iterate over indices of dataset elements, and a :meth:`__len__` method
+ that returns the length of the returned iterators.
+ """
+
+ def __init__(self, data_source: Optional[Sized]) -> None:
+ pass
+
+ def __iter__(self) -> Iterator[T_co]:
+ raise NotImplementedError
+
+ # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
+ #
+ # Many times we have an abstract class representing a collection/iterable of
+ # data, e.g., `.data.Sampler`, with its subclasses optionally
+ # implementing a `__len__` method. In such cases, we must make sure to not
+ # provide a default implementation, because both straightforward default
+ # implementations have their issues:
+ #
+ # + `return NotImplemented`:
+ # Calling `len(subclass_instance)` raises:
+ # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
+ #
+ # + `raise NotImplementedError()`:
+ # This prevents triggering some fallback behavior. E.g., the built-in
+ # `list(X)` tries to call `len(X)` first, and executes a different code
+ # path if the method is not found or `NotImplemented` is returned, while
+ # raising an `NotImplementedError` will propagate and and make the call
+ # fail where it could have use `__iter__` to complete the call.
+ #
+ # Thus, the only two sensible things to do are
+ #
+ # + **not** provide a default `__len__`.
+ #
+ # + raise a `TypeError` instead, which is what Python uses when users call
+ # a method that is not defined on an object.
+ # (@ssnl verifies that this works on at least Python 3.7.)
+
+
+class SequentialSampler(Sampler[int]):
+ r"""Samples elements sequentially, always in the same order.
+
+ Args:
+ data_source (Dataset): dataset to sample from
+ """
+ data_source: Sized
+
+ def __init__(self, data_source: Sized) -> None:
+ self.data_source = data_source
+
+ def __iter__(self) -> Iterator[int]:
+ return iter(range(len(self.data_source)))
+
+ def __len__(self) -> int:
+ return len(self.data_source)
+
+
+class RandomSampler(Sampler[int]):
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+ If with replacement, then user can specify :attr:`num_samples` to draw.
+
+ Args:
+ data_source (Dataset): dataset to sample from
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
+ generator (Generator): Generator used in sampling.
+ """
+ data_source: Sized
+ replacement: bool
+
+ def __init__(
+ self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None
+ ) -> None:
+ self.data_source = data_source
+ self.replacement = replacement
+ self._num_samples = num_samples
+ self.generator = generator
+
+ if not isinstance(self.replacement, bool):
+ raise TypeError("replacement should be a boolean value, but got " "replacement={}".format(self.replacement))
+
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+ raise ValueError(
+ "num_samples should be a positive integer " "value, but got num_samples={}".format(self.num_samples)
+ )
+
+ @property
+ def num_samples(self) -> int:
+ # dataset size might change at runtime
+ if self._num_samples is None:
+ return len(self.data_source)
+ return self._num_samples
+
+ def __iter__(self) -> Iterator[int]:
+ n = len(self.data_source)
+
+ if self.replacement:
+ for _ in range(self.num_samples // 32):
+ yield from np.random.randint(low=0, high=n, size=(32,), dtype=np.int64).tolist()
+ yield from np.random.randint(low=0, high=n, size=(self.num_samples % 32,), dtype=np.int64).tolist()
+ else:
+ for _ in range(self.num_samples // n):
+ yield from np.random.permutation(n).tolist()
+ yield from np.random.permutation(n).tolist()[: self.num_samples % n]
+
+ def __len__(self) -> int:
+ return self.num_samples
+
+
+class Dataset(Generic[T_co]):
+ r"""An abstract class representing a :class:`Dataset`.
+
+ reference to torch.utils.data.Dataset
+
+ All datasets that represent a map from keys to data samples should subclass
+ it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
+ data sample for a given key. Subclasses could also optionally overwrite
+ :meth:`__len__`, which is expected to return the size of the dataset by many
+ :class:`~mindspore_adapter.data.Sampler` implementations and the default options
+ of :class:`~mindspore_adapter.data.DataLoader`.
+
+ .. note::
+ :class:`~mindspore_adapter.data.DataLoader` by default constructs a index
+ sampler that yields integral indices. To make it work with a map-style
+ dataset with non-integral indices/keys, a custom sampler must be provided.
+ """
+
+ def __getitem__(self, index) -> T_co:
+ raise NotImplementedError
+
+ # No `def __len__(self)` default?
+ # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
+ # in mindone/transformers/mindspore_adapter/data.py:Sampler
+
+
+class TensorDataset(Dataset):
+ r"""Dataset wrapping tensors.
+
+ Each sample will be retrieved by indexing tensors along the first dimension.
+
+ Args:
+ *tensors (Tensor): tensors that have the same size of the first dimension.
+ """
+
+ def __init__(self, *tensors: Union[ms.Tensor, np.ndarray]) -> None:
+ assert all(tensors[0].shape[0] == tensor.shape[0] for tensor in tensors), "Size mismatch between tensors"
+ self.tensors = tensors
+
+ def __getitem__(self, index):
+ return tuple(
+ tensor[index] if isinstance(tensor, ms.Tensor) else ms.Tensor(tensor[index]) for tensor in self.tensors
+ )
+
+ def __len__(self):
+ return self.tensors[0].shape[0]
+
+
+class HF2MSDataset:
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __getitem__(self, item):
+ return self.dataset[int(item)]
+
+ def __len__(self):
+ return len(self.dataset)
diff --git a/mindone/transformers/mindspore_adapter/recompute.py b/mindone/transformers/mindspore_adapter/recompute.py
new file mode 100644
index 0000000000..03fda20426
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/recompute.py
@@ -0,0 +1,10 @@
+from mindspore import nn
+
+
+def recompute_except_output(cell: nn.Cell, **recompute_kwargs):
+ if not cell._has_config_recompute:
+ cell.recompute(**recompute_kwargs)
+ if isinstance(cell, nn.CellList):
+ recompute_except_output(cell[-1])
+ else:
+ cell.add_flags(output_no_recompute=True)
diff --git a/mindone/transformers/mindspore_adapter/train_onestep_wrapper.py b/mindone/transformers/mindspore_adapter/train_onestep_wrapper.py
new file mode 100644
index 0000000000..0ff7745f79
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/train_onestep_wrapper.py
@@ -0,0 +1,270 @@
+from typing import Dict
+
+import mindspore as ms
+from mindspore import ParallelMode, Tensor, context, nn, ops
+from mindspore.boost.grad_accumulation import gradient_clear_op as _grad_clear_op
+from mindspore.ops import composite as C
+from mindspore.ops import operations as P
+
+try:
+ from .adamw_zero import AdamWeightDecayZeRO1, AdamWeightDecayZeRO2
+
+ is_adamw_zero_available = True
+except ImportError:
+ is_adamw_zero_available = False
+
+
+_grad_accum_op = C.MultitypeFuncGraph("gradient_accumulation_op")
+
+
+@_grad_accum_op.register("Int64", "Tensor", "Tensor")
+def cumulative_grad_process(cumulative_grad, grad):
+ """Apply gradient accumulation to cumulative grad."""
+ P.AssignAdd()(cumulative_grad, grad)
+ return cumulative_grad
+
+
+def _is_pynative_parallel():
+ parallel_mode = context.get_auto_parallel_context("parallel_mode")
+ return context.get_context("mode") == context.PYNATIVE_MODE and parallel_mode in (
+ context.ParallelMode.SEMI_AUTO_PARALLEL,
+ context.ParallelMode.AUTO_PARALLEL,
+ )
+
+
+def create_loss_scaler(ms_loss_scaler="static", scale_value=1024, scale_factor=2, scale_window=1000):
+ if ms_loss_scaler == "dynamic":
+ from mindspore.amp import DynamicLossScaler
+
+ loss_scaler = DynamicLossScaler(scale_value=scale_value, scale_factor=scale_factor, scale_window=scale_window)
+ elif ms_loss_scaler == "static":
+ from mindspore.amp import StaticLossScaler
+
+ loss_scaler = StaticLossScaler(scale_value=scale_value)
+ elif ms_loss_scaler in ("none", "None"):
+ from mindspore.amp import StaticLossScaler
+
+ loss_scaler = StaticLossScaler(1.0)
+ else:
+ raise NotImplementedError(f"Not support ms_loss_scaler: {ms_loss_scaler}")
+
+ return loss_scaler
+
+
+def _is_parallel():
+ is_parallel = (
+ context.get_auto_parallel_context("parallel_mode") in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
+ or _is_pynative_parallel()
+ )
+ return is_parallel
+
+
+def _is_cpu():
+ return context.get_context("device_target") == "CPU"
+
+
+def return_true(*args, **kwargs):
+ return ops.ones((), ms.bool_)
+
+
+def create_grad_reducer(trainable_parameters):
+ use_reducer = _is_parallel()
+
+ if use_reducer:
+ mean = context.get_auto_parallel_context("gradients_mean")
+ degree = context.get_auto_parallel_context("device_num")
+ grad_reducer = nn.DistributedGradReducer(trainable_parameters, mean, degree)
+ else:
+ grad_reducer = nn.Identity()
+ return grad_reducer
+
+
+class TrainOneStepWrapper(nn.Cell):
+ """TrainStep with ema and clip grad.
+
+ Returns:
+ Tuple of 3 Tensor, the loss, overflow flag and current loss scale value.
+ loss (Tensor) - A scalar, the loss value.
+ overflow (Tensor) - A scalar, whether overflow occur or not, the type is bool.
+ loss scale (Tensor) - The loss scale value, the shape is :math:`()` or :math:`(1,)`.
+
+ """
+
+ def __init__(
+ self,
+ network: nn.Cell,
+ optimizer: nn.Optimizer,
+ ema: nn.Cell = None,
+ drop_overflow_step: bool = True,
+ scaler: str = "default",
+ scaler_config: Dict = {},
+ gradient_accumulation_steps: int = 1,
+ clip_grad: str = "none",
+ clip_value: float = 1.0,
+ ):
+ super().__init__(auto_prefix=False)
+
+ if is_adamw_zero_available and isinstance(optimizer, (AdamWeightDecayZeRO1, AdamWeightDecayZeRO2)):
+ assert hasattr(optimizer, "grad_reduce")
+ reducer = None
+ if optimizer.shard_size > 1:
+ is_zero = True
+ self.reduce_op_for_clip_grad = ops.AllReduce(group=optimizer.comm_group)
+ else:
+ is_zero = False
+ else:
+ reducer = create_grad_reducer(network.trainable_params())
+ is_zero = False
+
+ # grad accumulation
+ assert gradient_accumulation_steps >= 1
+ self.accum_steps = gradient_accumulation_steps
+ if gradient_accumulation_steps > 1:
+ self.hyper_map = ops.HyperMap()
+ self.cur_accum_step = ms.Parameter(ms.Tensor(0, dtype=ms.int32), name="accum_step", requires_grad=False)
+
+ if is_zero:
+ self.accumulated_grads = optimizer.moments1.clone(prefix="accum_grad", init="zeros") # split grads
+ else:
+ self.accumulated_grads = optimizer.parameters.clone(prefix="accum_grad", init="zeros")
+
+ class ScalingLossForGradAccum(nn.Cell):
+ def __init__(self, net, accum_steps_):
+ super(ScalingLossForGradAccum, self).__init__(auto_prefix=False)
+ self.net = net
+ self.accum_steps_ = accum_steps_
+
+ def construct(self, *args, **kwargs):
+ loss = self.net(*args, **kwargs)
+ return loss / self.accum_steps_
+
+ network = ScalingLossForGradAccum(network, gradient_accumulation_steps)
+
+ # grad and optimizer
+ self.network = network
+ self.network.set_train()
+ self.network.set_grad()
+
+ # self.value_and_grad = ops.value_and_grad(network, grad_position=None, weights=optimizer.parameters)
+ self.grad_fn = ops.GradOperation(get_by_list=True, sens_param=True)(self.network, optimizer.parameters)
+
+ self.optimizer = optimizer
+ self.ema = ema
+
+ # scaler and reducer
+ assert "ms_loss_scaler" not in scaler_config
+ if scaler.lower() in ("default", "static"):
+ _scaler_config = {"scale_value": 1024}
+ _scaler_config.update(scaler_config)
+ scaler = create_loss_scaler("static", **_scaler_config)
+ elif scaler.lower() in ("auto", "dynamic"):
+ scaler = create_loss_scaler("dynamic", **scaler_config)
+ elif scaler.lower() == "none":
+ scaler = create_loss_scaler("none", **scaler_config)
+ else:
+ raise NotImplementedError
+
+ self.scaler = scaler
+ self.reducer = reducer
+ self.is_zero = is_zero
+ self.all_finite = ms.amp.all_finite if not _is_cpu() else return_true
+ self.all_finite_reducer = ops.AllReduce() if _is_parallel() else nn.Identity()
+ self.drop_overflow_step = Tensor(drop_overflow_step, ms.bool_)
+
+ # clip grad
+ assert clip_value > 0.0 and isinstance(
+ clip_value, float
+ ), f"clip_value must be float > 0., but got {clip_value}"
+ self.clip_value = clip_value
+ self.is_clip_norm = False
+ if clip_grad.lower() in ("norm", "l2norm", "l2_norm", "global", "global_norm", "total", "total_norm"):
+ self.is_clip_norm = True
+ if self.is_zero:
+ from mindone.transformers.mindspore_adapter.clip_grad import clip_grad_norm_for_zero
+
+ clip_grad_fn = clip_grad_norm_for_zero
+ else:
+ from mindone.transformers.mindspore_adapter.clip_grad import clip_grad_norm
+
+ clip_grad_fn = clip_grad_norm
+ elif clip_grad.lower() in ("local", "value"):
+ from mindone.transformers.mindspore_adapter.clip_grad import clip_grad_value
+
+ clip_grad_fn = clip_grad_value
+ elif clip_grad.lower() == "none":
+ clip_grad_fn = None
+ else:
+ raise NotImplementedError
+ self.clip_grad_fn = clip_grad_fn
+
+ def do_optim(self, loss, grads):
+ if self.accum_steps == 1:
+ if self.clip_grad_fn is not None:
+ if self.is_zero and self.is_clip_norm:
+ grads = self.clip_grad_fn(grads, self.clip_value, self.reduce_op_for_clip_grad)
+ else:
+ grads = self.clip_grad_fn(grads, self.clip_value)
+ loss = ops.depend(loss, self.optimizer(grads))
+ if self.ema is not None:
+ self.ema.ema_update()
+ else:
+ loss = ops.depend(loss, self.hyper_map(_grad_accum_op, self.accumulated_grads, grads))
+ loss = ops.depend(loss, ops.assign_add(self.cur_accum_step, ms.Tensor(1, ms.int32)))
+ if self.cur_accum_step % self.accum_steps == 0:
+ if self.clip_grad_fn is not None:
+ if self.is_zero and self.is_clip_norm:
+ clipped_grads = self.clip_grad_fn(
+ self.accumulated_grads, self.clip_value, self.reduce_op_for_clip_grad
+ )
+ else:
+ clipped_grads = self.clip_grad_fn(self.accumulated_grads, self.clip_value)
+
+ loss = ops.depend(loss, self.optimizer(clipped_grads))
+ else:
+ loss = ops.depend(loss, self.optimizer(self.accumulated_grads))
+
+ loss = ops.depend(loss, self.hyper_map(ops.partial(_grad_clear_op), self.accumulated_grads))
+ loss = ops.depend(loss, ops.assign(self.cur_accum_step, ms.Tensor(0, ms.int32)))
+ if self.ema is not None:
+ self.ema.ema_update()
+ else:
+ # update the optimizer global step and learning rate, do not update the parameter
+ loss = ops.depend(
+ loss, ops.assign_add(self.optimizer.global_step, self.optimizer.global_step_increase_tensor)
+ )
+
+ # unscaling loss for grad accum
+ loss = loss * self.accum_steps
+
+ return loss
+
+ def construct(self, *inputs):
+ loss = self.network(*inputs)
+ sens = ops.fill(loss.dtype, loss.shape, self.scaler.scale_value)
+ grads = self.grad_fn(*inputs, sens)
+ if self.is_zero:
+ grads = self.optimizer.grad_reduce(grads)
+ else:
+ grads = self.reducer(grads)
+ unscaled_grads = self.scaler.unscale(grads)
+
+ finite = self.all_finite(unscaled_grads)
+ finite = ops.equal(
+ self.all_finite_reducer(finite.to(ms.int32)), self.all_finite_reducer(ops.ones((), ms.int32))
+ ).to(ms.bool_)
+ finite = ops.depend(finite, self.scaler.adjust(finite)).to(ms.bool_)
+
+ if not self.drop_overflow_step:
+ loss = self.do_optim(loss, unscaled_grads)
+ loss = loss.to(ms.float32)
+ else:
+ if finite:
+ loss = self.do_optim(loss, unscaled_grads)
+ loss = loss.to(ms.float32)
+ else:
+ # FIXME: has bug when run amp fp16 on MindSpore 2.3
+ loss = loss.to(ms.float32)
+
+ overflow_tag = not finite
+
+ return loss, unscaled_grads, overflow_tag
diff --git a/mindone/transformers/mindspore_adapter/training_args.py b/mindone/transformers/mindspore_adapter/training_args.py
new file mode 100644
index 0000000000..ead7699741
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/training_args.py
@@ -0,0 +1,117 @@
+import os
+from dataclasses import dataclass, field
+from typing import Optional
+
+import mindspore as ms
+from mindspore.communication.management import get_group_size, get_rank, init
+
+
+@dataclass
+class MindSporeArguments:
+ # for mindspore
+
+ mode: int = field(default=ms.GRAPH_MODE, metadata={"help": "Graph/Pynative"})
+
+ jit_level: Optional[str] = field(default="O0", metadata={"help": ("jit level")})
+
+ device_target: str = field(default="Ascend", metadata={"help": "Ascend/GPU/CPU"})
+
+ is_distribute: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
+ " only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add"
+ " CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op"
+ " offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard"
+ " auto_wrap` or `shard_grad_op auto_wrap`."
+ ),
+ },
+ )
+ rank: int = field(default=0, metadata={"help": "rank id"})
+ rank_size: int = field(default=1, metadata={"help": "device num"})
+
+ enable_flash_attention: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": (
+ "if enable_flash_attention is True, model attention implementation will be set to `flash_attention_2`"
+ )
+ },
+ )
+
+ adamw_enable_fuse: Optional[bool] = field(
+ default=True,
+ metadata={"help": ("enable fuse op")},
+ )
+ adamw_zero_shard_size: Optional[int] = field(
+ default=None,
+ metadata={"help": ("setting zero parallelism shard size")},
+ )
+ max_device_memory: Optional[str] = field(
+ default=None,
+ metadata={"help": ("max device memory")},
+ )
+
+ precision_mode: Optional[str] = field(
+ default="must_keep_origin_dtype", metadata={"help": ("global precision_mode")}
+ )
+
+
+def init_environment(training_args: MindSporeArguments):
+ # FIXME, stream synchronize bug when jit_level is `O0` on MindSpore 2.3.0
+ if training_args.mode == 0:
+ if os.environ.get("MS_DEV_RUNTIME_CONF") is None:
+ os.environ["MS_DEV_RUNTIME_CONF"] = "synchronize:True"
+ print("WARNING: os environment MS_DEV_RUNTIME_CONF synchronize has not been set, force setting it now.")
+ else:
+ if "synchronize:True" not in os.environ.get("MS_DEV_RUNTIME_CONF"):
+ _old = os.environ.get("MS_DEV_RUNTIME_CONF")
+ _old.replace("synchronize:False,", "")
+ _old.replace(",synchronize:False", "")
+ _old.replace("synchronize:False", "")
+ _new = "synchronize:True," + _old if len(_old) > 0 else "synchronize:True"
+ os.environ["MS_DEV_RUNTIME_CONF"] = _new
+ print("WARNING: os environment MS_DEV_RUNTIME_CONF synchronize has not been set, force setting it now.")
+
+ # set mindspore context
+ ms.set_context(
+ mode=training_args.mode,
+ device_target=training_args.device_target,
+ jit_config={"jit_level": training_args.jit_level},
+ deterministic="ON",
+ pynative_synchronize=True,
+ memory_optimize_level="O1",
+ # jit_syntax_level=ms.STRICT
+ )
+
+ if training_args.mode == ms.PYNATIVE_MODE:
+ print("WARNING: run pynative mode, set `pynative_synchronize` True")
+
+ if training_args.max_device_memory is not None:
+ ms.set_context(max_device_memory=training_args.max_device_memory)
+
+ if training_args.precision_mode is not None:
+ ms.set_context(
+ ascend_config={"precision_mode": training_args.precision_mode},
+ )
+
+ if training_args.is_distribute:
+ init()
+ world_size = get_group_size()
+ rank_id = get_rank()
+ print(f"init_environment, rank_id: {rank_id}, world_size: {world_size}")
+
+ ms.reset_auto_parallel_context()
+
+ ms.set_auto_parallel_context(
+ parallel_mode=ms.ParallelMode.DATA_PARALLEL,
+ gradients_mean=True,
+ device_num=world_size,
+ )
+
+ training_args.rank = rank_id
+ training_args.rank_size = world_size
+ else:
+ training_args.rank = 0
+ training_args.rank_size = 1
diff --git a/mindone/transformers/mindspore_adapter/utils.py b/mindone/transformers/mindspore_adapter/utils.py
new file mode 100644
index 0000000000..aaa8999088
--- /dev/null
+++ b/mindone/transformers/mindspore_adapter/utils.py
@@ -0,0 +1,60 @@
+import numpy as np
+
+import mindspore as ms
+from mindspore import ParallelMode
+
+_DTYPE_2_STRING = {
+ ms.float16: "float16",
+ ms.bfloat16: "bfloat16",
+ ms.float32: "float32",
+ ms.float64: "float64",
+ ms.uint8: "uint8",
+ ms.int8: "int8",
+ ms.int16: "int16",
+ ms.int32: "int32",
+ ms.int64: "int64",
+ ms.bool_: "bool",
+}
+
+
+_MIN_FP16 = ms.tensor(np.finfo(np.float16).min, dtype=ms.float16)
+_MIN_FP32 = ms.tensor(np.finfo(np.float32).min, dtype=ms.float32)
+_MIN_FP64 = ms.tensor(np.finfo(np.float64).min, dtype=ms.float64)
+_MIN_BF16 = ms.tensor(float.fromhex("-0x1.fe00000000000p+127"), dtype=ms.bfloat16)
+
+
+_DTYPE_2_MIN = {
+ ms.float16: _MIN_FP16,
+ ms.float32: _MIN_FP32,
+ ms.float64: _MIN_FP64,
+ ms.bfloat16: _MIN_BF16,
+}
+
+
+def dtype_to_min(dtype):
+ if dtype in _DTYPE_2_MIN:
+ return _DTYPE_2_MIN[dtype]
+ else:
+ raise ValueError(f"Only support get minimum value of (float16, ), but got {dtype}")
+
+
+def dtype_to_str(dtype):
+ return _DTYPE_2_STRING.get(dtype, "others dtype")
+
+
+def _is_parallel():
+ return ms.context.get_auto_parallel_context("parallel_mode") not in (ParallelMode.STAND_ALONE,)
+
+
+def _is_graph():
+ return ms.context.get_context("mode") == ms.GRAPH_MODE
+
+
+def _is_ascend():
+ return ms.context.get_context("device_target") == "Ascend"
+
+
+# FIXME: Can't work on MindSpore 2.3.0
+# @ms.constexpr(reuse_result=False)
+# def _tensor_2_tuple(input):
+# return tuple(input.asnumpy().tolist())
diff --git a/mindone/transformers/mindspore_utils.py b/mindone/transformers/mindspore_utils.py
index 3a1dd6324b..a60cf11128 100644
--- a/mindone/transformers/mindspore_utils.py
+++ b/mindone/transformers/mindspore_utils.py
@@ -89,12 +89,12 @@ def prune_conv1d_layer(layer: Conv1D, index: ms.Tensor, dim: int = 1) -> Conv1D:
Used to remove heads.
Args:
- layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
- index (`torch.LongTensor`): The indices to keep in the layer.
+ layer ([`~mindspore_utils.Conv1D`]): The layer to prune.
+ index (`ms.Tensor`): The indices to keep in the layer.
dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
Returns:
- [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
+ [`~mindspore_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
"""
w = layer.weight.index_select(dim, index).clone()
if dim == 0:
@@ -115,17 +115,17 @@ def prune_conv1d_layer(layer: Conv1D, index: ms.Tensor, dim: int = 1) -> Conv1D:
def prune_layer(layer: Union[nn.Dense, Conv1D], index: ms.Tensor, dim: Optional[int] = None) -> Union[nn.Dense, Conv1D]:
"""
- Prune a Conv1D or linear layer to keep only entries in index.
+ Prune a Conv1D or Dense layer to keep only entries in index.
Used to remove heads.
Args:
- layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
- index (`torch.LongTensor`): The indices to keep in the layer.
+ layer (`Union[mindspore.nn.Dense, Conv1D]`): The layer to prune.
+ index (`mindspore.Tensor`): The indices to keep in the layer.
dim (`int`, *optional*): The dimension on which to keep the indices.
Returns:
- `torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
+ `mindspore.nn.Dense` or [`~mindspore_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
"""
if isinstance(layer, nn.Dense):
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
@@ -148,7 +148,7 @@ def find_pruneable_heads_and_indices(
already_pruned_heads (`Set[int]`): A set of already pruned heads.
Returns:
- `Tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
+ `Tuple[Set[int], ms.Tensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
into account and the indices of rows/columns to keep in the layer weight.
"""
mask = ops.ones((n_heads, head_size))
diff --git a/mindone/transformers/modeling_attn_mask_utils.py b/mindone/transformers/modeling_attn_mask_utils.py
index be7a9b2673..188a86fadc 100644
--- a/mindone/transformers/modeling_attn_mask_utils.py
+++ b/mindone/transformers/modeling_attn_mask_utils.py
@@ -263,10 +263,10 @@ def _prepare_4d_causal_attention_mask(
Args:
attention_mask (`ms.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
- input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+ input_shape (`tuple(int)` or `list(int)`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
inputs_embeds (`ms.Tensor`):
- The embedded inputs as a torch Tensor.
+ The embedded inputs as a mindspore Tensor.
past_key_values_length (`int`):
The length of the key value cache.
sliding_window (`int`, *optional*):
@@ -310,8 +310,8 @@ def _prepare_4d_attention_mask(mask: ms.Tensor, dtype: ms.Type, tgt_len: Optiona
Args:
mask (`ms.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
- dtype (`torch.dtype`):
- The torch dtype the created mask shall have.
+ dtype (`ms.dtype`):
+ The mindspore dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
@@ -330,8 +330,8 @@ def _create_4d_causal_attention_mask(
Args:
input_shape (`tuple(int)` or `list(int)`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
- dtype (`torch.dtype`):
- The torch dtype the created mask shall have.
+ dtype (`ms.dtype`):
+ The mindspore dtype the created mask shall have.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
"""
diff --git a/mindone/transformers/modeling_outputs.py b/mindone/transformers/modeling_outputs.py
index f62ebb5095..c278889a54 100644
--- a/mindone/transformers/modeling_outputs.py
+++ b/mindone/transformers/modeling_outputs.py
@@ -119,6 +119,45 @@ class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
+@dataclass
+class BaseModelOutputWithPast(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+ encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+ input) to speed up sequential decoding.
+ hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: ms.Tensor = None
+ past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
+ attentions: Optional[Tuple[ms.Tensor, ...]] = None
+
+
@dataclass
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
@@ -226,6 +265,42 @@ class Seq2SeqModelOutput(ModelOutput):
encoder_attentions: Optional[Tuple[ms.Tensor, ...]] = None
+@dataclass
+class CausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[ms.Tensor] = None
+ logits: ms.Tensor = None
+ past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
+ hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
+ attentions: Optional[Tuple[ms.Tensor, ...]] = None
+
+
@dataclass
class Seq2SeqLMOutput(ModelOutput):
"""
@@ -683,7 +758,7 @@ class QuestionAnsweringModelOutput(ModelOutput):
@dataclass
-class SequenceClassifierOutput(ModelOutput):
+class SequenceClassifierOutputWithPast(ModelOutput):
"""
Base class for outputs of sentence classification models.
@@ -692,74 +767,12 @@ class SequenceClassifierOutput(ModelOutput):
Classification (or regression if config.num_labels==1) loss.
logits (`ms.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
- hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
- attentions: Optional[Tuple[ms.Tensor, ...]] = None
-
-
-@dataclass
-class TokenClassifierOutput(ModelOutput):
- """
- Base class for outputs of token classification models.
-
- Args:
- loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
- Classification loss.
- logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
- Classification scores (before SoftMax).
- hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
-
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
-
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
-
- loss: Optional[ms.Tensor] = None
- logits: ms.Tensor = None
- hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
- attentions: Optional[Tuple[ms.Tensor, ...]] = None
-
-
-@dataclass
-class BaseModelOutputWithPast(ModelOutput):
- """
- Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
-
- Args:
- last_hidden_state (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
-
- If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
- hidden_size)` is output.
past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
- encoder_sequence_length, embed_size_per_head)`.
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
- Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
- `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
- input) to speed up sequential decoding.
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
@@ -773,28 +786,23 @@ class BaseModelOutputWithPast(ModelOutput):
heads.
"""
- last_hidden_state: ms.Tensor = None
+ loss: Optional[ms.Tensor] = None
+ logits: ms.Tensor = None
past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@dataclass
-class CausalLMOutputWithPast(ModelOutput):
+class SequenceClassifierOutput(ModelOutput):
"""
- Base class for causal language model (or autoregressive) outputs.
+ Base class for outputs of sentence classification models.
Args:
loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`ms.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
@@ -810,27 +818,20 @@ class CausalLMOutputWithPast(ModelOutput):
loss: Optional[ms.Tensor] = None
logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
@dataclass
-class SequenceClassifierOutputWithPast(ModelOutput):
+class TokenClassifierOutput(ModelOutput):
"""
- Base class for outputs of sentence classification models.
+ Base class for outputs of token classification models.
Args:
- loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification (or regression if config.num_labels==1) loss.
- logits (`ms.Tensor` of shape `(batch_size, config.num_labels)`):
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- past_key_values (`tuple(tuple(ms.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
+ loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
+ Classification loss.
+ logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+ Classification scores (before SoftMax).
hidden_states (`tuple(ms.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `ms.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
@@ -846,6 +847,5 @@ class SequenceClassifierOutputWithPast(ModelOutput):
loss: Optional[ms.Tensor] = None
logits: ms.Tensor = None
- past_key_values: Optional[Tuple[Tuple[ms.Tensor]]] = None
hidden_states: Optional[Tuple[ms.Tensor, ...]] = None
attentions: Optional[Tuple[ms.Tensor, ...]] = None
diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py
index 7d413cf204..407b436d7c 100644
--- a/mindone/transformers/modeling_utils.py
+++ b/mindone/transformers/modeling_utils.py
@@ -18,12 +18,14 @@
import json
import os
import re
+import time
import warnings
from contextlib import contextmanager, nullcontext
from typing import Callable, Dict, Optional, Tuple, Union
from transformers.configuration_utils import PretrainedConfig
from transformers.dynamic_module_utils import custom_object_save
+from transformers.generation.utils import GenerationConfig
from transformers.safetensors_conversion import auto_conversion
from transformers.utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
@@ -53,8 +55,11 @@
import mindspore as ms
from mindspore import Tensor, nn, ops
+from .generation.utils import GenerationMixin
from .integrations import PeftAdapterMixin
+from .mindspore_adapter import dtype_to_str
from .modeling_attn_mask_utils import dtype_to_min
+from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available
if is_safetensors_available():
from safetensors import safe_open
@@ -187,7 +192,7 @@ def shard_checkpoint(
Args:
- state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
+ state_dict (`Dict[str, ms.Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
@@ -323,6 +328,9 @@ class ModuleUtilsMixin:
A few utilities for `mindspore.nn.Cell`, to be used as a mixin.
"""
+ def _get_name(self):
+ return self.__class__.__name__
+
def to(self, dtype: Optional[ms.Type] = None):
for p in self.get_parameters():
p.set_dtype(dtype)
@@ -511,7 +519,7 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
return sum(total_numel)
-class MSPreTrainedModel(nn.Cell, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
+class MSPreTrainedModel(nn.Cell, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.
@@ -525,7 +533,7 @@ class MSPreTrainedModel(nn.Cell, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
for this model architecture.
- - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
+ - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a MindSpore model,
taking as arguments:
- **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
@@ -571,13 +579,17 @@ class MSPreTrainedModel(nn.Cell, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# SDPA support
_supports_sdpa = False
- # Has support for a `Cache` instance as `past_key_values`
+ # Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False
+ _supports_static_cache = False
+
+ # Has support for a `QuantoQuantizedCache` instance as `past_key_values`
+ _supports_quantized_cache = False
@property
def dummy_inputs(self) -> Dict[str, ms.Tensor]:
"""
- `Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
+ `Dict[str, ms.Tensor]`: Dummy inputs to do a forward pass in the network.
"""
return {"input_ids": ms.tensor(DUMMY_INPUTS)}
@@ -600,7 +612,7 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
self.config = config
self.name_or_path = config.name_or_path
self.warnings_issued = {}
- self.generation_config = None
+ self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# Overwrite the class attribute to make it an instance attribute, so models like
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
# when a different component (e.g. language_model) is used.
@@ -613,6 +625,153 @@ def post_init(self):
"""
self.init_weights()
+ @classmethod
+ def _autoset_attn_implementation(
+ cls,
+ config,
+ use_flash_attention_2: bool = False,
+ mindspore_dtype=None,
+ ):
+ """
+ Automatically checks and dispatches to a default attention implementation. In order of priority:
+ 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained).
+ 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example)
+ 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example)
+ 4. The default model's implementation otherwise (`LlamaAttention` for example) .
+ """
+ # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
+ # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
+ # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
+ requested_attn_implementation = None
+ if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None:
+ if config._attn_implementation != "flash_attention_2" and use_flash_attention_2:
+ raise ValueError(
+ f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were '
+ f"used when loading the model, which are not compatible."
+ ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
+ )
+
+ if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
+ message = (
+ f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. '
+ f'The only possible arguments are `attn_implementation="eager"`'
+ f" (manual attention implementation)"
+ )
+ if cls._supports_flash_attn_2:
+ message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
+ if cls._supports_sdpa:
+ message += ', `"attn_implementation=sdpa"` (implementation using scaled_dot_product_attention)'
+ raise ValueError(message + ".")
+
+ # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the
+ # user-provided config, with hard checks that the requested attention implementation is available.
+ requested_attn_implementation = config._attn_implementation_internal
+
+ if use_flash_attention_2:
+ logger.warning_once(
+ "The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a "
+ 'future release. Please use `attn_implementation="flash_attention_2"` instead.'
+ )
+ config._attn_implementation = "flash_attention_2"
+
+ if config._attn_implementation == "flash_attention_2":
+ cls._check_and_enable_flash_attn_2(
+ config,
+ mindspore_dtype=mindspore_dtype,
+ hard_check_only=False,
+ )
+ elif requested_attn_implementation in [None, "sdpa"]:
+ # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
+ config = cls._check_and_enable_sdpa(
+ config,
+ hard_check_only=False if requested_attn_implementation is None else True,
+ )
+ else:
+ config._attn_implementation = "eager"
+
+ return config
+
+ @classmethod
+ def can_generate(cls) -> bool:
+ """
+ Returns whether this model can generate sequences with `.generate()`.
+
+ Returns:
+ `bool`: Whether this model can generate sequences with `.generate()`.
+ """
+ # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
+ # Alternativelly, the model can also have a custom `generate` function.
+ if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
+ return False
+ return True
+
+ @classmethod
+ def _check_and_enable_flash_attn_2(
+ cls,
+ config,
+ mindspore_dtype=None,
+ hard_check_only: bool = False,
+ ) -> PretrainedConfig:
+ """
+ Checks the availability of Flash Attention 2 and compatibility with the current model.
+
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute
+ `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module.
+ """
+ if not cls._supports_flash_attn_2:
+ raise ValueError(
+ f"{cls.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
+ f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
+ " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
+ )
+
+ if not is_flash_attn_2_available():
+ raise ImportError("FlashAttention2 has been toggled on, but it cannot be used due to some error")
+
+ if mindspore_dtype is None:
+ logger.warning_once(
+ "You are attempting to use Flash Attention 2.0 without specifying a MindSpore dtype. This might lead to unexpected behaviour"
+ )
+ elif mindspore_dtype is not None and mindspore_dtype not in [ms.float16, ms.bfloat16]:
+ logger.warning_once(
+ "Flash Attention 2.0 only supports ms.float16 and ms.bfloat16 dtypes, but"
+ f" the current dype in {cls.__name__} is {mindspore_dtype}. You should run training or inference using "
+ f"Automatic Mixed-Precision via the `network=auto_mix_precision(network, ...)` decorator,"
+ " or load the model with the `mindspore_dtype` argument. Example: `model = "
+ 'AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", mindspore_dtype=ms.float16)`'
+ )
+
+ if not hard_check_only:
+ config._attn_implementation = "flash_attention_2"
+ return config
+
+ @classmethod
+ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
+ """
+ Checks the availability of SDPA for a given model.
+
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation`
+ to "flash_attention_2" so that the model can initialize the correct attention module.
+ """
+ if hard_check_only:
+ if not cls._supports_sdpa:
+ raise ValueError(
+ f"{cls.__name__} does not support an attention implementation through `scaled_dot_product_attention` yet."
+ " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. "
+ "If you believe this error is a bug, please open an issue in Transformers GitHub repository and "
+ 'load your model with the argument `attn_implementation="eager"` meanwhile. Example: '
+ '`model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
+ )
+ if not is_sdpa_available():
+ raise ImportError("SDPA requirements in Transformers are not met.")
+
+ if not is_sdpa_available() or not cls._supports_sdpa:
+ return config
+
+ if not hard_check_only:
+ config._attn_implementation = "sdpa"
+ return config
+
def get_input_embeddings(self) -> nn.Cell:
"""
Returns the model's input embeddings.
@@ -653,7 +812,7 @@ def _init_weights(self, module):
Initialize the weights. This method should be overridden by derived class and is
the only initialization method that will be called when loading a checkpoint
using `from_pretrained`. Any attempt to initialize outside of this function
- will be useless as the torch.nn.init function are all replaced with skip.
+ will be useless as the mindspore.common.initializer function are all replaced with skip.
"""
pass
@@ -678,7 +837,7 @@ def resize_token_embeddings(
new_num_tokens (`int`, *optional*):
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
- returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
+ returns a pointer to the input tokens `mindspore.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
@@ -736,7 +895,7 @@ def _get_resized_embeddings(
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
- `torch.nn.Embedding` module of the model without doing anything.
+ `mindspore.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
@@ -749,7 +908,8 @@ def _get_resized_embeddings(
if pad_to_multiple_of is not None:
if not isinstance(pad_to_multiple_of, int):
raise ValueError(
- f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
+ f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, "
+ f"which is not and integer. Please make sure to pass an integer"
)
if new_num_tokens is None:
new_num_tokens = old_embeddings.embedding_table.shape[0]
@@ -917,13 +1077,13 @@ def save_pretrained(
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
- state_dict (nested dictionary of `torch.Tensor`):
+ state_dict (nested dictionary of `ms.Tensor`):
The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
save parts of the model or if special precautions need to be taken when recovering the state dictionary
of a model (like when using model parallelism).
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
- need to replace `torch.save` by another method.
+ need to replace `ms.save_checkpoint` by another method.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -990,7 +1150,7 @@ def save_pretrained(
# Only save the model itself if we are using distributed training
model_to_save = self # we don't unwrap_model(self) in mindspore
- # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
+ # save the string version of dtype to the config, e.g. convert ms.float32 => "float32"
# we currently don't use this setting automatically, but may start to use with v5
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.torch_dtype = repr(dtype).split(".")[1]
@@ -1007,6 +1167,26 @@ def save_pretrained(
if is_main_process:
if not _hf_peft_config_loaded:
model_to_save.config.save_pretrained(save_directory)
+ if self.can_generate():
+ # generation config built from the model config + the model config holds generation kwargs -> generate
+ # may revert to legacy behavior if the two don't match
+ if (
+ model_to_save.generation_config._from_model_config
+ and model_to_save.config._has_non_default_generation_parameters()
+ ):
+ new_generation_config = GenerationConfig.from_model_config(model_to_save.config)
+ if new_generation_config != model_to_save.generation_config:
+ logger.warning(
+ "Your generation config was originally created from the model config, but the model "
+ "config has changed since then. Unless you pass the `generation_config` argument to this "
+ "model's `generate` calls, they will revert to the legacy behavior where the base "
+ "`generate` parameterization is loaded from the model config instead. "
+ "To avoid this behavior and this warning, we recommend you to overwrite the generation "
+ "config model attribute before calling the model's `save_pretrained`, preferably also "
+ "removing any generation kwargs from the model config. This warning will be raised to an "
+ "exception in v4.41."
+ )
+ model_to_save.generation_config.save_pretrained(save_directory)
if _hf_peft_config_loaded:
logger.info(
@@ -1016,7 +1196,8 @@ def save_pretrained(
if save_peft_format:
logger.info(
- "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`."
+ "To match the expected format of the PEFT library, all keys of the state dict of adapters will "
+ "be pre-pended with `base_model.model`."
)
peft_state_dict = {}
for key, value in state_dict.items():
@@ -1027,7 +1208,8 @@ def save_pretrained(
if len(active_adapter) > 1:
raise ValueError(
- "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
+ "Multiple active adapters detected, saving multiple active adapters is not supported yet. "
+ "You can save adapters separately one by one "
"by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
)
active_adapter = active_adapter[0]
@@ -1061,7 +1243,7 @@ def save_pretrained(
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
- # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
+ # make sure that file to be deleted matches format of sharded file, e.g. mindspore_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
@@ -1115,7 +1297,7 @@ def from_pretrained(
**kwargs,
):
r"""
- Instantiate a pretrained pytorch model from a pre-trained model configuration.
+ Instantiate a pretrained mindspore model from a pre-trained model configuration.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, you should first set it back in training mode with `model.train()`.
@@ -1137,7 +1319,7 @@ def from_pretrained(
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
- PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+ MindSpore model using the provided conversion scripts and loading the MindSpore model afterwards.
- A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
`./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
`True`.
@@ -1160,7 +1342,7 @@ def from_pretrained(
save directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
- state_dict (`Dict[str, torch.Tensor]`, *optional*):
+ state_dict (`Dict[str, ms.Tensor]`, *optional*):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own
@@ -1214,9 +1396,9 @@ def from_pretrained(
Override the default `mindspore.Type` and load the model under a specific `dtype`. The different options
are:
- 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
+ 1. `ms.float16` or `ms.bfloat16` or `ms.float32`: load in a specified
`dtype`, ignoring the model's `config.mindspore_dtype` if one exists. If not specified
- - the model will get loaded in `torch.float` (fp32).
+ - the model will get loaded in `ms.float32` (fp32).
2. `"auto"` - A `mindspore_dtype` entry in the `config.json` file of the model will be
attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
@@ -1236,7 +1418,7 @@ def from_pretrained(
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
variant (`str`, *optional*):
- If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+ If specified load weights from `variant` filename, *e.g.* mindspore_model..bin. `variant` is
ignored when using `from_tf` or `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
@@ -1275,10 +1457,10 @@ def from_pretrained(
>>> # Update configuration during loading.
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
>>> assert model.config.output_attentions == True
- >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
+ >>> # Loading from a TF checkpoint file instead of a MindSpore model (slower, for example purposes, not runnable).
>>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
- >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
+ >>> # Loading from a Flax checkpoint file instead of a MindSpore model (slower)
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
```
@@ -1312,6 +1494,7 @@ def from_pretrained(
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
+ use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
@@ -1672,7 +1855,7 @@ def from_pretrained(
pass
elif metadata.get("format") == "tf":
from_tf = True
- logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
+ logger.info("A TensorFlow safetensors file is being loaded in a MindSpore model.")
elif metadata.get("format") == "flax":
from_flax = True
logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
@@ -1729,12 +1912,21 @@ def from_pretrained(
loaded_state_dict_keys = list(state_dict.keys())
config.name_or_path = pretrained_model_name_or_path
+
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
+ config = cls._autoset_attn_implementation(
+ config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype
+ )
+
model = cls(config, *model_args, **model_kwargs)
# We cannot set default mindspore dtype. So we need to cast model weights after creating.
if mindspore_dtype is not None:
model = model.to(mindspore_dtype)
+ logger.info(
+ f"convert model:{model.__class__.__name__} parameters to mindspore_dtype {dtype_to_str(mindspore_dtype)}"
+ )
+
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@@ -1778,6 +1970,29 @@ def from_pretrained(
# Set model in evaluation mode to deactivate DropOut modules by default
model.set_train(False)
+ # If it is a model with generation capabilities, attempt to load the generation config
+ if model.can_generate() and pretrained_model_name_or_path is not None:
+ try:
+ model.generation_config = GenerationConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ _from_auto=from_auto_class,
+ _from_pipeline=from_pipeline,
+ **kwargs,
+ )
+ except OSError:
+ logger.info(
+ "Generation config file not found, using a generation config created from the model config."
+ )
+ pass
+
if output_loading_info:
if loading_info is None:
loading_info = {
@@ -1805,9 +2020,7 @@ def _load_pretrained_model(
):
# Mapping loaded_keys from pt to ms
pt2ms_mappings = _get_pt2ms_mappings(model)
- loaded_keys = [
- _get_pt2ms_mapped_kv(pt2ms_mappings, s, None, f"{model.base_model_prefix}.")[0] for s in loaded_keys
- ]
+ loaded_keys = [_get_pt2ms_mapped_kv(pt2ms_mappings, s, None, "")[0] for s in loaded_keys]
# Retrieve missing & unexpected_keys
model_state_dict = {k: v for k, v in model.parameters_and_names()}
expected_keys = list(model_state_dict.keys())
@@ -1909,6 +2122,7 @@ def _find_mismatched_keys(
if state_dict is not None:
# Whole checkpoint
state_dict = _convert_state_dict(model, state_dict, start_prefix)
+
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
@@ -1930,9 +2144,16 @@ def _find_mismatched_keys(
if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
+
+ # loading checkpoint
+ _s_time = time.time()
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
+ print(f"====> time cost, load_state_dict: {time.time() - _s_time:.3f}s")
+ _s_time = time.time()
state_dict = _convert_state_dict(model, state_dict, start_prefix)
+ print(f"====> time cost, _convert_state_dict: {time.time() - _s_time:.3f}s")
+ _s_time = time.time()
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
@@ -1944,7 +2165,12 @@ def _find_mismatched_keys(
remove_prefix_from_model,
ignore_mismatched_sizes,
)
+ print(f"====> time cost, _find_mismatched_keys: {time.time() - _s_time:.3f}s")
+
+ _s_time = time.time()
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_sharded=True)
+ print(f"====> time cost, _load_state_dict_into_model: {time.time() - _s_time:.3f}s")
+ _s_time = time.time()
# force memory release
del state_dict
diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py
index 08379484ad..77dec5edb8 100644
--- a/mindone/transformers/models/__init__.py
+++ b/mindone/transformers/models/__init__.py
@@ -1 +1 @@
-from . import bert, bit, blip_2, clip, dpt, gemma, gemma2, t5, umt5, xlm_roberta
+from . import bert, bit, blip_2, clip, dpt, gemma, llama, gemma2, t5, umt5, xlm_roberta
diff --git a/mindone/transformers/models/bert/modeling_bert.py b/mindone/transformers/models/bert/modeling_bert.py
index 14c571c3b3..85849bda4b 100644
--- a/mindone/transformers/models/bert/modeling_bert.py
+++ b/mindone/transformers/models/bert/modeling_bert.py
@@ -292,7 +292,7 @@ def construct(
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
- "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
+ "BertSdpaSelfAttention is used but `scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
@@ -339,8 +339,6 @@ def construct(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
# is_causal = (
@@ -1067,12 +1065,13 @@ def construct(
```python
>>> from transformers import AutoTokenizer, BertForPreTraining
- >>> import torch
+ >>> import mindspore as ms
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> inputs = tokenizer("Hello, my dog is cute")
+ >>> inputs = {k:ms.Tensor(v) for k, v in inputs.items()}
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.prediction_logits
@@ -1399,14 +1398,14 @@ def construct(
```python
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction
- >>> import torch
+ >>> import mindspore as ms
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
- >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+ >>> encoding = tokenizer(prompt, next_sentence)
>>> outputs = model(**encoding, labels=ms.Tensor([1]))
>>> logits = outputs.logits
@@ -1517,24 +1516,27 @@ def construct(
if labels is not None:
if self.problem_type is None:
if self.num_labels == 1:
- self.problem_type = "regression"
+ problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == ms.int64 or labels.dtype == ms.int32):
- self.problem_type = "single_label_classification"
+ problem_type = "single_label_classification"
else:
- self.problem_type = "multi_label_classification"
+ problem_type = "multi_label_classification"
+ else:
+ problem_type = self.problem_type
- if self.problem_type == "regression":
+ if problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
- elif self.problem_type == "single_label_classification":
+ elif problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).int())
- elif self.problem_type == "multi_label_classification":
+ elif problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
+
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
diff --git a/mindone/transformers/models/llama/__init__.py b/mindone/transformers/models/llama/__init__.py
new file mode 100644
index 0000000000..858ef6f015
--- /dev/null
+++ b/mindone/transformers/models/llama/__init__.py
@@ -0,0 +1 @@
+from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
diff --git a/mindone/transformers/models/llama/modeling_llama.py b/mindone/transformers/models/llama/modeling_llama.py
new file mode 100644
index 0000000000..c2aa3daf97
--- /dev/null
+++ b/mindone/transformers/models/llama/modeling_llama.py
@@ -0,0 +1,1235 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, Tuple, Union
+
+import numpy as np
+from transformers import LlamaConfig
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
+
+import mindspore as ms
+from mindspore import Parameter, Tensor, nn, ops
+from mindspore.common import initializer as init
+
+from ...activations import ACT2FN
+from ...cache_utils import get_max_length, get_seq_length, update
+from ...mindspore_adapter import recompute_except_output
+from ...mindspore_adapter.attention import FlashAttention2
+from ...mindspore_utils import ALL_LAYERNORM_LAYERS
+from ...modeling_attn_mask_utils import _MIN_FP16
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from ...modeling_utils import MSPreTrainedModel as PreTrainedModel
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+class LlamaRMSNorm(nn.Cell):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = Parameter(Tensor(np.ones(hidden_size), ms.float32), name="weight")
+ self.variance_epsilon = eps
+
+ def construct(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(ms.float32)
+ variance = hidden_states.pow(2).mean(-1, keep_dims=True)
+ hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon)
+ out = self.weight * hidden_states.to(input_dtype)
+ return out
+
+
+ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
+
+
+class LlamaRotaryEmbedding(nn.Cell):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
+ super().__init__()
+ self.scaling_factor = scaling_factor
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (np.arange(0, self.dim, 2).astype(np.float32) / self.dim))
+ self.inv_freq = Parameter(Tensor(inv_freq, ms.float32), requires_grad=False, name="inv_freq_buffer")
+ # For BC we register cos and sin cached
+ self.max_seq_len_cached = max_position_embeddings
+
+ # with no grad
+ def construct(self, x, position_ids):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ inv_freq_expanded = self.inv_freq[None, :, None].to(ms.float32).broadcast_to((position_ids.shape[0], -1, 1))
+ position_ids_expanded = position_ids[:, None, :].to(ms.float32)
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ freqs = ops.matmul(inv_freq_expanded, position_ids_expanded).swapdims(1, 2)
+ emb = ops.cat((freqs, freqs), axis=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ cos, sin = cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+ cos, sin = ops.stop_gradient(cos), ops.stop_gradient(sin)
+ return cos, sin
+
+
+class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def construct(self, x, position_ids):
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
+ position_ids = position_ids.to(ms.float32) / self.scaling_factor
+ cos, sin = super().construct(x, position_ids)
+ return cos, sin
+
+
+class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def construct(self, x, position_ids):
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
+ seq_len = ops.max(position_ids)[0] + 1
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (base ** (ops.arange(0, self.dim, 2, dtype=ms.float32) / self.dim))
+ x = ops.depend(x, ops.assign(self.inv_freq, inv_freq))
+
+ cos, sin = super().construct(x, position_ids)
+ return cos, sin
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return ops.cat((-x2, x1), axis=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`ms.Tensor`): The query tensor.
+ k (`ms.Tensor`): The key tensor.
+ cos (`ms.Tensor`): The cosine part of the rotary embedding.
+ sin (`ms.Tensor`): The sine part of the rotary embedding.
+ position_ids (`ms.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class LlamaMLP(nn.Cell):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=config.mlp_bias)
+ self.up_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=config.mlp_bias)
+ self.down_proj = nn.Dense(self.intermediate_size, self.hidden_size, has_bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ # setting config var to self attribute
+ _name_list = [
+ "pretraining_tp",
+ ]
+ for name in _name_list:
+ setattr(self, name, getattr(config, name))
+
+ def construct(self, x):
+ if self.pretraining_tp > 1:
+ slice = self.intermediate_size // self.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, axis=0)
+ up_proj_slices = self.up_proj.weight.split(slice, axis=0)
+ down_proj_slices = self.down_proj.weight.split(slice, axis=1)
+
+ gate_proj = ops.cat([ops.dense(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], axis=-1)
+ up_proj = ops.cat([ops.dense(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], axis=-1)
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, axis=2)
+ down_proj = [ops.dense(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
+ down_proj = sum(down_proj)
+ else:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+ return down_proj
+
+
+def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor:
+ """
+ This is the equivalent of ops.repeat_interleave(x, axis=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim))
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class LlamaAttention(nn.Cell):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ self.q_proj = nn.Dense(self.hidden_size, self.num_heads * self.head_dim, has_bias=config.attention_bias)
+ self.k_proj = nn.Dense(
+ self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=config.attention_bias
+ )
+ self.v_proj = nn.Dense(
+ self.hidden_size, self.num_key_value_heads * self.head_dim, has_bias=config.attention_bias
+ )
+ self.o_proj = nn.Dense(self.hidden_size, self.hidden_size, has_bias=config.attention_bias)
+ self._init_rope()
+
+ _name_list = [
+ "pretraining_tp",
+ ]
+ for name in _name_list:
+ setattr(self, name, getattr(config, name))
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = LlamaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ scaling_factor = self.config.rope_scaling["factor"]
+ if scaling_type == "linear":
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == "dynamic":
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[ms.Tensor] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs,
+ ):
+ bsz, q_len, _ = hidden_states.shape
+
+ if self.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, axis=0)
+ key_slices = self.k_proj.weight.split(key_value_slicing, axis=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, axis=0)
+
+ query_states = [ops.dense(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
+ query_states = ops.cat(query_states, axis=-1)
+
+ key_states = [ops.dense(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
+ key_states = ops.cat(key_states, axis=-1)
+
+ value_states = [ops.dense(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
+ value_states = ops.cat(value_states, axis=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).swapdims(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapdims(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapdims(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ key_states, value_states = update(past_key_value, key_states, value_states, cache_position)
+ past_key_value = (key_states, value_states)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = ops.matmul(query_states, key_states.swapdims(2, 3)) / (self.head_dim**0.5)
+
+ attn_weights = ops.cast(attn_weights, ms.float32)
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + ops.cast(causal_mask, attn_weights.dtype)
+
+ # upcast attention to fp32
+ attn_weights = ops.softmax(attn_weights, axis=-1, dtype=ms.float32).to(query_states.dtype)
+ attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = ops.matmul(attn_weights, value_states)
+
+ # assert attn_output.shape == (bsz, self.num_heads, q_len, self.head_dim)
+
+ attn_output = attn_output.swapdims(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ if self.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, axis=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, axis=1)
+ attn_output = sum([ops.dense(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class LlamaFlashAttention2(LlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__(config, layer_idx)
+
+ self.flash_attention = FlashAttention2(
+ self.head_dim, self.num_heads, self.attention_dropout, input_layout="BNSD", dtype=ms.float16
+ )
+
+ def convert_mask_to_fa_format(self, attention_mask):
+ if attention_mask is not None:
+ if attention_mask.dtype == ms.bool_:
+ # flip mask, since ms FA treats 1 as discard, 0 as retain.
+ attention_mask = 1 - attention_mask
+ attention_mask = attention_mask.to(ms.uint8)
+ else:
+ attention_mask = attention_mask.to(ms.float16)
+ attention_mask = ops.select(
+ ops.equal(attention_mask, _MIN_FP16),
+ ops.ones((), ms.uint8),
+ ops.zeros((), ms.uint8),
+ )
+
+ return attention_mask
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[Tuple[ms.Tensor, ms.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs,
+ ):
+ # assert output_attentions == False
+
+ bsz, q_len, _ = hidden_states.shape
+
+ if self.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, axis=0)
+ key_slices = self.k_proj.weight.split(key_value_slicing, axis=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, axis=0)
+
+ query_states = [ops.dense(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
+ query_states = ops.cat(query_states, axis=-1)
+
+ key_states = [ops.dense(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
+ key_states = ops.cat(key_states, axis=-1)
+
+ value_states = [ops.dense(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
+ value_states = ops.cat(value_states, axis=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).swapdims(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapdims(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapdims(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ key_states, value_states = update(past_key_value, key_states, value_states, cache_position)
+ past_key_value = (key_states, value_states)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # 1. flash attention
+ if attention_mask is not None: # no matter the length, we just slice it
+ attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attention_mask = self.convert_mask_to_fa_format(attention_mask)
+ attn_output = self.flash_attention(query_states, key_states, value_states, attention_mask)
+ # assert attn_output.shape == (bsz, self.num_heads, q_len, self.head_dim)
+
+ # 2. vanilla attention
+ # attn_weights = ops.matmul(query_states, key_states.swapdims(2, 3)) / (self.head_dim ** 0.5)
+ #
+ # if attention_mask is not None: # no matter the length, we just slice it
+ # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ # attn_weights = attn_weights + causal_mask
+ #
+ # # upcast attention to fp32
+ # attn_weights = ops.softmax(attn_weights, axis=-1, dtype=ms.float32).to(query_states.dtype)
+ # attn_weights = ops.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ # attn_output = ops.matmul(attn_weights, value_states)
+ # # assert attn_output.shape == (bsz, self.num_heads, q_len, self.head_dim)
+
+ attn_output = attn_output.swapdims(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ if self.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, axis=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, axis=1)
+ attn_output = sum([ops.dense(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+LLAMA_ATTENTION_CLASSES = {
+ "eager": LlamaAttention,
+ "flash_attention_2": LlamaFlashAttention2,
+ # "sdpa": None, # not support sdpa
+}
+
+
+class LlamaDecoderLayer(nn.Cell):
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = LlamaMLP(config)
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.output_identity = nn.Identity()
+
+ def construct(
+ self,
+ hidden_states: ms.Tensor,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[ms.Tensor, ms.Tensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[ms.Tensor] = None,
+ **kwargs,
+ ) -> Tuple[ms.Tensor, Optional[Tuple[ms.Tensor, ms.Tensor]]]:
+ """
+ Args:
+ hidden_states (`ms.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`ms.Tensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(ms.Tensor)`, *optional*): cached past key and value projection states
+ cache_position (`ms.Tensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ hidden_states = self.output_identity(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a MindSpore [mindspore.nn.Cell](https://www.mindspore.cn/docs/zh-CN/r2.3.1/api_python/nn/mind
+ spore.nn.Cell.html?highlight=cell#mindspore.nn.Cell) subclass.
+ Use it as a regular MindSpore Cell and refer to the MindSpore documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_cache_class = False
+ _supports_quantized_cache = False
+ _supports_static_cache = False
+
+ def _init_weights(self, cell):
+ std = self.config.initializer_range
+ if isinstance(cell, nn.Dense):
+ cell.weight.set_data(
+ init.initializer(init.Normal(mean=0.0, sigma=std), cell.weight.shape, cell.weight.dtype)
+ )
+ if cell.bias is not None:
+ cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype))
+ elif isinstance(cell, nn.Embedding):
+ cell.embedding_table.set_data(
+ init.initializer(
+ init.Normal(mean=0.0, sigma=std), cell.embedding_table.shape, cell.embedding_table.dtype
+ )
+ )
+ if cell.padding_idx is not None:
+ cell.embedding_table.data[cell.padding_idx] = 0.0
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(ms.Tensor, ms.Tensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Only one formats are allowed:
+ - Tuple of `tuple(ms.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Ignore `return_dict`.
+ cache_position (`ms.Tensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
+ self.layers = nn.CellList(
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ # set congig var to self attribute
+ _name_list = [
+ "output_attentions",
+ "output_hidden_states",
+ "use_return_dict",
+ "use_cache",
+ "_attn_implementation",
+ "pretraining_tp",
+ "vocab_size",
+ ]
+ for name in _name_list:
+ setattr(self, name, getattr(config, name))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value: nn.Embedding):
+ if not isinstance(value, nn.Embedding):
+ raise NotImplementedError
+ ori_name = value.embedding_table.name
+
+ self.embed_tokens = value
+
+ self.embed_tokens.embedding_table.name = ori_name
+
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
+ if gradient_checkpointing_kwargs is None:
+ # gradient_checkpointing_kwargs = {"mp_comm_recompute": True, "parallel_optimizer_comm_recompute": True}
+ gradient_checkpointing_kwargs = {}
+
+ # llama layers
+ for decoder_layer in self.layers:
+ assert isinstance(decoder_layer, LlamaDecoderLayer)
+ for name, cell in decoder_layer.name_cells().items():
+ if "output_identity" in name:
+ assert isinstance(cell, nn.Identity)
+ pass
+ else:
+ # cell._recompute()
+ recompute_except_output(cell, **gradient_checkpointing_kwargs)
+ recompute_except_output(self.embed_tokens, **gradient_checkpointing_kwargs)
+ recompute_except_output(self.norm, **gradient_checkpointing_kwargs)
+
+ logger.info(f"{self.__class__.__name__}: enable recompute.")
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[ms.Tensor] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = False,
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
+ use_cache = use_cache if use_cache is not None else self.use_cache
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+
+ if self.training:
+ use_cache = False
+
+ # assert ((input_ids is None) and (inputs_embeds is not None)) or \
+ # ((input_ids is not None) and (inputs_embeds is None))
+ # # assert (input_ids is None) ^ (inputs_embeds is None)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = get_seq_length(past_key_values) if past_key_values is not None else 0
+ cache_position = ops.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1])
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_caches = () if use_cache else None
+
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values[layer_idx] if past_key_values is not None else None,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_caches += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_caches, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_caches,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: ms.Tensor,
+ input_tensor: ms.Tensor,
+ cache_position: ms.Tensor,
+ past_key_values: Tuple[Tuple[ms.Tensor, ms.Tensor]],
+ output_attentions: bool = False,
+ ):
+ # if self._attn_implementation == "flash_attention_2":
+ # return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = get_seq_length(past_key_values) if past_key_values is not None else 0
+
+ sequence_length = input_tensor.shape[1]
+
+ if past_key_values is not None:
+ target_length = get_max_length(past_key_values)
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, ms.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ if attention_mask is not None and len(attention_mask.shape) == 4:
+ # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
+ # if attention_mask.max() != 0:
+ # raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
+ causal_mask = attention_mask
+ else:
+ causal_mask = ops.broadcast_to(_MIN_FP16, (sequence_length, target_length))
+ if sequence_length != 1:
+ causal_mask = ops.triu(causal_mask, diagonal=1)
+ _mask_position = ops.arange(target_length) > cache_position.reshape(-1, 1)
+ causal_mask *= _mask_position
+ causal_mask = causal_mask[None, None, :, :].broadcast_to((input_tensor.shape[0], 1, -1, -1))
+ if attention_mask is not None:
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, _MIN_FP16
+ )
+
+ return causal_mask
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Dense(config.hidden_size, config.vocab_size, has_bias=False)
+ self.cross_entropy_loss = nn.CrossEntropyLoss()
+
+ _name_list = ["output_attentions", "output_hidden_states", "use_return_dict", "pretraining_tp", "vocab_size"]
+ for name in _name_list:
+ setattr(self, name, getattr(config, name))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[ms.Tensor, ms.Tensor]]] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = False,
+ cache_position: Optional[ms.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.output_attentions
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = outputs[0]
+
+ if self.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, axis=0)
+ logits = [ops.dense(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
+ logits = ops.cat(logits, axis=-1)
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.to(ms.float32)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :]
+ shift_labels = labels[..., 1:]
+ # Flatten the tokens
+ shift_logits = shift_logits.view(-1, self.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ loss = self.cross_entropy_loss(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ use_cache=False,
+ **kwargs,
+ ):
+ past_length = 0
+ if past_key_values is not None:
+ # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
+ past_length = cache_position[0] if cache_position is not None else get_seq_length(past_key_values)
+ max_cache_length = get_max_length(past_key_values) if get_max_length(past_key_values) is not None else None
+ cache_length = past_length if max_cache_length is None else ops.minimum(max_cache_length, past_length)
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
+
+ if attention_mask is not None and int(attention_mask.sum(-1).max()) > input_ids.shape[1]:
+ input_ids = input_ids[:, -(int(attention_mask.sum(-1).max()) - int(past_length)) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, int(past_length) :]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.to(ms.int32).cumsum(-1) - 1
+ position_ids = position_ids.masked_fill(attention_mask == 0, 1)
+ if past_key_values and past_length > 0:
+ cur_len = attention_mask.sum(-1).max()
+ position_ids = position_ids[:, cur_len - input_ids.shape[1] : cur_len]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_length == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ # TODO: use `next_tokens` directly instead.
+ if not isinstance(input_ids, Tensor):
+ input_ids = Tensor(input_ids, dtype=ms.int32)
+
+ # Padding to max_len when no cache
+ if past_key_values is None:
+ pad_len = max(0, attention_mask.shape[1] - input_ids.shape[1])
+ input_ids = ops.pad(input_ids, (0, pad_len), value=0)
+
+ model_inputs = {"input_ids": input_ids}
+
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
+
+ if cache_position is None:
+ cache_position = ops.arange(past_length, past_length + input_length)
+ elif use_cache:
+ if input_length < cache_position.shape[0]:
+ assert cache_position.shape[0] == attention_mask.shape[-1]
+ cur_len = int(attention_mask.sum(-1).max())
+ cache_position = cache_position[cur_len - input_length : cur_len]
+ else:
+ cache_position = cache_position[-input_length:]
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": ms.mutable(past_key_values) if past_key_values is not None else None,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ raise NotImplementedError
+
+
+@add_start_docstrings(
+ """
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = LlamaModel(config)
+ self.score = nn.Dense(config.hidden_size, self.num_labels, has_bias=False)
+
+ self.loss_fct_regression = nn.MSELoss()
+ self.loss_fct_single_label_classification = nn.CrossEntropyLoss()
+ self.loss_fct_multi_label_classification = nn.BCEWithLogitsLoss()
+
+ problem_type_map = {
+ "regression": 0,
+ "single_label_classification": 1,
+ "multi_label_classification": 2,
+ None: None,
+ }
+ self.problem_type = problem_type_map[config.problem_type]
+ self.pad_token_id = config.pad_token_id
+
+ _name_list = ["output_attentions", "output_hidden_states", "use_return_dict", "pretraining_tp", "vocab_size"]
+ for name in _name_list:
+ setattr(self, name, getattr(config, name))
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def construct(
+ self,
+ input_ids: ms.Tensor = None,
+ attention_mask: Optional[ms.Tensor] = None,
+ position_ids: Optional[ms.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[ms.Tensor, ms.Tensor]]] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ labels: Optional[ms.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = False,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`ms.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ # if self.pad_token_id is None and batch_size != 1:
+ # raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = ops.equal(input_ids, self.pad_token_id).to(ms.int32).argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[ops.arange(batch_size), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.problem_type is None:
+ if self.num_labels == 1:
+ problem_type = 0 # "regression"
+ elif self.num_labels > 1 and (labels.dtype in (ms.int32, ms.int64)):
+ problem_type = 1 # "single_label_classification"
+ else:
+ problem_type = 2 # "multi_label_classification"
+ else:
+ problem_type = self.problem_type
+
+ if problem_type == 0: # "regression"
+ if self.num_labels == 1:
+ loss = self.loss_fct_regression(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = self.loss_fct_regression(pooled_logits, labels)
+ elif problem_type == 1: # "single_label_classification"
+ loss = self.loss_fct_single_label_classification(
+ pooled_logits.view(-1, self.num_labels), labels.view(-1).int()
+ )
+ elif problem_type == 2: # "multi_label_classification"
+ loss = self.loss_fct_multi_label_classification(pooled_logits, labels)
+
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
diff --git a/mindone/transformers/optimization.py b/mindone/transformers/optimization.py
new file mode 100644
index 0000000000..eafb64242d
--- /dev/null
+++ b/mindone/transformers/optimization.py
@@ -0,0 +1,192 @@
+import math
+from functools import partial
+from typing import Optional, Union
+
+from transformers.trainer_utils import SchedulerType
+
+
+def get_constant_schedule(base_lr):
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+ """
+
+ return base_lr
+
+
+def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1.0, num_warmup_steps))
+ return 1.0
+
+
+def get_constant_schedule_with_warmup(base_lr: float, num_warmup_steps: int, num_training_steps: int):
+ """
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ base_lr (`float`):
+ The base learning rate for scheduler.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+
+ Return:
+ `List` with the appropriate schedule.
+ """
+
+ lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
+ return [base_lr * lr_lambda(cur_step) for cur_step in range(num_training_steps)]
+
+
+def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
+
+
+def get_linear_schedule_with_warmup(base_lr: float, num_warmup_steps: int, num_training_steps: int):
+ """
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ base_lr (`float`):
+ The base learning rate for scheduler.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+
+ Return:
+ `List` with the appropriate schedule.
+ """
+
+ lr_lambda = partial(
+ _get_linear_schedule_with_warmup_lr_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ )
+
+ return [base_lr * lr_lambda(cur_step) for cur_step in range(num_training_steps)]
+
+
+def _get_cosine_schedule_with_warmup_lr_lambda(
+ current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
+):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+
+def get_cosine_schedule_with_warmup(
+ base_lr: float, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ base_lr (`float`):
+ The base learning rate for scheduler.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+
+ Return:
+ `List` with the appropriate schedule.
+ """
+
+ lr_lambda = partial(
+ _get_cosine_schedule_with_warmup_lr_lambda,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ num_cycles=num_cycles,
+ )
+ return [base_lr * lr_lambda(cur_step) for cur_step in range(num_training_steps)]
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+ SchedulerType.COSINE_WITH_RESTARTS: None,
+ SchedulerType.POLYNOMIAL: None,
+ SchedulerType.CONSTANT: get_constant_schedule,
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+ SchedulerType.INVERSE_SQRT: None,
+ SchedulerType.REDUCE_ON_PLATEAU: None,
+ SchedulerType.COSINE_WITH_MIN_LR: None,
+ SchedulerType.WARMUP_STABLE_DECAY: None,
+}
+
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ base_lr: Optional[float],
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+ scheduler_specific_kwargs: Optional[dict] = None,
+):
+ """
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`mindspore.nn.Optimizer`):
+ The optimizer that will be used during training.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ scheduler_specific_kwargs (`dict`, *optional*):
+ Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
+ parameters will cause the scheduler function to raise a TypeError.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+
+ # Note: Not support `LayerWiseDummyOptimizer` now.
+ # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
+ # recursively call `get_scheduler` to get the proper schedulers on each parameter
+ # if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
+ # raise NotImplementedError
+
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(base_lr=base_lr)
+
+ if scheduler_specific_kwargs is None:
+ scheduler_specific_kwargs = {}
+
+ if name == SchedulerType.REDUCE_ON_PLATEAU:
+ raise NotImplementedError
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.INVERSE_SQRT:
+ raise NotImplementedError
+
+ if name == SchedulerType.WARMUP_STABLE_DECAY:
+ raise NotImplementedError
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ return schedule_func(
+ base_lr=base_lr,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ **scheduler_specific_kwargs,
+ )
diff --git a/mindone/transformers/requirements.txt b/mindone/transformers/requirements.txt
new file mode 100644
index 0000000000..6ced3b3ab5
--- /dev/null
+++ b/mindone/transformers/requirements.txt
@@ -0,0 +1,6 @@
+transformers==4.42.4
+evaluate
+datasets
+safetensors
+ezcolorlog
+ipython==8.12.3
diff --git a/mindone/transformers/trainer.py b/mindone/transformers/trainer.py
new file mode 100644
index 0000000000..d0b89f847e
--- /dev/null
+++ b/mindone/transformers/trainer.py
@@ -0,0 +1,1675 @@
+import functools
+import inspect
+import math
+import os
+import re
+import shutil
+import sys
+import time
+import warnings
+from collections.abc import Mapping
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
+
+import numpy as np
+from ezcolorlog import root_logger as logger
+from packaging import version
+from transformers import PreTrainedTokenizerBase
+from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
+from transformers.integrations import get_reporting_integration_callbacks
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+from transformers.trainer_callback import (
+ CallbackHandler,
+ DefaultFlowCallback,
+ ExportableState,
+ PrinterCallback,
+ ProgressCallback,
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+)
+from transformers.trainer_utils import (
+ EvalPrediction,
+ RemoveColumnsCollator,
+ get_last_checkpoint,
+ has_length,
+ number_of_arguments,
+ speed_metrics,
+)
+from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_datasets_available, logging
+
+import mindspore as ms
+from mindspore import Tensor, nn, ops
+from mindspore.communication.management import get_group_size
+
+from ..safetensors.mindspore import save_file
+from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
+from .debug_utils import DebugOption
+from .mindspore_adapter import RandomSampler, Sampler, TrainOneStepWrapper, auto_mixed_precision
+from .mindspore_adapter.utils import _is_parallel
+from .mindspore_utils import ALL_LAYERNORM_LAYERS
+from .modeling_utils import MSPreTrainedModel as PreTrainedModel
+from .optimization import get_scheduler
+from .trainer_ms_utils import LabelSmoother, LengthGroupedSampler, get_model_param_count, get_parameter_names
+from .trainer_utils import enable_full_determinism, set_seed
+from .training_args import OptimizerNames, TrainingArguments
+from .utils import can_return_loss, find_labels
+
+if TYPE_CHECKING:
+ import optuna
+
+ if is_datasets_available():
+ import datasets
+
+
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+DEFAULT_PROGRESS_CALLBACK = ProgressCallback
+
+
+def _is_peft_model(model):
+ # TODO: support PEFT Model
+ return False
+
+
+class TrainOutput(NamedTuple):
+ global_step: int
+ training_loss: float
+ metrics: Dict[str, float]
+
+
+PREFIX_CHECKPOINT_DIR = "checkpoint"
+_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
+
+
+# Name of the files used for checkpointing
+TRAINER_STATE_NAME = "trainer_state.json"
+OPTIMIZER_NAME = "optimizer.ckpt"
+# SCHEDULER_NAME = "scheduler.ckpt" # Note: lr_scheduler is already included in the optimizer on MindSpore 2.3.1
+SCALER_NAME = "scaler.ckpt"
+
+
+class Trainer:
+ from .trainer_ms_utils import save_state
+
+ def __init__(
+ self,
+ model: Union[PreTrainedModel, nn.Cell] = None,
+ args: TrainingArguments = None,
+ data_collator: Optional[DataCollator] = None,
+ train_dataset: Optional[Iterable] = None,
+ eval_dataset: Optional[Iterable] = None,
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ optimizers: Tuple[nn.Optimizer, nn.learning_rate_schedule.LearningRateSchedule] = (None, None),
+ preprocess_logits_for_metrics: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
+ ):
+ if args is None:
+ output_dir = "tmp_trainer"
+ logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
+ args = TrainingArguments(output_dir=output_dir)
+ if args.batch_eval_metrics and compute_metrics is not None:
+ if "compute_result" not in inspect.signature(compute_metrics).parameters.keys():
+ raise ValueError(
+ "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
+ " boolean argument which will be triggered after the last batch of the eval set to signal that the"
+ " summary statistics should be returned by the function."
+ )
+ self.args = args
+ # Seed must be set before instantiating the model when using model
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+ self.hp_name = None
+ self.deepspeed = None
+ self.is_in_train = False
+
+ # self.create_accelerator_and_postprocess()
+
+ # memory metrics - must set up as early as possible
+ # self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
+ # self._memory_tracker.start()
+
+ # set the correct log level depending on the node
+ log_level = args.get_process_log_level()
+ logging.set_verbosity(log_level)
+
+ if model is None:
+ if model_init is not None:
+ self.model_init = model_init
+ model = self.call_model_init()
+ else:
+ raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
+ else:
+ if model_init is not None:
+ warnings.warn(
+ "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
+ " overwrite your model when calling the `train` method. This will become a fatal error in the next"
+ " release.",
+ FutureWarning,
+ )
+ self.model_init = model_init
+
+ if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False):
+ self.is_model_parallel = True
+ else:
+ self.is_model_parallel = False
+
+ _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
+ model, "_hf_peft_config_loaded", False
+ )
+ _quantization_method_supports_training = (
+ getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
+ )
+ if _is_quantized_and_base_model or _quantization_method_supports_training:
+ raise NotImplementedError
+
+ # Filter out quantized + compiled models
+ if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
+ raise ValueError(
+ "You cannot fine-tune quantized model with `ms.jit()` or `ms.GRAPH_MODE` make sure to pass a "
+ "non-compiled model when fine-tuning a quantized model with PEFT"
+ )
+
+ # At this stage the model is already loaded
+ if _is_quantized_and_base_model and not _is_peft_model(model):
+ raise ValueError(
+ "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
+ " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
+ " for more details"
+ )
+ elif _is_quantized_and_base_model and not _quantization_method_supports_training:
+ raise ValueError(
+ f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
+ " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
+ f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
+ )
+
+ default_collator = (
+ DataCollatorWithPadding(tokenizer)
+ if tokenizer is not None and isinstance(tokenizer, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
+ else lambda features, batch_info: default_data_collator(features, return_tensors="np")
+ )
+ self.data_collator = data_collator if data_collator is not None else default_collator
+ self.train_dataset = train_dataset
+ self.eval_dataset = eval_dataset
+ self.tokenizer = tokenizer
+
+ self.model = model
+
+ self.neftune_noise_alpha = args.neftune_noise_alpha
+
+ self.compute_metrics = compute_metrics
+ self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
+ self.optimizer, self.lr_scheduler = optimizers
+ if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
+ raise RuntimeError(
+ "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
+ "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+ )
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
+ callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
+ self.callback_handler = CallbackHandler(
+ callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
+ )
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
+
+ # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
+ self._loggers_initialized = False
+
+ # Create distant repo and output directory if needed
+ self.hub_model_id = None
+ if self.args.push_to_hub:
+ # self.init_hf_repo()
+ raise NotImplementedError
+ if self.args.should_save:
+ os.makedirs(self.args.output_dir, exist_ok=True)
+
+ if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
+ raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
+
+ if args.max_steps > 0 and args.num_train_epochs > 0:
+ logger.warning("max_steps is given, it will override any value given in num_train_epochs")
+
+ if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
+ raise ValueError(
+ "The train_dataset does not implement __len__, max_steps has to be specified. "
+ "The number of steps needs to be known in advance for the learning rate scheduler."
+ )
+
+ if train_dataset is not None and args.group_by_length:
+ raise NotImplementedError
+
+ self._signature_columns = None
+
+ # Mixed precision setup
+ self.use_apex = False
+ self.use_cpu_amp = False
+
+ # Label smoothing
+ if self.args.label_smoothing_factor != 0:
+ self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
+ else:
+ self.label_smoother = None
+
+ self.control = TrainerControl()
+
+ self.state = TrainerState(
+ is_local_process_zero=self.is_local_process_zero(),
+ is_world_process_zero=self.is_world_process_zero(),
+ stateful_callbacks=[
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ],
+ )
+ # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
+ # returned to 0 every time flos need to be logged
+ self.current_flos = 0
+ self.hp_search_backend = None
+ default_label_names = find_labels(self.model.__class__)
+ self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
+ self.can_return_loss = can_return_loss(self.model.__class__)
+ self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+
+ # Internal variables to help with automatic batch size reduction
+ self._train_batch_size = args.train_batch_size
+ self._created_lr_scheduler = False
+
+ def add_callback(self, callback):
+ """
+ Add a callback to the current list of [`~transformers.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformers.TrainerCallback`]):
+ A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+ first case, will instantiate a member of that class.
+ """
+ self.callback_handler.add_callback(callback)
+
+ def pop_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.
+
+ If the callback is not found, returns `None` (and no error is raised).
+
+ Args:
+ callback (`type` or [`~transformers.TrainerCallback`]):
+ A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+ first case, will pop the first member of that class found in the list of callbacks.
+
+ Returns:
+ [`~transformers.TrainerCallback`]: The callback removed, if found.
+ """
+ return self.callback_handler.pop_callback(callback)
+
+ def remove_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformers.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformers.TrainerCallback`]):
+ A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+ first case, will remove the first member of that class found in the list of callbacks.
+ """
+ self.callback_handler.remove_callback(callback)
+
+ def _set_signature_columns_if_needed(self):
+ if self._signature_columns is None:
+ # Inspect model forward signature to keep only the arguments it accepts.
+ model_to_inspect = self.model
+ if _is_peft_model(self.model):
+ raise NotImplementedError
+ signature = inspect.signature(model_to_inspect.construct)
+ self._signature_columns = list(signature.parameters.keys())
+ # Labels may be named label or label_ids, the default data collator handles that.
+ self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
+
+ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
+ if not self.args.remove_unused_columns:
+ return dataset
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
+ if len(ignored_columns) > 0:
+ dset_description = "" if description is None else f"in the {description} set"
+ logger.info(
+ f"The following columns {dset_description} don't have a corresponding argument in "
+ f"`{self.model.__class__.__name__}.construct` and have been ignored: {', '.join(ignored_columns)}."
+ f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.construct`, "
+ " you can safely ignore this message."
+ )
+
+ columns = [k for k in signature_columns if k in dataset.column_names]
+ if len(columns) == 0:
+ raise ValueError(
+ "No columns in the dataset match the model's construct method signature. "
+ f"The following columns have been ignored: [{', '.join(ignored_columns)}]. "
+ "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`."
+ )
+
+ if version.parse(datasets.__version__) < version.parse("1.4.0"):
+ dataset.set_format(
+ type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
+ )
+ return dataset
+ else:
+ return dataset.remove_columns(ignored_columns)
+
+ def _get_collator_with_removed_columns(
+ self, data_collator: Callable, description: Optional[str] = None
+ ) -> Callable:
+ """Wrap the data collator in a callable removing unused columns."""
+ if not self.args.remove_unused_columns:
+ return data_collator
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ remove_columns_collator = RemoveColumnsCollator(
+ data_collator=data_collator,
+ signature_columns=signature_columns,
+ logger=logger,
+ description=description,
+ model_name=self.model.__class__.__name__,
+ )
+ return remove_columns_collator
+
+ def create_optimizer_and_scheduler(self, num_training_steps: int):
+ """
+ Setup the optimizer and the learning rate scheduler.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
+ `create_scheduler`) in a subclass.
+ """
+ lr_scheduler = self.create_scheduler(num_training_steps=num_training_steps)
+ self.create_optimizer(lr_scheduler)
+
+ def get_decay_parameter_names(self, model) -> List[str]:
+ """
+ Get all parameter names that weight decay will be applied to
+
+ Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
+ apply to those modules since this function only filter out instance of nn.LayerNorm
+ """
+ decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ return decay_parameters
+
+ def create_optimizer(self, lr_scheduler: Union[Tuple, List] = None):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ opt_model = self.model
+
+ if self.optimizer is None:
+ decay_parameters = self.get_decay_parameter_names(opt_model)
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.parameters_and_names() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p
+ for n, p in opt_model.parameters_and_names()
+ if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)
+
+ # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
+ # e.g. for GaLore optimizer.
+ if "params" in optimizer_kwargs:
+ optimizer_grouped_parameters = optimizer_kwargs.pop("params")
+
+ # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
+ # e.g. for LOMO optimizer.
+ if "model" in optimizer_kwargs:
+ optimizer_grouped_parameters = optimizer_kwargs.pop("model")
+
+ # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
+ # to avoid arguments conflicts.
+ if "optimizer_dict" in optimizer_kwargs:
+ optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
+
+ # Note: Init optimizer with lr scheduler on MindSpore 2.3.1
+ # Update learning rate with lr_scheduler
+ if lr_scheduler is not None:
+ optimizer_kwargs.update({"learning_rate": lr_scheduler})
+
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+
+ if optimizer_cls.__name__ == "Adam8bit":
+ raise NotImplementedError
+
+ return self.optimizer
+
+ @staticmethod
+ def get_optimizer_cls_and_kwargs(
+ args: TrainingArguments, model: Optional[PreTrainedModel] = None
+ ) -> Tuple[Any, Any]:
+ """
+ Returns the optimizer class and optimizer parameters based on the training arguments.
+
+ Args:
+ args (`transformers.training_args.TrainingArguments`):
+ The training arguments for the training session.
+
+ """
+
+ # parse args.optim_args
+ optim_args = {}
+ if args.optim_args:
+ for mapping in args.optim_args.replace(" ", "").split(","):
+ key, value = mapping.split("=")
+ optim_args[key] = value
+
+ optimizer_kwargs = {"learning_rate": args.learning_rate}
+
+ adam_kwargs = {
+ "beta1": args.adam_beta1,
+ "beta2": args.adam_beta2,
+ "eps": args.adam_epsilon,
+ }
+ if args.optim == OptimizerNames.ADAFACTOR:
+ optimizer_cls = nn.AdaFactor
+ optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+ elif args.optim == OptimizerNames.ADAMW_MINDSPORE:
+ from .mindspore_adapter.adamw import AdamWeightDecay
+
+ optimizer_cls = AdamWeightDecay
+ optimizer_kwargs.update(adam_kwargs)
+ optimizer_kwargs.update({"enable_fuse": getattr(args, "adamw_enable_fuse", True)})
+ elif args.optim in (OptimizerNames.ADAMW_ZERO1_MINDSPORE, OptimizerNames.ADAMW_ZERO2_MINDSPORE):
+ from .mindspore_adapter.adamw_zero import AdamWeightDecayZeRO1, AdamWeightDecayZeRO2
+
+ optimizer_cls = (
+ AdamWeightDecayZeRO1 if args.optim == OptimizerNames.ADAMW_ZERO1_MINDSPORE else AdamWeightDecayZeRO2
+ )
+ optimizer_kwargs.update(adam_kwargs)
+ optimizer_kwargs.update({"enable_fuse": getattr(args, "adamw_enable_fuse", True)})
+ optimizer_kwargs.update({"shard_size": getattr(args, "adamw_zero_shard_size", None)})
+ optimizer_kwargs.update({"momentum_dtype": getattr(args, "adamw_zero_momentum_dtype", ms.float32)})
+ elif args.optim == OptimizerNames.SGD:
+ optimizer_cls = nn.SGD
+ elif args.optim == OptimizerNames.Momentum:
+ optimizer_cls = nn.Momentum
+ optimizer_kwargs.update({"momentum": getattr(args, "momentum_value", 0.9)})
+ elif args.optim == OptimizerNames.ADAGRAD:
+ optimizer_cls = nn.Adagrad
+ elif args.optim == OptimizerNames.RMSPROP:
+ optimizer_cls = nn.RMSProp
+ elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+ raise NotImplementedError
+ else:
+ raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
+ return optimizer_cls, optimizer_kwargs
+
+ def create_scheduler(self, num_training_steps: int):
+ """
+ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
+ passed as an argument.
+
+ Args:
+ num_training_steps (int): The number of training steps to do.
+ """
+ if self.lr_scheduler is None:
+ self.lr_scheduler = get_scheduler(
+ self.args.lr_scheduler_type,
+ base_lr=self.args.learning_rate,
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps,
+ scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
+ )
+ self._created_lr_scheduler = True
+ return self.lr_scheduler
+
+ def get_train_dataloader(self) -> ms.dataset.Dataset:
+ """
+ Returns the training [`~mindspore.dataset.GeneratorDataset`].
+
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+ training if necessary) otherwise.
+
+ Subclass and override this method if you want to inject some custom behavior.
+ """
+ if self.train_dataset is None:
+ raise ValueError("Trainer: training requires a train_dataset.")
+
+ train_dataset = self.train_dataset
+ data_collator = self.data_collator
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+
+ class MSDataset:
+ def __init__(self, dataset: datasets.Dataset):
+ self.dataset = dataset
+
+ def __getitem__(self, item):
+ return self.dataset[int(item)]
+
+ def __len__(self):
+ return len(self.dataset)
+
+ train_dataset = self._remove_unused_columns(train_dataset, description="training")
+ train_dataset = MSDataset(train_dataset)
+
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
+
+ if self.args.dataloader_pin_memory:
+ logger.warning("Not support `dataloader_pin_memory`")
+ if self.args.dataloader_persistent_workers:
+ logger.warning("Not support `dataloader_persistent_workers`")
+
+ prefetch_factor = self.args.dataloader_prefetch_factor
+ if prefetch_factor is not None and prefetch_factor > 0:
+ ms.dataset.config.set_prefetch_size(prefetch_factor)
+
+ ds_init_params = {
+ "num_parallel_workers": self.args.dataloader_num_workers,
+ "sampler": self._get_train_sampler(),
+ "python_multiprocessing": False,
+ "num_shards": getattr(self.args, "rank_size", 1),
+ "shard_id": getattr(self.args, "rank", 0),
+ "column_names": "item",
+ }
+
+ ds_batch_params = {
+ "num_parallel_workers": self.args.dataloader_num_workers, # num workers
+ "batch_size": self.args.per_device_train_batch_size, # per device batch size
+ "per_batch_map": data_collator, # collate function
+ "drop_remainder": self.args.dataloader_drop_last, # drop last
+ }
+ ds_repeat_params = {"count": 1} # self.args.num_train_epochs # num_train_epochs, loop at train func
+
+ loader = ms.dataset.GeneratorDataset(train_dataset, **ds_init_params)
+ loader = loader.batch(**ds_batch_params)
+ loader = loader.repeat(**ds_repeat_params)
+
+ logger.info(
+ f"create dataloader success, \n"
+ f"\tshard_id/num_shards: {ds_init_params['shard_id']}/{ds_init_params['num_shards']}\n"
+ f"\tnum_parallel_workers: {ds_init_params['num_parallel_workers']}\n"
+ f"\tpython_multiprocessing: {ds_init_params['python_multiprocessing']}\n"
+ f"\tper_batch_size: {ds_batch_params['batch_size']}"
+ )
+
+ return loader
+
+ def _get_train_sampler(self) -> Optional[Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ # Build the sampler.
+ if self.args.group_by_length:
+ if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
+ lengths = (
+ self.train_dataset[self.args.length_column_name]
+ if self.args.length_column_name in self.train_dataset.column_names
+ else None
+ )
+ else:
+ lengths = None
+ model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
+ return LengthGroupedSampler(
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ dataset=self.train_dataset,
+ lengths=lengths,
+ model_input_name=model_input_name,
+ )
+
+ else:
+ return RandomSampler(self.train_dataset)
+
+ def num_examples(self, dataloader: ms.dataset.Dataset) -> int:
+ if not isinstance(dataloader, ms.dataset.Dataset):
+ dataset = dataloader.dataset
+ return len(dataset)
+ else: # no dataset or length, estimate by length of dataloader
+ # FIXME: Consider parallel scenarios
+ return len(dataloader) * self.args.per_device_train_batch_size
+
+ def num_tokens(self, train_dl: ms.dataset.Dataset, max_steps: Optional[int] = None) -> int:
+ """
+ Helper to get number of tokens in a [`~mindspore.dataset.Dataset`] by enumerating dataloader.
+ """
+ train_tokens = 0
+ try:
+ # FIXME: Consider padding
+ for step, batch in enumerate(train_dl):
+ if isinstance(batch["input_ids"], Tensor):
+ # tokens = batch["input_ids"].numel()
+ tokens = np.prod(batch["input_ids"].shape)
+ elif isinstance(batch["input_ids"], np.ndarray):
+ tokens = batch["input_ids"].size
+ else:
+ tokens = None
+
+ if max_steps is not None:
+ return tokens * max_steps
+ train_tokens += tokens
+ return train_tokens
+ except KeyError:
+ logger.warning("Cannot get num_tokens from dataloader")
+ return train_tokens
+
+ def mindspore_jit_model(self, model, dataloader):
+ # TODO: add pre-compile
+ logger.warning(f"wrap model[{model.__class__.__name__}] to jit model.")
+
+ class JitWarpper(nn.Cell):
+ def __init__(self, model):
+ super(JitWarpper, self).__init__(auto_prefix=False)
+ self.jit_model = model
+
+ @ms.jit
+ def construct(self, *args, **kwargs):
+ self.jit_model(*args, **kwargs)
+
+ return JitWarpper(model)
+
+ def _wrap_model(self, model, dataloader=None):
+ if self.args.jit_mode and ms.get_context("mode") == ms.PYNATIVE_MODE:
+ start_time = time.time()
+ model = self.mindspore_jit_model(model, dataloader)
+
+ # FIXME: just build model, time not included compile cost.
+ self.jit_compilation_time = round(time.time() - start_time, 4)
+
+ # enable auto mix precision
+ assert not (self.args.fp16 and self.args.bf16)
+ amp_level = self.args.amp_opt_level if self.args.amp_opt_level is not None else "O2"
+ if self.args.fp16:
+ model = auto_mixed_precision(model, amp_level, dtype=ms.float16)
+ if self.args.bf16:
+ model = auto_mixed_precision(model, amp_level, dtype=ms.bfloat16)
+
+ # Note: unlike the original transformers, support label_smoother through `Trainer._wrap_model`, and origin support it at `Trainer.compute_loss`
+ if self.label_smoother is not None:
+ signature_columns = list(inspect.signature(self.model.construct).parameters.keys())[1:]
+ input_labels_index = signature_columns.index("labels") if "labels" in signature_columns else None
+
+ class LabelSmootherModel(nn.Cell):
+ def __init__(self, model, label_smoother, labels_index):
+ super(LabelSmootherModel, self).__init__(auto_prefix=False)
+ self.model = model
+ self.label_smoother_ = label_smoother
+ self.labels_index = labels_index
+ self.shift_labels = model._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()
+
+ def construct(self, *inputs):
+ labels = None
+ if self.labels_index is not None:
+ labels = inputs[self.labels_index]
+
+ outputs = self.model(*inputs)
+ loss, logits = outputs[:2]
+ if labels is not None:
+ loss = self.label_smoother_(logits, labels, self.shift_labels)
+
+ return loss
+
+ model_ = LabelSmootherModel(model, self.label_smoother, input_labels_index)
+ else:
+
+ class ReturnLoss(nn.Cell):
+ def __init__(self, model):
+ super(ReturnLoss, self).__init__(auto_prefix=False)
+ self.model = model
+
+ def construct(self, *args, **kwargs):
+ outputs = self.model(*args, **kwargs)
+ loss = outputs[0]
+ return loss
+
+ model_ = ReturnLoss(model)
+
+ # Note: unlike the original transformers, we will define train step process
+ # that include auto mix precision, forward process, loss compute and optimizer step on `train_model`
+ train_model = TrainOneStepWrapper(
+ model_,
+ self.optimizer,
+ ema=None,
+ drop_overflow_step=True,
+ scaler="default",
+ scaler_config={"scale_value": 1024},
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ clip_grad="global_norm",
+ clip_value=self.args.max_grad_norm,
+ )
+
+ return model, train_model
+
+ def train(
+ self,
+ resume_from_checkpoint: Optional[Union[str, bool]] = None,
+ trial: Union["optuna.Trial", Dict[str, Any]] = None,
+ ignore_keys_for_eval: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ """
+ Main training entry point.
+
+ Args:
+ resume_from_checkpoint (`str` or `bool`, *optional*):
+ If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
+ `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
+ of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
+ trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
+ The trial run or the hyperparameter dictionary for hyperparameter search.
+ ignore_keys_for_eval (`List[str]`, *optional*)
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions for evaluation during the training.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments used to hide deprecated arguments
+ """
+ if resume_from_checkpoint is False:
+ resume_from_checkpoint = None
+
+ # memory metrics - must set up as early as possible
+ # self._memory_tracker.start()
+
+ args = self.args
+
+ self.is_in_train = True
+
+ # Attach NEFTune hooks if necessary
+ if self.neftune_noise_alpha is not None:
+ raise NotImplementedError
+
+ if "model_path" in kwargs:
+ resume_from_checkpoint = kwargs.pop("model_path")
+ warnings.warn(
+ "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
+ "instead.",
+ FutureWarning,
+ )
+ if len(kwargs) > 0:
+ raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
+ # This might change the seed so needs to run first.
+ # self._hp_search_setup(trial) # TODO, level 3, Add hyper parameters search function
+ self._train_batch_size = self.args.train_batch_size
+
+ # Model re-init
+ # model_reloaded = False
+ if self.model_init is not None:
+ # Seed must be set before instantiating the model when using model_init.
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+ self.model = self.call_model_init(trial)
+ # model_reloaded = True
+ # Reinitializes optimizer and scheduler
+ self.optimizer, self.lr_scheduler = None, None
+
+ # Load potential model checkpoint
+ if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+ resume_from_checkpoint = get_last_checkpoint(args.output_dir)
+ if resume_from_checkpoint is None:
+ raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
+
+ if resume_from_checkpoint is not None:
+ # self._load_from_checkpoint(resume_from_checkpoint) # load weight later
+
+ if os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)):
+ # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
+ state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+ if state.train_batch_size is not None:
+ self._train_batch_size = state.train_batch_size
+
+ inner_training_loop = functools.partial(
+ self._inner_training_loop, batch_size=self._train_batch_size
+ ) # TODO: level 3, Add find_executable_batch_size function
+ if args.push_to_hub:
+ raise NotImplementedError
+ else:
+ return inner_training_loop(
+ args=args,
+ resume_from_checkpoint=resume_from_checkpoint,
+ trial=trial,
+ ignore_keys_for_eval=ignore_keys_for_eval,
+ )
+
+ def _inner_training_loop(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ self._train_batch_size = batch_size
+ if self.args.auto_find_batch_size:
+ self.state.train_batch_size = self._train_batch_size
+ logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
+ # Data loader and number of training steps
+ train_dataloader = self.get_train_dataloader()
+
+ # Setting up training control variables:
+ # number of training epochs: num_train_epochs
+ # number of training steps per epoch: num_update_steps_per_epoch
+ # total number of training steps to execute: max_steps
+ total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
+
+ len_dataloader = None
+ num_train_tokens = None
+ if has_length(train_dataloader):
+ len_dataloader = len(train_dataloader)
+ num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+ num_examples = self.num_examples(train_dataloader)
+ if args.max_steps > 0:
+ max_steps = args.max_steps
+ num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
+ args.max_steps % num_update_steps_per_epoch > 0
+ )
+ # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+ # the best we can do.
+ num_train_samples = args.max_steps * total_train_batch_size
+ if args.include_tokens_per_second:
+ num_train_tokens = (
+ self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
+ )
+ else:
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+ num_train_epochs = math.ceil(args.num_train_epochs)
+ num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
+ if args.include_tokens_per_second:
+ num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
+ elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
+ max_steps = args.max_steps
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+ num_train_epochs = sys.maxsize
+ num_update_steps_per_epoch = max_steps
+ num_examples = total_train_batch_size * args.max_steps
+ num_train_samples = args.max_steps * total_train_batch_size
+ if args.include_tokens_per_second:
+ num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
+ else:
+ raise ValueError(
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
+ )
+
+ if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+ if self.args.n_gpu > 1:
+ raise ValueError("Currently --debug underflow_overflow is not supported under DP.")
+ else:
+ raise NotImplementedError
+
+ # FIXME: Consider parallelism mode
+ delay_optimizer_creation = False
+
+ # We need to reset the scheduler, as its parameters may be different on subsequent calls
+ if self._created_lr_scheduler:
+ self.lr_scheduler = None
+ self._created_lr_scheduler = False
+
+ if not delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+ else:
+ raise NotImplementedError
+
+ self.state = TrainerState(
+ stateful_callbacks=[
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ]
+ )
+ self.state.is_hyper_param_search = trial is not None
+ self.state.train_batch_size = self._train_batch_size
+
+ # Compute absolute values for logging, eval, and save if given as ratio
+ if args.logging_steps is not None:
+ if args.logging_steps < 1:
+ self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
+ else:
+ self.state.logging_steps = args.logging_steps
+ if args.eval_steps is not None:
+ if args.eval_steps < 1:
+ self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
+ else:
+ self.state.eval_steps = args.eval_steps
+ if args.save_steps is not None:
+ if args.save_steps < 1:
+ self.state.save_steps = math.ceil(max_steps * args.save_steps)
+ else:
+ self.state.save_steps = args.save_steps
+
+ # Activate gradient checkpointing if needed
+ if args.gradient_checkpointing:
+ if args.gradient_checkpointing_kwargs is None:
+ gradient_checkpointing_kwargs = {}
+ else:
+ gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs
+
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
+
+ self.model, self.train_model = self._wrap_model(self.model)
+
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+ # In this case we are in DDP + LOMO, which should be supported
+ raise NotImplementedError
+
+ # ckpt loading
+ if resume_from_checkpoint is not None:
+ logger.info("Checkpoint loading...")
+
+ self._load_from_checkpoint(resume_from_checkpoint)
+
+ # Check if saved optimizer or scheduler states exist
+ self._load_optimizer_and_scheduler(resume_from_checkpoint)
+ else:
+ logger.warning("No available resume checkpoint.")
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {num_examples:,}")
+ logger.info(f" Num Epochs = {num_train_epochs:,}")
+ logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
+ if self.args.per_device_train_batch_size != self._train_batch_size:
+ logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps:,}")
+ logger.info(f" Number of full parameters = {get_model_param_count(self.model, trainable_only=False):,}")
+ logger.info(f" Number of trainable parameters = {get_model_param_count(self.model, trainable_only=True):,}")
+
+ self.state.epoch = 0
+ start_time = time.time()
+ epochs_trained = 0
+ steps_trained_in_current_epoch = 0
+ steps_trained_progress_bar = None
+
+ # Check if continuing training from a checkpoint
+ if resume_from_checkpoint is not None and os.path.isfile(
+ os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
+ ):
+ raise NotImplementedError
+
+ # Update the references
+ self.callback_handler.model = self.model
+ self.callback_handler.optimizer = self.optimizer
+ self.callback_handler.lr_scheduler = self.lr_scheduler
+ self.callback_handler.train_dataloader = train_dataloader
+ if self.hp_name is not None and self._trial is not None:
+ # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
+ # parameter to Train when using DDP.
+ self.state.trial_name = self.hp_name(self._trial)
+ raise NotImplementedError
+ if trial is not None:
+ raise NotImplementedError
+ else:
+ self.state.trial_params = None
+ # This should be the same if the state has been saved but in case the training arguments changed, it's safer
+ # to set this after the load.
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+ tr_loss = 0.0
+ # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+ self._total_loss_scalar = 0.0
+ self._globalstep_last_logged = self.state.global_step
+ grad_norm: Optional[float] = None
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ if args.eval_on_start:
+ self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
+
+ total_batched_samples = 0
+ for epoch in range(epochs_trained, num_train_epochs):
+ epoch_iterator = train_dataloader.create_dict_iterator(num_epochs=1, output_numpy=True)
+ # FIXME: consider resume, skip the previous steps
+ if hasattr(epoch_iterator, "set_epoch"):
+ epoch_iterator.set_epoch(epoch)
+
+ # Reset the past mems state at the beginning of each epoch if necessary.
+ if args.past_index >= 0:
+ self._past = None
+
+ steps_in_epoch = (
+ len(train_dataloader)
+ if len_dataloader is not None
+ else args.max_steps * args.gradient_accumulation_steps
+ )
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+ if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
+ # self._load_rng_state(resume_from_checkpoint) # FIXME: load rng state
+ pass
+
+ rng_to_sync = False
+ steps_skipped = 0
+ if steps_trained_in_current_epoch > 0:
+ raise NotImplementedError
+
+ step = -1
+ for step, inputs in enumerate(epoch_iterator):
+ inputs = inputs["item"]
+
+ total_batched_samples += 1
+
+ if self.args.include_num_input_tokens_seen:
+ raise NotImplementedError
+
+ if rng_to_sync:
+ # self._load_rng_state(resume_from_checkpoint)
+ # rng_to_sync = False
+ raise NotImplementedError
+
+ # Skip past any already trained steps if resuming training
+ if steps_trained_in_current_epoch > 0:
+ raise NotImplementedError
+ elif steps_trained_progress_bar is not None:
+ raise NotImplementedError
+
+ if step % args.gradient_accumulation_steps == 0:
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ self.model.set_train(True)
+ self.train_model.set_train(True)
+ tr_loss_step, overflow = self.training_step(self.train_model, inputs)
+ tr_loss_step = tr_loss_step.asnumpy()
+
+ # TODO: log by callback_fn
+ logger.info(f"Epoch: {epoch}, Step: {step}, tr_loss: {tr_loss_step}, overflow: {overflow}")
+
+ if args.logging_nan_inf_filter and (np.isnan(tr_loss_step) or np.isinf(tr_loss_step)):
+ # if loss is nan or inf simply add the average of previous logged losses
+ tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+ logger.warning("tr_loss exist nan/inf, replace to average of previous")
+ else:
+ tr_loss += tr_loss_step
+
+ self.current_flos += float(self.floating_point_ops(inputs))
+
+ is_last_step_and_steps_less_than_grad_acc = (
+ steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
+ )
+
+ if (
+ total_batched_samples % args.gradient_accumulation_steps == 0
+ or
+ # last step in epoch but step is always smaller than gradient_accumulation_steps
+ is_last_step_and_steps_less_than_grad_acc
+ ):
+ # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
+ # in accelerate. So, explicitly enable sync gradients to True in that case.
+ if is_last_step_and_steps_less_than_grad_acc:
+ logger.warning("last step not gradient_accumulation_steps, skip.")
+
+ self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
+ self.state.global_step += 1
+ self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+
+ self._maybe_log_save_evaluate(tr_loss, grad_norm, self.model, trial, epoch, ignore_keys_for_eval)
+ else:
+ self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ break
+
+ if step < 0:
+ logger.warning(
+ "There seems to be not a single sample in your epoch_iterator, stopping training at step"
+ f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+ f" num_steps ({max_steps}) higher than the number of available samples."
+ )
+ self.control.should_training_stop = True
+
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+ self._maybe_log_save_evaluate(tr_loss, grad_norm, self.model, trial, epoch, ignore_keys_for_eval)
+
+ if self.control.should_training_stop:
+ break
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of training
+ delattr(self, "_past")
+
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+ if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+ raise NotImplementedError
+
+ # add remaining tr_loss
+ self._total_loss_scalar += tr_loss.item()
+ effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
+ train_loss = self._total_loss_scalar / effective_global_step
+
+ metrics = speed_metrics(
+ "train",
+ start_time,
+ num_samples=num_train_samples,
+ num_steps=self.state.max_steps,
+ num_tokens=num_train_tokens,
+ )
+ self.store_flos()
+ metrics["total_flos"] = self.state.total_flos
+ metrics["train_loss"] = train_loss
+
+ self.is_in_train = False
+
+ # TODO: level 3, Add memory tracker
+ # self._memory_tracker.stop_and_update_metrics(metrics)
+
+ self.log(metrics)
+
+ run_dir = self._get_output_dir(trial)
+
+ # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
+ if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
+ for checkpoint in checkpoints_sorted:
+ if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint)
+
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+ # Wait for the checkpoint to be uploaded.
+ self._finish_current_push()
+
+ # After training we make sure to retrieve back the original forward pass method
+ # for the embedding layer by removing the forward post hook.
+ if self.neftune_noise_alpha is not None:
+ raise NotImplementedError
+
+ return TrainOutput(self.state.global_step, train_loss, metrics)
+
+ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+ if model is None:
+ model = self.model
+
+ if os.path.isfile(resume_from_checkpoint):
+ s_time = time.time()
+ state_dict = ms.load_checkpoint(resume_from_checkpoint)
+ m, u = ms.load_param_into_net(model, state_dict)
+
+ m = [n for n in m if ("_buffer" not in n) and (".inv_freq" not in n)]
+ if len(m) > 0:
+ logger.info(f"WARNING: missing keys num: {len(m)}, names (top 100): {m[:10]}")
+ if len(u) > 0:
+ logger.info(f"WARNING: unexpected keys num: {len(u)}, names (top 100): {u[:10]}")
+
+ logger.info(
+ f"load checkpoint from `{resume_from_checkpoint}` success, time cost: {time.time() - s_time:.2f}s"
+ )
+ else:
+ logger.warning(f"resume_from_checkpoint is not file: `{resume_from_checkpoint}`")
+
+ def _load_optimizer_and_scheduler(self, resume_from_checkpoint):
+ if resume_from_checkpoint is None:
+ return
+
+ # get path to file
+ OPTIMIZER_PATH = os.path.join(resume_from_checkpoint, OPTIMIZER_NAME)
+
+ # Note: lr_scheduler is already included in the optimizer on MindSpore 2.3.1
+ # LR_PATH = os.path.join(resume_from_checkpoint, SCHEDULER_NAME)
+
+ if os.path.isfile(OPTIMIZER_PATH):
+ optimizer_state = ms.load_checkpoint(OPTIMIZER_PATH)
+ optimizer_state = optimizer_state["optimizer_state"]
+ ms.load_param_into_net(self.optimizer, optimizer_state)
+ logger.info(f"Optimizer state successfully loaded from {OPTIMIZER_PATH}")
+ else:
+ logger.warning(f"Not exist optimizer state checkpoint path: `{OPTIMIZER_PATH}`")
+
+ # Note: lr_scheduler is already included in the optimizer on MindSpore 2.3.1
+ # if os.path.isfile(LR_PATH):
+ # lr_scheduler_state = ms.load_checkpoint(LR_PATH)
+ # ms.load_param_into_net(self.lr_scheduler, lr_scheduler_state)
+ # logger.info(f"LR scheduler state successfully loaded from {LR_PATH}")
+ # else:
+ # logger.warning(f"Not exist lr scheduler state checkpoint path: `{LR_PATH}`")
+
+ print("Loaded optimizer and lr scheduler state done.")
+
+ def _nested_reduce_sum(self, tensors, name=None):
+ """
+ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+ concatenating them to `gathered`
+ """
+ if tensors is None:
+ return
+
+ if self.args.framework == "mindspore":
+ if _is_parallel():
+ return ops.AllReduce()(tensors).mean()
+ else:
+ raise NotImplementedError
+
+ return tensors
+
+ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
+ if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
+ logs: Dict[str, float] = {}
+
+ # FIXME: consider parallel reduce
+ # get average loss over all processes
+ # tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+ # if _is_parallel():
+ # tr_loss_scalar = self._nested_reduce_sum(tr_loss).item() / get_group_size()
+ # else:
+ # tr_loss_scalar = tr_loss.item()
+ tr_loss_scalar = tr_loss.item() if isinstance(tr_loss, (Tensor, np.ndarray)) else tr_loss
+
+ # reset tr_loss to zero
+ tr_loss -= tr_loss
+
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+ if grad_norm is not None:
+ logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, (Tensor, np.ndarray)) else grad_norm
+ # logs["learning_rate"] = _get_learning_rate(self.optimizer, self.state.global_step) # FIXME: may causl memory leak?
+
+ self._total_loss_scalar += tr_loss_scalar
+ self._globalstep_last_logged = self.state.global_step
+ self.store_flos()
+
+ self.log(logs)
+
+ metrics = None
+ if self.control.should_evaluate:
+ metrics = self._evaluate(trial, ignore_keys_for_eval)
+
+ if self.control.should_save:
+ self._save_checkpoint(model, trial, metrics=metrics)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
+ # want to save except FullyShardedDDP.
+ # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
+
+ # Save model checkpoint
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ if self.hp_search_backend is None and trial is None:
+ self.store_flos()
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+ self.save_model(output_dir, _internal_call=True)
+
+ if not self.args.save_only_model:
+ # Save optimizer and scheduler
+ # self._save_optimizer_and_scheduler(output_dir)
+
+ # Save RNG state
+ # self._save_rng_state(output_dir)
+
+ raise NotImplementedError
+
+ # Determine the new best metric / best model checkpoint
+ if metrics is not None and self.args.metric_for_best_model is not None:
+ metric_to_check = self.args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ try:
+ metric_value = metrics[metric_to_check]
+ except KeyError as exc:
+ raise KeyError(
+ f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
+ f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
+ ) from exc
+
+ operator = np.greater if self.args.greater_is_better else np.less
+ if (
+ self.state.best_metric is None
+ or self.state.best_model_checkpoint is None
+ or operator(metric_value, self.state.best_metric)
+ ):
+ self.state.best_metric = metric_value
+ self.state.best_model_checkpoint = output_dir
+
+ # Save the Trainer state
+ if self.args.should_save:
+ # Update the `TrainerControl` state to where we are currently
+ self.state.stateful_callbacks["TrainerControl"] = self.control.state()
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+
+ if self.args.push_to_hub:
+ # self._push_from_checkpoint(output_dir)
+ raise NotImplementedError
+
+ # Maybe delete some older checkpoints.
+ if self.args.should_save:
+ # Solely rely on numerical checkpoint id for rotation.
+ # mtime is not reliable especially on some fuse fs in cloud environments.
+ self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
+
+ def log(self, logs: Dict[str, float]) -> None:
+ """
+ Log `logs` on the various objects watching training.
+
+ Subclass and override this method to inject custom behavior.
+
+ Args:
+ logs (`Dict[str, float]`):
+ The values to log.
+ """
+ if self.state.epoch is not None:
+ logs["epoch"] = self.state.epoch
+ if self.args.include_num_input_tokens_seen:
+ logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
+
+ output = {**logs, **{"step": self.state.global_step}}
+ self.state.log_history.append(output)
+ self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
+
+ def _prepare_input(self, data: Union[Tensor, Any]) -> Union[Tensor, Any]:
+ """
+ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+ """
+ if isinstance(data, Mapping):
+ return type(data)({k: self._prepare_input(v) for k, v in data.items()})
+ elif isinstance(data, (tuple, list)):
+ return type(data)(self._prepare_input(v) for v in data)
+ elif isinstance(data, ms.Tensor):
+ if hasattr(self.args, "input_dtype"):
+ # NLP models inputs are int/uint and those get adjusted to the right dtype of the
+ # embedding. Other models such as wav2vec2's inputs are already float and thus
+ # may need special handling to match the dtypes of the model
+ if data.dtype in (ms.int32, ms.int64, ms.bool_):
+ return data
+
+ kwargs = {"dtype": self.args.input_dtype}
+ return data.to(**kwargs)
+
+ elif isinstance(data, np.ndarray):
+ return self._prepare_input(Tensor(data))
+
+ return data
+
+ def _prepare_inputs(self, inputs: Dict[str, Union[Tensor, Any]]) -> Dict[str, Union[Tensor, Any]]:
+ """
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
+ handling potential state.
+ """
+ inputs = self._prepare_input(inputs)
+ if len(inputs) == 0:
+ raise ValueError(
+ "The batch received was empty, your model won't be able to train on it. Double-check that your "
+ f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
+ )
+ if self.args.past_index >= 0 and self._past is not None:
+ inputs["mems"] = self._past
+
+ return inputs
+
+ def _prepare_inputs_ms(self, inputs: Dict[str, Union[Tensor, Any]]):
+ if len(inputs) == 0:
+ raise ValueError(
+ "The batch received was empty, your model won't be able to train on it. Double-check that your "
+ f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
+ )
+ if self.args.past_index >= 0 and self._past is not None:
+ inputs["mems"] = self._past
+
+ # 1. get model args
+ model_to_inspect = self.model
+ signature = inspect.signature(model_to_inspect.construct)
+ for n, p in signature.parameters.items():
+ assert p.kind in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.VAR_POSITIONAL,
+ ), f"construct func input not position args, check in `class {model_to_inspect.__class__.__name__}`"
+ _signature_columns = list(signature.parameters.keys())
+ _signature_columns = _signature_columns[1:] if _signature_columns[0] == self else _signature_columns
+
+ input_keys = _signature_columns
+ dict_inputs = inputs
+ input_len = max([input_keys.index(k) for k in dict_inputs]) + 1
+
+ # 2. to tuple
+ tuple_inputs = ()
+ for k in input_keys[:input_len]:
+ if k not in dict_inputs:
+ assert not isinstance(signature.parameters[k].default, inspect._empty)
+ v = signature.parameters[k].default
+ else:
+ v = dict_inputs.pop(k)
+ if isinstance(v, (tuple, list)):
+ tuple_inputs += (*v,)
+ else:
+ tuple_inputs += (v,)
+ if len(dict_inputs) > 0:
+ logger.warning(
+ f"input args {dict_inputs.keys()} not found in {self.model.__class__.__name__}, ignore them."
+ )
+
+ # 3. to tensor
+ inputs = ()
+ for data in tuple_inputs:
+ if data is not None:
+ if hasattr(self.args, "input_dtype") and data.dtype in (np.float16, np.float32, np.float64):
+ data = ms.Tensor(data, dtype=self.args.input_dtype)
+ elif data.dtype in (np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, np.int64):
+ data = ms.Tensor(data, dtype=ms.int32)
+ else:
+ data = ms.Tensor(data)
+ inputs += (data,)
+
+ return inputs
+
+ def call_model_init(self, trial=None):
+ model_init_argcount = number_of_arguments(self.model_init)
+ if model_init_argcount == 0:
+ model = self.model_init()
+ elif model_init_argcount == 1:
+ model = self.model_init(trial)
+ else:
+ raise RuntimeError("model_init should have 0 or 1 argument.")
+
+ if model is None:
+ raise RuntimeError("model_init should not return None.")
+
+ return model
+
+ def training_step(self, model: nn.Cell, inputs: Dict[str, Union[ms.Tensor, Any]]) -> Tuple[ms.Tensor, ms.Tensor]:
+ """
+ Perform a training step on a batch of inputs.
+
+ Subclass and override to inject custom behavior.
+
+ Args:
+ model (`nn.Cell`):
+ The model to train.
+ inputs (`Dict[str, Union[ms.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+
+ Return:
+ `Tuple[ms.Tensor, ms.Tensor]`: The tensor with training loss and overflow flag on this batch.
+ """
+ train_model = model
+ train_model.set_train()
+
+ tuple_inputs = self._prepare_inputs_ms(inputs)
+
+ loss, _, overflow = train_model(*tuple_inputs)
+
+ # For LOMO optimizers you need to explicitly use the learnign rate
+ if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+ raise NotImplementedError
+
+ if self.args.n_gpu > 1:
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
+
+ if self.use_apex:
+ raise NotImplementedError
+
+ return loss / self.args.gradient_accumulation_steps, overflow
+
+ def compute_loss(self, model, inputs, return_outputs=False):
+ raise NotImplementedError
+
+ def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
+ raise NotImplementedError
+
+ def _get_output_dir(self, trial):
+ if self.hp_search_backend is not None and trial is not None:
+ raise NotImplementedError
+ else:
+ run_dir = self.args.output_dir
+ return run_dir
+
+ def is_local_process_zero(self) -> bool:
+ """
+ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
+ machines) main process.
+ """
+ return self.args.local_process_index == 0
+
+ def is_world_process_zero(self) -> bool:
+ """
+ Whether or not this process is the global main process (when training in a distributed fashion on several
+ machines, this is only going to be `True` for one process).
+ """
+ # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
+ # process index.
+ return self.args.process_index == 0
+
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
+ """
+ Will save the model, so you can reload it using `from_pretrained()`.
+
+ Will only save from the main process.
+ """
+
+ if output_dir is None:
+ output_dir = self.args.output_dir
+
+ if self.args.should_save:
+ self._save(output_dir)
+
+ # Push to the Hub when `save_model` is called by the user.
+ if self.args.push_to_hub and not _internal_call:
+ # self.push_to_hub(commit_message="Model save")
+ raise NotImplementedError
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ # If we are executing this function, we are the process zero, so we don't check for that.
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving model checkpoint to {output_dir}")
+
+ supported_classes = (PreTrainedModel,)
+ # Save a trained model and configuration using `save_pretrained()`.
+ # They can then be reloaded using `from_pretrained()`
+ if not isinstance(self.model, supported_classes):
+ if state_dict is None:
+ state_dict = {k: v for k, v in self.model.parameters_and_names()}
+
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+
+ if self.args.save_safetensors:
+ save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "ms"})
+ else:
+ ms.save_checkpoint(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+ else:
+ self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
+
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(output_dir)
+
+ # TODO: save args
+ # Good practice: save your training arguments together with the trained model
+ # torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ def store_flos(self):
+ # Storing the number of floating-point operations that went into the model
+
+ if _is_parallel():
+ # FIXME: consider parallel reduce when dynamic size
+ # self.state.total_flos += (
+ # ops.AllReduce()(Tensor(self.current_flos, ms.float32)).item()
+ # )
+ self.state.total_flos += self.current_flos * get_group_size()
+ self.current_flos = 0
+ else:
+ self.state.total_flos += self.current_flos
+ self.current_flos = 0
+
+ def _sorted_checkpoints(
+ self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
+ ) -> List[str]:
+ ordering_and_checkpoint_path = []
+
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
+
+ for path in glob_checkpoints:
+ if use_mtime:
+ ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
+ else:
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
+ if regex_match is not None and regex_match.groups() is not None:
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
+
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
+ # Make sure we don't delete the best model.
+ if (
+ self.state.best_model_checkpoint is not None
+ and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted
+ ):
+ best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
+ for i in range(best_model_index, len(checkpoints_sorted) - 2):
+ checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
+ return checkpoints_sorted
+
+ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
+ if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
+ return
+
+ # Check if we should delete older checkpoint(s)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
+ if len(checkpoints_sorted) <= self.args.save_total_limit:
+ return
+
+ # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
+ # we don't do to allow resuming.
+ save_total_limit = self.args.save_total_limit
+ if (
+ self.state.best_model_checkpoint is not None
+ and self.args.save_total_limit == 1
+ and checkpoints_sorted[-1] != self.state.best_model_checkpoint
+ ):
+ save_total_limit = 2
+
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
+ for checkpoint in checkpoints_to_be_deleted:
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+ def floating_point_ops(self, inputs: Dict[str, Union[Tensor, Any]]):
+ """
+ For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
+ operations for every backward + forward pass. If using another model, either implement such a method in the
+ model or subclass and override this method.
+
+ Args:
+ inputs (`Dict[str, Union[ms.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ Returns:
+ `int`: The number of floating-point operations.
+ """
+ if hasattr(self.model, "floating_point_ops"):
+ return self.model.floating_point_ops(inputs)
+ else:
+ return 0
+
+ def _finish_current_push(self):
+ if not hasattr(self, "push_in_progress"):
+ return
+ if self.push_in_progress is not None and not self.push_in_progress.is_done():
+ logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
+ self.push_in_progress.wait_until_done()
diff --git a/mindone/transformers/trainer_ms_utils.py b/mindone/transformers/trainer_ms_utils.py
new file mode 100644
index 0000000000..9af7d7ac14
--- /dev/null
+++ b/mindone/transformers/trainer_ms_utils.py
@@ -0,0 +1,210 @@
+import os
+from dataclasses import dataclass
+from typing import Iterable, List, Optional
+
+import numpy as np
+from transformers import BatchEncoding, logging
+
+import mindspore as ms
+from mindspore import Tensor, nn, ops
+
+from .mindspore_adapter import Sampler
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class LabelSmoother:
+ """
+ Adds label-smoothing on a pre-computed output from a Transformers model.
+
+ Args:
+ epsilon (`float`, *optional*, defaults to 0.1):
+ The label smoothing factor.
+ ignore_index (`int`, *optional*, defaults to -100):
+ The index in the labels to ignore when computing the loss.
+ """
+
+ epsilon: float = 0.1
+ ignore_index: int = -100
+
+ def __call__(self, logits: ms.Tensor, labels: ms.Tensor, shift_labels: bool = False):
+ if shift_labels:
+ logits = logits[..., :-1, :]
+ labels = labels[..., 1:]
+
+ log_probs = -ops.log_softmax(logits.to(ms.float32), axis=-1).to(logits.dtype)
+ if labels.ndim == log_probs.ndim - 1:
+ labels = labels.unsqueeze(-1)
+
+ padding_mask = labels.equal(self.ignore_index)
+ # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
+ # will ignore them in any case.
+ labels = ops.clamp(labels, min=0)
+ nll_loss = log_probs.gather_elements(dim=-1, index=labels)
+ # works for fp16 input tensor too, by internally upcasting it to fp32
+ smoothed_loss = log_probs.sum(axis=-1, keepdims=True, dtype=ms.float32)
+
+ nll_loss = nll_loss.masked_fill(padding_mask, 0.0)
+ smoothed_loss = smoothed_loss.masked_fill(padding_mask, 0.0)
+
+ # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
+ num_active_elements = padding_mask.numel() - padding_mask.to(ms.int32).sum()
+ nll_loss = nll_loss.sum() / num_active_elements
+ smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
+ return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
+
+
+def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None):
+ """
+ Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
+ lengths. To do this, the indices are:
+
+ - randomly permuted
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
+ - sorted by length in each mega-batch
+
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
+ maximum length placed first, so that an OOM happens sooner rather than later.
+ """
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
+ if mega_batch_mult is None:
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
+ # Just in case, for tiny datasets
+ if mega_batch_mult == 0:
+ mega_batch_mult = 1
+
+ # We need to use numpy for the random part as a distributed sampler will set the random seed for numpy.
+ indices = np.random.permutation(len(lengths))
+ megabatch_size = mega_batch_mult * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+
+ # The rest is to get the biggest batch first.
+ # Since each megabatch is sorted by descending length, the longest element is the first
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
+ max_idx = np.argmax(np.array(megabatch_maximums)).item()
+ # Switch to put the longest element in first position
+ megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ dataset: Optional[Iterable] = None,
+ lengths: Optional[List[int]] = None,
+ model_input_name: Optional[str] = None,
+ ):
+ if dataset is None and lengths is None:
+ raise ValueError("One of dataset and lengths must be provided.")
+
+ self.batch_size = batch_size
+ if lengths is None:
+ model_input_name = model_input_name if model_input_name is not None else "input_ids"
+ if (
+ not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
+ or model_input_name not in dataset[0]
+ ):
+ raise ValueError(
+ "Can only automatically infer lengths for datasets whose items are dictionaries with an "
+ f"'{model_input_name}' key."
+ )
+ lengths = [len(feature[model_input_name]) for feature in dataset]
+ elif isinstance(lengths, ms.Tensor):
+ logger.info(
+ "If lengths is a ms.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
+ )
+ lengths = lengths.tolist()
+
+ self.lengths = lengths
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ indices = get_length_grouped_indices(self.lengths, self.batch_size)
+ return iter(indices)
+
+
+def get_model_param_count(model, trainable_only=False):
+ """
+ Calculate model's total param count. If trainable_only is True then count only those requiring grads
+ """
+
+ def numel(p):
+ # return p.numel()
+ return np.prod(p.shape)
+
+ return sum(numel(p) for p in model.get_parameters() if not trainable_only or p.requires_grad)
+
+
+def get_parameter_names(model: nn.Cell, forbidden_layer_types):
+ """
+ Returns the names of the model parameters that are not inside a forbidden layer.
+ """
+
+ # method 1
+ # _neg_result = []
+ # for name, child in model.cells_and_names():
+ # if isinstance(child, tuple(forbidden_layer_types)):
+ # _neg_result += [n for n, _ in child.parameters_and_names(expand=False)]
+ #
+ # result = []
+ # for p_name, _ in model.parameters_and_names():
+ # if p_name not in _neg_result:
+ # result += [p_name,]
+ #
+ # return result
+
+ # method 2
+ result = []
+ for name, child in model.name_cells().items():
+ result += [
+ f"{name}.{n}"
+ for n in get_parameter_names(child, forbidden_layer_types)
+ if not isinstance(child, tuple(forbidden_layer_types))
+ ]
+ # Add model specific parameters (defined with nn.Parameter) since they are not in any child.
+ result += [n for n, p in model.parameters_and_names(expand=False)]
+ return result
+
+
+def _get_learning_rate(object, global_step):
+ if isinstance(object, nn.Optimizer):
+ optimizer = object
+ if optimizer.dynamic_lr:
+ if optimizer.is_group_lr:
+ lr_cell = optimizer.learning_rate[0]
+ cur_lr = lr_cell(Tensor(global_step, ms.int32)).asnumpy().item()
+ else:
+ cur_lr = optimizer.learning_rate(Tensor(global_step, ms.int32)).asnumpy().item()
+ else:
+ cur_lr = optimizer.learning_rate.asnumpy().item()
+ elif isinstance(object, nn.learning_rate_schedule.LearningRateSchedule):
+ lr_cell = object
+ cur_lr = lr_cell(Tensor(global_step, ms.int32)).asnumpy().item()
+ else:
+ raise NotImplementedError
+
+ return cur_lr
+
+
+def save_state(self):
+ """
+ Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model
+
+ Under distributed environment this is done only for a process with rank 0.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ path = os.path.join(self.args.output_dir, "trainer_state.json")
+ self.state.save_to_json(path)
diff --git a/mindone/transformers/trainer_utils.py b/mindone/transformers/trainer_utils.py
new file mode 100644
index 0000000000..1d9574cd75
--- /dev/null
+++ b/mindone/transformers/trainer_utils.py
@@ -0,0 +1,37 @@
+import os
+import random
+
+import numpy as np
+
+import mindspore as ms
+
+
+def enable_full_determinism(seed: int):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://www.mindspore.cn/docs/zh-CN/r2.3.1/index.html for MindSpore
+ """
+ # set seed first
+ set_seed(seed)
+
+ ms.set_context(deterministic="ON")
+ print("WARNING: Set mindspore context `deterministic=ON`")
+
+ os.environ["HCCL_DETERMINISTIC"] = "true"
+ os.environ["ASCEND_LAUNCH_BLOCKING"] = "1"
+ os.environ["TE_PARALLEL_COMPILER"] = "1"
+
+
+def set_seed(seed: int):
+ """
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `mindspore`.
+
+ Args:
+ seed (`int`):
+ The seed to set.
+ deterministic (`bool`, *optional*, defaults to `False`):
+ Whether to use deterministic algorithms where available. Can slow down training.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ ms.set_seed(seed)
diff --git a/mindone/transformers/training_args.py b/mindone/transformers/training_args.py
new file mode 100644
index 0000000000..2cfb1b387a
--- /dev/null
+++ b/mindone/transformers/training_args.py
@@ -0,0 +1,1369 @@
+import json
+import math
+import os
+import warnings
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import List, Optional, Union
+
+from transformers import is_safetensors_available, logging
+from transformers.trainer_utils import EvaluationStrategy, HubStrategy, IntervalStrategy, SchedulerType
+from transformers.utils.generic import ExplicitEnum, cached_property
+
+import mindspore as ms
+from mindspore.communication.management import get_group_size, get_rank
+
+from .debug_utils import DebugOption
+from .mindspore_adapter.utils import _is_parallel
+
+logger = logging.get_logger(__name__)
+log_levels = logging.get_log_levels_dict().copy()
+trainer_log_levels = dict(**log_levels, passive=-1)
+
+
+def default_logdir() -> str:
+ """
+ Same default as PyTorch
+ """
+ import socket
+ from datetime import datetime
+
+ current_time = datetime.now().strftime("%b%d_%H-%M-%S")
+ return os.path.join("runs", current_time + "_" + socket.gethostname())
+
+
+class OptimizerNames(ExplicitEnum):
+ """
+ Stores the acceptable string identifiers for optimizers.
+ """
+
+ ADAMW_MINDSPORE = "adamw_mindspore"
+ ADAMW_ZERO1_MINDSPORE = "adamw_zero1_mindspore"
+ ADAMW_ZERO2_MINDSPORE = "adamw_zero2_mindspore"
+ ADAFACTOR = "adafactor"
+ SGD = "sgd"
+ Momentum = "momentum"
+ ADAGRAD = "adagrad"
+ RMSPROP = "rmsprop"
+ LOMO = "lomo"
+ ADALOMO = "adalomo"
+
+
+# Sometimes users will pass in a `str` repr of a dict in the CLI
+# We need to track what fields those can be. Each time a new arg
+# has a dict type, it must be added to this list.
+# Important: These should be typed with Optional[Union[dict,str,...]]
+_VALID_DICT_FIELDS = [
+ "gradient_checkpointing_kwargs",
+ "lr_scheduler_kwargs",
+]
+
+
+def _convert_str_dict(passed_value: dict):
+ "Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
+ for key, value in passed_value.items():
+ if isinstance(value, dict):
+ passed_value[key] = _convert_str_dict(value)
+ elif isinstance(value, str):
+ # First check for bool and convert
+ if value.lower() in ("true", "false"):
+ passed_value[key] = value.lower() == "true"
+ # Check for digit
+ elif value.isdigit():
+ passed_value[key] = int(value)
+ elif value.replace(".", "", 1).isdigit():
+ passed_value[key] = float(value)
+
+ return passed_value
+
+
+# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a
+# few keys: https://github.com/huggingface/transformers/pull/25903
+@dataclass
+class TrainingArguments:
+ """
+ TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
+ itself**.
+
+ Using [`HfArgumentParser`] we can turn this class into
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+ command line.
+
+ Parameters:
+ output_dir (`str`):
+ The output directory where the model predictions and checkpoints will be written.
+ overwrite_output_dir (`bool`, *optional*, defaults to `False`):
+ If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
+ points to a checkpoint directory.
+ do_train (`bool`, *optional*, defaults to `False`):
+ Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used
+ by your training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ do_eval (`bool`, *optional*):
+ Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is
+ different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your
+ training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ do_predict (`bool`, *optional*, defaults to `False`):
+ Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's
+ intended to be used by your training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
+ The evaluation strategy to adopt during training. Possible values are:
+
+ - `"no"`: No evaluation is done during training.
+ - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
+ - `"epoch"`: Evaluation is done at the end of each epoch.
+
+ prediction_loss_only (`bool`, *optional*, defaults to `False`):
+ When performing evaluation and generating predictions, only returns the loss.
+ per_device_train_batch_size (`int`, *optional*, defaults to 8):
+ The batch size per NPU/GPU/XPU/TPU/MPS core/CPU for training.
+ per_device_eval_batch_size (`int`, *optional*, defaults to 8):
+ The batch size per NPU/GPU/XPU/TPU/MPS core/CPU for evaluation.
+ gradient_accumulation_steps (`int`, *optional*, defaults to 1):
+ Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
+
+
+
+ When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,
+ evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.
+
+
+
+ eval_accumulation_steps (`int`, *optional*):
+ Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
+ left unset, the whole predictions are accumulated on NPU/GPU/TPU before being moved to the CPU (faster but
+ requires more memory).
+ eval_delay (`float`, *optional*):
+ Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
+ eval_strategy.
+ learning_rate (`float`, *optional*, defaults to 5e-5):
+ The initial learning rate for [`AdamW`] optimizer.
+ weight_decay (`float`, *optional*, defaults to 0):
+ The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`]
+ optimizer.
+ adam_beta1 (`float`, *optional*, defaults to 0.9):
+ The beta1 hyperparameter for the [`AdamW`] optimizer.
+ adam_beta2 (`float`, *optional*, defaults to 0.999):
+ The beta2 hyperparameter for the [`AdamW`] optimizer.
+ adam_epsilon (`float`, *optional*, defaults to 1e-8):
+ The epsilon hyperparameter for the [`AdamW`] optimizer.
+ momentum_value (`float`, *optional*, defaults to 0.9):
+ The momentum hyperparameter for the [`Momentum`] optimizer.
+ max_grad_norm (`float`, *optional*, defaults to 1.0):
+ Maximum gradient norm (for gradient clipping).
+ num_train_epochs(`float`, *optional*, defaults to 3.0):
+ Total number of training epochs to perform (if not an integer, will perform the decimal part percents of
+ the last epoch before stopping training).
+ max_steps (`int`, *optional*, defaults to -1):
+ If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+ For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+ `max_steps` is reached.
+ lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
+ The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
+ lr_scheduler_kwargs ('dict', *optional*, defaults to {}):
+ The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.
+ warmup_ratio (`float`, *optional*, defaults to 0.0):
+ Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
+ warmup_steps (`int`, *optional*, defaults to 0):
+ Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
+ log_level (`str`, *optional*, defaults to `passive`):
+ Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug',
+ 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and keeps the
+ current log level for the Transformers library (which will be `"warning"` by default).
+ log_level_replica (`str`, *optional*, defaults to `"warning"`):
+ Logger log level to use on replicas. Same choices as `log_level`"
+ log_on_each_node (`bool`, *optional*, defaults to `True`):
+ In multinode distributed training, whether to log using `log_level` once per node, or only on the main
+ node.
+ logging_dir (`str`, *optional*):
+ [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
+ *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.
+ logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+ The logging strategy to adopt during training. Possible values are:
+
+ - `"no"`: No logging is done during training.
+ - `"epoch"`: Logging is done at the end of each epoch.
+ - `"steps"`: Logging is done every `logging_steps`.
+
+ logging_first_step (`bool`, *optional*, defaults to `False`):
+ Whether to log the first `global_step` or not.
+ logging_steps (`int` or `float`, *optional*, defaults to 500):
+ Number of update steps between two logs if `logging_strategy="steps"`. Should be an integer or a float in
+ range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
+ logging_nan_inf_filter (`bool`, *optional*, defaults to `True`):
+ Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is `nan`
+ or `inf` is filtered and the average loss of the current logging window is taken instead.
+
+
+
+ `logging_nan_inf_filter` only influences the logging of loss values, it does not change the behavior the
+ gradient is computed or applied to the model.
+
+
+
+ save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+ The checkpoint save strategy to adopt during training. Possible values are:
+
+ - `"no"`: No save is done during training.
+ - `"epoch"`: Save is done at the end of each epoch.
+ - `"steps"`: Save is done every `save_steps`.
+
+ If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
+ very end of training, always.
+ save_steps (`int` or `float`, *optional*, defaults to 500):
+ Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
+ float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
+ save_total_limit (`int`, *optional*):
+ If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
+ `output_dir`. When `load_best_model_at_end` is enabled, the "best" checkpoint according to
+ `metric_for_best_model` will always be retained in addition to the most recent ones. For example, for
+ `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
+ alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
+ checkpoints are saved: the last one and the best one (if they are different).
+ save_safetensors (`bool`, *optional*, defaults to `True`):
+ Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts.
+ save_on_each_node (`bool`, *optional*, defaults to `False`):
+ When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
+ the main one.
+
+ This should not be activated when the different nodes use the same storage as the files will be saved with
+ the same names for each node.
+ save_only_model (`bool`, *optional*, defaults to `False`):
+ When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.
+ Note that when this is true, you won't be able to resume training from checkpoint.
+ This enables you to save storage by not storing the optimizer, scheduler & rng state.
+ You can only load the model using `from_pretrained` with this option set to `True`.
+ restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
+ Whether to restore the callback states from the checkpoint. If `True`, will override
+ callbacks passed to the `Trainer` if they exist in the checkpoint."
+ use_cpu (`bool`, *optional*, defaults to `False`):
+ Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
+ seed (`int`, *optional*, defaults to 42):
+ Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
+ [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters.
+ data_seed (`int`, *optional*):
+ Random seed to be used with data samplers. If not set, random generators for data sampling will use the
+ same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
+ seed.
+ bf16 (`bool`, *optional*, defaults to `False`):
+ Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
+ NVIDIA architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change.
+ fp16 (`bool`, *optional*, defaults to `False`):
+ Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
+ fp16_backend (`str`, *optional*, defaults to `"auto"`):
+ This argument is deprecated. Use `half_precision_backend` instead.
+ half_precision_backend (`str`, *optional*, defaults to `"auto"`):
+ The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
+ use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
+ requested backend.
+ bf16_full_eval (`bool`, *optional*, defaults to `False`):
+ Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
+ metric values. This is an experimental API and it may change.
+ fp16_full_eval (`bool`, *optional*, defaults to `False`):
+ Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
+ metric values.
+ tf32 (`bool`, *optional*):
+ Whether to enable the TF32 mode, currently not supported.
+ local_rank (`int`, *optional*, defaults to -1):
+ Rank of the process during distributed training.
+ ddp_backend (`str`, *optional*):
+ The backend to use for distributed training. Must be one of `"nccl"`, `"mpi"`, `"ccl"`, `"gloo"`, `"hccl"`.
+ dataloader_drop_last (`bool`, *optional*, defaults to `False`):
+ Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
+ or not.
+ eval_steps (`int` or `float`, *optional*):
+ Number of update steps between two evaluations if `eval_strategy="steps"`. Will default to the same
+ value as `logging_steps` if not set. Should be an integer or a float in range `[0,1)`. If smaller than 1,
+ will be interpreted as ratio of total training steps.
+ dataloader_num_workers (`int`, *optional*, defaults to 0):
+ Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the
+ main process.
+ past_index (`int`, *optional*, defaults to -1):
+ Some models like [TransformerXL](../model_doc/transformerxl) or [XLNet](../model_doc/xlnet) can make use of
+ the past hidden states for their predictions. If this argument is set to a positive int, the `Trainer` will
+ use the corresponding output (usually index 2) as the past state and feed it to the model at the next
+ training step under the keyword argument `mems`.
+ run_name (`str`, *optional*, defaults to `output_dir`):
+ A descriptor for the run. Typically used for [wandb](https://www.wandb.com/) and
+ [mlflow](https://www.mlflow.org/) logging. If not specified, will be the same as `output_dir`.
+ disable_tqdm (`bool`, *optional*):
+ Whether or not to disable the tqdm progress bars and table of metrics produced by
+ [`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
+ set to warn or lower (default), `False` otherwise.
+ remove_unused_columns (`bool`, *optional*, defaults to `True`):
+ Whether or not to automatically remove the columns unused by the model forward method.
+ label_names (`List[str]`, *optional*):
+ The list of keys in your dictionary of inputs that correspond to the labels.
+
+ Will eventually default to the list of argument names accepted by the model that contain the word "label",
+ except if the model used is one of the `XxxForQuestionAnswering` in which case it will also include the
+ `["start_positions", "end_positions"]` keys.
+ load_best_model_at_end (`bool`, *optional*, defaults to `False`):
+ Whether or not to load the best model found during training at the end of training. When this option is
+ enabled, the best checkpoint will always be saved. See
+ [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit)
+ for more.
+
+
+
+ When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in
+ the case it is "steps", `save_steps` must be a round multiple of `eval_steps`.
+
+
+
+ metric_for_best_model (`str`, *optional*):
+ Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different
+ models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`. Will
+ default to `"loss"` if unspecified and `load_best_model_at_end=True` (to use the evaluation loss).
+
+ If you set this value, `greater_is_better` will default to `True`. Don't forget to set it to `False` if
+ your metric is better when lower.
+ greater_is_better (`bool`, *optional*):
+ Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models
+ should have a greater metric or not. Will default to:
+
+ - `True` if `metric_for_best_model` is set to a value that doesn't end in `"loss"`.
+ - `False` if `metric_for_best_model` is not set, or set to a value that ends in `"loss"`.
+ ignore_data_skip (`bool`, *optional*, defaults to `False`):
+ When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
+ stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
+ can take a long time) but will not yield the same results as the interrupted training would have.
+
+ label_smoothing_factor (`float`, *optional*, defaults to 0.0):
+ The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
+ labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
+ label_smoothing_factor/num_labels` respectively.
+ debug (`str` or list of [`~debug_utils.DebugOption`], *optional*, defaults to `""`):
+ Enable one or more debug features. This is an experimental feature.
+
+ Possible options are:
+
+ - `"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that led to
+ the event
+ - `"tpu_metrics_debug"`: print debug metrics on TPU
+
+ The options should be separated by whitespaces.
+ optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_mindspore"`):
+ The optimizer to use: adamw_mindspore or adafactor.
+ optim_args (`str`, *optional*):
+ Optional arguments that are supplied to AnyPrecisionAdamW.
+ group_by_length (`bool`, *optional*, defaults to `False`):
+ Whether or not to group together samples of roughly the same length in the training dataset (to minimize
+ padding applied and be more efficient). Only useful if applying dynamic padding.
+ length_column_name (`str`, *optional*, defaults to `"length"`):
+ Column name for precomputed lengths. If the column exists, grouping by length will use these values rather
+ than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an
+ instance of `Dataset`.
+ report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
+ The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
+ `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
+ `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
+ integrations.
+ ddp_find_unused_parameters (`bool`, *optional*):
+ When using distributed training, the value of the flag `find_unused_parameters` passed to
+ `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
+ ddp_bucket_cap_mb (`int`, *optional*):
+ When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.
+ ddp_broadcast_buffers (`bool`, *optional*):
+ When using distributed training, the value of the flag `broadcast_buffers` passed to
+ `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
+ dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
+ Whether you want to pin memory in data loaders or not. Will default to `True`.
+ dataloader_persistent_workers (`bool`, *optional*, defaults to `False`):
+ If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
+ This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
+ increase RAM usage. Will default to `False`.
+ dataloader_prefetch_factor (`int`, *optional*):
+ Number of batches loaded in advance by each worker.
+ 2 means there will be a total of 2 * num_workers batches prefetched across all workers.
+ skip_memory_metrics (`bool`, *optional*, defaults to `True`):
+ Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
+ down the training and evaluation speed.
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push the model to the Hub every time the model is saved. If this is activated,
+ `output_dir` will begin a git directory synced with the repo (determined by `hub_model_id`) and the content
+ will be pushed each time a save is triggered (depending on your `save_strategy`). Calling
+ [`~Trainer.save_model`] will also trigger a push.
+
+
+
+ If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
+ pushed.
+
+
+
+ resume_from_checkpoint (`str`, *optional*):
+ The path to a folder with a valid checkpoint for your model. This argument is not directly used by
+ [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
+ scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+ hub_model_id (`str`, *optional*):
+ The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in
+ which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
+ for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
+ `"organization_name/model"`. Will default to `user_name/output_dir_name` with *output_dir_name* being the
+ name of `output_dir`.
+
+ Will default to the name of `output_dir`.
+ hub_strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`):
+ Defines the scope of what is pushed to the Hub and when. Possible values are:
+
+ - `"end"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and a
+ draft of a model card when the [`~Trainer.save_model`] method is called.
+ - `"every_save"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and
+ a draft of a model card each time there is a model save. The pushes are asynchronous to not block
+ training, and in case the save are very frequent, a new push is only attempted if the previous one is
+ finished. A last push is made with the final model at the end of training.
+ - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named
+ last-checkpoint, allowing you to resume training easily with
+ `trainer.train(resume_from_checkpoint="last-checkpoint")`.
+ - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the output
+ folder (so you will get one checkpoint folder per folder in your final repository)
+
+ hub_token (`str`, *optional*):
+ The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
+ `huggingface-cli login`.
+ hub_private_repo (`bool`, *optional*, defaults to `False`):
+ If True, the Hub repo will be set to private.
+ hub_always_push (`bool`, *optional*, defaults to `False`):
+ Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
+ gradient_checkpointing (`bool`, *optional*, defaults to `False`):
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+ gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
+ Key word arguments to be passed to the `gradient_checkpointing_enable` method.
+ include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
+ Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
+ that need inputs, predictions and references for scoring calculation in Metric class.
+ eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
+ Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
+ will instead store them as lists, with each batch kept separate.
+ auto_find_batch_size (`bool`, *optional*, defaults to `False`)
+ Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
+ CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
+ full_determinism (`bool`, *optional*, defaults to `False`)
+ If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
+ distributed training. Important: this will negatively impact the performance, so only use it for debugging.
+ ray_scope (`str`, *optional*, defaults to `"last"`):
+ The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
+ then use the last checkpoint of all trials, compare those, and select the best one. However, other options
+ are also available. See the [Ray documentation](
+ https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
+ more options.
+ include_tokens_per_second (`bool`, *optional*):
+ Whether or not to compute the number of tokens per second per device for training speed metrics.
+
+ This will iterate over the entire training dataloader once beforehand,
+
+ and will slow down the entire process.
+
+ include_num_input_tokens_seen (`bool`, *optional*):
+ Whether or not to track the number of input tokens seen throughout training.
+
+ May be slower in distributed training as gather operations must be called.
+
+ neftune_noise_alpha (`Optional[float]`):
+ If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance
+ for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
+ [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
+ `PeftModel` from peft.
+ optim_target_modules (`Union[str, List[str]]`, *optional*):
+ The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
+ https://arxiv.org/abs/2403.03507
+ See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
+ optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
+ only.
+
+ batch_eval_metrics (`Optional[bool]`, defaults to `False`):
+ If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
+ rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
+ that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
+ summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
+
+ eval_on_start(`bool`, *optional*, defaults to `False`):
+ Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.
+ """
+
+ framework = "mindspore"
+ output_dir: str = field(
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
+ )
+ overwrite_output_dir: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Overwrite the content of the output directory. "
+ "Use this to continue training if output_dir points to a checkpoint directory."
+ )
+ },
+ )
+
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
+ do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
+ eval_strategy: Union[IntervalStrategy, str] = field(
+ default="no",
+ metadata={"help": "The evaluation strategy to use."},
+ )
+ prediction_loss_only: bool = field(
+ default=False,
+ metadata={"help": "When performing evaluation and predictions, only returns the loss."},
+ )
+
+ per_device_train_batch_size: int = field(
+ default=8, metadata={"help": "Batch size per NPU/GPU/TPU/MPS core/CPU for training."}
+ )
+ per_device_eval_batch_size: int = field(
+ default=8, metadata={"help": "Batch size per NPU/GPU/TPU/MPS core/CPU for evaluation."}
+ )
+
+ gradient_accumulation_steps: int = field(
+ default=1,
+ metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
+ )
+ eval_accumulation_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
+ )
+
+ eval_delay: Optional[float] = field(
+ default=0,
+ metadata={
+ "help": (
+ "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the"
+ " eval_strategy."
+ )
+ },
+ )
+
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
+ momentum_value: float = field(default=0.9, metadata={"help": "Momentum value for Momentum optimizer"})
+ max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
+
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
+ max_steps: int = field(
+ default=-1,
+ metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
+ )
+ lr_scheduler_type: Union[SchedulerType, str] = field(
+ default="linear",
+ metadata={"help": "The scheduler type to use."},
+ )
+ lr_scheduler_kwargs: Optional[Union[dict, str]] = field(
+ default_factory=dict,
+ metadata={
+ "help": (
+ "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts."
+ )
+ },
+ )
+ warmup_ratio: float = field(
+ default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
+ )
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
+
+ log_level: Optional[str] = field(
+ default="info",
+ metadata={
+ "help": (
+ "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug',"
+ " 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and"
+ " lets the application set the level. Defaults to 'passive'."
+ ),
+ "choices": trainer_log_levels.keys(),
+ },
+ )
+ log_level_replica: Optional[str] = field(
+ default="warning",
+ metadata={
+ "help": "Logger log level to use on replica nodes. Same choices and defaults as ``log_level``",
+ "choices": trainer_log_levels.keys(),
+ },
+ )
+ log_on_each_node: bool = field(
+ default=True,
+ metadata={
+ "help": (
+ "When doing a multinode distributed training, whether to log once per node or just once on the main"
+ " node."
+ )
+ },
+ )
+ logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
+ logging_strategy: Union[IntervalStrategy, str] = field(
+ default="steps",
+ metadata={"help": "The logging strategy to use."},
+ )
+ logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
+ logging_steps: float = field(
+ default=500,
+ metadata={
+ "help": (
+ "Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
+ "If smaller than 1, will be interpreted as ratio of total training steps."
+ )
+ },
+ )
+ logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
+ save_strategy: Union[IntervalStrategy, str] = field(
+ default="steps",
+ metadata={"help": "The checkpoint save strategy to use."},
+ )
+ save_steps: float = field(
+ default=500,
+ metadata={
+ "help": (
+ "Save checkpoint every X updates steps. Should be an integer or a float in range `[0,1)`. "
+ "If smaller than 1, will be interpreted as ratio of total training steps."
+ )
+ },
+ )
+ save_total_limit: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in"
+ " `output_dir`. When `load_best_model_at_end` is enabled, the 'best' checkpoint according to"
+ " `metric_for_best_model` will always be retained in addition to the most recent ones. For example,"
+ " for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be"
+ " retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`,"
+ " it is possible that two checkpoints are saved: the last one and the best one (if they are different)."
+ " Default is unlimited checkpoints"
+ )
+ },
+ )
+ save_safetensors: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Use safetensors saving and loading for state dicts."},
+ )
+ save_on_each_node: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
+ " only on the main one"
+ )
+ },
+ )
+ save_only_model: bool = field(
+ default=True,
+ metadata={
+ "help": (
+ "When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state."
+ "Note that when this is true, you won't be able to resume training from checkpoint."
+ "This enables you to save storage by not storing the optimizer, scheduler & rng state."
+ "You can only load the model using from_pretrained with this option set to True."
+ )
+ },
+ )
+ restore_callback_states_from_checkpoint: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks "
+ "passed to the `Trainer` if they exist in the checkpoint."
+ },
+ )
+ use_cpu: bool = field(
+ default=False,
+ metadata={
+ "help": " Whether or not to use cpu. If set to False, we will use cuda/tpu/mps/npu device if available."
+ },
+ )
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
+ data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
+ jit_mode: bool = field(default=False, metadata={"help": "Whether or not to use MindSpore jit trace"})
+ bf16: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
+ " architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change."
+ )
+ },
+ )
+ fp16: bool = field(
+ default=False,
+ metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
+ )
+ amp_opt_level: str = field(
+ default=None,
+ metadata={
+ "help": (
+ "For fp16/bf16 auto mix-precision: AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
+ "See details at https://nvidia.github.io/apex/amp.html"
+ )
+ },
+ )
+ bf16_full_eval: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may"
+ " change."
+ )
+ },
+ )
+ fp16_full_eval: bool = field(
+ default=False,
+ metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
+ )
+ tf32: Optional[bool] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Whether to enable tf32 mode, not available on MindSpore 2.3.1. This is an experimental"
+ " API and it may change."
+ )
+ },
+ )
+ local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
+ ddp_backend: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The backend to be used for distributed training",
+ "choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl"],
+ },
+ )
+ debug: Union[str, List[DebugOption]] = field(
+ default="",
+ metadata={
+ "help": (
+ "Whether or not to enable debug mode. Current options: "
+ "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
+ "`tpu_metrics_debug` (print debug metrics on TPU)."
+ )
+ },
+ )
+
+ dataloader_drop_last: bool = field(
+ default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
+ )
+ eval_steps: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Run an evaluation every X steps. Should be an integer or a float in range `[0,1)`. "
+ "If smaller than 1, will be interpreted as ratio of total training steps."
+ )
+ },
+ )
+ dataloader_num_workers: int = field(
+ default=1,
+ metadata={
+ "help": (
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded"
+ " in the main process."
+ )
+ },
+ )
+ dataloader_prefetch_factor: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "Number of batches loaded in advance by each worker. "
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
+ )
+ },
+ )
+ past_index: int = field(
+ default=-1,
+ metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
+ )
+
+ run_name: Optional[str] = field(
+ default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
+ )
+ disable_tqdm: Optional[bool] = field(
+ default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
+ )
+
+ remove_unused_columns: Optional[bool] = field(
+ default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
+ )
+ label_names: Optional[List[str]] = field(
+ default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
+ )
+ load_best_model_at_end: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether or not to load the best model found during training at the end of training. When this option"
+ " is enabled, the best checkpoint will always be saved. See `save_total_limit` for more."
+ )
+ },
+ )
+ metric_for_best_model: Optional[str] = field(
+ default=None, metadata={"help": "The metric to use to compare two different models."}
+ )
+ greater_is_better: Optional[bool] = field(
+ default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
+ )
+ ignore_data_skip: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "When resuming training, whether or not to skip the first epochs and batches to get to the same"
+ " training data."
+ )
+ },
+ )
+
+ label_smoothing_factor: float = field(
+ default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
+ )
+
+ default_optim = "adamw_mindspore"
+ optim: Union[OptimizerNames, str] = field(
+ default=default_optim,
+ metadata={"help": "The optimizer to use."},
+ )
+ optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."})
+ zero_stage: Optional[int] = field(
+ default=None,
+ metadata={"help": ("Enable ZeRO optimizer parallelism, select from [1, 2]")},
+ )
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
+ group_by_length: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
+ )
+ length_column_name: Optional[str] = field(
+ default="length",
+ metadata={"help": "Column name with precomputed lengths to use when grouping by length."},
+ )
+ report_to: Union[None, str, List[str]] = field(
+ default=None, metadata={"help": "The list of integrations to report the results and logs to."}
+ )
+ ddp_find_unused_parameters: Optional[bool] = field(
+ default=None,
+ metadata={
+ "help": (
+ "When using distributed training, the value of the flag `find_unused_parameters` passed to "
+ "`DistributedDataParallel`."
+ )
+ },
+ )
+ ddp_bucket_cap_mb: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": (
+ "When using distributed training, the value of the flag `bucket_cap_mb` passed to "
+ "`DistributedDataParallel`."
+ )
+ },
+ )
+ ddp_broadcast_buffers: Optional[bool] = field(
+ default=None,
+ metadata={
+ "help": (
+ "When using distributed training, the value of the flag `broadcast_buffers` passed to "
+ "`DistributedDataParallel`."
+ )
+ },
+ )
+ dataloader_pin_memory: bool = field(
+ default=False, metadata={"help": "Whether or not to pin memory for DataLoader."}
+ )
+ dataloader_persistent_workers: bool = field(
+ default=False,
+ metadata={
+ "help": "If True, the data loader will not shut down the worker processes after a dataset has been consumed once. "
+ "This allows to maintain the workers Dataset instances alive. Can potentially speed up training, "
+ "but will increase RAM usage."
+ },
+ )
+ skip_memory_metrics: bool = field(
+ default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
+ )
+ use_legacy_prediction_loop: bool = field(
+ default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
+ )
+ push_to_hub: bool = field(
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
+ )
+ resume_from_checkpoint: Optional[str] = field(
+ default=None,
+ metadata={"help": "The path to a folder with a valid checkpoint for your model."},
+ )
+ hub_model_id: Optional[str] = field(
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
+ )
+ hub_strategy: Union[HubStrategy, str] = field(
+ default="every_save",
+ metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
+ )
+ hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
+ hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
+ hub_always_push: bool = field(
+ default=False,
+ metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."},
+ )
+ gradient_checkpointing: bool = field(
+ default=False,
+ metadata={"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."},
+ )
+ gradient_checkpointing_kwargs: Optional[Union[dict, str]] = field(
+ default=None,
+ metadata={
+ "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to "
+ "`mindspore.nn.cell.recompute` through `model.gradient_checkpointing_enable`."
+ },
+ )
+ include_inputs_for_metrics: bool = field(
+ default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
+ )
+ eval_do_concat_batches: bool = field(
+ default=True,
+ metadata={
+ "help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, "
+ "will instead store them as lists, with each batch kept separate."
+ },
+ )
+ # Deprecated arguments
+ fp16_backend: str = field(
+ default="auto",
+ metadata={
+ "help": "Deprecated. Use half_precision_backend instead",
+ "choices": ["auto", "apex", "cpu_amp"],
+ },
+ )
+ evaluation_strategy: Union[IntervalStrategy, str] = field(
+ default=None,
+ metadata={"help": "Deprecated. Use `eval_strategy` instead"},
+ )
+ push_to_hub_model_id: Optional[str] = field(
+ default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
+ )
+ push_to_hub_organization: Optional[str] = field(
+ default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
+ )
+ push_to_hub_token: Optional[str] = field(
+ default=None, metadata={"help": "The token to use to push to the Model Hub."}
+ )
+ _n_gpu: int = field(init=False, repr=False, default=-1)
+ mp_parameters: str = field(
+ default="",
+ metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
+ )
+
+ auto_find_batch_size: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to automatically decrease the batch size in half and rerun the training loop again each time"
+ " a CUDA Out-of-Memory was reached"
+ )
+ },
+ )
+ full_determinism: bool = field(
+ default=False,
+ metadata={
+ "help": (
+ "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed"
+ " training. Important: this will negatively impact the performance, so only use it for debugging."
+ )
+ },
+ )
+ ray_scope: Optional[str] = field(
+ default="last",
+ metadata={
+ "help": (
+ 'The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray'
+ " will then use the last checkpoint of all trials, compare those, and select the best one. However,"
+ " other options are also available. See the Ray documentation"
+ " (https://docs.ray.io/en/latest/tune/api_docs/analysis.html"
+ "#ray.tune.ExperimentAnalysis.get_best_trial)"
+ " for more options."
+ )
+ },
+ )
+
+ include_tokens_per_second: Optional[bool] = field(
+ default=False,
+ metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
+ )
+
+ include_num_input_tokens_seen: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": "If set to `True`, will track the number of input tokens seen throughout training. "
+ "(May be slower in distributed training)"
+ },
+ )
+
+ neftune_noise_alpha: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "Activates neftune noise embeddings into the model. NEFTune has been proven to drastically "
+ "improve model performances for instrcution fine-tuning. Check out the original paper here: "
+ "https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune. "
+ "Only supported for `PreTrainedModel` and `PeftModel` classes."
+ },
+ )
+
+ optim_target_modules: Union[None, str, List[str]] = field(
+ default=None,
+ metadata={
+ "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment."
+ },
+ )
+
+ batch_eval_metrics: bool = field(
+ default=False,
+ metadata={"help": "Break eval metrics calculation into batches to save memory."},
+ )
+
+ eval_on_start: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether to run through the entire `evaluation` step at the very beginning of training as a sanity check."
+ },
+ )
+
+ def __post_init__(self):
+ # Parse in args that could be `dict` sent in from the CLI as a string
+ for _field in _VALID_DICT_FIELDS:
+ if not hasattr(self, _field):
+ logger.warning(f"cambrian.transformers not support args: {_field}, skip.")
+ continue
+
+ passed_value = getattr(self, _field)
+
+ # We only want to do this if the str starts with a bracket to indiciate a `dict`
+ # else its likely a filename if supported
+ if isinstance(passed_value, str) and passed_value.startswith("{"):
+ loaded_dict = json.loads(passed_value)
+ # Convert str values to types if applicable
+ loaded_dict = _convert_str_dict(loaded_dict)
+ setattr(self, _field, loaded_dict)
+
+ # expand paths, if not os.makedirs("~/bar") will make directory
+ # in the current directory instead of the actual home
+ # see https://github.com/huggingface/transformers/issues/10628
+ if self.output_dir is not None:
+ self.output_dir = os.path.expanduser(self.output_dir)
+ if self.logging_dir is None and self.output_dir is not None:
+ self.logging_dir = os.path.join(self.output_dir, default_logdir())
+ if self.logging_dir is not None:
+ self.logging_dir = os.path.expanduser(self.logging_dir)
+
+ if self.disable_tqdm is None:
+ self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
+
+ if self.evaluation_strategy is not None:
+ warnings.warn(
+ "`evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead",
+ FutureWarning,
+ )
+ self.eval_strategy = self.evaluation_strategy
+
+ if isinstance(self.eval_strategy, EvaluationStrategy):
+ warnings.warn(
+ "using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5"
+ " of 🤗 Transformers. Use `IntervalStrategy` instead",
+ FutureWarning,
+ )
+ # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it.
+ self.eval_strategy = self.eval_strategy.value
+
+ self.eval_strategy = IntervalStrategy(self.eval_strategy)
+ self.logging_strategy = IntervalStrategy(self.logging_strategy)
+ self.save_strategy = IntervalStrategy(self.save_strategy)
+ self.hub_strategy = HubStrategy(self.hub_strategy)
+
+ self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
+ if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO:
+ self.do_eval = True
+
+ # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
+ if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
+ if self.logging_steps > 0:
+ logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}")
+ self.eval_steps = self.logging_steps
+ else:
+ raise ValueError(
+ f"evaluation strategy {self.eval_strategy} requires either non-zero --eval_steps or"
+ " --logging_steps"
+ )
+
+ # logging_steps must be non-zero for logging_strategy that is other than 'no'
+ if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0:
+ raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps")
+
+ if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1:
+ if self.logging_steps != int(self.logging_steps):
+ raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}")
+ self.logging_steps = int(self.logging_steps)
+ if self.eval_strategy == IntervalStrategy.STEPS and self.eval_steps > 1:
+ if self.eval_steps != int(self.eval_steps):
+ raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
+ self.eval_steps = int(self.eval_steps)
+ if self.save_strategy == IntervalStrategy.STEPS and self.save_steps > 1:
+ if self.save_steps != int(self.save_steps):
+ raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}")
+ self.save_steps = int(self.save_steps)
+
+ # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
+ if self.load_best_model_at_end:
+ if self.eval_strategy != self.save_strategy:
+ raise ValueError(
+ "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation "
+ f"strategy: {self.eval_strategy}\n- Save strategy: {self.save_strategy}"
+ )
+ if self.eval_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
+ if self.eval_steps < 1 or self.save_steps < 1:
+ if not (self.eval_steps < 1 and self.save_steps < 1):
+ raise ValueError(
+ "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
+ "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps "
+ f"{self.save_steps} and eval_steps {self.eval_steps}."
+ )
+ # Work around floating point precision issues
+ LARGE_MULTIPLIER = 1_000_000
+ if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0:
+ raise ValueError(
+ "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
+ f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}."
+ )
+ raise ValueError(
+ "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation "
+ f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
+ )
+
+ safetensors_available = is_safetensors_available()
+ if self.save_safetensors and not safetensors_available:
+ raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!")
+ if not self.save_safetensors and safetensors_available:
+ logger.info(
+ f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
+ f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
+ f"If your model cannot be saved by safetensors please feel free to open an issue at "
+ f"https://github.com/huggingface/safetensors!"
+ )
+
+ if (
+ self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
+ ) and self.metric_for_best_model is None:
+ self.metric_for_best_model = "loss"
+ if self.greater_is_better is None and self.metric_for_best_model is not None:
+ self.greater_is_better = not (self.metric_for_best_model.endswith("loss"))
+ if self.run_name is None:
+ self.run_name = self.output_dir
+
+ if self.fp16 and self.bf16:
+ raise ValueError("At most one of fp16 and bf16 can be True, but not both")
+
+ if self.fp16_full_eval and self.bf16_full_eval:
+ raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")
+
+ if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
+ if self.eval_strategy == IntervalStrategy.NO:
+ raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
+
+ self.optim = OptimizerNames(self.optim)
+ if self.adafactor:
+ warnings.warn(
+ "`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim"
+ " adafactor` instead",
+ FutureWarning,
+ )
+ self.optim = OptimizerNames.ADAFACTOR
+ if self.zero_stage is not None:
+ if self.zero_stage not in [1, 2]:
+ raise NotImplementedError
+ zero_stage_2_optim = {1: OptimizerNames.ADAMW_ZERO1_MINDSPORE, 2: OptimizerNames.ADAMW_ZERO2_MINDSPORE}
+ if self.optim in [
+ OptimizerNames.ADAMW_MINDSPORE,
+ OptimizerNames.ADAMW_ZERO1_MINDSPORE,
+ OptimizerNames.ADAMW_ZERO2_MINDSPORE,
+ ]:
+ optim = zero_stage_2_optim[self.zero_stage]
+ warnings.warn(f"`--zero_stage` is {self.zero_stage}, replace {self.optim} with {optim}.")
+ self.optim = optim
+
+ if self.framework == "mindspore" and self.tf32 is not None:
+ if self.tf32:
+ raise NotImplementedError
+
+ # FIXME: delete it later if not available
+ # if training args is specified, it will override the one specified in the accelerate config
+ mixed_precision_dtype = os.environ.get("MINDSPORE_MIXED_PRECISION", "no")
+ self.input_dtype = ms.float32
+ if self.fp16:
+ mixed_precision_dtype = "fp16"
+ self.input_dtype = ms.float16
+ elif self.bf16:
+ mixed_precision_dtype = "bf16"
+ self.input_dtype = ms.bfloat16
+ os.environ["MINDSPORE_MIXED_PRECISION"] = mixed_precision_dtype
+
+ if self.report_to is None:
+ # logger.info(
+ # "The default value for the training argument `--report_to` will change in v5 (from all installed "
+ # "integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as "
+ # "now. You should start updating your code and make this info disappear :-)."
+ # )
+ # self.report_to = "all"
+ self.report_to = []
+
+ if self.warmup_ratio < 0 or self.warmup_ratio > 1:
+ raise ValueError("warmup_ratio must lie in range [0,1]")
+ elif self.warmup_ratio > 0 and self.warmup_steps > 0:
+ logger.info(
+ "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio"
+ " during training"
+ )
+
+ if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0 or 0 < self.warmup_steps <= 1:
+ raise ValueError("warmup_steps must be either 0 or > 1")
+
+ if isinstance(self.debug, str):
+ self.debug = [DebugOption(s) for s in self.debug.split()]
+ elif self.debug is None:
+ self.debug = []
+
+ if self.push_to_hub_token is not None:
+ raise NotImplementedError
+
+ if self.push_to_hub_model_id is not None:
+ raise NotImplementedError
+ elif self.push_to_hub_organization is not None:
+ raise NotImplementedError
+
+ @property
+ def train_batch_size(self) -> int:
+ per_device_batch_size = self.per_device_train_batch_size
+ train_batch_size = per_device_batch_size * max(1, self.n_gpu)
+ return train_batch_size
+
+ @property
+ def eval_batch_size(self) -> int:
+ per_device_batch_size = self.per_device_eval_batch_size
+ eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
+ return eval_batch_size
+
+ @cached_property
+ def _setup_devices(self):
+ self._n_gpu = get_group_size() if _is_parallel() else 1
+
+ @property
+ def n_gpu(self):
+ """
+ The number of GPUs used by this process.
+
+ Note:
+ This will only be greater than one when you have multiple NPUs/GPUs available but are not using distributed
+ training. For distributed training, it will always be 1.
+ """
+ # Make sure `self._n_gpu` is properly setup.
+ if not hasattr(self, "_n_gpu"):
+ _ = self._setup_devices
+ return self._n_gpu
+
+ @property
+ def parallel_mode(self):
+ """
+ The current mode used for parallelism if multiple NPUs/GPUs/TPU cores are available.
+ """
+ if self.n_gpu > 1:
+ return ParallelMode.MINDSPORE_DATA_PARALLEL
+ else:
+ return ParallelMode.STAND_ALONE
+
+ @property
+ def world_size(self):
+ """
+ The number of processes used in parallel.
+ """
+ if self.framework == "mindspore":
+ if _is_parallel():
+ return get_group_size()
+ else:
+ raise NotImplementedError
+
+ return 1
+
+ @property
+ def process_index(self):
+ """
+ The index of the current process used.
+ """
+ if self.framework == "mindspore":
+ if _is_parallel():
+ return get_rank()
+ else:
+ raise NotImplementedError
+
+ return 0
+
+ @property
+ def local_process_index(self):
+ """
+ The index of the local process used.
+ """
+ if self.framework == "mindspore":
+ if _is_parallel():
+ return get_rank() % 8
+ else:
+ raise NotImplementedError
+
+ return 0
+
+ def get_process_log_level(self):
+ """
+ Returns the log level to be used depending on whether this process is the main process of node 0, main process
+ of node non-0, or a non-main process.
+
+ For the main process the log level defaults to the logging level set (`logging.WARNING` if you didn't do
+ anything) unless overridden by `log_level` argument.
+
+ For the replica processes the log level defaults to `logging.WARNING` unless overridden by `log_level_replica`
+ argument.
+
+ The choice between the main and replica process settings is made according to the return value of `should_log`.
+ """
+
+ # convert to int
+ log_level = trainer_log_levels[self.log_level]
+ log_level_replica = trainer_log_levels[self.log_level_replica]
+
+ log_level_main_node = logging.get_verbosity() if log_level == -1 else log_level
+ log_level_replica_node = logging.get_verbosity() if log_level_replica == -1 else log_level_replica
+ return log_level_main_node if self.should_log else log_level_replica_node
+
+ @property
+ def should_log(self):
+ """
+ Whether or not the current process should produce log.
+ """
+ if self.log_on_each_node:
+ return self.local_process_index == 0
+ else:
+ return self.process_index == 0
+
+ @property
+ def should_save(self):
+ """
+ Whether or not the current process should write to disk, e.g., to save models and checkpoints.
+ """
+ if self.save_on_each_node:
+ return self.local_process_index == 0
+ else:
+ return self.process_index == 0
+
+ def get_warmup_steps(self, num_training_steps: int):
+ """
+ Get number of steps used for a linear warmup.
+ """
+ warmup_steps = self.warmup_steps if self.warmup_steps > 0 else math.ceil(num_training_steps * self.warmup_ratio)
+ return warmup_steps
+
+
+class ParallelMode(Enum):
+ STAND_ALONE = "stand_alone"
+ MINDSPORE_MODEL_PARALLEL = "mindspore_model_parallel"
+ MINDSPORE_DATA_PARALLEL = "mindspore_data_parallel"
diff --git a/mindone/transformers/utils/__init__.py b/mindone/transformers/utils/__init__.py
index 8a0428491f..d3e5e6c5e8 100644
--- a/mindone/transformers/utils/__init__.py
+++ b/mindone/transformers/utils/__init__.py
@@ -1,4 +1,6 @@
-from .backbone_utils import BackboneMixin
+from .backbone_utils import *
+from .generic import *
+from .import_utils import *
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
diff --git a/mindone/transformers/utils/generic.py b/mindone/transformers/utils/generic.py
new file mode 100644
index 0000000000..0a8bfa82d4
--- /dev/null
+++ b/mindone/transformers/utils/generic.py
@@ -0,0 +1,33 @@
+import inspect
+
+
+def can_return_loss(model_class):
+ """
+ Check if a given model can return loss.
+
+ Args:
+ model_class (`type`): The class of the model.
+ """
+ signature = inspect.signature(model_class.construct) # MindSpore models
+
+ for p in signature.parameters:
+ if p == "return_loss" and signature.parameters[p].default is True:
+ return True
+
+ return False
+
+
+def find_labels(model_class):
+ """
+ Find the labels used by a given model.
+
+ Args:
+ model_class (`type`): The class of the model.
+ """
+ model_name = model_class.__name__
+ signature = inspect.signature(model_class.construct) # MindSpore models
+
+ if "QuestionAnswering" in model_name:
+ return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
+ else:
+ return [p for p in signature.parameters if "label" in p]
diff --git a/mindone/transformers/utils/import_utils.py b/mindone/transformers/utils/import_utils.py
new file mode 100644
index 0000000000..090de3eca0
--- /dev/null
+++ b/mindone/transformers/utils/import_utils.py
@@ -0,0 +1,12 @@
+from ..mindspore_adapter.utils import _is_ascend
+
+
+def is_flash_attn_2_available():
+ if _is_ascend():
+ return True
+
+ return False
+
+
+def is_sdpa_available():
+ return False
diff --git a/mkdocs.yml b/mkdocs.yml
index 253c0ecc5b..b7418b3121 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -1,4 +1,5 @@
site_name: MindOne - One for All
+
site_url: https://mindspore-lab.github.io/mindone
repo_url: https://github.com/mindspore-lab/mindone
repo_name: mindspore-lab/mindone