Skip to content

wip MoE refactor #2600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 55 additions & 43 deletions torchao/_models/mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import torch._inductor.config

from torchao.utils import get_model_size_in_bytes
from torchao.prototype.moe_quant import MoEFeedForwardAOQuantizable
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from model import MoEFeedForward

torch.manual_seed(0)

Expand Down Expand Up @@ -187,7 +190,6 @@ def _load_model(checkpoint_path, device, precision):

B_INST, E_INST = "[INST]", "[/INST]"


def main(
prompt: str = "Hello, my name is",
interactive: bool = False,
Expand All @@ -199,6 +201,7 @@ def main(
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
compile: bool = True,
compile_prefill: bool = False,
compile_mode: str = "reduce-overhead",
moe_quant: Optional[str] = None,
profile: Optional[Path] = None,
memory_profile: Optional[Path] = None,
Expand All @@ -212,6 +215,13 @@ def main(
precision = torch.bfloat16
is_chat = "chat" in str(checkpoint_path)

if batch_size > 1 and moe_quant is None:
print(
"Warning: Detected no moe_quant but batchsize>1. The default MoE implementation uses a lot of memory when batched,"+
" if it OOMs you can instead run without quantization by specifying --moe_quant noquant which uses the AO quantizable"+
"module without quantization to run the quantizable module without quantization"
)

if device == "cuda" and memory_profile is not None:
torch.cuda.memory._record_memory_history(
True, trace_alloc_max_entries=500000, trace_alloc_record_context=True
Expand All @@ -236,10 +246,12 @@ def main(
]
)

from torchao.prototype.moe_quant.utils import (
from torchao.prototype.moe_quant import (
MoEQuantConfig,
MoEMapping,
UseFakeExtraDimTensor,
cond_ffn_filter,
MoEFeedForwardAOQuantizable,

)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Expand All @@ -255,71 +267,64 @@ def main(

if moe_quant:
torch._dynamo.config.capture_dynamic_output_shape_ops = True
config = None
config = MoEQuantConfig(mapping=MoEMapping(target_module_type=MoEFeedForward))
if "int8wo-base" in moe_quant:
config = MoEQuantConfig(Int8WeightOnlyConfig())
config.base_config = Int8WeightOnlyConfig()

elif "int8wo" in moe_quant:
config = MoEQuantConfig(
Int8WeightOnlyConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Int8WeightOnlyConfig()
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE

elif "int8dq-base" in moe_quant:
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
config.base_config = Int8DynamicActivationInt8WeightConfig()

elif "int8dq" in moe_quant:
config = MoEQuantConfig(
Int8DynamicActivationInt8WeightConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Int8DynamicActivationInt8WeightConfig()
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE

elif "int4wo-base" in moe_quant:
config = MoEQuantConfig(Int4WeightOnlyConfig())
config.base_config = Int4WeightOnlyConfig()

elif "int4wo" in moe_quant:
config = MoEQuantConfig(
Int4WeightOnlyConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Int4WeightOnlyConfig()
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE

elif "fp8wo-base" in moe_quant:
config = MoEQuantConfig(Float8WeightOnlyConfig())
config.base_config = Float8WeightOnlyConfig()

elif "fp8wo" in moe_quant:
config = MoEQuantConfig(
Float8WeightOnlyConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Float8WeightOnlyConfig()
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE

elif "fp8dq-base" in moe_quant:
config = MoEQuantConfig(
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
)
config.base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())

elif "fp8dq" in moe_quant:
config = MoEQuantConfig(
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE

elif "intxdq" in moe_quant:
config = MoEQuantConfig(
Int8DynamicActivationIntxWeightConfig(
config.base_config = Int8DynamicActivationIntxWeightConfig(
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
elif "noquant" in moe_quant:
pass
else:
assert config is not None, (
f"expected moe_quant to match one of the options but got {moe_quant}"
)

if config is not None:
quantize_(model, config, filter_fn=cond_ffn_filter, device=device)
print(
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds"
)
def filter_fn(mod, fqn):
return isinstance(mod, MoEFeedForward)

model.layers = model.layers

quantize_(model, config, filter_fn=filter_fn, device=device)

print(
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds"
)

model.to(device=device)
device_sync(device=device)
Expand All @@ -335,12 +340,12 @@ def main(

global decode_one_token, prefill

if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant):
if True and (batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant)):
decode_one_token = torch.compile(
decode_one_token, mode="reduce-overhead", fullgraph=True
decode_one_token, mode=compile_mode, fullgraph=True
)
else:
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
decode_one_token = torch.compile(decode_one_token, mode=compile_mode)

if args.compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
Expand Down Expand Up @@ -474,6 +479,12 @@ def callback(x):
action="store_true",
help="Whether to compile the prefill (improves prefill perf, but higher compile times)",
)
parser.add_argument(
"--compile_mode",
type=str,
default="reduce-overhead",
help="which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set.",
)
# parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8')
parser.add_argument(
"--moe_quant",
Expand All @@ -499,6 +510,7 @@ def callback(x):
args.checkpoint_path,
args.compile,
args.compile_prefill,
args.compile_mode,
args.moe_quant,
args.profile,
args.memory_profile,
Expand Down
136 changes: 98 additions & 38 deletions torchao/_models/mixtral-moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.block_sparse_moe = MOEFeedForwardAOQuantizable(config)
self.block_sparse_moe = MoEFeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

Expand Down Expand Up @@ -225,41 +225,39 @@ def forward(
y = self.wo(y)
return y

class MoEFeedForward(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
self.cond_ffn = ConditionalFeedForward(config)
self.dim = config.dim
self.num_activated_experts = config.num_activated_experts
def forward(self, x: Tensor) -> Tensor:
x = x.view(-1, self.dim)
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
# x: [T, D]
scores = self.gate(x) # [T, E]
expert_weights = F.softmax(scores, dim=-1)
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
expert_outs = self.cond_ffn(x, expert_indices)
return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)

# class ConditionalFeedForward(nn.Module):
# def __init__(self, config):
# super().__init__()
# self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
# self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
# self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))

# def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
# w1_weights = self.w1[expert_indices] # [T, A, D, D]
# w3_weights = self.w3[expert_indices] # [T, A, D, D]
# w2_weights = self.w2[expert_indices] # [T, A, D, D]
# x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
# x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
# expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
# return expert_outs


# class MOEFeedForward(nn.Module):
# def __init__(self, config) -> None:
# super().__init__()
# self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
# self.cond_ffn = ConditionalFeedForward(config)
# self.dim = config.dim
# self.num_activated_experts = config.num_activated_experts
# def forward(self, x: Tensor) -> Tensor:
# x = x.view(-1, self.dim)
# # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
# # x: [T, D]
# scores = self.gate(x) # [T, E]
# expert_weights = F.softmax(scores, dim=-1)
# expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
# expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
# expert_outs = self.cond_ffn(x, expert_indices)
# return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)
class ConditionalFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))

def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
w1_weights = self.w1[expert_indices] # [T, A, D, D]
w3_weights = self.w3[expert_indices] # [T, A, D, D]
w2_weights = self.w2[expert_indices] # [T, A, D, D]
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
return expert_outs


class RMSNorm(nn.Module):
Expand Down Expand Up @@ -301,6 +299,8 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)

#TODO delete


# T tokens
# E experts
Expand All @@ -310,7 +310,7 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
# T'(e) tokens for expert e


class MOEFeedForwardAOQuantizable(nn.Module):
class MoEFeedForwardAOQuantizable(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unlikely that people are going to swap their MoE module to AO's version. Can we just target torch._grouped_mm calls directly without requiring a module swap?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would it mean to "target" it specifically? If given model compiled, the compiled version of this operator will be used anyway, not sure what else torchao could do about it...

def __init__(self, config) -> None:
super().__init__()
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
Expand All @@ -337,7 +337,7 @@ class ConditionalFeedForwardAOQuantizable(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.w1 = nn.Parameter(
self.w1 = nn.Parameter( # num exp, expert_dim, hidden_dim
torch.empty(config.num_experts, config.intermediate_size, config.dim)
) # E, I, D
self.w2 = nn.Parameter(
Expand All @@ -347,6 +347,14 @@ def __init__(self, config):
torch.empty(config.num_experts, config.intermediate_size, config.dim)
) # E, I, D
self.num_experts = config.num_experts
self.perf_is_optimized = False

def optimize_perf(self):
self.w13 = torch.cat((self.w1, self.w3), dim=1)
self.w13 = torch.nn.Parameter(self.w13.transpose(-2,-1).contiguous().transpose(-2,-1))
self.w2 = torch.nn.Parameter(self.w2.transpose(-2, -1).contiguous().transpose(-2, -1))
del self.w1, self.w3
self.perf_is_optimized = True

def forward(
self,
Expand All @@ -355,8 +363,60 @@ def forward(
expert_weights: Tensor, # T, A
num_activated_experts: int,
) -> Tensor:


num_tokens, dim = x.shape
num_token_activations = num_tokens * num_activated_experts
num_token_activations = expert_indices.numel()


ordered_token_activations = expert_indices.view(-1).argsort(stable=True)
ordered_token_indices = (
ordered_token_activations.div(num_activated_experts)
.floor()
.to(torch.int32)
) # [T]

indices_for_histc = expert_indices.view(-1) if expert_indices.is_cuda else expert_indices.float().view(-1) # histc doesn't work on cpu for integers
num_tokens_per_expert = torch.histc(
indices_for_histc,
bins=self.num_experts,
min=0,
max=self.num_experts,
)
offs = num_tokens_per_expert.cumsum(dim=0).to(torch.int32)
ordered_inputs = x[ordered_token_indices]

if self.optimized_perf:
x1, x3 = torch._grouped_mm(ordered_inputs, self.w13.transpose(-2, -1), offs).split(self.config.intermediate_size, dim=1)
y1 = F.silu(x1) * x3
else:
x1 = F.silu(torch._grouped_mm(ordered_inputs, self.w1.transpose(-2,-1), offs))
x3 = torch._grouped_mm(ordered_inputs, self.w3.transpose(-2,-1), offs)
y1 = x1 * x3
ordered_outs = torch._grouped_mm(y1, self.w2.transpose(-2,-1), offs)
# ordered_outs = torch._grouped_mm(y1, self.w2, offs)

ordered_token_activation_weights = expert_weights.view(-1, 1)[
ordered_token_activations
].view(-1, 1) # [T*A, 1]
weighted_ordered_outs = (
ordered_outs * ordered_token_activation_weights
) # [T*A, D]

# sum weighted token-activation outputs together for each token
final_out = torch.zeros_like(x) # [T, D]
final_out = final_out.scatter_add(
dim=0,
index=ordered_token_indices.unsqueeze(-1)
.expand(num_token_activations, dim)
.to(torch.int64),
src=weighted_ordered_outs,
)

return final_out



if x.shape[0] == 1 and not isinstance(
self.w1, FakeExtraDimTensor
): # only 1 token (can be done without graph breaks when compiled)
Expand Down
Loading
Loading