Skip to content

Commit

Permalink
Added Openvino support
Browse files Browse the repository at this point in the history
  • Loading branch information
prabod committed Jul 17, 2024
1 parent 4e6b443 commit 4c83df7
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 39 deletions.
4 changes: 2 additions & 2 deletions python/sparknlp/annotator/seq2seq/qwen_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.QwenTransf
repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], batchSize=1)

@staticmethod
def loadSavedModel(folder, spark_session):
def loadSavedModel(folder, spark_session, use_openvino=False):
"""Loads a locally saved model.
Parameters
Expand All @@ -313,7 +313,7 @@ def loadSavedModel(folder, spark_session):
The restored model
"""
from sparknlp.internal import _QwenLoader
jModel = _QwenLoader(folder, spark_session._jsparkSession)._java_obj
jModel = _QwenLoader(folder, spark_session._jsparkSession, use_openvino)._java_obj
return QwenTransformer(java_model=jModel)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions python/sparknlp/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def __init__(self, path, jspark, useCache):
)

class _QwenLoader(ExtendedJavaWrapper):
def __init__(self, path, jspark):
def __init__(self, path, jspark, use_openvino=False):
super(_QwenLoader, self).__init__(
"com.johnsnowlabs.nlp.annotators.seq2seq.QwenTransformer.loadSavedModel", path, jspark)
"com.johnsnowlabs.nlp.annotators.seq2seq.QwenTransformer.loadSavedModel", path, jspark, use_openvino)


class _USELoader(ExtendedJavaWrapper):
Expand Down
145 changes: 121 additions & 24 deletions src/main/scala/com/johnsnowlabs/ml/ai/Qwen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,33 @@ import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig}
import com.johnsnowlabs.ml.onnx.OnnxSession
import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers
import com.johnsnowlabs.ml.onnx.TensorResources.implicits._
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper
import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.nlp.Annotation
import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
import com.johnsnowlabs.nlp.annotators.common.SentenceSplit
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, QwenTokenizer}
import org.intel.openvino.InferRequest
import org.tensorflow.{Session, Tensor}

import scala.collection.JavaConverters._

