Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complete_cumsum cpu and meta ops #140

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions generative_recommenders/ops/cpp/complete_cumsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include "fbgemm_gpu/sparse_ops.h" // @manual

namespace gr {

at::Tensor complete_cumsum_cpu(const at::Tensor& values) {
TORCH_CHECK(values.dim() == 1);
auto len = values.size(0);
const torch::Tensor index = at::range(0, len, at::kLong).cpu();
auto output = fbgemm_gpu::asynchronous_complete_cumsum_cpu(values);
return output;
}

at::Tensor complete_cumsum_meta(const at::Tensor& values) {
auto len = values.sym_size(0);
auto output = at::native::empty_meta_symint(
{len + 1},
/*dtype=*/::std::make_optional(values.scalar_type()),
/*layout=*/::std::make_optional(values.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
return output;
}

} // namespace gr
6 changes: 6 additions & 0 deletions generative_recommenders/ops/cpp/cpp_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ at::Tensor batched_complete_cumsum_cuda(const at::Tensor& values);

at::Tensor batched_complete_cumsum_meta(const at::Tensor& values);

at::Tensor complete_cumsum_cpu(const at::Tensor& values);

at::Tensor complete_cumsum_cuda(const at::Tensor& values);

at::Tensor complete_cumsum_meta(const at::Tensor& values);

at::Tensor concat_1d_jagged_jagged_cpu(
const at::Tensor& lengths_left,
const at::Tensor& values_left,
Expand Down Expand Up @@ -78,6 +82,7 @@ TORCH_LIBRARY_IMPL(gr, CPU, m) {
m.impl("expand_1d_jagged_to_dense", gr::expand_1d_jagged_to_dense_cpu);
m.impl("batched_complete_cumsum", gr::batched_complete_cumsum_cpu);
m.impl("concat_1d_jagged_jagged", gr::concat_1d_jagged_jagged_cpu);
m.impl("complete_cumsum", gr::complete_cumsum_cpu);
}

TORCH_LIBRARY_IMPL(gr, CUDA, m) {
Expand All @@ -91,6 +96,7 @@ TORCH_LIBRARY_IMPL(gr, Meta, m) {
m.impl("expand_1d_jagged_to_dense", gr::expand_1d_jagged_to_dense_meta);
m.impl("batched_complete_cumsum", gr::batched_complete_cumsum_meta);
m.impl("concat_1d_jagged_jagged", gr::concat_1d_jagged_jagged_meta);
m.impl("complete_cumsum", gr::complete_cumsum_meta);
}

TORCH_LIBRARY_IMPL(gr, Autograd, m) {
Expand Down