Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKNLP-907 Allows setting up ONNX configs through spark session #13979

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import org.apache.spark.sql.SparkSession

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -93,12 +94,14 @@ private[johnsnowlabs] class Albert(
private def sessionWarmup(): Unit = {
val dummyInput =
Array(101, 2292, 1005, 1055, 4010, 6279, 1996, 5219, 2005, 1996, 2034, 28937, 1012, 102)
tag(Seq(dummyInput))
tag(Seq(dummyInput), None)
}

sessionWarmup()

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
def tag(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = {

val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length
Expand All @@ -107,7 +110,7 @@ private[johnsnowlabs] class Albert(

case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()
val (runner, env) = onnxWrapper.get.getSession(sparkSession)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
Expand Down Expand Up @@ -204,7 +207,8 @@ private[johnsnowlabs] class Albert(
tokenizedSentences: Seq[TokenizedSentence],
batchSize: Int,
maxSentenceLength: Int,
caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = {
caseSensitive: Boolean,
sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = {

val wordPieceTokenizedSentences =
tokenizeWithAlignment(tokenizedSentences, maxSentenceLength, caseSensitive)
Expand All @@ -217,7 +221,7 @@ private[johnsnowlabs] class Albert(
SentenceStartTokenId,
SentenceEndTokenId,
SentencePadTokenId)
val vectors = tag(encoded)
val vectors = tag(encoded, sparkSession)

/*Combine tokens and calculated embeddings*/
batch.zip(vectors).map { case (sentence, tokenVectors) =>
Expand Down
33 changes: 22 additions & 11 deletions src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import org.apache.spark.sql.SparkSession
import org.tensorflow.ndarray.buffer.IntDataBuffer

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -103,12 +103,15 @@ private[johnsnowlabs] class AlbertClassification(
sentenceTokenPieces
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
def tag(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true)
case ONNX.name =>
getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true, sparkSession)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

Expand All @@ -123,12 +126,16 @@ private[johnsnowlabs] class AlbertClassification(
batchScores
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
def tagSequence(
batch: Seq[Array[Int]],
activation: String,
sparkSession: Option[SparkSession]): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true)
case ONNX.name =>
getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true, sparkSession)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

Expand Down Expand Up @@ -206,12 +213,13 @@ private[johnsnowlabs] class AlbertClassification(
private def getRowScoresWithOnnx(
batch: Seq[Array[Int]],
maxSentenceLength: Int,
sequence: Boolean): Array[Float] = {
sequence: Boolean,
sparkSession: Option[SparkSession]): Array[Float] = {

val output = if (sequence) "logits" else "last_hidden_state"

// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()
val (runner, env) = onnxWrapper.get.getSession(sparkSession)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
Expand Down Expand Up @@ -253,11 +261,13 @@ private[johnsnowlabs] class AlbertClassification(
contradictionId: Int,
activation: String): Array[Array[Float]] = ???

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
def tagSpan(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = {
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val (startLogits, endLogits) = detectedEngine match {
case ONNX.name => computeLogitsWithOnnx(batch, maxSentenceLength)
case ONNX.name => computeLogitsWithOnnx(batch, maxSentenceLength, sparkSession)
case _ => computeLogitsWithTF(batch, maxSentenceLength)
}

Expand Down Expand Up @@ -347,9 +357,10 @@ private[johnsnowlabs] class AlbertClassification(

private def computeLogitsWithOnnx(
batch: Seq[Array[Int]],
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
maxSentenceLength: Int,
sparkSession: Option[SparkSession]): (Array[Float], Array[Float]) = {
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()
val (runner, env) = onnxWrapper.get.getSession(sparkSession)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
Expand Down
14 changes: 11 additions & 3 deletions src/main/scala/com/johnsnowlabs/ml/ai/BartClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import org.apache.spark.sql.SparkSession
import org.tensorflow.ndarray.buffer.IntDataBuffer

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -127,7 +128,9 @@ private[johnsnowlabs] class BartClassification(

}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
def tag(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand Down Expand Up @@ -185,7 +188,10 @@ private[johnsnowlabs] class BartClassification(
batchScores
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
def tagSequence(
batch: Seq[Array[Int]],
activation: String,
sparkSession: Option[SparkSession]): Array[Array[Float]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand Down Expand Up @@ -317,7 +323,9 @@ private[johnsnowlabs] class BartClassification(
.toArray
}

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
def tagSpan(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand Down
27 changes: 17 additions & 10 deletions src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import org.apache.spark.sql.SparkSession

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -71,26 +72,28 @@ private[johnsnowlabs] class Bert(
val dummyInput =
Array(101, 2292, 1005, 1055, 4010, 6279, 1996, 5219, 2005, 1996, 2034, 28937, 1012, 102)
if (modelArch == ModelArch.wordEmbeddings) {
tag(Seq(dummyInput))
tag(Seq(dummyInput), None)
} else if (modelArch == ModelArch.sentenceEmbeddings) {
if (isSBert)
tagSequenceSBert(Seq(dummyInput))
else
tagSequence(Seq(dummyInput))
tagSequence(Seq(dummyInput), None)
}
}

sessionWarmup()

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
def tag(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = {
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length

val embeddings = detectedEngine match {

case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()
val (runner, env) = onnxWrapper.get.getSession(sparkSession)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
Expand Down Expand Up @@ -185,15 +188,17 @@ private[johnsnowlabs] class Bert(

}

def tagSequence(batch: Seq[Array[Int]]): Array[Array[Float]] = {
def tagSequence(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): Array[Array[Float]] = {

val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length

val embeddings = detectedEngine match {
case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()
val (runner, env) = onnxWrapper.get.getSession(sparkSession)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
Expand Down Expand Up @@ -346,7 +351,8 @@ private[johnsnowlabs] class Bert(
originalTokenSentences: Seq[TokenizedSentence],
batchSize: Int,
maxSentenceLength: Int,
caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = {
caseSensitive: Boolean,
sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = {

/*Run embeddings calculation by batches*/
sentences.zipWithIndex
Expand All @@ -357,7 +363,7 @@ private[johnsnowlabs] class Bert(
maxSentenceLength,
sentenceStartTokenId,
sentenceEndTokenId)
val vectors = tag(encoded)
val vectors = tag(encoded, sparkSession)

/*Combine tokens and calculated embeddings*/
batch.zip(vectors).map { case (sentence, tokenVectors) =>
Expand Down Expand Up @@ -408,7 +414,8 @@ private[johnsnowlabs] class Bert(
sentences: Seq[Sentence],
batchSize: Int,
maxSentenceLength: Int,
isLong: Boolean = false): Seq[Annotation] = {
isLong: Boolean = false,
sparkSession: Option[SparkSession]): Seq[Annotation] = {

/*Run embeddings calculation by batches*/
tokens
Expand All @@ -426,7 +433,7 @@ private[johnsnowlabs] class Bert(
val embeddings = if (isLong) {
tagSequenceSBert(encoded)
} else {
tagSequence(encoded)
tagSequence(encoded, sparkSession)
}

sentencesBatch.zip(embeddings).map { case (sentence, vectors) =>
Expand Down
14 changes: 11 additions & 3 deletions src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import org.apache.spark.sql.SparkSession
import org.tensorflow.ndarray.buffer.IntDataBuffer

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -134,7 +135,9 @@ private[johnsnowlabs] class BertClassification(
}
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
def tag(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand Down Expand Up @@ -198,7 +201,10 @@ private[johnsnowlabs] class BertClassification(
batchScores
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
def tagSequence(
batch: Seq[Array[Int]],
activation: String,
sparkSession: Option[SparkSession]): Array[Array[Float]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand Down Expand Up @@ -338,7 +344,9 @@ private[johnsnowlabs] class BertClassification(
.toArray
}

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
def tagSpan(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): (Array[Array[Float]], Array[Array[Float]]) = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand Down
14 changes: 9 additions & 5 deletions src/main/scala/com/johnsnowlabs/ml/ai/CamemBert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignat
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import org.apache.spark.sql.SparkSession

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -68,12 +69,14 @@ private[johnsnowlabs] class CamemBert(
private def sessionWarmup(): Unit = {
val dummyInput =
Array(5, 54, 110, 11, 10, 15540, 215, 1280, 808, 25352, 1782, 808, 24696, 378, 17409, 9, 6)
tag(Seq(dummyInput))
tag(Seq(dummyInput), None)
}

sessionWarmup()

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
def tag(
batch: Seq[Array[Int]],
sparkSession: Option[SparkSession]): Seq[Array[Array[Float]]] = {

val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length
Expand All @@ -82,7 +85,7 @@ private[johnsnowlabs] class CamemBert(

case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()
val (runner, env) = onnxWrapper.get.getSession(sparkSession)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
Expand Down Expand Up @@ -167,7 +170,8 @@ private[johnsnowlabs] class CamemBert(
tokenizedSentences: Seq[TokenizedSentence],
batchSize: Int,
maxSentenceLength: Int,
caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence] = {
caseSensitive: Boolean,
sparkSession: Option[SparkSession]): Seq[WordpieceEmbeddingsSentence] = {

val wordPieceTokenizedSentences =
tokenizeWithAlignment(tokenizedSentences, maxSentenceLength, caseSensitive)
Expand All @@ -180,7 +184,7 @@ private[johnsnowlabs] class CamemBert(
SentenceStartTokenId,
SentenceEndTokenId,
SentencePadTokenId)
val vectors = tag(encoded)
val vectors = tag(encoded, sparkSession)

/*Combine tokens and calculated embeddings*/
batch.zip(vectors).map { case (sentence, tokenVectors) =>
Expand Down
Loading
Loading