Skip to content

Commit

Permalink
Added custom stop token id support (#14344)
Browse files Browse the repository at this point in the history
  • Loading branch information
prabod authored Jul 14, 2024
1 parent e120e61 commit 257fd57
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 47 deletions.
16 changes: 10 additions & 6 deletions src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ private[johnsnowlabs] class LLAMA2(
*/
def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = {
sentences.map(s => {
val sentWithTask = s.result
spp.getSppModel.encodeAsIds(sentWithTask)
val sentWithTask = "_" + s.result
Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask)
})
}

Expand All @@ -97,7 +97,8 @@ private[johnsnowlabs] class LLAMA2(
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Array[Array[Int]] = {
val ignoreTokenIdsInt = ignoreTokenIds
val expandedDecoderInputsVals = batch
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
Expand Down Expand Up @@ -165,7 +166,8 @@ private[johnsnowlabs] class LLAMA2(
ignoreTokenIdsInt,
session,
applySoftmax = true,
ovInferRequest = ovInferRequest)
ovInferRequest = ovInferRequest,
stopTokenIds = stopTokenIds)

modelOutputs
}
Expand All @@ -184,7 +186,8 @@ private[johnsnowlabs] class LLAMA2(
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch)
Expand All @@ -201,7 +204,8 @@ private[johnsnowlabs] class LLAMA2(
randomSeed,
ignoreTokenIds,
beamSize,
maxInputLength)
maxInputLength,
stopTokenIds)

decode(spIds)

Expand Down
18 changes: 11 additions & 7 deletions src/main/scala/com/johnsnowlabs/ml/ai/Mistral.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ private[johnsnowlabs] class Mistral(
*/
def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = {
sentences.map(s => {
val sentWithTask = s.result
spp.getSppModel.encodeAsIds(sentWithTask)
val sentWithTask = "_" + s.result
Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask)
})
}

