diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala index ed3444a3059ee2..9e9757d0115c37 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala @@ -79,8 +79,8 @@ private[johnsnowlabs] class LLAMA2( */ def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = { sentences.map(s => { - val sentWithTask = s.result - spp.getSppModel.encodeAsIds(sentWithTask) + val sentWithTask = "_" + s.result + Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask) }) } @@ -97,7 +97,8 @@ private[johnsnowlabs] class LLAMA2( randomSeed: Option[Long], ignoreTokenIds: Array[Int] = Array(), beamSize: Int, - maxInputLength: Int): Array[Array[Int]] = { + maxInputLength: Int, + stopTokenIds: Array[Int]): Array[Array[Int]] = { val ignoreTokenIdsInt = ignoreTokenIds val expandedDecoderInputsVals = batch val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray @@ -165,7 +166,8 @@ private[johnsnowlabs] class LLAMA2( ignoreTokenIdsInt, session, applySoftmax = true, - ovInferRequest = ovInferRequest) + ovInferRequest = ovInferRequest, + stopTokenIds = stopTokenIds) modelOutputs } @@ -184,7 +186,8 @@ private[johnsnowlabs] class LLAMA2( randomSeed: Option[Long] = None, ignoreTokenIds: Array[Int] = Array(), beamSize: Int, - maxInputLength: Int): Seq[Annotation] = { + maxInputLength: Int, + stopTokenIds: Array[Int]): Seq[Annotation] = { val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => val batchSP = encode(batch) @@ -201,7 +204,8 @@ private[johnsnowlabs] class LLAMA2( randomSeed, ignoreTokenIds, beamSize, - maxInputLength) + maxInputLength, + stopTokenIds) decode(spIds) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Mistral.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Mistral.scala index 58d074a90cba32..e37ee56abac5e5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Mistral.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Mistral.scala @@ -78,8 +78,8 @@ private[johnsnowlabs] class Mistral( */ def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = { sentences.map(s => { - val sentWithTask = s.result - spp.getSppModel.encodeAsIds(sentWithTask) + val sentWithTask = "_" + s.result + Array(bosTokenId) ++ spp.getSppModel.encodeAsIds(sentWithTask) }) } @@ -96,7 +96,8 @@ private[johnsnowlabs] class Mistral( randomSeed: Option[Long], ignoreTokenIds: Array[Int] = Array(), beamSize: Int, - maxInputLength: Int): Array[Array[Int]] = { + maxInputLength: Int, + stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = { val ignoreTokenIdsInt = ignoreTokenIds val expandedDecoderInputsVals = batch val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray @@ -162,8 +163,9 @@ private[johnsnowlabs] class Mistral( randomSeed, ignoreTokenIdsInt, session, - applySoftmax = false, - ovInferRequest = ovInferRequest) + applySoftmax = true, + ovInferRequest = ovInferRequest, + stopTokenIds = stopTokenIds) // decoderOutputs modelOutputs @@ -183,7 +185,8 @@ private[johnsnowlabs] class Mistral( randomSeed: Option[Long] = None, ignoreTokenIds: Array[Int] = Array(), beamSize: Int, - maxInputLength: Int): Seq[Annotation] = { + maxInputLength: Int, + stopTokenIds: Array[Int]): Seq[Annotation] = { val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => val batchSP = encode(batch) @@ -200,7 +203,8 @@ private[johnsnowlabs] class Mistral( randomSeed, ignoreTokenIds, beamSize, - maxInputLength) + maxInputLength, + stopTokenIds) decode(spIds) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala index 400a103abb22cd..36fa9927431663 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Phi2.scala @@ -103,7 +103,8 @@ private[johnsnowlabs] class Phi2( randomSeed: Option[Long], ignoreTokenIds: Array[Int] = Array(), beamSize: Int, - maxInputLength: Int): Array[Array[Int]] = { + maxInputLength: Int, + stopTokenIds: Array[Int]): Array[Array[Int]] = { val ignoreTokenIdsInt = ignoreTokenIds val expandedDecoderInputsVals = batch val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray @@ -169,7 +170,8 @@ private[johnsnowlabs] class Phi2( ignoreTokenIdsInt, session, applySoftmax = false, - ovInferRequest = ovInferRequest) + ovInferRequest = ovInferRequest, + stopTokenIds = stopTokenIds) // decoderOutputs modelOutputs @@ -189,7 +191,8 @@ private[johnsnowlabs] class Phi2( randomSeed: Option[Long] = None, ignoreTokenIds: Array[Int] = Array(), beamSize: Int, - maxInputLength: Int): Seq[Annotation] = { + maxInputLength: Int, + stopTokenIds: Array[Int]): Seq[Annotation] = { val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => val batchSP = encode(batch) @@ -206,7 +209,8 @@ private[johnsnowlabs] class Phi2( randomSeed, ignoreTokenIds, beamSize, - maxInputLength) + maxInputLength, + stopTokenIds) decode(spIds) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala index 4e4140f7735ab2..24d2ac1d3f6696 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala @@ -104,7 +104,8 @@ trait Generate { ignoreTokenIds: Array[Int] = Array(), session: Either[Session, (OrtEnvironment, OrtSession)], applySoftmax: Boolean = true, - ovInferRequest: Option[InferRequest] = None): Array[Array[Int]] = { + ovInferRequest: Option[InferRequest] = None, + stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = { // TODO: Add support for ignoreTokenIds @@ -117,8 +118,8 @@ trait Generate { noRepeatNgramSize = noRepeatNgramSize, vocabSize = vocabSize)) - logitProcessorList.addProcess( - new MinLengthLogitProcessor(eosTokenId, minOutputLength, vocabSize)) +// logitProcessorList.addProcess( +// new MinLengthLogitProcessor(eosTokenId, minOutputLength, vocabSize)) logitProcessorList.addProcess(new TemperatureLogitWarper(temperature)) @@ -148,7 +149,8 @@ trait Generate { randomSeed, session, applySoftmax, - ovInferRequest) + ovInferRequest, + stopTokenIds) } /** Beam Search for text generation @@ -193,7 +195,8 @@ trait Generate { randomSeed: Option[Long], session: Either[Session, (OrtEnvironment, OrtSession)], applySoftmax: Boolean, - ovInferRequest: Option[InferRequest] = None): Array[Array[Int]] = { + ovInferRequest: Option[InferRequest] = None, + stopTokenIds: Array[Int] = Array()): Array[Array[Int]] = { val inputIds = inputIdsVal val batchSize = beamScorer.getBeamHypothesesSeq.length val numBeams = beamScorer.getNumBeams @@ -227,21 +230,22 @@ trait Generate { // Optionally Apply log softmax to model outputs var nextTokenScores = if (applySoftmax) nextTokenLogits.map(logSoftmax) else nextTokenLogits - // Process the logits by defined logit processors val nextTokenScoresProcessed = logitProcessor.process(expandedInputs, nextTokenScores, currentLength) + // Process the logits by defined logit warpers + if (doSample) { + nextTokenScores = + logitProcessor.warp(expandedInputs, nextTokenScoresProcessed, currentLength) + } // Add previous beam scores to the output - nextTokenScores = nextTokenScoresProcessed.zipWithIndex.map { case (x, ind1) => + nextTokenScores = nextTokenScores.zipWithIndex.map { case (x, ind1) => x.zipWithIndex.map { case (y, _) => y + beamScores(ind1) } } - // Process the logits by defined logit warpers - if (doSample) { - nextTokenScores = logitProcessor.warp(expandedInputs, nextTokenScores, currentLength) - } + // Reshape next token score to (batchSize, vocabSize * numBeams) val vocabSize = nextTokenScores.head.length val reshapedNextTokenScores = @@ -290,7 +294,8 @@ trait Generate { padTokenId, eosTokenId, beamIndices, - currentLength) + currentLength, + stopTokenIds) val newBeamScores = beamOutputs._1.flatMap(_.toList) val beamNextTokens = beamOutputs._2.flatMap(_.toList) val beamIdx = beamOutputs._3.flatMap(_.toList) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala index 4d60a0e1684eda..f63fbba4ea7b1a 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopKLogitWarper.scala @@ -43,7 +43,13 @@ class TopKLogitWarper( } private def getTopKIndices(logits: Array[Float], k: Int): Array[Int] = { - logits.indices.sortBy(logits(_)).reverse.take(k).toArray + // ignore float.NegativeInfinity values + val topKIndices = new ArrayBuffer[Int]() + val sortedLogits = logits.zipWithIndex.filter(_._1 != filterValue).sortBy(-_._1) + for ((_, i) <- sortedLogits.take(k)) { + topKIndices += i + } + topKIndices.toArray } private def maskNotTopKValues(logits: Array[Float], topKIndices: Array[Int]): Array[Float] = { diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala index 85e0dcf0e2893a..9c0ce72c6e45ce 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/TopPLogitWarper.scala @@ -24,13 +24,13 @@ class TopPLogitWarper(val p: Double, val minTokensToKeep: Int = 1) extends Logit val logitsUpd = scores.map(_.clone()) // Deep copy of the scores if (p < 1.0) { - val scoresFiltered = scores.map(_.filterNot(_.isInfinite)) // Filter out infinite values - val scoresShape = Array(scoresFiltered.length, scoresFiltered.head.length) - val topPThreshold = math.ceil(p * scoresShape.last).toInt // Determine top-p threshold + val scoresFiltered = scores // Filter out infinite values + val scoresSoftmaxed = scoresFiltered.map(softmax) // Softmax the scores - for ((logits, i) <- scores.zipWithIndex) { - val topPIndices = getTopPIndices(logits, topPThreshold) - val maskedValues = maskNotTopPValues(logits, topPIndices) + for ((logits, i) <- scoresSoftmaxed.zipWithIndex) { + val topPIndices = getTopPIndices(logits, p) + // Mask the values that are not in the top-p + val maskedValues = maskNotTopPValues(logitsUpd(i), topPIndices) logitsUpd(i) = maskedValues } } @@ -38,8 +38,26 @@ class TopPLogitWarper(val p: Double, val minTokensToKeep: Int = 1) extends Logit logitsUpd } - private def getTopPIndices(logits: Array[Float], k: Int): Array[Int] = { - logits.zipWithIndex.sortBy(-_._1).take(k).map(_._2) + private def getTopPIndices(logits: Array[Float], p: Double): Array[Int] = { + // sort the logits in descending order + var sortedLogits = logits.zipWithIndex.sortBy(-_._1) + + // filter out the negative infinity values + sortedLogits = sortedLogits.filter(_._1 > 0.0) + + // cumulative sum of the probabilities + val cumSum = sortedLogits.map(_._1).scanLeft(0.0)(_ + _) + + // find the index of the last element that is less than p + val lastIdx = cumSum.indexWhere(_ >= p) + // if the last index is less than the minimum tokens to keep, return the top p tokens + + if (lastIdx < minTokensToKeep) { + sortedLogits.take(math.ceil(p * logits.length).toInt).map(_._2) + } else { + sortedLogits.take(lastIdx).map(_._2) + } + } private def maskNotTopPValues(logits: Array[Float], topPIndices: Array[Int]): Array[Float] = { diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamScorer.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamScorer.scala index 2fcbcada95337f..9f6eaed16b6361 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamScorer.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamScorer.scala @@ -26,7 +26,8 @@ abstract class BeamScorer() { padTokenId: Int, eosTokenId: Int, beamIndices: Seq[Array[Int]], - currentLength: Int): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) + currentLength: Int, + stopTokenIds: Array[Int]): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) def finalize( inputIds: Seq[Array[Int]], @@ -40,4 +41,5 @@ abstract class BeamScorer() { def getBeamHypothesesSeq: Seq[BeamHypotheses] def getNumBeams: Int def isDone: Boolean + def getDone: Array[Boolean] } diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamSearchScorer.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamSearchScorer.scala index fbc4cb466215bc..577da0571698ab 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamSearchScorer.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Search/BeamSearchScorer.scala @@ -43,6 +43,8 @@ class BeamSearchScorer( override def getNumBeams: Int = numBeams private val done: Array[Boolean] = Array.fill(batchSize)(false) + override def getDone: Array[Boolean] = done + override def process( inputIds: Seq[Array[Int]], nextScores: Seq[Array[Float]], @@ -51,7 +53,8 @@ class BeamSearchScorer( padTokenId: Int, eosTokenId: Int, beamIndices: Seq[Array[Int]], - currentLength: Int): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) = { + currentLength: Int, + stopTokenIds: Array[Int]): (Array[Array[Float]], Array[Array[Int]], Array[Array[Int]]) = { // val currentLength = inputIds.length val batchSize = this.beamHypothesesSeq.length val nextBeamScores = Array.ofDim[Float](batchSize, this.beamSize) @@ -75,7 +78,8 @@ class BeamSearchScorer( val nextIndex = nextIndices(batchIdx)(beamTokenRank) val batchBeamIdx = batchIdx * this.beamSize + nextIndex - if (eosTokenId == nextToken) { + // either eos token or stop tokens are found + if (eosTokenId == nextToken || stopTokenIds.contains(nextToken)) { if (beamTokenRank >= this.beamSize) { break } diff --git a/src/main/scala/com/johnsnowlabs/nlp/HasGeneratorProperties.scala b/src/main/scala/com/johnsnowlabs/nlp/HasGeneratorProperties.scala index eeddad13aacd32..6f13d946a21bc3 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/HasGeneratorProperties.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/HasGeneratorProperties.scala @@ -222,4 +222,19 @@ trait HasGeneratorProperties { /** @group getParam */ def getNReturnSequences: Int = $(nReturnSequences) + + /** Stop tokens to terminate the generation + * + * @group param + */ + var stopTokenIds = + new IntArrayParam(this, "stopTokens", "Stop tokens to terminate the generation") + + /** @group setParam */ + def setStopTokenIds(value: Array[Int]): this.type = { + set(stopTokenIds, value) + } + + /** @group getParam */ + def getStopTokenIds: Array[Int] = $(stopTokenIds) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index b9c114ea62de5f..9ecec85caa2520 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -235,7 +235,8 @@ class LLAMA2Transformer(override val uid: String) ignoreTokenIds -> Array(), batchSize -> 1, beamSize -> 1, - maxInputLength -> 4096) + maxInputLength -> 4096, + stopTokenIds -> Array()) /** takes a document and annotations and produces new annotations of this annotator's annotation * type @@ -269,7 +270,8 @@ class LLAMA2Transformer(override val uid: String) randomSeed = this.randomSeed, ignoreTokenIds = $(ignoreTokenIds), beamSize = $(beamSize), - maxInputLength = $(maxInputLength)) + maxInputLength = $(maxInputLength), + stopTokenIds = $(stopTokenIds)) } else { Seq() } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MistralTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MistralTransformer.scala index 43ab7a9f6264dd..ba2cf5af900030 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MistralTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MistralTransformer.scala @@ -243,7 +243,8 @@ class MistralTransformer(override val uid: String) ignoreTokenIds -> Array(), batchSize -> 1, beamSize -> 1, - maxInputLength -> 4096) + maxInputLength -> 4096, + stopTokenIds -> Array()) /** takes a document and annotations and produces new annotations of this annotator's annotation * type @@ -277,7 +278,8 @@ class MistralTransformer(override val uid: String) randomSeed = this.randomSeed, ignoreTokenIds = $(ignoreTokenIds), beamSize = $(beamSize), - maxInputLength = $(maxInputLength)) + maxInputLength = $(maxInputLength), + stopTokenIds = $(stopTokenIds)) } else { Seq() } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2Transformer.scala index fbb16fa7e13ea2..ecb8dbb88f768c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/Phi2Transformer.scala @@ -266,7 +266,8 @@ class Phi2Transformer(override val uid: String) ignoreTokenIds -> Array(), batchSize -> 1, beamSize -> 1, - maxInputLength -> 4096) + maxInputLength -> 4096, + stopTokenIds -> Array()) /** takes a document and annotations and produces new annotations of this annotator's annotation * type @@ -300,7 +301,8 @@ class Phi2Transformer(override val uid: String) randomSeed = this.randomSeed, ignoreTokenIds = $(ignoreTokenIds), beamSize = $(beamSize), - maxInputLength = $(maxInputLength)) + maxInputLength = $(maxInputLength), + stopTokenIds = $(stopTokenIds)) } else { Seq() } diff --git a/src/test/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitProcess/LogitProcessorTest.scala b/src/test/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitProcess/LogitProcessorTest.scala index 8ea0a1a5a26b71..c21fe6079259a2 100644 --- a/src/test/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitProcess/LogitProcessorTest.scala +++ b/src/test/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitProcess/LogitProcessorTest.scala @@ -69,4 +69,27 @@ class LogitProcessorTest extends AnyFlatSpec { assert(forcedScoresMultiple(1) == 0) } + "MinlengthLogitProcessor" should "process correctly" taggedAs FastTest in { + + val vocabSize = 32 + val scoresBatches: Array[Array[Float]] = Array(Array.fill(vocabSize)(1.0f)) + + val minLength = 2 + val minLengthLogitProcessor = new MinLengthLogitProcessor( + eosTokenId = vocabSize - 1, + minLength = minLength, + vocabSize = vocabSize) + + // if the min length is not reached, the eos token should be suppressed + val processedScores = + minLengthLogitProcessor.call(Seq.empty, scoresBatches, minLength - 1).head + + assert(processedScores(vocabSize - 1) == Float.NegativeInfinity) + + // if the min length is reached, the eos token should not be suppressed + val processedScoresAfter = + minLengthLogitProcessor.call(Seq.empty, scoresBatches, minLength).head + + assert(processedScoresAfter(vocabSize - 1) == 1.0f) + } } diff --git a/src/test/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/LogitWarperTest.scala b/src/test/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/LogitWarperTest.scala new file mode 100644 index 00000000000000..7f7112f5879ce9 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/ml/ai/util/Generation/Logit/LogitWarper/LogitWarperTest.scala @@ -0,0 +1,74 @@ +package com.johnsnowlabs.ml.ai.util.Generation.Logit.LogitWarper + +import com.johnsnowlabs.tags.FastTest +import org.scalatest.flatspec.AnyFlatSpec + +class LogitWarperTest extends AnyFlatSpec { + + "TopKLogitWarper" should "process correctly" taggedAs FastTest in { + val vocabSize = 10 + val topK = 5 + + val logitWarper = new TopKLogitWarper(k = topK, minTokensToKeep = 1) + val scoresBatches: Array[Array[Float]] = + Array(Array(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f)) + + val processedScores = logitWarper.call(Seq.empty, scoresBatches, 1).head + + // Check that the top 5 scores are the same and the rest are -inf + assert(processedScores(0) == Float.NegativeInfinity) + assert(processedScores(1) == Float.NegativeInfinity) + assert(processedScores(2) == Float.NegativeInfinity) + assert(processedScores(3) == Float.NegativeInfinity) + assert(processedScores(4) == Float.NegativeInfinity) + assert(processedScores(5) == 0.6f) + assert(processedScores(6) == 0.7f) + assert(processedScores(7) == 0.8f) + assert(processedScores(8) == 0.9f) + assert(processedScores(9) == 1.0f) + + } + + "TemperatureLogitWarper" should "process correctly" taggedAs FastTest in { + val vocabSize = 10 + val temperature = 0.5f + + val logitWarper = new TemperatureLogitWarper(temperature = temperature) + val scoresBatches: Array[Array[Float]] = + Array(Array(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f)) + + val processedScores = logitWarper.call(Seq.empty, scoresBatches, 1).head + + // Check that the scores are correctly scaled + processedScores.zipWithIndex.foreach({ case (score, i) => + assert(score == scoresBatches(0)(i) / temperature) + }) + + } + + "TopPLogitWarper" should "process correctly" taggedAs FastTest in { + val vocabSize = 10 + val topP = 0.5f + + val logitWarper = new TopPLogitWarper(p = topP, minTokensToKeep = 1) + val scoresBatches: Array[Array[Float]] = + Array(Array(0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, Float.NegativeInfinity)) + + val processedScores = logitWarper.call(Seq.empty, scoresBatches, 1).head + + // print out the processed scores + processedScores.foreach(println) + + // Check that the top 5 scores are the same and the rest are -inf + assert(processedScores(0) == Float.NegativeInfinity) + assert(processedScores(1) == Float.NegativeInfinity) + assert(processedScores(2) == Float.NegativeInfinity) + assert(processedScores(3) == Float.NegativeInfinity) + assert(processedScores(4) == Float.NegativeInfinity) + assert(processedScores(5) !== Float.NegativeInfinity) + assert(processedScores(6) !== Float.NegativeInfinity) + assert(processedScores(7) !== Float.NegativeInfinity) + assert(processedScores(8) !== Float.NegativeInfinity) + assert(processedScores(9) == Float.NegativeInfinity) + } +}