-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
JBAI-5829 [examples] Added GPT-2 example using ORTEngine for text gen…
…eration.
- Loading branch information
Showing
4 changed files
with
157 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIMain.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package io.kinference.examples.lm | ||
|
||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer | ||
import io.kinference.core.KIEngine | ||
import io.kinference.core.data.tensor.asTensor | ||
import io.kinference.examples.downloadFile | ||
import io.kinference.examples.extractTopToken | ||
import io.kinference.examples.resourcesPath | ||
import io.kinference.ndarray.arrays.LongNDArray | ||
import io.kinference.ndarray.arrays.NDArrayCore | ||
import io.kinference.utils.CommonDataLoader | ||
import io.kinference.utils.PredictionConfigs | ||
import io.kinference.utils.inlines.InlineInt | ||
import okio.Path.Companion.toPath | ||
|
||
// Constants for input and output tensor names used in the GPT-2 model | ||
private const val INPUT_TENSOR_NAME = "input1" | ||
private const val OUTPUT_TENSOR_NAME = "output1" // We use only logits tensor | ||
|
||
suspend fun main() { | ||
val modelUrl = "https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx" | ||
val modelName = "gpt2-lm-head-10" | ||
|
||
println("Downloading model from: $modelUrl") | ||
downloadFile(modelUrl, "$resourcesPath/$modelName.onnx") | ||
|
||
val modelBytes = CommonDataLoader.bytes("${resourcesPath}/$modelName.onnx".toPath()) | ||
|
||
println("Loading model...") | ||
val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator) | ||
|
||
val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024")) | ||
val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " + | ||
"all the various types of neurons of the organism, but it continues throughout adult life " + | ||
"in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will " + | ||
"live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances." | ||
val encoded = tokenizer.encode(testString) | ||
val tokens = encoded.ids | ||
val tokensSize = tokens.size | ||
|
||
val predictionLength = 34 | ||
val outputTokens = LongArray(predictionLength) { 0 } | ||
|
||
val input = LongNDArray(1, tokensSize) { idx: InlineInt -> tokens[idx.value] }.unsqueeze(0) | ||
var currentContext = input.clone() | ||
|
||
print("Here goes the test text for generation:\n$testString") | ||
|
||
for (idx in 0 until predictionLength) { | ||
val inputTensor = listOf((currentContext as NDArrayCore).asTensor(INPUT_TENSOR_NAME)) | ||
val output = model.predict(inputTensor) | ||
|
||
outputTokens[idx] = extractTopToken(output, tokensSize + idx, OUTPUT_TENSOR_NAME) | ||
|
||
val newTokenArray = LongNDArray(1, 1) { _: InlineInt -> outputTokens[idx] } | ||
currentContext = currentContext.concat(listOf(newTokenArray.unsqueeze(0)), axis = -1) | ||
print(tokenizer.decode(longArrayOf(outputTokens[idx]))) | ||
} | ||
println("\n\nDone") | ||
} |
169 changes: 0 additions & 169 deletions
169
examples/src/jvmMain/kotlin/io/kinference/examples/lm/Main.kt
This file was deleted.
Oops, something went wrong.
74 changes: 74 additions & 0 deletions
74
examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTMain.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package io.kinference.examples.lm | ||
|
||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer | ||
import io.kinference.core.data.tensor.KITensor | ||
import io.kinference.core.data.tensor.asTensor | ||
import io.kinference.examples.downloadFile | ||
import io.kinference.examples.extractTopToken | ||
import io.kinference.examples.resourcesPath | ||
import io.kinference.ndarray.arrays.FloatNDArray | ||
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke | ||
import io.kinference.ort.ORTData | ||
import io.kinference.ort.ORTEngine | ||
import io.kinference.ort.data.tensor.ORTTensor | ||
import io.kinference.utils.CommonDataLoader | ||
import io.kinference.utils.inlines.InlineInt | ||
import io.kinference.utils.toIntArray | ||
import okio.Path.Companion.toPath | ||
|
||
// Constants for input and output tensor names used in the GPT-2 model | ||
private const val INPUT_TENSOR_NAME = "input1" | ||
private const val OUTPUT_TENSOR_NAME = "output1" // We use only logits tensor | ||
|
||
suspend fun main() { | ||
val modelUrl = "https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx" | ||
val modelName = "gpt2-lm-head-10" | ||
|
||
println("Downloading model from: $modelUrl") | ||
downloadFile(modelUrl, "$resourcesPath/$modelName.onnx") | ||
|
||
val modelBytes = CommonDataLoader.bytes("${resourcesPath}/$modelName.onnx".toPath()) | ||
|
||
println("Loading model...") | ||
val model = ORTEngine.loadModel(modelBytes) | ||
|
||
val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024")) | ||
val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " + | ||
"all the various types of neurons of the organism, but it continues throughout adult life " + | ||
"in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will " + | ||
"live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances." | ||
val encoded = tokenizer.encode(testString) | ||
val tokens = encoded.ids | ||
val tokensSize = tokens.size | ||
|
||
val predictionLength = 34 | ||
val outputTokens = LongArray(predictionLength) { 0 } | ||
|
||
val input = ORTTensor(tokens, longArrayOf(1, 1, tokensSize.toLong())) | ||
var currentContext = input.clone(INPUT_TENSOR_NAME) | ||
|
||
print("Here goes the test text for generation:\n$testString") | ||
|
||
for (idx in 0 until predictionLength) { | ||
val inputTensor = listOf(currentContext) | ||
val output = model.predict(inputTensor) | ||
|
||
outputTokens[idx] = extractTopToken(convertToKITensorMap(output), tokensSize + idx, OUTPUT_TENSOR_NAME) | ||
|
||
val newTokenArray = tokens + outputTokens.slice(IntRange(0, idx)) | ||
currentContext = ORTTensor(newTokenArray, longArrayOf(1, 1, tokensSize + idx + 1L), INPUT_TENSOR_NAME) | ||
print(tokenizer.decode(longArrayOf(outputTokens[idx]))) | ||
} | ||
println("\n\nDone") | ||
} | ||
|
||
private suspend fun convertToKITensorMap(outputs: Map<String, ORTData<*>>): Map<String, KITensor> { | ||
return outputs.map { (key, value) -> | ||
val ortTensor = value as ORTTensor | ||
val data = ortTensor.toFloatArray() | ||
val shape = ortTensor.shape.toIntArray() | ||
val ndArray = FloatNDArray(shape) { idx: InlineInt -> data[idx.value] } | ||
val tensor = ndArray.asTensor(key) | ||
return@map key to tensor | ||
}.toMap() | ||
} |