Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Phi 3.5 MoE #116

Merged
merged 6 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Libraries/LLM/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ private class ModelTypeRegistry: @unchecked Sendable {
Phi3Configuration.self, from: Data(contentsOf: url))
return Phi3Model(configuration)
},
"phimoe": { url in
let configuration = try JSONDecoder().decode(
PhiMoEConfiguration.self, from: Data(contentsOf: url))
return PhiMoEModel(configuration)
},
"gemma": { url in
let configuration = try JSONDecoder().decode(
GemmaConfiguration.self, from: Data(contentsOf: url))
Expand Down
26 changes: 19 additions & 7 deletions Libraries/LLM/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ extension ModelConfiguration {
"<s><|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}

public static let phi3_5MoE = ModelConfiguration(
id: "mlx-community/Phi-3.5-MoE-instruct-4bit",
defaultPrompt: "What is the gravity on Mars and the moon?",
extraEOSTokens: ["<|end|>"]
) {
prompt in
"<|user|>\n\(prompt)<|end|>\n<|assistant|>\n"
}

public static let gemma2bQuantized = ModelConfiguration(
id: "mlx-community/quantized-gemma-2b-it",
overrideTokenizer: "PreTrainedTokenizer",
Expand Down Expand Up @@ -260,19 +269,22 @@ extension ModelConfiguration {
case .idle:
bootstrapState = .bootstrapping
register(configurations: [
codeLlama13b4bit,
gemma2bQuantized,
gemma_2_2b_it_4bit,
gemma_2_9b_it_4bit,
llama3_1_8B_4bit,
llama3_2_1B_4bit,
llama3_2_3B_4bit,
mistralNeMo4bit,
smolLM_135M_4bit,
llama3_8B_4bit,
mistral7B4bit,
codeLlama13b4bit,
phi4bit,
mistralNeMo4bit,
openelm270m4bit,
phi3_5MoE,
phi3_5_4bit,
gemma2bQuantized,
gemma_2_9b_it_4bit,
phi4bit,
qwen205b4bit,
openelm270m4bit,
smolLM_135M_4bit,
])
bootstrapState = .bootstrapped

Expand Down
35 changes: 20 additions & 15 deletions Libraries/LLM/Phi3.swift
Original file line number Diff line number Diff line change
Expand Up @@ -207,21 +207,25 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider {
}
}

public struct Phi3Configuration: Codable, Sendable {
struct RopeScaling: Codable {
let longFactor: [Float]?
let shortFactor: [Float]?
let factor: Float?
let type: String?

enum CodingKeys: String, CodingKey {
case type
case factor
case longFactor = "long_factor"
case shortFactor = "short_factor"
}
struct RopeScalingWithFactorArrays: Codable {
let longFactor: [Float]?
let shortFactor: [Float]?
let factor: Float?
let type: String?
let longMScale: Float?
let shortMScale: Float?

enum CodingKeys: String, CodingKey {
case type
case factor
case longFactor = "long_factor"
case shortFactor = "short_factor"
case longMScale = "long_mscale"
case shortMScale = "short_mscale"
}
}

public struct Phi3Configuration: Codable, Sendable {
var hiddenSize: Int
var hiddenLayers: Int
var intermediateSize: Int
Expand All @@ -231,7 +235,7 @@ public struct Phi3Configuration: Codable, Sendable {
var kvHeads: Int
var ropeTheta: Float = 10_000
var ropeTraditional: Bool = false
var ropeScaling: RopeScaling?
var ropeScaling: RopeScalingWithFactorArrays?
var maxPositionEmbeddings: Int
var originalMaxPositionEmbeddings: Int

Expand Down Expand Up @@ -273,7 +277,8 @@ public struct Phi3Configuration: Codable, Sendable {
ropeTraditional =
try container.decodeIfPresent(
Bool.self, forKey: Phi3Configuration.CodingKeys.ropeTraditional) ?? false
ropeScaling = try container.decodeIfPresent(RopeScaling.self, forKey: .ropeScaling)
ropeScaling = try container.decodeIfPresent(
RopeScalingWithFactorArrays.self, forKey: .ropeScaling)
maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
originalMaxPositionEmbeddings = try container.decode(
Int.self, forKey: .originalMaxPositionEmbeddings)
Expand Down
263 changes: 263 additions & 0 deletions Libraries/LLM/PhiMoE.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import Foundation
import MLX
import MLXFast
import MLXNN
import MLXRandom

// Port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/phimoe.py

public struct PhiMoEConfiguration: Codable, Sendable {
var modelType: String = "phimoe"
var vocabularySize: Int = 32064
var hiddenSize: Int = 4096
var intermediateSize: Int = 6400
var hiddenLayers: Int = 32
var attentionHeads: Int = 32
var kvHeads: Int = 8
var maxPositionEmbeddings: Int = 131072
var originalMaxPositionEmbeddings: Int = 4096
var rmsNormEps: Float = 1e-6
var ropeScaling: RopeScalingWithFactorArrays?
var numLocalExperts: Int = 16
var numExpertsPerToken: Int = 2
var ropeTheta: Float = 10000.0

enum CodingKeys: String, CodingKey {
case modelType = "model_type"
case vocabularySize = "vocab_size"
case hiddenSize = "hidden_size"
case intermediateSize = "intermediate_size"
case hiddenLayers = "num_hidden_layers"
case attentionHeads = "num_attention_heads"
case kvHeads = "num_key_value_heads"
case maxPositionEmbeddings = "max_position_embeddings"
case originalMaxPositionEmbeddings = "original_max_position_embeddings"
case rmsNormEps = "rms_norm_eps"
case ropeScaling = "rope_scaling"
case numLocalExperts = "num_local_experts"
case numExpertsPerToken = "num_experts_per_tok"
case ropeTheta = "rope_theta"
}
}

private class Attention: Module {
let args: PhiMoEConfiguration
let scale: Float

@ModuleInfo(key: "q_proj") var wq: Linear
@ModuleInfo(key: "k_proj") var wk: Linear
@ModuleInfo(key: "v_proj") var wv: Linear
@ModuleInfo(key: "o_proj") var wo: Linear

let rope: SuScaledRotaryEmbedding

init(_ args: PhiMoEConfiguration) {
self.args = args

let dim = args.hiddenSize
let heads = args.attentionHeads
let kvHeads = args.kvHeads

let headDim = args.hiddenSize / heads
self.scale = pow(Float(headDim), -0.5)

self._wq.wrappedValue = Linear(dim, heads * headDim, bias: true)
self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: true)
self._wo.wrappedValue = Linear(heads * headDim, dim, bias: true)

self.rope = SuScaledRotaryEmbedding(
dimensions: headDim,
base: args.ropeTheta,
maxPositionEmbeddings: args.maxPositionEmbeddings,
originalMaxPositionEmbeddings: args.originalMaxPositionEmbeddings,
longFactor: args.ropeScaling?.longFactor as? [Float] ?? [1.0],
longMScale: args.ropeScaling?.longMScale as? Float
)
}

func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray {
let (B, L, _) = (x.dim(0), x.dim(1), x.dim(2))

let queries = wq(x)
let keys = wk(x)
let values = wv(x)

// Prepare the queries, keys and values for the attention computation
var q = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3)
var k = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)
var v = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3)

if let cache {
q = rope(q, offset: cache.offset)
k = rope(k, offset: cache.offset)
(k, v) = cache.update(keys: k, values: v)
} else {
q = rope(q)
k = rope(k)
}

let output = MLXFast.scaledDotProductAttention(
queries: q, keys: k, values: v, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)

return wo(output)
}
}

private class PhiMoESparseMoeBlock: Module {
let hiddenDim: Int
let ffnDim: Int
let numExperts: Int
let topK: Int

@ModuleInfo(key: "gate") var gate: Linear
@ModuleInfo(key: "switch_mlp") var switchMLP: SwitchGLU

init(_ args: PhiMoEConfiguration) {
self.hiddenDim = args.hiddenSize
self.ffnDim = args.intermediateSize
self.numExperts = args.numLocalExperts
self.topK = args.numExpertsPerToken

self._gate.wrappedValue = Linear(hiddenDim, numExperts, bias: false)
self._switchMLP.wrappedValue = SwitchGLU(
inputDims: hiddenDim, hiddenDims: ffnDim, numExperts: numExperts)
}

func callAsFunction(_ x: MLXArray) -> MLXArray {
let gates = gate(x)

let k = self.topK
let inds = MLX.stopGradient(
MLX.argPartition(
-gates,
kth: k - 1,
axis: -1
)[.ellipsis, ..<k])
let scores = MLX.softmax(MLX.takeAlong(gates, inds, axis: -1), axis: -1, precise: true)

let y = switchMLP(x, inds)
return (y * scores[.ellipsis, .newAxis]).sum(axis: -2)
}
}

private class PhiMoEDecoderLayer: Module {
let hiddenSize: Int

@ModuleInfo(key: "self_attn") var selfAttn: Attention
@ModuleInfo(key: "block_sparse_moe") var blockSparseMoe: PhiMoESparseMoeBlock
@ModuleInfo(key: "input_layernorm") var inputLayerNorm: LayerNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: LayerNorm

init(_ args: PhiMoEConfiguration) {
self.hiddenSize = args.hiddenSize

self._selfAttn.wrappedValue = Attention(args)
self._blockSparseMoe.wrappedValue = PhiMoESparseMoeBlock(args)
self._inputLayerNorm.wrappedValue = LayerNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
self._postAttentionLayerNorm.wrappedValue = LayerNorm(
dimensions: args.hiddenSize, eps: args.rmsNormEps)
}

func callAsFunction(_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?) -> MLXArray {
var residual = x
var hiddenStates = inputLayerNorm(x)
hiddenStates = selfAttn(hiddenStates, mask: mask, cache: cache)
hiddenStates = residual + hiddenStates

residual = hiddenStates
hiddenStates = postAttentionLayerNorm(hiddenStates)
hiddenStates = blockSparseMoe(hiddenStates)
hiddenStates = residual + hiddenStates

return hiddenStates
}
}

private class PhiMoEModelInner: Module {
let args: PhiMoEConfiguration

@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
let layers: [PhiMoEDecoderLayer]
@ModuleInfo(key: "norm") var norm: LayerNorm

init(_ args: PhiMoEConfiguration) {
self.args = args

self._embedTokens.wrappedValue = Embedding(
embeddingCount: args.vocabularySize, dimensions: args.hiddenSize)
self.layers = (0 ..< args.hiddenLayers).map { _ in PhiMoEDecoderLayer(args) }
self._norm.wrappedValue = LayerNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps)
}

func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
var h = embedTokens(inputs)

let mask = createAttentionMask(h: h, cache: cache)

for (i, layer) in layers.enumerated() {
h = layer(h, mask: mask, cache: cache?[i])
}

return norm(h)
}
}

public class PhiMoEModel: Module, LLMModel, KVCacheDimensionProvider {
public let vocabularySize: Int
public let kvHeads: [Int]
public let headDim: IntOrPair

fileprivate let model: PhiMoEModelInner
@ModuleInfo(key: "lm_head") var lmHead: Linear

public init(_ args: PhiMoEConfiguration) {
self.vocabularySize = args.vocabularySize
self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
self.headDim = .init(args.hiddenSize / args.attentionHeads)
self.model = PhiMoEModelInner(args)
self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: true)
}

public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
let out = model(inputs, cache: cache)
return lmHead(out)
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var sanitizedWeights = weights
if sanitizedWeights["model.layers.0.block_sparse_moe.experts.0.w1.weight"] == nil {
return sanitizedWeights
}

for l in 0 ..< model.args.hiddenLayers {
let prefix = "model.layers.\(l)"
for (n, m) in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")] {
for k in ["weight", "scales", "biases"] {
if sanitizedWeights["\(prefix).block_sparse_moe.experts.0.\(n).\(k)"] != nil {
let toJoin = (0 ..< model.args.numLocalExperts).map { e in
sanitizedWeights.removeValue(
forKey: "\(prefix).block_sparse_moe.experts.\(e).\(n).\(k)")!
}
sanitizedWeights["\(prefix).block_sparse_moe.switch_mlp.\(m).\(k)"] =
MLX.stacked(toJoin)
}
}
}
}

return sanitizedWeights
}
}

// MARK: - LoRA

extension PhiMoEModel: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.selfAttn, ["q_proj", "v_proj"]) }
}
}
Loading