From b0e7010d4e45186083b4a33112eb8f6b9e6183b6 Mon Sep 17 00:00:00 2001 From: ahmedlone127 Date: Fri, 2 Aug 2024 17:24:07 +0500 Subject: [PATCH] adding onnx API --- .../com/johnsnowlabs/ml/ai/Instructor.scala | 134 ++++++++++++++---- .../nlp/embeddings/InstructorEmbeddings.scala | 74 ++++++++-- 2 files changed, 167 insertions(+), 41 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala index 1507da55a59de7..52577fa1d757bd 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala @@ -16,9 +16,12 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.{OnnxTensor, TensorInfo} +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} +import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow} import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} import scala.collection.JavaConverters._ @@ -36,7 +39,8 @@ import scala.collection.JavaConverters._ */ private[johnsnowlabs] class Instructor( - val tensorflow: TensorflowWrapper, + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], val spp: SentencePieceWrapper, configProtoBytes: Option[Array[Byte]] = None, signatures: Option[Map[String, String]] = None) @@ -46,44 +50,98 @@ private[johnsnowlabs] class Instructor( signatures.getOrElse(ModelSignatureManager.apply()) private val paddingTokenId = 0 private val eosTokenId = 1 + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions - /** Get sentence embeddings for a batch of sentences - * @param batch - * batch of sentences - * @param contextLengths - * context lengths - * @return - * sentence embeddings - */ - private def getSentenceEmbedding( - batch: Seq[Array[Int]], - contextLengths: Seq[Int]): Array[Array[Float]] = { - // get max sentence length - val sequencesLength = batch.map(x => x.length).toArray - val maxSentenceLength = sequencesLength.max - val batchLength = batch.length + private def getSentenceEmbeddingFromOnnx( + batch: Seq[Array[Int]], + contextLengths: Seq[Int], + maxSentenceLength: Int): Array[Array[Float]] = { + + val inputIds = batch.map(x => x.map(x => x.toLong)).toArray + val attentionMask = batch + .map(sentence => sentence.map(x => if (x == this.paddingTokenId) 0L else 1L)) + .toArray + + val contextMask = attentionMask.zipWithIndex.map { case (batchElement, idx) => + batchElement.zipWithIndex.map { case (x, i) => + if (i < contextLengths(idx)) 0L else x + } + }.toArray + + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + val tokenTensors = OnnxTensor.createTensor(env, inputIds) + val maskTensors = OnnxTensor.createTensor(env, attentionMask) + val contextTensor = + OnnxTensor.createTensor(env, contextMask) + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors).asJava + + // TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled. + try { + val results = runner.run(inputs) + val lastHiddenState = results.get("token_embeddings").get() + val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo] + val shape = info.getShape + try { + val flattenEmbeddings = lastHiddenState + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + val embeddings = LinAlg.avgPooling(flattenEmbeddings, contextMask, shape) + val normalizedEmbeddings = LinAlg.l2Normalize(embeddings) + LinAlg.denseMatrixToArray(normalizedEmbeddings) + } finally if (results != null) results.close() + } catch { + case e: Exception => + // Handle exceptions by logging or other means. + e.printStackTrace() + Array.empty[Array[Float]] // Return an empty array or appropriate error handling + } finally { + // Close tensors outside the try-catch to avoid repeated null checks. + // These resources are initialized before the try-catch, so they should be closed here. + tokenTensors.close() + maskTensors.close() + contextTensor.close() + } + } + + private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = { + if (arr.length >= maxLength) { + arr + } else { + arr ++ Array.fill(maxLength - arr.length)(0) + } + } + + private def getSentenceEmbeddingFromTF( + paddedBatch: Seq[Array[Int]], + contextLengths: Seq[Int], + maxSentenceLength: Int) = { // encode batch val tensorEncoder = new TensorResources() - val inputDim = batch.length * maxSentenceLength + val inputDim = paddedBatch.length * maxSentenceLength + val batchLength = paddedBatch.length // create buffers val encoderInputBuffers = tensorEncoder.createIntBuffer(inputDim) val encoderAttentionMaskBuffers = tensorEncoder.createIntBuffer(inputDim) val encoderContextMaskBuffers = tensorEncoder.createIntBuffer(inputDim) - val shape = Array(batch.length.toLong, maxSentenceLength) + val shape = Array(paddedBatch.length.toLong, maxSentenceLength) - batch.zipWithIndex.foreach { case (tokenIds, idx) => + paddedBatch.zipWithIndex.foreach { case (tokenIds, idx) => val offset = idx * maxSentenceLength - val diff = maxSentenceLength - tokenIds.length - - // pad with 0 - val s = tokenIds.take(maxSentenceLength) ++ Array.fill[Int](diff)(this.paddingTokenId) - encoderInputBuffers.offset(offset).write(s) + encoderInputBuffers.offset(offset).write(tokenIds) // create attention mask - val mask = s.map(x => if (x != this.paddingTokenId) 1 else 0) + val mask = tokenIds.map(x => if (x != this.paddingTokenId) 1 else 0) encoderAttentionMaskBuffers.offset(offset).write(mask) // create context mask @@ -101,7 +159,7 @@ private[johnsnowlabs] class Instructor( tensorEncoder.createIntBufferTensor(shape, encoderContextMaskBuffers) // run model - val runner = tensorflow + val runner = tensorflowWrapper.get .getTFSessionWithSignature( configProtoBytes = configProtoBytes, initAllTables = false, @@ -144,6 +202,30 @@ private[johnsnowlabs] class Instructor( tensorEncoder.clearSession(sentenceEmbeddings) sentenceEmbeddingsFloatsArray + + } + /** Get sentence embeddings for a batch of sentences + * @param batch + * batch of sentences + * @param contextLengths + * context lengths + * @return + * sentence embeddings + */ + private def getSentenceEmbedding( + batch: Seq[Array[Int]], + contextLengths: Seq[Int]): Array[Array[Float]] = { + val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max + val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength)) + val sentenceEmbeddings: Array[Array[Float]] = detectedEngine match { + case ONNX.name => + getSentenceEmbeddingFromOnnx(paddedBatch, contextLengths, maxSentenceLength) + case _ => // TF Case + getSentenceEmbeddingFromTF(paddedBatch, contextLengths, maxSentenceLength) + } + + sentenceEmbeddings + } /** Tokenize sentences diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala index ede6caa13ee84f..6fa3bf2bc00443 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala @@ -16,7 +16,9 @@ package com.johnsnowlabs.nlp.embeddings + import com.johnsnowlabs.ml.ai.Instructor +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, @@ -28,7 +30,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.TensorFlow +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import com.johnsnowlabs.storage.HasStorageRef @@ -148,6 +150,7 @@ class InstructorEmbeddings(override val uid: String) extends AnnotatorModel[InstructorEmbeddings] with HasBatchedAnnotate[InstructorEmbeddings] with WriteTensorflowModel + with WriteOnnxModel with HasEmbeddingsProperties with HasStorageRef with WriteSentencePieceModel @@ -227,13 +230,15 @@ class InstructorEmbeddings(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], spp: SentencePieceWrapper): InstructorEmbeddings = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new Instructor( tensorflowWrapper, + onnxWrapper, spp = spp, configProtoBytes = getConfigProtoBytes, signatures = getSignatures))) @@ -319,15 +324,29 @@ class InstructorEmbeddings(override val uid: String) def getModelIfNotSet: Instructor = _model.get.value override def onWrite(path: String, spark: SparkSession): Unit = { + + + super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflow, - "_instructor", - InstructorEmbeddings.tfFile, - configProtoBytes = getConfigProtoBytes, - savedSignatures = getSignatures) + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + "_instructor", + InstructorEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes, + savedSignatures = getSignatures) + + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + "_instructor", + InstructorEmbeddings.onnxFile) + } writeSentencePieceModel( path, spark, @@ -335,6 +354,7 @@ class InstructorEmbeddings(override val uid: String) "_instructor", InstructorEmbeddings.sppFile) + } /** @group getParam */ @@ -371,21 +391,41 @@ trait ReadablePretrainedInstructorModel super.pretrained(name, lang, remoteLoc) } -trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceModel { +trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceModel with ReadOnnxModel { this: ParamsAndFeaturesReadable[InstructorEmbeddings] => override val tfFile: String = "instructor_tensorflow" override val sppFile: String = "instructor_spp" + override val onnxFile: String = "instructor_onnx" + def readModel(instance: InstructorEmbeddings, path: String, spark: SparkSession): Unit = { + val spp = readSentencePieceModel(path, spark, "_instructor_spp", sppFile) + + instance.getEngine match { + case TensorFlow.name => val tf = readTensorflowModel( path, spark, "_instructor_tf", savedSignatures = instance.getSignatures, initAllTables = false) - val spp = readSentencePieceModel(path, spark, "_instructor_spp", sppFile) - instance.setModelIfNotSet(spark, tf, spp) + instance.setModelIfNotSet(spark, Some(tf), None, spp) + + + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "_instructor_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + + } + } addReader(readModel) @@ -401,7 +441,7 @@ trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceMo val spModel = loadSentencePieceAsset(localModelPath, "spiece.model") detectedEngine match { case TensorFlow.name => - val (wrapper, signatures) = TensorflowWrapper.read( + val (tfwrapper, signatures) = TensorflowWrapper.read( localModelPath, zipped = false, useBundle = true, @@ -418,8 +458,12 @@ trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceMo */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper, spModel) + .setModelIfNotSet(spark, Some(tfwrapper), None, spModel) + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => throw new Exception(notSupportedEngineError) }