Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Praxis 1.2.0 release #32

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# Version: 1.2.0
## Major Features and Improvements
## Breaking changes
## Deprecations
## Note
* Version: 1.2.0
* Build Date: 20231016
* Praxis commit: 7bd63412bf86a68e09fcd9455f76a4909d19377e
# Version: 1.1.0
## Major Features and Improvements
* Move to python 3.10 as the minimal python requirement (previously on python 3.8).
Expand Down
18 changes: 13 additions & 5 deletions praxis/layers/quantization/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,17 @@ def einsum(
Returns:
A JTensor.
"""
# Non performent equation for inference testing purposes
# TODO: b/305735188 - Improve the performance by using the integer einsum op.
if zp_act is not None:
dequantized_x = jnp.multiply(x, scale_act) - zp_act
# explicit broadcast if necessary.
if w.ndim == 3 and scale.ndim == 1:
scale = jnp.expand_dims(scale, (1, 2))
dequantized_w = jnp.multiply(w, scale)
if zp is not None:
dequantized_w = dequantized_w - zp
return jnp.einsum(eqn, dequantized_x, dequantized_w)

use_int_dot_general = (
x.dtype in QUANTIZED_TYPES and w.dtype in QUANTIZED_TYPES
Expand Down Expand Up @@ -302,11 +313,6 @@ def einsum(
offset = compute_offset(x, zp, eqn)
ret = ret - offset

if zp_act is not None:
# Non performent equation for inference testing purposes
dequantized_x = scale_act * x - zp_act
dequantized_w = scale * w - zp
ret = jnp.einsum(eqn, dequantized_x, dequantized_w)
return ret


Expand Down Expand Up @@ -623,6 +629,8 @@ def reduce_einsum_activation_precision(

if squeeze:
scale = jnp.squeeze(scale, axis=contract_dims)
if zp is not None:
zp = jnp.squeeze(zp, axis=contract_dims)
return t, scale, zp


Expand Down
19 changes: 18 additions & 1 deletion praxis/layers/quantization/operations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,24 @@ def test_quantized_einsum_with_asym_weight_act(self, eqn):

ret = operations.einsum(eqn, qx, qw, sw, zpw, sx, zpx)
expected = jnp.einsum(eqn, x, w)
self.assertAllClose(ret, expected, rtol=0.1, atol=0.5)
self.assertAllClose(ret, expected, rtol=0.02, atol=0.02)

@parameterized.named_parameters(
('eqn_with_dot', '...y,yz->...z'),
)
def test_quantized_einsum_with_aym_weight_asym_act(self, eqn):
w = jax.random.uniform(jax.random.PRNGKey(0), (4, 3))
x = jax.random.uniform(jax.random.PRNGKey(0), (2, 4))
qw, sw, zpw = operations.reduce_einsum_weight_precision(
eqn, w, use_symmetric=True
)
qx, sx, zpx = operations.reduce_einsum_activation_precision(
eqn, x, symmetric=False
)

ret = operations.einsum(eqn, qx, qw, sw, zpw, sx, zpx)
expected = jnp.einsum(eqn, x, w)
self.assertAllClose(ret, expected, rtol=0.02, atol=0.02)

@parameterized.parameters(
('ab,bc->ac', (10, 4), (4, 5)),
Expand Down
37 changes: 34 additions & 3 deletions praxis/layers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ def for_transformer(
quantize_init_from_checkpoint_rules_task: bool = False,
block_size: int = 0,
# Internal quantization parameters.
num_bits_act: int | None = None,
use_symmetric_act: bool | None = None,
):
"""Find and quantize transformer.

Expand Down Expand Up @@ -420,6 +422,10 @@ def for_transformer(
that are defined in task_p.train.init_from_checkpoint_rules.values()
block_size: block size for sub-channel quantization. Defaults to 0, which
means off.
num_bits_act: The number of bits used for activation quantization. Only
valid when weight_quant_only is false.
use_symmetric_act: Use symmetric activation quantization.Only valid when
weight_quant_only is false.

Returns:
A modifier that quantizes transformers when applied to a config.
Expand Down Expand Up @@ -463,6 +469,8 @@ def task(self):
quantize_self_attention=quantize_self_attention,
quantize_cross_attention=quantize_cross_attention,
softmax_only=softmax_only,
use_symmetric_act=use_symmetric_act,
num_bits_act=num_bits_act,
)
return task_p

Expand Down Expand Up @@ -568,6 +576,8 @@ def set_transformer_quantization(
use_int4_packed_weights: bool = True,
int4_packed_weights_container_dtype: jnp.dtype = jnp.int32,
# Internal quantization parameters.
num_bits_act: int | None = None,
use_symmetric_act: bool | None = None,
):
"""Sets quantization params for TransformerLm or TransformerEncoderDecoder.

Expand Down Expand Up @@ -611,6 +621,10 @@ def set_transformer_quantization(
False int4 weights will be kept in int8.
int4_packed_weights_container_dtype: Container type for int4 weights: int32
to pack 8 int4s, or int8 to pack 2 int4s.
num_bits_act: The number of bits used for activation quantization. Only
valid when weight_quant_only is false.
use_symmetric_act: Use symmetric activation quantization. Only valid when
weight_quant_only is false.
"""
weight_quantization_params = WeightQuantizationParams(
precision=num_bits,
Expand All @@ -621,9 +635,26 @@ def set_transformer_quantization(
int4_packed_weights_container_dtype=int4_packed_weights_container_dtype,
# Pass internal quantization parameters.
)
act_quantization_params = (
None if weight_quant_only else ActQuantizationParams(precision=num_bits)
)
act_quantization_params = None
if (
num_bits_act is not None or use_symmetric_act is not None
) and weight_quant_only:
raise ValueError(
f'Activation quantization params (`num_bits_act` and'
f' `use_symmetric_act`) should not be set when `weight_quant_only` is'
f' set to True.'
)
if not weight_quant_only:
if num_bits_act == None or use_symmetric_act == None:
raise ValueError(
f'Activation quantization params (`num_bits_act` and'
f' `use_symmetric_act`) have to be set when `weight_quant_only` is'
f' set to false.'
)
act_quantization_params = ActQuantizationParams(
precision=num_bits_act,
symmetric=use_symmetric_act,
)

transformer_tpls = utils.find_target_tpl(
config, layers.transformers.Transformer
Expand Down
2 changes: 0 additions & 2 deletions praxis/layers/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,6 @@ def setup(self) -> None:
xformer_params = xformer_params.pipeline_stage
if issubclass(xformer_params.cls, transformers.StackedTransformerRepeated):
xformer_params = xformer_params.block
if not issubclass(xformer_params.cls, transformers.StackedTransformer):
assert False, f'{xformer_params.cls} not supported.'
assert (
xformer_params.model_dims == 0
or xformer_params.model_dims == self.model_dims
Expand Down
2 changes: 1 addition & 1 deletion praxis/pip_package/cloudbuild-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ steps:

substitutions:
_PYTHON_VERSION: '3.10'
_RELEASE_VERSION: '1.1.0' # or rX.Y
_RELEASE_VERSION: '1.2.0' # or rX.Y
_IMAGE_NAME: 'praxis_${_RELEASE_VERSION}_${_PYTHON_VERSION}'
_WHEEL_FOLDER: '/tmp/wheels'
options:
Expand Down
Loading