Skip to content

Add checks for dimensions of pooled_embs #4159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,20 @@ inline std::string torch_tensor_device_name(
}
}

inline const std::string torch_tensor_shape_str(const at::Tensor& ten) {
std::stringstream ss;
const auto sizes = ten.sizes();
ss << "[";
for (auto i = 0; i < sizes.size(); ++i) {
ss << sizes[i];
if (i != sizes.size() - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}

inline bool torch_tensor_on_same_device_check(
const at::Tensor& ten1,
const at::Tensor& ten2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ Tensor permute_pooled_embs_gpu_impl(
return pooled_embs;
}

TORCH_CHECK(
pooled_embs.dim() == 2,
"pooled_embs must be 2-D tensor of size [B_local][Sum_T_global(D)], "
"current shape is: ",
torch_tensor_shape_str(pooled_embs));

// inv_permute_list is not being used so it's not checked here.
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <vector>
#include "fbgemm_gpu/permute_pooled_embedding_ops.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/tensor_utils.h"

using Tensor = at::Tensor;

Expand All @@ -25,12 +26,19 @@ Tensor permute_pooled_embs_cpu_impl(
if (pooled_embs.numel() == 0) {
return pooled_embs;
}

TORCH_CHECK(
pooled_embs.dim() == 2,
"pooled_embs must be 2-D tensor of size [B_local][Sum_T_global(D)], "
"current shape is: ",
torch_tensor_shape_str(pooled_embs));
TORCH_CHECK(
offset_dim_list.scalar_type() == at::ScalarType::Long,
"offset_dim_list needs to have long/int64 type")
TORCH_CHECK(
permute_list.scalar_type() == at::ScalarType::Long,
"permute_list needs to have long/int64 type")

auto permute = permute_list.data_ptr<int64_t>();
const auto n = permute_list.numel();
const auto dims_size = allow_duplicates ? offset_dim_list.numel() : n;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ Tensor permute_pooled_embs_split_gpu_impl(
if (pooled_embs.numel() == 0) {
return pooled_embs;
}

TORCH_CHECK(
pooled_embs.dim() == 2,
"pooled_embs must be 2-D tensor of size [B_local][Sum_T_global(D)], "
"current shape is: ",
torch_tensor_shape_str(pooled_embs));

// inv_permute_list is not being used so it's not checked here.
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "fbgemm_gpu/permute_pooled_embedding_ops_split.h"
#include "fbgemm_gpu/permute_pooled_embs_function_split.h"
#include "fbgemm_gpu/utils/ops_utils.h"
#include "fbgemm_gpu/utils/tensor_utils.h"

using Tensor = at::Tensor;

Expand All @@ -34,12 +35,19 @@ Tensor permute_pooled_embs_split_cpu_impl(
if (pooled_embs.numel() == 0) {
return pooled_embs;
}

TORCH_CHECK(
pooled_embs.dim() == 2,
"pooled_embs must be 2-D tensor of size [B_local][Sum_T_global(D)], "
"current shape is: ",
torch_tensor_shape_str(pooled_embs));
TORCH_CHECK(
offset_dim_list.scalar_type() == at::ScalarType::Long,
"offset_dim_list needs to have long/int64 type")
TORCH_CHECK(
permute_list.scalar_type() == at::ScalarType::Long,
"permute_list needs to have long/int64 type")

auto permute = permute_list.data_ptr<int64_t>();
const auto n = permute_list.numel();
const auto dims_size = allow_duplicates ? offset_dim_list.numel() : n;
Expand Down
Loading