Skip to content

Commit

Permalink
Add support for Tensorflow SparseTensors: Dot merging layer.
Browse files Browse the repository at this point in the history
Added `tf.SparseTensor` support for ops:
- matmul (sparse support was very partial before this PR)
- squeeze
- expand_dims

Added `tf.SparseTensor` support for merging layer:
- Dot
  • Loading branch information
hertschuh committed Sep 21, 2023
1 parent bf04179 commit e33e1cf
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 34 deletions.
102 changes: 92 additions & 10 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import builtins
import functools
import math
import warnings

import tensorflow as tf
from tensorflow.experimental import numpy as tfnp
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops

from keras_core.backend import config
from keras_core.backend.tensorflow.core import convert_to_tensor
Expand Down Expand Up @@ -52,22 +54,82 @@ def subtract(x1, x2):


def matmul(x1, x2):
if isinstance(x1, tf.SparseTensor):
def with_combined_batch_dimensions(a, b, fn_3d):
batch_shape = (
b.shape[:-2] if isinstance(b, tf.SparseTensor) else a.shape[:-2]
)
batch_size = math.prod(batch_shape)
a_3d = reshape(a, [batch_size] + a.shape[-2:])
b_3d = reshape(b, [batch_size] + b.shape[-2:])
result = fn_3d(a_3d, b_3d)
return reshape(result, batch_shape + result.shape[1:])

def sparse_sparse_matmul(a, b):
dtype = a.values.dtype
# Convert SparseTensors to CSR SparseMatrix.
a_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
a.indices, a.values, a.dense_shape
)
b_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
b.indices, b.values, b.dense_shape
)
# Compute the CSR SparseMatrix matrix multiplication.
result_csr = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
a_csr, b_csr, dtype
)
# Convert the CSR SparseMatrix to a SparseTensor.
res = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(
result_csr, dtype
)
return tf.SparseTensor(res.indices, res.values, res.dense_shape)

def embedding_lookup_sparse_dense_matmul(a, b):
# We need at least one id per rows for embedding_lookup_sparse,
# otherwise there will be missing rows in the output.
x1, _ = tf.sparse.fill_empty_rows(x1, 0)
a, _ = tf.sparse.fill_empty_rows(a, 0)
# We need to split x1 into separate ids and weights tensors. The ids
# should be the column indices of x1 and the values of the weights
# can continue to be the actual x1. The column arrangement of ids and
# weights does not matter as we sum over columns. See documentation for
# sparse_ops.sparse_tensor_dense_matmul for details.
# can continue to be the actual x1. The column arrangement of ids
# and weights does not matter as we sum over columns. See details in
# the documentation for sparse_ops.sparse_tensor_dense_matmul.
ids = tf.SparseTensor(
indices=x1.indices,
values=x1.indices[:, 1],
dense_shape=x1.dense_shape,
indices=a.indices,
values=a.indices[:, 1],
dense_shape=a.dense_shape,
)
weights = x1
return tf.nn.embedding_lookup_sparse(x2, ids, weights, combiner="sum")
return tf.nn.embedding_lookup_sparse(b, ids, a, combiner="sum")

# Either a or b is sparse
def sparse_dense_matmul_3d(a, b):
return tf.map_fn(
lambda x: tf.sparse.sparse_dense_matmul(x[0], x[1]),
elems=(a, b),
fn_output_signature=a.dtype,
)

x1_sparse = isinstance(x1, tf.SparseTensor)
x2_sparse = isinstance(x2, tf.SparseTensor)
if x1_sparse and x2_sparse:
if x1.shape.rank <= 3:
return sparse_sparse_matmul(x1, x2)
else:
return with_combined_batch_dimensions(x1, x2, sparse_sparse_matmul)
elif x1_sparse or x2_sparse:
# Sparse * dense or dense * sparse
sparse_rank = x1.shape.rank if x1_sparse else x2.shape.rank

# Special case: embedding_lookup_sparse for sparse * dense and rank 2
if x1_sparse and sparse_rank == 2:
return embedding_lookup_sparse_dense_matmul(x1, x2)
elif sparse_rank == 2:
return tf.sparse.sparse_dense_matmul(x1, x2)
elif sparse_rank == 3:
return sparse_dense_matmul_3d(x1, x2)
else:
return with_combined_batch_dimensions(
x1, x2, sparse_dense_matmul_3d
)

