Skip to content

Commit

Permalink
Merge remote-tracking branch 'master' into tp_grad_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Apr 15, 2024
2 parents 3ebed5e + 54c0687 commit fc537b8
Show file tree
Hide file tree
Showing 44 changed files with 244 additions and 221 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ DeepSpeed has been integrated with several different popular open-source DL fram
| AMD | [![amd-mi200](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi200.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/amd-mi200.yml) |
| CPU | [![torch-latest-cpu](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-torch-latest.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-torch-latest.yml) [![cpu-inference](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-inference.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/cpu-inference.yml) |
| Intel Gaudi | [![hpu-gaudi2](https://github.com/microsoft/DeepSpeed/actions/workflows/hpu-gaudi2.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/hpu-gaudi2.yml) |
| Intel XPU | [![xpu-max1100](https://github.com/microsoft/DeepSpeed/actions/workflows/xpu-max1100.yml/badge.svg)?branch=master](https://github.com/microsoft/DeepSpeed/actions/workflows/xpu-max1100.yml) |
| Intel XPU | [![xpu-max1100](https://github.com/microsoft/DeepSpeed/actions/workflows/xpu-max1100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/xpu-max1100.yml) |
| PyTorch Nightly | [![nv-torch-nightly-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) |
| Integrations | [![nv-transformers-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [![nv-lightning-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-lightning-v100.yml) [![nv-accelerate-v100](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-accelerate-v100.yml) [![nv-mii](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-mii.yml) [![nv-ds-chat](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-ds-chat.yml) [![nv-sd](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/nv-sd.yml) |
| Misc | [![Formatting](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/microsoft/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/microsoft/DeepSpeed/actions/workflows/python.yml) |
Expand Down Expand Up @@ -163,7 +163,7 @@ dynamically link them at runtime.
| ----------- | -------- | ---------------- | --------------------- | ------------------ |
| Intel | Intel(R) Gaudi(R) 2 AI accelerator | hpu | Yes | Yes |
| Intel | Intel(R) Xeon(R) Processors | cpu | Yes | Yes |
| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | No |
| Intel | Intel(R) Data Center GPU Max series | xpu | Yes | Yes |

## PyPI
We regularly push releases to [PyPI](https://pypi.org/project/deepspeed/) and encourage users to install from there in most cases.
Expand Down
168 changes: 68 additions & 100 deletions csrc/fp_quantizer/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -219,119 +219,100 @@ __global__ void apply_quantization(T* val,
}

template <typename T,
int unroll,
int q_mantisa_bits,
int total_q_bits = 16,
int _mantisa_bits = 3,
int _exponent_bits = 4>
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size)
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements)
{
int tidx = threadIdx.x;
int wid = tidx >> 5;
int lane = tidx & 0x1f;
int gid = blockIdx.x * quantization::warps + wid;
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;

constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);

constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
constexpr uint32_t load_stride = vector_size * hw_warp_size;
const uint32_t thread_offset = lane * vector_size;
const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8;
const uint32_t base_load_offset =
gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset
const uint32_t base_store_offset = gid * group_size + thread_offset;
const uint8_t* load_base_ptr = val + base_load_offset;
const uint32_t g_index = (tidx / group_size);
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
const uint8_t* load_base_ptr =
val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8;

int mantisa_mask = ((1 << q_mantisa_bits) - 1);
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);

T* store_base_ptr = q_val + base_store_offset;
float scale; //= q_scale[gid];
T* store_base_ptr = q_val + tidx;
float scale;

uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) +
val + g_index * (group_size_bytes + 4) + group_size_bytes +
quantization::quanitzed_access_granularity_6bits);
} else
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));

#pragma unroll
for (int i = 0; i < unroll; i++) {
if (i * load_stride + thread_offset < group_size) {
uint64_t q_buf_in;
uint64_t q_buf_in1;
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
uint32_t loading_offset = i * load_stride * quantized_bits / 8;
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data, load_base_ptr + loading_offset);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity_6bits);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity_6bits * 2);
} else {
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);

if (tidx < total_num_elements) {
uint64_t q_buf_in;
uint64_t q_buf_in1;
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data, load_base_ptr);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits,
load_base_ptr + quantization::quanitzed_access_granularity_6bits);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
} else {
mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
load_base_ptr);
if (quantized_bits > 4) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data, load_base_ptr + loading_offset);
if (quantized_bits > 4) {
int8_data + quantization::quanitzed_access_granularity,
load_base_ptr + quantization::quanitzed_access_granularity);
if (quantized_bits == 12) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data + quantization::quanitzed_access_granularity,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity);
if (quantized_bits == 12) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data1,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity * 2);
}
int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
}
}
T store_buf[vector_size];
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
}
T store_buf[vector_size];
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
#pragma unroll
for (int j = 0; j < vector_size; j++) {
uint16_t new_data;
if (j < 5 || quantized_bits != 12) {
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
} else {
if (j == 5) {
new_data = (uint16_t)(q_buf_in1);
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
} else
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}
for (int j = 0; j < vector_size; j++) {
uint16_t new_data;
if (j < 5 || quantized_bits != 12) {
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
} else {
if (j == 5) {
new_data = (uint16_t)(q_buf_in1);
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
} else
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}

uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);
uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);

