Skip to content

Commit

Permalink
SPARKNLP-938 E5 and MPNet embeddings crash on a sentence basis - miss…
Browse files Browse the repository at this point in the history
…ing pool average (#14051)

* [SPARKNLP-938] Padding with zeros to make sure all sequences have same length

* [SPARKNLP-938] Removing unused tests

* [SPARKNLP-939] Fix embeddings dimension size in ONNX models

* [SPARKNLP-939] Removing unnecessary tests

* [SPARKNLP-939] Adding average pooling and normalizing E5 and MPNeT embeddings

* [SPARKNLP-939] Reformatting code
  • Loading branch information
danilojsl authored Dec 7, 2023
1 parent 4850abf commit 999a21b
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 50 deletions.
62 changes: 37 additions & 25 deletions src/main/scala/com/johnsnowlabs/ml/ai/E5.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.{OnnxTensor, TensorInfo}
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

Expand Down Expand Up @@ -63,19 +63,28 @@ private[johnsnowlabs] class E5(
* sentence embeddings
*/
private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength))
val embeddings = detectedEngine match {
case ONNX.name =>
getSentenceEmbeddingFromOnnx(batch)
getSentenceEmbeddingFromOnnx(paddedBatch, maxSentenceLength)
case _ =>
getSentenceEmbeddingFromTF(batch)
getSentenceEmbeddingFromTF(paddedBatch, maxSentenceLength)
}
embeddings
}

