diff --git a/generative_recommenders/ops/cpp/complete_cumsum.cpp b/generative_recommenders/ops/cpp/complete_cumsum.cpp new file mode 100644 index 0000000..ae614cf --- /dev/null +++ b/generative_recommenders/ops/cpp/complete_cumsum.cpp @@ -0,0 +1,44 @@ +/* Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "fbgemm_gpu/sparse_ops.h" // @manual + +namespace gr { + +at::Tensor complete_cumsum_cpu(const at::Tensor& values) { + TORCH_CHECK(values.dim() == 1); + auto len = values.size(0); + const torch::Tensor index = at::range(0, len, at::kLong).cpu(); + auto output = fbgemm_gpu::asynchronous_complete_cumsum_cpu(values); + return output; +} + +at::Tensor complete_cumsum_meta(const at::Tensor& values) { + auto len = values.sym_size(0); + auto output = at::native::empty_meta_symint( + {len + 1}, + /*dtype=*/::std::make_optional(values.scalar_type()), + /*layout=*/::std::make_optional(values.layout()), + /*device=*/::std::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/::std::nullopt); + return output; +} + +} // namespace gr diff --git a/generative_recommenders/ops/cpp/cpp_ops.cpp b/generative_recommenders/ops/cpp/cpp_ops.cpp index 35845de..2bc43f3 100644 --- a/generative_recommenders/ops/cpp/cpp_ops.cpp +++ b/generative_recommenders/ops/cpp/cpp_ops.cpp @@ -44,8 +44,12 @@ at::Tensor batched_complete_cumsum_cuda(const at::Tensor& values); at::Tensor batched_complete_cumsum_meta(const at::Tensor& values); +at::Tensor complete_cumsum_cpu(const at::Tensor& values); + at::Tensor complete_cumsum_cuda(const at::Tensor& values); +at::Tensor complete_cumsum_meta(const at::Tensor& values); + at::Tensor concat_1d_jagged_jagged_cpu( const at::Tensor& lengths_left, const at::Tensor& values_left, @@ -78,6 +82,7 @@ TORCH_LIBRARY_IMPL(gr, CPU, m) { m.impl("expand_1d_jagged_to_dense", gr::expand_1d_jagged_to_dense_cpu); m.impl("batched_complete_cumsum", gr::batched_complete_cumsum_cpu); m.impl("concat_1d_jagged_jagged", gr::concat_1d_jagged_jagged_cpu); + m.impl("complete_cumsum", gr::complete_cumsum_cpu); } TORCH_LIBRARY_IMPL(gr, CUDA, m) { @@ -91,6 +96,7 @@ TORCH_LIBRARY_IMPL(gr, Meta, m) { m.impl("expand_1d_jagged_to_dense", gr::expand_1d_jagged_to_dense_meta); m.impl("batched_complete_cumsum", gr::batched_complete_cumsum_meta); m.impl("concat_1d_jagged_jagged", gr::concat_1d_jagged_jagged_meta); + m.impl("complete_cumsum", gr::complete_cumsum_meta); } TORCH_LIBRARY_IMPL(gr, Autograd, m) {