diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst new file mode 100644 index 0000000000..519b74e6be --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst @@ -0,0 +1,6 @@ +Pooled Embedding Operators +========================== + +.. automodule:: fbgemm_gpu + +.. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index a71a589959..c4d98c720f 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -91,3 +91,4 @@ Table of Contents fbgemm_gpu-python-api/table_batched_embedding_ops.rst fbgemm_gpu-python-api/jagged_tensor_ops.rst + fbgemm_gpu-python-api/pooled_embedding_ops.rst diff --git a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py index 250f9d58e9..5077a5ba3f 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py @@ -7,6 +7,6 @@ # Trigger the manual addition of docstrings to pybind11-generated operators try: - from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401 + from . import jagged_tensor_ops, merge_pooled_embedding_ops # noqa: F401 except Exception: pass diff --git a/fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py new file mode 100644 index 0000000000..6990946fba --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .common import add_docs + +add_docs( + torch.ops.fbgemm.merge_pooled_embeddings, + """ +merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) -> Tensor + +Concatenate embedding outputs from different devices (on the same host) +on to the target device. + +Args: + pooled_embeddings (List[Tensor]): A list of embedding outputs from + different devices on the same host. Each output has 2 + dimensions. + + uncat_dim_size (int): The size of the dimension that is not + concatenated, i.e., if `cat_dim=0`, `uncat_dim_size` is the size + of dim 1 and vice versa. + + target_device (torch.device): The target device that aggregates all + the embedding outputs. + + cat_dim (int = 1): The dimension that the tensors are concatenated + +Returns: + The concatenated embedding output (2D) on the target device + """, +)