Skip to content

Commit

Permalink
Fix einsum _int8_call (keras-team#19570)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Apr 20, 2024
1 parent 86b08c8 commit 261fa4e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
4 changes: 3 additions & 1 deletion keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ def _int8_build(
self._custom_gradient_equation,
self._kernel_reverse_transpose_axes,
) = _analyze_quantization_info(self.equation, self.input_spec.ndim)
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
axis=self._input_reduced_axes
)
self._kernel = self.add_weight(
name="kernel",
shape=kernel_shape,
Expand Down
22 changes: 22 additions & 0 deletions keras/src/layers/core/einsum_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,28 @@ def test_quantize_int8(self):
backend.standardize_dtype(layer.kernel_scale.dtype), "float32"
)

@parameterized.named_parameters(
("btnh,nhd->btd", "btnh,nhd->btd", (None, 8), (1, 2, 2, 4)),
("btd,ndh->btnh", "btd,ndh->btnh", (None, 2, 8), (1, 2, 4)),
("btd,df->btf", "btd,df->btf", (None, 4), (1, 2, 4)),
)
@pytest.mark.skipif(
backend.backend() == "numpy",
reason=f"{backend.backend()} does not support ops.custom_gradient.",
)
def test_quantize_int8_with_specific_equations(
self, equation, output_shape, input_shape
):
layer = layers.EinsumDense(equation=equation, output_shape=output_shape)
layer.build(input_shape)
x = ops.random.uniform(input_shape)
y_float = layer(x)

layer.quantize("int8")
y_quantized = layer(x)
mse = ops.mean(ops.square(y_float - y_quantized))
self.assertLess(mse, 1e-3) # A weak correctness test

@parameterized.named_parameters(
("int8", "int8"),
("float8", "float8"),
Expand Down

0 comments on commit 261fa4e

Please sign in to comment.