Skip to content

Commit

Permalink
MXNet Neo bug fixes (aws#2767)
Browse files Browse the repository at this point in the history
* MXNet Neo bug fixes/V2

* Add entry point script & MXNet model, fix prediction content type

* Fix formatting

* Add missing PIL.Image import
  • Loading branch information
jkroll-aws authored Jul 6, 2021
1 parent d57eba1 commit 00267d9
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"import sagemaker\n",
"from sagemaker import get_execution_role\n",
"\n",
"role = get_execution_role()\n",
"print(role)\n",
"\n",
"sess = sagemaker.Session()\n",
"bucket = sess.default_bucket()\n",
Expand All @@ -66,9 +64,11 @@
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.amazon.amazon_estimator import get_image_uri\n",
"from sagemaker import image_uris\n",
"\n",
"training_image = get_image_uri(sess.boto_region_name, \"image-classification\", repo_version=\"latest\")\n",
"training_image = image_uris.retrieve(\n",
" region=sess.boto_region_name, framework=\"image-classification\", version=\"latest\"\n",
")\n",
"print(training_image)"
]
},
Expand Down Expand Up @@ -158,10 +158,10 @@
"ic = sagemaker.estimator.Estimator(\n",
" training_image,\n",
" role,\n",
" train_instance_count=1,\n",
" train_instance_type=\"ml.p2.xlarge\",\n",
" train_volume_size=50,\n",
" train_max_run=360000,\n",
" instance_count=1,\n",
" instance_type=\"ml.p2.xlarge\",\n",
" volume_size=50,\n",
" max_run=360000,\n",
" input_mode=\"File\",\n",
" output_path=s3_output_location,\n",
" sagemaker_session=sess,\n",
Expand Down Expand Up @@ -218,13 +218,13 @@
"metadata": {},
"outputs": [],
"source": [
"train_data = sagemaker.session.s3_input(\n",
"train_data = sagemaker.inputs.TrainingInput(\n",
" s3train,\n",
" distribution=\"FullyReplicated\",\n",
" content_type=\"application/x-recordio\",\n",
" s3_data_type=\"S3Prefix\",\n",
")\n",
"validation_data = sagemaker.session.s3_input(\n",
"validation_data = sagemaker.inputs.TrainingInput(\n",
" s3validation,\n",
" distribution=\"FullyReplicated\",\n",
" content_type=\"application/x-recordio\",\n",
Expand Down Expand Up @@ -270,22 +270,42 @@
"metadata": {},
"outputs": [],
"source": [
"optimized_ic = ic\n",
"if ic.create_model().check_neo_region(boto3.Session().region_name) is False:\n",
" print(\"Neo is not currently supported in\", boto3.Session().region_name)\n",
"else:\n",
" output_path = \"/\".join(ic.output_path.split(\"/\")[:-1])\n",
" optimized_ic = ic.compile_model(\n",
" target_instance_family=\"ml_m4\",\n",
" input_shape={\"data\": [1, 3, 224, 224]}, # Batch size 1, 3 channels, 224x224 Images.\n",
" output_path=output_path,\n",
" framework=\"mxnet\",\n",
" framework_version=\"1.2.1\",\n",
" )\n",
" optimized_ic.image = get_image_uri(\n",
" sess.boto_region_name, \"image-classification-neo\", repo_version=\"latest\"\n",
" )\n",
" optimized_ic.name = \"deployed-image-classification\""
"output_path = \"/\".join(ic.output_path.split(\"/\")[:-1])\n",
"optimized_ic = ic.compile_model(\n",
" target_instance_family=\"ml_m4\",\n",
" input_shape={\"data\": [1, 3, 224, 224]}, # Batch size 1, 3 channels, 224x224 Images.\n",
" output_path=output_path,\n",
" framework=\"mxnet\",\n",
" framework_version=\"1.8\",\n",
" env={\"MMS_DEFAULT_RESPONSE_TIMEOUT\": \"500\"},\n",
")\n",
"optimized_ic.image = image_uris.retrieve(\n",
" region=sess.boto_region_name, framework=\"image-classification-neo\", version=\"latest\"\n",
")\n",
"optimized_ic.name = \"deployed-image-classification\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.mxnet.model import MXNetModel\n",
"\n",
"s3_custom_code_location = \"s3://{}/{}/custom_code\".format(bucket, prefix)\n",
"\n",
"optimized_ic_model = MXNetModel(\n",
" model_data=optimized_ic.model_data,\n",
" image_uri=optimized_ic.image_uri,\n",
" framework_version=\"1.8\",\n",
" role=role,\n",
" sagemaker_session=sess,\n",
" entry_point=\"inference.py\",\n",
" py_version=\"py37\",\n",
" env={\"MMS_DEFAULT_RESPONSE_TIMEOUT\": \"500\"},\n",
" code_location=s3_custom_code_location,\n",
")"
]
},
{
Expand All @@ -305,7 +325,9 @@
"metadata": {},
"outputs": [],
"source": [
"ic_classifier = optimized_ic.deploy(initial_instance_count=1, instance_type=\"ml.m4.xlarge\")"
"ic_classifier = optimized_ic_model.deploy(\n",
" initial_instance_count=1, instance_type=\"ml.m4.xlarge\", use_compiled_model=True\n",
")"
]
},
{
Expand All @@ -321,8 +343,8 @@
"metadata": {},
"outputs": [],
"source": [
"!wget -O /tmp/test.jpg http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/008.bathtub/008_0007.jpg\n",
"file_name = \"/tmp/test.jpg\"\n",
"!wget -O test.jpg http://sagemaker-sample-files.s3.amazonaws.com/datasets/image/caltech-256/256_ObjectCategories/008.bathtub/008_0007.jpg\n",
"file_name = \"test.jpg\"\n",
"# test image\n",
"from IPython.display import Image\n",
"\n",
Expand All @@ -348,13 +370,12 @@
"source": [
"import json\n",
"import numpy as np\n",
"import PIL.Image\n",
"\n",
"with open(file_name, \"rb\") as f:\n",
" payload = f.read()\n",
" payload = bytearray(payload)\n",
"test_image = PIL.Image.open(file_name)\n",
"payload = np.asarray(test_image.resize((224, 224)))\n",
"\n",
"ic_classifier.content_type = \"application/x-image\"\n",
"result = json.loads(ic_classifier.predict(payload))\n",
"result = ic_classifier.predict(payload)\n",
"# the result will output the probabilities for all classes\n",
"# find the class with maximum probability and print the class index\n",
"index = np.argmax(result)\n",
Expand Down Expand Up @@ -636,15 +657,9 @@
"metadata": {},
"outputs": [],
"source": [
"ic_classifier.delete_model()\n",
"ic_classifier.delete_endpoint()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -663,7 +678,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.6.13"
},
"notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import io
import json
import logging
import os

import mxnet as mx

# Please make sure to import neomx
import neomx # noqa: F401
import numpy as np

# Change the context to mx.gpu() if deploying to a GPU endpoint
ctx = mx.cpu()


def model_fn(model_dir):
logging.info("Invoking user-defined model_fn")
# The compiled model artifacts are saved with the prefix 'compiled'
sym, arg_params, aux_params = mx.model.load_checkpoint(os.path.join(model_dir, "compiled"), 0)
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
exe = mod.bind(
for_training=False, data_shapes=[("data", (1, 3, 224, 224))], label_shapes=mod._label_shapes
)
mod.set_params(arg_params, aux_params, allow_missing=True)
# Run warm-up inference on empty data during model load (required for GPU)
data = mx.nd.empty((1, 3, 224, 224), ctx=ctx)
mod.predict(data)
return mod


def transform_fn(mod, data, input_content_type, output_content_type):
logging.info("Invoking user-defined transform_fn")
if output_content_type == "application/json":
# pre-processing
data = json.loads(data)
mx_ndarray = mx.nd.array(data)
resized = mx.image.imresize(mx_ndarray, 224, 224)
transposed = resized.transpose((2, 0, 1))
batchified = transposed.expand_dims(axis=0)
processed_input = batchified.as_in_context(ctx)

# prediction/inference
prediction_result = mod.predict(processed_input)

# post-processing
prediction = prediction_result.asnumpy().tolist()
prediction_json = json.dumps(prediction[0])
return prediction_json, output_content_type

0 comments on commit 00267d9

Please sign in to comment.