diff --git a/recml/core/data/iterator.py b/recml/core/data/iterator.py index 5c389b1..23ac5cf 100644 --- a/recml/core/data/iterator.py +++ b/recml/core/data/iterator.py @@ -57,16 +57,15 @@ def __next__(self) -> clu_data.Element: if self._prefetched_batch is not None: batch = self._prefetched_batch self._prefetched_batch = None - return batch - - batch = next(self._iterator) - if self._postprocessor is not None: - batch = self._postprocessor(batch) + else: + batch = next(self._iterator) + if self._postprocessor is not None: + batch = self._postprocessor(batch) def _maybe_to_numpy( - x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor, + x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray, ) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor: - if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)): return x if hasattr(x, "_numpy"): numpy = x._numpy() # pylint: disable=protected-access @@ -83,13 +82,16 @@ def _maybe_to_numpy( @property def element_spec(self) -> clu_data.ElementSpec: if self._element_spec is not None: - batch = self._element_spec - else: - batch = self.__next__() - self._prefetched_batch = batch + return self._element_spec + + batch = next(self._iterator) + if self._postprocessor is not None: + batch = self._postprocessor(batch) + + self._prefetched_batch = batch def _to_element_spec( - x: np.ndarray | tf.SparseTensor | tf.RaggedTensor, + x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray, ) -> clu_data.ArraySpec: if isinstance(x, tf.SparseTensor): return clu_data.ArraySpec( @@ -101,6 +103,10 @@ def _to_element_spec( dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error ) + if isinstance(x, tf.Tensor): + return clu_data.ArraySpec( + dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list()) + ) return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape)) element_spec = tf.nest.map_structure(_to_element_spec, batch) diff --git a/recml/core/ops/embedding_ops.py b/recml/core/ops/embedding_ops.py new file mode 100644 index 0000000..b8abc84 --- /dev/null +++ b/recml/core/ops/embedding_ops.py @@ -0,0 +1,114 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Embedding lookup ops.""" + +from collections.abc import Mapping, Sequence +import dataclasses +import functools +from typing import Any + +from etils import epy +import jax +from jax.experimental import shard_map + +with epy.lazy_imports(): + # pylint: disable=g-import-not-at-top + from jax_tpu_embedding.sparsecore.lib.nn import embedding + # pylint: enable=g-import-not-at-top + + +@dataclasses.dataclass +class SparsecoreParams: + """Embedding parameters.""" + + feature_specs: embedding.Nested[Any] # Nested[FeatureSpec] + abstract_mesh: jax.sharding.AbstractMesh + data_axes: Sequence[str | None] + embedding_axes: Sequence[str | None] + sharding_strategy: str + + +@functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) +def sparsecore_lookup( + sparsecore_params: SparsecoreParams, + tables: Mapping[str, tuple[jax.Array, ...]], + csr_inputs: tuple[jax.Array, ...], +): + return shard_map.shard_map( + functools.partial( + embedding.tpu_sparse_dense_matmul, + global_device_count=sparsecore_params.abstract_mesh.size, + feature_specs=sparsecore_params.feature_specs, + sharding_strategy=sparsecore_params.sharding_strategy, + ), + mesh=sparsecore_params.abstract_mesh, + in_specs=( + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes), + ), + out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + check_rep=False, + )(*csr_inputs, tables) + + +def _emb_lookup_fwd( + sparsecore_params: SparsecoreParams, + tables: Mapping[str, tuple[jax.Array, ...]], + csr_inputs: tuple[jax.Array, ...], +): + out = sparsecore_lookup(sparsecore_params, tables, csr_inputs) + return out, (tables, csr_inputs) + + +def _emb_lookup_bwd( + sparsecore_params: SparsecoreParams, + res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]], + gradients: embedding.Nested[jax.Array], +) -> tuple[embedding.Nested[jax.Array], None]: + """Backward pass for embedding lookup.""" + (tables, csr_inputs) = res + + emb_table_grads = shard_map.shard_map( + functools.partial( + embedding.tpu_sparse_dense_matmul_grad, + feature_specs=sparsecore_params.feature_specs, + sharding_strategy=sparsecore_params.sharding_strategy, + ), + mesh=sparsecore_params.abstract_mesh, + in_specs=( + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes), + ), + out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes), + check_rep=False, + )(gradients, *csr_inputs, tables) + + # `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict). + # It may not be the same type as the embedding table (e.g. FrozenDict). + # Here we use flatten / unflatten to ensure the types are the same. + emb_table_grads = jax.tree.unflatten( + jax.tree.structure(tables), jax.tree.leaves(emb_table_grads) + ) + + return emb_table_grads, None + + +sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd) diff --git a/recml/core/training/jax.py b/recml/core/training/jax.py index 14db581..3126f6f 100644 --- a/recml/core/training/jax.py +++ b/recml/core/training/jax.py @@ -26,6 +26,7 @@ from clu import periodic_actions import clu.metrics as clu_metrics from flax import struct +import flax.linen as nn import jax import jax.numpy as jnp import keras @@ -67,43 +68,85 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]): step: A counter of the current step of the job. It starts at zero and it is incremented by 1 on a call to `state.update(...)`. This should be a Jax array and not a Python integer. - apply: A function that can be used to apply the forward pass of the model. - For Flax models this is usually set to `model.apply`. params: A pytree of trainable variables that will be updated by `tx` and used in `apply`. tx: An optax gradient transformation that will be used to update the parameters contained in `params` on a call to `state.update(...)`. opt_state: The optimizer state for `tx`. This is usually created by calling `tx.init(params)`. + _apply: An optional function that can be used to apply the forward pass of + the model. For Flax models this is usually set to `model.apply` while for + Haiku models this is usually set to `transform.apply`. + _model: An optional reference to a stateless Flax model for convenience. mutable: A pytree of mutable variables that are used by `apply`. meta: Arbitrary metadata that is recorded on the state. This can be useful for tracking additional references in the state. """ step: jax.Array - apply: Callable[..., Any] = struct.field(pytree_node=False) params: PyTree = struct.field(pytree_node=True) tx: optax.GradientTransformation = struct.field(pytree_node=False) opt_state: optax.OptState = struct.field(pytree_node=True) mutable: PyTree = struct.field(pytree_node=True, default_factory=dict) meta: MetaT = struct.field(pytree_node=False, default_factory=dict) + _apply: Callable[..., Any] | None = struct.field( + pytree_node=False, default_factory=None + ) + _model: nn.Module | None = struct.field(pytree_node=False, default=None) + + @property + def model(self) -> nn.Module: + """Returns a reference to the model used to create the state.""" + if self._model is None: + raise ValueError("No Flax `model` is set on the state.") + return self._model + + def apply(self, *args, **kwargs) -> Any: + """Applies the forward pass of the model.""" + if self._apply is None: + raise ValueError("No `apply` function is set on the state.") + return self._apply(*args, **kwargs) @classmethod def create( cls, *, - apply: Callable[..., Any], + apply: Callable[..., Any] | None = None, + model: nn.Module | None = None, params: PyTree, tx: optax.GradientTransformation, **kwargs, ) -> Self: - """Creates a new instance from a Jax apply function and Optax optimizer.""" + """Creates a new instance from a Jax model / apply fn and Optax optimizer. + + Args: + apply: A function that can be used to apply the forward pass of the model. + For Flax models this is usually set to `model.apply`. This cannot be set + along with `model`. + model: A reference to a stateless Flax model. This cannot be set along + with `apply`. When set the `apply` attribute of the state will be set to + `model.apply`. + params: A pytree of trainable variables that will be updated by `tx` and + used in `apply`. + tx: An optax gradient transformation that will be used to update the + parameters contained in `params` on a call to `state.update(...)`. + **kwargs: Other updates to set on the new state. + + Returns: + An new instance of the state. + """ + if apply is not None and model is not None: + raise ValueError("Only one of `apply` or `model` can be provided.") + elif model is not None: + apply = model.apply + return cls( step=jnp.zeros([], dtype=jnp.int32), - apply=apply, params=params, tx=tx, opt_state=tx.init(params), + _apply=apply, + _model=model, **kwargs, ) diff --git a/recml/core/training/optax_factory.py b/recml/core/training/optax_factory.py index 0775c26..bc3fabc 100644 --- a/recml/core/training/optax_factory.py +++ b/recml/core/training/optax_factory.py @@ -29,10 +29,10 @@ def _default_weight_decay_mask(params: optax.Params) -> optax.Params: def _regex_mask(regex: str) -> Callable[[optax.Params], optax.Params]: - """Returns a weight decay mask that applies to parameters matching a regex.""" + """Returns a mask that applies to parameters matching a regex.""" def _matches_regex(path: tuple[str, ...], _: Any) -> bool: - key = "/".join([jax.tree_util.keystr((k,), simple=True) for k in path]) + key = '/'.join([jax.tree_util.keystr((k,), simple=True) for k in path]) return re.fullmatch(regex, key) is not None def _mask(params: optax.Params) -> optax.Params: @@ -54,6 +54,8 @@ class OptimizerFactory(types.Factory[optax.GradientTransformation]): magnitude of the gradients during optimization. Defaults to None. weight_decay_mask: The weight decay mask to use when applying weight decay. Defaults applying weight decay to all non-1D parameters. + freeze_mask: Optional mask to freeze parameters during optimization. + Defaults to None. Example usage: @@ -78,6 +80,7 @@ class OptimizerFactory(types.Factory[optax.GradientTransformation]): weight_decay_mask: str | Callable[[optax.Params], optax.Params] = ( _default_weight_decay_mask ) + freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None def make(self) -> optax.GradientTransformation: if self.grad_clip_norm is not None: @@ -99,13 +102,30 @@ def make(self) -> optax.GradientTransformation: else: weight_decay = optax.identity() - return optax.chain(*[ + tx = optax.chain(*[ apply_clipping, self.scaling, weight_decay, lr_scaling, ]) + if self.freeze_mask is not None: + if isinstance(self.freeze_mask, str): + mask = _regex_mask(self.freeze_mask) + else: + mask = self.freeze_mask + + def _param_labels(params: optax.Params) -> optax.Params: + return jax.tree.map( + lambda p: 'frozen' if mask(p) else 'trainable', params + ) + + tx = optax.multi_transform( + transforms={'trainable': tx, 'frozen': optax.set_to_zero()}, + param_labels=_param_labels, + ) + return tx + class AdamFactory(types.Factory[optax.GradientTransformation]): """Adam optimizer factory. @@ -121,6 +141,8 @@ class AdamFactory(types.Factory[optax.GradientTransformation]): magnitude of the gradients during optimization. Defaults to None. weight_decay_mask: The weight decay mask to use when applying weight decay. Defaults applying weight decay to all non-1D parameters. + freeze_mask: Optional mask to freeze parameters during optimization. + Defaults to None. Example usage: ``` @@ -143,6 +165,7 @@ class AdamFactory(types.Factory[optax.GradientTransformation]): weight_decay_mask: str | Callable[[optax.Params], optax.Params] = ( _default_weight_decay_mask ) + freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None def make(self) -> optax.GradientTransformation: return OptimizerFactory( @@ -164,6 +187,8 @@ class AdagradFactory(types.Factory[optax.GradientTransformation]): eps: The epsilon coefficient for the Adagrad optimizer. Defaults to 1e-7. grad_clip_norm: Optional gradient clipping norm to limit the maximum magnitude of the gradients during optimization. Defaults to None. + freeze_mask: Optional mask to freeze parameters during optimization. + Defaults to None. Example usage: ``` @@ -175,6 +200,7 @@ class AdagradFactory(types.Factory[optax.GradientTransformation]): initial_accumulator_value: float = 0.1 eps: float = 1e-7 grad_clip_norm: float | None = None + freeze_mask: str | Callable[[optax.Params], optax.Params] | None = None def make(self) -> optax.GradientTransformation: return OptimizerFactory( diff --git a/recml/core/training/partitioning.py b/recml/core/training/partitioning.py index 1fb7e1e..2eda2e0 100644 --- a/recml/core/training/partitioning.py +++ b/recml/core/training/partitioning.py @@ -113,7 +113,9 @@ def partition_init( def _wrapped_init(batch: PyTree) -> State: with jax.sharding.use_mesh(self.mesh): - return init_fn(batch) + state = init_fn(batch) + state = _maybe_unbox_state(state) + return state return _wrapped_init @@ -235,20 +237,11 @@ def partition_init( ) compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding) - def _maybe_unbox(x: Any) -> Any: - if isinstance(x, nn.spmd.LogicallyPartitioned): - return x.unbox() - return x - def _init(batch: PyTree) -> State: with self.mesh_context_manager(self.mesh): state = compiled_init_fn(batch) - unboxed_state = jax.tree.map( - _maybe_unbox, - state, - is_leaf=lambda k: isinstance(k, nn.spmd.LogicallyPartitioned), - ) - return unboxed_state + state = _maybe_unbox_state(state) + return state self.abstract_batch = abstract_batch self.abstract_state = abstract_state @@ -288,3 +281,16 @@ def _step(batch: PyTree, state: State) -> Any: return step_fn(batch, state) return _step + + +def _maybe_unbox_state(x: Any) -> Any: + def _maybe_unbox(x: Any) -> Any: + if isinstance(x, nn.Partitioned): + return x.unbox() + return x + + return jax.tree.map( + _maybe_unbox, + x, + is_leaf=lambda k: isinstance(k, nn.Partitioned), + ) diff --git a/recml/core/training/partitioning_test.py b/recml/core/training/partitioning_test.py index 1103473..55d286f 100644 --- a/recml/core/training/partitioning_test.py +++ b/recml/core/training/partitioning_test.py @@ -136,12 +136,10 @@ def _init(batch: jax.Array) -> jax.Array: state = partitioner.partition_init(_init, abstract_batch=sharded_inputs)( sharded_inputs ) - self.assertIsInstance(state, nn.Partitioned) - unboxed_state = state.unbox() - self.assertIsInstance(unboxed_state, jax.Array) - self.assertSequenceEqual(unboxed_state.shape, (128, 16)) - self.assertEqual(unboxed_state.sharding, partitioner.state_sharding) + self.assertIsInstance(state, jax.Array) + self.assertSequenceEqual(state.shape, (128, 16)) + self.assertEqual(state.sharding, partitioner.state_sharding) self.assertEqual( partitioner.state_sharding, jax.sharding.NamedSharding( diff --git a/recml/examples/dlrm_experiment.py b/recml/examples/dlrm_experiment.py new file mode 100644 index 0000000..d11bc50 --- /dev/null +++ b/recml/examples/dlrm_experiment.py @@ -0,0 +1,383 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DLRM experiment.""" + +from __future__ import annotations + +from collections.abc import Iterator, Mapping, Sequence +import dataclasses +from typing import Generic, Literal, TypeVar + +import fiddle as fdl +import flax.linen as nn +import jax +import jax.numpy as jnp +import jaxtyping as jt +import numpy as np +import optax +from recml import mlrx +from recml.layers.linen import sparsecore +import tensorflow as tf + + +@dataclasses.dataclass +class Feature: + name: str + + +FeatureT = TypeVar('FeatureT', bound=Feature) + + +@dataclasses.dataclass +class DenseFeature(Feature): + """Dense feature.""" + + +@dataclasses.dataclass +class SparseFeature(Feature): + """Sparse feature.""" + + vocab_size: int + embedding_dim: int + max_sequence_length: int | None = None + combiner: Literal['mean', 'sum', 'sqrtn'] = 'mean' + sparsity: float = 0.8 + + +@dataclasses.dataclass +class FeatureSet(Generic[FeatureT]): + """A collection of features.""" + + features: Sequence[FeatureT] + + def __post_init__(self): + feature_names = [f.name for f in self.features] + if len(feature_names) != len(set(feature_names)): + raise ValueError( + f'Feature names must be unique. Got names: {feature_names}.' + ) + + def dense_features(self) -> FeatureSet[DenseFeature]: + return FeatureSet([f for f in self if isinstance(f, DenseFeature)]) + + def sparse_features(self) -> FeatureSet[SparseFeature]: + return FeatureSet([f for f in self if isinstance(f, SparseFeature)]) + + def __iter__(self) -> Iterator[FeatureT]: + return iter(self.features) + + def __or__(self, other: FeatureSet[Feature]) -> FeatureSet[Feature]: + return FeatureSet([*self.features, *other.features]) + + +class DLRMModel(nn.Module): + """DLRM DCN v2 model.""" + + features: FeatureSet + embedding_optimizer: sparsecore.embedding_spec.OptimizerSpec + bottom_mlp_dims: Sequence[int] + top_mlp_dims: Sequence[int] + dcn_layers: int + dcn_inner_dim: int + + # We need to track the embedder on the Flax module to ensure it is not + # re-created on cloning. It is not possible to create an embedder inside + # setup() because it is called lazily at compile time. The embedder needs + # to be created before `model.init` so we can use it to create a preprocessor. + # A simpler pattern that works is passing `embedder` directly to the module. + _embedder: sparsecore.SparsecoreEmbedder | None = None + + @property + def embedder(self) -> sparsecore.SparsecoreEmbedder: + if self._embedder is not None: + return self._embedder + + embedder = sparsecore.SparsecoreEmbedder( + specs={ + f.name: sparsecore.EmbeddingSpec( + input_dim=f.vocab_size, + embedding_dim=f.embedding_dim, + max_sequence_length=f.max_sequence_length, + combiner=f.combiner, + ) + for f in self.features.sparse_features() + }, + optimizer=self.embedding_optimizer, + ) + object.__setattr__(self, '_embedder', embedder) + return embedder + + def bottom_mlp(self, inputs: Mapping[str, jt.Array]) -> jt.Array: + x = jnp.concatenate( + [inputs[f.name] for f in self.features.dense_features()], axis=-1 + ) + + for dim in self.bottom_mlp_dims: + x = nn.Dense(dim)(x) + x = nn.relu(x) + return x + + def top_mlp(self, x: jt.Array) -> jt.Array: + for dim in self.top_mlp_dims[:-1]: + x = nn.Dense(dim)(x) + x = nn.relu(x) + + x = nn.Dense(self.top_mlp_dims[-1])(x) + return x + + def dcn(self, x0: jt.Array) -> jt.Array: + xl = x0 + input_dim = x0.shape[-1] + + for i in range(self.dcn_layers): + u_kernel = self.param( + f'u_kernel_{i}', + nn.initializers.xavier_normal(), + (input_dim, self.dcn_inner_dim), + ) + v_kernel = self.param( + f'v_kernel_{i}', + nn.initializers.xavier_normal(), + (self.dcn_inner_dim, input_dim), + ) + bias = self.param(f'bias_{i}', nn.initializers.zeros, (input_dim,)) + + u = jnp.matmul(xl, u_kernel) + v = jnp.matmul(u, v_kernel) + v += bias + + xl = x0 * v + xl + + return xl + + @nn.compact + def __call__( + self, inputs: Mapping[str, jt.Array], training: bool = False + ) -> jt.Array: + dense_embeddings = self.bottom_mlp(inputs) + sparse_embeddings = self.embedder.make_sparsecore_module()(inputs) + sparse_embeddings = jax.tree.flatten(sparse_embeddings)[0] + concatenated_embeddings = jnp.concatenate( + (dense_embeddings, *sparse_embeddings), axis=-1 + ) + interaction_outputs = self.dcn(concatenated_embeddings) + predictions = self.top_mlp(interaction_outputs) + predictions = jnp.reshape(predictions, (-1,)) + return predictions + + +class CriteoFactory(mlrx.Factory[tf.data.Dataset]): + """Data loader for dummy Criteo data optimized for Jax training.""" + + features: FeatureSet + global_batch_size: int + use_cached_data: bool = False + + def make(self) -> tf.data.Dataset: + data = {} + batch_size = self.global_batch_size // jax.process_count() + + for f in self.features.dense_features(): + feature = np.random.normal(0.0, 1.0, size=(batch_size, 1)) + data[f.name] = feature.astype(np.float32) + + for f in self.features.sparse_features(): + non_zero_mask = ( + np.random.normal(size=(batch_size, f.embedding_dim)) > f.sparsity + ) + sparse_feature = np.random.randint( + low=0, + high=f.vocab_size, + size=(batch_size, f.embedding_dim), + ) + sparse_feature = np.where( + non_zero_mask, sparse_feature, np.zeros_like(sparse_feature) + ) + data[f.name] = tf.constant(sparse_feature, dtype=tf.int64) + + label = np.random.randint(0, 2, size=(batch_size,)) + + dataset = tf.data.Dataset.from_tensors((data, label)) + dataset = dataset.take(1).repeat() + dataset = dataset.prefetch(buffer_size=2048) + options = tf.data.Options() + options.deterministic = False + options.threading.private_threadpool_size = 96 + dataset = dataset.with_options(options) + return dataset + + +@dataclasses.dataclass +class PredictionTask(mlrx.JaxTask): + """Prediction task.""" + + train_data: CriteoFactory + eval_data: CriteoFactory + model: DLRMModel + optimizer: mlrx.Factory[optax.GradientTransformation] + + def create_datasets( + self, + ) -> tuple[mlrx.TFDatasetIterator, mlrx.TFDatasetIterator]: + global_batch_size = self.train_data.global_batch_size + train_iter = mlrx.TFDatasetIterator( + dataset=self.train_data.make(), + postprocessor=self.model.embedder.make_preprocessor(global_batch_size), + ) + eval_iter = mlrx.TFDatasetIterator( + dataset=self.eval_data.make(), + postprocessor=self.model.embedder.make_preprocessor(global_batch_size), + ) + return train_iter, eval_iter + + def create_state(self, batch: jt.PyTree, rng: jt.Array) -> mlrx.JaxState: + inputs, _ = batch + params = self.model.init(rng, inputs) + optimizer = self.optimizer.make() + return mlrx.JaxState.create(params=params, tx=optimizer) + + def train_step( + self, batch: jt.PyTree, state: mlrx.JaxState, rng: jt.Array + ) -> tuple[mlrx.JaxState, Mapping[str, mlrx.Metric]]: + inputs, label = batch + + def _loss_fn(params: jt.PyTree) -> tuple[jt.Scalar, jt.Array]: + logits = self.model.apply(params, inputs, training=True) + loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0) + return loss, logits + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True, allow_int=True) + (loss, logits), grads = grad_fn(state.params) + state = state.update(grads=grads) + + metrics = { + 'loss': mlrx.metrics.scalar(loss), + 'accuracy': mlrx.metrics.binary_accuracy(label, logits, threshold=0.0), + 'auc': mlrx.metrics.aucpr(label, logits, from_logits=True), + 'aucroc': mlrx.metrics.aucroc(label, logits, from_logits=True), + 'label/mean': mlrx.metrics.mean(label), + 'prediction/mean': mlrx.metrics.mean(jax.nn.sigmoid(logits)), + } + return state, metrics + + def eval_step( + self, batch: jt.PyTree, state: mlrx.JaxState + ) -> Mapping[str, mlrx.Metric]: + inputs, label = batch + logits = self.model.apply(state.params, inputs, training=False) + loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0) + + metrics = { + 'loss': mlrx.metrics.mean(loss), + 'accuracy': mlrx.metrics.binary_accuracy(label, logits, threshold=0.0), + 'auc': mlrx.metrics.aucpr(label, logits, from_logits=True), + 'aucroc': mlrx.metrics.aucroc(label, logits, from_logits=True), + 'label/mean': mlrx.metrics.mean(label), + 'prediction/mean': mlrx.metrics.mean(jax.nn.sigmoid(logits)), + } + return metrics + + +def features() -> fdl.Config[FeatureSet]: + """Creates a feature collection for the DLRM model.""" + table_sizes = [ + (40000000, 3), + (39060, 2), + (17295, 1), + (7424, 2), + (20265, 6), + (3, 1), + (7122, 1), + (1543, 1), + (63, 1), + (40000000, 7), + (3067956, 3), + (405282, 8), + (10, 1), + (2209, 6), + (11938, 9), + (155, 5), + (4, 1), + (976, 1), + (14, 1), + (40000000, 12), + (40000000, 100), + (40000000, 27), + (590152, 10), + (12973, 3), + (108, 1), + (36, 1), + ] + return fdl.Config( + FeatureSet, + features=[ + fdl.Config(DenseFeature, name=f'float-feature-{i}') for i in range(13) + ] + + [ + fdl.Config( + SparseFeature, + vocab_size=vocab_size, + embedding_dim=embedding_dim, + name=f'categorical-feature-{i}', + ) + for i, (vocab_size, embedding_dim) in enumerate(table_sizes) + ], + ) + + +def experiment() -> fdl.Config[mlrx.Experiment]: + """DLRM experiment.""" + + feature_set = features() + + task = fdl.Config( + PredictionTask, + train_data=fdl.Config( + CriteoFactory, + features=feature_set, + global_batch_size=131_072, + ), + eval_data=fdl.Config( + CriteoFactory, + features=feature_set, + global_batch_size=131_072, + use_cached_data=True, + ), + model=fdl.Config( + DLRMModel, + features=feature_set, + embedding_optimizer=fdl.Config( + sparsecore.embedding_spec.AdagradOptimizerSpec, + learning_rate=0.01, + ), + bottom_mlp_dims=[512, 256, 128], + top_mlp_dims=[1024, 1024, 512, 256, 1], + dcn_layers=3, + dcn_inner_dim=512, + ), + optimizer=fdl.Config( + mlrx.AdagradFactory, + learning_rate=0.01, + # Sparsecore embedding parameters are optimized in the backward pass. + freeze_mask=rf'.*{sparsecore.EMBEDDING_PARAM_NAME}.*', + ), + ) + trainer = fdl.Config( + mlrx.JaxTrainer, + partitioner=fdl.Config(mlrx.DataParallelPartitioner), + train_steps=1_000, + steps_per_eval=100, + steps_per_loop=100, + ) + return fdl.Config(mlrx.Experiment, task=task, trainer=trainer) diff --git a/recml/examples/dlrm_experiment_test.py b/recml/examples/dlrm_experiment_test.py new file mode 100644 index 0000000..502e630 --- /dev/null +++ b/recml/examples/dlrm_experiment_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the DLRM experiment.""" + +from absl.testing import absltest +import fiddle as fdl +from fiddle import selectors +import jax +import numpy as np +from recml import mlrx +from recml.examples import dlrm_experiment + + +class DLRMExperimentTest(absltest.TestCase): + + def test_dlrm_experiment(self): + if jax.devices()[0].platform != "tpu": + self.skipTest("Test only supported on TPUs.") + + np.random.seed(1337) + + experiment = dlrm_experiment.experiment() + + experiment.task.train_data.global_batch_size = 4 + experiment.task.eval_data.global_batch_size = 4 + experiment.trainer.train_steps = 12 + experiment.trainer.steps_per_loop = 4 + experiment.trainer.steps_per_eval = 4 + + for cfg in selectors.select(experiment, dlrm_experiment.SparseFeature): + cfg.vocab_size = 200 + cfg.embedding_dim = 8 + + experiment = fdl.build(experiment) + mlrx.run_experiment(experiment, mlrx.Experiment.Mode.TRAIN_AND_EVAL) + + +if __name__ == "__main__": + absltest.main() diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py new file mode 100644 index 0000000..2e42c39 --- /dev/null +++ b/recml/layers/linen/sparsecore.py @@ -0,0 +1,404 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sparsecore embedding layers.""" + +from __future__ import annotations + +from collections.abc import Callable +import dataclasses +import functools +from typing import Any, Literal, Mapping, TypeVar + +from etils import epy +from flax import linen as nn +from flax import typing +import jax +import jax.numpy as jnp +import numpy as np +from recml.core.ops import embedding_ops +import tensorflow as tf + +with epy.lazy_imports(): + # pylint: disable=g-import-not-at-top + from jax_tpu_embedding.sparsecore.lib.flax import embed + from jax_tpu_embedding.sparsecore.lib.nn import embedding + from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + from jax_tpu_embedding.sparsecore.utils import utils + # pylint: enable=g-import-not-at-top + +A = TypeVar('A') +CSR_INPUTS_KEY = 'csr_inputs' +EMBEDDING_PARAM_NAME = embed.EMBEDDING_PARAM_NAME + + +@dataclasses.dataclass +class EmbeddingSpec: + """Sparsecore embedding spec. + + Attributes: + input_dim: The cardinality of the input feature or size of its vocabulary. + embedding_dim: The length of each embedding vector. + max_sequence_length: An optional maximum sequence length. If set, the looked + up embeddings will not be aggregated over the sequence dimension. + Otherwise the embeddings will be aggregated over the sequence dimension + using the `combiner`. Defaults to None. + combiner: The combiner to use to aggregate the embeddings over the sequence + dimension. This is ignored when `max_sequence_length` is set. Allowed + values are 'sum', 'mean', and 'sqrtn'. Defaults to 'mean'. + initializer: The initializer to use for the embedding table. Defaults to + truncated_normal(stddev=1 / sqrt(embedding_dim)) if not set. + optimizer: An optional custom optimizer to use for the embedding table. + weight_name: An optional weight feature name to use for performing a + weighted aggregation on the output of the embedding lookup. Defaults to + None. + """ + + input_dim: int + embedding_dim: int + max_sequence_length: int | None = None + combiner: Literal['sum', 'mean', 'sqrtn'] = 'mean' + initializer: jax.nn.initializers.Initializer | None = None + optimizer: embedding_spec.OptimizerSpec | None = None + weight_name: str | None = None + + def __post_init__(self): + if self.max_sequence_length is not None and self.weight_name is not None: + raise ValueError( + '`max_sequence_length` and `weight_name` cannot both be set. Weighted' + ' aggregation can only be performed when the embeddings are' + ' aggregated over the sequence dimension.' + ) + + def __hash__(self): + return id(self) + + +@dataclasses.dataclass +class SparsecoreEmbedder: + """Sparsecore embedder. + + Attributes: + specs: A mapping from feature name to embedding specs. + optimizer: The default optimizer to use for the embedding variables. + sharding_strategy: The sharding strategy to use for the embedding table. + Defaults to 'MOD' sharding. See the sparsecore documentation for more + details. + num_sc_per_device: The number of sparsecores per Jax device. By default, a + fixed mapping is used to determine this based on device 0. This may fail + on newer TPU architectures if the mapping is not updated of if device 0 is + not a TPU device with a sparsecore. + static_buffer_size_multiplier: The multiplier to use for the static buffer + size. Defaults to 256. + + Example usage: + ```python + class DLRMModel(nn.Module): + # The embedder must be a property of the Flax model and cannot be created + # inside setup(). + embedder: sparsecore.SparsecoreEmbedder + ... + + def setup(self): + self.sparsecore_module = self.embedder.make_sparsecore_module() + ... + + def __call__(self, inputs: Mapping[str, jax.Array]) -> jax.Array: + embedding_activations = self.sparsecore_module(inputs) + ... + + # Instantiate the model and the embedder. + model = DLRMModel(embedder=embedder) + + # Create the eager preprocessor. + preprocessor = model.embedder.make_preprocessor(global_batch_size) + + # Fetch and preprocess the inputs on CPU. + inputs = ... + + # Post-process the sparse features into CSR format on CPU. + processed_inputs = preprocessor(inputs) + + # Shard the inputs and put them on device. + sharded_inputs = ... + + # Initialize and call the model on TPU inside JIT as usual. + vars = model.init(jax.random.key(0), sharded_inputs) + embedding_activations = model.apply(vars, sharded_inputs) + ``` + """ + + specs: Mapping[str, EmbeddingSpec] + optimizer: embedding_spec.OptimizerSpec + sharding_strategy: str = 'MOD' + num_sc_per_device: int = dataclasses.field( + default_factory=utils.num_sparsecores_per_device + ) + static_buffer_size_multiplier: int = 256 + + def __post_init__(self): + self._feature_specs = None + self._global_batch_size = None + + def _init_feature_specs( + self, batch_size: int + ) -> Mapping[str, embedding_spec.FeatureSpec]: + """Returns the feature specs for sparsecore embedding lookup.""" + if self._feature_specs is not None: + return self._feature_specs + + feature_specs = {} + shared_tables = {} + for name, spec in self.specs.items(): + if spec in shared_tables: + table_spec = shared_tables[spec] + else: + table_spec = embedding_spec.TableSpec( + vocabulary_size=spec.input_dim, + embedding_dim=spec.embedding_dim, + initializer=( + spec.initializer + or jax.nn.initializers.truncated_normal( + stddev=1.0 / jnp.sqrt(spec.embedding_dim) + ) + ), + optimizer=spec.optimizer or self.optimizer, + combiner=spec.combiner, + name=f'{name}_table', + ) + shared_tables[spec] = table_spec + + if spec.max_sequence_length is not None: + batch_dim = batch_size * spec.max_sequence_length + else: + batch_dim = batch_size + + feature_specs[name] = embedding_spec.FeatureSpec( + name=name, + table_spec=table_spec, + input_shape=(batch_dim, 1), + output_shape=(batch_dim, spec.embedding_dim), + ) + + embedding.auto_stack_tables( + feature_specs, + jax.device_count(), + self.num_sc_per_device, + stack_to_max_ids_per_partition=lambda n, bs: bs, + stack_to_max_unique_ids_per_partition=lambda n, bs: bs, + ) + embedding.prepare_feature_specs_for_training( + feature_specs, + jax.device_count(), + self.num_sc_per_device, + ) + self._feature_specs = feature_specs + self._global_batch_size = batch_size + return feature_specs + + def make_preprocessor(self, batch_size: int) -> Callable[..., Any]: + """Returns a preprocessor for sparsecore embedding lookup.""" + feature_specs = self._init_feature_specs(batch_size) + weights_names = { + name: spec.weight_name + for name, spec in self.specs.items() + if spec.weight_name is not None + } + + def _to_np(x: Any) -> np.ndarray: + if isinstance(x, np.ndarray): + return x + if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + raise NotImplementedError( + 'Sparsecore embedding layer does not support sparse or' + ' raggedtensors.' + ) + if isinstance(x, tf.Tensor): + return x.numpy() + if isinstance(x, jax.Array): + return jax.device_get(x) + return np.array(x) + + def _preprocessor(inputs): + if isinstance(inputs, tuple): + inputs, *rem = inputs + else: + rem = None + + sparse_features = set() + features = {} + weights = {} + for key in feature_specs: + features[key] = _to_np(inputs[key]) + sparse_features.add(key) + if key in weights_names: + weights[key] = _to_np(inputs[weights_names[key]]) + sparse_features.add(weights_names[key]) + else: + weights[key] = np.ones_like(features[key]) + + csr_inputs = embed.EmbeddingLookups( + *embedding.preprocess_sparse_dense_matmul_input( + features=features, + features_weights=weights, + feature_specs=feature_specs, + local_device_count=jax.local_device_count(), + global_device_count=jax.device_count(), + num_sc_per_device=self.num_sc_per_device, + sharding_strategy=self.sharding_strategy, + static_buffer_size_multiplier=self.static_buffer_size_multiplier, + allow_id_dropping=False, + )[:-1] + ) + processed_inputs = { + k: v for k, v in inputs.items() if k not in sparse_features + } + processed_inputs[CSR_INPUTS_KEY] = csr_inputs + + if rem is not None: + processed_inputs = (processed_inputs, *rem) + return processed_inputs + + return _preprocessor + + def make_sparsecore_module(self, **kwargs) -> _SparsecoreEmbed: + """Returns the sparsecore embedding layer.""" + if self._feature_specs is None or self._global_batch_size is None: + raise ValueError( + 'The feature specs are not initialized. Make sure to call' + ' `make_preprocessor` before calling `sparsecore_layer`.' + ) + + def _key(k: str | tuple[str, str]) -> str: + return k[0] if isinstance(k, tuple) else k + + return _SparsecoreEmbed( + feature_specs=self._feature_specs, + global_batch_size=self._global_batch_size, + sharding_axis=0, + sharding_strategy=self.sharding_strategy, + num_sc_per_device=self.num_sc_per_device, + **kwargs, + ) + + +class _SparsecoreEmbed(nn.Module): + """Sparsecore embedding layer.""" + + feature_specs: embedding.Nested[embedding_spec.FeatureSpec] + global_batch_size: int + sharding_axis: str | int + sharding_strategy: str + num_sc_per_device: int + + @property + def abstract_mesh(self) -> jax.sharding.AbstractMesh: + abstract_mesh = jax.sharding.get_abstract_mesh() + if not abstract_mesh.shape_tuple: + raise ValueError( + 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make' + ' sure to set the mesh when calling the sparsecore module.' + ) + return abstract_mesh + + @property + def sharding_axis_name(self) -> str: + if isinstance(self.sharding_axis, int): + return self.abstract_mesh.axis_names[self.sharding_axis] + return self.sharding_axis + + @property + def num_shards(self) -> int: + return self.abstract_mesh.shape[self.sharding_axis_name] + + def setup(self): + initializer = functools.partial( + embedding.init_embedding_variables, + table_specs=embedding.get_table_specs(self.feature_specs), + global_sharding=jax.sharding.NamedSharding( + self.abstract_mesh, + jax.sharding.PartitionSpec(self.sharding_axis_name, None), + ), + num_sparsecore_per_device=self.num_sc_per_device, + # We need to by-pass the mesh check to use the abstract mesh. + bypass_mesh_check=True, + ) + self.embedding_table = self.param( + name=EMBEDDING_PARAM_NAME, + init_fn=_with_sparsecore_layout( + initializer, (self.sharding_axis_name,), self.abstract_mesh + ), + ) + + def __call__( + self, inputs: Mapping[str, jax.Array] + ) -> embedding.Nested[jax.Array]: + """Computes the embedding activations. + + Args: + inputs: A mapping from feature name to the feature values. The values must + have been preprocessed by the preprocessor returned by + `make_preprocessor`. + + Returns: + The activations structure with the same structure as specs. + """ + activations = embedding_ops.sparsecore_lookup( + embedding_ops.SparsecoreParams( + feature_specs=self.feature_specs, + abstract_mesh=self.abstract_mesh, + data_axes=(self.sharding_axis_name,), + embedding_axes=(self.sharding_axis_name, None), + sharding_strategy=self.sharding_strategy, + ), + self.embedding_table, + inputs[CSR_INPUTS_KEY], + ) + + # Reshape the activations if the batch size is not the same as the global + # batch size. + def _maybe_reshape_activation(activation: jax.Array) -> jax.Array: + if activation.shape[0] != self.global_batch_size: + return jnp.reshape( + activation, + ( + self.global_batch_size, + activation.shape[0] // self.global_batch_size, + activation.shape[1], + ), + ) + return activation + + return jax.tree.map(_maybe_reshape_activation, activations) + + +class _WithSparseCoreLayout(nn.Partitioned[A]): + + def get_sharding(self, _): + assert self.mesh is not None + return embed.Layout( + embed.DeviceLocalLayout(major_to_minor=(0, 1), _tiling=((8,),)), + jax.sharding.NamedSharding(self.mesh, self.get_partition_spec()), + ) + + +def _with_sparsecore_layout( + fn: Callable[..., Any], + names: typing.LogicalNames, + abstract_mesh: jax.sharding.AbstractMesh, +): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return _WithSparseCoreLayout(fn(*args, **kwargs), names, mesh=abstract_mesh) # pytype: disable=wrong-arg-types + + return wrapper diff --git a/recml/layers/linen/sparsecore_test.py b/recml/layers/linen/sparsecore_test.py new file mode 100644 index 0000000..b181118 --- /dev/null +++ b/recml/layers/linen/sparsecore_test.py @@ -0,0 +1,76 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sparsecore tests.""" + +import functools + +from absl.testing import absltest +import jax +from recml.layers.linen import sparsecore +from recml.core.training import partitioning + + +class SparsecoreTest(absltest.TestCase): + + def test_sparsecore_embedder_equivalence(self): + if jax.devices()[0].platform != "tpu": + self.skipTest("Test only supported on TPUs.") + + k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + + inputs = { + "a": jax.random.randint(k1, (32, 16), minval=1, maxval=100), + "b": jax.random.randint(k2, (32, 16), minval=1, maxval=100), + "w": jax.random.normal(k3, (32, 16)), + } + + dp_partitioner = partitioning.DataParallelPartitioner() + embedder = sparsecore.SparsecoreEmbedder( + specs={ + "a": sparsecore.EmbeddingSpec( + input_dim=100, + embedding_dim=16, + combiner="mean", + weight_name="w", + ), + "b": sparsecore.EmbeddingSpec( + input_dim=100, + embedding_dim=16, + max_sequence_length=10, + ), + }, + optimizer=sparsecore.embedding_spec.AdagradOptimizerSpec( + learning_rate=0.01 + ), + ) + preprocessor = embedder.make_preprocessor(32) + layer = embedder.make_sparsecore_module() + + sc_inputs = dp_partitioner.shard_inputs(preprocessor(inputs)) + sc_vars = dp_partitioner.partition_init(functools.partial(layer.init, k4))( + sc_inputs + ) + + def step(inputs, params): + return layer.apply(params, inputs) + + p_step = dp_partitioner.partition_step(step, training=False) + sparsecore_activations = jax.device_get(p_step(sc_inputs, sc_vars)) + + self.assertEqual(sparsecore_activations["a"].shape, (32, 16)) + self.assertEqual(sparsecore_activations["b"].shape, (32, 10, 16)) + + +if __name__ == "__main__": + absltest.main()