Skip to content

Commit

Permalink
[SPARKNLP-1052] Adding random suffix to avoid duplication in spark fi…
Browse files Browse the repository at this point in the history
…les (#14340)
  • Loading branch information
danilojsl authored Jul 14, 2024
1 parent 0cc970a commit a070adc
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 5 deletions.
9 changes: 8 additions & 1 deletion src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import ai.onnxruntime.OrtSession.SessionOptions
import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel}
import ai.onnxruntime.providers.OrtCUDAProviderOptions
import ai.onnxruntime.{OrtEnvironment, OrtSession}
import com.johnsnowlabs.ml.util.LoadExternalModel
import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -114,9 +115,10 @@ object OnnxWrapper {
.toString

// 2. Unpack archive
val randomSuffix = generateRandomSuffix(onnxFileSuffix)
val folder =
if (zipped)
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), onnxFileSuffix)
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), randomSuffix)
else
modelPath

Expand Down Expand Up @@ -151,6 +153,11 @@ object OnnxWrapper {
onnxWrapper
}

private def generateRandomSuffix(fileSuffix: Option[String]): Option[String] = {
val randomSuffix = Some(LoadExternalModel.generateRandomString(10))
Some(s"${randomSuffix.get}${fileSuffix.getOrElse("")}")
}

private def mapToSessionOptionsObject(sessionOptions: Map[String, String]): SessionOptions = {
val providers = OrtEnvironment.getAvailableProviders
if (providers.toArray.map(x => x.toString).contains("CUDA")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
package com.johnsnowlabs.ml.openvino

import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader, FileHelper, ZipArchiveUtil}
import com.johnsnowlabs.ml.util.{LoadExternalModel, ONNX, Openvino, TensorFlow}
import com.johnsnowlabs.util.{FileHelper, ZipArchiveUtil}
import org.apache.commons.io.{FileUtils, FilenameUtils}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -113,9 +113,10 @@ object OpenvinoWrapper {
.toAbsolutePath
.toString

val randomSuffix = generateRandomSuffix(ovFileSuffix)
val folder =
if (zipped)
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), ovFileSuffix)
ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), randomSuffix)
else
modelPath

Expand Down Expand Up @@ -151,6 +152,11 @@ object OpenvinoWrapper {
openvinoWrapper
}

private def generateRandomSuffix(fileSuffix: Option[String]): Option[String] = {
val randomSuffix = Some(LoadExternalModel.generateRandomString(10))
Some(s"${randomSuffix.get}${fileSuffix.getOrElse("")}")
}

/** Convert the model at srcPath to OpenVINO IR Format and export to exportPath.
*
* @param srcPath
Expand Down
13 changes: 13 additions & 0 deletions src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper}
import java.io.File
import java.nio.file.Paths
import scala.io.Source
import scala.util.Random

object LoadExternalModel {

Expand Down Expand Up @@ -228,4 +229,16 @@ object LoadExternalModel {
f
}

/** Generates a random alphanumeric string of a given length.
*
* @param n
* the length of the generated string
* @return
* a random alphanumeric string of length n
*/
def generateRandomString(n: Int): String = {
val alphanumeric = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
(1 to n).map(_ => alphanumeric(Random.nextInt(alphanumeric.length))).mkString
}

}
7 changes: 6 additions & 1 deletion src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ object ZipArchiveUtil {

val zip = new ZipFile(file)
zip.entries.asScala foreach { entry =>
val entryName = if (suffix.isDefined) suffix.get + "_" + entry.getName else entry.getName
val entryName = buildEntryName(entry, suffix)
val entryPath = {
if (entryName.startsWith(basename))
entryName.substring(0, basename.length)
Expand Down Expand Up @@ -165,4 +165,9 @@ object ZipArchiveUtil {
destDir.getPath
}

private def buildEntryName(entry: ZipEntry, suffix: Option[String]): String = {
val entryName = if (suffix.isDefined) suffix.get + "_" + entry.getName else entry.getName
entryName.split("_").distinct.mkString("_")
}

}

0 comments on commit a070adc

Please sign in to comment.