-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add VLM support, refactor common LM code into MLXLMCommon. breaking API changes #151
Conversation
@@ -1,30 +1,7 @@ | |||
// Copyright © 2024 Apple Inc. | |||
|
|||
import Foundation | |||
|
|||
public enum StringOrNumber: Codable, Equatable, Sendable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to LMCommon
|
||
/// Container for models that guarantees single threaded access. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to ModelContainer
Libraries/LLM/LLMModel.swift
Outdated
} | ||
} | ||
} | ||
// TODO move? these cause some ambiguity -- how to resolve? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was playing around with these to avoid breaking API -- moving types into LMCommon means callers will need to import LMCommon if they refer to them. This (the aliases) caused more trouble than I think it is worth
@@ -3,6 +3,7 @@ | |||
import Foundation | |||
@preconcurrency import Hub | |||
import MLX | |||
import MLXLMCommon | |||
import MLXNN | |||
import MLXRandom | |||
import Tokenizers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ultimately I would like this to move into LMCommon -- I think it can support both LLM and VLM models, but I didn't get a chance to move this yet.
import MLXNN | ||
import MLXOptimizers | ||
import MLXRandom | ||
import Tokenizers | ||
|
||
/// Layers to apply LoRA adapters to. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to LMCommon
Libraries/LLM/LoraTrain.swift
Outdated
return y + scale * z | ||
} | ||
} | ||
|
||
/// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``. | ||
struct LoRABatchIterator: Sequence, IteratorProtocol { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally the rest of this moves to LMCommon as well -- I think it can.
Libraries/LMCommon/Evaluate.swift
Outdated
mutating func prompt(_ prompt: MLXArray) | ||
func process(logits: MLXArray) -> MLXArray | ||
mutating func didSample(token: MLXArray) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generate / step code has been refactored a bit and can now take custom logit samplers and processors
Libraries/LMCommon/Evaluate.swift
Outdated
public init( | ||
prompt: MLXArray, model: any LanguageModel, cache: [KVCache]? = nil, | ||
parameters: GenerateParameters | ||
) throws { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This now takes either a prompt (MLXArray) or an LMInput (text + image + ...) via multiple initializers.
} | ||
} | ||
|
||
public struct LMInput { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A new union type that holds the different inputs to generate()
and LanguageModel.prepare()
} | ||
} | ||
|
||
public struct LMOutput { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Union type for the output. Some of the VLMs return additional state, which is represented here.
Libraries/LMCommon/Models.swift
Outdated
@@ -134,6 +135,7 @@ extension ModelConfiguration { | |||
extraEOSTokens: ["<|end|>"] | |||
) | |||
|
|||
// TODO the prompt formatter is replaced by the chat template |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or is it? #150
Libraries/LMCommon/Processor.swift
Outdated
|
||
import CoreImage | ||
import Foundation | ||
import MLX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file may be deleted -- it was some notes & thoughts along the way
Libraries/LMCommon/Prompt.swift
Outdated
// Copyright © 2024 Apple Inc. | ||
|
||
import Foundation | ||
import MLX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also to be deleted -- LMInput
replaces this
Libraries/VLM/MediaProcessing.swift
Outdated
private let context = CIContext() | ||
|
||
// TODO documentation | ||
public enum MediaProcessing { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs documentation, but see PaliGemmaImageProvider
which implements
SiglipImageProcessor {
"do_convert_rgb": null,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "SiglipImageProcessor",
"image_seq_length": 1024,
"image_std": [
0.5,
0.5,
0.5
],
"processor_class": "PaliGemmaProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"height": 448,
"width": 448
}
}
from the python transformers code.
Libraries/VLM/Models/Paligemma.swift
Outdated
import MLXNN | ||
import Tokenizers | ||
|
||
// MARK: - Language |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First cut at a port of https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/paligemma
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this builds, loads weights and "runs" but doesn't produce any output -- still needs to be debugged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be usable as an example of the structure I think we need
Libraries/VLM/Models/Paligemma.swift
Outdated
} | ||
} | ||
|
||
// TODO does not suport multiple images -- how do we represent? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a protocol for the image and text processing pieces.
Libraries/VLM/Models/Paligemma.swift
Outdated
image = MediaProcessing.inSRGBToneCurveSpace(image) | ||
|
||
image = MediaProcessing.resampleBicubic(image, to: .init(width: size, height: size)) | ||
image = MediaProcessing.normalize(image, mean: (0.5, 0.5, 0.5), std: (0.5, 0.5, 0.5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SiglipImageProcessor {
"do_convert_rgb": null,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.5,
0.5,
0.5
],
"image_processor_type": "SiglipImageProcessor",
"image_seq_length": 1024,
"image_std": [
0.5,
0.5,
0.5
],
"processor_class": "PaliGemmaProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"height": 448,
"width": 448
}
}
Libraries/VLM/Models/Paligemma.swift
Outdated
} | ||
} | ||
|
||
private func loadConfiguration(url: URL) throws -> PaliGemma { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These next couple of functions are just stubs to let me try it out -- this will work much like the LLM models
Libraries/VLM/Models/Paligemma.swift
Outdated
private let _ropeTheta: Float? | ||
public var ropeTheta: Float { _ropeTheta ?? 10_000 } | ||
public let _ropeTraditional: Bool? | ||
public var ropeTraditional: Bool { _ropeTraditional ?? false } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than doing the full implementation of Codable I went a simpler route for default values. Less code, cleaner (I think)
Tools/llm-tool/LLMTool.swift
Outdated
@Option var path: URL | ||
|
||
@MainActor | ||
mutating func run() async throws { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just stub code to exercise the model. This still needs the input processing layers, in particular the prompt processing. The image processing is in place but will need to be wrapped up API-wise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now the real code
e19f736
to
5ffe9b3
Compare
import MLX | ||
import MLXLLM | ||
import MLXLMCommon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See PR description -- split LLM -> LLM and LMCommon. Switched local names to match what people get via swiftpm (MLXLLM, etc.).
@@ -159,7 +160,7 @@ class LLMEvaluator { | |||
|
|||
/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on | |||
/// more devices. | |||
let modelConfiguration = ModelConfiguration.phi3_5_4bit | |||
let modelConfiguration = ModelRegistry.phi3_5_4bit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From PR description:
- constants for models have moved from
ModelConfiguration
toModelRegistry
- this is
MLXLM.ModelRegistry
and there is alsoMLXVLM.ModelRegistry
- let modelConfiguration = ModelConfiguration.phi3_5_4bit
+ let modelConfiguration = ModelRegistry.phi3_5_4bit
- based on models from https://github.com/Blaizzy/mlx-vlm There are two new libraries: - `MLXVLM` contains vision language models that combine images and text prompts to produce text results, e.g. `describe this image` - `MLXLMCommon` contains the `LanguageModel` code that is shared between `MLXLLM` and `MLXVLM` The API between `LLM` and `VLM` is identical aside from the preparation of the `UserInput`. ```swift let parameters = GenerateParameters() // LLM prompt let input = UserInput(prompt: "tell me a story") // VLM prompt let input = UserInput(prompt: "describe the image", images: [.url(url)]) // inference is identical let result = try await modelContainer.perform { [generate, input] context in let input = try await context.processor.prepare(input: input) return try generate(input: input, parameters: parameters, context: context) { token in // print tokens as they are generated, stop early, etc. return .more } } ``` VLM example code is available in the `llm-tool` example: ``` ./mlx-run llm-tool vlm --help OVERVIEW: evaluate prompt and images to generate text (VLM) USAGE: llm-tool vlm <options> OPTIONS: --model <model> Name of the huggingface model or absolute path to directory -p, --prompt <prompt> The message to be processed by the model. Use @path,@path to load from files, e.g. @/tmp/prompt.txt --resize <resize> Resize images to this size (width, height) --image <image> Paths or urls for input images ... ``` Probably no effect to code external to this repo: - the mlx-swift-examples.xcodeproj now references the local `Package.swift` to build the libraries - the example code now uses the naming matching external uses of mlx-swift-examples, e.g. `import LLM` -> `import MLXLLM` - the library directories are now renamed to match their target names, e.g. `LLM` -> `MLXLLM` Breaking: - some code will now need to import both `MLXLLM` and `MLXLMCommon` (particularly code that loads models) - `MLXLMCommon` contains the common API between LLM and VLM ```swift import MLXLLM import MLXLMCommon ``` - constants for models have moved from `ModelConfiguration` to `ModelRegistry` - this is `MLXLM.ModelRegistry` and there is also `MLXVLM.ModelRegistry` ```diff - let modelConfiguration = ModelConfiguration.phi3_5_4bit + let modelConfiguration = ModelRegistry.phi3_5_4bit ``` - the `loadModelContainer()` function is now `LLMModelFactory.shared.loadContainer()` - there is a new `VLMModelFactory` with identical methods for loading VLMs ```diff - let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration) - { + let modelContainer = try await LLMModelFactory.shared.loadContainer( + configuration: modelConfiguration + ) { ``` - `ModelContainer.perform` is now throwing (and in MLXLMCommon): ```diff - let result = await modelContainer.perform { model, tokenizer in - LLM.generate( + let result = try await modelContainer.perform { model, tokenizer in + try MLXLMCommon.generate( ``` - `ModelConfiguration` previously had a way to register new configurations. This is now on `LLMModelFactory` (and `VLMModelFactory` has the same): ```swift LLMModelFactory.shared.modelRegistry.register(configurations: [modelConfiguration]) ``` An example at the end shows all of these deprecations in context. **Prefer to use the `ModelContext.processor` to prepare prompts.** Previously users would pass in a bare `[Int]` of tokens, but in order to support more complex inputs (VLMs) the use of bare `[Int]` is deprecated and callers should use `UserInput` and `LMInput`. For example, previously callers might have done something like this: ```swift let messages = [["role": "user", "content": prompt]] let promptTokens = try await modelContainer.perform { _, tokenizer in try tokenizer.applyChatTemplate(messages: messages) } ``` Now that should be: ```swift let input = try await context.processor.prepare(input: .init(prompt: prompt)) ``` Which will initialize a `UserInput` from the prompt text and produce an `LMInput` that can be used to generate tokens. **This call to `generate()` is now deprecated:** ```swift public func generate( promptTokens: [Int], parameters: GenerateParameters, model: any LanguageModel, tokenizer: Tokenizer, extraEOSTokens: Set<String>? = nil, didGenerate: ([Int]) -> GenerateDisposition ) throws -> GenerateResult ``` This consumed the `[Int]` variety of tokens. Now this is preferred: ```swift public func generate( input: LMInput, parameters: GenerateParameters, context: ModelContext, didGenerate: ([Int]) -> GenerateDisposition ) throws -> GenerateResult ``` **This method on `ModelContainer` is now deprecated:** ```swift /// Perform an action on the model and/or tokenizer. Callers _must_ eval any `MLXArray` before returning as /// `MLXArray` is not `Sendable`. @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext") public func perform<R>(_ action: @sendable (any LanguageModel, Tokenizer) throws -> R) rethrows -> R ``` use this one instead (though the former still works): ```swift /// Perform an action on the ``ModelContext``. Callers _must_ eval any `MLXArray` before returning as /// `MLXArray` is not `Sendable`. public func perform<R>(_ action: @sendable (ModelContext) async throws -> R) async rethrows -> R ``` Putting all of these deprecations together, previously you might have generated text like this: ```swift let messages = [["role": "user", "content": prompt]] let promptTokens = try await modelContainer.perform { _, tokenizer in try tokenizer.applyChatTemplate(messages: messages) } let result = await modelContainer.perform { model, tokenizer in LLM.generate( promptTokens: promptTokens, parameters: generateParameters, model: model, tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens ) { tokens in ... } } ``` now do this: ```swift let result = try await modelContainer.perform { context in let input = try await context.processor.prepare(input: .init(prompt: prompt)) return try MLXLMCommon.generate( input: input, parameters: generateParameters, context: context ) { tokens in ... } } ```
This code is ready for review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incredibly cool. I barely touched the surface but leaving a small review and going to try running it shortly.
structure something like this: | ||
|
||
```swift | ||
public class YourModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw I changed the KV cache implementation in mlx-lm to just init the keys and values the first time you call it. There is no need to initialize the KV cache with a head dim etc. so we could probably remove this interface as well. (Just a comment not something that we need to update in this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will take a look at it -- if it simplifies things it may be worth including here as we are already making some breaking changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- revisit KVCache / mlx-lm
Libraries/MLXLLM/README.md
Outdated
public let kvHeads: [Int] | ||
public let headDim: IntOrPair |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And e.g. got rid of this which is not necessary
Tools/llm-tool/LLMTool.swift
Outdated
let (modelContainer, modelConfiguration) = try await memory.start(args.load) | ||
let modelContainer = try await memory.start { [args] in | ||
try await args.load( | ||
defaultModel: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should update this default model, it's pretty dated. Maybe to mlx-community/Mistral-7B-Instruct-v0.3-4bit
is a good option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I will give it a run and make sure it works!
- test this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is one of the preset models, so good to go
Co-authored-by: Awni Hannun <[email protected]>
Co-authored-by: Awni Hannun <[email protected]>
Tools/llm-tool/LLMTool.swift
Outdated
@@ -203,29 +206,88 @@ struct EvaluateCommand: AsyncParsableCommand { | |||
|
|||
@MainActor | |||
mutating func run() async throws { | |||
let (modelContainer, modelConfiguration) = try await memory.start(args.load) | |||
let modelContainer = try await memory.start { [args] in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this to LMCommand
and subcommand lm
, to match the VLMCommand
.
Alternatively (given the complexity) it might be worth using the same subcommand and just dispatching to the vlm subroutine if an image input is provided or not..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting idea! The default model is different, as is the model factory. We could certainly switch on the presence of an image (or video) to chose but I wonder if that complicates things over just having the two subcommands?
Let me try the refactor to fold these down into one and see if that looks reasonable.
- try refactor of vlm -> eval (lm) command
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it was a slightly off the cuff suggestion. It simplifies the command line but it might not be worth doing at the expense of code complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that worked well -- it came down to this (mostly):
// switch between LLM and VLM
let vlm = image.count > 0
if vlm {
modelFactory = VLMModelFactory.shared
defaultModel = MLXVLM.ModelRegistry.paligemma3bMix448_8bit
} else {
modelFactory = LLMModelFactory.shared
defaultModel = MLXLLM.ModelRegistry.mistral7B4bit
}
/// ```swift | ||
/// let messages = [["role": "user", "content": prompt]] | ||
/// let promptTokens = try await modelContainer.perform { context in | ||
/// try context.tokenizer.applyChatTemplate(messages: messages) | ||
/// } | ||
/// ``` | ||
/// | ||
/// or: | ||
/// | ||
/// ```swift | ||
/// let result = await modelContainer.perform { context in | ||
/// LLM.generate( | ||
/// promptTokens: promptTokens, parameters: generateParameters, model: context.model, | ||
/// tokenizer: context.tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens | ||
/// ) { tokens in | ||
/// ... | ||
/// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment outdated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, thanks for spotting that!
let inputEmbedding = languageModel.model.embedTokens(inputIds) | ||
let (hiddenState, _, _) = self.visionModel( | ||
pixelValues.transposed(0, 2, 3, 1).asType(inputEmbedding.dtype), | ||
outputHiddenStates: true | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to be pretty careful with data types in these models cause it's really easy to upcast to fp32 by accident and that can slow things down a lot or use a lot more memory (or both).
One thing I recommend doing is if you have a test suite that runs the models, making sure the output type is the same as the input type.
Here you cast the pixelValues
to the embedding type which is good. But below you cast the output back to the pixelValues
type which I'm not sure about.. I would just keep those in the same model type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good spot on that!
inputEmbedding float16, hiddenState float32, pixelValues float32
let embedDimension = imageFeatures.dim(2) | ||
let (batchSize, sequenceLength) = inputIds.shape2 | ||
var scaledImageFeatures = imageFeatures / pow(Float(config.hiddenSize), 0.5) | ||
var finalEmbedding = zeros([batchSize, sequenceLength, embedDimension]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default data type of zeros
is fp32
. That will cause anything that works with this finalEmbedding
to be upcasat to fp32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
let (inputEmbedding, finalAttentionMask4d) = inputEmbeddings( | ||
inputIds: inputIds, pixelValues: image.pixels, mask: mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to cast the inputEmbedding
to the LM dtype as well (get it from the embedding layer weight or something).. just in case they have different types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
handled inside the inpuEmbeddings function:
private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, mask: MLXArray) -> (
MLXArray, MLXArray
) {
guard let pixelValues else {
return (inputIds, mask)
}
let inputEmbedding = languageModel.model.embedTokens(inputIds)
let (hiddenState, _, _) = self.visionModel(
pixelValues.transposed(0, 2, 3, 1).asType(inputEmbedding.dtype),
imageMaskExpanded = repeated(imageMaskExpanded, count: embedDimension, axis: -1) | ||
finalEmbedding = which(imageMaskExpanded, scaledImageFeatures, finalEmbedding) | ||
|
||
finalEmbedding = which(padMaskExpanded, zeros(like: finalEmbedding), finalEmbedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In python it's better to do:
mx.where(mask, array, 0.0)
since the 0
will be broadcast and inherit the type of array
. I think the same is true in Swift?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, to avoid the zeros float32 (and maybe faster to boot because of the broadcasting instead of a realized array). done
|
||
// insert image embeddings - the image mask is always less or equal to the sentence in length | ||
var imageMaskExpanded = expandedDimensions(imageMask, axis: -1) | ||
imageMaskExpanded = repeated(imageMaskExpanded, count: embedDimension, axis: -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to explicitly repeat these.. just rely on the fact that which
broadcasts it's inputs against one another. Same is true for most of the calls to repeated
above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, went from ~92 tokens / s -> 112 tokens / s
|
||
// insert padding and text token embeddings | ||
finalEmbedding = which(textMaskExpanded, inputEmbedding, finalEmbedding) | ||
finalEmbedding = which(padMaskExpanded, zeros(like: finalEmbedding), finalEmbedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This zeros
also should be a plain scalar and inherit the type of the finalEmbedding
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Massive! Thanks for adding this!
Status: almost ready, just testing and cleaning up. Models are working. I am using a local override of mlx-swift main.
Xcode 16
Xcode 16 is required to build the example applications and tools. Older Xcode can still build the libraries via swiftpm (so no changes in requirements to any applications or libraries that refer to this).
This change is required because the xcodeproj now refers to the local
Package.swift
file to get builds consistent with external users. If needed we can switch back to using xcodeproj for library builds (internal) and swiftpm for library builds (external) -- if there is a problem please file an issue and it can be considered.Additions
There are two new libraries:
MLXVLM
contains vision language models that combine images and text prompts to produce text results, e.g.describe this image
MLXLMCommon
contains theLanguageModel
code that is shared betweenMLXLLM
andMLXVLM
The API between
LLM
andVLM
is identical aside from the preparation of theUserInput
.VLM example code is available in the
llm-tool
example:Breaking Changes
Probably no effect to code external to this repo:
Package.swift
to build the librariesimport LLM
->import MLXLLM
LLM
->MLXLLM
Breaking:
MLXLLM
andMLXLMCommon
(particularly code that loads models)MLXLMCommon
contains the common API between LLM and VLMModelConfiguration
toModelRegistry
MLXLM.ModelRegistry
and there is alsoMLXVLM.ModelRegistry
loadModelContainer()
function is nowLLMModelFactory.shared.loadContainer()
VLMModelFactory
with identical methods for loading VLMsModelContainer.perform
is now throwing (and in MLXLMCommon):ModelConfiguration
previously had a way to register new configurations. This is now onLLMModelFactory
(andVLMModelFactory
has the same):Deprecations
An example at the end shows all of these deprecations in context.
Prefer to use the
ModelContext.processor
to prepare prompts. Previously users would pass in a bare[Int]
of tokens, but in order to support more complex inputs (VLMs) the use of bare[Int]
is deprecated and callers should useUserInput
andLMInput
.For example, previously callers might have done something like this:
Now that should be:
Which will initialize a
UserInput
from the prompt text and produce anLMInput
that can be used to generate tokens.This call to
generate()
is now deprecated:This consumed the
[Int]
variety of tokens. Now this is preferred:This method on
ModelContainer
is now deprecated:use this one instead (though the former still works):
Example
Putting all of these deprecations together, previously you might have generated text like this:
now do this: