Skip to content

Commit

Permalink
qwen2 image processing
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkoski committed Dec 2, 2024
1 parent 501666a commit daaee43
Showing 1 changed file with 27 additions and 114 deletions.
141 changes: 27 additions & 114 deletions Libraries/VLM/Models/Qwen2VL.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,55 +25,6 @@ private func rotateHalf(_ x: MLXArray) -> MLXArray {

private enum Language {

fileprivate class Qwen2RotaryEmbedding {

private let dimensions: Int
private let maxPositionEmbeddings: Int
private let base: Float

private let inverseFreq: MLXArray

private var cachedSequenceLength = 0
private var cachedCos = MLXArray(0)
private var cachedSin = MLXArray(0)

public init(dimensions: Int, maxPositionEmbeddings: Int, base: Float) {
self.dimensions = dimensions
self.maxPositionEmbeddings = maxPositionEmbeddings
self.base = base

self.inverseFreq =
1.0
/ (pow(
base,
MLXArray(stride(from: 0, to: dimensions, by: 2)).asType(.float32) / dimensions))

buildCache(length: maxPositionEmbeddings)
}

private func buildCache(length: Int) {
cachedSequenceLength = length
let t = MLXArray(0 ..< cachedSequenceLength).asType(.float32)
let freqs = outer(t, inverseFreq)

// Different from paper, but it uses a different permutation in order to obtain the same calculation
let emb = concatenated([freqs, freqs], axis: -1)
cachedCos = cos(emb)
cachedSin = sin(emb)
}

public func callAsFunction(_ x: MLXArray, sequenceLength: Int) -> (MLXArray, MLXArray) {
if sequenceLength > self.cachedSequenceLength {
buildCache(length: sequenceLength)
}

return (
cachedCos[0 ..< sequenceLength].asType(x.dtype),
cachedSin[0 ..< sequenceLength].asType(x.dtype)
)
}
}

/// Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors
static private func applyMultimodalRotaryPositionEmbedding(
q: MLXArray, k: MLXArray, cos: MLXArray, sin: MLXArray,
Expand Down Expand Up @@ -114,7 +65,7 @@ private enum Language {
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear

let rotaryEmbedding: Qwen2RotaryEmbedding
@ModuleInfo(key: "rotary_emb") var rotaryEmbedding: RoPE

public init(_ args: Qwen2VLConfiguration.TextConfiguration) {
let dim = args.hiddenSize
Expand Down Expand Up @@ -143,9 +94,8 @@ private enum Language {
fatalError("rope_scaling['mrope_section'] must be an array of integers")
}

self.rotaryEmbedding = Qwen2RotaryEmbedding(
dimensions: headDim, maxPositionEmbeddings: args.maxpPositionEmbeddings,
base: args.ropeTheta)
self._rotaryEmbedding.wrappedValue = RoPE(
dimensions: headDim, traditional: args.ropeTraditional, base: args.ropeTheta)
}

public func callAsFunction(
Expand All @@ -162,30 +112,11 @@ private enum Language {
keys = keys.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3)
values = values.reshaped(B, L, kvHeads, headDim).transposed(0, 2, 1, 3)

var kvSequenceLength = keys.dim(-2)
var positionIds: MLXArray
if let cache {
kvSequenceLength += cache.offset + 1
positionIds = MLXArray(cache.offset ..< (cache.offset + L))
} else {
positionIds = MLXArray(0 ..< L)
}

positionIds = expandedDimensions(positionIds, axis: 0)
positionIds = tiled(positionIds, repetitions: [3, 1, 1])

let (cos, sin) = rotaryEmbedding(values, sequenceLength: kvSequenceLength)

let mask: MLXArray? =
if var mask {
mask[.newAxis, .newAxis, 0..., 0...][0..., 0..., 0..., ..<keys.dim(-2)]
} else {
nil
}
let offset = cache?.offset ?? 0
let mask = mask?[0..., 0 ..< keys.dim(-2)]

(queries, keys) = applyMultimodalRotaryPositionEmbedding(
q: queries, k: keys, cos: cos, sin: sin, positionIds: positionIds,
mropeSection: mropeSection)
queries = rotaryEmbedding(queries, offset: offset)
keys = rotaryEmbedding(keys, offset: offset)

if let cache {
(keys, values) = cache.update(keys: keys, values: values)
Expand Down Expand Up @@ -450,34 +381,25 @@ private enum Vision {
}

public func callAsFunction(
_ x: MLXArray, cuSequenceLengths: MLXArray, rotaryPositionEmbedding: MLXArray
_ x: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
) -> MLXArray {
let sequenceLength = x.dim(0)
let B = gridThw[0].t
let L = sequenceLength / B

let qkv = qkv(x)
.reshaped(sequenceLength, 3, self.numHeads, -1)
.transposed(1, 0, 2, 3)
let s = split(qkv, parts: 3)
let qkv = qkv(x).reshaped(sequenceLength, 3, -1)
let s = split(qkv, parts: 3, axis: 1)
var (q, k, v) = (s[0], s[1], s[2])

print("rotaryPositionEmbedding \(rotaryPositionEmbedding.shape)")

q =
applyMultimodalRotaryPositionEmbedding(
expandedDimensions(q, axis: 0), freqs: rotaryPositionEmbedding)[0]
k =
applyMultimodalRotaryPositionEmbedding(
expandedDimensions(k, axis: 0), freqs: rotaryPositionEmbedding)[0]

q = q.transposed(0, 2, 1, 3)
k = k.transposed(0, 2, 1, 3)
v = v.transposed(0, 2, 1, 3)
q = applyMultimodalRotaryPositionEmbedding(q, freqs: rotaryPositionEmbedding)
k = applyMultimodalRotaryPositionEmbedding(k, freqs: rotaryPositionEmbedding)

let mask = makeMask(
cuSequenceLengths: cuSequenceLengths, sequenceLength: sequenceLength)
q = q.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3)
k = k.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3)
v = v.reshaped(B, L, numHeads, -1).transposed(0, 2, 1, 3)

let output = MLXFast.scaledDotProductAttention(
queries: q, keys: k, values: v, scale: scale, mask: mask
queries: q, keys: k, values: v, scale: scale, mask: nil
)
.transposed(0, 2, 1, 3)
.reshaped(sequenceLength, -1)
Expand Down Expand Up @@ -523,13 +445,13 @@ private enum Vision {
}

func callAsFunction(
_ hiddenStates: MLXArray, cuSequenceLengths: MLXArray, rotaryPositionEmbedding: MLXArray
_ hiddenStates: MLXArray, gridThw: [THW], rotaryPositionEmbedding: MLXArray
) -> MLXArray {
var hiddenStates =
hiddenStates
+ attention(
norm1(hiddenStates),
cuSequenceLengths: cuSequenceLengths,
gridThw: gridThw,
rotaryPositionEmbedding: rotaryPositionEmbedding
)
hiddenStates = hiddenStates + mlp(norm2(hiddenStates))
Expand Down Expand Up @@ -619,22 +541,9 @@ private enum Vision {

let batchSize = gridThw.count

// Calculate cu_seqlens for each item in the batch
var collect = [MLXArray]()
for thw in gridThw {
let sequenceLength = thw.h * thw.w
collect.append(repeated(MLXArray(sequenceLength), count: thw.t))
}

// Concatenate the cu_seqlens for all items in the batch
var cuSeqLengths = concatenated(collect)

cuSeqLengths = cumsum(cuSeqLengths.asType(Int32.self), axis: 0)
cuSeqLengths = padded(cuSeqLengths, width: [1, 0], mode: .constant, value: MLXArray(0))

for block in blocks {
hiddenStates = block(
hiddenStates, cuSequenceLengths: cuSeqLengths,
hiddenStates, gridThw: gridThw,
rotaryPositionEmbedding: rotaryPositionEmbedding)
}

Expand Down Expand Up @@ -797,7 +706,7 @@ public class Qwen2VLProcessor: UserInputProcessor {
let mask = ones(like: promptArray)

// image_processing_qwen2_vl.preprocess
let images = try input.images.map { try preprocess(images: [$0]) }
let images = try input.images.map { try preprocess(images: [$0.asCIImage()]) }
let pixels = concatenated(images.map { $0.0 })
let image = LMInput.ProcessedImage(pixels: pixels, imageGridThw: images.map { $0.1 })

Expand Down Expand Up @@ -873,8 +782,12 @@ public class Qwen2VL: Module, VLMModel, KVCacheDimensionProvider {
-> PrepareResult
{
let gridThw = input.image?.imageGridThw

let dtype = visionModel.patchEmbed.proj.weight.dtype
let pixels = input.image?.pixels.asType(dtype)

let inputEmbeddings = self.inputEmbeddings(
inputIds: input.text.tokens, pixelValues: input.image?.pixels, gridThw: gridThw)
inputIds: input.text.tokens, pixelValues: pixels, gridThw: gridThw)

let result = languageModel(nil, cache: cache, inputEmbedding: inputEmbeddings)

Expand Down

0 comments on commit daaee43

Please sign in to comment.