-
Notifications
You must be signed in to change notification settings - Fork 365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add convenience methods to detect if a model is installed #236
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -266,12 +266,87 @@ open class WhisperKit { | |
throw error | ||
} | ||
} | ||
public static func modelLocation( | ||
variant: String, | ||
downloadBase: URL? = nil, | ||
from repo: String = "argmaxinc/whisperkit-coreml", | ||
specificOnly: Bool = false | ||
) throws -> URL { | ||
let saveRoot: URL | ||
if let downloadBase { | ||
saveRoot = downloadBase | ||
} else { | ||
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first! | ||
saveRoot = documents.appending(component: "huggingface") | ||
} | ||
let modelSearchPath = "*\(variant.description)" | ||
let repoDestination = saveRoot.appending(component: "models").appending(component: repo) | ||
|
||
do { | ||
Logging.debug(repoDestination.absoluteString) | ||
let isInstalled = try? repoDestination.resourceValues(forKeys: [.isDirectoryKey]).isDirectory | ||
guard let isInstalled, isInstalled else { | ||
throw WhisperError.modelsUnavailable("Repo destination does not exist \(repoDestination.absoluteString)") | ||
} | ||
let dirFiles = try FileManager.default.contentsOfDirectory(atPath: repoDestination.path(percentEncoded: false)) | ||
let modelFiles = dirFiles.matching(glob: modelSearchPath) | ||
var uniquePaths = Set(modelFiles.map { $0.components(separatedBy: "/").first! }) | ||
|
||
var variantPath: String? = nil | ||
|
||
if uniquePaths.count == 0 { | ||
throw WhisperError.modelsUnavailable("Could not find model matching \"\(modelSearchPath)\"") | ||
} else if uniquePaths.count == 1 { | ||
variantPath = uniquePaths.first | ||
} else if specificOnly { | ||
// We only want the one specific model, and won't accept fuzzy fallbacks | ||
throw WhisperError.modelsUnavailable("Multiple models found matching \"\(modelSearchPath)\" and specificOnly was set") | ||
} else { | ||
// If the model name search returns more than one unique model folder, then prepend the default "openai" prefix from whisperkittools to disambiguate | ||
Logging.debug("No definitive model matching \"\(modelSearchPath)\"") | ||
let adjustedModelSearchPath = "*openai*\(variant.description)" | ||
Logging.debug("Searching for models matching \"\(adjustedModelSearchPath)\" in \(repo)") | ||
let adjustedModelFiles = dirFiles.matching(glob: adjustedModelSearchPath) | ||
uniquePaths = Set(adjustedModelFiles.map { $0.components(separatedBy: "/").first! }) | ||
|
||
if uniquePaths.count == 1 { | ||
variantPath = uniquePaths.first | ||
} | ||
} | ||
|
||
guard let variantPath else { | ||
// If there is still ambiguity, throw an error | ||
throw WhisperError.modelsUnavailable("Could not find definitive model matching \"\(modelSearchPath)\"") | ||
} | ||
|
||
let truePath = repoDestination.appending(path: variantPath) | ||
return truePath | ||
} catch { | ||
Logging.debug(error) | ||
throw error | ||
} | ||
} | ||
public static func modelInstalled( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New line above this as well |
||
variant: String, | ||
downloadBase: URL? = nil, | ||
from repo: String = "argmaxinc/whisperkit-coreml", | ||
specificOnly: Bool = false | ||
) -> Bool { | ||
do { | ||
let location = try modelLocation(variant: variant, downloadBase: downloadBase, from: repo, specificOnly: specificOnly) | ||
let isInstalled = try? location.resourceValues(forKeys: [.isDirectoryKey]).isDirectory | ||
return isInstalled ?? false | ||
} catch { | ||
Logging.error(error) | ||
return false | ||
} | ||
} | ||
|
||
/// Sets up the model folder either from a local path or by downloading from a repository. | ||
open func setupModels( | ||
model: String?, | ||
downloadBase: URL? = nil, | ||
modelRepo: String?, | ||
modelRepo: String? = nil, | ||
modelFolder: String?, | ||
download: Bool | ||
) async throws { | ||
|
@@ -298,6 +373,12 @@ open class WhisperKit { | |
Error: \(error) | ||
""") | ||
} | ||
} else { | ||
let modelSupport = WhisperKit.recommendedModels() | ||
let modelVariant = model ?? modelSupport.default | ||
let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" | ||
guard let folder = try? Self.modelLocation(variant: modelVariant, downloadBase: downloadBase, from: repo) else { return } | ||
self.modelFolder = folder | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -31,6 +31,46 @@ final class UnitTests: XCTestCase { | |||
"Failed to init WhisperKit" | ||||
) | ||||
} | ||||
func testModelIsInstalled() async throws { | ||||
XCTAssertTrue( | ||||
WhisperKit.modelInstalled(variant: "openai_whisper-tiny"), | ||||
"Model was not installed" | ||||
) | ||||
XCTAssertFalse( | ||||
WhisperKit.modelInstalled(variant: "THIS_MODEL_DOES_NOT_EXIST"), | ||||
"Model does not exist, but returned true" | ||||
) | ||||
XCTAssertTrue( | ||||
WhisperKit.modelInstalled(variant: "tiny"), | ||||
"Model was not installed" | ||||
) | ||||
XCTAssertTrue( | ||||
WhisperKit.modelInstalled(variant: "tiny.en"), | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll need some method to download this one ahead of time before checking it is installed.
make download-model MODEL=tiny.en here to download it, or just removing this tiny.en check and only running this test for tiny.
|
||||
"Model was not installed" | ||||
) | ||||
let location = try WhisperKit.modelLocation(variant: "tiny") | ||||
let location2 = try WhisperKit.modelLocation(variant: "openai_whisper-tiny") | ||||
XCTAssertEqual(location, location2, "Auto fallback to OpenAI model returned a different result") | ||||
} | ||||
func testOfflineMode() async throws { | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New line above this one |
||||
|
||||
let pipe = try await WhisperKit(WhisperKitConfig(model: "tiny.en", download: false)) | ||||
|
||||
let cancellable: AnyCancellable? = pipe.progress.publisher(for: \.fractionCompleted) | ||||
.removeDuplicates() | ||||
.withPrevious() | ||||
.sink { previous, current in | ||||
if let previous { | ||||
XCTAssertLessThan(previous, current) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool! 👍 |
||||
} | ||||
} | ||||
_ = try await pipe.transcribe( | ||||
audioPath: Bundle.current.path(forResource: "ted_60", ofType: "m4a")!, | ||||
decodeOptions: .init(chunkingStrategy: .vad) | ||||
) | ||||
cancellable?.cancel() | ||||
|
||||
} | ||||
|
||||
// MARK: - Config Tests | ||||
|
||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add a new line above this for code style. Will want to add docc for all of the functions in this file eventually but not a blocker.