-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
JBAI-5829 [examples] Added a new
examples
module with sample projec…
…t for image classification. The classification example demonstrates a pipeline for dogs vs. cats recognition using a pre-trained CaffeNet model.
- Loading branch information
Showing
3 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
group = rootProject.group | ||
version = rootProject.version | ||
|
||
kotlin { | ||
jvm() | ||
|
||
sourceSets { | ||
jvmMain { | ||
dependencies { | ||
api(project(":inference:inference-api")) | ||
api(project(":inference:inference-core")) | ||
api(project(":serialization:serializer-protobuf")) | ||
api(project(":utils:utils-common")) | ||
|
||
api(project(":ndarray:ndarray-api")) | ||
api(project(":ndarray:ndarray-core")) | ||
|
||
api(libs.wire.runtime) | ||
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2") | ||
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2") // Dataset support | ||
|
||
implementation("io.ktor:ktor-client-core:2.3.12") | ||
implementation("io.ktor:ktor-client-cio:2.3.12") // JVM Engine | ||
|
||
api("org.slf4j:slf4j-api:2.0.9") | ||
api("org.slf4j:slf4j-simple:2.0.9") | ||
|
||
implementation("com.knuddels:jtokkit:1.1.0") | ||
} | ||
} | ||
} | ||
} |
158 changes: 158 additions & 0 deletions
158
examples/src/jvmMain/kotlin/io/kinference/examples/classification/Main.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
package io.kinference.examples.classification | ||
|
||
import io.kinference.core.KIEngine | ||
import io.kinference.core.data.tensor.KITensor | ||
import io.kinference.core.data.tensor.asTensor | ||
import io.kinference.ndarray.arrays.* | ||
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke | ||
import io.kinference.utils.CommonDataLoader | ||
import io.kinference.utils.PredictionConfigs | ||
import io.kinference.utils.inlines.InlineInt | ||
import io.ktor.client.HttpClient | ||
import io.ktor.client.plugins.HttpTimeout | ||
import io.ktor.client.request.prepareRequest | ||
import io.ktor.client.statement.bodyAsChannel | ||
import io.ktor.util.cio.writeChannel | ||
import io.ktor.utils.io.copyAndClose | ||
import okio.Path.Companion.toPath | ||
import org.jetbrains.kotlinx.dl.api.preprocessing.pipeline | ||
import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset | ||
import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath | ||
import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders | ||
import org.jetbrains.kotlinx.dl.impl.inference.imagerecognition.InputType | ||
import org.jetbrains.kotlinx.dl.impl.preprocessing.* | ||
import org.jetbrains.kotlinx.dl.impl.preprocessing.image.* | ||
import java.awt.image.BufferedImage | ||
import java.io.File | ||
import kotlin.collections.mutableMapOf | ||
|
||
// Constants for input and output tensor names used in the CaffeNet model | ||
private const val INPUT_TENSOR_NAME = "data_0" | ||
private const val OUTPUT_TENSOR_NAME = "prob_1" | ||
|
||
// Preprocessing pipeline for input images using KotlinDL | ||
private val preprocessing = pipeline<BufferedImage>() | ||
.resize { | ||
outputWidth = 224 | ||
outputHeight = 224 | ||
interpolation = InterpolationType.BILINEAR | ||
} | ||
.convert { colorMode = ColorMode.BGR } | ||
.toFloatArray { } | ||
.call(InputType.CAFFE.preprocessing()) | ||
|
||
// Path to the small dataset of dogs vs cats images (100 images) | ||
private val dogsVsCatsDatasetPath = dogsCatsSmallDatasetPath() | ||
|
||
/** | ||
* Downloads a file from the specified URL and saves it to the given output path. | ||
* If the file already exists at the output path, the download is skipped. | ||
* | ||
* @param url The URL from which the file will be downloaded. | ||
* @param outputPath The path to which the downloaded file will be saved. | ||
*/ | ||
private suspend fun downloadFile(url: String, outputPath: String) { | ||
// Check if the file already exists | ||
val file = File(outputPath) | ||
if (file.exists()) { | ||
println("File already exists at $outputPath. Skipping download.") | ||
return // Exit the function if the file exists | ||
} | ||
|
||
// Create an instance of HttpClient with custom timeout settings | ||
val client = HttpClient { | ||
install(HttpTimeout) { | ||
requestTimeoutMillis = 600_000 // Set timeout to 10 minutes (600,000 milliseconds) | ||
} | ||
} | ||
|
||
// Download the file and write to the specified output path | ||
client.prepareRequest(url).execute { response -> | ||
response.bodyAsChannel().copyAndClose(File(outputPath).writeChannel()) | ||
} | ||
|
||
client.close() | ||
} | ||
|
||
/** | ||
* Creates a Map of input tensors categorized by their respective classes (e.g., "cat" and "dog"). | ||
* | ||
* This function reads images from the dataset, preprocesses them, | ||
* transposes the tensors to the required format, and groups them | ||
* based on their class label. | ||
* | ||
* @return A Map where the keys are the class labels (e.g., "cat" and "dog"), | ||
* and the values are lists of KITensor objects representing the input tensors | ||
* for each class. | ||
*/ | ||
private suspend fun createInputs(): Map<String, List<KITensor>> { | ||
val dataset = OnFlyImageDataset.create( | ||
File(dogsVsCatsDatasetPath), | ||
FromFolders(mapping = mapOf("cat" to 0, "dog" to 1)), | ||
preprocessing | ||
).shuffle() | ||
|
||
|
||
val tensorShape = intArrayOf(1, 224, 224, 3) // Original tensor shape is [batch, width, height, channel] | ||
val permuteAxis = intArrayOf(0, 3, 1, 2) // Permutations for shape [batch, channel, width, height] | ||
val inputTensors = mutableMapOf<String, MutableList<KITensor>>() | ||
|
||
for (i in 0 until dataset.xSize()) { | ||
val inputData = dataset.getX(i) | ||
val inputClass = if (dataset.getY(i).toInt() == 0) "cat" else "dog" | ||
val floatNDArray = FloatNDArray(tensorShape) { index: InlineInt -> inputData[index.value]} // Create an NDArray from the image data | ||
val inputTensor = floatNDArray.transpose(permuteAxis).asTensor(INPUT_TENSOR_NAME) // Transpose and create a tensor from the NDArray | ||
inputTensors.putIfAbsent(inputClass, mutableListOf()) | ||
inputTensors[inputClass]!!.add(inputTensor) | ||
} | ||
|
||
return inputTensors | ||
} | ||
|
||
/** | ||
* Displays the top 5 predictions with their corresponding labels and scores. | ||
* | ||
* @param predictions The predicted scores in a multidimensional array format. | ||
* @param classLabels The list of class labels corresponding to the predictions. | ||
* @param originalClass The actual class label of the instance being predicted. | ||
*/ | ||
private fun displayTopPredictions(predictions: FloatNDArray, classLabels: List<String>, originalClass: String) { | ||
val predictionArray = predictions.array.blocks.first() | ||
val indexedScores = predictionArray.withIndex().sortedByDescending { it.value }.take(5) | ||
|
||
println("\nOriginal class: $originalClass") | ||
println("Top 5 predictions:") | ||
for ((index, score) in indexedScores) { | ||
val predictedClassLabel = if (index in classLabels.indices) classLabels[index] else "Unknown" | ||
println("${predictedClassLabel}: ${"%.2f".format(score * 100)}%") | ||
} | ||
} | ||
|
||
suspend fun main() { | ||
val resourcesPath = System.getProperty("user.dir") + "/cache/" | ||
val modelUrl = "https://github.com/onnx/models/raw/main/validated/vision/classification/caffenet/model/caffenet-12.onnx" | ||
val synsetUrl = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt" | ||
|
||
println("Current working directory: $resourcesPath") | ||
println("Downloading model from: $modelUrl") | ||
downloadFile(modelUrl, "$resourcesPath/model.onnx") | ||
println("Downloading synset from: $synsetUrl") | ||
downloadFile(synsetUrl, "$resourcesPath/synset.txt") | ||
|
||
val modelBytes = CommonDataLoader.bytes("$resourcesPath/model.onnx".toPath()) | ||
val classLabels = File("$resourcesPath/synset.txt").readLines() | ||
|
||
println("Loading model...") | ||
val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator) | ||
println("Creating inputs...") | ||
val inputTensors = createInputs() | ||
|
||
println("Starting inference...") | ||
inputTensors.forEach { dataClass -> | ||
dataClass.value.forEach { tensor -> | ||
val actualOutputs = model.predict(listOf(tensor)) | ||
val predictions = actualOutputs[OUTPUT_TENSOR_NAME]?.data as FloatNDArray | ||
displayTopPredictions(predictions, classLabels, dataClass.key) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters