diff --git a/src/main/scala/com/johnsnowlabs/ml/gguf/GGUFWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/gguf/GGUFWrapper.scala index 495e8cb2a6b0f9..ef7091c3b5cd12 100644 --- a/src/main/scala/com/johnsnowlabs/ml/gguf/GGUFWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/gguf/GGUFWrapper.scala @@ -16,6 +16,8 @@ package com.johnsnowlabs.ml.gguf import com.johnsnowlabs.nlp.llama.{LlamaModel, ModelParameters} +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkFiles import org.apache.spark.sql.SparkSession import org.slf4j.{Logger, LoggerFactory} @@ -72,7 +74,7 @@ object GGUFWrapper { // TODO: make sure this.synchronized is needed or it's not a bottleneck private def withSafeGGUFModelLoader(modelParameters: ModelParameters): LlamaModel = this.synchronized { - new LlamaModel(modelParameters) // TODO: Model parameters + new LlamaModel(modelParameters) } def read(sparkSession: SparkSession, modelPath: String): GGUFWrapper = { @@ -89,4 +91,31 @@ object GGUFWrapper { new GGUFWrapper(modelFile.getName, modelFile.getParent) } + + def readModel(modelFolderPath: String, spark: SparkSession): GGUFWrapper = { + def findGGUFModelInFolder(folderPath: String): String = { + val folder = new File(folderPath) + if (folder.exists && folder.isDirectory) { + val ggufFile: String = folder.listFiles + .filter(_.isFile) + .filter(_.getName.endsWith(".gguf")) + .map(_.getAbsolutePath) + .headOption // Should only be one file + .getOrElse( + throw new IllegalArgumentException(s"Could not find GGUF model in $folderPath")) + + new File(ggufFile).getAbsolutePath + } else { + throw new IllegalArgumentException(s"Path $folderPath is not a directory") + } + } + + val uri = new java.net.URI(modelFolderPath.replaceAllLiterally("\\", "/")) + // In case the path belongs to a different file system but doesn't have the scheme prepended (e.g. dbfs) + val fileSystem: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration) + val actualFolderPath = fileSystem.resolvePath(new Path(modelFolderPath)).toString + val localFolder = ResourceHelper.copyToLocal(actualFolderPath) + val modelFile = findGGUFModelInFolder(localFolder) + read(spark, modelFile) + } } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala index 20c9053b361d7b..385b9ddc0e983d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/AutoGGUFModel.scala @@ -235,25 +235,8 @@ trait ReadAutoGGUFModel { this: ParamsAndFeaturesReadable[AutoGGUFModel] => def readModel(instance: AutoGGUFModel, path: String, spark: SparkSession): Unit = { - def findGGUFModelInFolder(): String = { - val folder = - new java.io.File( - path.replace("file:", "") - ) // File should be local at this point. TODO: Except if its HDFS? - if (folder.exists && folder.isDirectory) { - folder.listFiles - .filter(_.isFile) - .filter(_.getName.endsWith(".gguf")) - .map(_.getAbsolutePath) - .headOption // Should only be one file - .getOrElse(throw new IllegalArgumentException(s"Could not find GGUF model in $path")) - } else { - throw new IllegalArgumentException(s"Path $path is not a directory") - } - } - - val model = AutoGGUFModel.loadSavedModel(findGGUFModelInFolder(), spark) - instance.setModelIfNotSet(spark, model.getModelIfNotSet) + val model: GGUFWrapper = GGUFWrapper.readModel(path, spark) + instance.setModelIfNotSet(spark, model) } addReader(readModel)