Skip to content

Commit

Permalink
qwen2-vl working
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkoski committed Dec 2, 2024
1 parent daaee43 commit 36f6433
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 50 deletions.
2 changes: 2 additions & 0 deletions Libraries/LMCommon/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public struct THW: Sendable {
public var values: (Int, Int, Int) {
(t, h, w)
}

public var product: Int { t * h * w }
}

extension Array where Element == THW {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@ import CoreImage
import Foundation
import MLX

public enum UserInputPrompt: Sendable {
case text(String)
case messages([[String: String]])

public func asMessages() -> [[String: String]] {
switch self {
case .text(let text):
return [["role": "user", "content": text]]
case .messages(let messages):
return messages
public struct UserInput: Sendable {

public enum Prompt: Sendable {
case text(String)
case messages([[String: String]])

public func asMessages() -> [[String: String]] {
switch self {
case .text(let text):
return [["role": "user", "content": text]]
case .messages(let messages):
return messages
}
}
}
}

public struct UserInput: Sendable {

public enum Image: Sendable {
case ciImage(CIImage)
Expand Down Expand Up @@ -83,8 +83,13 @@ public struct UserInput: Sendable {
}
}

public var prompt: UserInputPrompt
public struct Processing: Sendable {
public var resize: CGSize?
}

public var prompt: Prompt
public var images = [Image]()
public var processing: Processing = .init()

public init(prompt: String, images: [Image] = [Image]()) {
self.prompt = .text(prompt)
Expand All @@ -96,7 +101,7 @@ public struct UserInput: Sendable {
self.images = images
}

public init(prompt: UserInputPrompt, images: [Image] = [Image]()) {
public init(prompt: Prompt, images: [Image] = [Image]()) {
self.prompt = prompt
self.images = images
}
Expand Down
14 changes: 13 additions & 1 deletion Libraries/VLM/MediaProcessing.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright © 2024 Apple Inc.

@preconcurrency import CoreImage.CIFilterBuiltins
import CoreImage.CIFilterBuiltins
import MLX
import MLXLMCommon

private let context = CIContext()

Expand Down Expand Up @@ -111,4 +112,15 @@ public enum MediaProcessing {

return array
}

static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage {
var image = image

if let resize = processing?.resize {
let scale = bestFitScale(image.extent.size, in: resize)
image = image.transformed(by: CGAffineTransform(scaleX: scale, y: scale))
}

return image
}
}
9 changes: 6 additions & 3 deletions Libraries/VLM/Models/Paligemma.swift
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
self.tokenizer = tokenizer
}

public func convert(image: CIImage) -> MLXArray {
private func prepare(image: CIImage, processing: UserInput.Processing?) -> MLXArray {
// based on image_processing_siglip from transformers
var image = image

Expand All @@ -459,6 +459,9 @@ public class PaligGemmaProcessor: UserInputProcessor {
// do (implicitly by using sRGB rasters directly)
image = MediaProcessing.inSRGBToneCurveSpace(image)

// apply user instructions
image = MediaProcessing.apply(image, processing: processing)

image = MediaProcessing.resampleBicubic(image, to: config.size.cgSize)
image = MediaProcessing.normalize(
image, mean: config.imageMeanTuple, std: config.imageStdTuple)
Expand All @@ -473,7 +476,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
default: throw VLMError.singleImageAllowed
}

// this doesn't have a chat template so just use the last message
// this doesn't have a chat template so just use the last message.
var prompt = input.prompt.asMessages().last?["content"] ?? ""

// based on transformers/processing_paligemma
Expand All @@ -486,7 +489,7 @@ public class PaligGemmaProcessor: UserInputProcessor {
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
let mask = ones(like: promptArray)

let pixels = try convert(image: input.images[0].asCIImage())
let pixels = try prepare(image: input.images[0].asCIImage(), processing: input.processing)

return LMInput(text: .init(tokens: promptArray, mask: mask), image: .init(pixels: pixels))
}
Expand Down
93 changes: 64 additions & 29 deletions Libraries/VLM/Models/Qwen2VL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,8 @@ private enum Vision {
}

func callAsFunction(sequenceLength: Int) -> MLXArray {
let inverseFreq =
1.0
/ (pow(
theta,
MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions))
let p = MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions
let inverseFreq = 1.0 / pow(theta, p)
let seq = MLXArray(0 ..< sequenceLength).asType(inverseFreq.dtype)
let freqs = outer(seq, inverseFreq)
return freqs
Expand Down Expand Up @@ -370,16 +367,6 @@ private enum Vision {
self._proj.wrappedValue = Linear(dims, dims)
}

private func makeMask(cuSequenceLengths: MLXArray, sequenceLength: Int) -> MLXArray {
let starts = cuSequenceLengths[.newAxis, ..<(-1)]
let ends = cuSequenceLengths[.newAxis, 1...]
let indices = MLXArray(0 ..< sequenceLength)[0..., .newAxis]
var mask = (indices .>= starts) & (indices .< ends)
mask = mask.any(axis: -1)
mask = mask[.newAxis] & mask[0..., .newAxis]
return 1 - mask
}

public func callAsFunction(
_ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
) -> MLXArray {
Expand All @@ -391,6 +378,10 @@ private enum Vision {
let s = split(qkv, parts: 3, axis: 1)
var (q, k, v) = (s[0], s[1], s[2])

q = q.reshaped(sequenceLength, numHeads, -1)
k = k.reshaped(sequenceLength, numHeads, -1)
v = v.reshaped(sequenceLength, numHeads, -1)

q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding)
k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding)

Expand Down Expand Up @@ -530,8 +521,6 @@ private enum Vision {
let rotaryPositionEmbedFull = rotaryPositionEmbedding(sequenceLength: maxGridSize)[
indices]

print("rot_pos_emb(), \(maxGridSize) \(gridThw), \(rotaryPositionEmbedFull.shape)")

return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1)
}

Expand Down Expand Up @@ -634,16 +623,20 @@ public class Qwen2VLProcessor: UserInputProcessor {
return (hBar, wBar)
}

public func preprocess(images: [CIImage]) throws -> (MLXArray, THW) {
public func preprocess(images: [CIImage], processing: UserInput.Processing?) throws -> (
MLXArray, THW
) {

// image_processing_qwen2_vl._preprocess

let images = images.map { MediaProcessing.apply($0, processing: processing) }

let size = images[0].extent.size
let (resizedHeight, resizedWidth) = try targetSize(
height: Int(size.height), width: Int(size.width),
factor: config.patchSize * config.mergeSize,
minPixels: config.size.minPixels, maxPixels: config.size.maxPixels)
let resizedSize = CGSize(width: resizedHeight, height: resizedWidth)
let resizedSize = CGSize(width: resizedWidth, height: resizedHeight)

let processedImages =
try images
Expand Down Expand Up @@ -691,26 +684,68 @@ public class Qwen2VLProcessor: UserInputProcessor {
return (flattenedPatches, .init(gridT, gridH, gridW))
}

public func prepare(input: UserInput) throws -> LMInput {
// this doesn't have a chat template so just use the last message
let prompt = input.prompt.asMessages().last?["content"] ?? ""
public func prepare(prompt: UserInput.Prompt, imageTHW: [THW]?) -> String {
// the tokenizer does have a chat template and it expects messages
// like this:
//
// [{'role': 'user', 'content': [{'type': 'text', 'text': 'What are these?'},
// {'type': 'image'}, {'type': 'image'}, {'type': 'image'}]}]
//
// The output of the prompt template is fed into
// image_processing_qwen2_vl.preprocess where it is further augmented
// by replacing tokens according to imageTHW.
//
// Neither the structured content nor the postprocessing of the template
// are supported in current Tokenizer/Jinja (swift) so handle that here.

var messages = prompt.asMessages()
if messages[0]["role"] != "system" {
messages.insert(["role": "system", "content": "You are a helpful assistant."], at: 0)
}

let lastIndex = messages.count - 1
var lastMessage = messages[lastIndex]["content"] ?? ""

// image_processing_qwen2_vl.preprocess -- inject image_pad tokens for each image
let mergeLength = config.mergeSize * config.mergeSize
for thw in imageTHW ?? [] {
lastMessage += "<|vision_start|>"
lastMessage += Array(repeating: "<|image_pad|>", count: thw.product / mergeLength)
.joined()
lastMessage += "<|vision_end|>"
}

messages[lastIndex]["content"] = lastMessage

return
messages
.map {
"<|im_start|>\($0["role"] ?? "user")\n\($0["content"] ?? "")<|im_end|>"
}
.joined(separator: "\n")
+ "\n<|im_start|>assistant\n"
}

public func prepare(input: UserInput) throws -> LMInput {
if input.images.isEmpty {
// just a straight text prompt
let prompt = prepare(prompt: input.prompt, imageTHW: nil)
let promptTokens = try tokenizer.encode(text: prompt)
return LMInput(tokens: MLXArray(promptTokens))
}

let promptTokens = try tokenizer.encode(text: prompt)
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
let mask = ones(like: promptArray)

// image_processing_qwen2_vl.preprocess
let images = try input.images.map { try preprocess(images: [$0.asCIImage()]) }
let images = try input.images.map {
try preprocess(images: [$0.asCIImage()], processing: input.processing)
}
let pixels = concatenated(images.map { $0.0 })
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })

print("image \(image.pixels.shape), \(image.imageGridThw)")
// processing_qwen2_vl.Qwen2VLProcessor
let prompt = prepare(prompt: input.prompt, imageTHW: image.imageGridThw)
let promptTokens = try tokenizer.encode(text: prompt)
let promptArray = MLXArray(promptTokens).expandedDimensions(axis: 0)
let mask = ones(like: promptArray)

return LMInput(text: .init(tokens: promptArray, mask: mask), image: image)
}
Expand Down Expand Up @@ -773,7 +808,7 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
imageIndices.append(i)
}
}
// TODO look at the inputIds -- I think I am missing something here

inputEmbeds[0..., MLXArray(imageIndices), 0...] = imageFeatures
return inputEmbeds
}
Expand Down
17 changes: 15 additions & 2 deletions Tools/llm-tool/LLMTool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,27 @@ struct VLMCommand: AsyncParsableCommand {
mutating func run() async throws {
let modelContainer = try await memory.start { [args] in
try await args.load(
defaultModel: "mlx-community/paligemma-3b-mix-448-8bit",
defaultModel: MLXVLM.ModelRegistry.paligemma3bMix4488bit.name,
modelFactory: VLMModelFactory.shared)
}
let modelConfiguration = modelContainer.configuration

let prompt = generate.prompt ?? modelConfiguration.defaultPrompt

let input = UserInput(prompt: prompt, images: image.map { .url($0) })
var input = UserInput(prompt: prompt, images: image.map { .url($0) })

if !resize.isEmpty {
let size: CGSize
if resize.count == 1 {
let v = resize[0]
size = CGSize(width: v, height: v)
} else {
let v0 = resize[0]
let v1 = resize[0]
size = CGSize(width: v0, height: v1)
}
input.processing.resize = size
}

let result = try await modelContainer.perform { [generate] context in
let input = try context.processor.prepare(input: input)
Expand Down

0 comments on commit 36f6433

Please sign in to comment.