Skip to content

Commit

Permalink
JBAI-5829 [examples] Added a new examples module with sample projec…
Browse files Browse the repository at this point in the history
…t for image classification. The classification example demonstrates a pipeline for dogs vs. cats recognition using a pre-trained CaffeNet model.
  • Loading branch information
dmitriyb committed Sep 17, 2024
1 parent 0b2bab7 commit 61ff574
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 0 deletions.
32 changes: 32 additions & 0 deletions examples/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
group = rootProject.group
version = rootProject.version

kotlin {
jvm()

sourceSets {
jvmMain {
dependencies {
api(project(":inference:inference-api"))
api(project(":inference:inference-core"))
api(project(":serialization:serializer-protobuf"))
api(project(":utils:utils-common"))

api(project(":ndarray:ndarray-api"))
api(project(":ndarray:ndarray-core"))

api(libs.wire.runtime)
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2")
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2") // Dataset support

implementation("io.ktor:ktor-client-core:2.3.12")
implementation("io.ktor:ktor-client-cio:2.3.12") // JVM Engine

api("org.slf4j:slf4j-api:2.0.9")
api("org.slf4j:slf4j-simple:2.0.9")

implementation("com.knuddels:jtokkit:1.1.0")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package io.kinference.examples.classification

import io.kinference.core.KIEngine
import io.kinference.core.data.tensor.KITensor
import io.kinference.core.data.tensor.asTensor
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke
import io.kinference.utils.CommonDataLoader
import io.kinference.utils.PredictionConfigs
import io.kinference.utils.inlines.InlineInt
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.request.prepareRequest
import io.ktor.client.statement.bodyAsChannel
import io.ktor.util.cio.writeChannel
import io.ktor.utils.io.copyAndClose
import okio.Path.Companion.toPath
import org.jetbrains.kotlinx.dl.api.preprocessing.pipeline
import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset
import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath
import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders
import org.jetbrains.kotlinx.dl.impl.inference.imagerecognition.InputType
import org.jetbrains.kotlinx.dl.impl.preprocessing.*
import org.jetbrains.kotlinx.dl.impl.preprocessing.image.*
import java.awt.image.BufferedImage
import java.io.File
import kotlin.collections.mutableMapOf

// Constants for input and output tensor names used in the CaffeNet model
private const val INPUT_TENSOR_NAME = "data_0"
private const val OUTPUT_TENSOR_NAME = "prob_1"

// Preprocessing pipeline for input images using KotlinDL
private val preprocessing = pipeline<BufferedImage>()
.resize {
outputWidth = 224
outputHeight = 224
interpolation = InterpolationType.BILINEAR
}
.convert { colorMode = ColorMode.BGR }
.toFloatArray { }
.call(InputType.CAFFE.preprocessing())

// Path to the small dataset of dogs vs cats images (100 images)
private val dogsVsCatsDatasetPath = dogsCatsSmallDatasetPath()

/**
* Downloads a file from the specified URL and saves it to the given output path.
* If the file already exists at the output path, the download is skipped.
*
* @param url The URL from which the file will be downloaded.
* @param outputPath The path to which the downloaded file will be saved.
*/
private suspend fun downloadFile(url: String, outputPath: String) {
// Check if the file already exists
val file = File(outputPath)
if (file.exists()) {
println("File already exists at $outputPath. Skipping download.")
return // Exit the function if the file exists
}

// Create an instance of HttpClient with custom timeout settings
val client = HttpClient {
install(HttpTimeout) {
requestTimeoutMillis = 600_000 // Set timeout to 10 minutes (600,000 milliseconds)
}
}

// Download the file and write to the specified output path
client.prepareRequest(url).execute { response ->
response.bodyAsChannel().copyAndClose(File(outputPath).writeChannel())
}

client.close()
}

/**
* Creates a Map of input tensors categorized by their respective classes (e.g., "cat" and "dog").
*
* This function reads images from the dataset, preprocesses them,
* transposes the tensors to the required format, and groups them
* based on their class label.
*
* @return A Map where the keys are the class labels (e.g., "cat" and "dog"),
* and the values are lists of KITensor objects representing the input tensors
* for each class.
*/
private suspend fun createInputs(): Map<String, List<KITensor>> {
val dataset = OnFlyImageDataset.create(
File(dogsVsCatsDatasetPath),
FromFolders(mapping = mapOf("cat" to 0, "dog" to 1)),
preprocessing
).shuffle()


val tensorShape = intArrayOf(1, 224, 224, 3) // Original tensor shape is [batch, width, height, channel]
val permuteAxis = intArrayOf(0, 3, 1, 2) // Permutations for shape [batch, channel, width, height]
val inputTensors = mutableMapOf<String, MutableList<KITensor>>()

for (i in 0 until dataset.xSize()) {
val inputData = dataset.getX(i)
val inputClass = if (dataset.getY(i).toInt() == 0) "cat" else "dog"
val floatNDArray = FloatNDArray(tensorShape) { index: InlineInt -> inputData[index.value]} // Create an NDArray from the image data
val inputTensor = floatNDArray.transpose(permuteAxis).asTensor(INPUT_TENSOR_NAME) // Transpose and create a tensor from the NDArray
inputTensors.putIfAbsent(inputClass, mutableListOf())
inputTensors[inputClass]!!.add(inputTensor)
}

return inputTensors
}

/**
* Displays the top 5 predictions with their corresponding labels and scores.
*
* @param predictions The predicted scores in a multidimensional array format.
* @param classLabels The list of class labels corresponding to the predictions.
* @param originalClass The actual class label of the instance being predicted.
*/
private fun displayTopPredictions(predictions: FloatNDArray, classLabels: List<String>, originalClass: String) {
val predictionArray = predictions.array.blocks.first()
val indexedScores = predictionArray.withIndex().sortedByDescending { it.value }.take(5)

println("\nOriginal class: $originalClass")
println("Top 5 predictions:")
for ((index, score) in indexedScores) {
val predictedClassLabel = if (index in classLabels.indices) classLabels[index] else "Unknown"
println("${predictedClassLabel}: ${"%.2f".format(score * 100)}%")
}
}

suspend fun main() {
val resourcesPath = System.getProperty("user.dir") + "/cache/"
val modelUrl = "https://github.com/onnx/models/raw/main/validated/vision/classification/caffenet/model/caffenet-12.onnx"
val synsetUrl = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt"

println("Current working directory: $resourcesPath")
println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$resourcesPath/model.onnx")
println("Downloading synset from: $synsetUrl")
downloadFile(synsetUrl, "$resourcesPath/synset.txt")

val modelBytes = CommonDataLoader.bytes("$resourcesPath/model.onnx".toPath())
val classLabels = File("$resourcesPath/synset.txt").readLines()

println("Loading model...")
val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator)
println("Creating inputs...")
val inputTensors = createInputs()

println("Starting inference...")
inputTensors.forEach { dataClass ->
dataClass.value.forEach { tensor ->
val actualOutputs = model.predict(listOf(tensor))
val predictions = actualOutputs[OUTPUT_TENSOR_NAME]?.data as FloatNDArray
displayTopPredictions(predictions, classLabels, dataClass.key)
}
}
}
2 changes: 2 additions & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ include(":adapters:kmath:adapter-kmath-core")
include(":adapters:kmath:adapter-kmath-ort")
include(":adapters:kmath:adapter-kmath-ort-gpu")

include(":examples")


pluginManagement {
repositories {
Expand Down

0 comments on commit 61ff574

Please sign in to comment.