Skip to content

Commit

Permalink
adding onnx API
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedlone127 committed Aug 2, 2024
1 parent 5a01057 commit b0e7010
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 41 deletions.
134 changes: 108 additions & 26 deletions src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.{OnnxTensor, TensorInfo}
import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper}
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

import scala.collection.JavaConverters._
Expand All @@ -36,7 +39,8 @@ import scala.collection.JavaConverters._
*/

private[johnsnowlabs] class Instructor(
val tensorflow: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val spp: SentencePieceWrapper,
configProtoBytes: Option[Array[Byte]] = None,
signatures: Option[Map[String, String]] = None)
Expand All @@ -46,44 +50,98 @@ private[johnsnowlabs] class Instructor(
signatures.getOrElse(ModelSignatureManager.apply())
private val paddingTokenId = 0
private val eosTokenId = 1
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name
private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions

/** Get sentence embeddings for a batch of sentences
* @param batch
* batch of sentences
* @param contextLengths
* context lengths
* @return
* sentence embeddings
*/
private def getSentenceEmbedding(
batch: Seq[Array[Int]],
contextLengths: Seq[Int]): Array[Array[Float]] = {
// get max sentence length
val sequencesLength = batch.map(x => x.length).toArray
val maxSentenceLength = sequencesLength.max
val batchLength = batch.length
private def getSentenceEmbeddingFromOnnx(
batch: Seq[Array[Int]],
contextLengths: Seq[Int],
maxSentenceLength: Int): Array[Array[Float]] = {

val inputIds = batch.map(x => x.map(x => x.toLong)).toArray
val attentionMask = batch
.map(sentence => sentence.map(x => if (x == this.paddingTokenId) 0L else 1L))
.toArray

val contextMask = attentionMask.zipWithIndex.map { case (batchElement, idx) =>
batchElement.zipWithIndex.map { case (x, i) =>
if (i < contextLengths(idx)) 0L else x
}
}.toArray

val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

val tokenTensors = OnnxTensor.createTensor(env, inputIds)
val maskTensors = OnnxTensor.createTensor(env, attentionMask)
val contextTensor =
OnnxTensor.createTensor(env, contextMask)
val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors).asJava

// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
val lastHiddenState = results.get("token_embeddings").get()
val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo]
val shape = info.getShape
try {
val flattenEmbeddings = lastHiddenState
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
val embeddings = LinAlg.avgPooling(flattenEmbeddings, contextMask, shape)
val normalizedEmbeddings = LinAlg.l2Normalize(embeddings)
LinAlg.denseMatrixToArray(normalizedEmbeddings)
} finally if (results != null) results.close()
} catch {
case e: Exception =>
// Handle exceptions by logging or other means.
e.printStackTrace()
Array.empty[Array[Float]] // Return an empty array or appropriate error handling
} finally {
// Close tensors outside the try-catch to avoid repeated null checks.
// These resources are initialized before the try-catch, so they should be closed here.
tokenTensors.close()
maskTensors.close()
contextTensor.close()
}
}

private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = {
if (arr.length >= maxLength) {
arr
} else {
arr ++ Array.fill(maxLength - arr.length)(0)
}
}

private def getSentenceEmbeddingFromTF(
paddedBatch: Seq[Array[Int]],
contextLengths: Seq[Int],
maxSentenceLength: Int) = {
// encode batch
val tensorEncoder = new TensorResources()
val inputDim = batch.length * maxSentenceLength
val inputDim = paddedBatch.length * maxSentenceLength
val batchLength = paddedBatch.length

// create buffers
val encoderInputBuffers = tensorEncoder.createIntBuffer(inputDim)
val encoderAttentionMaskBuffers = tensorEncoder.createIntBuffer(inputDim)
val encoderContextMaskBuffers = tensorEncoder.createIntBuffer(inputDim)

val shape = Array(batch.length.toLong, maxSentenceLength)
val shape = Array(paddedBatch.length.toLong, maxSentenceLength)

batch.zipWithIndex.foreach { case (tokenIds, idx) =>
paddedBatch.zipWithIndex.foreach { case (tokenIds, idx) =>
val offset = idx * maxSentenceLength
val diff = maxSentenceLength - tokenIds.length

// pad with 0
val s = tokenIds.take(maxSentenceLength) ++ Array.fill[Int](diff)(this.paddingTokenId)
encoderInputBuffers.offset(offset).write(s)
encoderInputBuffers.offset(offset).write(tokenIds)

// create attention mask
val mask = s.map(x => if (x != this.paddingTokenId) 1 else 0)
val mask = tokenIds.map(x => if (x != this.paddingTokenId) 1 else 0)
encoderAttentionMaskBuffers.offset(offset).write(mask)

// create context mask
Expand All @@ -101,7 +159,7 @@ private[johnsnowlabs] class Instructor(
tensorEncoder.createIntBufferTensor(shape, encoderContextMaskBuffers)

// run model
val runner = tensorflow
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
initAllTables = false,
Expand Down Expand Up @@ -144,6 +202,30 @@ private[johnsnowlabs] class Instructor(
tensorEncoder.clearSession(sentenceEmbeddings)

sentenceEmbeddingsFloatsArray

}
/** Get sentence embeddings for a batch of sentences
* @param batch
* batch of sentences
* @param contextLengths
* context lengths
* @return
* sentence embeddings
*/
private def getSentenceEmbedding(
batch: Seq[Array[Int]],
contextLengths: Seq[Int]): Array[Array[Float]] = {
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength))
val sentenceEmbeddings: Array[Array[Float]] = detectedEngine match {
case ONNX.name =>
getSentenceEmbeddingFromOnnx(paddedBatch, contextLengths, maxSentenceLength)
case _ => // TF Case
getSentenceEmbeddingFromTF(paddedBatch, contextLengths, maxSentenceLength)
}

sentenceEmbeddings

}

/** Tokenize sentences
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package com.johnsnowlabs.nlp.embeddings


import com.johnsnowlabs.ml.ai.Instructor
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{
ReadSentencePieceModel,
Expand All @@ -28,7 +30,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.storage.HasStorageRef
Expand Down Expand Up @@ -148,6 +150,7 @@ class InstructorEmbeddings(override val uid: String)
extends AnnotatorModel[InstructorEmbeddings]
with HasBatchedAnnotate[InstructorEmbeddings]
with WriteTensorflowModel
with WriteOnnxModel
with HasEmbeddingsProperties
with HasStorageRef
with WriteSentencePieceModel
Expand Down Expand Up @@ -227,13 +230,15 @@ class InstructorEmbeddings(override val uid: String)
/** @group setParam */
def setModelIfNotSet(
spark: SparkSession,
tensorflowWrapper: TensorflowWrapper,
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrapper: Option[OnnxWrapper],
spp: SentencePieceWrapper): InstructorEmbeddings = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new Instructor(
tensorflowWrapper,
onnxWrapper,
spp = spp,
configProtoBytes = getConfigProtoBytes,
signatures = getSignatures)))
Expand Down Expand Up @@ -319,22 +324,37 @@ class InstructorEmbeddings(override val uid: String)
def getModelIfNotSet: Instructor = _model.get.value

override def onWrite(path: String, spark: SparkSession): Unit = {



super.onWrite(path, spark)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflow,
"_instructor",
InstructorEmbeddings.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)
getEngine match {
case TensorFlow.name =>
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
"_instructor",
InstructorEmbeddings.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)

case ONNX.name =>
writeOnnxModel(
path,
spark,
getModelIfNotSet.onnxWrapper.get,
"_instructor",
InstructorEmbeddings.onnxFile)
}
writeSentencePieceModel(
path,
spark,
getModelIfNotSet.spp,
"_instructor",
InstructorEmbeddings.sppFile)


}

/** @group getParam */
Expand Down Expand Up @@ -371,21 +391,41 @@ trait ReadablePretrainedInstructorModel
super.pretrained(name, lang, remoteLoc)
}

trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceModel {
trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceModel with ReadOnnxModel {
this: ParamsAndFeaturesReadable[InstructorEmbeddings] =>

override val tfFile: String = "instructor_tensorflow"
override val sppFile: String = "instructor_spp"
override val onnxFile: String = "instructor_onnx"

def readModel(instance: InstructorEmbeddings, path: String, spark: SparkSession): Unit = {
val spp = readSentencePieceModel(path, spark, "_instructor_spp", sppFile)


instance.getEngine match {
case TensorFlow.name =>
val tf = readTensorflowModel(
path,
spark,
"_instructor_tf",
savedSignatures = instance.getSignatures,
initAllTables = false)
val spp = readSentencePieceModel(path, spark, "_instructor_spp", sppFile)
instance.setModelIfNotSet(spark, tf, spp)
instance.setModelIfNotSet(spark, Some(tf), None, spp)


case ONNX.name =>
val onnxWrapper =
readOnnxModel(
path,
spark,
"_instructor_onnx",
zipped = true,
useBundle = false,
None)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp)

}

}

addReader(readModel)
Expand All @@ -401,7 +441,7 @@ trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceMo
val spModel = loadSentencePieceAsset(localModelPath, "spiece.model")
detectedEngine match {
case TensorFlow.name =>
val (wrapper, signatures) = TensorflowWrapper.read(
val (tfwrapper, signatures) = TensorflowWrapper.read(
localModelPath,
zipped = false,
useBundle = true,
Expand All @@ -418,8 +458,12 @@ trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceMo
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, wrapper, spModel)
.setModelIfNotSet(spark, Some(tfwrapper), None, spModel)

case ONNX.name =>
val onnxWrapper = OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true)
annotatorModel
.setModelIfNotSet(spark, None, Some(onnxWrapper), spModel)
case _ =>
throw new Exception(notSupportedEngineError)
}
Expand Down

0 comments on commit b0e7010

Please sign in to comment.