Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[skip ci] WIP on index_fill batch rule #370

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions functorch/csrc/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,99 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
return std::make_tuple(at::stack(results), 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Scalar& value) {

if (!index_bdim) {
// Handle scalar tensors... self, other can be scalar tensors
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
dim = maybe_wrap_dim(dim, self_logical_rank);

optional<int64_t> out_bdim = nullopt;
if (self_bdim) {
const auto batch_size = self.size(*self_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
dim = dim + 1;
out_bdim = 0;
}

auto result = self_.index_fill(dim, index, value);
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, out_bdim);
}

// SAME AS FOR index_add
// Index is batched. For-loop and stack is the best thing I can come up with
// right now. We really want generalized index_fill kernel in PyTorch
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
std::vector<Tensor> results;
results.reserve(batch_size);
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_bdim.has_value() ?
self.select(*self_bdim, i) : self;
const auto& index_slice = index_bdim.has_value() ?
index.select(*index_bdim, i) : index;
results.push_back(at::index_fill(self_slice, dim, index_slice, value));
}
return std::make_tuple(at::stack(results), 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, optional<int64_t> index_bdim,
const Tensor& value, optional<int64_t> value_bdim) {

if (!index_bdim && !value_bdim) {
// Handle scalar tensors... self, other can be scalar tensors
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
dim = maybe_wrap_dim(dim, self_logical_rank);

optional<int64_t> out_bdim = nullopt;
if (self_bdim) {
const auto batch_size = self.size(*self_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
dim = dim + 1;
out_bdim = 0;
}
auto result = self_.index_fill(dim, index, value);
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, out_bdim);
}

// SAME AS FOR index_add
// Index is batched. For-loop and stack is the best thing I can come up with
// right now. We really want generalized index_fill kernel in PyTorch
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
std::vector<Tensor> results;
results.reserve(batch_size);
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_bdim.has_value() ?
self.select(*self_bdim, i) : self;
const auto& index_slice = index_bdim.has_value() ?
index.select(*index_bdim, i) : index;
const auto& value_slice = value_bdim.has_value() ?
value.select(*value_bdim, i) : value;
results.push_back(at::index_fill(self_slice, dim, index_slice, value_slice));
}
return std::make_tuple(at::stack(results), 0);
}

TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("index.Tensor", index_plumbing);
m.impl("index_put_", index_put__plumbing);
Expand All @@ -550,6 +643,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("index_copy", index_copy_decomp);
m.impl("index_select", index_select_decomp);
VMAP_SUPPORT("index_add", index_add_batch_rule);
VMAP_SUPPORT("index_fill.int_Scalar", index_fill_int_scalar_batch_rule);
VMAP_SUPPORT("index_fill.int_Tensor", index_fill_int_tensor_batch_rule);
VMAP_SUPPORT("diagonal_scatter", diagonal_scatter_batch_rule);
VMAP_SUPPORT("gather", gather_batch_rule);
VMAP_SUPPORT("gather_backward", gather_backward_batch_rule);
Expand Down
5 changes: 1 addition & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def vjp_of_vjp(*args_and_cotangents):
xfail('fmax'),
xfail('fmin'),
xfail('index_copy'),
xfail('index_fill'),
xfail('linalg.det', ''),
xfail('linalg.eigh'),
xfail('linalg.householder_product'),
Expand Down Expand Up @@ -595,7 +594,6 @@ def test_vmapvjp(self, device, dtype, op):
xfail('block_diag'), # TODO: We expect this to fail in core, but it doesn't
xfail('index_copy'),
xfail('index_put'),
xfail('index_fill'),
xfail('masked_fill'),
xfail('masked_scatter'),

Expand Down Expand Up @@ -701,7 +699,6 @@ def test_vmapjvp(self, device, dtype, op):
xfail('max', 'binary'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('min', 'binary'),
xfail('index_fill'),
xfail('index_put'),
xfail('std_mean'),
xfail('double', 'channels_last'),
Expand Down Expand Up @@ -760,7 +757,7 @@ def test_vmapjvpall(self, device, dtype, op):
xfail('fmax'),
xfail('fmin'),
xfail('index_copy'),
xfail('index_fill'),
xfail('index_fill'), # RuntimeError: aten::_unique hit the vmap fallback which is currently disabled
xfail('linalg.cholesky'),
xfail('linalg.cholesky_ex'),
xfail('linalg.det'),
Expand Down
1 change: 0 additions & 1 deletion test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3181,7 +3181,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('gradient'),
xfail('histogram'),
xfail('hsplit'),
xfail('index_fill'),
xfail('index_put'),
xfail('isin'),
xfail('linalg.cholesky'),
Expand Down