Skip to content

Commit

Permalink
Add int support to dense_to_jagged_forward on GPU (pytorch#1587)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1587

As titled

Reviewed By: jianyuh

Differential Revision: D43132884

fbshipit-source-id: ab3b00b4f0e2cd79134e3e89822a73456a2fd97b
  • Loading branch information
sryap authored and facebook-github-bot committed Feb 23, 2023
1 parent 4ebdb1a commit 111f696
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1202,22 +1202,24 @@ Tensor dense_to_jagged_forward(
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(dense.get_device());
#define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \
AT_DISPATCH_CASE(TYPE, [&] { \
jagged_dense_elementwise_jagged_output_opt_<scalar_t>( \
values, \
offsets, \
dense, \
output, \
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { \
return y; \
}); \
})
// clang-format off
AT_DISPATCH_SWITCH(
values.scalar_type(),
"dense_to_jagged_gpu_op_forward",
AT_DISPATCH_CASE(
at::ScalarType::Half,
[&] {
jagged_dense_elementwise_jagged_output_opt_<scalar_t>(
values,
offsets,
dense,
output,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
}); // device lambda
} // lambda
) // CASE
DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Half)
DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Int)
AT_DISPATCH_CASE_FLOATING_TYPES_AND2(
at::ScalarType::Long,
at::ScalarType::BFloat16,
Expand All @@ -1233,6 +1235,9 @@ Tensor dense_to_jagged_forward(
} // lambda
) // CASE_FLOATING_TYPES_AND
); // SWITCH
// clang-format on
#undef DISPATCH_DENSE_TO_JAGGED_CASE
return output;
}
Expand Down

0 comments on commit 111f696

Please sign in to comment.