Skip to content

Commit

Permalink
Fix pretrained models not being found on dbfs systems
Browse files Browse the repository at this point in the history
  • Loading branch information
DevinTDHa committed Oct 19, 2024
1 parent 9db3332 commit 6d7ae78
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
31 changes: 30 additions & 1 deletion src/main/scala/com/johnsnowlabs/ml/gguf/GGUFWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package com.johnsnowlabs.ml.gguf

import com.johnsnowlabs.nlp.llama.{LlamaModel, ModelParameters}
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.SparkSession
import org.slf4j.{Logger, LoggerFactory}
Expand Down Expand Up @@ -72,7 +74,7 @@ object GGUFWrapper {
// TODO: make sure this.synchronized is needed or it's not a bottleneck
private def withSafeGGUFModelLoader(modelParameters: ModelParameters): LlamaModel =
this.synchronized {
new LlamaModel(modelParameters) // TODO: Model parameters
new LlamaModel(modelParameters)
}

def read(sparkSession: SparkSession, modelPath: String): GGUFWrapper = {
Expand All @@ -89,4 +91,31 @@ object GGUFWrapper {

new GGUFWrapper(modelFile.getName, modelFile.getParent)
}

def readModel(modelFolderPath: String, spark: SparkSession): GGUFWrapper = {
def findGGUFModelInFolder(folderPath: String): String = {
val folder = new File(folderPath)
if (folder.exists && folder.isDirectory) {
val ggufFile: String = folder.listFiles
.filter(_.isFile)
.filter(_.getName.endsWith(".gguf"))
.map(_.getAbsolutePath)
.headOption // Should only be one file
.getOrElse(
throw new IllegalArgumentException(s"Could not find GGUF model in $folderPath"))

new File(ggufFile).getAbsolutePath
} else {
throw new IllegalArgumentException(s"Path $folderPath is not a directory")
}
}

val uri = new java.net.URI(modelFolderPath.replaceAllLiterally("\\", "/"))
// In case the path belongs to a different file system but doesn't have the scheme prepended (e.g. dbfs)
val fileSystem: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
val actualFolderPath = fileSystem.resolvePath(new Path(modelFolderPath)).toString
val localFolder = ResourceHelper.copyToLocal(actualFolderPath)
val modelFile = findGGUFModelInFolder(localFolder)
read(spark, modelFile)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,25 +235,8 @@ trait ReadAutoGGUFModel {
this: ParamsAndFeaturesReadable[AutoGGUFModel] =>

def readModel(instance: AutoGGUFModel, path: String, spark: SparkSession): Unit = {
def findGGUFModelInFolder(): String = {
val folder =
new java.io.File(
path.replace("file:", "")
) // File should be local at this point. TODO: Except if its HDFS?
if (folder.exists && folder.isDirectory) {
folder.listFiles
.filter(_.isFile)
.filter(_.getName.endsWith(".gguf"))
.map(_.getAbsolutePath)
.headOption // Should only be one file
.getOrElse(throw new IllegalArgumentException(s"Could not find GGUF model in $path"))
} else {
throw new IllegalArgumentException(s"Path $path is not a directory")
}
}

val model = AutoGGUFModel.loadSavedModel(findGGUFModelInFolder(), spark)
instance.setModelIfNotSet(spark, model.getModelIfNotSet)
val model: GGUFWrapper = GGUFWrapper.readModel(path, spark)
instance.setModelIfNotSet(spark, model)
}

addReader(readModel)
Expand Down

0 comments on commit 6d7ae78

Please sign in to comment.