if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;
if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;

q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) |
(dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
mem_access::store_global<quantization::access_granularity>(
store_base_ptr + i * load_stride, store_buf);
q_buf[j] =
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
mem_access::store_global<quantization::access_granularity>(store_base_ptr, store_buf);
}
}

Expand Down Expand Up @@ -386,12 +367,6 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8);
#endif
INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8);

#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \
case COUNT: \
apply_dequantization<T, COUNT, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS> \
<<<grid, block, 0, stream>>>(val, q_val, group_size); \
break;

template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
Expand All @@ -401,21 +376,14 @@ void launch_dequantization(uint8_t* val,
int q_exponent_bits,
cudaStream_t stream)
{
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
int blocks = ((num_groups * group_size) - 1) /
(quantization::threads * (quantization::access_granularity / sizeof(T))) +
1;
const dim3 grid(blocks);
const dim3 block(quantization::threads);

constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;

DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
switch (copy_unroll) {
LAUNCH_FOR_DEQUANTIZATION_UNROLL(1)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(2)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(3)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(4)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(5)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(6)
}
apply_dequantization<T, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS>
<<<grid, block, 0, stream>>>(val, q_val, group_size, (num_groups * group_size));
});
}
#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@

from .zero_checkpoint import ZeROCheckpoint

from .universal_checkpoint import enable_universal_checkpoint
from .universal_checkpoint import enable_universal_checkpoint, SubparamShape

from .constants import *
4 changes: 4 additions & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,14 @@
# Similarly, load_hp_checkpoint_state has to take the needed actions when loading from universal.
PARAM_N_SUB_PARAMS = "param_n_sub_params"

SUB_PARAM_SHAPE = "sub_param_shape"

# Regex list of parameters that require special handling
VOCABULARY_PARAMETER_PATTERNS = 'vocabulary_parameter_patterns'
PIPELINE_REPLICATED_PARAMETER_PATTERNS = 'pipeline_replicated_parameter_patterns'
PARAMETER_TO_AVERAGE_PATTERNS = 'parameter_to_average_patterns'
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS = 'parameter_with_row_parallelism_patterns'
TP_REPLICATED_PARAMETER_PATTERNS = 'tp_replicated_parameter_patterns'
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0 = 'parameter_with_2_sub_params_cat_dim_0'
PARAMETER_WITH_SUB_PARAMS = 'parameter_with_sub_params'
SUB_PARAMS_SHAPE = 'sub_params_shape'
Loading

0 comments on commit fc537b8

Please sign in to comment.