Skip to content

Latest commit

 

History

History
71 lines (58 loc) · 10.9 KB

README.md

File metadata and controls

71 lines (58 loc) · 10.9 KB

PyTorch Neuron (torch-neuronx) Samples for AWS Trn1/Trn1n & Inf2 Instances

This directory contains sample PyTorch Neuron inference and training scripts that can be run on AWS Trainium (Trn1/Trn1n instances) and AWS Inferentia (Inf2 instances).

For additional information on these training scripts, please refer to the tutorials found in the official Trainium documentation.

Training

The following samples are available for training:

Name Description Training Parallelism
dp_bert_hf_pretrain Phase1 and phase2 pretraining of Hugging Face BERT-large model DataParallel
mnist_mlp Examples of training a multilayer perceptron on the MNIST dataset DataParallel
mnist_mlp Examples of training a multilayer perceptron on the MNIST dataset using DDP DataParallel
hf_text_classification Fine-tuning various Hugging Face models for a text classification task DataParallel
hf_image_classification Fine-tuning Hugging Face models (ex. ViT) for a image classification task DataParallel
hf_contrastive_image_text Fine-tuning Multi-modal Image and Text Hugging Face models (ex. CLIP) DataParallel
hf_language_modeling Training Hugging Face models (ex. GPT2) for causal language modeling (CLM) DataParallel
hf_bert_jp Fine-tuning & Deployment Hugging Face BERT Japanese model DataParallel
hf_sentiment_analysis Examples of training Hugging Face bert-base-cased model for a text classification task with Trn1 Single Neuron and Distributed Training DataParallel
customop_mlp Examples of training a multilayer perceptron model with a custom Relu operator on a single Trn1 DataParallel
tp_dp_gpt_neox_20b_hf_pretrain Training GPT-NEOX 20B model using neuronx-distributed Tensor Parallel & DataParallel
tp_dp_gpt_neox_6.9b_hf_pretrain Training GPT-NEOX 6.9B model using neuronx-distributed Tensor Parallel & DataParallel
tp_zero1_llama2_7b_hf_pretrain Training Llama-2 7B model using neuronx-distributed Tensor Parallel
tp_pp_llama2_70b_hf_pretrain Training Llama-2 70B model using neuronx-distributed Tensor Parallel & Pipeline Parallel

Inference

The following samples are available for inference:

Model Name Model Task Original Model Source
BERT Base Uncased Masked language modeling and next sentence prediction bert-base-uncased
DistilBERT Base Uncased Masked language modeling and next sentence prediction distilbert-base-uncased
RoBERTa Large Masked language modeling, sequence classification, and question and answering roberta-large
Vision Transformer (ViT) Image classification google/vit-base-patch16-224
GPT2 Text feature extraction gpt2
ResNet50 Image classification resnet50
HuggingFace Stable Diffusion 1.5 (512x512) Text to image generation stable-diffusion-v1-5
HuggingFace Stable Diffusion 2.1 (512x512) Text to image generation stable-diffusion-2-1-base
HuggingFace Stable Diffusion 2.1 (768x768) Text to image generation stable-diffusion-2-1
HuggingFace Stable Diffusion XL Base 1.0 (1024x1024) Text to image generation stable-diffusion-xl-base-1.0
HuggingFace Stable Diffusion XL Base & Refiner 1.0 (1024x1024) Text to image generation stable-diffusion-xl-base-1.0
UNet Image Segmentation unet
VGG Image Classification vgg
Multimodal Perceiver Video Classification and Autoencoding multimodal-perceiver
Language Perceiver Text Classification language-perceiver
Vision Perceiver Image Classification vision-perceiver-conv
CLIP Base Image Classification clip-vit-base-patch32
CLIP Large Image Classification clip-vit-large-patch14
Wav2Vec2 Conformer with Rotary Position Embeddings Automatic Speech Recognition facebook/wav2vec2-conformer-rope-large-960h-ft
Wav2Vec2 Conformer with Relative Position Embeddings Automatic Speech Recognition facebook/wav2vec2-conformer-rel-pos-large-960h-ft

The following samples are available for LLM tensor parallel inference:

Name Instance type
facebook/opt-13b Inf2 & Trn1
facebook/opt-30b Inf2 & Trn1
facebook/opt-66b Inf2
meta-llama/Llama-2-13b Inf2 & Trn1

Microbenchmarking

The following samples are available for microbenchmarking:

Name Description
tutorial Microbenchmarking tutorial
matmult Matrix multiplication microbenchmark