Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add VLM support, refactor common LM code into MLXLMCommon. breaking API changes #151

Merged
merged 14 commits into from
Dec 10, 2024

Conversation

davidkoski
Copy link
Collaborator

@davidkoski davidkoski commented Nov 1, 2024

Status: almost ready, just testing and cleaning up. Models are working. I am using a local override of mlx-swift main.

Xcode 16

Xcode 16 is required to build the example applications and tools. Older Xcode can still build the libraries via swiftpm (so no changes in requirements to any applications or libraries that refer to this).

This change is required because the xcodeproj now refers to the local Package.swift file to get builds consistent with external users. If needed we can switch back to using xcodeproj for library builds (internal) and swiftpm for library builds (external) -- if there is a problem please file an issue and it can be considered.

Additions

There are two new libraries:

  • MLXVLM contains vision language models that combine images and text prompts to produce text results, e.g. describe this image
  • MLXLMCommon contains the LanguageModel code that is shared between MLXLLM and MLXVLM

The API between LLM and VLM is identical aside from the preparation of the UserInput.

let parameters = GenerateParameters()

// LLM prompt
let input = UserInput(prompt: "tell me a story")

// VLM prompt
let input = UserInput(prompt: "describe the image", images: [.url(url)])

// inference is identical
let result = try await modelContainer.perform { [generate, input] context in
    let input = try await context.processor.prepare(input: input)
    return try generate(input: input, parameters: parameters, context: context) { token in
        // print tokens as they are generated, stop early, etc.
        return .more
    }
}

VLM example code is available in the llm-tool example:

./mlx-run llm-tool vlm --help
OVERVIEW: evaluate prompt and images to generate text (VLM)

USAGE: llm-tool vlm <options>

OPTIONS:
  --model <model>         Name of the huggingface model or absolute path to directory
  -p, --prompt <prompt>   The message to be processed by the model.  Use @path,@path to load from files, e.g. @/tmp/prompt.txt
  --resize <resize>       Resize images to this size (width, height)
  --image <image>         Paths or urls for input images
...

Breaking Changes

Probably no effect to code external to this repo:

  • the mlx-swift-examples.xcodeproj now references the local Package.swift to build the libraries
  • the example code now uses the naming matching external uses of mlx-swift-examples, e.g. import LLM -> import MLXLLM
  • the library directories are now renamed to match their target names, e.g. LLM -> MLXLLM

Breaking:

  • some code will now need to import both MLXLLM and MLXLMCommon (particularly code that loads models)
  • MLXLMCommon contains the common API between LLM and VLM
import MLXLLM
import MLXLMCommon
  • constants for models have moved from ModelConfiguration to ModelRegistry
  • this is MLXLM.ModelRegistry and there is also MLXVLM.ModelRegistry
-    let modelConfiguration = ModelConfiguration.phi3_5_4bit
+    let modelConfiguration = ModelRegistry.phi3_5_4bit
  • the loadModelContainer() function is now LLMModelFactory.shared.loadContainer()
  • there is a new VLMModelFactory with identical methods for loading VLMs
