Skip to content

Commit

Permalink
Adding ONNX Support to ALBERT Token and Sequence Classification and Q…
Browse files Browse the repository at this point in the history
…uestion Answering annotators (#13956)

* SPARKNLP-891 Adding ONNX support for AlbertQuestionAnswering
SPARKNLP-892 Adding ONNX support for AlbertSequenceClassification
SPARKNLP-893 Adding ONNX support for AlbertTokenClassification

* SPARKNLP-891 Adding ONNX support for AlbertQuestionAnswering
SPARKNLP-892 Adding ONNX support for AlbertSequenceClassification
SPARKNLP-893 Adding ONNX support for AlbertTokenClassification
  • Loading branch information
danilojsl authored Sep 7, 2023
1 parent 96094c3 commit fae2344
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 124 deletions.
234 changes: 155 additions & 79 deletions src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import org.tensorflow.ndarray.buffer.IntDataBuffer
Expand All @@ -37,7 +41,8 @@ import scala.collection.JavaConverters._
* TF v2 signatures in Spark NLP
*/
private[johnsnowlabs] class AlbertClassification(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val spp: SentencePieceWrapper,
configProtoBytes: Option[Array[Byte]] = None,
tags: Map[String, Int],
Expand All @@ -48,6 +53,10 @@ private[johnsnowlabs] class AlbertClassification(

val _tfAlbertSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

// keys representing the input and output tensors of the ALBERT model
protected val sentencePadTokenId: Int = spp.getSppModel.pieceToId("[pad]")
Expand Down Expand Up @@ -95,59 +104,13 @@ private[johnsnowlabs] class AlbertClassification(
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}

val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers)

runner
.feed(
_tfAlbertSignatures.getOrElse(
ModelSignatureConstants.InputIds.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfAlbertSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfAlbertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfAlbertSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))

val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)

outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / (batchLength * maxSentenceLength)
val batchScores: Array[Array[Array[Float]]] = rawScores
Expand All @@ -161,17 +124,39 @@ private[johnsnowlabs] class AlbertClassification(
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
.grouped(dim)
.map(scores =>
activation match {
case ActivationFunction.softmax => calculateSoftmax(scores)
case ActivationFunction.sigmoid => calculateSigmoid(scores)
case _ => calculateSoftmax(scores)
})
.toArray

batchScores
}

private def getRawScoresWithTF(batch: Seq[Array[Int]], maxSentenceLength: Int): Array[Float] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)
val shape = Array(batchLength.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
Expand All @@ -183,7 +168,7 @@ private[johnsnowlabs] class AlbertClassification(
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -215,19 +200,51 @@ private[johnsnowlabs] class AlbertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
.grouped(dim)
.map(scores =>
activation match {
case ActivationFunction.softmax => calculateSoftmax(scores)
case ActivationFunction.sigmoid => calculateSigmoid(scores)
case _ => calculateSoftmax(scores)
})
.toArray
rawScores
}

batchScores
private def getRowScoresWithOnnx(
batch: Seq[Array[Int]],
maxSentenceLength: Int,
sequence: Boolean): Array[Float] = {

val output = if (sequence) "logits" else "last_hidden_state"

// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get(output)
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

embeddings
} finally if (results != null) results.close()
}
}

def tagZeroShotSequence(
Expand All @@ -237,10 +254,29 @@ private[johnsnowlabs] class AlbertClassification(
activation: String): Array[Array[Float]] = ???

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
val tensors = new TensorResources()

val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val (startLogits, endLogits) = detectedEngine match {
case ONNX.name => computeLogitsWithOnnx(batch, maxSentenceLength)
case _ => computeLogitsWithTF(batch, maxSentenceLength)
}

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}

private def computeLogitsWithTF(
batch: Seq[Array[Int]],
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
val batchLength = batch.length
val tensors = new TensorResources()

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
Expand Down Expand Up @@ -271,7 +307,7 @@ private[johnsnowlabs] class AlbertClassification(
})
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -306,15 +342,55 @@ private[johnsnowlabs] class AlbertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray
(endLogits, startLogits)
}

(startScores, endScores)
private def computeLogitsWithOnnx(
batch: Seq[Array[Int]],
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val output = runner.run(inputs)
try {
val startLogits = output
.get("start_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

val endLogits = output
.get("end_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

tokenTensors.close()
maskTensors.close()
segmentTensors.close()

(startLogits.slice(1, startLogits.length), endLogits.slice(1, endLogits.length))
} finally if (output != null) output.close()
}
}

def findIndexedToken(
Expand Down
10 changes: 8 additions & 2 deletions src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation, AnnotatorType}

Expand Down Expand Up @@ -244,7 +245,8 @@ private[johnsnowlabs] trait XXXForClassification {
documents: Seq[Annotation],
maxSentenceLength: Int,
caseSensitive: Boolean,
mergeTokenStrategy: String = MergeTokenStrategy.vocab): Seq[Annotation] = {
mergeTokenStrategy: String = MergeTokenStrategy.vocab,
engine: String = TensorFlow.name): Seq[Annotation] = {

val questionAnnot = Seq(documents.head)
val contextAnnot = documents.drop(1)
Expand All @@ -264,9 +266,13 @@ private[johnsnowlabs] trait XXXForClassification {
val startIndex = startScores.zipWithIndex.maxBy(_._1)
val endIndex = endScores.zipWithIndex.maxBy(_._1)

val offsetStartIndex = if (engine == TensorFlow.name) 2 else 1
val offsetEndIndex = if (engine == TensorFlow.name) 1 else 0

val allTokenPieces =
wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens)
val decodedAnswer = allTokenPieces.slice(startIndex._2 - 2, endIndex._2 - 1)
val decodedAnswer =
allTokenPieces.slice(startIndex._2 - offsetStartIndex, endIndex._2 - offsetEndIndex)
val content =
mergeTokenStrategy match {
case MergeTokenStrategy.vocab =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ private[johnsnowlabs] class ZeroShotNerClassification(
documents: Seq[Annotation],
maxSentenceLength: Int,
caseSensitive: Boolean,
mergeTokenStrategy: String): Seq[Annotation] = {
mergeTokenStrategy: String,
engine: String): Seq[Annotation] = {
val questionAnnot = Seq(documents.head)
val contextAnnot = documents.drop(1)

Expand Down
Loading

0 comments on commit fae2344

Please sign in to comment.