Expand All @@ -96,7 +96,8 @@ private[johnsnowlabs] class Mistral(
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {
maxInputLength: Int,
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {
val ignoreTokenIdsInt = ignoreTokenIds
val expandedDecoderInputsVals = batch
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
Expand Down Expand Up @@ -162,8 +163,9 @@ private[johnsnowlabs] class Mistral(
randomSeed,
ignoreTokenIdsInt,
session,
applySoftmax = false,
ovInferRequest = ovInferRequest)
applySoftmax = true,
ovInferRequest = ovInferRequest,
stopTokenIds = stopTokenIds)

// decoderOutputs
modelOutputs
Expand All @@ -183,7 +185,8 @@ private[johnsnowlabs] class Mistral(
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch)
Expand All @@ -200,7 +203,8 @@ private[johnsnowlabs] class Mistral(
randomSeed,
ignoreTokenIds,
beamSize,
maxInputLength)
maxInputLength,
stopTokenIds)

decode(spIds)

Expand Down
12 changes: 8 additions & 4 deletions src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ private[johnsnowlabs] class Phi2(
randomSeed: Option[Long],
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Array[Array[Int]] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Array[Array[Int]] = {
val ignoreTokenIdsInt = ignoreTokenIds
val expandedDecoderInputsVals = batch
val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray
Expand Down Expand Up @@ -169,7 +170,8 @@ private[johnsnowlabs] class Phi2(
ignoreTokenIdsInt,
session,
applySoftmax = false,
ovInferRequest = ovInferRequest)
ovInferRequest = ovInferRequest,
stopTokenIds = stopTokenIds)

// decoderOutputs
modelOutputs
Expand All @@ -189,7 +191,8 @@ private[johnsnowlabs] class Phi2(
randomSeed: Option[Long] = None,
ignoreTokenIds: Array[Int] = Array(),
beamSize: Int,
maxInputLength: Int): Seq[Annotation] = {
maxInputLength: Int,
stopTokenIds: Array[Int]): Seq[Annotation] = {

val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch =>
val batchSP = encode(batch)
Expand All @@ -206,7 +209,8 @@ private[johnsnowlabs] class Phi2(
randomSeed,
ignoreTokenIds,
beamSize,
maxInputLength)
maxInputLength,
stopTokenIds)

decode(spIds)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ trait Generate {
ignoreTokenIds: Array[Int] = Array(),
session: Either[Session, (OrtEnvironment, OrtSession)],
applySoftmax: Boolean = true,
ovInferRequest: Option[InferRequest] = None): Array[Array[Int]] = {
ovInferRequest: Option[InferRequest] = None,
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {

// TODO: Add support for ignoreTokenIds

Expand All @@ -117,8 +118,8 @@ trait Generate {
noRepeatNgramSize = noRepeatNgramSize,
vocabSize = vocabSize))

logitProcessorList.addProcess(
new MinLengthLogitProcessor(eosTokenId, minOutputLength, vocabSize))
// logitProcessorList.addProcess(
// new MinLengthLogitProcessor(eosTokenId, minOutputLength, vocabSize))

logitProcessorList.addProcess(new TemperatureLogitWarper(temperature))

Expand Down Expand Up @@ -148,7 +149,8 @@ trait Generate {
randomSeed,
session,
applySoftmax,
ovInferRequest)
ovInferRequest,
stopTokenIds)
}

/** Beam Search for text generation
Expand Down Expand Up @@ -193,7 +195,8 @@ trait Generate {
randomSeed: Option[Long],
session: Either[Session, (OrtEnvironment, OrtSession)],
applySoftmax: Boolean,
ovInferRequest: Option[InferRequest] = None): Array[Array[Int]] = {
ovInferRequest: Option[InferRequest] = None,
stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = {
val inputIds = inputIdsVal
val batchSize = beamScorer.getBeamHypothesesSeq.length
val numBeams = beamScorer.getNumBeams
Expand Down Expand Up @@ -227,21 +230,22 @@ trait Generate {
// Optionally Apply log softmax to model outputs
var nextTokenScores =
if (applySoftmax) nextTokenLogits.map(logSoftmax) else nextTokenLogits

// Process the logits by defined logit processors
val nextTokenScoresProcessed =
logitProcessor.process(expandedInputs, nextTokenScores, currentLength)

// Process the logits by defined logit warpers
if (doSample) {
nextTokenScores =
logitProcessor.warp(expandedInputs, nextTokenScoresProcessed, currentLength)
}
// Add previous beam scores to the output
nextTokenScores = nextTokenScoresProcessed.zipWithIndex.map { case (x, ind1) =>
nextTokenScores = nextTokenScores.zipWithIndex.map { case (x, ind1) =>
x.zipWithIndex.map { case (y, _) =>
y + beamScores(ind1)
}
}
// Process the logits by defined logit warpers
if (doSample) {
nextTokenScores = logitProcessor.warp(expandedInputs, nextTokenScores, currentLength)
}

// Reshape next token score to (batchSize, vocabSize * numBeams)
val vocabSize = nextTokenScores.head.length
val reshapedNextTokenScores =
Expand Down Expand Up @@ -290,7 +294,8 @@ trait Generate {
padTokenId,
eosTokenId,
beamIndices,
currentLength)
currentLength,
stopTokenIds)
val newBeamScores = beamOutputs._1.flatMap(_.toList)
val beamNextTokens = beamOutputs._2.flatMap(_.toList)
val beamIdx = beamOutputs._3.flatMap(_.toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ class TopKLogitWarper(
}

private def getTopKIndices(logits: Array[Float], k: Int): Array[Int] = {
logits.indices.sortBy(logits(_)).reverse.take(k).toArray
// ignore float.NegativeInfinity values
val topKIndices = new ArrayBuffer[Int]()
val sortedLogits = logits.zipWithIndex.filter(_._1 != filterValue).sortBy(-_._1)
for ((_, i) <- sortedLogits.take(k)) {
topKIndices += i
}
topKIndices.toArray
}

private def maskNotTopKValues(logits: Array[Float], topKIndices: Array[Int]): Array[Float] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,40 @@ class TopPLogitWarper(val p: Double, val minTokensToKeep: Int = 1) extends Logit
val logitsUpd = scores.map(_.clone()) // Deep copy of the scores

if (p < 1.0) {
val scoresFiltered = scores.map(_.filterNot(_.isInfinite)) // Filter out infinite values
val scoresShape = Array(scoresFiltered.length, scoresFiltered.head.length)
val topPThreshold = math.ceil(p * scoresShape.last).toInt // Determine top-p threshold
val scoresFiltered = scores // Filter out infinite values
val scoresSoftmaxed = scoresFiltered.map(softmax) // Softmax the scores

for ((logits, i) <- scores.zipWithIndex) {
val topPIndices = getTopPIndices(logits, topPThreshold)
val maskedValues = maskNotTopPValues(logits, topPIndices)
for ((logits, i) <- scoresSoftmaxed.zipWithIndex) {
val topPIndices = getTopPIndices(logits, p)
// Mask the values that are not in the top-p
val maskedValues = maskNotTopPValues(logitsUpd(i), topPIndices)
logitsUpd(i) = maskedValues
}
}

logitsUpd
}

private def getTopPIndices(logits: Array[Float], k: Int): Array[Int] = {
logits.zipWithIndex.sortBy(-_._1).take(k).map(_._2)
private def getTopPIndices(logits: Array[Float], p: Double): Array[Int] = {
// sort the logits in descending order
var sortedLogits = logits.zipWithIndex.sortBy(-_._1)

// filter out the negative infinity values
sortedLogits = sortedLogits.filter(_._1 > 0.0)

// cumulative sum of the probabilities
val cumSum = sortedLogits.map(_._1).scanLeft(0.0)(_ + _)

// find the index of the last element that is less than p
val lastIdx = cumSum.indexWhere(_ >= p)
// if the last index is less than the minimum tokens to keep, return the top p tokens

if (lastIdx < minTokensToKeep) {
sortedLogits.take(math.ceil(p * logits.length).toInt).map(_._2)
} else {
sortedLogits.take(lastIdx).map(_._2)
}

}

private def maskNotTopPValues(logits: Array[Float], topPIndices: Array[Int]): Array[Float] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ abstract class BeamScorer() {
padTokenId: Int,
eosTokenId: Int,
beamIndices: Seq[Array[Int]],
currentLength: Int): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]])
currentLength: Int,
stopTokenIds: Array[Int]): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]])

def finalize(
inputIds: Seq[Array[Int]],
Expand All @@ -40,4 +41,5 @@ abstract class BeamScorer() {
def getBeamHypothesesSeq: Seq[BeamHypotheses]
def getNumBeams: Int
def isDone: Boolean
def getDone: Array[Boolean]
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class BeamSearchScorer(
override def getNumBeams: Int = numBeams
private val done: Array[Boolean] = Array.fill(batchSize)(false)

override def getDone: Array[Boolean] = done

override def process(
inputIds: Seq[Array[Int]],
nextScores: Seq[Array[Float]],
Expand All @@ -51,7 +53,8 @@ class BeamSearchScorer(
padTokenId: Int,
eosTokenId: Int,
beamIndices: Seq[Array[Int]],
currentLength: Int): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) = {
currentLength: Int,
stopTokenIds: Array[Int]): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) = {
// val currentLength = inputIds.length
val batchSize = this.beamHypothesesSeq.length
val nextBeamScores = Array.ofDim[Float](batchSize, this.beamSize)
Expand All @@ -75,7 +78,8 @@ class BeamSearchScorer(
val nextIndex = nextIndices(batchIdx)(beamTokenRank)
val batchBeamIdx = batchIdx * this.beamSize + nextIndex

if (eosTokenId == nextToken) {
// either eos token or stop tokens are found
if (eosTokenId == nextToken || stopTokenIds.contains(nextToken)) {
if (beamTokenRank >= this.beamSize) {
break
}
Expand Down
15 changes: 15 additions & 0 deletions src/main/scala/com/johnsnowlabs/nlp/HasGeneratorProperties.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,19 @@ trait HasGeneratorProperties {

/** @group getParam */
def getNReturnSequences: Int = $(nReturnSequences)

/** Stop tokens to terminate the generation
*
* @group param
*/
var stopTokenIds =
new IntArrayParam(this, "stopTokens", "Stop tokens to terminate the generation")

/** @group setParam */
def setStopTokenIds(value: Array[Int]): this.type = {
set(stopTokenIds, value)
}

/** @group getParam */
def getStopTokenIds: Array[Int] = $(stopTokenIds)
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ class LLAMA2Transformer(override val uid: String)
ignoreTokenIds -> Array(),
batchSize -> 1,
beamSize -> 1,
maxInputLength -> 4096)
maxInputLength -> 4096,
stopTokenIds -> Array())

/** takes a document and annotations and produces new annotations of this annotator's annotation
* type
Expand Down Expand Up @@ -269,7 +270,8 @@ class LLAMA2Transformer(override val uid: String)
randomSeed = this.randomSeed,
ignoreTokenIds = $(ignoreTokenIds),
beamSize = $(beamSize),
maxInputLength = $(maxInputLength))
maxInputLength = $(maxInputLength),
stopTokenIds = $(stopTokenIds))
} else {
Seq()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ class MistralTransformer(override val uid: String)
ignoreTokenIds -> Array(),
batchSize -> 1,
beamSize -> 1,
maxInputLength -> 4096)
maxInputLength -> 4096,
stopTokenIds -> Array())

/** takes a document and annotations and produces new annotations of this annotator's annotation
* type
Expand Down Expand Up @@ -277,7 +278,8 @@ class MistralTransformer(override val uid: String)
randomSeed = this.randomSeed,
ignoreTokenIds = $(ignoreTokenIds),
beamSize = $(beamSize),
maxInputLength = $(maxInputLength))
maxInputLength = $(maxInputLength),
stopTokenIds = $(stopTokenIds))
} else {
Seq()
}
Expand Down
Loading

0 comments on commit 257fd57

Please sign in to comment.