Skip to content

Commit

Permalink
improve zip util code and add tests for both ZipArchiveUtil ane OnnxW… (
Browse files Browse the repository at this point in the history
#14056)

* improve zip util code and add tests for both ZipArchiveUtil ane OnnxWrapper

* fix onnx test

* fix zipUtil test element ordering issue

---------

Co-authored-by: Angelo Anqi Ni <[email protected]>
  • Loading branch information
anqini and Angelo Anqi Ni authored Dec 7, 2023
1 parent 97a541b commit 37b24b9
Show file tree
Hide file tree
Showing 4 changed files with 297 additions and 30 deletions.
81 changes: 51 additions & 30 deletions src/main/scala/com/johnsnowlabs/util/ZipArchiveUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,47 @@ import java.util.zip.{ZipEntry, ZipFile, ZipOutputStream}
import scala.collection.JavaConverters._

object ZipArchiveUtil {

private def listFiles(file: File, outputFilename: String): List[String] = {
// Recursively lists all files in a given directory, returning a list of their absolute paths
private[util] def listFilesRecursive(file: File): List[File] = {
file match {
case file if file.isFile =>
if (file.getName != outputFilename)
List(file.getAbsoluteFile.toString)
else
List()
case file if file.isFile => List(new File(file.getAbsoluteFile.toString))
case file if file.isDirectory =>
val fList = file.list
// Add all files in current dir to list and recur on subdirs
fList.foldLeft(List[String]())((pList: List[String], path: String) =>
pList ++ listFiles(new File(file, path), outputFilename))
fList.foldLeft(List[File]())((pList: List[File], path: String) =>
pList ++ listFilesRecursive(new File(file, path)))
case _ => throw new IOException("Bad path. No file or directory found.")
}
}

private def addFileToZipEntry(
filename: String,
parentPath: String,
filePathsCount: Int): ZipEntry = {
if (filePathsCount <= 1)
new ZipEntry(new File(filename).getName)
else {
// use relative path to avoid adding absolute path directories
val relative = new File(parentPath).toURI.relativize(new File(filename).toURI).getPath
private[util] def addFileToZipEntry(
filename: File,
parentPath: File,
useRelativePath: Boolean = false): ZipEntry = {
if (!useRelativePath) // use absolute path
new ZipEntry(filename.getName)
else { // use relative path
val relative = parentPath.toURI.relativize(filename.toURI).getPath
new ZipEntry(relative)
}
}

private def createZip(
filePaths: List[String],
outputFilename: String,
parentPath: String): Unit = {
private[util] def createZip(
filePaths: List[File],
outputFilePath: File,
parentPath: File): Unit = {

val Buffer = 2 * 1024
val data = new Array[Byte](Buffer)
try {
val zipFileOS = new FileOutputStream(outputFilename)
val zipFileOS = new FileOutputStream(outputFilePath)
val zip = new ZipOutputStream(zipFileOS)
zip.setLevel(0)
filePaths.foreach((name: String) => {
val zipEntry = addFileToZipEntry(name, parentPath, filePaths.size)
filePaths.foreach((file: File) => {
val zipEntry = addFileToZipEntry(file, parentPath, filePaths.size > 1)
// add zip entry to output stream
zip.putNextEntry(new ZipEntry(zipEntry))
val in = new BufferedInputStream(new FileInputStream(name), Buffer)
val in = new BufferedInputStream(new FileInputStream(file), Buffer)
var b = in.read(data, 0, Buffer)
while (b != -1) {
zip.write(data, 0, b)
Expand All @@ -86,10 +81,36 @@ object ZipArchiveUtil {
}
}

def zip(fileName: String, outputFileName: String): Unit = {
val file = new File(fileName)
val filePaths = listFiles(file, outputFileName)
createZip(filePaths, outputFileName, fileName)
private[util] def zipFile(soureFile: File, outputFilePath: File): Unit = {
createZip(List(soureFile.getAbsoluteFile), outputFilePath, null)
}

private[util] def zipDir(sourceDir: File, outputFilePath: File): Unit = {
val filePaths = listFilesRecursive(sourceDir)
createZip(filePaths, outputFilePath, sourceDir)
}

def zip(sourcePath: String, outputFilePath: String): Unit = {
val sourceFile = new File(sourcePath)
val outputFile = new File(outputFilePath)
if (sourceFile.equals(outputFile))
throw new IllegalArgumentException("source path cannot be identical to target path")

if (!outputFile.getParentFile().exists)
throw new IOException("the parent directory of output file doesn't exist")

if (!sourceFile.exists())
throw new IOException("zip source path must exsit")

if (outputFile.exists())
throw new IOException("zip target file exsits")

if (sourceFile.isDirectory())
zipDir(sourceFile, outputFile)
else if (sourceFile.isFile())
zipFile(sourceFile, outputFile)
else
throw new IllegalArgumentException("only folder and file input are valid")
}

def unzip(file: File, destDirPath: Option[String] = None): String = {
Expand Down
Binary file added src/test/resources/onnx/models/dummy_model.onnx
Binary file not shown.
82 changes: 82 additions & 0 deletions src/test/scala/com/johnsnowlabs/ml/onnx/OnnxWrapperTestSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.johnsnowlabs.ml.onnx

import com.johnsnowlabs.tags.FastTest
import org.scalatest.flatspec.AnyFlatSpec
import java.nio.file.{Files, Paths, Path}
import java.io.File
import com.johnsnowlabs.util.FileHelper
import org.scalatest.BeforeAndAfter
import java.util.UUID

class OnnxWrapperTestSpec extends AnyFlatSpec with BeforeAndAfter {
/*
* Dummy model was created with the following python script
"""
import torch
import torch.nn as nn
import torch.onnx
# Define a simple neural network model
class DummyModel(nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
self.linear = nn.Linear(in_features=10, out_features=5)
def forward(self, x):
return self.linear(x)
# Create the model and dummy input
model = DummyModel()
dummy_input = torch.randn(1, 10) # batch size of 1, 10 features
# Export the model to ONNX format
torch.onnx.export(model, dummy_input, "dummy_model.onnx", verbose=True)
"""
*
*/
private val modelPath: String = "src/test/resources/onnx/models/dummy_model.onnx"

private val tmpDirPath: String = UUID.randomUUID().toString.takeRight(12) + "_onnx"
var tmpFolder: String = _

before {
tmpFolder = Files
.createDirectory(Paths.get(tmpDirPath))
.toAbsolutePath
.toString
}

after {
FileHelper.delete(tmpFolder)
}

"a dummy onnx wrapper" should "get session correctly" taggedAs FastTest in {
val modelBytes: Array[Byte] = Files.readAllBytes(Paths.get(modelPath))
val dummyOnnxWrapper = new OnnxWrapper(modelBytes)
dummyOnnxWrapper.getSession()
}

"a dummy onnx wrapper" should "saveToFile correctly" taggedAs FastTest in {
val modelBytes: Array[Byte] = Files.readAllBytes(Paths.get(modelPath))
val dummyOnnxWrapper = new OnnxWrapper(modelBytes)
dummyOnnxWrapper.saveToFile(Paths.get(tmpFolder, "modelFromTest.zip").toString)
// verify file existence
assert(new File(tmpFolder, "modelFromTest.zip").exists())
}
}
164 changes: 164 additions & 0 deletions src/test/scala/com/johnsnowlabs/util/ZipArchiveUtilTestSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.johnsnowlabs.util

import com.johnsnowlabs.tags.FastTest
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers._
import java.nio.file.{Files, Paths, Path}
import java.io.File
import com.johnsnowlabs.util.FileHelper
import org.scalatest.BeforeAndAfter
import java.util.UUID

class ZipArchiveUtilTestSpec extends AnyFlatSpec with BeforeAndAfter {
private val tmpDirPath: String = UUID.randomUUID().toString.takeRight(12) + "_onnx"
var tmpFolder: String = _

before {
// create temp dir for testing
tmpFolder = Files
.createDirectory(Paths.get(tmpDirPath))
.toAbsolutePath
.toString

// create files and dirs for recusive testing
new File(tmpFolder, "fileA").createNewFile()
Files.createDirectory(Paths.get(tmpDirPath, "dir"))
Files.createFile(Paths.get(tmpFolder, "dir", "fileA"))
Files.createFile(Paths.get(tmpFolder, "dir", "fileB"))
Files.createDirectory(Paths.get(tmpDirPath, "dir", "dir2"))
Files.createFile(Paths.get(tmpDirPath, "dir", "dir2", "fileC"))
}

after {
// delete the temp directory
FileHelper.delete(tmpFolder)
}

"listFilesRecursive" should "throw exception if the file doesn't exist" taggedAs FastTest in {
val isIOException =
try {
ZipArchiveUtil.listFilesRecursive(new File("a"))
false
} catch {
case e: java.io.IOException => true
case _: Throwable => false
}
assert(isIOException)
}

"listFilesRecursive" should "return a single item list if give a file" taggedAs FastTest in {
val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "fileA"))
assert(list.length == 1)
assert(list.head.equals(new File(tmpFolder, "fileA")))
}

"listFilesRecursive" should "return a single item list if give a file within folder" taggedAs FastTest in {
val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "dir/fileA"))
assert(list.length == 1)
assert(list.head.equals(new File(tmpFolder, "dir/fileA")))
}

"listFilesRecursive" should "return a list with 3 items if give the dir folder" taggedAs FastTest in {
val list = ZipArchiveUtil.listFilesRecursive(new File(tmpFolder, "dir"))
assert(list.length == 3)

list.toSet should contain theSameElementsAs Set(
new File(tmpFolder, "dir/dir2/fileC"),
new File(tmpFolder, "dir/fileA"),
new File(tmpFolder, "dir/fileB"))
}

"addFileToZipEntry" should "return zip entry with absolute setting" taggedAs FastTest in {
val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("fileA"), null, false)
assert(zipEntry.getName == "fileA")
}

"addFileToZipEntry" should "return zip entry with relative setting" taggedAs FastTest in {
val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("dir/fileA"), new File("dir"), true)
assert(zipEntry.getName == "fileA")
}

"addFileToZipEntry" should "return zip entry full path with absolute setting" taggedAs FastTest in {
val zipEntry = ZipArchiveUtil.addFileToZipEntry(new File("dir/fileA"), new File("dir"), false)
assert(zipEntry.getName == "fileA")
}

"createZip" should "create zip for a single file" taggedAs FastTest in {
ZipArchiveUtil.createZip(
List(new File(tmpFolder, "dir/fileA")),
new File(tmpFolder, "targetA.zip"),
null)

assert(new File(tmpFolder, "targetA.zip").exists())
}

"createZip" should "create zip" taggedAs FastTest in {
ZipArchiveUtil.createZip(
List(
new File(Paths.get(tmpFolder, "dir", "fileA").toString),
new File(Paths.get(tmpFolder, "dir", "fileB").toString)),
new File(tmpFolder, "targetDir.zip"),
new File(Paths.get(tmpFolder).toString))

assert(new File(tmpFolder, "targetDir.zip").exists())
}

"zipFile" should "zip a single file" taggedAs FastTest in {
ZipArchiveUtil.zipFile(
new File(Paths.get(tmpFolder, "dir", "fileA").toString),
new File(tmpFolder, "targetA.zip"))
assert(new File(tmpFolder, "targetA.zip").exists())
}

"zipDir" should "zip a directory" taggedAs FastTest in {
ZipArchiveUtil.zipDir(
new File(Paths.get(tmpFolder, "dir").toString),
new File(tmpFolder, "targetDir.zip"))
assert(new File(tmpFolder, "targetDir.zip").exists())
}

"zip" should "zip a single file with String input" taggedAs FastTest in {
ZipArchiveUtil.zip(
Paths.get(tmpFolder, "dir", "fileA").toString,
Paths.get(tmpFolder, "targetA.zip").toString)
assert(new File(tmpFolder, "targetA.zip").exists())
}

"zip" should "zip a dir with String input" taggedAs FastTest in {
ZipArchiveUtil.zip(
Paths.get(tmpFolder, "dir").toString,
Paths.get(tmpFolder, "targetDir.zip").toString)
assert(new File(tmpFolder, "targetDir.zip").exists())
}

"zip" should "throw exception if the folder not exist since we are not responsible to create folders" taggedAs FastTest in {
val isIOExceptinoCaught =
try {
ZipArchiveUtil.zip(
Paths.get(tmpFolder, "dir").toString,
Paths.get(tmpFolder, "otherdir/targetDir.zip").toString)
false
} catch {
case e: java.io.IOException => true
case _: Throwable => false
}

assert(isIOExceptinoCaught)
}
}

0 comments on commit 37b24b9

Please sign in to comment.