diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 88a665f..1562a59 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -266,12 +266,87 @@ open class WhisperKit { throw error } } + public static func modelLocation( + variant: String, + downloadBase: URL? = nil, + from repo: String = "argmaxinc/whisperkit-coreml", + specificOnly: Bool = false + ) throws -> URL { + let saveRoot: URL + if let downloadBase { + saveRoot = downloadBase + } else { + let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first! + saveRoot = documents.appending(component: "huggingface") + } + let modelSearchPath = "*\(variant.description)" + let repoDestination = saveRoot.appending(component: "models").appending(component: repo) + + do { + Logging.debug(repoDestination.absoluteString) + let isInstalled = try? repoDestination.resourceValues(forKeys: [.isDirectoryKey]).isDirectory + guard let isInstalled, isInstalled else { + throw WhisperError.modelsUnavailable("Repo destination does not exist \(repoDestination.absoluteString)") + } + let dirFiles = try FileManager.default.contentsOfDirectory(atPath: repoDestination.path(percentEncoded: false)) + let modelFiles = dirFiles.matching(glob: modelSearchPath) + var uniquePaths = Set(modelFiles.map { $0.components(separatedBy: "/").first! }) + + var variantPath: String? = nil + + if uniquePaths.count == 0 { + throw WhisperError.modelsUnavailable("Could not find model matching \"\(modelSearchPath)\"") + } else if uniquePaths.count == 1 { + variantPath = uniquePaths.first + } else if specificOnly { + // We only want the one specific model, and won't accept fuzzy fallbacks + throw WhisperError.modelsUnavailable("Multiple models found matching \"\(modelSearchPath)\" and specificOnly was set") + } else { + // If the model name search returns more than one unique model folder, then prepend the default "openai" prefix from whisperkittools to disambiguate + Logging.debug("No definitive model matching \"\(modelSearchPath)\"") + let adjustedModelSearchPath = "*openai*\(variant.description)" + Logging.debug("Searching for models matching \"\(adjustedModelSearchPath)\" in \(repo)") + let adjustedModelFiles = dirFiles.matching(glob: adjustedModelSearchPath) + uniquePaths = Set(adjustedModelFiles.map { $0.components(separatedBy: "/").first! }) + + if uniquePaths.count == 1 { + variantPath = uniquePaths.first + } + } + + guard let variantPath else { + // If there is still ambiguity, throw an error + throw WhisperError.modelsUnavailable("Could not find definitive model matching \"\(modelSearchPath)\"") + } + + let truePath = repoDestination.appending(path: variantPath) + return truePath + } catch { + Logging.debug(error) + throw error + } + } + public static func modelInstalled( + variant: String, + downloadBase: URL? = nil, + from repo: String = "argmaxinc/whisperkit-coreml", + specificOnly: Bool = false + ) -> Bool { + do { + let location = try modelLocation(variant: variant, downloadBase: downloadBase, from: repo, specificOnly: specificOnly) + let isInstalled = try? location.resourceValues(forKeys: [.isDirectoryKey]).isDirectory + return isInstalled ?? false + } catch { + Logging.error(error) + return false + } + } /// Sets up the model folder either from a local path or by downloading from a repository. open func setupModels( model: String?, downloadBase: URL? = nil, - modelRepo: String?, + modelRepo: String? = nil, modelFolder: String?, download: Bool ) async throws { @@ -298,6 +373,12 @@ open class WhisperKit { Error: \(error) """) } + } else { + let modelSupport = WhisperKit.recommendedModels() + let modelVariant = model ?? modelSupport.default + let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" + guard let folder = try? Self.modelLocation(variant: modelVariant, downloadBase: downloadBase, from: repo) else { return } + self.modelFolder = folder } } diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index e633558..47298e5 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -31,6 +31,46 @@ final class UnitTests: XCTestCase { "Failed to init WhisperKit" ) } + func testModelIsInstalled() async throws { + XCTAssertTrue( + WhisperKit.modelInstalled(variant: "openai_whisper-tiny"), + "Model was not installed" + ) + XCTAssertFalse( + WhisperKit.modelInstalled(variant: "THIS_MODEL_DOES_NOT_EXIST"), + "Model does not exist, but returned true" + ) + XCTAssertTrue( + WhisperKit.modelInstalled(variant: "tiny"), + "Model was not installed" + ) + XCTAssertTrue( + WhisperKit.modelInstalled(variant: "tiny.en"), + "Model was not installed" + ) + let location = try WhisperKit.modelLocation(variant: "tiny") + let location2 = try WhisperKit.modelLocation(variant: "openai_whisper-tiny") + XCTAssertEqual(location, location2, "Auto fallback to OpenAI model returned a different result") + } + func testOfflineMode() async throws { + + let pipe = try await WhisperKit(WhisperKitConfig(model: "tiny.en", download: false)) + + let cancellable: AnyCancellable? = pipe.progress.publisher(for: \.fractionCompleted) + .removeDuplicates() + .withPrevious() + .sink { previous, current in + if let previous { + XCTAssertLessThan(previous, current) + } + } + _ = try await pipe.transcribe( + audioPath: Bundle.current.path(forResource: "ted_60", ofType: "m4a")!, + decodeOptions: .init(chunkingStrategy: .vad) + ) + cancellable?.cancel() + + } // MARK: - Config Tests