diff --git a/06-tutorial-classification/6.4-eval-with-model.md b/06-tutorial-classification/6.4-eval-with-model.md new file mode 100644 index 0000000..d56a20a --- /dev/null +++ b/06-tutorial-classification/6.4-eval-with-model.md @@ -0,0 +1,169 @@ +# 使用训练的模型运行推理服务 + +## 下载模型 + +在之前创建的训练进行过程中,经过一些 epoch 会生成模型,训练代码中的 AVA-SDK 会将这些模型上传至 AVA 平台。用户在 "模型列表 -> 训练产生模型" 的页面中,可以查看到对应的训练任务,以及训练任务产生的模型列表,如下图: + +![](/images/ch-06/6.4/snapshot-models.png) + +点击右侧的 "下载",即可将训练产生的模型下载至本地。 + +```bash +➜ ll models +total 6256 +-rw-r--r--@ 1 quxiao staff 2.7M 7 1 10:35 5b337b3a5710da0001333594-snapshot-model_epoch_30.model.tar.gz +``` + +解压之后,就可以看到模型对应的网络结构和权重文件了。 + +```bash +➜ tar -zxvf 5b337b3a5710da0001333594-snapshot-model_epoch_30.model.tar.gz +x snapshot-symbol.json +x snapshot-0030.params +➜ ll +total 12464 +-rw-r--r--@ 1 quxiao staff 2.7M 7 1 10:35 5b337b3a5710da0001333594-snapshot-model_epoch_30.model.tar.gz +-rw-r--r--@ 1 quxiao staff 2.9M 6 27 20:24 snapshot-0030.params +-rw-r--r--@ 1 quxiao staff 105K 6 27 20:24 snapshot-symbol.json +``` + +## 编写推理代码 + +接着,我们可以根据 mxnet 提供的教程,对上面的推理代码进行一些修改。demo 代码如下: + +```python +import os +import base64 +import mxnet as mx +import numpy as np +from collections import namedtuple +from bottle import route, run +import urllib +from urlparse import urlparse + +Batch = namedtuple('Batch', ['data']) + + +class ImagePredictor (object): + def __init__(self): + self.Batch = namedtuple('Batch', ['data']) + self.model_dir = "model" + self.data_dir = "data" + self.mod = None + self.labels = [] + self.symbol_filename = 'deploy.symbol.json' + self.weight_filename = 'weight.params' + self.label_filename = 'labels.csv' + + def load_model(self, custom_model_dir=None, symbol_fn=None, weight_fn=None, label_fn=None): + if custom_model_dir != None: + self.model_dir = custom_model_dir + if symbol_fn != None: + self.symbol_filename = symbol_fn + if weight_fn != None: + self.weight_filename = weight_fn + if label_fn != None: + self.label_filename = label_fn + + print "set cpu/gpu model" + # set the context on CPU, switch to GPU if there is one available + ctx = mx.cpu() + + print "loading model" + sym, arg_params, aux_params, labels = self._load_model() + self.mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None) + self.mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], + label_shapes=self.mod._label_shapes) + self.mod.set_params(arg_params, aux_params, allow_missing=True) + self.labels = labels + + def _load_model(self): + sym = mx.symbol.load(os.path.join(self.model_dir, self.symbol_filename)) + save_dict = mx.ndarray.load(os.path.join(self.model_dir, self.weight_filename)) + arg_params = {} + aux_params = {} + for k, v in save_dict.items(): + tp, name = k.split(':', 1) + if tp == 'arg': + arg_params[name] = v + if tp == 'aux': + aux_params[name] = v + + labels = [] + with open(os.path.join(self.model_dir, self.label_filename), 'r') as f: + labels = [l.rstrip() for l in f] + + return sym, arg_params, aux_params, labels + + def get_image(self, uri): + # download and show the image + # fname = mx.test_utils.download(url, dirname=self.data_dir) + # img = mx.image.imread(fname) + content = urllib.urlopen(uri).read() + img = mx.image.imdecode(content) + if img is None: + return None + # convert into format (batch, RGB, width, height) + img = mx.image.imresize(img, 224, 224) # resize + img = img.transpose((2, 0, 1)) # Channel first + img = img.expand_dims(axis=0) # batchify + return img + + def predict(self, uri, topN=5): + img = self.get_image(uri) + # compute the predict probabilities + self.mod.forward(Batch([img])) + prob = self.mod.get_outputs()[0].asnumpy() + # print the top-5 + prob = np.squeeze(prob) + a = np.argsort(prob)[::-1] + ret = [] + for i in a[0:topN]: + print('probability=%f, class=%s' %(prob[i], self.labels[i])) + ret.append({'prob': float(prob[i]), 'label': self.labels[i]}) + return ret + +``` + +另外,我们还需要准备一份分类新的文件 `label.csv`,包含了 cifar10 数据集中数据分类 ID 和名称的映射关系,文件可参考:https://github.com/quxiao/cifar10-example/blob/master/image-predict/models/ava-snapshot-model/labels.csv + +## 查看推理结果 + +之后,我们运行几张图片的推理: + +```python +def main(): + worker = predictor.ImagePredictor() + worker.load_model( + custom_model_dir='models/ava-snapshot-model', + symbol_fn='snapshot-symbol.json', + weight_fn='snapshot-0030.params', + label_fn='labels.csv') + worker.predict('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true') + worker.predict('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true') +``` + +就可以得到推理的结果: + +``` +➜ python local.py +set cpu/gpu model +loading model +[12:15:51] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v1.1.0. Attempting to upgrade... +[12:15:51] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded! +uri: https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true +probability=0.533429, class=n01491361 tiger shark, Galeocerdo cuvieri +probability=0.309057, class=n01494475 hammerhead, hammerhead shark +probability=0.089500, class=n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +probability=0.042574, class=n01498041 stingray +probability=0.024384, class=n01496331 electric ray, crampfish, numbfish, torpedo + +uri: https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true +probability=0.656869, class=n01491361 tiger shark, Galeocerdo cuvieri +probability=0.165087, class=n01496331 electric ray, crampfish, numbfish, torpedo +probability=0.063514, class=n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +probability=0.056640, class=n01494475 hammerhead, hammerhead shark +probability=0.037440, class=n01498041 stingray +``` + +PS:从结果可以看出,推理的效果与预期存在差距,这是由多方面原因造成的,例如训练数据集编码的参数、训练代码的超参等等。 \ No newline at end of file diff --git a/images/ch-06/6.4/snapshot-models.png b/images/ch-06/6.4/snapshot-models.png new file mode 100644 index 0000000..17742ce Binary files /dev/null and b/images/ch-06/6.4/snapshot-models.png differ