private[johnsnowlabs] class Qwen(
val onnxWrappers: DecoderWrappers,
val onnxWrappers: Option[DecoderWrappers],
val openvinoWrapper: Option[OpenvinoWrapper],
merges: Map[(String, String), Int],
vocabulary: Map[String, Int],
generationConfig: GenerationConfig)
extends Serializable
with Generate {

private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions
val detectedEngine: String =
if (onnxWrappers.isDefined) ONNX.name
else if (openvinoWrapper.isDefined) Openvino.name
else ONNX.name
private var nextPositionId: Option[Array[Long]] = None
val bpeTokenizer: QwenTokenizer = BpeTokenizer
.forModel("qwen", merges = merges, vocab = vocabulary, padWithSequenceTokens = false)
.asInstanceOf[QwenTokenizer]
Expand Down Expand Up @@ -93,8 +102,8 @@ private[johnsnowlabs] class Qwen(
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {
val (encoderSession, env) = onnxWrappers.decoder.getSession(onnxSessionOptions)
maxInputLength: Int,
stopTokenIds: Array[Int]): Array[Array[Int]] = {
val ignoreTokenIdsInt = ignoreTokenIds
val expandedDecoderInputsVals = batch
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
Expand All @@ -121,10 +130,23 @@ private[johnsnowlabs] class Qwen(
// (encoderSession, env),
// maxOutputLength)

// dummy tensors for decoder encode state and attention mask
val decoderEncoderStateTensors = Right(OnnxTensor.createTensor(env, Array(0)))
val encoderAttentionMaskTensors = Right(OnnxTensor.createTensor(env, Array(1)))

val (decoderEncoderStateTensors, encoderAttentionMaskTensors, session) =
detectedEngine match {
case ONNX.name =>
// dummy tensors for decoder encode state and attention mask
val (encoderSession, env) = onnxWrappers.get.decoder.getSession(onnxSessionOptions)
(
Right(OnnxTensor.createTensor(env, Array(0))),
Right(OnnxTensor.createTensor(env, Array(1))),
Right((env, encoderSession)))
case Openvino.name =>
// not needed
(null, null, null)
}
val ovInferRequest: Option[InferRequest] = detectedEngine match {
case ONNX.name => None
case Openvino.name => Some(openvinoWrapper.get.getCompiledModel().create_infer_request())
}
// output with beam search
val modelOutputs = generate(
batch,
Expand All @@ -146,8 +168,10 @@ private[johnsnowlabs] class Qwen(
this.paddingTokenId,
randomSeed,
ignoreTokenIdsInt,
Right((env, encoderSession)),
applySoftmax = false)
session,
applySoftmax = false,
ovInferRequest = ovInferRequest,
stopTokenIds = stopTokenIds)

// decoderOutputs
modelOutputs
Expand All @@ -167,7 +191,8 @@ private[johnsnowlabs] class Qwen(
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch)
Expand All @@ -184,7 +209,8 @@ private[johnsnowlabs] class Qwen(
randomSeed,
ignoreTokenIds,
beamSize,
maxInputLength)
maxInputLength,
stopTokenIds)

decode(spIds)

Expand Down Expand Up @@ -239,20 +265,76 @@ private[johnsnowlabs] class Qwen(
decoderEncoderStateTensors: Either[Tensor, OnnxTensor],
encoderAttentionMaskTensors: Either[Tensor, OnnxTensor],
maxLength: Int,
session: Either[Session, (OrtEnvironment, OrtSession)]): Array[Array[Float]] = {
session: Either[Session, (OrtEnvironment, OrtSession)],
ovInferRequest: Option[InferRequest]): Array[Array[Float]] = {

session.fold(
tfSession => {
detectedEngine match {
case TensorFlow.name =>
// not implemented yet
Array()
},
onnxSession => {
val (env, decoderSession) = onnxSession
case ONNX.name =>
val (env, decoderSession) = session.right.get
val decoderOutputs =
getDecoderOutputs(decoderInputIds.toArray, onnxSession = (decoderSession, env))
decoderOutputs
})
case Openvino.name =>
val decoderOutputs =
getDecoderOutputsOv(decoderInputIds.toArray, ovInferRequest.get)
decoderOutputs
}
}

private def getDecoderOutputsOv(
inputIds: Array[Array[Int]],
inferRequest: InferRequest): (Array[Array[Float]]) = {
val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) =
if (nextPositionId.isDefined) {
val inpIdsLong = inputIds.map { tokenIds => tokenIds.last.toLong }
(inpIdsLong, nextPositionId.get)
} else {
val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
val posIdsLong = inputIds.flatMap { tokenIds =>
tokenIds.zipWithIndex.map { case (_, i) =>
i.toLong
}
}
(inpIdsLong, posIdsLong)
}
val attentionMask: Array[Long] =
inputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }

val batchSize: Int = inputIds.length
val beamIdx: Array[Int] = new Array[Int](batchSize)
val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)

val inputIdsLongTensor: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(shape, inputIdsLong)
val decoderAttentionMask: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(Array(batchSize, inputIds.head.length), attentionMask)
val decoderPositionIDs: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(shape, inputPositionIDsLong)
val beamIdxTensor: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(Array(batchSize), beamIdx)

inferRequest.set_tensor(OpenVinoSignatures.decoderInputIDs, inputIdsLongTensor)
inferRequest.set_tensor(OpenVinoSignatures.decoderAttentionMask, decoderAttentionMask)
inferRequest.set_tensor(OpenVinoSignatures.decoderPositionIDs, decoderPositionIDs)
inferRequest.set_tensor(OpenVinoSignatures.decoderBeamIdx, beamIdxTensor)

inferRequest.infer()

val result = inferRequest.get_tensor(OpenVinoSignatures.decoderOutput)
val logitsRaw = result.data()
nextPositionId = Some(inputIds.map(tokenIds => tokenIds.length.toLong))

val sequenceLength = inputIdsLong.length / batchSize
val decoderOutputs = (0 until batchSize).map(i => {
logitsRaw
.slice(
i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize,
i * sequenceLength * vocabSize + sequenceLength * vocabSize)
})
decoderOutputs.toArray
}
private def getDecoderOutputs(
inputIds: Array[Array[Int]],
Expand Down Expand Up @@ -285,12 +367,12 @@ private[johnsnowlabs] class Qwen(
val sequenceLength = inputIds.head.length
val batchSize = inputIds.length

// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
// inputIdsLongTensor.close()
// decoderPositionIDs.close()
// decoderAttentionMask.close()
// val batchLogits = logits.grouped(vocabSize).toArray
// batchLogits
// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
// inputIdsLongTensor.close()
// decoderPositionIDs.close()
// decoderAttentionMask.close()
// val batchLogits = logits.grouped(vocabSize).toArray
// batchLogits

val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput)
val decoderOutputs = (0 until batchSize).map(i => {
Expand Down Expand Up @@ -358,4 +440,19 @@ private[johnsnowlabs] class Qwen(
(0 until 32).flatMap(i => Seq(s"present.$i.key", s"present.$i.value")).toArray
}

private object OpenVinoSignatures {
val encoderInputIDs: String = "input_ids"
val encoderAttentionMask: String = "attention_mask"

val encoderOutput: String = "last_hidden_state"

val decoderInputIDs: String = "input_ids"
val decoderEncoderAttentionMask: String = "encoder_attention_mask"
val decoderAttentionMask: String = "attention_mask"
val decoderPositionIDs: String = "position_ids"
val decoderBeamIdx: String = "beam_idx"
val decoderEncoderState: String = "encoder_hidden_states"

val decoderOutput: String = "logits"
}
}
Loading

0 comments on commit 4c83df7

Please sign in to comment.