diff --git a/.azure/hpu-tests.yml b/.azure/hpu-tests.yml index a35ae0f7..5f76bec6 100644 --- a/.azure/hpu-tests.yml +++ b/.azure/hpu-tests.yml @@ -136,7 +136,9 @@ jobs: - bash: | export PYTHONPATH="${PYTHONPATH}:$(pwd)" - python mnist_sample.py + python mnist_trainer.py + LOWER_LIST=ops_fp32_mnist.txt FP32_LIST=ops_bf16_mnist.txt \ + python mnist_trainer.py -r autocast workingDirectory: examples/pytorch/ displayName: 'Testing HPU examples' diff --git a/CHANGELOG.md b/CHANGELOG.md index af81b034..ecc59a5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,12 +12,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added tests for mixed precision training ([#36](https://github.com/Lightning-AI/lightning-Habana/pull/36)) +- Example to include mixed precision training ([#54](https://github.com/Lightning-AI/lightning-Habana/pull/54)) - ### Changed - Enabled skipped tests based on registered strategy, accelerator ([#46](https://github.com/Lightning-AI/lightning-Habana/pull/46)) -- ### Fixed diff --git a/examples/pytorch/mnist_sample.py b/examples/pytorch/mnist_sample.py index 75fc6fd0..8f406c12 100644 --- a/examples/pytorch/mnist_sample.py +++ b/examples/pytorch/mnist_sample.py @@ -12,24 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse - import torch from lightning_utilities import module_available from torch.nn import functional as F # noqa: N812 if module_available("lightning"): - from lightning.pytorch import LightningModule, Trainer - from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule + from lightning.pytorch import LightningModule elif module_available("pytorch_lightning"): - from pytorch_lightning import LightningModule, Trainer - from pytorch_lightning.demos.mnist_datamodule import MNISTDataModule - -from lightning_habana.pytorch.accelerator import HPUAccelerator -from lightning_habana.pytorch.strategies import HPUParallelStrategy, SingleHPUStrategy + from pytorch_lightning import LightningModule class LitClassifier(LightningModule): + """Base model.""" + def __init__(self): super().__init__() self.l1 = torch.nn.Linear(28 * 28, 10) @@ -61,20 +56,37 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="MNIST on HPU", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--hpus", default=1, type=int, help="Number of hpus to be used for training") - parser.add_argument("-b", "--batch-size", default=32, type=int) - args = parser.parse_args() - dm = MNISTDataModule(batch_size=args.batch_size) - model = LitClassifier() +class LitAutocastClassifier(LitClassifier): + """Base Model with torch.autocast CM.""" - hpus = args.hpus - _strategy = SingleHPUStrategy() - if hpus > 1: - parallel_hpus = [torch.device("hpu")] * hpus - _strategy = HPUParallelStrategy(parallel_devices=parallel_hpus) - trainer = Trainer(fast_dev_run=True, accelerator=HPUAccelerator(), devices=hpus, strategy=_strategy) + def __init__(self, op_override=False): + super().__init__() + self.op_override = op_override + + def forward(self, x): + if self.op_override: + self.check_override(x) + return super().forward(x) + + def check_override(self, x): + """Checks for op override.""" + identity = torch.eye(x.shape[1], device=x.device, dtype=x.dtype) + y = torch.mm(x, identity) + z = torch.tan(x) + assert y.dtype == torch.float32 + assert z.dtype == torch.bfloat16 - trainer.fit(model, datamodule=dm) - trainer.test(model, datamodule=dm) + def training_step(self, batch, batch_idx): + """Training step.""" + with torch.autocast(device_type="hpu", dtype=torch.bfloat16): + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + """Validation step.""" + with torch.autocast(device_type="hpu", dtype=torch.bfloat16): + return super().validation_step(batch, batch_idx) + + def test_step(self, batch, batch_idx): + """Test step.""" + with torch.autocast(device_type="hpu", dtype=torch.bfloat16): + return super().test_step(batch, batch_idx) diff --git a/examples/pytorch/mnist_trainer.py b/examples/pytorch/mnist_trainer.py new file mode 100644 index 00000000..361d6f43 --- /dev/null +++ b/examples/pytorch/mnist_trainer.py @@ -0,0 +1,112 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import warnings + +from lightning_utilities import module_available +from mnist_sample import LitAutocastClassifier, LitClassifier + +if module_available("lightning"): + from lightning.pytorch import Trainer, seed_everything + from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule + from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +elif module_available("pytorch_lightning"): + from pytorch_lightning import Trainer, seed_everything + from pytorch_lightning.demos.mnist_datamodule import MNISTDataModule + from pytorch_lightning.plugins.precision import MixedPrecisionPlugin + +from lightning_habana import HPUAccelerator, SingleHPUStrategy + +RUN_TYPE = ["basic", "autocast"] + + +def run_trainer(model, plugin): + """Run trainer.fit and trainer.test with given parameters.""" + _data_module = MNISTDataModule(batch_size=32) + trainer = Trainer( + accelerator=HPUAccelerator(), + devices=1, + strategy=SingleHPUStrategy(), + plugins=plugin, + fast_dev_run=True, + ) + trainer.fit(model, _data_module) + trainer.test(model, _data_module) + + +def check_and_init_plugins(plugins, run_type, verbose): + """Initialise plugins with appropriate checks.""" + _plugins = [] + for plugin in plugins: + if verbose: + print(f"Initializing {plugin}") + if plugin == "MixedPrecisionPlugin": + warnings.warn("Operator overriding is not supported with MixedPrecisionPlugin on Habana devices.") + if run_type != "autocast": + _plugins.append(MixedPrecisionPlugin(device="hpu", precision="bf16-mixed")) + else: + warnings.warn("Skipping MixedPrecisionPlugin. Redundant with autocast run.") + else: + print(f"Unsupported or invalid plugin: {plugin}") + return _plugins + + +def run_model(run_type, plugins, verbose): + """Picks appropriate model and plugins.""" + # Initialise plugins + _plugins = check_and_init_plugins(plugins, run_type, verbose) + if run_type == "basic": + _model = LitClassifier() + elif run_type == "autocast": + if "LOWER_LIST" in os.environ or "FP32_LIST" in os.environ: + _model = LitAutocastClassifier(op_override=True) + else: + _model = LitAutocastClassifier() + warnings.warn( + "To override operators with autocast, set LOWER_LIST and FP32_LIST file paths as env variables." + "Example: LOWER_LIST= python example.py" + "https://docs.habana.ai/en/latest/PyTorch/PyTorch_Mixed_Precision/Autocast.html#override-options" + ) + + if verbose: + print(f"With run type: {run_type}, running model: {_model} with plugin: {_plugins}") + return run_trainer(_model, _plugins) + + +def parse_args(): + """Cmdline arguments parser.""" + parser = argparse.ArgumentParser(description="Example to showcase mixed precision training with HPU.") + + parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbosity") + parser.add_argument( + "-r", "--run_types", nargs="+", choices=RUN_TYPE, default=RUN_TYPE, help="Select run type for example" + ) + parser.add_argument( + "-p", "--plugins", nargs="+", default=[], choices=["MixedPrecisionPlugin"], help="Plugins for use in training" + ) + return parser.parse_args() + + +if __name__ == "__main__": + # Get options + options = parse_args() + if options.verbose: + print(f"Running MNIST mixed precision training with options: {options}") + + # Run model and print accuracy + for run_type in options.run_types: + seed_everything(42) + run_model(run_type, options.plugins, options.verbose) diff --git a/examples/pytorch/ops_bf16_mnist.txt b/examples/pytorch/ops_bf16_mnist.txt index 53ec99c1..c8be1c21 100644 --- a/examples/pytorch/ops_bf16_mnist.txt +++ b/examples/pytorch/ops_bf16_mnist.txt @@ -1,2 +1,3 @@ linear relu +mm diff --git a/examples/pytorch/ops_fp32_mnist.txt b/examples/pytorch/ops_fp32_mnist.txt index 4509b7e5..c5b51e78 100644 --- a/examples/pytorch/ops_fp32_mnist.txt +++ b/examples/pytorch/ops_fp32_mnist.txt @@ -1 +1,2 @@ cross_entropy +tan