Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarks: Add Mixture of Experts Model #679

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/superbench-config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ A list of models to run, only supported in model-benchmark.
squeezenet1_0 | squeezenet1_1 |
vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 |
bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl |
llama2-7b | llama2-13b | llama2-70b ]
llama2-7b | llama2-13b | llama2-70b |
mixtral-8x7b | mixtral-8x22b ]
```
* default value: `[ ]`

Expand Down
1 change: 1 addition & 0 deletions docs/user-tutorial/benchmarks/model-benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Run training or inference tasks with single or half precision for deep learning
including the following categories:
* GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl
* LLAMA: llama2-7b, llama2-13b, llama2-70b
* MoE: mixtral-8x7b, mixtral-8x22b
* BERT: bert-base and bert-large
* LSTM
* CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
Expand Down
42 changes: 42 additions & 0 deletions superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

"""Export PyTorch models to ONNX format."""

import sys
from pathlib import Path
from typing import Optional

from packaging import version
import torch.hub
Expand All @@ -16,6 +18,12 @@
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel

# Check Python version and skip Mixtral if Python is 3.7 or lower
MixtralBenchmarkModel: Optional[type] = None
if sys.version_info >= (3, 8):
from transformers import MixtralConfig
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel


class torch2onnxExporter():
"""PyTorch model to ONNX exporter."""
Expand Down Expand Up @@ -122,6 +130,40 @@ def __init__(self):
self.num_classes,
),
}

# Only include Mixtral models if MixtralBenchmarkModel is available
if MixtralBenchmarkModel is not None:
self.benchmark_models.update(
{
'mixtral-8x7b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
intermediate_size=14336,
max_position_embeddings=32768,
router_aux_loss_coef=0.02,
),
self.num_classes,
),
'mixtral-8x22b':
lambda: MixtralBenchmarkModel(
MixtralConfig(
hidden_size=6144,
num_hidden_layers=56,
num_attention_heads=48,
num_key_value_heads=8,
intermediate_size=16384,
max_position_embeddings=65536,
router_aux_loss_coef=0.001,
),
self.num_classes,
),
}
)

self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)

Expand Down
12 changes: 12 additions & 0 deletions superbench/benchmarks/model_benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@

"""A module containing all the e2e model related benchmarks."""

import sys
from typing import Optional

from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import PytorchGPT2
from superbench.benchmarks.model_benchmarks.pytorch_cnn import PytorchCNN
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT
from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama

# Check for Python version > 3.7 and conditionally import PytorchMixtral
PytorchMixtral: Optional[type] = None
if sys.version_info >= (3, 8):
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import PytorchMixtral

__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama']

if PytorchMixtral is not None:
__all__.append('PytorchMixtral')
273 changes: 273 additions & 0 deletions superbench/benchmarks/model_benchmarks/pytorch_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the Pytorch Mixtral model."""

import torch
from transformers import MixtralModel, MixtralConfig
try:
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
except ImportError:
te = None

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset


class MixtralBenchmarkModel(torch.nn.Module):
"""The Mixtral model for benchmarking."""
def __init__(self, config, num_classes):
"""Constructor.

Args:
config (MixtralConfig): Configurations of Mixtral model.
num_classes (int): The number of objects for classification.
"""
super().__init__()
self._Mixtral = MixtralModel(config)
self._linear = torch.nn.Linear(config.hidden_size, num_classes)

def forward(self, input):
"""Forward propagation function.

Args:
input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
shape (batch_size, sequence_length).

Return:
result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
"""
outputs = self._Mixtral(input)
result = self._linear(outputs[0])
return result


class PytorchMixtral(PytorchBase):
"""The Mixtral benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.

Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)
self._config = None
self._fp8_recipe = None
self._supported_precision = [
Precision.FLOAT32,
Precision.FLOAT16,
Precision.FP8_HYBRID,
Precision.FP8_E4M3,
]
self._optimizer_type = Optimizer.ADAMW
self._loss_fn = torch.nn.CrossEntropyLoss()

def add_parser_arguments(self):
"""Add the Mixtral-specified arguments.

Mixtral model reference: https://huggingface.co/docs/transformers/model_doc/Mixtral
"""
super().add_parser_arguments()

self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.')
self._parser.add_argument('--hidden_size', type=int, default=4096, required=False, help='Hidden size.')
self._parser.add_argument(
'--num_hidden_layers', type=int, default=32, required=False, help='The number of hidden layers.'
)
self._parser.add_argument(
'--num_attention_heads', type=int, default=32, required=False, help='The number of attention heads.'
)
self._parser.add_argument(
'--intermediate_size',
type=int,
default=14336,
required=False,
help='Dimension of the MLP representations.'
)
self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')
self._parser.add_argument(
'--num_key_value_heads',
type=int,
default=8,
required=False,
help='The number of key_value heads that should be used to implement Grouped Query Attention.'
)
self._parser.add_argument(
'--max_position_embeddings',
type=int,
default=None,
required=False,
help='Maximum sequence length that Mixtral supports'
)
self._parser.add_argument(
'--router_aux_loss_coef',
type=float,
default=0.001,
required=False,
help='The aux loss factor for the total loss.'
)

def _generate_dataset(self):
"""Generate dataset for benchmarking according to shape info.

Return:
True if dataset is created successfully.
"""
self._dataset = TorchRandomDataset(
[self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long
)
if len(self._dataset) == 0:
logger.error('Generate random dataset failed - model: {}'.format(self._name))
return False

return True

def _create_model(self, precision):
"""Construct the model for benchmarking.

Args:
precision (Precision): precision of model and input data, such as float32, float16.
"""
self._config = MixtralConfig(
hidden_size=self._args.hidden_size,
num_hidden_layers=self._args.num_hidden_layers,
num_attention_heads=self._args.num_attention_heads,
num_key_value_heads=self._args.num_key_value_heads,
intermediate_size=self._args.intermediate_size,
max_position_embeddings=self._args.max_position_embeddings,
router_aux_loss_coef=self._args.router_aux_loss_coef,
)

enable_fp8 = precision.name.startswith('FP8_')
if enable_fp8 and te is None:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: Cannot find transformer_engine.'
)
return False
if enable_fp8 and not self._gpu_available:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: FP8 is only supported on GPU.'
)
return False

try:
self._model = MixtralBenchmarkModel(self._config, self._args.num_classes)
if enable_fp8:
self._fp8_recipe = DelayedScaling(
fp8_format=Format[precision.name.strip('FP8_')],
amax_history_len=16,
amax_compute_algo='max',
)
self._to_te_model(self._model.to(dtype=torch.float16))
else:
self._model = self._model.to(dtype=getattr(torch, precision.value))
if self._gpu_available:
self._model = self._model.cuda()
except BaseException as e:
logger.error(
'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
self._name, precision, str(e)
)
)
return False

self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes)
if self._gpu_available:
self._target = self._target.cuda()

return True

def _train_step(self, precision):
"""Define the training process.

Args:
precision (Precision): precision of model and input data, such as float32, float16.

Return:
The step-time list of every training step.
"""
duration = []
curr_step = 0
check_frequency = 100
while True:
for idx, sample in enumerate(self._dataloader):
start = self._timer()
if self._gpu_available:
sample = sample.cuda()
self._optimizer.zero_grad()
if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
output = self._model(sample)
else:
output = self._model(sample)
loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target)
loss.backward()
self._optimizer.step()
end = self._timer()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end, check_frequency):
return duration

def _inference_step(self, precision):
"""Define the inference process.

Args:
precision (Precision): precision of model and input data,
such as float32, float16.

Return:
The latency list of every inference operation.
"""
duration = []
curr_step = 0
with torch.no_grad():
self._model.eval()
while True:
for idx, sample in enumerate(self._dataloader):
start = self._timer()
if self._gpu_available:
sample = sample.cuda()
if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
self._model(sample)
else:
self._model(sample)
end = self._timer()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end):
return duration


# Register Mixtral benchmark with 8x7b parameters.
# Ref: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
BenchmarkRegistry.register_benchmark(
'pytorch-mixtral-8x7b',
PytorchMixtral,
parameters='--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --intermediate_size=14336 \
--num_key_value_heads=8 --max_position_embeddings=32768 --router_aux_loss_coef=0.02'
)

# Register Mixtral benchmark with 8x22b parameters.
# Ref: https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json
BenchmarkRegistry.register_benchmark(
'pytorch-mixtral-8x22b',
PytorchMixtral,
parameters='--hidden_size=6144 --num_hidden_layers=56 --num_attention_heads=48 --intermediate_size=16384 \
--num_key_value_heads=8 --max_position_embeddings=65536 --router_aux_loss_coef=0.001'
)
Loading
Loading