From fae23448d21cd929cfc607eba2a44b4aba758728 Mon Sep 17 00:00:00 2001 From: Danilo Burbano <37355249+danilojsl@users.noreply.github.com> Date: Thu, 7 Sep 2023 11:29:35 -0500 Subject: [PATCH] Adding ONNX Support to ALBERT Token and Sequence Classification and Question Answering annotators (#13956) * SPARKNLP-891 Adding ONNX support for AlbertQuestionAnswering SPARKNLP-892 Adding ONNX support for AlbertSequenceClassification SPARKNLP-893 Adding ONNX support for AlbertTokenClassification * SPARKNLP-891 Adding ONNX support for AlbertQuestionAnswering SPARKNLP-892 Adding ONNX support for AlbertSequenceClassification SPARKNLP-893 Adding ONNX support for AlbertTokenClassification --- .../ml/ai/AlbertClassification.scala | 234 ++++++++++++------ .../ml/ai/XXXForClassification.scala | 10 +- .../ml/ai/ZeroShotNerClassification.scala | 3 +- .../dl/AlbertForQuestionAnswering.scala | 72 ++++-- .../dl/AlbertForSequenceClassification.scala | 72 ++++-- .../dl/AlbertForTokenClassification.scala | 71 ++++-- 6 files changed, 338 insertions(+), 124 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala index aa6b561b0f34f7..d66e299015ccdb 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala @@ -16,9 +16,13 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.OnnxTensor +import com.johnsnowlabs.ml.onnx.OnnxWrapper import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} +import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} import org.tensorflow.ndarray.buffer.IntDataBuffer @@ -37,7 +41,8 @@ import scala.collection.JavaConverters._ * TF v2 signatures in Spark NLP */ private[johnsnowlabs] class AlbertClassification( - val tensorflowWrapper: TensorflowWrapper, + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], val spp: SentencePieceWrapper, configProtoBytes: Option[Array[Byte]] = None, tags: Map[String, Int], @@ -48,6 +53,10 @@ private[johnsnowlabs] class AlbertClassification( val _tfAlbertSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply()) + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name // keys representing the input and output tensors of the ALBERT model protected val sentencePadTokenId: Int = spp.getSppModel.pieceToId("[pad]") @@ -95,59 +104,13 @@ private[johnsnowlabs] class AlbertClassification( } def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { - val tensors = new TensorResources() - - val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max val batchLength = batch.length + val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max - val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) - val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) - val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) - - // [nb of encoded sentences , maxSentenceLength] - val shape = Array(batch.length.toLong, maxSentenceLength) - - batch.zipWithIndex - .foreach { case (sentence, idx) => - val offset = idx * maxSentenceLength - tokenBuffers.offset(offset).write(sentence) - maskBuffers - .offset(offset) - .write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1)) - segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0)) - } - - val runner = tensorflowWrapper - .getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false) - .runner - - val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers) - val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers) - val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers) - - runner - .feed( - _tfAlbertSignatures.getOrElse( - ModelSignatureConstants.InputIds.key, - "missing_input_id_key"), - tokenTensors) - .feed( - _tfAlbertSignatures - .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), - maskTensors) - .feed( - _tfAlbertSignatures - .getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"), - segmentTensors) - .fetch(_tfAlbertSignatures - .getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key")) - - val outs = runner.run().asScala - val rawScores = TensorResources.extractFloats(outs.head) - - outs.foreach(_.close()) - tensors.clearSession(outs) - tensors.clearTensors() + val rawScores = detectedEngine match { + case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true) + case _ => getRawScoresWithTF(batch, maxSentenceLength) + } val dim = rawScores.length / (batchLength * maxSentenceLength) val batchScores: Array[Array[Array[Float]]] = rawScores @@ -161,17 +124,39 @@ private[johnsnowlabs] class AlbertClassification( } def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + val batchLength = batch.length + val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max + + val rawScores = detectedEngine match { + case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true) + case _ => getRawScoresWithTF(batch, maxSentenceLength) + } + + val dim = rawScores.length / batchLength + val batchScores: Array[Array[Float]] = + rawScores + .grouped(dim) + .map(scores => + activation match { + case ActivationFunction.softmax => calculateSoftmax(scores) + case ActivationFunction.sigmoid => calculateSigmoid(scores) + case _ => calculateSoftmax(scores) + }) + .toArray + + batchScores + } + + private def getRawScoresWithTF(batch: Seq[Array[Int]], maxSentenceLength: Int): Array[Float] = { val tensors = new TensorResources() - val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max val batchLength = batch.length - val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) // [nb of encoded sentences , maxSentenceLength] - val shape = Array(batch.length.toLong, maxSentenceLength) + val shape = Array(batchLength.toLong, maxSentenceLength) batch.zipWithIndex .foreach { case (sentence, idx) => @@ -183,7 +168,7 @@ private[johnsnowlabs] class AlbertClassification( segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0)) } - val runner = tensorflowWrapper + val runner = tensorflowWrapper.get .getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false) .runner @@ -215,19 +200,51 @@ private[johnsnowlabs] class AlbertClassification( tensors.clearSession(outs) tensors.clearTensors() - val dim = rawScores.length / batchLength - val batchScores: Array[Array[Float]] = - rawScores - .grouped(dim) - .map(scores => - activation match { - case ActivationFunction.softmax => calculateSoftmax(scores) - case ActivationFunction.sigmoid => calculateSigmoid(scores) - case _ => calculateSoftmax(scores) - }) - .toArray + rawScores + } - batchScores + private def getRowScoresWithOnnx( + batch: Seq[Array[Int]], + maxSentenceLength: Int, + sequence: Boolean): Array[Float] = { + + val output = if (sequence) "logits" else "last_hidden_state" + + // [nb of encoded sentences , maxSentenceLength] + val (runner, env) = onnxWrapper.get.getSession() + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val segmentTensors = + OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) + + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors, + "token_type_ids" -> segmentTensors).asJava + + try { + val results = runner.run(inputs) + try { + val embeddings = results + .get(output) + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + + embeddings + } finally if (results != null) results.close() + } } def tagZeroShotSequence( @@ -237,10 +254,29 @@ private[johnsnowlabs] class AlbertClassification( activation: String): Array[Array[Float]] = ??? def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { - val tensors = new TensorResources() - + val batchLength = batch.length val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max + val (startLogits, endLogits) = detectedEngine match { + case ONNX.name => computeLogitsWithOnnx(batch, maxSentenceLength) + case _ => computeLogitsWithTF(batch, maxSentenceLength) + } + + val endDim = endLogits.length / batchLength + val endScores: Array[Array[Float]] = + endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray + + val startDim = startLogits.length / batchLength + val startScores: Array[Array[Float]] = + startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray + + (startScores, endScores) + } + + private def computeLogitsWithTF( + batch: Seq[Array[Int]], + maxSentenceLength: Int): (Array[Float], Array[Float]) = { val batchLength = batch.length + val tensors = new TensorResources() val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) @@ -271,7 +307,7 @@ private[johnsnowlabs] class AlbertClassification( }) } - val runner = tensorflowWrapper + val runner = tensorflowWrapper.get .getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false) .runner @@ -306,15 +342,55 @@ private[johnsnowlabs] class AlbertClassification( tensors.clearSession(outs) tensors.clearTensors() - val endDim = endLogits.length / batchLength - val endScores: Array[Array[Float]] = - endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray - - val startDim = startLogits.length / batchLength - val startScores: Array[Array[Float]] = - startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray + (endLogits, startLogits) + } - (startScores, endScores) + private def computeLogitsWithOnnx( + batch: Seq[Array[Int]], + maxSentenceLength: Int): (Array[Float], Array[Float]) = { + // [nb of encoded sentences , maxSentenceLength] + val (runner, env) = onnxWrapper.get.getSession() + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val segmentTensors = + OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) + + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors, + "token_type_ids" -> segmentTensors).asJava + + try { + val output = runner.run(inputs) + try { + val startLogits = output + .get("start_logits") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + + val endLogits = output + .get("end_logits") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + + (startLogits.slice(1, startLogits.length), endLogits.slice(1, endLogits.length)) + } finally if (output != null) output.close() + } } def findIndexedToken( diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala index b6e0e18863e819..919d6aa0d17c6e 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala @@ -16,6 +16,7 @@ package com.johnsnowlabs.ml.ai +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{ActivationFunction, Annotation, AnnotatorType} @@ -244,7 +245,8 @@ private[johnsnowlabs] trait XXXForClassification { documents: Seq[Annotation], maxSentenceLength: Int, caseSensitive: Boolean, - mergeTokenStrategy: String = MergeTokenStrategy.vocab): Seq[Annotation] = { + mergeTokenStrategy: String = MergeTokenStrategy.vocab, + engine: String = TensorFlow.name): Seq[Annotation] = { val questionAnnot = Seq(documents.head) val contextAnnot = documents.drop(1) @@ -264,9 +266,13 @@ private[johnsnowlabs] trait XXXForClassification { val startIndex = startScores.zipWithIndex.maxBy(_._1) val endIndex = endScores.zipWithIndex.maxBy(_._1) + val offsetStartIndex = if (engine == TensorFlow.name) 2 else 1 + val offsetEndIndex = if (engine == TensorFlow.name) 1 else 0 + val allTokenPieces = wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens) - val decodedAnswer = allTokenPieces.slice(startIndex._2 - 2, endIndex._2 - 1) + val decodedAnswer = + allTokenPieces.slice(startIndex._2 - offsetStartIndex, endIndex._2 - offsetEndIndex) val content = mergeTokenStrategy match { case MergeTokenStrategy.vocab => diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala index 0553b1a94c3028..57a60fe26ea175 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala @@ -62,7 +62,8 @@ private[johnsnowlabs] class ZeroShotNerClassification( documents: Seq[Annotation], maxSentenceLength: Int, caseSensitive: Boolean, - mergeTokenStrategy: String): Seq[Annotation] = { + mergeTokenStrategy: String, + engine: String): Seq[Annotation] = { val questionAnnot = Seq(documents.head) val contextAnnot = documents.drop(1) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala index 217fbc6ca25947..3c7e1347e70be1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala @@ -17,6 +17,7 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.{AlbertClassification, MergeTokenStrategy} +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, @@ -28,8 +29,9 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.TensorFlow +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.embeddings.BertEmbeddings import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{IntArrayParam, IntParam} @@ -116,6 +118,7 @@ class AlbertForQuestionAnswering(override val uid: String) extends AnnotatorModel[AlbertForQuestionAnswering] with HasBatchedAnnotate[AlbertForQuestionAnswering] with WriteTensorflowModel + with WriteOnnxModel with WriteSentencePieceModel with HasCaseSensitiveProperties with HasEngine { @@ -196,13 +199,15 @@ class AlbertForQuestionAnswering(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], spp: SentencePieceWrapper): AlbertForQuestionAnswering = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new AlbertClassification( tensorflowWrapper, + onnxWrapper, spp, configProtoBytes = getConfigProtoBytes, tags = Map.empty[String, Int], @@ -244,7 +249,8 @@ class AlbertForQuestionAnswering(override val uid: String) documents, $(maxSentenceLength), $(caseSensitive), - MergeTokenStrategy.sentencePiece) + MergeTokenStrategy.sentencePiece, + getEngine) } else { Seq.empty[Annotation] } @@ -253,13 +259,26 @@ class AlbertForQuestionAnswering(override val uid: String) override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_albert_classification", - AlbertForQuestionAnswering.tfFile, - configProtoBytes = getConfigProtoBytes) + val suffix = "_albert_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + AlbertForQuestionAnswering.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + AlbertForQuestionAnswering.onnxFile) + } + writeSentencePieceModel( path, spark, @@ -291,17 +310,37 @@ trait ReadablePretrainedAlbertForQAModel trait ReadAlbertForQuestionAnsweringDLModel extends ReadTensorflowModel + with ReadOnnxModel with ReadSentencePieceModel { this: ParamsAndFeaturesReadable[AlbertForQuestionAnswering] => override val tfFile: String = "albert_classification_tensorflow" + override val onnxFile: String = "albert_classification_onnx" override val sppFile: String = "albert_spp" def readModel(instance: AlbertForQuestionAnswering, path: String, spark: SparkSession): Unit = { - val tf = readTensorflowModel(path, spark, "_albert_classification_tf", initAllTables = false) val spp = readSentencePieceModel(path, spark, "_albert_spp", sppFile) - instance.setModelIfNotSet(spark, tf, spp) + + instance.getEngine match { + case TensorFlow.name => + val tf = + readTensorflowModel(path, spark, "_albert_classification_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tf), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "_albert_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } + } addReader(readModel) @@ -318,7 +357,7 @@ trait ReadAlbertForQuestionAnsweringDLModel detectedEngine match { case TensorFlow.name => - val (wrapper, signatures) = + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -331,7 +370,12 @@ trait ReadAlbertForQuestionAnsweringDLModel */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper, spModel) + .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala index f0d61bcaade650..16b9e6c196e37d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala @@ -17,6 +17,7 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.AlbertClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, @@ -29,7 +30,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.TensorFlow +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -124,6 +125,7 @@ class AlbertForSequenceClassification(override val uid: String) extends AnnotatorModel[AlbertForSequenceClassification] with HasBatchedAnnotate[AlbertForSequenceClassification] with WriteTensorflowModel + with WriteOnnxModel with WriteSentencePieceModel with HasCaseSensitiveProperties with HasClassifierActivationProperties @@ -239,13 +241,15 @@ class AlbertForSequenceClassification(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], spp: SentencePieceWrapper): AlbertForSequenceClassification = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new AlbertClassification( tensorflowWrapper, + onnxWrapper, spp, configProtoBytes = getConfigProtoBytes, tags = $$(labels), @@ -305,13 +309,26 @@ class AlbertForSequenceClassification(override val uid: String) override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_albert_classification", - AlbertForSequenceClassification.tfFile, - configProtoBytes = getConfigProtoBytes) + val suffix = "_albert_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + "_albert_classification", + AlbertForSequenceClassification.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + AlbertForSequenceClassification.onnxFile) + } + writeSentencePieceModel( path, spark, @@ -341,10 +358,14 @@ trait ReadablePretrainedAlbertForSequenceModel super.pretrained(name, lang, remoteLoc) } -trait ReadAlbertForSequenceDLModel extends ReadTensorflowModel with ReadSentencePieceModel { +trait ReadAlbertForSequenceDLModel + extends ReadTensorflowModel + with ReadOnnxModel + with ReadSentencePieceModel { this: ParamsAndFeaturesReadable[AlbertForSequenceClassification] => override val tfFile: String = "albert_classification_tensorflow" + override val onnxFile: String = "albert_classification_onnx" override val sppFile: String = "albert_spp" def readModel( @@ -352,9 +373,27 @@ trait ReadAlbertForSequenceDLModel extends ReadTensorflowModel with ReadSentence path: String, spark: SparkSession): Unit = { - val tf = readTensorflowModel(path, spark, "_albert_classification_tf", initAllTables = false) val spp = readSentencePieceModel(path, spark, "_albert_spp", sppFile) - instance.setModelIfNotSet(spark, tf, spp) + + instance.getEngine match { + case TensorFlow.name => + val tf = + readTensorflowModel(path, spark, "_albert_classification_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tf), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "_albert_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } + } addReader(readModel) @@ -373,7 +412,7 @@ trait ReadAlbertForSequenceDLModel extends ReadTensorflowModel with ReadSentence detectedEngine match { case TensorFlow.name => - val (wrapper, signatures) = + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -386,7 +425,12 @@ trait ReadAlbertForSequenceDLModel extends ReadTensorflowModel with ReadSentence */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper, spModel) + .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala index 89e61223d63097..8f91eb208ffc4b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala @@ -17,6 +17,7 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.AlbertClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, @@ -29,7 +30,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -123,6 +124,7 @@ class AlbertForTokenClassification(override val uid: String) extends AnnotatorModel[AlbertForTokenClassification] with HasBatchedAnnotate[AlbertForTokenClassification] with WriteTensorflowModel + with WriteOnnxModel with WriteSentencePieceModel with HasCaseSensitiveProperties with HasEngine { @@ -217,13 +219,15 @@ class AlbertForTokenClassification(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], spp: SentencePieceWrapper): AlbertForTokenClassification = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new AlbertClassification( tensorflowWrapper, + onnxWrapper, spp, configProtoBytes = getConfigProtoBytes, tags = $$(labels), @@ -276,13 +280,26 @@ class AlbertForTokenClassification(override val uid: String) override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_albert_classification", - AlbertForTokenClassification.tfFile, - configProtoBytes = getConfigProtoBytes) + val suffix = "_albert_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + AlbertForTokenClassification.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + AlbertForTokenClassification.onnxFile) + } + writeSentencePieceModel( path, spark, @@ -312,10 +329,14 @@ trait ReadablePretrainedAlbertForTokenModel remoteLoc: String): AlbertForTokenClassification = super.pretrained(name, lang, remoteLoc) } -trait ReadAlbertForTokenDLModel extends ReadTensorflowModel with ReadSentencePieceModel { +trait ReadAlbertForTokenDLModel + extends ReadTensorflowModel + with ReadOnnxModel + with ReadSentencePieceModel { this: ParamsAndFeaturesReadable[AlbertForTokenClassification] => override val tfFile: String = "albert_classification_tensorflow" + override val onnxFile: String = "albert_classification_onnx" override val sppFile: String = "albert_spp" def readModel( @@ -323,9 +344,27 @@ trait ReadAlbertForTokenDLModel extends ReadTensorflowModel with ReadSentencePie path: String, spark: SparkSession): Unit = { - val tf = readTensorflowModel(path, spark, "_albert_classification_tf", initAllTables = false) val spp = readSentencePieceModel(path, spark, "_albert_spp", sppFile) - instance.setModelIfNotSet(spark, tf, spp) + + instance.getEngine match { + case TensorFlow.name => + val tf = + readTensorflowModel(path, spark, "_albert_classification_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tf), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "_albert_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } + } addReader(readModel) @@ -344,7 +383,7 @@ trait ReadAlbertForTokenDLModel extends ReadTensorflowModel with ReadSentencePie detectedEngine match { case TensorFlow.name => - val (wrapper, signatures) = + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -357,8 +396,12 @@ trait ReadAlbertForTokenDLModel extends ReadTensorflowModel with ReadSentencePie */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper, spModel) + .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => throw new Exception(notSupportedEngineError) }