Skip to content

Commit

Permalink
Update jupyter notebook for Vocab API refactor
Browse files Browse the repository at this point in the history
Change-Id: Ib83b46357e6455cb09d4557a4989986c34322dbd
  • Loading branch information
frankfliu committed Dec 28, 2020
1 parent 4c823c7 commit 8d743f1
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 74 deletions.
25 changes: 12 additions & 13 deletions jupyter/mxnet/load_your_own_mxnet_bert.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
"metadata": {},
"outputs": [],
"source": [
"// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
"%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
"\n",
"%maven ai.djl:api:0.9.0\n",
"%maven ai.djl:api:0.10.0-SNAPSHOT\n",
"%maven ai.djl.mxnet:mxnet-engine:0.9.0\n",
"%maven ai.djl.mxnet:mxnet-model-zoo:0.9.0\n",
"%maven org.slf4j:slf4j-api:1.7.26\n",
Expand Down Expand Up @@ -224,8 +224,8 @@
" @SerializedName(\"idx_to_token\")\n",
" List<String> idx2token;\n",
"\n",
" public static List<String> parseToken(String file) {\n",
" try (InputStream is = new URL(file).openStream();\n",
" public static List<String> parseToken(URL file) {\n",
" try (InputStream is = file.openStream();\n",
" Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {\n",
" return JsonUtils.GSON.fromJson(reader, VocabParser.class).idx2token;\n",
" } catch (IOException e) {\n",
Expand All @@ -241,12 +241,12 @@
"metadata": {},
"outputs": [],
"source": [
"var path = Paths.get(\"build/mxnet/bertqa/vocab.json\").toAbsolutePath();\n",
"URL url = Paths.get(\"build/mxnet/bertqa/vocab.json\").toUri().toURL();\n",
"var vocabulary = SimpleVocabulary.builder()\n",
" .optMinFrequency(1)\n",
" .addFromCustomizedFile(\"file://\" + path.toString(), VocabParser::parseToken)\n",
" .optUnknownToken(\"[UNK]\")\n",
" .build();"
" .optMinFrequency(1)\n",
" .addFromCustomizedFile(url, VocabParser::parseToken)\n",
" .optUnknownToken(\"[UNK]\")\n",
" .build();"
]
},
{
Expand Down Expand Up @@ -339,12 +339,11 @@
" \n",
" @Override\n",
" public void prepare(NDManager manager, Model model) throws IOException {\n",
" Path path = Paths.get(\"build/mxnet/bertqa/vocab.json\").toAbsolutePath();\n",
" URL path = Paths.get(\"build/mxnet/bertqa/vocab.json\").toUri().toURL();\n",
" vocabulary =\n",
" SimpleVocabulary.builder()\n",
" .optMinFrequency(1)\n",
" .addFromCustomizedFile(\n",
" \"file://\" + path.toString(), VocabParser::parseToken)\n",
" .addFromCustomizedFile(path, VocabParser::parseToken)\n",
" .optUnknownToken(\"[UNK]\")\n",
" .build();\n",
" tokenizer = new BertTokenizer();\n",
Expand Down Expand Up @@ -489,7 +488,7 @@
"mimetype": "text/x-java-source",
"name": "Java",
"pygments_lexer": "java",
"version": "12.0.2+10"
"version": "14.0.2+12"
}
},
"nbformat": 4,
Expand Down
16 changes: 8 additions & 8 deletions jupyter/pytorch/load_your_own_pytorch_bert.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
"metadata": {},
"outputs": [],
"source": [
"// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
"%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
"\n",
"%maven ai.djl:api:0.9.0\n",
"%maven ai.djl:api:0.10.0-SNAPSHOT\n",
"%maven ai.djl.pytorch:pytorch-engine:0.9.0\n",
"%maven ai.djl.pytorch:pytorch-model-zoo:0.9.0\n",
"%maven org.slf4j:slf4j-api:1.7.26\n",
Expand Down Expand Up @@ -227,10 +227,10 @@
"source": [
"var path = Paths.get(\"build/pytorch/bertqa/vocab.txt\");\n",
"var vocabulary = SimpleVocabulary.builder()\n",
" .optMinFrequency(1)\n",
" .addFromTextFile(path.toString())\n",
" .optUnknownToken(\"[UNK]\")\n",
" .build();"
" .optMinFrequency(1)\n",
" .addFromTextFile(path)\n",
" .optUnknownToken(\"[UNK]\")\n",
" .build();"
]
},
{
Expand Down Expand Up @@ -302,7 +302,7 @@
" Path path = Paths.get(\"build/pytorch/bertqa/vocab.txt\");\n",
" vocabulary = SimpleVocabulary.builder()\n",
" .optMinFrequency(1)\n",
" .addFromTextFile(path.toString())\n",
" .addFromTextFile(path)\n",
" .optUnknownToken(\"[UNK]\")\n",
" .build();\n",
" tokenizer = new BertTokenizer();\n",
Expand Down Expand Up @@ -439,7 +439,7 @@
"mimetype": "text/x-java-source",
"name": "Java",
"pygments_lexer": "java",
"version": "12.0.2+10"
"version": "14.0.2+12"
}
},
"nbformat": 4,
Expand Down
106 changes: 53 additions & 53 deletions jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,11 @@
"}\n",
"\n",
"Criteria<NDList, NDList> criteria = Criteria.builder()\n",
" .optApplication(Application.NLP.WORD_EMBEDDING)\n",
" .setTypes(NDList.class, NDList.class)\n",
" .optModelUrls(modelUrls)\n",
" .optProgress(new ProgressBar())\n",
" .build();\n",
" .optApplication(Application.NLP.WORD_EMBEDDING)\n",
" .setTypes(NDList.class, NDList.class)\n",
" .optModelUrls(modelUrls)\n",
" .optProgress(new ProgressBar())\n",
" .build();\n",
"ZooModel<NDList, NDList> embedding = ModelZoo.loadModel(criteria);"
]
},
Expand All @@ -233,37 +233,37 @@
"source": [
"Predictor<NDList, NDList> embedder = embedding.newPredictor();\n",
"Block classifier = new SequentialBlock()\n",
" // text embedding layer\n",
" .add(\n",
" ndList -> {\n",
" NDArray data = ndList.singletonOrThrow();\n",
" NDList inputs = new NDList();\n",
" long batchSize = data.getShape().get(0);\n",
" float maxLength = data.getShape().get(1);\n",
" \n",
" if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n",
" inputs.add(data.toType(DataType.INT64, false));\n",
" inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));\n",
" inputs.add(data.getManager().arange(maxLength)\n",
" .toType(DataType.INT64, false)\n",
" .broadcast(data.getShape()));\n",
" } else {\n",
" inputs.add(data);\n",
" inputs.add(data.getManager().full(new Shape(batchSize), maxLength));\n",
" }\n",
" // run embedding\n",
" try {\n",
" return embedder.predict(inputs);\n",
" } catch (TranslateException e) {\n",
" throw new IllegalArgumentException(\"embedding error\", e);\n",
" }\n",
" })\n",
" // classification layer\n",
" .add(Linear.builder().setUnits(768).build()) // pre classifier\n",
" .add(Activation::relu)\n",
" .add(Dropout.builder().optRate(0.2f).build())\n",
" .add(Linear.builder().setUnits(5).build()) // 5 star rating\n",
" .addSingleton(nd -> nd.get(\":,0\")); // Take [CLS] as the head\n",
" // text embedding layer\n",
" .add(\n",
" ndList -> {\n",
" NDArray data = ndList.singletonOrThrow();\n",
" NDList inputs = new NDList();\n",
" long batchSize = data.getShape().get(0);\n",
" float maxLength = data.getShape().get(1);\n",
"\n",
" if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n",
" inputs.add(data.toType(DataType.INT64, false));\n",
" inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));\n",
" inputs.add(data.getManager().arange(maxLength)\n",
" .toType(DataType.INT64, false)\n",
" .broadcast(data.getShape()));\n",
" } else {\n",
" inputs.add(data);\n",
" inputs.add(data.getManager().full(new Shape(batchSize), maxLength));\n",
" }\n",
" // run embedding\n",
" try {\n",
" return embedder.predict(inputs);\n",
" } catch (TranslateException e) {\n",
" throw new IllegalArgumentException(\"embedding error\", e);\n",
" }\n",
" })\n",
" // classification layer\n",
" .add(Linear.builder().setUnits(768).build()) // pre classifier\n",
" .add(Activation::relu)\n",
" .add(Dropout.builder().optRate(0.2f).build())\n",
" .add(Linear.builder().setUnits(5).build()) // 5 star rating\n",
" .addSingleton(nd -> nd.get(\":,0\")); // Take [CLS] as the head\n",
"Model model = Model.newInstance(\"AmazonReviewRatingClassification\");\n",
"model.setBlock(classifier);"
]
Expand Down Expand Up @@ -291,10 +291,10 @@
"source": [
"// Prepare the vocabulary\n",
"SimpleVocabulary vocabulary = SimpleVocabulary.builder()\n",
" .optMinFrequency(1)\n",
" .addFromTextFile(Paths.get(embedding.getArtifact(\"vocab.txt\").toURI()))\n",
" .optUnknownToken(\"[UNK]\")\n",
" .build();\n",
" .optMinFrequency(1)\n",
" .addFromTextFile(embedding.getArtifact(\"vocab.txt\"))\n",
" .optUnknownToken(\"[UNK]\")\n",
" .build();\n",
"// Prepare dataset\n",
"int maxTokenLength = 64; // cutoff tokens length\n",
"int batchSize = 8;\n",
Expand Down Expand Up @@ -323,19 +323,19 @@
"source": [
"CheckpointsTrainingListener listener = new CheckpointsTrainingListener(\"build/model\");\n",
" listener.setSaveModelCallback(\n",
" trainer -> {\n",
" TrainingResult result = trainer.getTrainingResult();\n",
" Model model = trainer.getModel();\n",
" // track for accuracy and loss\n",
" float accuracy = result.getValidateEvaluation(\"Accuracy\");\n",
" model.setProperty(\"Accuracy\", String.format(\"%.5f\", accuracy));\n",
" model.setProperty(\"Loss\", String.format(\"%.5f\", result.getValidateLoss()));\n",
" });\n",
" trainer -> {\n",
" TrainingResult result = trainer.getTrainingResult();\n",
" Model model = trainer.getModel();\n",
" // track for accuracy and loss\n",
" float accuracy = result.getValidateEvaluation(\"Accuracy\");\n",
" model.setProperty(\"Accuracy\", String.format(\"%.5f\", accuracy));\n",
" model.setProperty(\"Loss\", String.format(\"%.5f\", result.getValidateLoss()));\n",
" });\n",
"DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type\n",
" .addEvaluator(new Accuracy())\n",
" .optDevices(Device.getDevices(1)) // train using single GPU\n",
" .addTrainingListeners(TrainingListener.Defaults.logging(\"build/model\"))\n",
" .addTrainingListeners(listener);"
" .addEvaluator(new Accuracy())\n",
" .optDevices(Device.getDevices(1)) // train using single GPU\n",
" .addTrainingListeners(TrainingListener.Defaults.logging(\"build/model\"))\n",
" .addTrainingListeners(listener);"
]
},
{
Expand Down Expand Up @@ -460,7 +460,7 @@
"mimetype": "text/x-java-source",
"name": "Java",
"pygments_lexer": "java",
"version": "11.0.5+10-LTS"
"version": "14.0.2+12"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 8d743f1

Please sign in to comment.