Skip to content

Commit

Permalink
Bring mlp and moe into variations
Browse files Browse the repository at this point in the history
Refactor for bringing mlp and moe as variations for swapping.
  • Loading branch information
gkielian committed Dec 19, 2024
1 parent dddfa5c commit 94d86c5
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 170 deletions.
175 changes: 5 additions & 170 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import torch.utils.checkpoint as checkpoint

# Variations
from variations.mlp_variations import get_mlp_instance
from variations.moe_variations import MoELayer
from variations.lsv_variations import lsv_dictionary
from variations.softmax_variations import softmax_dictionary
from variations.norm_variations import norm_dictionary
Expand All @@ -35,6 +37,7 @@
from variations.linear_variations import linear_dictionary
from variations.router_variations import router_dictionary
from quantization.quantize import quantize_dictionary, dequantize, fake_quantize_act
from quantization.quant_utils import set_variant, create_activation_buffers

def create_shared_param_group(layer_type, config):

Expand Down Expand Up @@ -68,7 +71,7 @@ def create_shared_param_group(layer_type, config):
# this iter is an moe layer iter
layer_block = MoELayer(config)
else:
layer_block = MLP(config)
layer_block = get_mlp_instance(config)
elif layer_type == "attn":
layer_block = CausalSelfAttention(config, fire_pos_enc=fire_pos_enc)
else:
Expand All @@ -95,18 +98,6 @@ def create_shared_param_group(layer_type, config):
return shared_group
return shared_group

def set_variant(variant, default_variant):
# If variant is false or None, then set to provided default value
if not variant:
return default_variant
return variant

def create_activation_buffers(obj, arg):
arg_str = arg.split("quantize_")[1]
obj.register_buffer(arg_str, None)
obj.register_buffer(f"{arg_str}_scale", None)
obj.register_buffer(f"{arg_str}_zero_point", None)

class CausalSelfAttention(nn.Module):
def __init__(self, config, fire_pos_enc=None):
super().__init__()
Expand Down Expand Up @@ -426,114 +417,6 @@ def forward(self, x, iter_num):

return y


class MLP(nn.Module):
def __init__(self, config):
super().__init__()

self.full_quant_iteration = config.full_quant_iteration
self.eval_interval = config.eval_interval

# Select "mlp variant"
self.mlp_variant = config.mlp_variant

self.start_quant_level = config.start_quant_level
self.quant_scheduler = config.quant_scheduler

# If "MLP Variant" is KAN, then we skip MLP specific items
if self.mlp_variant == "kan":
self.kan = linear_dictionary["kan"](config.n_embd, config.n_embd, config=config)
else:
# Select activation variant
self.activation_variant = activation_dictionary[config.activation_variant](config=config)

# Sets the class of linear for MLP
self.linear_variant_mlp_up = linear_dictionary[set_variant(config.linear_variant_mlp_up, config.linear_variant_mlp)]
self.linear_variant_mlp_down = linear_dictionary[set_variant(config.linear_variant_mlp_down, config.linear_variant_mlp)]

self.quantization_mlp_dict = {}
self.quantization_mlp_dict["activations_quant_method"] = config.activations_quant_method

# Set quantization parameters for MLP
for arg, val in vars(config).items():
# Set MLP Activation precision and quantization method
if arg.startswith("quantize_") and "mlp_act" in arg and arg.endswith("_bits"):
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_mlp_act_bits)
elif arg.startswith("quantize_") and "mlp_act" in arg:
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_mlp_act)
if config.store_activations and arg != "quantize_mlp_act" and self.quantization_mlp_dict[arg]:
create_activation_buffers(self, arg)
# Set MLP Linear Weight precision and quantization method
elif arg.startswith("quantize_") and "linear_mlp" in arg and arg.endswith("_bits"):
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_linear_bits)
elif arg.startswith("quantize_") and "linear_mlp" in arg and arg.endswith("_method"):
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_linear_method)

