Skip to content

Commit

Permalink
Llama2 on inf2 example tests, bug fixes and documentation (#2607)
Browse files Browse the repository at this point in the history
* skip specifying neuron library versions in requirements.txt

* add test for text iterator batch streamer

* add test for micro batch index API

* Add details about supported Neuron SDK version

* Add accelerator memory details for inf2

* fix linter error

---------

Co-authored-by: Naman Nandan <[email protected]>
  • Loading branch information
namannandan and Naman Nandan committed Sep 26, 2023
1 parent ab69b69 commit c3ca259
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 14 deletions.
14 changes: 11 additions & 3 deletions examples/large_models/inferentia2/llama2/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ Instructions on how to use the AOT compiled model artifacts is shown below.
Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlarge`), ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed.
DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher.

**Note**: The `inf2.24xlarge` instance consists of 6 neuron chips with 2 neuron cores each. The total accelerator memory is 192GB.
Based on the configuration used in [model-config.yaml](model-config.yaml), with `tp_degree` set to 6, 3 of the 6 neuron chips are used, i.e 6 neuron cores.
On loading the model, the accelerator memory consumed is 38.1GB (12.7GB per chip).

### Step 2: Package Installations

Follow the steps below to complete package installations
Expand All @@ -29,9 +33,10 @@ Follow the steps below to complete package installations
sudo apt-get update
sudo apt-get upgrade

# Update Neuron Runtime
sudo apt-get install aws-neuronx-collectives=2.* -y
sudo apt-get install aws-neuronx-runtime-lib=2.* -y
# Install Neuron libraries, SDK 2.12.2: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/prev/content.html#id8
sudo apt-get install aws-neuronx-dkms=2.11.9.0
sudo apt-get install aws-neuronx-collectives=2.15.16.0*
sudo apt-get install aws-neuronx-runtime-lib=2.15.14.0*

# Activate Python venv
source /opt/aws_neuron_venv_pytorch/bin/activate
Expand All @@ -46,6 +51,9 @@ python ts_scripts/install_dependencies.py --neuronx --environment=dev
# Install torchserve and torch-model-archiver
python ts_scripts/install_from_src.py

# Install additional neuron packages, SDK 2.12.2: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/prev/content.html#id8
python -m pip install neuronx-cc==2.8.0.25 torch-neuronx==1.13.1.1.9.1 transformers-neuronx==0.5.58

# Navigate to `examples/large_models/inferentia2/llama2` directory
cd examples/large_models/inferentia2/llama2/

Expand Down
3 changes: 0 additions & 3 deletions examples/large_models/inferentia2/llama2/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
--extra-index-url https://pip.repos.neuron.amazonaws.com
torch-neuronx==1.13.1.1.9.0
transformers-neuronx==0.5.58
transformers==4.31.0
tokenizers==0.13.3
sentencepiece==0.1.99
30 changes: 30 additions & 0 deletions ts/tests/unit_tests/test_hf_batch_streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from transformers import AutoTokenizer

from ts.handler_utils.hf_batch_streamer import TextIteratorStreamerBatch


def test_hf_batch_streamer():
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
streamer = TextIteratorStreamerBatch(
tokenizer=tokenizer, batch_size=2, skip_special_tokens=True
)

input1 = "hello world"
input2 = "good day"

for inputs in zip(tokenizer(input1)["input_ids"], tokenizer(input2)["input_ids"]):
streamer.put(torch.tensor(inputs))

streamer.end()

output1 = ""
output2 = ""

for data in streamer:
assert len(data) == 2
output1 += data[0]
output2 += data[1]

assert output1 == input1
assert output2 == input2
39 changes: 31 additions & 8 deletions ts/tests/unit_tests/test_micro_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,33 @@
Unit test for MicroBatchHandler class.
"""
import json
import math
import random
import sys
from pathlib import Path

import pytest
from torchvision.models.resnet import ResNet18_Weights

from ts.handler_utils.micro_batching import MicroBatching
from ts.torch_handler.image_classifier import ImageClassifier
from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext
from ts.torch_handler.unit_tests.test_utils.model_dir import copy_files, download_model

REPO_DIR = Path(__file__).parents[3]


class MicroBatchingTestHandler(ImageClassifier):
def __init__(self, micro_batch_size):
super().__init__()
self.micro_batch_idx_set = set()
self.handle = MicroBatching(self, micro_batch_size)

def preprocess(self, data):
self.micro_batch_idx_set.add(self.handle.get_micro_batch_idx())
return super().preprocess(data)


def read_image_bytes(filename):
with open(
filename,
Expand Down Expand Up @@ -97,19 +110,13 @@ def context(model_dir, model_name):

@pytest.fixture(scope="module", params=[1, 8])
def handler(context, request):
handler = ImageClassifier()

from ts.handler_utils.micro_batching import MicroBatching

mb_handle = MicroBatching(handler, micro_batch_size=request.param)
handler = MicroBatchingTestHandler(micro_batch_size=request.param)
handler.initialize(context)

handler.handle = mb_handle
handler.handle.parallelism = context.model_yaml_config["mb_parallelism"]

yield handler

mb_handle.shutdown()
handler.handle.shutdown()


@pytest.fixture(scope="module", params=[1, 16])
Expand All @@ -129,13 +136,24 @@ def mixed_batch(kitten_image_bytes, dog_image_bytes, request):
return test_data, labels


def verify_micro_batch_idx_set(micro_batch_idx_set, test_data_size, micro_batch_size):
assert micro_batch_idx_set == set(
[val for val in range(0, math.ceil(test_data_size / micro_batch_size))]
)


def test_handle(context, mixed_batch, handler):
test_data, labels = mixed_batch
results = handler.handle(test_data, context)
assert len(results) == len(labels)
for l, r in zip(labels, results):
assert l in r

verify_micro_batch_idx_set(
handler.micro_batch_idx_set, len(test_data), handler.handle.micro_batch_size
)
handler.micro_batch_idx_set.clear()


def test_handle_explain(context, kitten_image_bytes, handler):
context.explain = True
Expand All @@ -144,6 +162,11 @@ def test_handle_explain(context, kitten_image_bytes, handler):
assert len(results) == 2
assert results[0]

verify_micro_batch_idx_set(
handler.micro_batch_idx_set, len(test_data), handler.handle.micro_batch_size
)
handler.micro_batch_idx_set.clear()


def test_micro_batching_handler_threads(handler):
assert len(handler.handle.thread_groups["preprocess"]) == 1
Expand Down
1 change: 1 addition & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1095,3 +1095,4 @@ PreprocessCallCount
AOT
microbatches
tokenization
tp

0 comments on commit c3ca259

Please sign in to comment.