Skip to content

Commit

Permalink
initial commit of vlm
Browse files Browse the repository at this point in the history
- based on models from https://github.com/Blaizzy/mlx-vlm
- for #132
  • Loading branch information
davidkoski committed Nov 18, 2024
1 parent 7baf9bc commit 3230887
Show file tree
Hide file tree
Showing 35 changed files with 1,956 additions and 1,699 deletions.
37 changes: 7 additions & 30 deletions Libraries/LLM/Configuration.swift
Original file line number Diff line number Diff line change
@@ -1,30 +1,7 @@
// Copyright © 2024 Apple Inc.

import Foundation

public enum StringOrNumber: Codable, Equatable, Sendable {
case string(String)
case float(Float)

public init(from decoder: Decoder) throws {
let values = try decoder.singleValueContainer()

if let v = try? values.decode(Float.self) {
self = .float(v)
} else {
let v = try values.decode(String.self)
self = .string(v)
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.singleValueContainer()
switch self {
case .string(let v): try container.encode(v)
case .float(let v): try container.encode(v)
}
}
}
import MLXLMCommon

private class ModelTypeRegistry: @unchecked Sendable {

Expand All @@ -34,13 +11,13 @@ private class ModelTypeRegistry: @unchecked Sendable {
private let lock = NSLock()

@Sendable
private static func createLlamaModel(url: URL) throws -> LLMModel {
private static func createLlamaModel(url: URL) throws -> any LLMModel {
let configuration = try JSONDecoder().decode(
LlamaConfiguration.self, from: Data(contentsOf: url))
return LlamaModel(configuration)
}

private var creators: [String: @Sendable (URL) throws -> LLMModel] = [
private var creators: [String: @Sendable (URL) throws -> any LLMModel] = [
"mistral": createLlamaModel,
"llama": createLlamaModel,
"phi": { url in
Expand Down Expand Up @@ -96,14 +73,14 @@ private class ModelTypeRegistry: @unchecked Sendable {
]

public func registerModelType(
_ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel
_ type: String, creator: @Sendable @escaping (URL) throws -> any LLMModel
) {
lock.withLock {
creators[type] = creator
}
}

public func createModel(configuration: URL, rawValue: String) throws -> LLMModel {
public func createModel(configuration: URL, rawValue: String) throws -> any LLMModel {
let creator = lock.withLock {
creators[rawValue]
}
Expand All @@ -125,12 +102,12 @@ public struct ModelType: RawRepresentable, Codable, Sendable {
}

public static func registerModelType(
_ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel
_ type: String, creator: @Sendable @escaping (URL) throws -> any LLMModel
) {
modelTypeRegistry.registerModelType(type, creator: creator)
}

public func createModel(configuration: URL) throws -> LLMModel {
public func createModel(configuration: URL) throws -> any LLMModel {
try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue)
}
}
Expand Down
1 change: 0 additions & 1 deletion Libraries/LLM/LLM.h

This file was deleted.

126 changes: 23 additions & 103 deletions Libraries/LLM/LLMModel.swift
Original file line number Diff line number Diff line change
@@ -1,116 +1,36 @@
// Copyright © 2024 Apple Inc.

import Foundation
@preconcurrency import Hub
import MLX
import MLXNN
import Tokenizers
import MLXLMCommon

/// Container for models that guarantees single threaded access.
///
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access
/// the model and/or tokenizer:
///
/// ```swift
/// let messages = [["role": "user", "content": prompt]]
/// let promptTokens = try await modelContainer.perform { _, tokenizer in
/// try tokenizer.applyChatTemplate(messages: messages)
/// }
/// ```
///
/// or:
///
/// ```swift
/// let result = await modelContainer.perform { model, tokenizer in
/// LLM.generate(
/// promptTokens: promptTokens, parameters: generateParameters, model: model,
/// tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
/// ) { tokens in
/// ...
/// }
/// }
/// ```
public actor ModelContainer {
let model: LLMModel
let tokenizer: Tokenizer

public init(model: LLMModel, tokenizer: Tokenizer) {
self.model = model
self.tokenizer = tokenizer
}

/// build the model and tokenizer without passing non-sendable data over isolation barriers
public init(
hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration
) async throws {
self.model = try loadSynchronous(modelDirectory: modelDirectory)

let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig(
configuration: configuration, hub: hub)
self.tokenizer = try PreTrainedTokenizer(
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as
/// `MLXArray` is not `Sendable`.
public func perform<R>(_ action: @Sendable (LLMModel, Tokenizer) throws -> R) rethrows -> R {
try action(model, tokenizer)
}
}

extension Module {

/// Compute the number of parameters in a possibly quantized model
public func numParameters() -> Int {
return leafModules().flattenedValues().map {
mod -> Int in
if let qlin = mod as? QuantizedLinear {
return qlin.scales.size * qlin.groupSize
} else if let qemb = mod as? QuantizedEmbedding {
return qemb.scales.size * qemb.groupSize
} else {
return mod.parameters().flattenedValues().reduce(
0,
{
$0 + $1.size
})
}
}.reduce(0, +)
}
// TODO document
public protocol LLMModel: LanguageModel, LoRAModel {
}

/// Interface for all LLM Models
public protocol LLMModel: Module {

var vocabularySize: Int { get }
extension LLMModel {

func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray
public func prepare(_ input: LMInput, cache: [KVCache], windowSize: Int?) throws
-> PrepareResult
{

/// create a new array of ``KVCache`` -- automatic implementation if self
/// implements ``KVCacheDimensionProvider``
func newCache(parameters: GenerateParameters) -> [KVCache]
let prefillStepSize = windowSize ?? 512
var y = input.text
var state: LMOutput.State? = nil

/// Optionally preprocess the weights and modify / remove values as needed.
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray]
}

/// Optional protocol that can be implemented by ``LLMModel`` and will
/// provide an automatic implementation of ``LLMModel/newCache(parameters:)``
public protocol KVCacheDimensionProvider {
var kvHeads: [Int] { get }
var headDim: IntOrPair { get }
}
// prepare the prompt in chunks if larger than the prefill size
while y.tokens.size > prefillStepSize {
let input = y[.newAxis, ..<prefillStepSize]
let result = self(input, cache: cache.isEmpty ? nil : cache, state: state)
eval(cache)
y = y[prefillStepSize...]
}

extension LLMModel {
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
weights
return .tokens(y)
}
}

extension LLMModel where Self: KVCacheDimensionProvider {
public func newCache(parameters: GenerateParameters) -> [KVCache] {
kvHeads.map { n in
KVCacheSimple(headDim: headDim, kvHeads: n)
}
}
}
// TODO move? these cause some ambiguity -- how to resolve?
//public typealias ModelConfiguration = MLXLMCommon.ModelConfiguration
//public typealias GenerateParameters = MLXLMCommon.GenerateParameters
//public typealias GenerateResult = MLXLMCommon.GenerateParameters
//public typealias NaiveStreamingDetokenizer = MLXLMCommon.NaiveStreamingDetokenizer
7 changes: 4 additions & 3 deletions Libraries/LLM/Load.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import Foundation
@preconcurrency import Hub
import MLX
import MLXLMCommon
import MLXNN
import MLXRandom
import Tokenizers
Expand All @@ -20,7 +21,7 @@ func prepareModelDirectory(
case .id(let id):
// download the model weights
let repo = Hub.Repo(id: id)
let modelFiles = ["*.safetensors", "config.json"]
let modelFiles = ["*.safetensors", "*.json"]
return try await hub.snapshot(
from: repo, matching: modelFiles, progressHandler: progressHandler)

Expand All @@ -47,7 +48,7 @@ func prepareModelDirectory(
public func load(
hub: HubApi = HubApi(), configuration: ModelConfiguration,
progressHandler: @Sendable @escaping (Progress) -> Void = { _ in }
) async throws -> (LLMModel, Tokenizer) {
) async throws -> (any LLMModel, Tokenizer) {
let modelDirectory = try await prepareModelDirectory(
hub: hub, configuration: configuration, progressHandler: progressHandler)
let model = try loadSynchronous(modelDirectory: modelDirectory)
Expand All @@ -56,7 +57,7 @@ public func load(
return (model, tokenizer)
}

func loadSynchronous(modelDirectory: URL) throws -> LLMModel {
func loadSynchronous(modelDirectory: URL) throws -> any LLMModel {
// create the model (no weights loaded)
let configurationURL = modelDirectory.appending(component: "config.json")
let baseConfig = try JSONDecoder().decode(
Expand Down
Loading

0 comments on commit 3230887

Please sign in to comment.