Skip to content

Commit

Permalink
Enable Zero Offload in seqmse (#3663)
Browse files Browse the repository at this point in the history
Signed-off-by: Huan Zhao <[email protected]>
  • Loading branch information
quic-huzh authored Dec 20, 2024
1 parent 794bfe2 commit 01fc213
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 31 deletions.
54 changes: 27 additions & 27 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from aimet_torch.v2.nn.base import BaseQuantizationMixin
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.utils import reduce, _is_reducible
from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters

__all__ = [
'SequentialMse',
Expand Down Expand Up @@ -189,9 +190,7 @@ def compute_param_encodings(cls,
with quantize_dequantize.compute_encodings():
_ = quantize_dequantize(torch.stack([x_min, x_max]))

with torch.no_grad():
quantizer.min.copy_(quantize_dequantize.min)
quantizer.max.copy_(quantize_dequantize.max)
quantizer.set_range(quantize_dequantize.min, quantize_dequantize.max)

@classmethod
def _is_symmetric_quantizer(cls, quantizer: AffineQuantizerBase):
Expand Down Expand Up @@ -250,7 +249,7 @@ def get_min_and_max_for_candidate_selection(cls, quant_module: BaseQuantizationM
block_size = quant_module.param_quantizers['weight'].block_size
if block_size is None:
# Per tensor or per channel case
assert _is_reducible(quant_module.weight.shape, quant_module.param_quantizers['weight'].min.shape)
assert _is_reducible(quant_module.weight.shape, quant_module.param_quantizers['weight'].shape)
if cls._is_symmetric_quantizer(quant_module.param_quantizers['weight']):
max_tensor = reduce(quant_module.weight.abs(),
quant_module.param_quantizers['weight'].shape, torch.max).values
Expand Down Expand Up @@ -348,29 +347,30 @@ def optimize_module(cls,
:param params: Sequenial MSE parameters
"""
# pylint: disable=too-many-locals
min_tensor, max_tensor = cls.get_min_and_max_for_candidate_selection(quant_module)

total_loss = []
for i in range(params.num_candidates):
cand_min, cand_max = cls._get_candidate(i, params.num_candidates, min_tensor, max_tensor)
cls.compute_param_encodings(quant_module.param_quantizers['weight'], cand_min, cand_max)
w = quant_module.weight
wq = cls._get_quantized_weight(quant_module)
with torch.no_grad():
for batch_idx in range(params.num_batches):
if batch_idx == 0:
loss = cls._compute_loss(quant_module, x[batch_idx], xq[batch_idx], w, wq, params)
else:
loss += cls._compute_loss(quant_module, x[batch_idx], xq[batch_idx], w, wq, params)
total_loss.append(loss)

best_indices = torch.stack(total_loss).min(0)[1]
block_size = cls._get_input_channel_block_size(quant_module)
# In the input_channels dimension, best_indices is of size num_blocks. We use repeat_interleave to expand
# each blockwise index into block_size number of indices. This makes best_indices input_channels dimension
# become size num_blocks * block_size, and allows for elementwise operation with min_tensor and max_tensor.
if block_size != quant_module.weight.shape[1]:
best_indices = best_indices.repeat_interleave(block_size, dim=-1)
with SafeGatheredParameters(quant_module.parameters(recurse=False)):
min_tensor, max_tensor = cls.get_min_and_max_for_candidate_selection(quant_module)

total_loss = []
for i in range(params.num_candidates):
cand_min, cand_max = cls._get_candidate(i, params.num_candidates, min_tensor, max_tensor)
cls.compute_param_encodings(quant_module.param_quantizers['weight'], cand_min, cand_max)
w = quant_module.weight
wq = cls._get_quantized_weight(quant_module)
with torch.no_grad():
for batch_idx in range(params.num_batches):
if batch_idx == 0:
loss = cls._compute_loss(quant_module, x[batch_idx], xq[batch_idx], w, wq, params)
else:
loss += cls._compute_loss(quant_module, x[batch_idx], xq[batch_idx], w, wq, params)
total_loss.append(loss)

best_indices = torch.stack(total_loss).min(0)[1]
block_size = cls._get_input_channel_block_size(quant_module)
# In the input_channels dimension, best_indices is of size num_blocks. We use repeat_interleave to expand
# each blockwise index into block_size number of indices. This makes best_indices input_channels dimension
# become size num_blocks * block_size, and allows for elementwise operation with min_tensor and max_tensor.
if block_size != quant_module.weight.shape[1]:
best_indices = best_indices.repeat_interleave(block_size, dim=-1)

# Unsqueeze best_indices until it matches dim length of max_tensor
while best_indices.dim() < max_tensor.dim():
Expand Down
120 changes: 116 additions & 4 deletions TrainingExtensions/torch/test/python/v2/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import json
from packaging import version

from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.utils.data import Dataset, DataLoader

from aimet_common import quantsim_config
from aimet_common.defs import QuantScheme
Expand All @@ -66,6 +66,8 @@
from aimet_torch.v2.quantization.base.quantizer import QuantizerBase
from aimet_torch.v2.quantization import DequantizedTensor
from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters
# SEQ-MSE
from aimet_torch.v2.seq_mse import apply_seq_mse, SeqMseParams


class Net(nn.Module):
Expand Down Expand Up @@ -207,9 +209,7 @@ def __len__(self):
return len(self.data)

dataset = MyDataset([torch.randn(1, 28, 28) for _ in range(10)])
# TODO: (huzh) Change RandomSampler to DistributedSampler for testing with multiple GPUs
sampler = RandomSampler(dataset)
return DataLoader(dataset, sampler=sampler, batch_size=1)
return DataLoader(dataset, batch_size=1, shuffle=False)


@pytest.fixture
Expand Down Expand Up @@ -668,6 +668,118 @@ def test_deepspeed_zero3_offload_fallback(unlabeled_data_loader,
ds_after = ds_params_after[param_name]
assert not torch.equal(ds_before, ds_after)

@pytest.mark.parametrize("inp_symmetry", ['asym', 'symfp', 'symqt'])
@pytest.mark.parametrize("loss_fn", ['mse', 'l1', 'sqnr'])
@pytest.mark.cuda
def test_seqmse_with_zero3_offload(per_channel_quantsim_config,
init_process_group,
unlabeled_data_loader,
deepspeed_zero3_offload_config,
inp_symmetry, loss_fn):
# Baseline model without deepsped
model_baseline = Net().cuda().eval()
baseline_state_dict = model_baseline.state_dict()
sim_baseline = QuantizationSimModel(model_baseline,
torch.randn(1, 1, 28, 28).cuda(),
default_param_bw=4,
config_file=per_channel_quantsim_config,
quant_scheme=QuantScheme.training_range_learning_with_tf_init,
in_place=False)

"""
Given: Model pre-partitioned with deepspeed zero3 offload
"""
with ds.zero.Init(config_dict_or_path=deepspeed_zero3_offload_config):
# ds.zero.Init context pre-partitoins the pytorch models at instantiation time.
# PyTorch modules instantiated under this context will only hold a partition
# of their parameters
model = Net().cuda().eval()
assert all(param.numel() == 0 for param in model.parameters()) # sanity check
assert all(hasattr(param, 'ds_shape') for param in model.parameters()) # sanity check

# Copy the parameters/buffers of baseline model to deepspeed pre-partitoined model to assert
# outputs to be equal with or without deepspeed
with ds.runtime.zero.GatheredParameters(model.parameters(), modifier_rank=0), torch.no_grad():
model.load_state_dict(baseline_state_dict)

"""
When: Create quantsim with the model pre-partitioned model
Then: Quantizers should be instantiated with correct shape
"""
sim_deepspeed = QuantizationSimModel(model,
torch.randn(1, 1, 28, 28).cuda(),
default_param_bw=4,
config_file=per_channel_quantsim_config,
quant_scheme=QuantScheme.training_range_learning_with_tf_init,
in_place=True)


"""
When: Initialize quantsim model with deepspeed zero3 offload
Then:
1) All parameters must be initialized with deepspeed zero3 parameter partitioning mechanism
2) Forward pass outputs must be equal with or without deepspeed
"""
if "optimizer" in deepspeed_zero3_offload_config:
del deepspeed_zero3_offload_config["optimizer"]
engine, ds_optimizer, *_ = ds.initialize(model=sim_deepspeed.model,
model_parameters=sim_deepspeed.model.parameters(),
config=deepspeed_zero3_offload_config,
mpu=CustomMPU(init_process_group))
assert all(hasattr(param, 'ds_shape') for param in model.parameters())


"""
When: Apply SEQ-MSE and Compute encodings after deepspeed initialization
Then:
1) All parameters in the quantizer should have requires_grad set to False
2) Parameters in the quantizer should be frozen after applying SEQ-MSE
3) All parameters must be udpated in the (almost) same way with or without deepspeed
"""

sim_deepspeed.model.requires_grad_(True)
params = SeqMseParams(num_batches=2, inp_symmetry=inp_symmetry, loss_fn=loss_fn)
apply_seq_mse(model_baseline, sim_deepspeed, unlabeled_data_loader, params)
assert not sim_deepspeed.model.fc1.param_quantizers['weight'].min.requires_grad
assert not sim_deepspeed.model.fc1.param_quantizers['weight'].max.requires_grad
assert not sim_deepspeed.model.fc1.param_quantizers['weight']._allow_overwrite
assert not sim_deepspeed.model.fc2.param_quantizers['weight'].min.requires_grad
assert not sim_deepspeed.model.fc2.param_quantizers['weight'].max.requires_grad
assert not sim_deepspeed.model.fc2.param_quantizers['weight']._allow_overwrite

# Compute encodings for all the activations and remaining non-supported modules
with ds.runtime.zero.GatheredParameters(sim_deepspeed.model.fc1.param_quantizers.parameters()):
enc_before = sim_deepspeed.model.fc1.param_quantizers['weight'].get_encoding()

# Apply seq-mse for baseline fp32 model
apply_seq_mse(model_baseline, sim_baseline, unlabeled_data_loader, params)

with aimet.nn.compute_encodings(sim_deepspeed.model),\
aimet.nn.compute_encodings(sim_baseline.model):
for data in itertools.islice(unlabeled_data_loader, 3):
data = data.cuda()
_ = sim_deepspeed.model(data)
_ = sim_baseline.model(data)
with ds.runtime.zero.GatheredParameters(sim_deepspeed.model.fc1.param_quantizers.parameters()):
enc_after = sim_deepspeed.model.fc1.param_quantizers['weight'].get_encoding()
assert enc_before.scale == enc_after.scale

bs_params = {
name: param.clone().detach() for name, param in sim_baseline.model.named_parameters()
}

with ds.runtime.zero.GatheredParameters(sim_deepspeed.model.parameters()):
ds_params = {
name: param.clone().detach() for name, param in sim_deepspeed.model.named_parameters()
}

assert bs_params.keys() == ds_params.keys()
for param_name in bs_params:
bs_param = bs_params[param_name]
ds_param = ds_params[param_name]
# Still need to check
assert torch.allclose(bs_param, ds_param, rtol=1e-3)

@pytest.mark.cuda
def test_conv_transpose(per_channel_quantsim_config,
init_process_group,
Expand Down

0 comments on commit 01fc213

Please sign in to comment.