diff --git a/jupyter/mxnet/load_your_own_mxnet_bert.ipynb b/jupyter/mxnet/load_your_own_mxnet_bert.ipynb index 1225127898e..3f64bf15208 100644 --- a/jupyter/mxnet/load_your_own_mxnet_bert.ipynb +++ b/jupyter/mxnet/load_your_own_mxnet_bert.ipynb @@ -39,9 +39,9 @@ "metadata": {}, "outputs": [], "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", + "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", "\n", - "%maven ai.djl:api:0.9.0\n", + "%maven ai.djl:api:0.10.0-SNAPSHOT\n", "%maven ai.djl.mxnet:mxnet-engine:0.9.0\n", "%maven ai.djl.mxnet:mxnet-model-zoo:0.9.0\n", "%maven org.slf4j:slf4j-api:1.7.26\n", @@ -224,8 +224,8 @@ " @SerializedName(\"idx_to_token\")\n", " List idx2token;\n", "\n", - " public static List parseToken(String file) {\n", - " try (InputStream is = new URL(file).openStream();\n", + " public static List parseToken(URL file) {\n", + " try (InputStream is = file.openStream();\n", " Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {\n", " return JsonUtils.GSON.fromJson(reader, VocabParser.class).idx2token;\n", " } catch (IOException e) {\n", @@ -241,12 +241,12 @@ "metadata": {}, "outputs": [], "source": [ - "var path = Paths.get(\"build/mxnet/bertqa/vocab.json\").toAbsolutePath();\n", + "URL url = Paths.get(\"build/mxnet/bertqa/vocab.json\").toUri().toURL();\n", "var vocabulary = SimpleVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromCustomizedFile(\"file://\" + path.toString(), VocabParser::parseToken)\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();" + " .optMinFrequency(1)\n", + " .addFromCustomizedFile(url, VocabParser::parseToken)\n", + " .optUnknownToken(\"[UNK]\")\n", + " .build();" ] }, { @@ -339,12 +339,11 @@ " \n", " @Override\n", " public void prepare(NDManager manager, Model model) throws IOException {\n", - " Path path = Paths.get(\"build/mxnet/bertqa/vocab.json\").toAbsolutePath();\n", + " URL path = Paths.get(\"build/mxnet/bertqa/vocab.json\").toUri().toURL();\n", " vocabulary =\n", " SimpleVocabulary.builder()\n", " .optMinFrequency(1)\n", - " .addFromCustomizedFile(\n", - " \"file://\" + path.toString(), VocabParser::parseToken)\n", + " .addFromCustomizedFile(path, VocabParser::parseToken)\n", " .optUnknownToken(\"[UNK]\")\n", " .build();\n", " tokenizer = new BertTokenizer();\n", @@ -489,7 +488,7 @@ "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", - "version": "12.0.2+10" + "version": "14.0.2+12" } }, "nbformat": 4, diff --git a/jupyter/pytorch/load_your_own_pytorch_bert.ipynb b/jupyter/pytorch/load_your_own_pytorch_bert.ipynb index 97b80fd0436..c21b9c46f8a 100644 --- a/jupyter/pytorch/load_your_own_pytorch_bert.ipynb +++ b/jupyter/pytorch/load_your_own_pytorch_bert.ipynb @@ -39,9 +39,9 @@ "metadata": {}, "outputs": [], "source": [ - "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", + "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", "\n", - "%maven ai.djl:api:0.9.0\n", + "%maven ai.djl:api:0.10.0-SNAPSHOT\n", "%maven ai.djl.pytorch:pytorch-engine:0.9.0\n", "%maven ai.djl.pytorch:pytorch-model-zoo:0.9.0\n", "%maven org.slf4j:slf4j-api:1.7.26\n", @@ -227,10 +227,10 @@ "source": [ "var path = Paths.get(\"build/pytorch/bertqa/vocab.txt\");\n", "var vocabulary = SimpleVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromTextFile(path.toString())\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();" + " .optMinFrequency(1)\n", + " .addFromTextFile(path)\n", + " .optUnknownToken(\"[UNK]\")\n", + " .build();" ] }, { @@ -302,7 +302,7 @@ " Path path = Paths.get(\"build/pytorch/bertqa/vocab.txt\");\n", " vocabulary = SimpleVocabulary.builder()\n", " .optMinFrequency(1)\n", - " .addFromTextFile(path.toString())\n", + " .addFromTextFile(path)\n", " .optUnknownToken(\"[UNK]\")\n", " .build();\n", " tokenizer = new BertTokenizer();\n", @@ -439,7 +439,7 @@ "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", - "version": "12.0.2+10" + "version": "14.0.2+12" } }, "nbformat": 4, diff --git a/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb b/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb index 3d4e07b46b8..57d97bcb8a2 100644 --- a/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb +++ b/jupyter/rank_classification_using_BERT_on_Amazon_Review.ipynb @@ -205,11 +205,11 @@ "}\n", "\n", "Criteria criteria = Criteria.builder()\n", - " .optApplication(Application.NLP.WORD_EMBEDDING)\n", - " .setTypes(NDList.class, NDList.class)\n", - " .optModelUrls(modelUrls)\n", - " .optProgress(new ProgressBar())\n", - " .build();\n", + " .optApplication(Application.NLP.WORD_EMBEDDING)\n", + " .setTypes(NDList.class, NDList.class)\n", + " .optModelUrls(modelUrls)\n", + " .optProgress(new ProgressBar())\n", + " .build();\n", "ZooModel embedding = ModelZoo.loadModel(criteria);" ] }, @@ -233,37 +233,37 @@ "source": [ "Predictor embedder = embedding.newPredictor();\n", "Block classifier = new SequentialBlock()\n", - " // text embedding layer\n", - " .add(\n", - " ndList -> {\n", - " NDArray data = ndList.singletonOrThrow();\n", - " NDList inputs = new NDList();\n", - " long batchSize = data.getShape().get(0);\n", - " float maxLength = data.getShape().get(1);\n", - " \n", - " if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n", - " inputs.add(data.toType(DataType.INT64, false));\n", - " inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));\n", - " inputs.add(data.getManager().arange(maxLength)\n", - " .toType(DataType.INT64, false)\n", - " .broadcast(data.getShape()));\n", - " } else {\n", - " inputs.add(data);\n", - " inputs.add(data.getManager().full(new Shape(batchSize), maxLength));\n", - " }\n", - " // run embedding\n", - " try {\n", - " return embedder.predict(inputs);\n", - " } catch (TranslateException e) {\n", - " throw new IllegalArgumentException(\"embedding error\", e);\n", - " }\n", - " })\n", - " // classification layer\n", - " .add(Linear.builder().setUnits(768).build()) // pre classifier\n", - " .add(Activation::relu)\n", - " .add(Dropout.builder().optRate(0.2f).build())\n", - " .add(Linear.builder().setUnits(5).build()) // 5 star rating\n", - " .addSingleton(nd -> nd.get(\":,0\")); // Take [CLS] as the head\n", + " // text embedding layer\n", + " .add(\n", + " ndList -> {\n", + " NDArray data = ndList.singletonOrThrow();\n", + " NDList inputs = new NDList();\n", + " long batchSize = data.getShape().get(0);\n", + " float maxLength = data.getShape().get(1);\n", + "\n", + " if (\"PyTorch\".equals(Engine.getInstance().getEngineName())) {\n", + " inputs.add(data.toType(DataType.INT64, false));\n", + " inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));\n", + " inputs.add(data.getManager().arange(maxLength)\n", + " .toType(DataType.INT64, false)\n", + " .broadcast(data.getShape()));\n", + " } else {\n", + " inputs.add(data);\n", + " inputs.add(data.getManager().full(new Shape(batchSize), maxLength));\n", + " }\n", + " // run embedding\n", + " try {\n", + " return embedder.predict(inputs);\n", + " } catch (TranslateException e) {\n", + " throw new IllegalArgumentException(\"embedding error\", e);\n", + " }\n", + " })\n", + " // classification layer\n", + " .add(Linear.builder().setUnits(768).build()) // pre classifier\n", + " .add(Activation::relu)\n", + " .add(Dropout.builder().optRate(0.2f).build())\n", + " .add(Linear.builder().setUnits(5).build()) // 5 star rating\n", + " .addSingleton(nd -> nd.get(\":,0\")); // Take [CLS] as the head\n", "Model model = Model.newInstance(\"AmazonReviewRatingClassification\");\n", "model.setBlock(classifier);" ] @@ -291,10 +291,10 @@ "source": [ "// Prepare the vocabulary\n", "SimpleVocabulary vocabulary = SimpleVocabulary.builder()\n", - " .optMinFrequency(1)\n", - " .addFromTextFile(Paths.get(embedding.getArtifact(\"vocab.txt\").toURI()))\n", - " .optUnknownToken(\"[UNK]\")\n", - " .build();\n", + " .optMinFrequency(1)\n", + " .addFromTextFile(embedding.getArtifact(\"vocab.txt\"))\n", + " .optUnknownToken(\"[UNK]\")\n", + " .build();\n", "// Prepare dataset\n", "int maxTokenLength = 64; // cutoff tokens length\n", "int batchSize = 8;\n", @@ -323,19 +323,19 @@ "source": [ "CheckpointsTrainingListener listener = new CheckpointsTrainingListener(\"build/model\");\n", " listener.setSaveModelCallback(\n", - " trainer -> {\n", - " TrainingResult result = trainer.getTrainingResult();\n", - " Model model = trainer.getModel();\n", - " // track for accuracy and loss\n", - " float accuracy = result.getValidateEvaluation(\"Accuracy\");\n", - " model.setProperty(\"Accuracy\", String.format(\"%.5f\", accuracy));\n", - " model.setProperty(\"Loss\", String.format(\"%.5f\", result.getValidateLoss()));\n", - " });\n", + " trainer -> {\n", + " TrainingResult result = trainer.getTrainingResult();\n", + " Model model = trainer.getModel();\n", + " // track for accuracy and loss\n", + " float accuracy = result.getValidateEvaluation(\"Accuracy\");\n", + " model.setProperty(\"Accuracy\", String.format(\"%.5f\", accuracy));\n", + " model.setProperty(\"Loss\", String.format(\"%.5f\", result.getValidateLoss()));\n", + " });\n", "DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type\n", - " .addEvaluator(new Accuracy())\n", - " .optDevices(Device.getDevices(1)) // train using single GPU\n", - " .addTrainingListeners(TrainingListener.Defaults.logging(\"build/model\"))\n", - " .addTrainingListeners(listener);" + " .addEvaluator(new Accuracy())\n", + " .optDevices(Device.getDevices(1)) // train using single GPU\n", + " .addTrainingListeners(TrainingListener.Defaults.logging(\"build/model\"))\n", + " .addTrainingListeners(listener);" ] }, { @@ -460,7 +460,7 @@ "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", - "version": "11.0.5+10-LTS" + "version": "14.0.2+12" } }, "nbformat": 4,