-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- based on models from https://github.com/Blaizzy/mlx-vlm - for #132
- Loading branch information
1 parent
7baf9bc
commit 3230887
Showing
35 changed files
with
1,956 additions
and
1,699 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,116 +1,36 @@ | ||
// Copyright © 2024 Apple Inc. | ||
|
||
import Foundation | ||
@preconcurrency import Hub | ||
import MLX | ||
import MLXNN | ||
import Tokenizers | ||
import MLXLMCommon | ||
|
||
/// Container for models that guarantees single threaded access. | ||
/// | ||
/// Wrap models used by e.g. the UI in a ModelContainer. Callers can access | ||
/// the model and/or tokenizer: | ||
/// | ||
/// ```swift | ||
/// let messages = [["role": "user", "content": prompt]] | ||
/// let promptTokens = try await modelContainer.perform { _, tokenizer in | ||
/// try tokenizer.applyChatTemplate(messages: messages) | ||
/// } | ||
/// ``` | ||
/// | ||
/// or: | ||
/// | ||
/// ```swift | ||
/// let result = await modelContainer.perform { model, tokenizer in | ||
/// LLM.generate( | ||
/// promptTokens: promptTokens, parameters: generateParameters, model: model, | ||
/// tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens | ||
/// ) { tokens in | ||
/// ... | ||
/// } | ||
/// } | ||
/// ``` | ||
public actor ModelContainer { | ||
let model: LLMModel | ||
let tokenizer: Tokenizer | ||
|
||
public init(model: LLMModel, tokenizer: Tokenizer) { | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
} | ||
|
||
/// build the model and tokenizer without passing non-sendable data over isolation barriers | ||
public init( | ||
hub: HubApi, modelDirectory: URL, configuration: ModelConfiguration | ||
) async throws { | ||
self.model = try loadSynchronous(modelDirectory: modelDirectory) | ||
|
||
let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( | ||
configuration: configuration, hub: hub) | ||
self.tokenizer = try PreTrainedTokenizer( | ||
tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) | ||
} | ||
|
||
/// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as | ||
/// `MLXArray` is not `Sendable`. | ||
public func perform<R>(_ action: @Sendable (LLMModel, Tokenizer) throws -> R) rethrows -> R { | ||
try action(model, tokenizer) | ||
} | ||
} | ||
|
||
extension Module { | ||
|
||
/// Compute the number of parameters in a possibly quantized model | ||
public func numParameters() -> Int { | ||
return leafModules().flattenedValues().map { | ||
mod -> Int in | ||
if let qlin = mod as? QuantizedLinear { | ||
return qlin.scales.size * qlin.groupSize | ||
} else if let qemb = mod as? QuantizedEmbedding { | ||
return qemb.scales.size * qemb.groupSize | ||
} else { | ||
return mod.parameters().flattenedValues().reduce( | ||
0, | ||
{ | ||
$0 + $1.size | ||
}) | ||
} | ||
}.reduce(0, +) | ||
} | ||
// TODO document | ||
public protocol LLMModel: LanguageModel, LoRAModel { | ||
} | ||
|
||
/// Interface for all LLM Models | ||
public protocol LLMModel: Module { | ||
|
||
var vocabularySize: Int { get } | ||
extension LLMModel { | ||
|
||
func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray | ||
public func prepare(_ input: LMInput, cache: [KVCache], windowSize: Int?) throws | ||
-> PrepareResult | ||
{ | ||
|
||
/// create a new array of ``KVCache`` -- automatic implementation if self | ||
/// implements ``KVCacheDimensionProvider`` | ||
func newCache(parameters: GenerateParameters) -> [KVCache] | ||
let prefillStepSize = windowSize ?? 512 | ||
var y = input.text | ||
var state: LMOutput.State? = nil | ||
|
||
/// Optionally preprocess the weights and modify / remove values as needed. | ||
func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] | ||
} | ||
|
||
/// Optional protocol that can be implemented by ``LLMModel`` and will | ||
/// provide an automatic implementation of ``LLMModel/newCache(parameters:)`` | ||
public protocol KVCacheDimensionProvider { | ||
var kvHeads: [Int] { get } | ||
var headDim: IntOrPair { get } | ||
} | ||
// prepare the prompt in chunks if larger than the prefill size | ||
while y.tokens.size > prefillStepSize { | ||
let input = y[.newAxis, ..<prefillStepSize] | ||
let result = self(input, cache: cache.isEmpty ? nil : cache, state: state) | ||
eval(cache) | ||
y = y[prefillStepSize...] | ||
} | ||
|
||
extension LLMModel { | ||
public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { | ||
weights | ||
return .tokens(y) | ||
} | ||
} | ||
|
||
extension LLMModel where Self: KVCacheDimensionProvider { | ||
public func newCache(parameters: GenerateParameters) -> [KVCache] { | ||
kvHeads.map { n in | ||
KVCacheSimple(headDim: headDim, kvHeads: n) | ||
} | ||
} | ||
} | ||
// TODO move? these cause some ambiguity -- how to resolve? | ||
//public typealias ModelConfiguration = MLXLMCommon.ModelConfiguration | ||
//public typealias GenerateParameters = MLXLMCommon.GenerateParameters | ||
//public typealias GenerateResult = MLXLMCommon.GenerateParameters | ||
//public typealias NaiveStreamingDetokenizer = MLXLMCommon.NaiveStreamingDetokenizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.