return tfnp.matmul(x1, x2)


Expand Down Expand Up @@ -354,6 +416,8 @@ def exp(x):


def expand_dims(x, axis):
if isinstance(x, tf.SparseTensor):
return tf.sparse.expand_dims(x, axis)
return tfnp.expand_dims(x, axis)


Expand Down Expand Up @@ -764,6 +828,24 @@ def sqrt(x):


def squeeze(x, axis=None):
if isinstance(x, tf.SparseTensor):
new_shape = list(x.shape)
gather_indices = list(range(len(new_shape)))
if axis is None:
for i in range(len(new_shape) - 1, -1, -1):
if new_shape[i] == 1:
del new_shape[i]
del gather_indices[i]
else:
if new_shape[axis] != 1:
raise ValueError(
f"Cannot squeeze axis {axis}, because the "
"dimension is not 1."
)
del new_shape[axis]
del gather_indices[axis]
new_indices = tf.gather(x.indices, gather_indices, axis=1)
return tf.SparseTensor(new_indices, x.values, tuple(new_shape))
return tfnp.squeeze(x, axis=axis)


Expand Down
3 changes: 0 additions & 3 deletions keras_core/layers/merging/merging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,6 @@ def test_sparse(
):
import tensorflow as tf

if layer_class == layers.Dot:
pytest.skip("Dot layer does not support sparse tensors.")

