Skip to content

Commit

Permalink
Add complete_cumsum cpu and meta ops
Browse files Browse the repository at this point in the history
Differential Revision: D66226634
  • Loading branch information
Jiyuan Zhang authored and facebook-github-bot committed Nov 20, 2024
1 parent 6e9dc67 commit 57110d4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
44 changes: 44 additions & 0 deletions generative_recommenders/ops/cpp/complete_cumsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include "fbgemm_gpu/sparse_ops.h" // @manual

namespace gr {

at::Tensor complete_cumsum_cpu(const at::Tensor& values) {
TORCH_CHECK(values.dim() == 1);
auto len = values.size(0);
const torch::Tensor index = at::range(0, len, at::kLong).cpu();
auto output = fbgemm_gpu::asynchronous_complete_cumsum_cpu(values);
return output;
}

at::Tensor complete_cumsum_meta(const at::Tensor& values) {
auto len = values.sym_size(0);
auto output = at::native::empty_meta_symint(
{len + 1},
/*dtype=*/::std::make_optional(values.scalar_type()),
/*layout=*/::std::make_optional(values.layout()),
/*device=*/::std::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/::std::nullopt);
return output;
}

} // namespace gr
6 changes: 6 additions & 0 deletions generative_recommenders/ops/cpp/cpp_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ at::Tensor batched_complete_cumsum_cuda(const at::Tensor& values);

at::Tensor batched_complete_cumsum_meta(const at::Tensor& values);

at::Tensor complete_cumsum_cpu(const at::Tensor& values);

at::Tensor complete_cumsum_cuda(const at::Tensor& values);

at::Tensor complete_cumsum_meta(const at::Tensor& values);

at::Tensor concat_1d_jagged_jagged_cpu(
const at::Tensor& lengths_left,
const at::Tensor& values_left,
Expand Down Expand Up @@ -78,6 +82,7 @@ TORCH_LIBRARY_IMPL(gr, CPU, m) {
m.impl("expand_1d_jagged_to_dense", gr::expand_1d_jagged_to_dense_cpu);
m.impl("batched_complete_cumsum", gr::batched_complete_cumsum_cpu);
m.impl("concat_1d_jagged_jagged", gr::concat_1d_jagged_jagged_cpu);
m.impl("complete_cumsum", gr::complete_cumsum_cpu);
}

TORCH_LIBRARY_IMPL(gr, CUDA, m) {
Expand All @@ -91,6 +96,7 @@ TORCH_LIBRARY_IMPL(gr, Meta, m) {
m.impl("expand_1d_jagged_to_dense", gr::expand_1d_jagged_to_dense_meta);
m.impl("batched_complete_cumsum", gr::batched_complete_cumsum_meta);
m.impl("concat_1d_jagged_jagged", gr::concat_1d_jagged_jagged_meta);
m.impl("complete_cumsum", gr::complete_cumsum_meta);
}

TORCH_LIBRARY_IMPL(gr, Autograd, m) {
Expand Down

0 comments on commit 57110d4

Please sign in to comment.