Skip to content

Commit

Permalink
Sparknlp 888 Add ONNX support to MPNet embeddings (#13955)
Browse files Browse the repository at this point in the history
* adding onxx support to mpnet

* remove name in test

* updating default name for mpnet models in scala and python

* updating default model name
  • Loading branch information
ahmedlone127 authored Sep 7, 2023
1 parent 7f78be3 commit 96094c3
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 28 deletions.
73 changes: 68 additions & 5 deletions src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

import scala.collection.JavaConverters._

/** MPNET Sentence embeddings model
*
* @param tensorflow
* @param tensorflowWrapper
* tensorflow wrapper
* @param configProtoBytes
* config proto bytes
Expand All @@ -37,7 +40,8 @@ import scala.collection.JavaConverters._
* signatures
*/
private[johnsnowlabs] class MPNet(
val tensorflow: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
configProtoBytes: Option[Array[Byte]] = None,
sentenceStartTokenId: Int,
sentenceEndTokenId: Int,
Expand All @@ -47,8 +51,11 @@ private[johnsnowlabs] class MPNet(
private val _tfInstructorSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
private val paddingTokenId = 1
private val bosTokenId = 0
private val eosTokenId = 2

val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

/** Get sentence embeddings for a batch of sentences
* @param batch
Expand All @@ -57,6 +64,22 @@ private[johnsnowlabs] class MPNet(
* sentence embeddings
*/
private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val embeddings = detectedEngine match {
case ONNX.name =>
getSentenceEmbeddingFromOnnx(batch)
case _ =>
getSentenceEmbeddingFromTF(batch)
}
embeddings
}

/** Get sentence embeddings for a batch of sentences
* @param batch
* batch of sentences
* @return
* sentence embeddings
*/
private def getSentenceEmbeddingFromTF(batch: Seq[Array[Int]]): Array[Array[Float]] = {
// get max sentence length
val sequencesLength = batch.map(x => x.length).toArray
val maxSentenceLength = sequencesLength.max
Expand Down Expand Up @@ -92,7 +115,7 @@ private[johnsnowlabs] class MPNet(
tensorEncoder.createIntBufferTensor(shape, encoderAttentionMaskBuffers)

// run model
val runner = tensorflow
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
initAllTables = false,
Expand Down Expand Up @@ -131,6 +154,46 @@ private[johnsnowlabs] class MPNet(
sentenceEmbeddingsFloatsArray
}

private def getSentenceEmbeddingFromOnnx(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max

val (runner, env) = onnxWrapper.get.getSession()
val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("last_hidden_state")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

val dim = embeddings.length / batchLength
// group embeddings
val sentenceEmbeddingsFloatsArray = embeddings.grouped(dim).toArray
sentenceEmbeddingsFloatsArray
} finally if (results != null) results.close()
}
}

/** Predict sentence embeddings for a batch of sentences
* @param sentences
* sentences
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.ml.ai.MPNet
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadTextAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
Expand Down Expand Up @@ -145,6 +146,7 @@ class MPNetEmbeddings(override val uid: String)
extends AnnotatorModel[MPNetEmbeddings]
with HasBatchedAnnotate[MPNetEmbeddings]
with WriteTensorflowModel
with WriteOnnxModel
with HasEmbeddingsProperties
with HasStorageRef
with HasCaseSensitiveProperties
Expand Down Expand Up @@ -229,12 +231,14 @@ class MPNetEmbeddings(override val uid: String)
/** @group setParam */
def setModelIfNotSet(
spark: SparkSession,
tensorflowWrapper: TensorflowWrapper): MPNetEmbeddings = {
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrapper: Option[OnnxWrapper]): MPNetEmbeddings = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new MPNet(
tensorflowWrapper,
onnxWrapper,
configProtoBytes = getConfigProtoBytes,
sentenceStartTokenId = sentenceStartTokenId,
sentenceEndTokenId = sentenceEndTokenId,
Expand Down Expand Up @@ -336,14 +340,29 @@ class MPNetEmbeddings(override val uid: String)

override def onWrite(path: String, spark: SparkSession): Unit = {
super.onWrite(path, spark)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflow,
"_mpnet",
MPNetEmbeddings.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)
val suffix = "_mpnet"

getEngine match {
case TensorFlow.name =>
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
suffix,
MPNetEmbeddings.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)
case ONNX.name =>
writeOnnxModel(
path,
spark,
getModelIfNotSet.onnxWrapper.get,
suffix,
MPNetEmbeddings.onnxFile)

case _ =>
throw new Exception(notSupportedEngineError)
}
}

/** @group getParam */
Expand All @@ -366,7 +385,7 @@ class MPNetEmbeddings(override val uid: String)
trait ReadablePretrainedMPNetModel
extends ParamsAndFeaturesReadable[MPNetEmbeddings]
with HasPretrained[MPNetEmbeddings] {
override val defaultModelName: Some[String] = Some("mpnet_small")
override val defaultModelName: Some[String] = Some("all_mpnet_base_v2")

/** Java compliant-overrides */
override def pretrained(): MPNetEmbeddings = super.pretrained()
Expand All @@ -380,19 +399,26 @@ trait ReadablePretrainedMPNetModel
super.pretrained(name, lang, remoteLoc)
}

trait ReadMPNetDLModel extends ReadTensorflowModel {
trait ReadMPNetDLModel extends ReadTensorflowModel with ReadOnnxModel {
this: ParamsAndFeaturesReadable[MPNetEmbeddings] =>

override val tfFile: String = "mpnet_tensorflow"
override val onnxFile: String = "mpnet_onnx"
def readModel(instance: MPNetEmbeddings, path: String, spark: SparkSession): Unit = {

val tf = readTensorflowModel(
path,
spark,
"_mpnet_tf",
savedSignatures = instance.getSignatures,
initAllTables = false)
instance.setModelIfNotSet(spark, tf)
instance.getEngine match {
case TensorFlow.name =>
val tfWrapper = readTensorflowModel(path, spark, "_mpnet_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tfWrapper), None)

case ONNX.name =>
val onnxWrapper =
readOnnxModel(path, spark, "_mpnet_onnx", zipped = true, useBundle = false, None)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper))

case _ =>
throw new Exception(notSupportedEngineError)
}
}

addReader(readModel)
Expand Down Expand Up @@ -424,7 +450,12 @@ trait ReadMPNetDLModel extends ReadTensorflowModel {
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, wrapper)
.setModelIfNotSet(spark, Some(wrapper), None)

case ONNX.name =>
val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true)
annotatorModel
.setModelIfNotSet(spark, None, Some(onnxWrapper))

case _ =>
throw new Exception(notSupportedEngineError)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.flatspec.AnyFlatSpec

class MPNetEmbeddingsTestSpec extends AnyFlatSpec {

"E5 Embeddings" should "correctly embed multiple sentences" taggedAs SlowTest in {
"Mpnet Embeddings" should "correctly embed multiple sentences" taggedAs SlowTest in {

import ResourceHelper.spark.implicits._

Expand All @@ -38,12 +38,13 @@ class MPNetEmbeddingsTestSpec extends AnyFlatSpec {
val embeddings = MPNetEmbeddings
.pretrained()
.setInputCols(Array("document"))
.setOutputCol("e5")
.setOutputCol("mpnet")

val pipeline = new Pipeline().setStages(Array(document, embeddings))

val pipelineDF = pipeline.fit(ddd).transform(ddd)
pipelineDF.select("e5.embeddings").show(truncate = false)
pipelineDF.select("mpnet.embeddings").show(truncate = false)

}

}

0 comments on commit 96094c3

Please sign in to comment.