diff --git a/Libraries/VLM/Models/Qwen2VL.swift b/Libraries/VLM/Models/Qwen2VL.swift index 8c8f055..6bd85b1 100644 --- a/Libraries/VLM/Models/Qwen2VL.swift +++ b/Libraries/VLM/Models/Qwen2VL.swift @@ -460,140 +460,152 @@ private enum Vision { } } - fileprivate class PhiMLP: Module, UnaryLayer { + fileprivate class MLP: Module, UnaryLayer { + @ModuleInfo var activation: GELU @ModuleInfo var fc1: Linear @ModuleInfo var fc2: Linear - public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { - self.fc1 = Linear(config.hiddenSize, config.intermediateSize, bias: true) - self.fc2 = Linear(config.intermediateSize, config.hiddenSize, bias: true) + public init(dimensions: Int, hiddenDimensions: Int) { + self.activation = GELU(approximation: .fast) + self.fc1 = Linear(dimensions, hiddenDimensions) + self.fc2 = Linear(hiddenDimensions, dimensions) } public func callAsFunction(_ x: MLXArray) -> MLXArray { - fc2(geluApproximate(fc1(x))) - } - } - - fileprivate class EncoderLayer: Module { - - @ModuleInfo(key: "self_attn") var attention: Attention - @ModuleInfo(key: "layer_norm1") var layerNorm1: LayerNorm - @ModuleInfo var mlp: PhiMLP - @ModuleInfo(key: "layer_norm2") var layerNorm2: LayerNorm - - public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { - self._attention.wrappedValue = Attention( - dims: config.hiddenSize, numHeads: config.attentionHeads, bias: true) - self._layerNorm1.wrappedValue = LayerNorm( - dimensions: config.hiddenSize, eps: config.layerNormEps) - self.mlp = PhiMLP(config) - self._layerNorm2.wrappedValue = LayerNorm( - dimensions: config.hiddenSize, eps: config.layerNormEps) - } - - public func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil) -> MLXArray { - var r = attention(layerNorm1(x), mask: mask) - let h = x + r - r = mlp(layerNorm2(h)) - return h + r + fc2(activation(fc1(x))) } } - - fileprivate class Encoder: Module { - var layers: [EncoderLayer] - + + fileprivate class Qwen2VLVisionBlock: Module { + + @ModuleInfo var norm1: LayerNorm + @ModuleInfo var norm2: LayerNorm + @ModuleInfo(key: "attn") var attention: Attention + @ModuleInfo var mlp: MLP + public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { - self.layers = (0 ..< config.hiddenLayers).map { _ in - EncoderLayer(config) - } - } - - public func callAsFunction( - _ x: MLXArray, outputHiddenStates: Bool = false, mask: MLXArray? = nil - ) -> (MLXArray, [MLXArray]?) { - var encoderStates: [MLXArray]? = outputHiddenStates ? [] : nil - var h = x - var x = x - for l in layers { - x = l(x, mask: mask) - if outputHiddenStates { - encoderStates?.append(x) - } - h = x[0] - } - return (h, encoderStates) + self.norm1 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6) + self.norm2 = LayerNorm(dimensions: config.embedDimensions, eps: 1e-6) + + self._attention.wrappedValue = Attention(dims: config.embedDimensions, numHeads: config.numHeads) + + let mlpHiddenDimensions = Int(Float(config.embedDimensions) * config.mlpRatio) + self.mlp = MLP(dimensions: config.embedDimensions, hiddenDimensions: mlpHiddenDimensions) } - } - - fileprivate class VisionEmbeddings: Module, UnaryLayer { - - @ModuleInfo(key: "patch_embedding") var patchEmbedding: Conv2d - @ModuleInfo(key: "position_embedding") var positionEmbedding: Embedding - - let positions: Int - let positionIds: MLXArray - - public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { - self._patchEmbedding.wrappedValue = Conv2d( - inputChannels: config.channels, outputChannels: config.hiddenSize, - kernelSize: .init(config.patchSize), stride: .init(config.patchSize) - ) - let d = config.imageSize / config.patchSize - self.positions = d * d - self._positionEmbedding.wrappedValue = Embedding( - embeddingCount: positions, dimensions: config.hiddenSize + + func callAsFunction(_ hiddenStates: MLXArray, cuSequenceLengths: MLXArray, rotaryPositionEmbedding: MLXArray) -> MLXArray { + var hiddenStates = hiddenStates + attention( + norm1(hiddenStates), + cuSequenceLengths: cuSequenceLengths, + rotaryPositionEmbedding: rotaryPositionEmbedding ) - self.positionIds = MLXArray(0 ..< positions)[.newAxis, 0...] - } - - public func callAsFunction(_ x: MLXArray) -> MLXArray { - var patchEmbeddings = self.patchEmbedding(x) - patchEmbeddings = patchEmbeddings.flattened(start: 1, end: 2) - let embeddings = patchEmbeddings + self.positionEmbedding(self.positionIds) - return embeddings - } - } - - fileprivate class SigLipVisionModel: Module { - - @ModuleInfo var embeddings: VisionEmbeddings - @ModuleInfo var encoder: Encoder - @ModuleInfo(key: "post_layernorm") var postLayerNorm: LayerNorm - - public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { - self.embeddings = VisionEmbeddings(config) - self.encoder = Encoder(config) - self._postLayerNorm.wrappedValue = LayerNorm(dimensions: config.hiddenSize) - } - - public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> ( - MLXArray, MLXArray, MLXArray? - ) { - let x = embeddings(x) - - let (encoderOutput, hiddenStates) = encoder(x, outputHiddenStates: outputHiddenStates) - let poolerOutput = postLayerNorm(encoderOutput) - - return (poolerOutput, x, hiddenStates?.last) + hiddenStates = hiddenStates + mlp(norm2(hiddenStates)) + return hiddenStates } } fileprivate class VisionModel: Module { - @ModuleInfo(key: "vision_model") var visionModel: SigLipVisionModel + @ModuleInfo(key: "patch_embed") var patchEmbed: PatchEmbed + @ModuleInfo(key: "rotary_pos_emb") var rotaryPositionEmbedding: VisionRotaryEmbedding + @ModuleInfo(key: "blocks") var blocks: [Qwen2VLVisionBlock] + @ModuleInfo(key: "merger") var patchMerger: PatchMerger + + let spatialMergeSize: Int public init(_ config: Qwen2VLConfiguration.VisionConfiguration) { precondition( - config.modelType == "siglip_vision_model", + config.modelType == "qwen2_vl", "Unsupported modelType: \(config.modelType)") - self._visionModel.wrappedValue = SigLipVisionModel(config) + + self.spatialMergeSize = config.spatialMergeSize + + self._patchEmbed.wrappedValue = PatchEmbed( + patchSize: config.patchSize, + temporalPatchSize: config.temporalPatchSize, + inChannels: config.inChannels, + embedDimensions: config.embedDimensions) + + let headDimensions = config.embedDimensions / config.numHeads + self._rotaryPositionEmbedding.wrappedValue = VisionRotaryEmbedding(dimensions: headDimensions, theta: 10_000) + + self._blocks.wrappedValue = (0 ..< config.depth).map { _ in + Qwen2VLVisionBlock(config) + } + self.patchMerger = PatchMerger(dimensions: config.hiddenSize, contextDimensions: config.embedDimensions, spatialMergeSize: 2) + } + + func rotaryPositionEmbedding(_ gridThw: MLXArray) -> MLXArray { + var positionIds = [MLXArray]() + + for row in gridThw { + // TODO NOTE: this evaluates gridThw -- it shouldn't do that + let t = row[0].item(Int.self) + let h = row[1].item(Int.self) + let w = row[2].item(Int.self) + + var hposIds = expandedDimensions(MLXArray(0 ..< h), axis: 1) + hposIds = repeated(hposIds, count: w, axis: 1) + hposIds = hposIds + .reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize) + .transposed(0, 2, 1, 3) + .flattened() + + var wposIds = expandedDimensions(MLXArray(0 ..< w), axis: 0) + wposIds = repeated(wposIds, count: h, axis: 0) + wposIds = hposIds + .reshaped( + h / spatialMergeSize, + spatialMergeSize, + w / spatialMergeSize, + spatialMergeSize) + .transposed(0, 2, 1, 3) + .flattened() + + let stackedPosIds = stacked([hposIds, wposIds], axis: -1) + positionIds.append(repeated(stackedPosIds, count: t, axis: 0)) + } + + let indices = concatenated(positionIds, axis: 0) + let maxGridSize = max(gridThw[0..., 1...]) + let rotaryPositionEmbedFull = rotaryPositionEmbedding(maxGridSize)[indices] + + return rotaryPositionEmbedFull.reshaped(indices.dim(0), -1) } - public func callAsFunction(_ x: MLXArray, outputHiddenStates: Bool = false) -> ( - MLXArray, MLXArray, MLXArray? - ) { - visionModel(x, outputHiddenStates: outputHiddenStates) + public func callAsFunction(_ hiddenStates: MLXArray, gridThw: MLXArray) -> MLXArray { + var hiddenStates = patchEmbed(hiddenStates) + let rotaryPositionEmbedding = rotaryPositionEmbedding(gridThw) + + // Assuming grid_thw has shape (batch_size, 3) + let batchSize = gridThw.dim(0) + + // Calculate cu_seqlens for each item in the batch + var collect = [MLXArray]() + for i in 0 ..< batchSize { + let sequenceLength = gridThw[i, 1] * gridThw[i, 2] + + // TODO NOTE: this evaluates gridThw -- it shouldn't do that + let t = gridThw[i, 0].item(Int.self) + collect.append(repeated(sequenceLength, count: t)) + } + + // Concatenate the cu_seqlens for all items in the batch + var cuSeqLengths = concatenated(collect) + + cuSeqLengths = cumsum(cuSeqLengths.asType(Int32.self), axis: 0) + cuSeqLengths = padded(cuSeqLengths, width: [1, 0], mode: .constant, value: MLXArray(0)) + + for block in blocks { + hiddenStates = block(hiddenStates, cuSequenceLengths: cuSeqLengths, rotaryPositionEmbedding: rotaryPositionEmbedding) + } + + return patchMerger(hiddenStates) } private func isMLXWeight(_ array: MLXArray) -> Bool { @@ -616,7 +628,10 @@ private enum Vision { if k.contains("position_id") { // Remove unused position_ids continue - } else if k.contains("patch_embedding.weight") { + } else if k.contains("patch_embed.proj.weight") { + // TODO: this comment doesn't match -- based on above code I presume + // the first dimension is now B + // PyTorch conv2d weight tensors have shape: // [out_channels, in_channels, kH, KW] // MLX conv2d expects the weight be of shape: @@ -624,7 +639,7 @@ private enum Vision { if isMLXWeight(v) { sanitizedWeights[k] = v } else { - sanitizedWeights[k] = v.transposed(0, 2, 3, 1) + sanitizedWeights[k] = v.transposed(0, 2, 3, 4, 1) } } else { sanitizedWeights[k] = v @@ -882,8 +897,8 @@ public struct Qwen2VLConfiguration: Codable, Sendable { public let patchSize: Int public let vocabularySize: Int public let mlpRatio: Float - public let _channels: Int? - public var channels: Int { _channels ?? 3 } + public let _inChannels: Int? + public var inChannels: Int { _inChannels ?? 3 } public let _layerNormEps: Float? public var layerNormEps: Float { _layerNormEps ?? 1e-6 } public let spatialPatchSize: Int @@ -900,7 +915,7 @@ public struct Qwen2VLConfiguration: Codable, Sendable { case patchSize = "patch_size" case vocabularySize = "vocab_size" case mlpRatio = "mlp_ratio" - case _channels = "num_channels" + case _inChannels = "in_channels" case _layerNormEps = "layer_norm_eps" case spatialPatchSize = "spatial_patch_size" case spatialMergeSize = "spatial_merge_size"