Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convenience methods to detect if a model is installed #236

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,87 @@ open class WhisperKit {
throw error
}
}
public static func modelLocation(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a new line above this for code style. Will want to add docc for all of the functions in this file eventually but not a blocker.

variant: String,
downloadBase: URL? = nil,
from repo: String = "argmaxinc/whisperkit-coreml",
specificOnly: Bool = false
) throws -> URL {
let saveRoot: URL
if let downloadBase {
saveRoot = downloadBase
} else {
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first!
saveRoot = documents.appending(component: "huggingface")
}
let modelSearchPath = "*\(variant.description)"
let repoDestination = saveRoot.appending(component: "models").appending(component: repo)

do {
Logging.debug(repoDestination.absoluteString)
let isInstalled = try? repoDestination.resourceValues(forKeys: [.isDirectoryKey]).isDirectory
guard let isInstalled, isInstalled else {
throw WhisperError.modelsUnavailable("Repo destination does not exist \(repoDestination.absoluteString)")
}
let dirFiles = try FileManager.default.contentsOfDirectory(atPath: repoDestination.path(percentEncoded: false))
let modelFiles = dirFiles.matching(glob: modelSearchPath)
var uniquePaths = Set(modelFiles.map { $0.components(separatedBy: "/").first! })

var variantPath: String? = nil

if uniquePaths.count == 0 {
throw WhisperError.modelsUnavailable("Could not find model matching \"\(modelSearchPath)\"")
} else if uniquePaths.count == 1 {
variantPath = uniquePaths.first
} else if specificOnly {
// We only want the one specific model, and won't accept fuzzy fallbacks
throw WhisperError.modelsUnavailable("Multiple models found matching \"\(modelSearchPath)\" and specificOnly was set")
} else {
// If the model name search returns more than one unique model folder, then prepend the default "openai" prefix from whisperkittools to disambiguate
Logging.debug("No definitive model matching \"\(modelSearchPath)\"")
let adjustedModelSearchPath = "*openai*\(variant.description)"
Logging.debug("Searching for models matching \"\(adjustedModelSearchPath)\" in \(repo)")
let adjustedModelFiles = dirFiles.matching(glob: adjustedModelSearchPath)
uniquePaths = Set(adjustedModelFiles.map { $0.components(separatedBy: "/").first! })

if uniquePaths.count == 1 {
variantPath = uniquePaths.first
}
}

guard let variantPath else {
// If there is still ambiguity, throw an error
throw WhisperError.modelsUnavailable("Could not find definitive model matching \"\(modelSearchPath)\"")
}

let truePath = repoDestination.appending(path: variantPath)
return truePath
} catch {
Logging.debug(error)
throw error
}
}
public static func modelInstalled(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New line above this as well

variant: String,
downloadBase: URL? = nil,
from repo: String = "argmaxinc/whisperkit-coreml",
specificOnly: Bool = false
) -> Bool {
do {
let location = try modelLocation(variant: variant, downloadBase: downloadBase, from: repo, specificOnly: specificOnly)
let isInstalled = try? location.resourceValues(forKeys: [.isDirectoryKey]).isDirectory
return isInstalled ?? false
} catch {
Logging.error(error)
return false
}
}

/// Sets up the model folder either from a local path or by downloading from a repository.
open func setupModels(
model: String?,
downloadBase: URL? = nil,
modelRepo: String?,
modelRepo: String? = nil,
modelFolder: String?,
download: Bool
) async throws {
Expand All @@ -298,6 +373,12 @@ open class WhisperKit {
Error: \(error)
""")
}
} else {
let modelSupport = WhisperKit.recommendedModels()
let modelVariant = model ?? modelSupport.default
let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
guard let folder = try? Self.modelLocation(variant: modelVariant, downloadBase: downloadBase, from: repo) else { return }
self.modelFolder = folder
}
}

Expand Down
40 changes: 40 additions & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,46 @@ final class UnitTests: XCTestCase {
"Failed to init WhisperKit"
)
}
func testModelIsInstalled() async throws {
XCTAssertTrue(
WhisperKit.modelInstalled(variant: "openai_whisper-tiny"),
"Model was not installed"
)
XCTAssertFalse(
WhisperKit.modelInstalled(variant: "THIS_MODEL_DOES_NOT_EXIST"),
"Model does not exist, but returned true"
)
XCTAssertTrue(
WhisperKit.modelInstalled(variant: "tiny"),
"Model was not installed"
)
XCTAssertTrue(
WhisperKit.modelInstalled(variant: "tiny.en"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need some method to download this one ahead of time before checking it is installed.

run: make download-model MODEL=tiny
you can do that by either adding a line make download-model MODEL=tiny.en here to download it, or just removing this tiny.en check and only running this test for tiny.

"Model was not installed"
)
let location = try WhisperKit.modelLocation(variant: "tiny")
let location2 = try WhisperKit.modelLocation(variant: "openai_whisper-tiny")
XCTAssertEqual(location, location2, "Auto fallback to OpenAI model returned a different result")
}
func testOfflineMode() async throws {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New line above this one


let pipe = try await WhisperKit(WhisperKitConfig(model: "tiny.en", download: false))

let cancellable: AnyCancellable? = pipe.progress.publisher(for: \.fractionCompleted)
.removeDuplicates()
.withPrevious()
.sink { previous, current in
if let previous {
XCTAssertLessThan(previous, current)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! 👍

}
}
_ = try await pipe.transcribe(
audioPath: Bundle.current.path(forResource: "ted_60", ofType: "m4a")!,
decodeOptions: .init(chunkingStrategy: .vad)
)
cancellable?.cancel()

}

// MARK: - Config Tests

Expand Down
Loading