diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index 3b08931558a41a..6e748faa72ee63 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -20,6 +20,7 @@ import ai.onnxruntime.OrtSession.SessionOptions import ai.onnxruntime.OrtSession.SessionOptions.{ExecutionMode, OptLevel} import ai.onnxruntime.providers.OrtCUDAProviderOptions import ai.onnxruntime.{OrtEnvironment, OrtSession} +import com.johnsnowlabs.ml.util.LoadExternalModel import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil} import org.apache.spark.SparkFiles import org.apache.spark.sql.SparkSession @@ -114,9 +115,10 @@ object OnnxWrapper { .toString // 2. Unpack archive + val randomSuffix = generateRandomSuffix(onnxFileSuffix) val folder = if (zipped) - ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), onnxFileSuffix) + ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), randomSuffix) else modelPath @@ -151,6 +153,11 @@ object OnnxWrapper { onnxWrapper } + private def generateRandomSuffix(fileSuffix: Option[String]): Option[String] = { + val randomSuffix = Some(LoadExternalModel.generateRandomString(10)) + Some(s"${randomSuffix.get}${fileSuffix.getOrElse("")}") + } + private def mapToSessionOptionsObject(sessionOptions: Map[String, String]): SessionOptions = { val providers = OrtEnvironment.getAvailableProviders if (providers.toArray.map(x => x.toString).contains("CUDA")) { diff --git a/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala index dd8b5f466a2927..fa5908383bae97 100644 --- a/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/openvino/OpenvinoWrapper.scala @@ -17,8 +17,8 @@ package com.johnsnowlabs.ml.openvino import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError -import com.johnsnowlabs.ml.util.{ONNX, Openvino, TensorFlow} -import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader, FileHelper, ZipArchiveUtil} +import com.johnsnowlabs.ml.util.{LoadExternalModel, ONNX, Openvino, TensorFlow} +import com.johnsnowlabs.util.{FileHelper, ZipArchiveUtil} import org.apache.commons.io.{FileUtils, FilenameUtils} import org.apache.spark.SparkFiles import org.apache.spark.sql.SparkSession @@ -113,9 +113,10 @@ object OpenvinoWrapper { .toAbsolutePath .toString + val randomSuffix = generateRandomSuffix(ovFileSuffix) val folder = if (zipped) - ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), ovFileSuffix) + ZipArchiveUtil.unzip(new File(modelPath), Some(tmpFolder), randomSuffix) else modelPath @@ -151,6 +152,11 @@ object OpenvinoWrapper { openvinoWrapper } + private def generateRandomSuffix(fileSuffix: Option[String]): Option[String] = { + val randomSuffix = Some(LoadExternalModel.generateRandomString(10)) + Some(s"${randomSuffix.get}${fileSuffix.getOrElse("")}") + } + /** Convert the model at srcPath to OpenVINO IR Format and export to exportPath. * * @param srcPath diff --git a/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala b/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala index 93cab6a0a89dd7..cd0761f0f9daa3 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala @@ -22,6 +22,7 @@ import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper} import java.io.File import java.nio.file.Paths import scala.io.Source +import scala.util.Random object LoadExternalModel { @@ -228,4 +229,16 @@ object LoadExternalModel { f } + /** Generates a random alphanumeric string of a given length. + * + * @param n + * the length of the generated string + * @return + * a random alphanumeric string of length n + */ + def generateRandomString(n: Int): String = { + val alphanumeric = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + (1 to n).map(_ => alphanumeric(Random.nextInt(alphanumeric.length))).mkString + } + } diff --git a/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala b/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala index 8c85f2915561f3..0443471e5080bc 100644 --- a/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala +++ b/src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala @@ -135,7 +135,7 @@ object ZipArchiveUtil { val zip = new ZipFile(file) zip.entries.asScala foreach { entry => - val entryName = if (suffix.isDefined) suffix.get + "_" + entry.getName else entry.getName + val entryName = buildEntryName(entry, suffix) val entryPath = { if (entryName.startsWith(basename)) entryName.substring(0, basename.length) @@ -165,4 +165,9 @@ object ZipArchiveUtil { destDir.getPath } + private def buildEntryName(entry: ZipEntry, suffix: Option[String]): String = { + val entryName = if (suffix.isDefined) suffix.get + "_" + entry.getName else entry.getName + entryName.split("_").distinct.mkString("_") + } + }