From b362e1aef804fcb98728651e6a0db3f9141d77a0 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Tue, 12 Sep 2023 21:47:50 -0500 Subject: [PATCH 1/3] SPARKNLP-907 Allows setting up ONNX configs through spark session --- .../scala/com/johnsnowlabs/ml/ai/Albert.scala | 14 +- .../ml/ai/AlbertClassification.scala | 33 +++-- .../ml/ai/BartClassification.scala | 14 +- .../scala/com/johnsnowlabs/ml/ai/Bert.scala | 27 ++-- .../ml/ai/BertClassification.scala | 14 +- .../com/johnsnowlabs/ml/ai/CamemBert.scala | 14 +- .../ml/ai/CamemBertClassification.scala | 14 +- .../com/johnsnowlabs/ml/ai/DeBerta.scala | 12 +- .../ml/ai/DeBertaClassification.scala | 14 +- .../com/johnsnowlabs/ml/ai/DistilBert.scala | 14 +- .../ml/ai/DistilBertClassification.scala | 14 +- .../scala/com/johnsnowlabs/ml/ai/E5.scala | 18 ++- .../scala/com/johnsnowlabs/ml/ai/MPNet.scala | 18 ++- .../com/johnsnowlabs/ml/ai/RoBerta.scala | 14 +- .../ml/ai/RoBertaClassification.scala | 14 +- .../com/johnsnowlabs/ml/ai/Whisper.scala | 14 +- .../ml/ai/XXXForClassification.scala | 27 ++-- .../ml/ai/XlmRoBertaClassification.scala | 14 +- .../com/johnsnowlabs/ml/ai/XlmRoberta.scala | 14 +- .../ml/ai/XlnetClassification.scala | 14 +- .../ml/ai/ZeroShotNerClassification.scala | 12 +- .../ml/onnx/OnnxSerializeModel.scala | 9 +- .../johnsnowlabs/ml/onnx/OnnxWrapper.scala | 128 +++++++++++++----- .../com/johnsnowlabs/nlp/AnnotatorModel.scala | 8 +- .../nlp/annotators/audio/WhisperForCTC.scala | 15 +- .../dl/AlbertForQuestionAnswering.scala | 25 ++-- .../dl/AlbertForSequenceClassification.scala | 25 ++-- .../dl/AlbertForTokenClassification.scala | 14 +- .../dl/BertForQuestionAnswering.scala | 5 +- .../dl/BertForSequenceClassification.scala | 5 +- .../dl/BertForTokenClassification.scala | 5 +- .../dl/CamemBertForQuestionAnswering.scala | 5 +- .../CamemBertForSequenceClassification.scala | 5 +- .../dl/CamemBertForTokenClassification.scala | 5 +- .../dl/DeBertaForQuestionAnswering.scala | 5 +- .../dl/DeBertaForSequenceClassification.scala | 3 +- .../dl/DeBertaForTokenClassification.scala | 3 +- .../dl/DistilBertForQuestionAnswering.scala | 3 +- .../DistilBertForSequenceClassification.scala | 3 +- .../dl/DistilBertForTokenClassification.scala | 3 +- .../dl/LongformerForQuestionAnswering.scala | 3 +- .../LongformerForSequenceClassification.scala | 3 +- .../dl/LongformerForTokenClassification.scala | 3 +- .../dl/RoBertaForQuestionAnswering.scala | 3 +- .../dl/RoBertaForSequenceClassification.scala | 3 +- .../dl/RoBertaForTokenClassification.scala | 3 +- .../dl/XlmRoBertaForQuestionAnswering.scala | 3 +- .../XlmRoBertaForSequenceClassification.scala | 3 +- .../dl/XlmRoBertaForTokenClassification.scala | 3 +- .../dl/XlnetForSequenceClassification.scala | 3 +- .../dl/XlnetForTokenClassification.scala | 3 +- .../nlp/embeddings/AlbertEmbeddings.scala | 23 ++-- .../nlp/embeddings/BertEmbeddings.scala | 11 +- .../embeddings/BertSentenceEmbeddings.scala | 12 +- .../nlp/embeddings/CamemBertEmbeddings.scala | 11 +- .../nlp/embeddings/DeBertaEmbeddings.scala | 11 +- .../nlp/embeddings/DistilBertEmbeddings.scala | 11 +- .../nlp/embeddings/Doc2VecModel.scala | 7 - .../nlp/embeddings/E5Embeddings.scala | 11 +- .../nlp/embeddings/LongformerEmbeddings.scala | 3 +- .../nlp/embeddings/MPNetEmbeddings.scala | 11 +- .../nlp/embeddings/RoBertaEmbeddings.scala | 11 +- .../nlp/embeddings/Word2VecModel.scala | 7 - .../nlp/embeddings/XlmRoBertaEmbeddings.scala | 14 +- 64 files changed, 524 insertions(+), 269 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala index bd4846945dc4bc..0ce9625c3c6675 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -93,12 +94,14 @@ private[johnsnowlabs] class Albert( private def sessionWarmup(): Unit = { val dummyInput = Array(101, 2292, 1005, 1055, 4010, 6279, 1996, 5219, 2005, 1996, 2034, 28937, 1012, 102) - tag(Seq(dummyInput)) + tag(Seq(dummyInput), None) } sessionWarmup() - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length @@ -107,7 +110,7 @@ private[johnsnowlabs] class Albert( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -204,7 +207,8 @@ private[johnsnowlabs] class Albert( tokenizedSentences: Seq[TokenizedSentence], batchSize: Int, maxSentenceLength: Int, - caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = { + caseSensitive: Boolean, + sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = { val wordPieceTokenizedSentences = tokenizeWithAlignment(tokenizedSentences, maxSentenceLength, caseSensitive) @@ -217,7 +221,7 @@ private[johnsnowlabs] class Albert( SentenceStartTokenId, SentenceEndTokenId, SentencePadTokenId) - val vectors = tag(encoded) + val vectors = tag(encoded, sparkSession) /*Combine tokens and calculated embeddings*/ batch.zip(vectors).map { case (sentence, tokenVectors) => diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala index d66e299015ccdb..c3cbf9a7d119c6 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala @@ -21,10 +21,10 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} -import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -103,12 +103,15 @@ private[johnsnowlabs] class AlbertClassification( sentenceTokenPieces } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val batchLength = batch.length val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max val rawScores = detectedEngine match { - case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true) + case ONNX.name => + getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true, sparkSession) case _ => getRawScoresWithTF(batch, maxSentenceLength) } @@ -123,12 +126,16 @@ private[johnsnowlabs] class AlbertClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val batchLength = batch.length val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max val rawScores = detectedEngine match { - case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true) + case ONNX.name => + getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true, sparkSession) case _ => getRawScoresWithTF(batch, maxSentenceLength) } @@ -206,12 +213,13 @@ private[johnsnowlabs] class AlbertClassification( private def getRowScoresWithOnnx( batch: Seq[Array[Int]], maxSentenceLength: Int, - sequence: Boolean): Array[Float] = { + sequence: Boolean, + sparkSession: Option[SparkSession]): Array[Float] = { val output = if (sequence) "logits" else "last_hidden_state" // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -253,11 +261,13 @@ private[johnsnowlabs] class AlbertClassification( contradictionId: Int, activation: String): Array[Array[Float]] = ??? - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { val batchLength = batch.length val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max val (startLogits, endLogits) = detectedEngine match { - case ONNX.name => computeLogitsWithOnnx(batch, maxSentenceLength) + case ONNX.name => computeLogitsWithOnnx(batch, maxSentenceLength, sparkSession) case _ => computeLogitsWithTF(batch, maxSentenceLength) } @@ -347,9 +357,10 @@ private[johnsnowlabs] class AlbertClassification( private def computeLogitsWithOnnx( batch: Seq[Array[Int]], - maxSentenceLength: Int): (Array[Float], Array[Float]) = { + maxSentenceLength: Int, + sparkSession: Option[SparkSession]): (Array[Float], Array[Float]) = { // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/BartClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/BartClassification.scala index 9a3f612aa02eb3..c73f75fffa69fb 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/BartClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/BartClassification.scala @@ -22,6 +22,7 @@ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -127,7 +128,9 @@ private[johnsnowlabs] class BartClassification( } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -185,7 +188,10 @@ private[johnsnowlabs] class BartClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -317,7 +323,9 @@ private[johnsnowlabs] class BartClassification( .toArray } - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala index c291b9f23c549d..384c8826556441 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -71,18 +72,20 @@ private[johnsnowlabs] class Bert( val dummyInput = Array(101, 2292, 1005, 1055, 4010, 6279, 1996, 5219, 2005, 1996, 2034, 28937, 1012, 102) if (modelArch == ModelArch.wordEmbeddings) { - tag(Seq(dummyInput)) + tag(Seq(dummyInput), None) } else if (modelArch == ModelArch.sentenceEmbeddings) { if (isSBert) tagSequenceSBert(Seq(dummyInput)) else - tagSequence(Seq(dummyInput)) + tagSequence(Seq(dummyInput), None) } } sessionWarmup() - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length @@ -90,7 +93,7 @@ private[johnsnowlabs] class Bert( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -185,7 +188,9 @@ private[johnsnowlabs] class Bert( } - def tagSequence(batch: Seq[Array[Int]]): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length @@ -193,7 +198,7 @@ private[johnsnowlabs] class Bert( val embeddings = detectedEngine match { case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -346,7 +351,8 @@ private[johnsnowlabs] class Bert( originalTokenSentences: Seq[TokenizedSentence], batchSize: Int, maxSentenceLength: Int, - caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = { + caseSensitive: Boolean, + sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = { /*Run embeddings calculation by batches*/ sentences.zipWithIndex @@ -357,7 +363,7 @@ private[johnsnowlabs] class Bert( maxSentenceLength, sentenceStartTokenId, sentenceEndTokenId) - val vectors = tag(encoded) + val vectors = tag(encoded, sparkSession) /*Combine tokens and calculated embeddings*/ batch.zip(vectors).map { case (sentence, tokenVectors) => @@ -408,7 +414,8 @@ private[johnsnowlabs] class Bert( sentences: Seq[Sentence], batchSize: Int, maxSentenceLength: Int, - isLong: Boolean = false): Seq[Annotation] = { + isLong: Boolean = false, + sparkSession: Option[SparkSession]): Seq[Annotation] = { /*Run embeddings calculation by batches*/ tokens @@ -426,7 +433,7 @@ private[johnsnowlabs] class Bert( val embeddings = if (isLong) { tagSequenceSBert(encoded) } else { - tagSequence(encoded) + tagSequence(encoded, sparkSession) } sentencesBatch.zip(embeddings).map { case (sentence, vectors) => diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala index 63babcc94e019b..0159d613d12437 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala @@ -21,6 +21,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -134,7 +135,9 @@ private[johnsnowlabs] class BertClassification( } } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -198,7 +201,10 @@ private[johnsnowlabs] class BertClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -338,7 +344,9 @@ private[johnsnowlabs] class BertClassification( .toArray } - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala index eb1d421b70ce0b..ac61c2cbe9425e 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -68,12 +69,14 @@ private[johnsnowlabs] class CamemBert( private def sessionWarmup(): Unit = { val dummyInput = Array(5, 54, 110, 11, 10, 15540, 215, 1280, 808, 25352, 1782, 808, 24696, 378, 17409, 9, 6) - tag(Seq(dummyInput)) + tag(Seq(dummyInput), None) } sessionWarmup() - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length @@ -82,7 +85,7 @@ private[johnsnowlabs] class CamemBert( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -167,7 +170,8 @@ private[johnsnowlabs] class CamemBert( tokenizedSentences: Seq[TokenizedSentence], batchSize: Int, maxSentenceLength: Int, - caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = { + caseSensitive: Boolean, + sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = { val wordPieceTokenizedSentences = tokenizeWithAlignment(tokenizedSentences, maxSentenceLength, caseSensitive) @@ -180,7 +184,7 @@ private[johnsnowlabs] class CamemBert( SentenceStartTokenId, SentenceEndTokenId, SentencePadTokenId) - val vectors = tag(encoded) + val vectors = tag(encoded, sparkSession) /*Combine tokens and calculated embeddings*/ batch.zip(vectors).map { case (sentence, tokenVectors) => diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala index 5d85697eb3f9e2..3973f603431705 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala @@ -23,6 +23,7 @@ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} import org.apache.spark.ml.param.Params import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.{IntDataBuffer, LongDataBuffer} import scala.collection.JavaConverters._ @@ -111,7 +112,9 @@ private[johnsnowlabs] class CamemBertClassification( sentenceTokenPieces } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -169,7 +172,10 @@ private[johnsnowlabs] class CamemBertClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -238,7 +244,9 @@ private[johnsnowlabs] class CamemBertClassification( contradictionId: Int, activation: String): Array[Array[Float]] = ??? - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala index bbf4ac83b1862b..963809cfa37ad5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -59,7 +60,9 @@ class DeBerta( private val SentencePadTokenId = spp.getSppModel.pieceToId("[PAD]") private val SentencePieceDelimiterId = spp.getSppModel.pieceToId("▁") - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { /* Actual size of each sentence to skip padding in the TF model */ val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length @@ -68,7 +71,7 @@ class DeBerta( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -162,7 +165,8 @@ class DeBerta( tokenizedSentences: Seq[TokenizedSentence], batchSize: Int, maxSentenceLength: Int, - caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = { + caseSensitive: Boolean, + sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = { val wordPieceTokenizedSentences = tokenizeWithAlignment(tokenizedSentences, maxSentenceLength, caseSensitive) @@ -175,7 +179,7 @@ class DeBerta( SentenceStartTokenId, SentenceEndTokenId, SentencePadTokenId) - val vectors = tag(encoded) + val vectors = tag(encoded, sparkSession) /*Combine tokens and calculated embeddings*/ batch.zip(vectors).map { case (sentence, tokenVectors) => diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala index 5022105f47d588..7a8a8801719d4c 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala @@ -21,6 +21,7 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -94,7 +95,9 @@ private[johnsnowlabs] class DeBertaClassification( sentenceTokenPieces } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -160,7 +163,10 @@ private[johnsnowlabs] class DeBertaClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -236,7 +242,9 @@ private[johnsnowlabs] class DeBertaClassification( contradictionId: Int, activation: String): Array[Array[Float]] = ??? - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala index afa6a3b8bb29d5..95857b13650941 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBert.scala @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -87,7 +88,7 @@ private[johnsnowlabs] class DistilBert( val dummyInput = Array(101, 2292, 1005, 1055, 4010, 6279, 1996, 5219, 2005, 1996, 2034, 28937, 1012, 102) if (modelArch == ModelArch.wordEmbeddings) { - tag(Seq(dummyInput)) + tag(Seq(dummyInput), None) } else if (modelArch == ModelArch.sentenceEmbeddings) { tagSequence(Seq(dummyInput)) } @@ -95,7 +96,9 @@ private[johnsnowlabs] class DistilBert( sessionWarmup() - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length @@ -103,7 +106,7 @@ private[johnsnowlabs] class DistilBert( val embeddings = detectedEngine match { case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -239,7 +242,8 @@ private[johnsnowlabs] class DistilBert( originalTokenSentences: Seq[TokenizedSentence], batchSize: Int, maxSentenceLength: Int, - caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = { + caseSensitive: Boolean, + sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = { /*Run embeddings calculation by batches*/ sentences.zipWithIndex @@ -250,7 +254,7 @@ private[johnsnowlabs] class DistilBert( maxSentenceLength, sentenceStartTokenId, sentenceEndTokenId) - val vectors = tag(encoded) + val vectors = tag(encoded, sparkSession) /*Combine tokens and calculated embeddings*/ batch.zip(vectors).map { case (sentence, tokenVectors) => diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala index 813a83fc6b08a2..c0b5c0ae0a2d37 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DistilBertClassification.scala @@ -21,6 +21,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -134,7 +135,9 @@ private[johnsnowlabs] class DistilBertClassification( } } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -192,7 +195,10 @@ private[johnsnowlabs] class DistilBertClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -323,7 +329,9 @@ private[johnsnowlabs] class DistilBertClassification( .grouped(dim) .toArray } - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/E5.scala b/src/main/scala/com/johnsnowlabs/ml/ai/E5.scala index 6ce563e30ded88..f34968c6086b4c 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/E5.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/E5.scala @@ -23,6 +23,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -62,10 +63,12 @@ private[johnsnowlabs] class E5( * @return * sentence embeddings */ - private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = { + private def getSentenceEmbedding( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val embeddings = detectedEngine match { case ONNX.name => - getSentenceEmbeddingFromOnnx(batch) + getSentenceEmbeddingFromOnnx(batch, sparkSession) case _ => getSentenceEmbeddingFromTF(batch) } @@ -147,11 +150,13 @@ private[johnsnowlabs] class E5( sentenceEmbeddingsFloatsArray } - private def getSentenceEmbeddingFromOnnx(batch: Seq[Array[Int]]): Array[Array[Float]] = { + private def getSentenceEmbeddingFromOnnx( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val batchLength = batch.length val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) val maskTensors = @@ -205,7 +210,8 @@ private[johnsnowlabs] class E5( sentences: Seq[Annotation], tokenizedSentences: Seq[WordpieceTokenizedSentence], batchSize: Int, - maxSentenceLength: Int): Seq[Annotation] = { + maxSentenceLength: Int, + sparkSession: Option[SparkSession]): Seq[Annotation] = { tokenizedSentences .zip(sentences) @@ -218,7 +224,7 @@ private[johnsnowlabs] class E5( Array(sentenceStartTokenId) ++ x .map(y => y.pieceId) .take(maxSentenceLength - 2) ++ Array(sentenceEndTokenId)) - val sentenceEmbeddings = getSentenceEmbedding(tokens) + val sentenceEmbeddings = getSentenceEmbedding(tokens, sparkSession) batch.zip(sentenceEmbeddings).map { case (sentence, vectors) => Annotation( diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala b/src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala index 3efa6b1acaa92b..08b5b17147c6f5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala @@ -23,6 +23,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -63,10 +64,12 @@ private[johnsnowlabs] class MPNet( * @return * sentence embeddings */ - private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = { + private def getSentenceEmbedding( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val embeddings = detectedEngine match { case ONNX.name => - getSentenceEmbeddingFromOnnx(batch) + getSentenceEmbeddingFromOnnx(batch, sparkSession) case _ => getSentenceEmbeddingFromTF(batch) } @@ -154,11 +157,13 @@ private[johnsnowlabs] class MPNet( sentenceEmbeddingsFloatsArray } - private def getSentenceEmbeddingFromOnnx(batch: Seq[Array[Int]]): Array[Array[Float]] = { + private def getSentenceEmbeddingFromOnnx( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val batchLength = batch.length val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) val maskTensors = @@ -209,7 +214,8 @@ private[johnsnowlabs] class MPNet( sentences: Seq[Annotation], tokenizedSentences: Seq[WordpieceTokenizedSentence], batchSize: Int, - maxSentenceLength: Int): Seq[Annotation] = { + maxSentenceLength: Int, + sparkSession: Option[SparkSession]): Seq[Annotation] = { tokenizedSentences .zip(sentences) @@ -222,7 +228,7 @@ private[johnsnowlabs] class MPNet( Array(sentenceStartTokenId) ++ x .map(y => y.pieceId) .take(maxSentenceLength - 2) ++ Array(sentenceEndTokenId)) - val sentenceEmbeddings = getSentenceEmbedding(tokens) + val sentenceEmbeddings = getSentenceEmbedding(tokens, sparkSession) batch.zip(sentenceEmbeddings).map { case (sentence, vectors) => Annotation( diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala index 1e903ff0d4a345..db0d918b6fe2da 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/RoBerta.scala @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -62,7 +63,7 @@ private[johnsnowlabs] class RoBerta( val dummyInput = Array(0, 7939, 18, 3279, 658, 5, 19374, 13, 5, 78, 42752, 4, 2) if (modelArch == ModelArch.wordEmbeddings) { - tag(Seq(dummyInput)) + tag(Seq(dummyInput), None) } else if (modelArch == ModelArch.sentenceEmbeddings) { tagSequence(Seq(dummyInput)) } @@ -70,7 +71,9 @@ private[johnsnowlabs] class RoBerta( sessionWarmup() - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length @@ -79,7 +82,7 @@ private[johnsnowlabs] class RoBerta( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -216,7 +219,8 @@ private[johnsnowlabs] class RoBerta( originalTokenSentences: Seq[TokenizedSentence], batchSize: Int, maxSentenceLength: Int, - caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = { + caseSensitive: Boolean, + sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = { /*Run embeddings calculation by batches*/ sentences.zipWithIndex @@ -228,7 +232,7 @@ private[johnsnowlabs] class RoBerta( sentenceStartTokenId, sentenceEndTokenId, padTokenId) - val vectors = tag(encoded) + val vectors = tag(encoded, sparkSession) /*Combine tokens and calculated embeddings*/ batch.zip(vectors).map { case (sentence, tokenVectors) => diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala index 3e80bedef517b1..d53c06bd4f2567 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala @@ -22,6 +22,7 @@ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -127,7 +128,9 @@ private[johnsnowlabs] class RoBertaClassification( } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -185,7 +188,10 @@ private[johnsnowlabs] class RoBertaClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -319,7 +325,9 @@ private[johnsnowlabs] class RoBertaClassification( .toArray } - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala index 8cf9559da229e2..9e5b53cf00f316 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Whisper.scala @@ -34,6 +34,7 @@ import com.johnsnowlabs.ml.util._ import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.WhisperPreprocessor import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{SpecialTokens, WhisperTokenDecoder} import com.johnsnowlabs.nlp.{Annotation, AnnotationAudio, AnnotatorType} +import org.apache.spark.sql.SparkSession import org.slf4j.LoggerFactory import org.tensorflow.{Session, Tensor} @@ -171,7 +172,8 @@ private[johnsnowlabs] class Whisper( topP = 1.0, repetitionPenalty = 1.0, noRepeatNgramSize = 0, - randomSeed = None) + randomSeed = None, + sparkSession = None) } sessionWarmup() @@ -268,7 +270,8 @@ private[johnsnowlabs] class Whisper( noRepeatNgramSize: Int, randomSeed: Option[Long], task: Option[String] = None, - language: Option[String] = None): Seq[Annotation] = { + language: Option[String] = None, + sparkSession: Option[SparkSession]): Seq[Annotation] = { if (beamSize > 1) logger.warn( @@ -332,9 +335,10 @@ private[johnsnowlabs] class Whisper( tokenIds case ONNX.name => - val (encoderSession, env) = onnxWrappers.get.encoder.getSession() - val decoderSession = onnxWrappers.get.decoder.getSession()._1 - val decoderWithPastSession = onnxWrappers.get.decoderWithPast.getSession()._1 + val (encoderSession, env) = onnxWrappers.get.encoder.getSession(sparkSession) + val decoderSession = onnxWrappers.get.decoder.getSession(sparkSession)._1 + val decoderWithPastSession = + onnxWrappers.get.decoderWithPast.getSession(sparkSession)._1 val encodedBatchTensor: OnnxTensor = encode(featuresBatch, None, Some((encoderSession, env))).asInstanceOf[OnnxTensor] diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala index 919d6aa0d17c6e..d7fc852c6ce470 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala @@ -19,6 +19,7 @@ package com.johnsnowlabs.ml.ai import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{ActivationFunction, Annotation, AnnotatorType} +import org.apache.spark.sql.SparkSession private[johnsnowlabs] trait XXXForClassification { @@ -32,7 +33,8 @@ private[johnsnowlabs] trait XXXForClassification { batchSize: Int, maxSentenceLength: Int, caseSensitive: Boolean, - tags: Map[String, Int]): Seq[Annotation] = { + tags: Map[String, Int], + sparkSession: Option[SparkSession]): Seq[Annotation] = { val wordPieceTokenizedSentences = tokenizeWithAlignment(tokenizedSentences, maxSentenceLength, caseSensitive) @@ -42,7 +44,7 @@ private[johnsnowlabs] trait XXXForClassification { .grouped(batchSize) .flatMap { batch => val encoded = encode(batch, maxSentenceLength) - val logits = tag(encoded) + val logits = tag(encoded, sparkSession) /*Combine tokens and calculated logits*/ batch.zip(logits).flatMap { case (sentence, tokenVectors) => @@ -71,7 +73,8 @@ private[johnsnowlabs] trait XXXForClassification { caseSensitive: Boolean, coalesceSentences: Boolean = false, tags: Map[String, Int], - activation: String = ActivationFunction.softmax): Seq[Annotation] = { + activation: String = ActivationFunction.softmax, + sparkSession: Option[SparkSession]): Seq[Annotation] = { val wordPieceTokenizedSentences = tokenizeWithAlignment(tokenizedSentences, maxSentenceLength, caseSensitive) @@ -84,7 +87,7 @@ private[johnsnowlabs] trait XXXForClassification { .flatMap { batch => val tokensBatch = batch.map(x => (x._1._1, x._2)) val encoded = encode(tokensBatch, maxSentenceLength) - val logits = tagSequence(encoded, activation) + val logits = tagSequence(encoded, activation, sparkSession) activation match { case ActivationFunction.softmax => if (coalesceSentences) { @@ -246,7 +249,8 @@ private[johnsnowlabs] trait XXXForClassification { maxSentenceLength: Int, caseSensitive: Boolean, mergeTokenStrategy: String = MergeTokenStrategy.vocab, - engine: String = TensorFlow.name): Seq[Annotation] = { + engine: String = TensorFlow.name, + sparkSession: Option[SparkSession]): Seq[Annotation] = { val questionAnnot = Seq(documents.head) val contextAnnot = documents.drop(1) @@ -258,7 +262,7 @@ private[johnsnowlabs] trait XXXForClassification { val encodedInput = encodeSequence(wordPieceTokenizedQuestion, wordPieceTokenizedContext, maxSentenceLength) - val (startLogits, endLogits) = tagSpan(encodedInput) + val (startLogits, endLogits) = tagSpan(encodedInput, sparkSession) val startScores = startLogits.transpose.map(_.sum / startLogits.length) val endScores = endLogits.transpose.map(_.sum / endLogits.length) @@ -362,9 +366,12 @@ private[johnsnowlabs] trait XXXForClassification { Seq(Array(sentenceStartTokenId) ++ question ++ context) } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] + def tag(batch: Seq[Array[Int]], sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] def tagZeroShotSequence( batch: Seq[Array[Int]], @@ -372,7 +379,9 @@ private[johnsnowlabs] trait XXXForClassification { contradictionId: Int, activation: String): Array[Array[Float]] - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) /** Calculate softmax from returned logits * @param scores diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoBertaClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoBertaClassification.scala index bddf0da0bbd368..390346a9878622 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoBertaClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoBertaClassification.scala @@ -22,6 +22,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -112,7 +113,9 @@ private[johnsnowlabs] class XlmRoBertaClassification( sentenceTokenPieces } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -170,7 +173,10 @@ private[johnsnowlabs] class XlmRoBertaClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -302,7 +308,9 @@ private[johnsnowlabs] class XlmRoBertaClassification( .toArray } - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala index 18448b1de90e90..8102bf37d35a49 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/XlmRoberta.scala @@ -25,6 +25,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import org.apache.spark.sql.SparkSession import scala.collection.JavaConverters._ @@ -98,7 +99,7 @@ private[johnsnowlabs] class XlmRoberta( val dummyInput = Array(0, 10842, 25, 7, 24814, 2037, 70, 148735, 100, 70, 5117, 53498, 6620, 5, 2) if (modelArch == ModelArch.wordEmbeddings) { - tag(Seq(dummyInput)) + tag(Seq(dummyInput), None) } else if (modelArch == ModelArch.sentenceEmbeddings) { tagSequence(Seq(dummyInput)) } @@ -106,7 +107,9 @@ private[johnsnowlabs] class XlmRoberta( sessionWarmup() - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max val batchLength = batch.length @@ -115,7 +118,7 @@ private[johnsnowlabs] class XlmRoberta( case ONNX.name => // [nb of encoded sentences , maxSentenceLength] - val (runner, env) = onnxWrapper.get.getSession() + val (runner, env) = onnxWrapper.get.getSession(sparkSession) val tokenTensors = OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) @@ -244,7 +247,8 @@ private[johnsnowlabs] class XlmRoberta( def predict( tokenizedSentences: Seq[TokenizedSentence], batchSize: Int, - maxSentenceLength: Int): Seq[WordpieceEmbeddingsSentence] = { + maxSentenceLength: Int, + sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = { val wordPieceTokenizedSentences = tokenizeWithAlignment(tokenizedSentences, maxSentenceLength) wordPieceTokenizedSentences.zipWithIndex @@ -256,7 +260,7 @@ private[johnsnowlabs] class XlmRoberta( SentenceStartTokenId, SentenceEndTokenId, SentencePadTokenId) - val vectors = tag(encoded) + val vectors = tag(encoded, sparkSession) /*Combine tokens and calculated embeddings*/ batch.zip(vectors).map { case (sentence, tokenVectors) => diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/XlnetClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/XlnetClassification.scala index 243861418fbdc4..9d8f7b0c65c667 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/XlnetClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/XlnetClassification.scala @@ -21,6 +21,7 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} +import org.apache.spark.sql.SparkSession import org.tensorflow.ndarray.buffer.IntDataBuffer import scala.collection.JavaConverters._ @@ -86,7 +87,9 @@ private[johnsnowlabs] class XlnetClassification( Seq.empty[WordpieceTokenizedSentence] } - def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { + def tag( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -155,7 +158,10 @@ private[johnsnowlabs] class XlnetClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + def tagSequence( + batch: Seq[Array[Int]], + activation: String, + sparkSession: Option[SparkSession]): Array[Array[Float]] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -234,7 +240,9 @@ private[johnsnowlabs] class XlnetClassification( contradictionId: Int, activation: String): Array[Array[Float]] = ??? - def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { (Array.empty[Array[Float]], Array.empty[Array[Float]]) } diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala index 57a60fe26ea175..89dea3565d5e88 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala @@ -18,6 +18,7 @@ package com.johnsnowlabs.ml.ai import com.johnsnowlabs.ml.tensorflow.TensorflowWrapper import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} +import org.apache.spark.sql.SparkSession private[johnsnowlabs] class ZeroShotNerClassification( override val tensorflowWrapper: TensorflowWrapper, @@ -41,8 +42,10 @@ private[johnsnowlabs] class ZeroShotNerClassification( merges, vocabulary) { - override def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { - val (startLogits, endLogits) = super.tagSpan(batch) + override def tagSpan( + batch: Seq[Array[Int]], + sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = { + val (startLogits, endLogits) = super.tagSpan(batch, sparkSession) val contextStartOffsets = batch.map(_.indexOf(sentenceEndTokenId)) ( @@ -63,7 +66,8 @@ private[johnsnowlabs] class ZeroShotNerClassification( maxSentenceLength: Int, caseSensitive: Boolean, mergeTokenStrategy: String, - engine: String): Seq[Annotation] = { + engine: String, + sparkSession: Option[SparkSession]): Seq[Annotation] = { val questionAnnot = Seq(documents.head) val contextAnnot = documents.drop(1) @@ -74,7 +78,7 @@ private[johnsnowlabs] class ZeroShotNerClassification( val encodedInput = encodeSequence(wordPieceTokenizedQuestion, wordPieceTokenizedContext, maxSentenceLength) - val (startLogits, endLogits) = tagSpan(encodedInput) + val (startLogits, endLogits) = tagSpan(encodedInput, sparkSession) val startScores = startLogits.map(x => x.map(y => y / x.sum)).head val endScores = endLogits.map(x => x.map(y => y / x.sum)).head diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index 85578509a90869..e9485afee4aecb 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -76,8 +76,7 @@ trait ReadOnnxModel { spark: SparkSession, suffix: String, zipped: Boolean = true, - useBundle: Boolean = false, - sessionOptions: Option[SessionOptions] = None): OnnxWrapper = { + useBundle: Boolean = false): OnnxWrapper = { val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) @@ -98,7 +97,7 @@ trait ReadOnnxModel { localPath, zipped = zipped, useBundle = useBundle, - sessionOptions = sessionOptions) + sparkSession = Some(spark)) // 4. Remove tmp folder FileHelper.delete(tmpFolder) @@ -113,7 +112,7 @@ trait ReadOnnxModel { suffix: String, zipped: Boolean = true, useBundle: Boolean = false, - sessionOptions: Option[SessionOptions] = None): Map[String, OnnxWrapper] = { + sparkSession: Option[SparkSession]): Map[String, OnnxWrapper] = { val uri = new java.net.URI(path.replaceAllLiterally("\\", "/")) val fs = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) @@ -136,7 +135,7 @@ trait ReadOnnxModel { localPath, zipped = zipped, useBundle = useBundle, - sessionOptions = sessionOptions) + sparkSession = sparkSession) (modelName, onnxWrapper) }).toMap diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index 1a4b9b925b033f..316a5a07b6e168 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -22,11 +22,13 @@ import ai.onnxruntime.providers.OrtCUDAProviderOptions import ai.onnxruntime.{OrtEnvironment, OrtSession} import com.johnsnowlabs.util.{FileHelper, ZipArchiveUtil} import org.apache.commons.io.FileUtils +import org.apache.spark.sql.SparkSession import org.slf4j.{Logger, LoggerFactory} import java.io._ import java.nio.file.{Files, Paths} import java.util.UUID +import scala.util.{Failure, Success, Try} class OnnxWrapper(var onnxModel: Array[Byte]) extends Serializable { @@ -40,10 +42,10 @@ class OnnxWrapper(var onnxModel: Array[Byte]) extends Serializable { @transient private var m_env: OrtEnvironment = _ @transient private val logger = LoggerFactory.getLogger("OnnxWrapper") - def getSession(sessionOptions: Option[SessionOptions] = None): (OrtSession, OrtEnvironment) = + def getSession(sparkSession: Option[SparkSession]): (OrtSession, OrtEnvironment) = this.synchronized { if (m_session == null && m_env == null) { - val (session, env) = OnnxWrapper.withSafeOnnxModelLoader(onnxModel, sessionOptions) + val (session, env) = OnnxWrapper.withSafeOnnxModelLoader(onnxModel, sparkSession) m_env = env m_session = session } @@ -80,40 +82,22 @@ object OnnxWrapper { // TODO: make sure this.synchronized is needed or it's not a bottleneck private def withSafeOnnxModelLoader( onnxModel: Array[Byte], - sessionOptions: Option[SessionOptions] = None): (OrtSession, OrtEnvironment) = + sparkSession: Option[SparkSession]): (OrtSession, OrtEnvironment) = this.synchronized { val env = OrtEnvironment.getEnvironment() - - val opts = - if (sessionOptions.isDefined) sessionOptions.get else new OrtSession.SessionOptions() - val providers = OrtEnvironment.getAvailableProviders - if (providers.toArray.map(x => x.toString).contains("CUDA")) { - logger.info("using CUDA") - // it seems there is no easy way to use multiple GPUs - // at least not without using multiple threads - // TODO: add support for multiple GPUs - // TODO: allow user to specify which GPU to use - val gpuDeviceId = 0 // The GPU device ID to execute on - val cudaOpts = new OrtCUDAProviderOptions(gpuDeviceId) - // TODO: incorporate other cuda-related configs - // cudaOpts.add("gpu_mem_limit", "" + (512 * 1024 * 1024)) - // sessOptions.addCUDA(gpuDeviceId) - opts.addCUDA(cudaOpts) + val sessionOptions = if (providers.toArray.map(x => x.toString).contains("CUDA")) { + getCUDASessionConfig(sparkSession) } else { - logger.info("using CPUs") - // TODO: the following configs can be tested for performance - // However, so far, they seem to be slower than the ones used - // opts.setIntraOpNumThreads(Runtime.getRuntime.availableProcessors()) - // opts.setMemoryPatternOptimization(true) - // opts.setCPUArenaAllocator(false) - opts.setIntraOpNumThreads(6) - opts.setOptimizationLevel(OptLevel.ALL_OPT) - opts.setExecutionMode(ExecutionMode.SEQUENTIAL) + getCPUSessionConfig(sparkSession) + } + + sessionOptions.getConfigEntries.forEach { case (key, value) => + println(s"config: $key, value: $value") } - val session = env.createSession(onnxModel, opts) + val session = env.createSession(onnxModel, sessionOptions) (session, env) } @@ -122,7 +106,7 @@ object OnnxWrapper { zipped: Boolean = true, useBundle: Boolean = false, modelName: String = "model", - sessionOptions: Option[SessionOptions] = None): OnnxWrapper = { + sparkSession: Option[SparkSession]): OnnxWrapper = { // 1. Create tmp folder val tmpFolder = Files @@ -143,13 +127,13 @@ object OnnxWrapper { val onnxFile = Paths.get(modelPath, s"$modelName.onnx").toString val modelFile = new File(onnxFile) val modelBytes = FileUtils.readFileToByteArray(modelFile) - val (session, env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) + val (session, env) = withSafeOnnxModelLoader(modelBytes, sparkSession) (session, env, modelBytes) } else { val modelFile = new File(folder).list().head val fullPath = Paths.get(folder, modelFile).toFile val modelBytes = FileUtils.readFileToByteArray(fullPath) - val (session, env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) + val (session, env) = withSafeOnnxModelLoader(modelBytes, sparkSession) (session, env, modelBytes) } @@ -162,6 +146,86 @@ object OnnxWrapper { onnxWrapper } + private def getCUDASessionConfig(sparkSession: Option[SparkSession]): SessionOptions = { + + logger.info("Using CUDA") + // it seems there is no easy way to use multiple GPUs + // at least not without using multiple threads + // TODO: add support for multiple GPUs + // TODO: allow user to specify which GPU to use + var gpuDeviceId = 0 // The GPU device ID to execute on + + if (sparkSession.isDefined) { + gpuDeviceId = sparkSession.get.conf.get("spark.jsl.settings.onnx.gpuDeviceId", "0").toInt + } + + val sessionOptions = new OrtSession.SessionOptions() + val cudaOpts = new OrtCUDAProviderOptions(gpuDeviceId) + sessionOptions.addCUDA(cudaOpts) + + sessionOptions + } + + private def getCPUSessionConfig(sparkSession: Option[SparkSession]): SessionOptions = { + + val defaultIntraOpNumThreads = 6 + val defaultExecutionMode = ExecutionMode.SEQUENTIAL + val defaultOptLevel = OptLevel.ALL_OPT + + def getOptLevel(optLevel: String): OptLevel = { + Try(OptLevel.valueOf(optLevel)) match { + case Success(value) => value + case Failure(_) => { + logger.warn( + s"Error while getting OptLevel, using default value: ${defaultOptLevel.name()}") + defaultOptLevel + } + } + } + + def getExecutionMode(executionMode: String): ExecutionMode = { + Try(ExecutionMode.valueOf(executionMode)) match { + case Success(value) => value + case Failure(_) => { + logger.warn( + s"Error while getting Execution Mode, using default value: ${defaultExecutionMode.name()}") + defaultExecutionMode + } + } + } + + logger.info("Using CPUs") + // TODO: the following configs can be tested for performance + // However, so far, they seem to be slower than the ones used + // opts.setIntraOpNumThreads(Runtime.getRuntime.availableProcessors()) + // opts.setMemoryPatternOptimization(true) + // opts.setCPUArenaAllocator(false) + var intraOpNumThreads = defaultIntraOpNumThreads + var optimizationLevel = defaultOptLevel + var executionMode = defaultExecutionMode + + if (sparkSession.isDefined) { + intraOpNumThreads = sparkSession.get.conf + .get("spark.jsl.settings.onnx.intraOpNumThreads", defaultIntraOpNumThreads.toString) + .toInt + + optimizationLevel = getOptLevel( + sparkSession.get.conf + .get("spark.jsl.settings.onnx.optimizationLevel", defaultOptLevel.toString)) + + executionMode = getExecutionMode( + sparkSession.get.conf + .get("spark.jsl.settings.onnx.executionMode", defaultExecutionMode.toString)) + } + + val sessionOptions = new OrtSession.SessionOptions() + sessionOptions.setIntraOpNumThreads(intraOpNumThreads) + sessionOptions.setOptimizationLevel(optimizationLevel) + sessionOptions.setExecutionMode(executionMode) + + sessionOptions + } + case class EncoderDecoderWrappers( encoder: OnnxWrapper, decoder: OnnxWrapper, diff --git a/src/main/scala/com/johnsnowlabs/nlp/AnnotatorModel.scala b/src/main/scala/com/johnsnowlabs/nlp/AnnotatorModel.scala index 3698ad0167a94f..bad3579cb3df4c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/AnnotatorModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/AnnotatorModel.scala @@ -19,7 +19,7 @@ package com.johnsnowlabs.nlp import org.apache.spark.ml.{Model, PipelineModel} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} /** This trait implements logic that applies nlp using Spark ML Pipeline transformers Should * strongly change once UsedDefinedTypes are allowed @@ -31,8 +31,12 @@ abstract class AnnotatorModel[M <: Model[M]] extends RawAnnotator[M] with CanBeL * UserDefinedTypes to @developerAPI */ protected type AnnotationContent = Seq[Row] + protected var sparkSession: Option[SparkSession] = None - protected def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = dataset + protected def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = { + sparkSession = Some(dataset.sparkSession) + dataset + } protected def afterAnnotate(dataset: DataFrame): DataFrame = dataset diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala index 36ef72c34449dd..c83eb81805249e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala @@ -392,7 +392,8 @@ class WhisperForCTC(override val uid: String) noRepeatNgramSize = getNoRepeatNgramSize, randomSeed = getRandomSeed, task = getTask, - language = getLanguage) + language = getLanguage, + sparkSession = sparkSession) } else Seq.empty } } @@ -441,7 +442,8 @@ trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel { path, spark, Seq("encoder_model", "decoder_model", "decoder_with_past_model"), - WhisperForCTC.suffix) + WhisperForCTC.suffix, + sparkSession = Some(spark)) val onnxWrappers = EncoderDecoderWrappers( wrappers("encoder_model"), @@ -573,21 +575,24 @@ trait ReadWhisperForCTCDLModel extends ReadTensorflowModel with ReadOnnxModel { modelPath, zipped = false, useBundle = true, - modelName = "encoder_model") + modelName = "encoder_model", + Some(spark)) val onnxWrapperDecoder = OnnxWrapper.read( modelPath, zipped = false, useBundle = true, - modelName = "decoder_model") + modelName = "decoder_model", + Some(spark)) val onnxWrapperDecoderWithPast = OnnxWrapper.read( modelPath, zipped = false, useBundle = true, - modelName = "decoder_with_past_model") + modelName = "decoder_with_past_model", + Some(spark)) val onnxWrappers = EncoderDecoderWrappers( onnxWrapperEncoder, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala index 3c7e1347e70be1..1b824370f366a8 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala @@ -19,19 +19,10 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.{AlbertClassification, MergeTokenStrategy} import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ - ReadSentencePieceModel, - SentencePieceWrapper, - WriteSentencePieceModel -} -import com.johnsnowlabs.ml.util.LoadExternalModel.{ - loadSentencePieceAsset, - modelSanityCheck, - notSupportedEngineError -} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, modelSanityCheck, notSupportedEngineError} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ -import com.johnsnowlabs.nlp.embeddings.BertEmbeddings import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{IntArrayParam, IntParam} @@ -250,7 +241,8 @@ class AlbertForQuestionAnswering(override val uid: String) $(maxSentenceLength), $(caseSensitive), MergeTokenStrategy.sentencePiece, - getEngine) + getEngine, + sparkSession) } else { Seq.empty[Annotation] } @@ -334,8 +326,7 @@ trait ReadAlbertForQuestionAnsweringDLModel spark, "_albert_classification_onnx", zipped = true, - useBundle = false, - None) + useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) case _ => throw new Exception(notSupportedEngineError) @@ -373,7 +364,11 @@ trait ReadAlbertForQuestionAnsweringDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala index 16b9e6c196e37d..616d594b02794b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala @@ -19,17 +19,8 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.AlbertClassification import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ - ReadSentencePieceModel, - SentencePieceWrapper, - WriteSentencePieceModel -} -import com.johnsnowlabs.ml.util.LoadExternalModel.{ - loadSentencePieceAsset, - loadTextAsset, - modelSanityCheck, - notSupportedEngineError -} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, loadTextAsset, modelSanityCheck, notSupportedEngineError} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ @@ -300,7 +291,8 @@ class AlbertForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] } @@ -387,8 +379,7 @@ trait ReadAlbertForSequenceDLModel spark, "_albert_classification_onnx", zipped = true, - useBundle = false, - None) + useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) case _ => throw new Exception(notSupportedEngineError) @@ -428,7 +419,11 @@ trait ReadAlbertForSequenceDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala index 8f91eb208ffc4b..4f7daaa9acd968 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala @@ -37,7 +37,7 @@ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} /** AlbertForTokenClassification can load ALBERT Models with a token classification head on top (a * linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) @@ -271,7 +271,8 @@ class AlbertForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) @@ -358,8 +359,7 @@ trait ReadAlbertForTokenDLModel spark, "_albert_classification_onnx", zipped = true, - useBundle = false, - None) + useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) case _ => throw new Exception(notSupportedEngineError) @@ -399,7 +399,11 @@ trait ReadAlbertForTokenDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala index d48b40dcb65c08..f48802c77babc0 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} /** BertForQuestionAnswering can load Bert Models with a span classification head on top for * extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states @@ -260,7 +260,8 @@ class BertForQuestionAnswering(override val uid: String) documents, $(maxSentenceLength), $(caseSensitive), - MergeTokenStrategy.vocab) + MergeTokenStrategy.vocab, + sparkSession = sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala index ff0bb3aeb4676a..74a81b5e777d5b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala @@ -32,7 +32,7 @@ import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} import java.io.File @@ -313,7 +313,8 @@ class BertForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala index 0c287de7d2cd64..8effc4606e5f2d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala @@ -30,7 +30,7 @@ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} /** BertForTokenClassification can load Bert Models with a token classification head on top (a * linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) @@ -280,7 +280,8 @@ class BertForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala index e55e6adf4b6cb5..8a43dd890ed09e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala @@ -34,7 +34,7 @@ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} /** CamemBertForQuestionAnswering can load CamemBERT Models with a span classification head on top * for extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states @@ -244,7 +244,8 @@ class CamemBertForQuestionAnswering(override val uid: String) documents, $(maxSentenceLength), $(caseSensitive), - MergeTokenStrategy.sentencePiece) + MergeTokenStrategy.sentencePiece, + sparkSession = sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala index 9519af01f8a7ac..50e8644a2c0c20 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala @@ -36,7 +36,7 @@ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} /** CamemBertForSequenceClassification can load CamemBERT Models with sequence * classification/regression head on top (a linear layer on top of the pooled output) e.g. for @@ -296,7 +296,8 @@ class CamemBertForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala index 275cd4bba61238..85d4aa965dc97e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala @@ -36,7 +36,7 @@ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} /** CamemBertForTokenClassification can load CamemBERT Models with a token classification head on * top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition @@ -267,7 +267,8 @@ class CamemBertForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala index 600b85da999a6d..3a96cf3034b40d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala @@ -34,7 +34,7 @@ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, SparkSession} /** DeBertaForQuestionAnswering can load DeBERTa Models with a span classification head on top for * extractive question-answering tasks like SQuAD (a linear layer on top of the hidden-states @@ -244,7 +244,8 @@ class DeBertaForQuestionAnswering(override val uid: String) documents, $(maxSentenceLength), $(caseSensitive), - MergeTokenStrategy.sentencePiece) + MergeTokenStrategy.sentencePiece, + sparkSession = sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala index 0f025ebca7c367..48199a16f7f1b0 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala @@ -296,7 +296,8 @@ class DeBertaForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala index 81b3fdff7def4b..bbd4f724b49d6e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala @@ -268,7 +268,8 @@ class DeBertaForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala index be3709d19b6279..4babc41a8a6e5f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala @@ -257,7 +257,8 @@ class DistilBertForQuestionAnswering(override val uid: String) documents, $(maxSentenceLength), $(caseSensitive), - MergeTokenStrategy.vocab) + MergeTokenStrategy.vocab, + sparkSession = sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala index aee25f66d01640..557018bfde0ec9 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala @@ -309,7 +309,8 @@ class DistilBertForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala index 20616a8303e7fc..757c5ac1f5beae 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala @@ -280,7 +280,8 @@ class DistilBertForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala index 453b8ac7e2cb17..f2421ebc13e349 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala @@ -270,7 +270,8 @@ class LongformerForQuestionAnswering(override val uid: String) documents, $(maxSentenceLength), $(caseSensitive), - MergeTokenStrategy.vocab) + MergeTokenStrategy.vocab, + sparkSession = sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala index 6dd293f033515c..8ee0c0f110ae69 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala @@ -322,7 +322,8 @@ class LongformerForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala index 176fea3d1e19f2..ea4b2f52606e61 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala @@ -293,7 +293,8 @@ class LongformerForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala index 35bd006fc4a3e5..b70d46ce5b047c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala @@ -270,7 +270,8 @@ class RoBertaForQuestionAnswering(override val uid: String) documents, $(maxSentenceLength), $(caseSensitive), - MergeTokenStrategy.vocab) + MergeTokenStrategy.vocab, + sparkSession = sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala index 5e4b268af48f0d..52d977d1ea9edb 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala @@ -322,7 +322,8 @@ class RoBertaForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala index 742306621bd376..4827a796dd5c01 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala @@ -293,7 +293,8 @@ class RoBertaForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala index 01920477d5a672..4dd30eb7fc9580 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala @@ -244,7 +244,8 @@ class XlmRoBertaForQuestionAnswering(override val uid: String) documents, $(maxSentenceLength), $(caseSensitive), - MergeTokenStrategy.sentencePiece) + MergeTokenStrategy.sentencePiece, + sparkSession = sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala index add55d9270b8be..6eb6eda1c7667f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala @@ -295,7 +295,8 @@ class XlmRoBertaForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala index ded252b097d481..7aa31b3669112c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala @@ -267,7 +267,8 @@ class XlmRoBertaForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala index b9e786c4a869fb..9a281acf1df09a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala @@ -296,7 +296,8 @@ class XlnetForSequenceClassification(override val uid: String) $(caseSensitive), $(coalesceSentences), $$(labels), - $(activation)) + $(activation), + sparkSession) } else { Seq.empty[Annotation] } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala index 43b1e4dcd46103..7825cdc81f7f5c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala @@ -267,7 +267,8 @@ class XlnetForTokenClassification(override val uid: String) $(batchSize), $(maxSentenceLength), $(caseSensitive), - $$(labels)) + $$(labels), + sparkSession) }) else { Seq(Seq.empty[Annotation]) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala index ddb2f45e17b82d..057cc4efea874c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala @@ -19,16 +19,8 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Albert import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ - ReadSentencePieceModel, - SentencePieceWrapper, - WriteSentencePieceModel -} -import com.johnsnowlabs.ml.util.LoadExternalModel.{ - loadSentencePieceAsset, - modelSanityCheck, - notSupportedEngineError -} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, modelSanityCheck, notSupportedEngineError} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ @@ -312,7 +304,8 @@ class AlbertEmbeddings(override val uid: String) sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength), - $(caseSensitive)) + $(caseSensitive), + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { @@ -405,7 +398,7 @@ trait ReadAlbertDLModel case ONNX.name => { val onnxWrapper = - readOnnxModel(path, spark, "_albert_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_albert_onnx", zipped = true, useBundle = false) val spp = readSentencePieceModel(path, spark, "_albert_spp", sppFile) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) } @@ -445,7 +438,11 @@ trait ReadAlbertDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala index b123d93e83a310..3b4eb88aa779d5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala @@ -347,7 +347,8 @@ class BertEmbeddings(override val uid: String) sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength), - $(caseSensitive)) + $(caseSensitive), + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { @@ -433,7 +434,7 @@ trait ReadBertDLModel extends ReadTensorflowModel with ReadOnnxModel { case ONNX.name => val onnxWrapper = - readOnnxModel(path, spark, "_bert_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_bert_onnx", zipped = true, useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => @@ -473,7 +474,11 @@ trait ReadBertDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala index c2e36695688a38..579abc9e6f2b28 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala @@ -360,7 +360,8 @@ class BertSentenceEmbeddings(override val uid: String) sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength), - getIsLong) + getIsLong, + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { @@ -461,8 +462,7 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel with ReadOnnxModel { spark, "_bert_sentence_onnx", zipped = true, - useBundle = false, - None) + useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) } case _ => @@ -502,7 +502,11 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala index f59d0d46c0fa41..1a253c6da1e425 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala @@ -267,7 +267,8 @@ class CamemBertEmbeddings(override val uid: String) sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength), - $(caseSensitive)) + $(caseSensitive), + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { @@ -373,7 +374,7 @@ trait ReadCamemBertDLModel case ONNX.name => { val onnxWrapper = - readOnnxModel(path, spark, "_albert_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_albert_onnx", zipped = true, useBundle = false) val spp = readSentencePieceModel(path, spark, "_albert_spp", sppFile) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) } @@ -413,7 +414,11 @@ trait ReadCamemBertDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala index 56f57238e3a84e..4100a0ca28f704 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala @@ -293,7 +293,8 @@ class DeBertaEmbeddings(override val uid: String) sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength), - $(caseSensitive)) + $(caseSensitive), + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { @@ -386,7 +387,7 @@ trait ReadDeBertaDLModel case ONNX.name => { val onnxWrapper = - readOnnxModel(path, spark, "_deberta_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_deberta_onnx", zipped = true, useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) } case _ => @@ -425,7 +426,11 @@ trait ReadDeBertaDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala index d28ce903c48eb0..24c1183079b5ec 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala @@ -349,7 +349,8 @@ class DistilBertEmbeddings(override val uid: String) sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength), - $(caseSensitive)) + $(caseSensitive), + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { @@ -435,7 +436,7 @@ trait ReadDistilBertDLModel extends ReadTensorflowModel with ReadOnnxModel { case ONNX.name => { val onnxWrapper = - readOnnxModel(path, spark, "_distilbert_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_distilbert_onnx", zipped = true, useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) } case _ => @@ -475,7 +476,11 @@ trait ReadDistilBertDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(tfWrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala index 6524c9a3bd1e0e..f32e72ba1d3540 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala @@ -167,8 +167,6 @@ class Doc2VecModel(override val uid: String) /** @group setParam */ def setWordVectors(value: Map[String, Array[Float]]): this.type = set(wordVectors, value) - private var sparkSession: Option[SparkSession] = None - def getVectors: DataFrame = { val vectors: Map[String, Array[Float]] = $$(wordVectors) val rows = vectors.toSeq.map { case (key, values) => Row(key, values) } @@ -196,11 +194,6 @@ class Doc2VecModel(override val uid: String) res } - override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = { - sparkSession = Some(dataset.sparkSession) - dataset - } - /** takes a document and annotations and produces new annotations of this annotator's annotation * type * diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala index 38ead9b55ac086..9099237e42740e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/E5Embeddings.scala @@ -311,7 +311,8 @@ class E5Embeddings(override val uid: String) sentences = allAnnotations.map(_._1), tokenizedSentences = tokenizedSentences, batchSize = $(batchSize), - maxSentenceLength = $(maxSentenceLength)) + maxSentenceLength = $(maxSentenceLength), + sparkSession) } else { Seq() } @@ -413,7 +414,7 @@ trait ReadE5DLModel extends ReadTensorflowModel with ReadOnnxModel { case ONNX.name => val onnxWrapper = - readOnnxModel(path, spark, "_e5_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_e5_onnx", zipped = true, useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => @@ -459,7 +460,11 @@ trait ReadE5DLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala index a42c1b334d9d0b..b8f2705f3a3dc8 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala @@ -341,7 +341,8 @@ class LongformerEmbeddings(override val uid: String) sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength), - $(caseSensitive)) + $(caseSensitive), + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/MPNetEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/MPNetEmbeddings.scala index 80e29d4c15dadf..47b5650a0d0ea5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/MPNetEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/MPNetEmbeddings.scala @@ -312,7 +312,8 @@ class MPNetEmbeddings(override val uid: String) sentences = allAnnotations.map(_._1), tokenizedSentences = tokenizedSentences, batchSize = $(batchSize), - maxSentenceLength = $(maxSentenceLength)) + maxSentenceLength = $(maxSentenceLength), + sparkSession = sparkSession) } else { Seq() } @@ -413,7 +414,7 @@ trait ReadMPNetDLModel extends ReadTensorflowModel with ReadOnnxModel { case ONNX.name => val onnxWrapper = - readOnnxModel(path, spark, "_mpnet_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_mpnet_onnx", zipped = true, useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) case _ => @@ -453,7 +454,11 @@ trait ReadMPNetDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala index 02c06bca1b4e77..b532e9f8428aa4 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala @@ -355,7 +355,8 @@ class RoBertaEmbeddings(override val uid: String) sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength), - $(caseSensitive)) + $(caseSensitive), + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { @@ -452,7 +453,7 @@ trait ReadRobertaDLModel extends ReadTensorflowModel with ReadOnnxModel { case ONNX.name => { val onnxWrapper = - readOnnxModel(path, spark, "_roberta_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_roberta_onnx", zipped = true, useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) } case _ => @@ -500,7 +501,11 @@ trait ReadRobertaDLModel extends ReadTensorflowModel with ReadOnnxModel { .setModelIfNotSet(spark, Some(wrapper), None) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecModel.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecModel.scala index 67a7388eef419f..61ca0442f9831a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecModel.scala @@ -168,8 +168,6 @@ class Word2VecModel(override val uid: String) /** @group setParam */ def setWordVectors(value: Map[String, Array[Float]]): this.type = set(wordVectors, value) - private var sparkSession: Option[SparkSession] = None - def getVectors: DataFrame = { val vectors: Map[String, Array[Float]] = $$(wordVectors) val rows = vectors.toSeq.map { case (key, values) => Row(key, values) } @@ -185,11 +183,6 @@ class Word2VecModel(override val uid: String) setDefault(inputCols -> Array(TOKEN), outputCol -> "word2vec", vectorSize -> 100) - override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = { - sparkSession = Some(dataset.sparkSession) - dataset - } - /** takes a document and annotations and produces new annotations of this annotator's annotation * type * diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala index 2d59b18fdb3292..c126c6bd5f55ec 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala @@ -295,7 +295,11 @@ class XlmRoBertaEmbeddings(override val uid: String) } val sentenceWordEmbeddings = - getModelIfNotSet.predict(sentencesWithRow.map(_._1), $(batchSize), $(maxSentenceLength)) + getModelIfNotSet.predict( + sentencesWithRow.map(_._1), + $(batchSize), + $(maxSentenceLength), + sparkSession) // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence batchedAnnotations.indices.map(rowIndex => { @@ -407,7 +411,7 @@ trait ReadXlmRobertaDLModel case ONNX.name => { val onnxWrapper = - readOnnxModel(path, spark, "_xlmroberta_onnx", zipped = true, useBundle = false, None) + readOnnxModel(path, spark, "_xlmroberta_onnx", zipped = true, useBundle = false) val spp = readSentencePieceModel(path, spark, "_xlmroberta_spp", sppFile) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) } @@ -447,7 +451,11 @@ trait ReadXlmRobertaDLModel .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) case ONNX.name => - val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + val onnxWrapper = OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + sparkSession = Some(spark)) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) From 4f8e7667068cbb6f857c995d50ede2204e5d18a5 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Wed, 13 Sep 2023 15:55:15 -0500 Subject: [PATCH 2/3] [SPARKNLP-907] Adding ONNX config notebook example --- ...Spark_NLP_AlbertForQuestionAnswering.ipynb | 2478 +++++++++++++++++ .../johnsnowlabs/ml/onnx/OnnxWrapper.scala | 4 - .../dl/AlbertForQuestionAnswering.scala | 12 +- .../dl/AlbertForSequenceClassification.scala | 13 +- .../nlp/embeddings/AlbertEmbeddings.scala | 12 +- .../embeddings/BertSentenceEmbeddings.scala | 7 +- 6 files changed, 2510 insertions(+), 16 deletions(-) create mode 100644 examples/python/transformers/onnx/ONNX_Configs_in_Spark_NLP_AlbertForQuestionAnswering.ipynb diff --git a/examples/python/transformers/onnx/ONNX_Configs_in_Spark_NLP_AlbertForQuestionAnswering.ipynb b/examples/python/transformers/onnx/ONNX_Configs_in_Spark_NLP_AlbertForQuestionAnswering.ipynb new file mode 100644 index 00000000000000..8cbf25fb4b50d2 --- /dev/null +++ b/examples/python/transformers/onnx/ONNX_Configs_in_Spark_NLP_AlbertForQuestionAnswering.ipynb @@ -0,0 +1,2478 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "vfU3Ee88cwGj" + }, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace%20in%20Spark%20NLP%20-%20AlbertForQuestionAnswering.ipynb)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Setting ONNX configs in SparkNLP" + ], + "metadata": { + "id": "vTt4y2jTDAAa" + } + }, + { + "cell_type": "markdown", + "source": [ + "Starting from Spark NLP 5.1.2, you can configure ONNX-related settings within your Spark session. This allows you to fine-tune the behavior of the ONNX engine for your specific needs." + ], + "metadata": { + "id": "Aqb-WQJFDG0K" + } + }, + { + "cell_type": "markdown", + "source": [ + "Here are the available options for CPU:\n", + "\n", + "- **intraOpNumThreads**: This setting, `spark.jsl.settings.onnx.intraOpNumThreads`, controls the number of threads used for intra-operation parallelism when executing ONNX models. You can set this value to optimize the performance of ONNX execution. To understand how this affects your ONNX tasks, refer to the ONNX documentation.\n", + "\n", + "- **optimizationLevel**: Use `spark.jsl.settings.onnx.optimizationLevel` to specify the optimization level for ONNX execution. This setting influences how aggressively Spark NLP optimizes the execution of ONNX models. Explore the available options to determine which level suits your workload best in [this ONNX documentation]((https://onnxruntime.ai/docs/api/java/ai/onnxruntime/OrtSession.SessionOptions.OptLevel.html)).\n", + "\n", + "- **executionMode**: With `spark.jsl.settings.onnx.executionMode`, you can choose the execution mode for ONNX models. Different modes may offer trade-offs between performance and resource utilization. Review the available options to select the mode that aligns with your requirements in [this ONNX documentation]((https://onnxruntime.ai/docs/api/java/ai/onnxruntime/OrtSession.SessionOptions.ExecutionMode.html))\n" + ], + "metadata": { + "id": "EMIks45tQxwX" + } + }, + { + "cell_type": "markdown", + "source": [ + "Here are the available options for CUDA:\n", + "\n", + "- **gpuDeviceId**: Use `spark.jsl.settings.onnx.gpuDeviceId` to define the GPU device to execute on" + ], + "metadata": { + "id": "ITxGZFXfTLUL" + } + }, + { + "cell_type": "markdown", + "source": [ + "To find more information and detailed usage instructions for these ONNX configuration options, refer to the [ONNX API documentation](https://onnxruntime.ai/docs/api/)." + ], + "metadata": { + "id": "zrFwUW1aUDCy" + } + }, + { + "cell_type": "code", + "source": [ + "# Let's set our config based on our needs:\n", + "onnx_params = {\n", + " \"spark.jsl.settings.onnx.intraOpNumThreads\": \"5\",\n", + " \"spark.jsl.settings.onnx.optimizationLevel\": \"BASIC_OPT\",\n", + " \"spark.jsl.settings.onnx.executionMode\": \"SEQUENTIAL\"\n", + "}" + ], + "metadata": { + "id": "l2WwQcUNR-_P" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fM_4ix0mcwGm" + }, + "source": [ + "## Import AlbertForQuestionAnswering models from HuggingFace 🤗 into Spark NLP 🚀\n", + "\n", + "Let's keep in mind a few things before we start 😊\n", + "\n", + "- ONNX support was introduced in `Spark NLP 5.0.0`, enabling high performance inference for models.\n", + "- `AlbertForQuestionAnswering` is only available since in `Spark NLP 5.1.1` and after. So please make sure you have upgraded to the latest Spark NLP release\n", + "- You can import ALBERT models trained/fine-tuned for question answering via `AlbertForQuestionAnswering`. These models are usually under `Question Answering` category and have `albert` in their labels\n", + "- Reference: [TFAlbertForQuestionAnswering](https://huggingface.co/transformers/model_doc/albert#transformers.TFAlbertForQuestionAnswering)\n", + "- Some [example models](https://huggingface.co/models?filter=albert&pipeline_tag=question-answering)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EVzmVKX8cwGn" + }, + "source": [ + "## Export and Save HuggingFace model" + ] + }, + { + "cell_type": "markdown", + "source": [ + "- Let's install `transformers` package with the `onnx` extension and it's dependencies. You don't need `onnx` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- We lock `transformers` on version `4.29.1`. This doesn't mean it won't work with the future releases, but we wanted you to know which versions have been tested successfully.\n", + "- Albert uses SentencePiece, so we will have to install that as well" + ], + "metadata": { + "id": "WDSalCHsd9-z" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -q --upgrade transformers[onnx]==4.29.1 optimum sentencepiece tensorflow" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "qSx09sNyegma", + "outputId": "d77a037d-7ff5-4397-f33a-83c7f3517ab7" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.1/7.1 MB\u001b[0m \u001b[31m42.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m396.1/396.1 kB\u001b[0m \u001b[31m26.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m61.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m294.8/294.8 kB\u001b[0m \u001b[31m23.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m85.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.5/84.5 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m454.7/454.7 kB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.9/5.9 MB\u001b[0m \u001b[31m48.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.7/212.7 kB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m3.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.6/519.6 kB\u001b[0m \u001b[31m31.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m51.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.5/55.5 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m51.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.9/5.9 MB\u001b[0m \u001b[31m67.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.0/5.0 MB\u001b[0m \u001b[31m88.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.0/5.0 MB\u001b[0m \u001b[31m79.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m90.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m89.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m85.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m524.1/524.1 MB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m585.9/585.9 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m62.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m76.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m440.7/440.7 kB\u001b[0m \u001b[31m29.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m93.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m781.3/781.3 kB\u001b[0m \u001b[31m48.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m85.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.6/5.6 MB\u001b[0m \u001b[31m88.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m585.9/585.9 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m588.3/588.3 MB\u001b[0m \u001b[31m401.5 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m66.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m54.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m77.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m439.2/439.2 kB\u001b[0m \u001b[31m27.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m49.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.8/83.8 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m454.7/454.7 kB\u001b[0m \u001b[31m34.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m451.2/451.2 kB\u001b[0m \u001b[31m31.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m49.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.5/13.5 MB\u001b[0m \u001b[31m62.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.5/13.5 MB\u001b[0m \u001b[31m59.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.1/13.1 MB\u001b[0m \u001b[31m38.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow-datasets 4.9.2 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "tensorflow-metadata 1.14.0 requires protobuf<4.21,>=3.20.3, but you have protobuf 3.19.6 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "- HuggingFace has an extension called Optimum which offers specialized model inference, including ONNX. We can use this to import and export ONNX models with `from_pretrained` and `save_pretrained`.\n", + "- We'll use [twmkn9/albert-base-v2-squad2](https://huggingface.co/twmkn9/albert-base-v2-squad2) model from HuggingFace as an example and load it as a `ORTModelForQuestionAnswering`, representing an ONNX model." + ], + "metadata": { + "id": "uFkFe1YUewJR" + } + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "FtWcH9nycwGq", + "outputId": "325fd2f2-06e5-4e5a-86cd-dd1a6fcd18d6", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 316, + "referenced_widgets": [ + "1213b22c364a4abbabd43abb5cc1a26b", + "97fd422145944317af1bc9a89e6ce3fd", + "7ff72165cefd4c0aa48e964e41ebb3f9", + "69b18e5217224db5a3c591135c3bebf2", + "0a489803721a492e8effa522b3483189", + "09b299ec34d64a548d5368061883988f", + "69b7811c8a404e00ab82a762a956f3af", + "2a1b18093a4c4861bc68e86fea6e3104", + "28d4c36ad4b84f2dad414fada2099482", + "71b1345cd1264472a7f2daf357740e91", + "0dfcd8ca4b4a4e7f9f522e4e3098ac3d", + "7f97935275124bd9a785fb9cb423bcba", + "80c9db42e0554825b82db8368f538661", + "3a266762324443199cada1bda11b1295", + "8c5aaebf6dce4a1cbc9a0260fa90485e", + "4d72350e63104c048da7151ec68ac759", + "c6cceba13f394ba79754faae33913cfa", + "88131bbfe2094ef49654809904132f40", + "a0a593bb283948929f4f2300c2030b04", + "e963319a0bd24f1d83bc4df76a3c72d9", + "8626f23bd667426d89ddfe7d16372e32", + "29709077455147488756367c6d8a848a", + "8cecd4db22724d8b8135c40b89eb49b0", + "6d49336bfd034bf68311ed5883243c31", + "b93899282aee4c6f85e38069fa3e8b0d", + "daecf3a743ca412891db6b4d722d99b7", + "e03a26efb4474238aaf8f147d0b5fb2f", + "399fdb554ed64bf582429c54d768cae7", + "a440b40c1cfa46009c7c33394f50eb9e", + "b42a9cf3d81049ed8ad8829148078fe5", + "f78287e946eb4f909ac872426a800b65", + "6785d72ae079455eaba2f26e8ca9b925", + "4461bb6a0b09419fbdb2942267229722", + "635a35d690f940ffb510304f9d007fa0", + "e3fc3abac4d94679ae18872b3b3ae44a", + "a03f9e50021b49e3b815cfd07b8ceddb", + "0d7bace449b9472291a6d78fabb90149", + "686f3cd8af174fd4980bf3d004fc8867", + "c73d03f92f784ce8bfd6c17954bd5454", + "ac298eb697724b0fa81a2abb97a257e4", + "30200864fcf44d649b110202c8612be7", + "158bd727b1624c5c86a907fd6df23ed6", + "09b5c65e66cb4fbcab6ca7c3947ce914", + "3f297d24ad444543a3811073900f44b5", + "622d90eede064ac794d6888bfbb77365", + "fe975110389147359739c21ac9f89ac2", + "65bab2e559a5455cadff53691b83a700", + "ff0611ae5afc454e9677dab4ffb31d36", + "74105c9458364362a3862dd52a292769", + "7870ff16393641b7993fef2c847e2586", + "21716ea71935456d8f190302047edbfb", + "68d782ece6f74b3ea5191c5f92d76a07", + "fbd136a8853b493486e0896344536070", + "91bb579ba2ad49648ab8464c084dfb4e", + "04ec729f43224ff5951d5f3d3dac6c73" + ] + } + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading (…)lve/main/config.json: 0%| | 0.00/716 [00:00 - println(s"config: $key, value: $value") - } - val session = env.createSession(onnxModel, sessionOptions) (session, env) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala index 1b824370f366a8..2ff650360bc0dc 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala @@ -19,8 +19,16 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.{AlbertClassification, MergeTokenStrategy} import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel} -import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, modelSanityCheck, notSupportedEngineError} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadSentencePieceAsset, + modelSanityCheck, + notSupportedEngineError +} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala index 616d594b02794b..74207c15e3f477 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala @@ -19,8 +19,17 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.AlbertClassification import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel} -import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, loadTextAsset, modelSanityCheck, notSupportedEngineError} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadSentencePieceAsset, + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala index 057cc4efea874c..421dac586968d7 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala @@ -19,8 +19,16 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Albert import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel} -import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, modelSanityCheck, notSupportedEngineError} +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadSentencePieceAsset, + modelSanityCheck, + notSupportedEngineError +} import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala index 579abc9e6f2b28..512a04219d4043 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala @@ -457,12 +457,7 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel with ReadOnnxModel { case ONNX.name => { val onnxWrapper = - readOnnxModel( - path, - spark, - "_bert_sentence_onnx", - zipped = true, - useBundle = false) + readOnnxModel(path, spark, "_bert_sentence_onnx", zipped = true, useBundle = false) instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) } case _ => From f2690d88340c0c2976b367a811ecfd44703a4ce8 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Tue, 26 Sep 2023 12:41:44 -0500 Subject: [PATCH 3/3] SPARKNLP-907 Adding control when spark config is null --- src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index cfdf575913d7c1..ac9717abebc562 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -200,7 +200,7 @@ object OnnxWrapper { var optimizationLevel = defaultOptLevel var executionMode = defaultExecutionMode - if (sparkSession.isDefined) { + if (sparkSession.isDefined && sparkSession.get.conf != null) { intraOpNumThreads = sparkSession.get.conf .get("spark.jsl.settings.onnx.intraOpNumThreads", defaultIntraOpNumThreads.toString) .toInt