diff --git a/docs/superbench-config.mdx b/docs/superbench-config.mdx index 051abeda3..7bc8748a6 100644 --- a/docs/superbench-config.mdx +++ b/docs/superbench-config.mdx @@ -329,7 +329,8 @@ A list of models to run, only supported in model-benchmark. squeezenet1_0 | squeezenet1_1 | vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 | bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl | - llama2-7b | llama2-13b | llama2-70b ] + llama2-7b | llama2-13b | llama2-70b | + mixtral-8x7b | mixtral-8x22b ] ``` * default value: `[ ]` diff --git a/docs/user-tutorial/benchmarks/model-benchmarks.md b/docs/user-tutorial/benchmarks/model-benchmarks.md index 71e8832cf..ba89ed6ff 100644 --- a/docs/user-tutorial/benchmarks/model-benchmarks.md +++ b/docs/user-tutorial/benchmarks/model-benchmarks.md @@ -14,6 +14,7 @@ Run training or inference tasks with single or half precision for deep learning including the following categories: * GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl * LLAMA: llama2-7b, llama2-13b, llama2-70b +* MoE: mixtral-8x7b, mixtral-8x22b * BERT: bert-base and bert-large * LSTM * CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including: diff --git a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py index 0f28f4f6a..3ae9dd3de 100644 --- a/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py +++ b/superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py @@ -3,7 +3,9 @@ """Export PyTorch models to ONNX format.""" +import sys from pathlib import Path +from typing import Optional from packaging import version import torch.hub @@ -16,6 +18,12 @@ from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel +# Check Python version and skip Mixtral if Python is 3.7 or lower +MixtralBenchmarkModel: Optional[type] = None +if sys.version_info >= (3, 8): + from transformers import MixtralConfig + from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel + class torch2onnxExporter(): """PyTorch model to ONNX exporter.""" @@ -122,6 +130,40 @@ def __init__(self): self.num_classes, ), } + + # Only include Mixtral models if MixtralBenchmarkModel is available + if MixtralBenchmarkModel is not None: + self.benchmark_models.update( + { + 'mixtral-8x7b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + intermediate_size=14336, + max_position_embeddings=32768, + router_aux_loss_coef=0.02, + ), + self.num_classes, + ), + 'mixtral-8x22b': + lambda: MixtralBenchmarkModel( + MixtralConfig( + hidden_size=6144, + num_hidden_layers=56, + num_attention_heads=48, + num_key_value_heads=8, + intermediate_size=16384, + max_position_embeddings=65536, + router_aux_loss_coef=0.001, + ), + self.num_classes, + ), + } + ) + self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx' self._onnx_model_path.mkdir(parents=True, exist_ok=True) diff --git a/superbench/benchmarks/model_benchmarks/__init__.py b/superbench/benchmarks/model_benchmarks/__init__.py index 0829c4d33..b0c102ca7 100644 --- a/superbench/benchmarks/model_benchmarks/__init__.py +++ b/superbench/benchmarks/model_benchmarks/__init__.py @@ -3,11 +3,23 @@ """A module containing all the e2e model related benchmarks.""" +import sys +from typing import Optional + from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2 from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT +from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama + +# Check for Python version > 3.7 and conditionally import PytorchMixtral +PytorchMixtral: Optional[type] = None +if sys.version_info >= (3, 8): + from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral __all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama'] + +if PytorchMixtral is not None: + __all__.append('PytorchMixtral') diff --git a/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py new file mode 100644 index 000000000..6a3d49995 --- /dev/null +++ b/superbench/benchmarks/model_benchmarks/pytorch_mixtral.py @@ -0,0 +1,273 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Module of the Pytorch Mixtral model.""" + +import torch +from transformers import MixtralModel, MixtralConfig +try: + import transformer_engine.pytorch as te + from transformer_engine.common.recipe import Format, DelayedScaling +except ImportError: + te = None + +from superbench.common.utils import logger +from superbench.benchmarks import BenchmarkRegistry, Precision +from superbench.benchmarks.model_benchmarks.model_base import Optimizer +from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase +from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset + + +class MixtralBenchmarkModel(torch.nn.Module): + """The Mixtral model for benchmarking.""" + def __init__(self, config, num_classes): + """Constructor. + + Args: + config (MixtralConfig): Configurations of Mixtral model. + num_classes (int): The number of objects for classification. + """ + super().__init__() + self._Mixtral = MixtralModel(config) + self._linear = torch.nn.Linear(config.hidden_size, num_classes) + + def forward(self, input): + """Forward propagation function. + + Args: + input (torch.LongTensor): Indices of input sequence tokens in the vocabulary, + shape (batch_size, sequence_length). + + Return: + result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence + (classification token) further processed by a Linear layer, shape (batch_size, hidden_size). + """ + outputs = self._Mixtral(input) + result = self._linear(outputs[0]) + return result + + +class PytorchMixtral(PytorchBase): + """The Mixtral benchmark class.""" + def __init__(self, name, parameters=''): + """Constructor. + + Args: + name (str): benchmark name. + parameters (str): benchmark parameters. + """ + super().__init__(name, parameters) + self._config = None + self._fp8_recipe = None + self._supported_precision = [ + Precision.FLOAT32, + Precision.FLOAT16, + Precision.FP8_HYBRID, + Precision.FP8_E4M3, + ] + self._optimizer_type = Optimizer.ADAMW + self._loss_fn = torch.nn.CrossEntropyLoss() + + def add_parser_arguments(self): + """Add the Mixtral-specified arguments. + + Mixtral model reference: https://huggingface.co/docs/transformers/model_doc/Mixtral + """ + super().add_parser_arguments() + + self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.') + self._parser.add_argument('--hidden_size', type=int, default=4096, required=False, help='Hidden size.') + self._parser.add_argument( + '--num_hidden_layers', type=int, default=32, required=False, help='The number of hidden layers.' + ) + self._parser.add_argument( + '--num_attention_heads', type=int, default=32, required=False, help='The number of attention heads.' + ) + self._parser.add_argument( + '--intermediate_size', + type=int, + default=14336, + required=False, + help='Dimension of the MLP representations.' + ) + self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.') + self._parser.add_argument( + '--num_key_value_heads', + type=int, + default=8, + required=False, + help='The number of key_value heads that should be used to implement Grouped Query Attention.' + ) + self._parser.add_argument( + '--max_position_embeddings', + type=int, + default=None, + required=False, + help='Maximum sequence length that Mixtral supports' + ) + self._parser.add_argument( + '--router_aux_loss_coef', + type=float, + default=0.001, + required=False, + help='The aux loss factor for the total loss.' + ) + + def _generate_dataset(self): + """Generate dataset for benchmarking according to shape info. + + Return: + True if dataset is created successfully. + """ + self._dataset = TorchRandomDataset( + [self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long + ) + if len(self._dataset) == 0: + logger.error('Generate random dataset failed - model: {}'.format(self._name)) + return False + + return True + + def _create_model(self, precision): + """Construct the model for benchmarking. + + Args: + precision (Precision): precision of model and input data, such as float32, float16. + """ + self._config = MixtralConfig( + hidden_size=self._args.hidden_size, + num_hidden_layers=self._args.num_hidden_layers, + num_attention_heads=self._args.num_attention_heads, + num_key_value_heads=self._args.num_key_value_heads, + intermediate_size=self._args.intermediate_size, + max_position_embeddings=self._args.max_position_embeddings, + router_aux_loss_coef=self._args.router_aux_loss_coef, + ) + + enable_fp8 = precision.name.startswith('FP8_') + if enable_fp8 and te is None: + logger.error( + f'Create model with fp8 failed - model: {self._name}, precision: {precision},' + ' message: Cannot find transformer_engine.' + ) + return False + if enable_fp8 and not self._gpu_available: + logger.error( + f'Create model with fp8 failed - model: {self._name}, precision: {precision},' + ' message: FP8 is only supported on GPU.' + ) + return False + + try: + self._model = MixtralBenchmarkModel(self._config, self._args.num_classes) + if enable_fp8: + self._fp8_recipe = DelayedScaling( + fp8_format=Format[precision.name.strip('FP8_')], + amax_history_len=16, + amax_compute_algo='max', + ) + self._to_te_model(self._model.to(dtype=torch.float16)) + else: + self._model = self._model.to(dtype=getattr(torch, precision.value)) + if self._gpu_available: + self._model = self._model.cuda() + except BaseException as e: + logger.error( + 'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format( + self._name, precision, str(e) + ) + ) + return False + + self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes) + if self._gpu_available: + self._target = self._target.cuda() + + return True + + def _train_step(self, precision): + """Define the training process. + + Args: + precision (Precision): precision of model and input data, such as float32, float16. + + Return: + The step-time list of every training step. + """ + duration = [] + curr_step = 0 + check_frequency = 100 + while True: + for idx, sample in enumerate(self._dataloader): + start = self._timer() + if self._gpu_available: + sample = sample.cuda() + self._optimizer.zero_grad() + if self._fp8_recipe is not None: + with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe): + output = self._model(sample) + else: + output = self._model(sample) + loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target) + loss.backward() + self._optimizer.step() + end = self._timer() + curr_step += 1 + if curr_step > self._args.num_warmup: + # Save the step time of every training/inference step, unit is millisecond. + duration.append((end - start) * 1000) + self._log_step_time(curr_step, precision, duration) + if self._is_finished(curr_step, end, check_frequency): + return duration + + def _inference_step(self, precision): + """Define the inference process. + + Args: + precision (Precision): precision of model and input data, + such as float32, float16. + + Return: + The latency list of every inference operation. + """ + duration = [] + curr_step = 0 + with torch.no_grad(): + self._model.eval() + while True: + for idx, sample in enumerate(self._dataloader): + start = self._timer() + if self._gpu_available: + sample = sample.cuda() + if self._fp8_recipe is not None: + with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe): + self._model(sample) + else: + self._model(sample) + end = self._timer() + curr_step += 1 + if curr_step > self._args.num_warmup: + # Save the step time of every training/inference step, unit is millisecond. + duration.append((end - start) * 1000) + self._log_step_time(curr_step, precision, duration) + if self._is_finished(curr_step, end): + return duration + + +# Register Mixtral benchmark with 8x7b parameters. +# Ref: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json +BenchmarkRegistry.register_benchmark( + 'pytorch-mixtral-8x7b', + PytorchMixtral, + parameters='--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --intermediate_size=14336 \ + --num_key_value_heads=8 --max_position_embeddings=32768 --router_aux_loss_coef=0.02' +) + +# Register Mixtral benchmark with 8x22b parameters. +# Ref: https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json +BenchmarkRegistry.register_benchmark( + 'pytorch-mixtral-8x22b', + PytorchMixtral, + parameters='--hidden_size=6144 --num_hidden_layers=56 --num_attention_heads=48 --intermediate_size=16384 \ + --num_key_value_heads=8 --max_position_embeddings=65536 --router_aux_loss_coef=0.001' +) diff --git a/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py new file mode 100644 index 000000000..6e028d10d --- /dev/null +++ b/tests/benchmarks/model_benchmarks/test_pytorch_mixtral.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for mixtral model benchmarks.""" + +import sys + +from tests.helper import decorator +from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode + +# Check for Python version 3.8 or greater and conditionally import PytorchMixtral +if sys.version_info >= (3, 8): + from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral + + +@decorator.cuda_test +@decorator.pytorch_test +def test_pytorch_mixtral_8x7b(): + """Test pytorch-mixtral-8x7b benchmark for float16 train and inference.""" + context = BenchmarkRegistry.create_benchmark_context( + 'mixtral-8x7b', + platform=Platform.CUDA, + parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision float16 \ + --hidden_size 1024 --max_position_embeddings 2048 --intermediate_size 3584 \ + --model_action train inference', + framework=Framework.PYTORCH + ) + + assert (BenchmarkRegistry.is_benchmark_context_valid(context)) + + benchmark = BenchmarkRegistry.launch_benchmark(context) + + # Check basic information. + assert (benchmark) + assert (isinstance(benchmark, PytorchMixtral)) + assert (benchmark.name == 'pytorch-mixtral-8x7b') + assert (benchmark.type == BenchmarkType.MODEL) + + # Check predefined parameters of mixtral-8x7b model. + assert (benchmark._args.hidden_size == 1024) + assert (benchmark._args.num_hidden_layers == 32) + assert (benchmark._args.num_attention_heads == 32) + assert (benchmark._args.num_key_value_heads == 8) + assert (benchmark._args.intermediate_size == 3584) + assert (benchmark._args.max_position_embeddings == 2048) + assert (benchmark._args.router_aux_loss_coef == 0.02) + + # Check parameters specified in BenchmarkContext. + assert (benchmark._args.batch_size == 1) + assert (benchmark._args.num_classes == 100) + assert (benchmark._args.seq_len == 32) + assert (benchmark._args.num_warmup == 1) + assert (benchmark._args.num_steps == 2) + + # Test Dataset. + assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size) + + # Check results and metrics. + assert (benchmark.run_count == 1) + assert (benchmark.return_code == ReturnCode.SUCCESS) + + for metric in [ + 'fp16_train_step_time', 'fp16_train_throughput', 'fp16_inference_step_time', 'fp16_inference_throughput' + ]: + assert (len(benchmark.raw_data[metric]) == benchmark.run_count) + assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps) + assert (len(benchmark.result[metric]) == benchmark.run_count) diff --git a/tests/helper/decorator.py b/tests/helper/decorator.py index ff08469ac..b626bb951 100644 --- a/tests/helper/decorator.py +++ b/tests/helper/decorator.py @@ -4,6 +4,7 @@ """Unittest decorator helpers.""" import os +import sys import unittest import functools from pathlib import Path @@ -12,6 +13,7 @@ rocm_test = unittest.skipIf(os.environ.get('SB_TEST_ROCM', '0') == '0', 'Skip ROCm tests.') pytorch_test = unittest.skipIf(os.environ.get('SB_TEST_PYTORCH', '1') == '0', 'Skip PyTorch tests.') +python_eol_test = unittest.skipIf(sys.version_info < (3, 8), 'Skip tests for Python 3.7 or lower.') directx_test = unittest.skipIf(os.environ.get('SB_TEST_DIRECTX', '0') == '0', 'Skip DirectX tests.')