# Instantiate Linear Layers
if self.mlp_variant == "mlp":
self.c_fc = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"], bias=config.bias)
self.c_proj = self.linear_variant_mlp_down(config.mlp_expansion_factor * config.n_embd, config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_down_method"], self.quantization_mlp_dict["quantize_linear_mlp_down_bits"], bias=config.bias)
elif self.mlp_variant == "swiglu":
self.c_fc_in1 = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"])
self.c_fc_in2 = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"])
self.c_fc_out = self.linear_variant_mlp_down(config.mlp_expansion_factor * config.n_embd, config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_down_method"], self.quantization_mlp_dict["quantize_linear_mlp_down_bits"])

self.dropout = nn.Dropout(config.dropout)

def forward(self, x, iter_num):

if self.quantization_mlp_dict["quantize_mlp_act_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_input", x, num_bits, quant_method, iter_num)

if self.mlp_variant == "kan":
x = self.kan(x)

elif self.mlp_variant == "mlp":
x = self.c_fc(x)

if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_activation_input", x, num_bits, quant_method, iter_num)

x = self.activation_variant(x)

if self.quantization_mlp_dict["quantize_mlp_act_activation_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_activation_output", x, num_bits, quant_method, iter_num)

x = self.c_proj(x)

elif self.mlp_variant == "swiglu":
x_in1 = self.c_fc_in1(x)

if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x_in1 = fake_quantize_act(self, "mlp_act_activation_input", x_in1, num_bits, quant_method, iter_num)

x_in1 = self.activation_variant(x_in1)

if self.quantization_mlp_dict["quantize_mlp_act_activation_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x_in1 = fake_quantize_act(self, "mlp_act_activation_output", x_in1, num_bits, quant_method, iter_num)

x_in2 = self.c_fc_in2(x)
x_out = x_in1 * x_in2
x = self.c_fc_out(x_out)

x = self.dropout(x)

if self.quantization_mlp_dict["quantize_mlp_act_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_output", x, num_bits, quant_method, iter_num)
return x

class Block(nn.Module):
def __init__(self, config, mlp=None, attn=None):
super().__init__()
Expand All @@ -556,7 +439,7 @@ def __init__(self, config, mlp=None, attn=None):

# Allow for sharing mlp between blocks
if mlp is None:
self.mlp = MLP(config)
self.mlp = get_mlp_instance(config)
else:
self.mlp = mlp

Expand Down Expand Up @@ -1090,51 +973,3 @@ def generate_with_stop(self, idx, max_new_tokens, stop_string, decode, temperatu

return idx, generated_text


class MoELayer(nn.Module):
""" Mixture of Experts layer to replace FFN (or every other FFN) """

def __init__(self, config):
super().__init__()
self.top_k = config.moe_top_k
# TODO: implement expert capacity throttling
# self.expert_capacity = config.expert_capacity
self.num_experts = config.n_experts
self.router = router_dictionary[config.moe_router_scheme](config)
self.experts = nn.ModuleList([MLP(config) for _ in range(config.n_experts)])

def forward(self, x):
# Assuming x has shape [batch_size, seq_len, n_embd]
batch_size, seq_len, _ = x.shape
gating_output, indices = self.router(x)
# print(f"gating_output.shape: {gating_output.shape}")
# print(f"indices 1 count: {indices}")
final_output = torch.zeros_like(x)

# Flatten the batch and sequence dimensions to treat each token independently
flat_x = x.view(-1, x.size(-1))
# print(f"x.shape() = {x.shape}")
# print(f"flat_x = {flat_x.shape}")
flat_gating_output = gating_output.view(-1, gating_output.size(-1))
# print(f"flat_gating_output.shape = {flat_gating_output.shape}")

# Process each expert in parallel
for i, expert in enumerate(self.experts):
# Create a mask for the inputs where the current expert is in top-k
expert_mask = (indices == i).any(dim=-1)
flat_mask = expert_mask.view(-1)
# print(f"expert_mask shape = {expert_mask.shape}")
# print(f"flat_mask shape = {flat_mask.shape}")

if flat_mask.any():
expert_input = flat_x[flat_mask]
expert_output = expert(expert_input)

# Extract and apply gating scores
gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
weighted_output = expert_output * gating_scores

# Update final output additively by indexing and adding
final_output[expert_mask] += weighted_output.squeeze(1)
# print(f"final_output.shape = {final_output.shape}\n")
return final_output
12 changes: 12 additions & 0 deletions quantization/quant_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

def set_variant(variant, default_variant):
# If variant is false or None, then set to provided default value
if not variant:
return default_variant
return variant

def create_activation_buffers(obj, arg):
arg_str = arg.split("quantize_")[1]
obj.register_buffer(arg_str, None)
obj.register_buffer(f"{arg_str}_scale", None)
obj.register_buffer(f"{arg_str}_zero_point", None)
130 changes: 130 additions & 0 deletions variations/mlp_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# mlp_variations.py

import torch
import torch.nn as nn
import torch.nn.functional as F

from variations.activation_variations import activation_dictionary
from variations.linear_variations import linear_dictionary
from quantization.quantize import fake_quantize_act
from quantization.quant_utils import set_variant, create_activation_buffers

class OriginalMLP(nn.Module):
def __init__(self, config):
super().__init__()

self.full_quant_iteration = config.full_quant_iteration
self.eval_interval = config.eval_interval

# Select "mlp variant"
self.mlp_variant = config.mlp_variant

self.start_quant_level = config.start_quant_level
self.quant_scheduler = config.quant_scheduler

# If "MLP Variant" is KAN, then we skip MLP specific items
if self.mlp_variant == "kan":
self.kan = linear_dictionary["kan"](config.n_embd, config.n_embd, config=config)
else:
# Select activation variant
self.activation_variant = activation_dictionary[config.activation_variant](config=config)

# Sets the class of linear for MLP
self.linear_variant_mlp_up = linear_dictionary[set_variant(config.linear_variant_mlp_up, config.linear_variant_mlp)]
self.linear_variant_mlp_down = linear_dictionary[set_variant(config.linear_variant_mlp_down, config.linear_variant_mlp)]

self.quantization_mlp_dict = {}
self.quantization_mlp_dict["activations_quant_method"] = config.activations_quant_method

# Set quantization parameters for MLP
for arg, val in vars(config).items():
# Set MLP Activation precision and quantization method
if arg.startswith("quantize_") and "mlp_act" in arg and arg.endswith("_bits"):
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_mlp_act_bits)
elif arg.startswith("quantize_") and "mlp_act" in arg:
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_mlp_act)
if config.store_activations and arg != "quantize_mlp_act" and self.quantization_mlp_dict[arg]:
create_activation_buffers(self, arg)
# Set MLP Linear Weight precision and quantization method
elif arg.startswith("quantize_") and "linear_mlp" in arg and arg.endswith("_bits"):
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_linear_bits)
elif arg.startswith("quantize_") and "linear_mlp" in arg and arg.endswith("_method"):
self.quantization_mlp_dict[arg] = set_variant(val, config.quantize_linear_method)

# Instantiate Linear Layers
if self.mlp_variant == "mlp":
self.c_fc = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"], bias=config.bias)
self.c_proj = self.linear_variant_mlp_down(config.mlp_expansion_factor * config.n_embd, config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_down_method"], self.quantization_mlp_dict["quantize_linear_mlp_down_bits"], bias=config.bias)
elif self.mlp_variant == "swiglu":
self.c_fc_in1 = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"])
self.c_fc_in2 = self.linear_variant_mlp_up(config.n_embd, config.mlp_expansion_factor * config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_up_method"], self.quantization_mlp_dict["quantize_linear_mlp_up_bits"])
self.c_fc_out = self.linear_variant_mlp_down(config.mlp_expansion_factor * config.n_embd, config.n_embd, config, self.quantization_mlp_dict["quantize_linear_mlp_down_method"], self.quantization_mlp_dict["quantize_linear_mlp_down_bits"])

self.dropout = nn.Dropout(config.dropout)

def forward(self, x, iter_num=None):

if self.quantization_mlp_dict["quantize_mlp_act_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_input", x, num_bits, quant_method, iter_num)

if self.mlp_variant == "kan":
x = self.kan(x)

elif self.mlp_variant == "mlp":
x = self.c_fc(x)

if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_activation_input", x, num_bits, quant_method, iter_num)

x = self.activation_variant(x)

if self.quantization_mlp_dict["quantize_mlp_act_activation_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_activation_output", x, num_bits, quant_method, iter_num)

x = self.c_proj(x)

elif self.mlp_variant == "swiglu":
x_in1 = self.c_fc_in1(x)

if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x_in1 = fake_quantize_act(self, "mlp_act_activation_input", x_in1, num_bits, quant_method, iter_num)

x_in1 = self.activation_variant(x_in1)

if self.quantization_mlp_dict["quantize_mlp_act_activation_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x_in1 = fake_quantize_act(self, "mlp_act_activation_output", x_in1, num_bits, quant_method, iter_num)

x_in2 = self.c_fc_in2(x)
x_out = x_in1 * x_in2
x = self.c_fc_out(x_out)

x = self.dropout(x)

if self.quantization_mlp_dict["quantize_mlp_act_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_output", x, num_bits, quant_method, iter_num)
return x


mlp_dictionary = {
"mlp": OriginalMLP
}

def get_mlp_instance(config):
mlp_type = config.mlp_variant
mlp_class = mlp_dictionary.get(mlp_type)
if mlp_class is None:
raise ValueError(f"Unsupported MLP variant: {mlp_type}")
return mlp_class(config)

Loading

0 comments on commit 94d86c5

Please sign in to comment.