private def getSentenceEmbeddingFromTF(batch: Seq[Array[Int]]): Array[Array[Float]] = {
// get max sentence length
val sequencesLength = batch.map(x => x.length).toArray
val maxSentenceLength = sequencesLength.max
private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = {
if (arr.length >= maxLength) {
arr
} else {
arr ++ Array.fill(maxLength - arr.length)(0)
}
}

private def getSentenceEmbeddingFromTF(
batch: Seq[Array[Int]],
maxSentenceLength: Int): Array[Array[Float]] = {
val batchLength = batch.length

// encode batch
Expand Down Expand Up @@ -147,17 +156,17 @@ private[johnsnowlabs] class E5(
sentenceEmbeddingsFloatsArray
}

private def getSentenceEmbeddingFromOnnx(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
private def getSentenceEmbeddingFromOnnx(
batch: Seq[Array[Int]],
maxSentenceLength: Int): Array[Array[Float]] = {

val inputIds = batch.map(x => x.map(x => x.toLong)).toArray
val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray

val (runner, env) = onnxWrapper.get.getSession()
val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val tokenTensors = OnnxTensor.createTensor(env, inputIds)
val maskTensors = OnnxTensor.createTensor(env, attentionMask)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)
Expand All @@ -171,21 +180,23 @@ private[johnsnowlabs] class E5(
// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
val lastHiddenState = results.get("last_hidden_state").get()
val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo]
val shape = info.getShape
try {
val embeddings = results
.get("last_hidden_state")
.get()
val embeddings = lastHiddenState
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

val dim = embeddings.length / batchLength
// group embeddings
val sentenceEmbeddingsFloatsArray = embeddings.grouped(dim).toArray
sentenceEmbeddingsFloatsArray
val dim = shape.last.toInt
val avgPooling = LinAlg.avgPooling(embeddings, attentionMask(0), dim)
val normalizedSentenceEmbeddings = LinAlg.normalizeArray(avgPooling)

Array(normalizedSentenceEmbeddings)
} finally if (results != null) results.close()
}
}
Expand Down Expand Up @@ -213,11 +224,12 @@ private[johnsnowlabs] class E5(
.grouped(batchSize)
.toArray
.flatMap { batch =>
val tokensBatch = batch.map(x => (x._1._1.tokens))
val tokensBatch = batch.map(x => x._1._1.tokens)
val tokens = tokensBatch.map(x =>
Array(sentenceStartTokenId) ++ x
.map(y => y.pieceId)
.take(maxSentenceLength - 2) ++ Array(sentenceEndTokenId))

val sentenceEmbeddings = getSentenceEmbedding(tokens)

batch.zip(sentenceEmbeddings).map { case (sentence, vectors) =>
Expand Down
51 changes: 28 additions & 23 deletions src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.{OnnxTensor, TensorInfo}
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -64,15 +64,25 @@ private[johnsnowlabs] class MPNet(
* sentence embeddings
*/
private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength))
val embeddings = detectedEngine match {
case ONNX.name =>
getSentenceEmbeddingFromOnnx(batch)
getSentenceEmbeddingFromOnnx(paddedBatch)
case _ =>
getSentenceEmbeddingFromTF(batch)
getSentenceEmbeddingFromTF(paddedBatch)
}
embeddings
}

private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = {
if (arr.length >= maxLength) {
arr
} else {
arr ++ Array.fill(maxLength - arr.length)(0)
}
}

/** Get sentence embeddings for a batch of sentences
* @param batch
* batch of sentences
Expand Down Expand Up @@ -155,41 +165,36 @@ private[johnsnowlabs] class MPNet(
}

private def getSentenceEmbeddingFromOnnx(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max

val (runner, env) = onnxWrapper.get.getSession()
val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)
val inputIds = batch.map(x => x.map(x => x.toLong)).toArray
val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)
val (runner, env) = onnxWrapper.get.getSession()
val tokenTensors = OnnxTensor.createTensor(env, inputIds)
val maskTensors = OnnxTensor.createTensor(env, attentionMask)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
val lastHiddenState = results.get("last_hidden_state").get()
val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo]
val shape = info.getShape
try {
val embeddings = results
.get("last_hidden_state")
.get()
val embeddings = lastHiddenState
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

val dim = embeddings.length / batchLength
// group embeddings
val sentenceEmbeddingsFloatsArray = embeddings.grouped(dim).toArray
sentenceEmbeddingsFloatsArray
val dim = shape.last.toInt
val avgPooling = LinAlg.avgPooling(embeddings, attentionMask(0), dim)
val normalizedSentenceEmbeddings = LinAlg.normalizeArray(avgPooling)

Array(normalizedSentenceEmbeddings)
} finally if (results != null) results.close()
}
}
Expand Down
46 changes: 46 additions & 0 deletions src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.johnsnowlabs.ml.util

import breeze.linalg.{DenseMatrix, tile}
import scala.math.sqrt

object LinAlg {

Expand Down Expand Up @@ -56,4 +57,49 @@ object LinAlg {

}

def avgPooling(embeddings: Array[Float], attentionMask: Array[Long], dim: Int): Array[Float] = {
val expandedAttentionMask = new Array[Float](embeddings.length)
// Expand attentionMask to match the length of embeddings
var j = 0
for (i <- embeddings.indices) {
expandedAttentionMask(i) = attentionMask(j)
j += 1
if (j == attentionMask.length) {
j = 0 // reset j when we reach the end of attentionMask
}
}

val sentenceEmbeddingsMatrix = embeddings.grouped(dim).toArray
val attentionMaskMatrix = expandedAttentionMask.grouped(dim).toArray

val elementWiseProduct =
computeElementWiseProduct(sentenceEmbeddingsMatrix, attentionMaskMatrix)
val weightedSum: Array[Float] = elementWiseProduct.transpose.map(_.sum)

val sumAlongDimension2: Array[Float] = attentionMaskMatrix.transpose.map(_.sum)
// Clamp each element to a minimum value of 1e-9
val totalWeight: Array[Float] = sumAlongDimension2.map(x => math.max(x, 1e-9.toFloat))
computeElementWiseDivision(weightedSum, totalWeight)
}

def computeElementWiseProduct(
arrayA: Array[Array[Float]],
arrayB: Array[Array[Float]]): Array[Array[Float]] = {
arrayA.zip(arrayB).map { case (row1, row2) =>
row1.zip(row2).map { case (a, b) => a * b }
}
}

def computeElementWiseDivision(arrayA: Array[Float], arrayB: Array[Float]): Array[Float] = {
arrayA.zip(arrayB).map { case (a, b) =>
if (b != 0.0f) a / b else 0.0f // Avoid division by zero
}
}

def normalizeArray(array: Array[Float]): Array[Float] = {
val l2Norm: Float = sqrt(array.map(x => x * x).sum).toFloat
// Normalize each element in the array
array.map(value => if (l2Norm != 0.0f) value / l2Norm else 0.0f)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel
import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.tags.{SlowTest}
import com.johnsnowlabs.tags.SlowTest
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.functions.{col, size}
import org.scalatest.flatspec.AnyFlatSpec

class E5EmbeddingsTestSpec extends AnyFlatSpec {
Expand Down Expand Up @@ -54,4 +56,61 @@ class E5EmbeddingsTestSpec extends AnyFlatSpec {
pipelineDF.select("e5.embeddings").show(truncate = false)

}

it should "have embeddings of the same size" taggedAs SlowTest in {
import ResourceHelper.spark.implicits._
val testDf = Seq(
"I like apples",
"I like bananas \\n and other things \\n like icream \\n and cats",
"I like rockets")
.toDF("text")

val document = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val embeddings = E5Embeddings
.pretrained()
.setInputCols(Array("document"))
.setOutputCol("e5")

val pipeline = new Pipeline().setStages(Array(document, embeddings))

val pipelineDF = pipeline.fit(testDf).transform(testDf)

val embeddingsDF = pipelineDF.withColumn("embeddings", col("e5.embeddings").getItem(0))

val sizesArray: Array[Int] = embeddingsDF
.select(size(col("embeddings")).as("size"))
.collect()
.map(row => row.getAs[Int]("size"))

assert(sizesArray.forall(_ == sizesArray.head))
}

it should "work with sentences" taggedAs SlowTest in {
import ResourceHelper.spark.implicits._
val testData = "I really enjoy my job. This is amazing"
val testDf = Seq(testData).toDF("text")

val document = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val sentenceDetectorDL = SentenceDetectorDLModel
.pretrained("sentence_detector_dl", "en")
.setInputCols(Array("document"))
.setOutputCol("sentences")

val embeddings = E5Embeddings
.pretrained()
.setInputCols(Array("sentences"))
.setOutputCol("e5")

val pipeline = new Pipeline().setStages(Array(document, sentenceDetectorDL, embeddings))

val pipelineDF = pipeline.fit(testDf).transform(testDf)
pipelineDF.select("e5.embeddings").show(false)
}

}
Loading

0 comments on commit 999a21b

Please sign in to comment.