self.run_layer_test(
layer_class,
init_kwargs=init_kwargs,
Expand Down
11 changes: 7 additions & 4 deletions keras_core/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,7 +2395,7 @@ def compute_output_spec(self, x):
else:
axis = self.axis
output_shape = x_shape[:axis] + [1] + x_shape[axis:]
return KerasTensor(output_shape, dtype=x.dtype)
return KerasTensor(output_shape, dtype=x.dtype, sparse=x.sparse)


@keras_core_export(
Expand Down Expand Up @@ -3368,7 +3368,10 @@ def compute_output_spec(self, x1, x2):
del output_shape[-2]
if len(x2.shape) == 1:
del output_shape[-1]
return KerasTensor(output_shape, dtype=x1.dtype)
x1_sparse = getattr(x1, "sparse", True)
x2_sparse = getattr(x2, "sparse", True)
output_sparse = x1_sparse and x2_sparse
return KerasTensor(output_shape, dtype=x1.dtype, sparse=output_sparse)


@keras_core_export(["keras_core.ops.matmul", "keras_core.ops.numpy.matmul"])
Expand Down Expand Up @@ -5309,15 +5312,15 @@ def compute_output_spec(self, x):
input_shape = list(x.shape)
if self.axis is None:
output_shape = list(filter((1).__ne__, input_shape))
return KerasTensor(output_shape)
return KerasTensor(output_shape, dtype=x.dtype, sparse=x.sparse)
else:
if input_shape[self.axis] != 1:
raise ValueError(
f"Cannot squeeze axis {self.axis}, because the dimension "
"is not 1."
)
del input_shape[self.axis]
return KerasTensor(input_shape, dtype=x.dtype)
return KerasTensor(input_shape, dtype=x.dtype, sparse=x.sparse)


@keras_core_export(["keras_core.ops.squeeze", "keras_core.ops.numpy.squeeze"])
Expand Down
143 changes: 126 additions & 17 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,23 @@ def test_matmul(self):
y = KerasTensor([2, 3, 4])
knp.matmul(x, y)

def test_matmul_sparse(self):
x = KerasTensor((2, 3), sparse=True)
y = KerasTensor((3, 2))
result = knp.matmul(x, y)
self.assertEqual(result.shape, (2, 2))

x = KerasTensor((2, 3))
y = KerasTensor((3, 2), sparse=True)
result = knp.matmul(x, y)
self.assertEqual(result.shape, (2, 2))

x = KerasTensor((2, 3), sparse=True)
y = KerasTensor((3, 2), sparse=True)
result = knp.matmul(x, y)
self.assertEqual(result.shape, (2, 2))
self.assertTrue(result.sparse)

def test_power(self):
x = KerasTensor([2, 3])
y = KerasTensor([2, 3])
Expand Down Expand Up @@ -1322,6 +1339,13 @@ def test_reshape(self):
self.assertEqual(knp.reshape(x, (3, 2)).shape, (3, 2))
self.assertEqual(knp.reshape(x, (3, -1)).shape, (3, None))

def test_reshape_sparse(self):
x = KerasTensor([None, 3], sparse=True)
self.assertTrue(knp.reshape(x, (3, 2)).sparse)
self.assertEqual(knp.reshape(x, (3, 2)).shape, (3, 2))
self.assertTrue(knp.reshape(x, (3, -1)).sparse)
self.assertEqual(knp.reshape(x, (3, -1)).shape, (3, None))

def test_roll(self):
x = KerasTensor([None, 3])
self.assertEqual(knp.roll(x, 1).shape, (None, 3))
Expand Down Expand Up @@ -1475,6 +1499,27 @@ def test_squeeze(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.squeeze(x).shape, (2, 3))

x = KerasTensor([2, 1, 3])
self.assertEqual(knp.squeeze(x).shape, (2, 3))
self.assertEqual(knp.squeeze(x, axis=1).shape, (2, 3))
self.assertEqual(knp.squeeze(x, axis=-2).shape, (2, 3))

with self.assertRaises(ValueError):
knp.squeeze(x, axis=0)

def test_squeeze_sparse(self):
x = KerasTensor([2, 3], sparse=True)
self.assertTrue(knp.squeeze(x).sparse)
self.assertEqual(knp.squeeze(x).shape, (2, 3))

x = KerasTensor([2, 1, 3], sparse=True)
self.assertTrue(knp.squeeze(x).sparse)
self.assertEqual(knp.squeeze(x).shape, (2, 3))
self.assertTrue(knp.squeeze(x, axis=1).sparse)
self.assertEqual(knp.squeeze(x, axis=1).shape, (2, 3))
self.assertTrue(knp.squeeze(x, axis=-2).sparse)
self.assertEqual(knp.squeeze(x, axis=-2).shape, (2, 3))

def test_transpose(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.transpose(x).shape, (3, 2))
Expand Down Expand Up @@ -1641,6 +1686,15 @@ def test_expand_dims(self):
self.assertEqual(knp.expand_dims(x, 1).shape, (2, 1, 3, 4))
self.assertEqual(knp.expand_dims(x, -2).shape, (2, 3, 1, 4))

def test_expand_dims_sparse(self):
x = KerasTensor([2, 3, 4], sparse=True)
self.assertTrue(knp.expand_dims(x, 0).sparse)
self.assertEqual(knp.expand_dims(x, 0).shape, (1, 2, 3, 4))
self.assertTrue(knp.expand_dims(x, 1).sparse)
self.assertEqual(knp.expand_dims(x, 1).shape, (2, 1, 3, 4))
self.assertTrue(knp.expand_dims(x, -2).sparse)
self.assertEqual(knp.expand_dims(x, -2).shape, (2, 3, 1, 4))

def test_expm1(self):
x = KerasTensor([2, 3])
self.assertEqual(knp.expm1(x).shape, (2, 3))
Expand Down Expand Up @@ -1957,30 +2011,45 @@ def test_matmul(self):
self.assertAllClose(knp.Matmul()(x, y), np.matmul(x, y))
self.assertAllClose(knp.Matmul()(x, z), np.matmul(x, z))

@parameterized.parameters(
("float16",),
("float32",),
("float64",),
("uint8",),
("int8",),
("int16",),
("int32",),
@parameterized.product(
(
{"x_shape": (5, 3), "y_shape": (3, 4)},
{"x_shape": (2, 5, 3), "y_shape": (2, 3, 4)},
{"x_shape": (2, 2, 5, 3), "y_shape": (2, 2, 3, 4)},
),
dtype=["float16", "float32", "float64", "int32"],
x_sparse=[False, True],
y_sparse=[False, True],
)
@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors.",
)
def test_matmul_sparse(self, dtype):
def test_matmul_sparse(self, dtype, x_shape, y_shape, x_sparse, y_sparse):
import tensorflow as tf

if x_sparse and y_sparse and dtype in ("float16", "int32"):
pytest.skip(f"Sparse sparse matmul unsupported for {dtype}")

rng = np.random.default_rng(0)
x1 = 4 * rng.standard_normal((5, 3))
x1 = tf.sparse.from_dense(tf.cast(tf.nn.dropout(x1, 0.7), dtype=dtype))
x2 = (4 * rng.standard_normal((3, 4))).astype(dtype)
self.assertAllClose(
knp.matmul(x1, x2),
np.matmul(tf.sparse.to_dense(x1).numpy(), x2),
)
if x_sparse:
x = 4 * rng.standard_normal(x_shape)
x = tf.sparse.from_dense(tf.cast(tf.nn.dropout(x, 0.7), dtype))
x_np = tf.sparse.to_dense(x).numpy()
else:
x = x_np = (4 * rng.standard_normal(x_shape)).astype(dtype)
y = y_np = (4 * rng.standard_normal(y_shape)).astype(dtype)
if y_sparse:
y = 4 * rng.standard_normal(y_shape)
y = tf.sparse.from_dense(tf.cast(tf.nn.dropout(y, 0.7), dtype))
y_np = tf.sparse.to_dense(y).numpy()
else:
y = y_np = (4 * rng.standard_normal(y_shape)).astype(dtype)

atol = 0.1 if dtype == "float16" else 1e-5
self.assertAllClose(knp.matmul(x, y), np.matmul(x_np, y_np), atol=atol)
if x_sparse and y_sparse:
self.assertIsInstance(knp.matmul(x, y), tf.SparseTensor)

def test_power(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
Expand Down Expand Up @@ -2722,13 +2791,32 @@ def test_absolute(self):
self.assertAllClose(knp.Absolute()(x), np.absolute(x))

def test_squeeze(self):
x = np.ones([1, 2, 3, 4, 5])
x = np.ones([1, 3, 1, 5])
self.assertAllClose(knp.squeeze(x), np.squeeze(x))
self.assertAllClose(knp.squeeze(x, axis=0), np.squeeze(x, axis=0))

self.assertAllClose(knp.Squeeze()(x), np.squeeze(x))
self.assertAllClose(knp.Squeeze(axis=0)(x), np.squeeze(x, axis=0))

@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors.",
)
def test_squeeze_sparse(self):
import tensorflow as tf

x = tf.SparseTensor(
indices=[[0, 0, 0, 0], [0, 2, 0, 4]],
values=[1, 2],
dense_shape=(1, 3, 1, 5),
)
x_np = tf.sparse.to_dense(x).numpy()
self.assertAllClose(knp.squeeze(x), np.squeeze(x_np))
self.assertAllClose(knp.squeeze(x, axis=0), np.squeeze(x_np, axis=0))

self.assertAllClose(knp.Squeeze()(x), np.squeeze(x_np))
self.assertAllClose(knp.Squeeze(axis=0)(x), np.squeeze(x_np, axis=0))

def test_transpose(self):
x = np.ones([1, 2, 3, 4, 5])
self.assertAllClose(knp.transpose(x), np.transpose(x))
Expand Down Expand Up @@ -3189,6 +3277,27 @@ def test_expand_dims(self):
self.assertAllClose(knp.ExpandDims(1)(x), np.expand_dims(x, 1))
self.assertAllClose(knp.ExpandDims(-2)(x), np.expand_dims(x, -2))

@pytest.mark.skipif(
not backend.SUPPORTS_SPARSE_TENSORS,
reason="Backend does not support sparse tensors.",
)
def test_expand_dims_sparse(self):
import tensorflow as tf

x = tf.SparseTensor(
indices=[[0, 0], [1, 2]],
values=[1, 2],
dense_shape=(2, 3),
)
x_np = tf.sparse.to_dense(x).numpy()
self.assertAllClose(knp.expand_dims(x, 0), np.expand_dims(x_np, 0))
self.assertAllClose(knp.expand_dims(x, 1), np.expand_dims(x_np, 1))
self.assertAllClose(knp.expand_dims(x, -2), np.expand_dims(x_np, -2))

self.assertAllClose(knp.ExpandDims(0)(x), np.expand_dims(x_np, 0))
self.assertAllClose(knp.ExpandDims(1)(x), np.expand_dims(x_np, 1))
self.assertAllClose(knp.ExpandDims(-2)(x), np.expand_dims(x_np, -2))

def test_expm1(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.expm1(x), np.expm1(x))
Expand Down

0 comments on commit e33e1cf

Please sign in to comment.