Skip to content

Commit

Permalink
Update documents
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Dec 5, 2024
1 parent c252391 commit 3fb3f21
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 47 deletions.
41 changes: 19 additions & 22 deletions csrc/quantization/quant_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,6 @@ void launch_dequant_reduce(int8_t* reduced_data,
}
}



/*
Modified loco_dequant_reduce function that performs dequantization and reduction,
and incorporates error-feedback by updating the error_feedback tensor in-place.
Expand Down Expand Up @@ -303,14 +301,13 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
__half2 local_buffer[totalChunks * storage_values];
__half2 err_buffer[totalChunks * storage_values];

quantize::GroupStats<quantType> stats;
quantize::GroupStats<quantType> stats;

#pragma unroll
for (int i = 0; i < totalChunks; i++) {
__half2* iteration_buffer = local_buffer + i * storage_values;
__half2* iter_err_buffer = err_buffer + i * storage_values;


#pragma unroll
for (int j = 0; j < storage_values; j++) {
iteration_buffer[j] = reduce::init<rop::Add, __half2>();
Expand Down Expand Up @@ -367,7 +364,7 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
}
}
mem_access::load_global<quantize::granularity>(
iter_err_buffer, error_feedback + iter_offset_err, do_loads);
iter_err_buffer, error_feedback + iter_offset_err, do_loads);
#pragma unroll
for (int k = 0; k < storage_values; k++) {
iteration_buffer[k] = __hadd2(iteration_buffer[k], iter_err_buffer[k]);
Expand All @@ -383,7 +380,7 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
if constexpr (quantType == quantize::Type::Asymmetric) { de_params.offset = params.offset; }

if (tb.thread_index().x == 0) { params.store(reduced_scales, tb.group_index().x); }

#pragma unroll
for (int i = 0; i < totalChunks; i++) {
const int iter_offset = i * stride + base_offset;
Expand Down Expand Up @@ -420,8 +417,9 @@ __global__ void __launch_bounds__(1024) loco_dequant_reduce(int8_t* reduced_data
// float2 back to __half2
iter_err_buffer[k] = __float22half2_rn(iter_err_buf_f);
}
mem_access::store_global<quantize::granularity>(error_feedback + iter_offset_err, iter_err_buffer);
}
mem_access::store_global<quantize::granularity>(error_feedback + iter_offset_err,
iter_err_buffer);
}
}
}

Expand Down Expand Up @@ -488,19 +486,19 @@ void launch_loco_dequant_reduce_impl(int8_t* reduced_data,
}
}

#define LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
launch_loco_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data, \
reduced_scales, \
input_data, \
input_scales, \
out_groups, \
elems_per_out_group, \
elems_per_in_tensor, \
groups_per_in_tensor, \
elems_per_in_group, \
num_gpus, \
error_feedback, \
err_beta, \
#define LAUNCH_LOCO_DEQUANT_REDUCE_IMPL(NUM_BITS, NUM_GPUS, QUANT_TYPE) \
launch_loco_dequant_reduce_impl<NUM_BITS, NUM_GPUS, QUANT_TYPE>(reduced_data, \
reduced_scales, \
input_data, \
input_scales, \
out_groups, \
elems_per_out_group, \
elems_per_in_tensor, \
groups_per_in_tensor, \
elems_per_in_group, \
num_gpus, \
error_feedback, \
err_beta, \
stream);

void launch_loco_dequant_reduce(int8_t* reduced_data,
Expand Down Expand Up @@ -557,4 +555,3 @@ void launch_loco_dequant_reduce(int8_t* reduced_data,
}
}
}

8 changes: 5 additions & 3 deletions csrc/quantization/swizzled_quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ __global__ void loco_swizzled_quant_kernel(int8_t* quantized_data,
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);

// Indexing offsets, same as normal quantization for in-case
const int block_rank_data = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
const int block_rank_data =
blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
const int block_offset_data = block_rank_data * elems_per_group;
const int elem_offset = tb.thread_index().x * quantize::h_per_load;
const int base_offset_data = block_offset_data + elem_offset;
Expand Down Expand Up @@ -307,8 +308,9 @@ __global__ void loco_swizzled_quant_kernel(int8_t* quantized_data,
// float2 back to __half2
iter_err_buffer[k] = __float22half2_rn(iter_err_buf_f);
}
__half2* error_feedback_base_h2 = reinterpret_cast<__half2*>(error_feedback_base);
mem_access::store_global<quantize::granularity>(error_feedback_base_h2 + i_stride / 2, iter_err_buffer);
__half2* error_feedback_base_h2 = reinterpret_cast<__half2*>(error_feedback_base);
mem_access::store_global<quantize::granularity>(error_feedback_base_h2 + i_stride / 2,
iter_err_buffer);
}
}
}
Expand Down
36 changes: 14 additions & 22 deletions tests/unit/runtime/comm/test_coalesced_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,14 @@ def test_non_divisible(self):
assert output.shape == (24, )
assert torch.allclose(output, torch.zeros_like(output))


class TestLocoQuantized(DistributedTest):

world_size = 1

@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("tensor_size", [(16, 16), (64, 64)])
@pytest.mark.parametrize("devices_per_node", [4,8])
@pytest.mark.parametrize("devices_per_node", [4, 8])
def test_loco_quantized_reduction(self, num_bits, tensor_size, devices_per_node):
from deepspeed.ops.op_builder import QuantizerBuilder
if not deepspeed.ops.__compatible_ops__[QuantizerBuilder.NAME]:
Expand All @@ -126,44 +127,35 @@ def test_loco_quantized_reduction(self, num_bits, tensor_size, devices_per_node)
swizzled_tensor = tensor_reshaped.permute(1, 0, 2).reshape(tensor.size())

# Perform loco_swizzle_quant
output, scales = quantizer_module.loco_swizzle_quant(tensor, error_feedback, 0.0,
num_groups, num_bits,
quantizer_module.Symmetric, 1,
num_nodes, devices_per_node)


output, scales = quantizer_module.loco_swizzle_quant(tensor, error_feedback, 0.0, num_groups, num_bits,
quantizer_module.Symmetric, 1, num_nodes,
devices_per_node)

# Compare swizzled_tensor with the output of loco_swizzle_quant
dequantized = quantizer_module.dequantize(output, scales,
scales.numel(), num_bits,
dequantized = quantizer_module.dequantize(output, scales, scales.numel(), num_bits,
quantizer_module.Symmetric).view(tensor.size())

assert torch.allclose(swizzled_tensor + error_feedback_ori, dequantized + error_feedback)

# Calculate elements per group and groups per partition
elements_per_group = total_elements // num_groups
groups_per_partition = num_groups // devices_per_node

# Reshape dequantized data to match the grouping in loco_quantized_reduction
dequantized_reshaped = dequantized.view(devices_per_node,
groups_per_partition,
elements_per_group)
dequantized_reshaped = dequantized.view(devices_per_node, groups_per_partition, elements_per_group)

# Perform reduction across devices_per_node dimension
reduced_dequantized = dequantized_reshaped.cumsum(dim=0)[-1]
reduced_dequantized = dequantized_reshaped.cumsum(dim=0)[-1]
# Initialize error_feedback tensor
error_feedback = torch.randn(reduced_dequantized.shape, device=tensor.device, dtype=dequantized.dtype)
error_feedback_ori = error_feedback.clone()

# perform loco_quantized_reduction
output, scales = quantizer_module.loco_quantized_reduction(output, scales,
error_feedback, 0.0, num_groups,
num_groups // devices_per_node,
num_bits, quantizer_module.Symmetric,
devices_per_node)

dequantized_reduced = quantizer_module.dequantize(output, scales,
scales.numel(), num_bits,
output, scales = quantizer_module.loco_quantized_reduction(output, scales, error_feedback, 0.0, num_groups,
num_groups // devices_per_node, num_bits,
quantizer_module.Symmetric, devices_per_node)

dequantized_reduced = quantizer_module.dequantize(output, scales, scales.numel(), num_bits,
quantizer_module.Symmetric).view(error_feedback.size())

assert torch.allclose(reduced_dequantized + error_feedback_ori, dequantized_reduced + error_feedback)

0 comments on commit 3fb3f21

Please sign in to comment.