-     let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
-    {
+     let modelContainer = try await LLMModelFactory.shared.loadContainer(
+          configuration: modelConfiguration
+    ) {
  • ModelContainer.perform is now throwing (and in MLXLMCommon):
-     let result = await modelContainer.perform { model, tokenizer in
-          LLM.generate(
+     let result = try await modelContainer.perform { model, tokenizer in
+          try MLXLMCommon.generate(
  • ModelConfiguration previously had a way to register new configurations. This is now on LLMModelFactory (and VLMModelFactory has the same):
LLMModelFactory.shared.modelRegistry.register(configurations: [modelConfiguration])

Deprecations

An example at the end shows all of these deprecations in context.

Prefer to use the ModelContext.processor to prepare prompts. Previously users would pass in a bare [Int] of tokens, but in order to support more complex inputs (VLMs) the use of bare [Int] is deprecated and callers should use UserInput and LMInput.

For example, previously callers might have done something like this:

let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
    try tokenizer.applyChatTemplate(messages: messages)
}

Now that should be:

let input = try await context.processor.prepare(input: .init(prompt: prompt))

Which will initialize a UserInput from the prompt text and produce an LMInput that can be used to generate tokens.

This call to generate() is now deprecated:

public func generate(
    promptTokens: [Int], parameters: GenerateParameters, model: any LanguageModel,
    tokenizer: Tokenizer,
    extraEOSTokens: Set<String>? = nil,
    didGenerate: ([Int]) -> GenerateDisposition
) throws -> GenerateResult

This consumed the [Int] variety of tokens. Now this is preferred:

public func generate(
    input: LMInput, parameters: GenerateParameters, context: ModelContext,
    didGenerate: ([Int]) -> GenerateDisposition
) throws -> GenerateResult

This method on ModelContainer is now deprecated:

    /// Perform an action on the model and/or tokenizer.  Callers _must_ eval any `MLXArray` before returning as
    /// `MLXArray` is not `Sendable`.
    @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext")
    public func perform<R>(_ action: @Sendable (any LanguageModel, Tokenizer) throws -> R) rethrows
        -> R

use this one instead (though the former still works):

    /// Perform an action on the ``ModelContext``.  Callers _must_ eval any `MLXArray` before returning as
    /// `MLXArray` is not `Sendable`.
    public func perform<R>(_ action: @Sendable (ModelContext) async throws -> R) async rethrows -> R

Example

Putting all of these deprecations together, previously you might have generated text like this:

            let messages = [["role": "user", "content": prompt]]
            let promptTokens = try await modelContainer.perform { _, tokenizer in
                try tokenizer.applyChatTemplate(messages: messages)
            }

            let result = await modelContainer.perform { model, tokenizer in
                LLM.generate(
                    promptTokens: promptTokens, parameters: generateParameters, model: model,
                    tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
                ) { tokens in ... }
            }

now do this:

            let result = try await modelContainer.perform { context in
                let input = try await context.processor.prepare(input: .init(prompt: prompt))
                return try MLXLMCommon.generate(
                    input: input, parameters: generateParameters, context: context
                ) { tokens in ... }
            }

@@ -1,30 +1,7 @@
// Copyright © 2024 Apple Inc.

import Foundation

public enum StringOrNumber: Codable, Equatable, Sendable {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to LMCommon


/// Container for models that guarantees single threaded access.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to ModelContainer

}
}
}
// TODO move? these cause some ambiguity -- how to resolve?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was playing around with these to avoid breaking API -- moving types into LMCommon means callers will need to import LMCommon if they refer to them. This (the aliases) caused more trouble than I think it is worth

@@ -3,6 +3,7 @@
import Foundation
@preconcurrency import Hub
import MLX
import MLXLMCommon
import MLXNN
import MLXRandom
import Tokenizers
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultimately I would like this to move into LMCommon -- I think it can support both LLM and VLM models, but I didn't get a chance to move this yet.

import MLXNN
import MLXOptimizers
import MLXRandom
import Tokenizers

/// Layers to apply LoRA adapters to.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to LMCommon

return y + scale * z
}
}

/// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``.
struct LoRABatchIterator: Sequence, IteratorProtocol {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally the rest of this moves to LMCommon as well -- I think it can.

mutating func prompt(_ prompt: MLXArray)
func process(logits: MLXArray) -> MLXArray
mutating func didSample(token: MLXArray)
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generate / step code has been refactored a bit and can now take custom logit samplers and processors

public init(
prompt: MLXArray, model: any LanguageModel, cache: [KVCache]? = nil,
parameters: GenerateParameters
) throws {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now takes either a prompt (MLXArray) or an LMInput (text + image + ...) via multiple initializers.

}
}

public struct LMInput {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new union type that holds the different inputs to generate() and LanguageModel.prepare()

}
}

public struct LMOutput {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Union type for the output. Some of the VLMs return additional state, which is represented here.

@@ -134,6 +135,7 @@ extension ModelConfiguration {
extraEOSTokens: ["<|end|>"]
)

// TODO the prompt formatter is replaced by the chat template
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it? #150


import CoreImage
import Foundation
import MLX
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file may be deleted -- it was some notes & thoughts along the way

// Copyright © 2024 Apple Inc.

import Foundation
import MLX
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also to be deleted -- LMInput replaces this

private let context = CIContext()

// TODO documentation
public enum MediaProcessing {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs documentation, but see PaliGemmaImageProvider which implements

SiglipImageProcessor {
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "SiglipImageProcessor",
  "image_seq_length": 1024,
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "processor_class": "PaliGemmaProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 448,
    "width": 448
  }
}

from the python transformers code.

import MLXNN
import Tokenizers

// MARK: - Language
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this builds, loads weights and "runs" but doesn't produce any output -- still needs to be debugged.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be usable as an example of the structure I think we need

}
}

// TODO does not suport multiple images -- how do we represent?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a protocol for the image and text processing pieces.

image = MediaProcessing.inSRGBToneCurveSpace(image)

image = MediaProcessing.resampleBicubic(image, to: .init(width: size, height: size))
image = MediaProcessing.normalize(image, mean: (0.5, 0.5, 0.5), std: (0.5, 0.5, 0.5))
Copy link
Collaborator Author

@davidkoski davidkoski Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SiglipImageProcessor {
  "do_convert_rgb": null,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "SiglipImageProcessor",
  "image_seq_length": 1024,
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "processor_class": "PaliGemmaProcessor",
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 448,
    "width": 448
  }
}

}
}

