Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repo option to regression test matrix #293

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Examples/WhisperAX/Debug.xcconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Run `make setup` to add your team here
DEVELOPMENT_TEAM=
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@
value = "$(MODEL_NAME)"
isEnabled = "YES">
</EnvironmentVariable>
<EnvironmentVariable
key = "MODEL_REPO"
value = "$(MODEL_REPO)"
isEnabled = "YES">
</EnvironmentVariable>
</EnvironmentVariables>
</LaunchAction>
<ProfileAction
Expand Down
4 changes: 4 additions & 0 deletions Sources/WhisperKit/Core/Configurations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ open class WhisperKitConfig {
public var downloadBase: URL?
/// Repository for downloading models
public var modelRepo: String?
/// Token for downloading models from repo (if required)
public var modelToken: String?

/// Folder to store models
public var modelFolder: String?
Expand Down Expand Up @@ -47,6 +49,7 @@ open class WhisperKitConfig {
public init(model: String? = nil,
downloadBase: URL? = nil,
modelRepo: String? = nil,
modelToken: String? = nil,
modelFolder: String? = nil,
tokenizerFolder: URL? = nil,
computeOptions: ModelComputeOptions? = nil,
Expand All @@ -67,6 +70,7 @@ open class WhisperKitConfig {
self.model = model
self.downloadBase = downloadBase
self.modelRepo = modelRepo
self.modelToken = modelToken
self.modelFolder = modelFolder
self.tokenizerFolder = tokenizerFolder
self.computeOptions = computeOptions
Expand Down
6 changes: 4 additions & 2 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ open class WhisperKit {
model: config.model,
downloadBase: config.downloadBase,
modelRepo: config.modelRepo,
modelToken: config.modelToken,
modelFolder: config.modelFolder,
download: config.download
)


if let prewarm = config.prewarm, prewarm {
Logging.info("Prewarming models...")
Expand Down Expand Up @@ -295,6 +295,7 @@ open class WhisperKit {
model: String?,
downloadBase: URL? = nil,
modelRepo: String?,
modelToken: String? = nil,
modelFolder: String?,
download: Bool
) async throws {
Expand All @@ -312,7 +313,8 @@ open class WhisperKit {
variant: modelVariant,
downloadBase: downloadBase,
useBackgroundSession: useBackgroundDownloadSession,
from: repo
from: repo,
token: modelToken
)
} catch {
// Handle errors related to model downloading
Expand Down
6 changes: 6 additions & 0 deletions Tests/WhisperKitTests/RegressionTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TestInfo: JSONCodable {
let datasetDir: String
let datasetRepo: String
let model: String
let modelRepo: String
let modelSizeMB: Double
let date: String
let timeElapsedInSeconds: TimeInterval
Expand All @@ -69,6 +70,7 @@ class TestInfo: JSONCodable {
datasetDir: String,
datasetRepo: String,
model: String,
modelRepo: String,
modelSizeMB: Double,
date: String,
timeElapsedInSeconds: TimeInterval,
Expand All @@ -83,6 +85,7 @@ class TestInfo: JSONCodable {
self.datasetDir = datasetDir
self.datasetRepo = datasetRepo
self.model = model
self.modelRepo = modelRepo
self.modelSizeMB = modelSizeMB
self.date = date
self.timeElapsedInSeconds = timeElapsedInSeconds
Expand All @@ -101,6 +104,7 @@ struct TestReport: JSONCodable {
let osType: String
let osVersion: String
let modelsTested: [String]
let modelReposTested: [String]
let failureInfo: [String: String]
let attachments: [String: String]

Expand All @@ -109,13 +113,15 @@ struct TestReport: JSONCodable {
osType: String,
osVersion: String,
modelsTested: [String],
modelReposTested: [String],
failureInfo: [String: String],
attachments: [String: String]
) {
self.deviceModel = deviceModel
self.osType = osType
self.osVersion = osVersion
self.modelsTested = modelsTested
self.modelReposTested = modelReposTested
self.failureInfo = failureInfo
self.attachments = attachments
}
Expand Down
78 changes: 56 additions & 22 deletions Tests/WhisperKitTests/RegressionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@ import WatchKit
#endif

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
final class RegressionTests: XCTestCase {
class RegressionTests: XCTestCase {
var audioFileURLs: [URL]?
var remoteFileURLs: [URL]?
var metadataURL: URL?
var testWERURLs: [URL]?
var modelsToTest: [String] = []
var modelReposToTest: [String] = []
var modelsTested: [String] = []
var modelReposTested: [String] = []
var optionsToTest: [DecodingOptions] = [DecodingOptions()]

struct TestConfig {
let dataset: String
let modelComputeOptions: ModelComputeOptions
var model: String
var modelRepo: String
let decodingOptions: DecodingOptions
}

Expand All @@ -34,6 +37,7 @@ final class RegressionTests: XCTestCase {
var datasets = ["librispeech-10mins", "earnings22-10mins"]
let debugDataset = ["earnings22-10mins"]
let debugModels = ["tiny"]
let debugRepos = ["argmaxinc/whisperkit-coreml"]

var computeOptions: [ModelComputeOptions] = [
ModelComputeOptions(audioEncoderCompute: .cpuAndNeuralEngine, textDecoderCompute: .cpuAndNeuralEngine),
Expand Down Expand Up @@ -71,16 +75,29 @@ final class RegressionTests: XCTestCase {
Logging.debug("Max memory before warning: \(maxMemory)")
}

func testEnvConfigurations(defaultModels: [String]? = nil) {
class func getModelToken() -> String? {
// Add token here or override
return nil
}

func testEnvConfigurations(defaultModels: [String]? = nil, defaultRepos: [String]? = nil) {
if let modelSizeEnv = ProcessInfo.processInfo.environment["MODEL_NAME"], !modelSizeEnv.isEmpty {
modelsToTest = [modelSizeEnv]
Logging.debug("Model size: \(modelSizeEnv)")

if let repoEnv = ProcessInfo.processInfo.environment["MODEL_REPO"] {
modelReposToTest = [repoEnv]
Logging.debug("Using repo: \(repoEnv)")
}

XCTAssertTrue(modelsToTest.count > 0, "Invalid model size: \(modelSizeEnv)")

if modelSizeEnv == "crash_test" {
fatalError("Crash test triggered")
}
} else {
modelsToTest = defaultModels ?? debugModels
modelReposToTest = defaultRepos ?? debugRepos
Logging.debug("Model size not set by env")
}
}
Expand Down Expand Up @@ -116,7 +133,7 @@ final class RegressionTests: XCTestCase {

// MARK: - Test Pipeline

private func runRegressionTests(with testMatrix: [TestConfig]) async throws {
public func runRegressionTests(with testMatrix: [TestConfig]) async throws {
var failureInfo: [String: String] = [:]
var attachments: [String: String] = [:]
let device = getCurrentDevice()
Expand Down Expand Up @@ -159,8 +176,7 @@ final class RegressionTests: XCTestCase {

// Create WhisperKit instance with checks for memory usage
let whisperKit = try await createWithMemoryCheck(
model: config.model,
computeOptions: config.modelComputeOptions,
testConfig: config,
verbose: true,
logLevel: .debug
)
Expand All @@ -169,6 +185,8 @@ final class RegressionTests: XCTestCase {
config.model = modelFile
modelsTested.append(modelFile)
modelsTested = Array(Set(modelsTested))
modelReposTested.append(config.modelRepo)
modelReposTested = Array(Set(modelReposTested))
}

for audioFilePath in audioFilePaths {
Expand Down Expand Up @@ -295,6 +313,7 @@ final class RegressionTests: XCTestCase {
datasetDir: config.dataset,
datasetRepo: datasetRepo,
model: config.model,
modelRepo: config.modelRepo,
modelSizeMB: modelSizeMB ?? -1,
date: startTime.formatted(Date.ISO8601FormatStyle().dateSeparator(.dash)),
timeElapsedInSeconds: Date().timeIntervalSince(startTime),
Expand Down Expand Up @@ -432,20 +451,23 @@ final class RegressionTests: XCTestCase {
}
}

private func getTestMatrix() -> [TestConfig] {
public func getTestMatrix() -> [TestConfig] {
var regressionTestConfigMatrix: [TestConfig] = []
for dataset in datasets {
for computeOption in computeOptions {
for options in optionsToTest {
for model in modelsToTest {
regressionTestConfigMatrix.append(
TestConfig(
dataset: dataset,
modelComputeOptions: computeOption,
model: model,
decodingOptions: options
for repo in modelReposToTest {
for model in modelsToTest {
regressionTestConfigMatrix.append(
TestConfig(
dataset: dataset,
modelComputeOptions: computeOption,
model: model,
modelRepo: repo,
decodingOptions: options
)
)
)
}
}
}
}
Expand Down Expand Up @@ -555,6 +577,7 @@ final class RegressionTests: XCTestCase {
osType: osDetails.osType,
osVersion: osDetails.osVersion,
modelsTested: modelsTested,
modelReposTested: modelReposTested,
failureInfo: failureInfo,
attachments: attachments
)
Expand Down Expand Up @@ -610,17 +633,14 @@ final class RegressionTests: XCTestCase {
return Double(modelSize / (1024 * 1024)) // Convert to MB
}

func createWithMemoryCheck(
model: String,
computeOptions: ModelComputeOptions,
verbose: Bool,
logLevel: Logging.LogLevel
) async throws -> WhisperKit {
public func initWhisperKitTask(testConfig config: TestConfig, verbose: Bool, logLevel: Logging.LogLevel) -> Task<WhisperKit, Error> {
// Create the initialization task
let initializationTask = Task { () -> WhisperKit in
let whisperKit = try await WhisperKit(WhisperKitConfig(
model: model,
computeOptions: computeOptions,
model: config.model,
modelRepo: config.modelRepo,
modelToken: Self.getModelToken(),
computeOptions: config.modelComputeOptions,
verbose: verbose,
logLevel: logLevel,
prewarm: true,
Expand All @@ -629,6 +649,20 @@ final class RegressionTests: XCTestCase {
try Task.checkCancellation()
return whisperKit
}
return initializationTask
}

func createWithMemoryCheck(
testConfig: TestConfig,
verbose: Bool,
logLevel: Logging.LogLevel
) async throws -> WhisperKit {
// Create the initialization task
let initializationTask = initWhisperKitTask(
testConfig: testConfig,
verbose: verbose,
logLevel: logLevel
)

// Start the memory monitoring task
let monitorTask = Task {
Expand Down
16 changes: 12 additions & 4 deletions fastlane/Fastfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ BASE_BENCHMARK_PATH = "#{WORKING_DIR}/benchmark_data".freeze
BASE_UPLOAD_PATH = "#{WORKING_DIR}/upload_folder".freeze
XCRESULT_PATH = File.expand_path("#{BASE_BENCHMARK_PATH}/#{COMMIT_TIMESTAMP}_#{COMMIT_HASH}/")
BENCHMARK_REPO = 'argmaxinc/whisperkit-evals-dataset'.freeze
BENCHMARK_CONFIGS = {
BENCHMARK_CONFIGS ||= {
full: {
test_identifier: 'WhisperAXTests/RegressionTests/testModelPerformance',
name: 'full',
Expand All @@ -50,12 +50,14 @@ BENCHMARK_CONFIGS = {
'openai_whisper-large-v3-v20240930_turbo',
'openai_whisper-large-v3-v20240930_626MB',
'openai_whisper-large-v3-v20240930_turbo_632MB'
]
],
repo: 'argmaxinc/whisperkit-coreml'
},
debug: {
test_identifier: 'WhisperAXTests/RegressionTests/testModelPerformanceWithDebugConfig',
name: 'debug',
models: ['tiny', 'crash_test', 'unknown_model', 'small.en']
models: ['tiny', 'crash_test', 'unknown_model', 'small.en'],
repo: 'argmaxinc/whisperkit-coreml'
}
}.freeze

Expand Down Expand Up @@ -200,7 +202,9 @@ end

def run_benchmark(devices, config)
summaries = []
BENCHMARK_CONFIGS[config][:models].each do |model|
config_data = BENCHMARK_CONFIGS[config]

config_data[:models].each do |model|
begin
# Sanitize device name for use in file path
devices_to_test = devices.map { |device_info| device_info[:name] }.compact
Expand Down Expand Up @@ -228,8 +232,12 @@ def run_benchmark(devices, config)
UI.message "Running in #{BENCHMARK_CONFIGS[config][:name]} mode"

UI.message "Running benchmark for model: #{model}"
UI.message 'Using Hugging Face:'
UI.message " • Repository: #{config_data[:repo]}"

xcargs = [
"MODEL_NAME=#{model}",
"MODEL_REPO=#{config_data[:repo]}",
'-allowProvisioningUpdates',
'-allowProvisioningDeviceRegistration'
].join(' ')
Expand Down
Loading