private func loadConfiguration(url: URL) throws -> PaliGemma {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These next couple of functions are just stubs to let me try it out -- this will work much like the LLM models

private let _ropeTheta: Float?
public var ropeTheta: Float { _ropeTheta ?? 10_000 }
public let _ropeTraditional: Bool?
public var ropeTraditional: Bool { _ropeTraditional ?? false }
Copy link
Collaborator Author

@davidkoski davidkoski Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than doing the full implementation of Codable I went a simpler route for default values. Less code, cleaner (I think)

@Option var path: URL

@MainActor
mutating func run() async throws {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just stub code to exercise the model. This still needs the input processing layers, in particular the prompt processing. The image processing is in place but will need to be wrapped up API-wise.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now the real code

Base automatically changed from v0.18.1 to main November 6, 2024 23:40
@davidkoski davidkoski force-pushed the vlm1 branch 2 times, most recently from e19f736 to 5ffe9b3 Compare November 19, 2024 16:15
@davidkoski davidkoski changed the title initial commit of vlm add VLM support, refactor common LM code into MLXLMCommon. breaking API changes Dec 4, 2024
@davidkoski davidkoski requested a review from awni December 4, 2024 19:16
import MLX
import MLXLLM
import MLXLMCommon
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See PR description -- split LLM -> LLM and LMCommon. Switched local names to match what people get via swiftpm (MLXLLM, etc.).

@@ -159,7 +160,7 @@ class LLMEvaluator {

/// This controls which model loads. `phi3_5_4bit` is one of the smaller ones, so this will fit on
/// more devices.
let modelConfiguration = ModelConfiguration.phi3_5_4bit
let modelConfiguration = ModelRegistry.phi3_5_4bit
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From PR description:

  • constants for models have moved from ModelConfiguration to ModelRegistry
  • this is MLXLM.ModelRegistry and there is also MLXVLM.ModelRegistry
-    let modelConfiguration = ModelConfiguration.phi3_5_4bit
+    let modelConfiguration = ModelRegistry.phi3_5_4bit

- based on models from https://github.com/Blaizzy/mlx-vlm

There are two new libraries:

- `MLXVLM` contains vision language models that combine images and text prompts to produce text results, e.g. `describe this image`
- `MLXLMCommon` contains the `LanguageModel` code that is shared between `MLXLLM` and `MLXVLM`

The API between `LLM` and `VLM` is identical aside from the preparation of the `UserInput`.

```swift
let parameters = GenerateParameters()

// LLM prompt
let input = UserInput(prompt: "tell me a story")

// VLM prompt
let input = UserInput(prompt: "describe the image", images: [.url(url)])

// inference is identical
let result = try await modelContainer.perform { [generate, input] context in
    let input = try await context.processor.prepare(input: input)
    return try generate(input: input, parameters: parameters, context: context) { token in
        // print tokens as they are generated, stop early, etc.
        return .more
    }
}
```

VLM example code is available in the `llm-tool` example:

```
./mlx-run llm-tool vlm --help
OVERVIEW: evaluate prompt and images to generate text (VLM)

USAGE: llm-tool vlm <options>

OPTIONS:
  --model <model>         Name of the huggingface model or absolute path to directory
  -p, --prompt <prompt>   The message to be processed by the model.  Use @path,@path to load from files, e.g. @/tmp/prompt.txt
  --resize <resize>       Resize images to this size (width, height)
  --image <image>         Paths or urls for input images
...
```

Probably no effect to code external to this repo:

- the mlx-swift-examples.xcodeproj now references the local `Package.swift` to build the libraries
- the example code now uses the naming matching external uses of mlx-swift-examples, e.g. `import LLM` -> `import MLXLLM`
- the library directories are now renamed to match their target names, e.g. `LLM` -> `MLXLLM`

Breaking:

- some code will now need to import both `MLXLLM` and `MLXLMCommon` (particularly code that loads models)
- `MLXLMCommon` contains the common API between LLM and VLM

```swift
import MLXLLM
import MLXLMCommon
```

- constants for models have moved from `ModelConfiguration` to `ModelRegistry`
- this is `MLXLM.ModelRegistry` and there is also `MLXVLM.ModelRegistry`

```diff
-    let modelConfiguration = ModelConfiguration.phi3_5_4bit
+    let modelConfiguration = ModelRegistry.phi3_5_4bit
```

- the `loadModelContainer()` function is now `LLMModelFactory.shared.loadContainer()`
- there is a new `VLMModelFactory` with identical methods for loading VLMs

```diff
-     let modelContainer = try await LLM.loadModelContainer(configuration: modelConfiguration)
-    {
+     let modelContainer = try await LLMModelFactory.shared.loadContainer(
+          configuration: modelConfiguration
+    ) {
```

- `ModelContainer.perform` is now throwing (and in MLXLMCommon):

```diff
-     let result = await modelContainer.perform { model, tokenizer in
-          LLM.generate(
+     let result = try await modelContainer.perform { model, tokenizer in
+          try MLXLMCommon.generate(
```

- `ModelConfiguration` previously had a way to register new configurations.  This is now on `LLMModelFactory` (and `VLMModelFactory` has the same):

```swift
LLMModelFactory.shared.modelRegistry.register(configurations: [modelConfiguration])
```

An example at the end shows all of these deprecations in context.

**Prefer to use the `ModelContext.processor` to prepare prompts.**  Previously users would pass in a bare `[Int]` of tokens, but in order to support more complex inputs (VLMs) the use of bare `[Int]` is deprecated and callers should use `UserInput` and `LMInput`.

For example, previously callers might have done something like this:

```swift
let messages = [["role": "user", "content": prompt]]
let promptTokens = try await modelContainer.perform { _, tokenizer in
    try tokenizer.applyChatTemplate(messages: messages)
}
```

Now that should be:

```swift
let input = try await context.processor.prepare(input: .init(prompt: prompt))
```

Which will initialize a `UserInput` from the prompt text and produce an `LMInput` that can be used to generate tokens.

**This call to `generate()` is now deprecated:**

```swift
public func generate(
    promptTokens: [Int], parameters: GenerateParameters, model: any LanguageModel,
    tokenizer: Tokenizer,
    extraEOSTokens: Set<String>? = nil,
    didGenerate: ([Int]) -> GenerateDisposition
) throws -> GenerateResult
```

This consumed the `[Int]` variety of tokens.  Now this is preferred:

```swift
public func generate(
    input: LMInput, parameters: GenerateParameters, context: ModelContext,
    didGenerate: ([Int]) -> GenerateDisposition
) throws -> GenerateResult
```

**This method on `ModelContainer` is now deprecated:**

```swift
    /// Perform an action on the model and/or tokenizer.  Callers _must_ eval any `MLXArray` before returning as
    /// `MLXArray` is not `Sendable`.
    @available(*, deprecated, message: "prefer perform(_:) that uses a ModelContext")
    public func perform<R>(_ action: @sendable (any LanguageModel, Tokenizer) throws -> R) rethrows
        -> R
```

use this one instead (though the former still works):

```swift
    /// Perform an action on the ``ModelContext``.  Callers _must_ eval any `MLXArray` before returning as
    /// `MLXArray` is not `Sendable`.
    public func perform<R>(_ action: @sendable (ModelContext) async throws -> R) async rethrows -> R
```

Putting all of these deprecations together, previously you might have generated text like this:

```swift
            let messages = [["role": "user", "content": prompt]]
            let promptTokens = try await modelContainer.perform { _, tokenizer in
                try tokenizer.applyChatTemplate(messages: messages)
            }

            let result = await modelContainer.perform { model, tokenizer in
                LLM.generate(
                    promptTokens: promptTokens, parameters: generateParameters, model: model,
                    tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
                ) { tokens in ... }
            }
```

now do this:

```swift
            let result = try await modelContainer.perform { context in
                let input = try await context.processor.prepare(input: .init(prompt: prompt))
                return try MLXLMCommon.generate(
                    input: input, parameters: generateParameters, context: context
                ) { tokens in ... }
            }
```
@davidkoski
Copy link
Collaborator Author

This code is ready for review!

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incredibly cool. I barely touched the surface but leaving a small review and going to try running it shortly.

structure something like this:

```swift
public class YourModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw I changed the KV cache implementation in mlx-lm to just init the keys and values the first time you call it. There is no need to initialize the KV cache with a head dim etc. so we could probably remove this interface as well. (Just a comment not something that we need to update in this PR)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will take a look at it -- if it simplifies things it may be worth including here as we are already making some breaking changes.

Copy link
Collaborator Author

@davidkoski davidkoski Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • revisit KVCache / mlx-lm

Comment on lines 89 to 90
public let kvHeads: [Int]
public let headDim: IntOrPair
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And e.g. got rid of this which is not necessary

Libraries/MLXLLM/README.md Outdated Show resolved Hide resolved
let (modelContainer, modelConfiguration) = try await memory.start(args.load)
let modelContainer = try await memory.start { [args] in
try await args.load(
defaultModel: "mlx-community/Mistral-7B-v0.1-hf-4bit-mlx",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update this default model, it's pretty dated. Maybe to mlx-community/Mistral-7B-Instruct-v0.3-4bit is a good option?

Copy link
Collaborator Author

@davidkoski davidkoski Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will give it a run and make sure it works!

  • test this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is one of the preset models, so good to go

Tools/llm-tool/LLMTool.swift Outdated Show resolved Hide resolved
@@ -203,29 +206,88 @@ struct EvaluateCommand: AsyncParsableCommand {

@MainActor
mutating func run() async throws {
let (modelContainer, modelConfiguration) = try await memory.start(args.load)
let modelContainer = try await memory.start { [args] in
Copy link
Member

@awni awni Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to LMCommand and subcommand lm, to match the VLMCommand.

Alternatively (given the complexity) it might be worth using the same subcommand and just dispatching to the vlm subroutine if an image input is provided or not..

Copy link
Collaborator Author

@davidkoski davidkoski Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting idea! The default model is different, as is the model factory. We could certainly switch on the presence of an image (or video) to chose but I wonder if that complicates things over just having the two subcommands?

Let me try the refactor to fold these down into one and see if that looks reasonable.

  • try refactor of vlm -> eval (lm) command

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it was a slightly off the cuff suggestion. It simplifies the command line but it might not be worth doing at the expense of code complexity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that worked well -- it came down to this (mostly):

        // switch between LLM and VLM
        let vlm = image.count > 0
        if vlm {
            modelFactory = VLMModelFactory.shared
            defaultModel = MLXVLM.ModelRegistry.paligemma3bMix448_8bit
        } else {
            modelFactory = LLMModelFactory.shared
            defaultModel = MLXLLM.ModelRegistry.mistral7B4bit
        }

Comment on lines 14 to 30
/// ```swift
/// let messages = [["role": "user", "content": prompt]]
/// let promptTokens = try await modelContainer.perform { context in
/// try context.tokenizer.applyChatTemplate(messages: messages)
/// }
/// ```
///
/// or:
///
/// ```swift
/// let result = await modelContainer.perform { context in
/// LLM.generate(
/// promptTokens: promptTokens, parameters: generateParameters, model: context.model,
/// tokenizer: context.tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens
/// ) { tokens in
/// ...
/// }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment outdated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thanks for spotting that!

Comment on lines +552 to +556
let inputEmbedding = languageModel.model.embedTokens(inputIds)
let (hiddenState, _, _) = self.visionModel(
pixelValues.transposed(0, 2, 3, 1).asType(inputEmbedding.dtype),
outputHiddenStates: true
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to be pretty careful with data types in these models cause it's really easy to upcast to fp32 by accident and that can slow things down a lot or use a lot more memory (or both).

One thing I recommend doing is if you have a test suite that runs the models, making sure the output type is the same as the input type.

Here you cast the pixelValues to the embedding type which is good. But below you cast the output back to the pixelValues type which I'm not sure about.. I would just keep those in the same model type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot on that!

inputEmbedding float16, hiddenState float32, pixelValues float32

let embedDimension = imageFeatures.dim(2)
let (batchSize, sequenceLength) = inputIds.shape2
var scaledImageFeatures = imageFeatures / pow(Float(config.hiddenSize), 0.5)
var finalEmbedding = zeros([batchSize, sequenceLength, embedDimension])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default data type of zeros is fp32. That will cause anything that works with this finalEmbedding to be upcasat to fp32.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +614 to +615
let (inputEmbedding, finalAttentionMask4d) = inputEmbeddings(
inputIds: inputIds, pixelValues: image.pixels, mask: mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to cast the inputEmbedding to the LM dtype as well (get it from the embedding layer weight or something).. just in case they have different types.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handled inside the inpuEmbeddings function:

    private func inputEmbeddings(inputIds: MLXArray, pixelValues: MLXArray?, mask: MLXArray) -> (
        MLXArray, MLXArray
    ) {
        guard let pixelValues else {
            return (inputIds, mask)
        }

        let inputEmbedding = languageModel.model.embedTokens(inputIds)
        let (hiddenState, _, _) = self.visionModel(
            pixelValues.transposed(0, 2, 3, 1).asType(inputEmbedding.dtype),

imageMaskExpanded = repeated(imageMaskExpanded, count: embedDimension, axis: -1)
finalEmbedding = which(imageMaskExpanded, scaledImageFeatures, finalEmbedding)

finalEmbedding = which(padMaskExpanded, zeros(like: finalEmbedding), finalEmbedding)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In python it's better to do:

mx.where(mask, array, 0.0) since the 0 will be broadcast and inherit the type of array. I think the same is true in Swift?

Copy link
Collaborator Author

@davidkoski davidkoski Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, to avoid the zeros float32 (and maybe faster to boot because of the broadcasting instead of a realized array). done


// insert image embeddings - the image mask is always less or equal to the sentence in length
var imageMaskExpanded = expandedDimensions(imageMask, axis: -1)
imageMaskExpanded = repeated(imageMaskExpanded, count: embedDimension, axis: -1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to explicitly repeat these.. just rely on the fact that which broadcasts it's inputs against one another. Same is true for most of the calls to repeated above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, went from ~92 tokens / s -> 112 tokens / s


// insert padding and text token embeddings
finalEmbedding = which(textMaskExpanded, inputEmbedding, finalEmbedding)
finalEmbedding = which(padMaskExpanded, zeros(like: finalEmbedding), finalEmbedding)
Copy link
Member

@awni awni Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This zeros also should be a plain scalar and inherit the type of the finalEmbedding.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Massive! Thanks for adding this!

@davidkoski davidkoski merged commit 6ef303b into main Dec 10, 2024
1 check passed
@davidkoski davidkoski deleted the vlm1 branch December 10, 2024 19:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants