diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 43db88cf42d..1db959115a2 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -425,23 +425,6 @@ TensorView* matmul(TensorView* tv_a, TensorView* tv_b) { return out; } -namespace { -template -void checkAllEqual(std::initializer_list elements) { - for (const auto& element : elements) { - NVF_CHECK( - element == *elements.begin(), - "Expected all elements to be equal, but found ", - element, - " and ", - *elements.begin(), - " in [", - toDelimitedString(elements), - "]"); - } -} -} // namespace - SdpfaFwdResult sdpfa_fwd( TensorView* query, TensorView* key, diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 26b5e21338f..99dd13e870d 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -209,65 +209,66 @@ class BackwardTraverseFromLogicalToAlloc { }; void validateAllocationSizesAndStrides( - const std::vector& alloc_dom_no_reductions, + const std::vector& alloc_dom, const std::vector>& contiguity, c10::IntArrayRef sizes, c10::IntArrayRef strides) { - NVF_ERROR(sizes.size() == strides.size()); + NVF_ERROR(alloc_dom.size() == contiguity.size()); + checkAllEqual( + {TensorDomain::noReductions(alloc_dom).size(), + sizes.size(), + strides.size()}); - // Validate contiguity - int64_t contiguous_stride = 1; - auto contiguity_rev = contiguity.crbegin(); - for (int64_t i = (int64_t)sizes.size() - 1; i >= 0; i--) { - if (alloc_dom_no_reductions.at(i)->isBroadcast()) { + int64_t expected_stride_if_contiguous = 1; + int64_t dim_index = sizes.size(); + // Go backwards because it's easier to compute the expected stride this way. + for (auto domain_index = static_cast(alloc_dom.size()) - 1; + domain_index >= 0; + domain_index--) { + IterDomain* alloc_id = alloc_dom[domain_index]; + if (alloc_id->isReduction()) { continue; } - while (!contiguity_rev->has_value()) { - contiguity_rev++; + + dim_index--; + auto size = sizes.at(dim_index); + auto stride = strides.at(dim_index); + + if (alloc_id->isBroadcast()) { + NVF_CHECK(!contiguity[domain_index].has_value()); + if (alloc_id->hasExpandedExtent()) { + NVF_CHECK( + stride == 0, + "Expecting an expanded dimension on dimension ", + dim_index, + " but found stride ", + stride); + } + continue; } - auto size = sizes.at(i); - auto stride = strides.at(i); - NVF_ERROR(!contiguity.empty()); - auto last_contiguity = *contiguity_rev; - NVF_ERROR( - last_contiguity.has_value(), - "I don't think this check makes sense, but unfortunately ", - "clang-tidy is not smart enough to infer from the context that this is always true."); - if (*last_contiguity) { + + if (alloc_id->isDeviceDim()) { + NVF_CHECK(size == 1); + continue; + } + + NVF_CHECK(contiguity[domain_index].has_value()); + if (*contiguity[domain_index]) { NVF_CHECK( - stride == contiguous_stride, + stride == expected_stride_if_contiguous, "Stride mismatch with contiguity info. ", " allocation domain: ", - ir_utils::toString(alloc_dom_no_reductions), - " dim: ", - i, - " expected stride: ", - contiguous_stride, - " actual stride: ", - stride); - } - contiguous_stride = stride * size; - contiguity_rev++; - } - NVF_ERROR( - std::none_of( - contiguity_rev, - contiguity.crend(), - [](auto c_flag) { return c_flag.has_value(); }), - "The size of contiguity mismatch with the dimensionality of allocation domain"); - - // Validate that for expanded broadcast, the stride must be zero. - for (int64_t i : c10::irange((int64_t)strides.size())) { - if (auto alloc_id = alloc_dom_no_reductions.at(i); - alloc_id->hasExpandedExtent()) { - auto stride = strides.at(i); - NVF_CHECK( - stride == 0, - "Expecting an expanded dimension on dimension ", - i, - " but found stride ", + ir_utils::toString(alloc_dom), + "; contiguity: ", + toDelimitedString(contiguity), + "; dim: ", + domain_index, + "; expected stride: ", + expected_stride_if_contiguous, + "; actual stride: ", stride); } + expected_stride_if_contiguous = stride * size; } } @@ -278,43 +279,39 @@ inferAndValidateAllocationSizesAndStrides( const at::Tensor& tensor, TensorView* tv, ExpressionEvaluator ee) { - if (tv == nullptr || !tv->hasAllocation()) { - // When tv is nullptr, or tv does not have allocation, the given sizes and - // strides should already be in the target format. So nothing to do here. - std::vector sizes; - std::vector strides; - for (auto i : c10::irange(tensor.dim())) { - sizes.emplace_back(tensor.size(i)); - strides.emplace_back(tensor.stride(i)); - } - return {sizes, strides}; - } - const auto& alloc = - TensorDomain::noReductions(tv->getMaybeAllocationDomain()); - const auto& logical = TensorDomain::noReductions(tv->getLogicalDomain()); + const auto& alloc = tv->getMaybeAllocationDomain(); + const auto& alloc_no_reductions = TensorDomain::noReductions(alloc); + const auto& logical_no_reductions = + TensorDomain::noReductions(tv->getLogicalDomain()); // active IDs and their shape and stride std::unordered_map> active_ids; - NVF_ERROR((int64_t)logical.size() == tensor.dim()); - for (int64_t i : c10::irange((int64_t)logical.size())) { - auto rf_id = logical.at(i); - active_ids[rf_id] = {tensor.size(i), tensor.stride(i)}; + NVF_ERROR(static_cast(logical_no_reductions.size()) == tensor.dim()); + for (const auto i : c10::irange(tensor.dim())) { + IterDomain* id = logical_no_reductions.at(i); + active_ids[id] = {tensor.size(i), tensor.stride(i)}; } - ForwardTraverseFromLogicalToAlloc(ee, active_ids).run(tv, logical, alloc); - BackwardTraverseFromLogicalToAlloc(ee, active_ids).run(tv, logical, alloc); + ForwardTraverseFromLogicalToAlloc(ee, active_ids) + .run(tv, logical_no_reductions, alloc_no_reductions); + BackwardTraverseFromLogicalToAlloc(ee, active_ids) + .run(tv, logical_no_reductions, alloc_no_reductions); // Now active_ids should contain the final sizes and strides, unordered. We // need to put them to the correct order. std::vector sizes; std::vector strides; - sizes.reserve(alloc.size()); - strides.reserve(alloc.size()); - for (auto i : c10::irange(alloc.size())) { - auto id = alloc.at(i); - sizes.emplace_back(active_ids.at(id).first); - strides.emplace_back(active_ids.at(id).second); + sizes.reserve(alloc_no_reductions.size()); + strides.reserve(alloc_no_reductions.size()); + for (IterDomain* id : alloc_no_reductions) { + if (id->isDeviceDim()) { + sizes.push_back(1); + } else { + sizes.push_back(active_ids.at(id).first); + } + strides.push_back(active_ids.at(id).second); } + // Only validate final sizes and strides when we have a non-empty tensor. if (tensor.numel() != 0) { validateAllocationSizesAndStrides( @@ -381,18 +378,12 @@ std::vector GetMetaData::evaluate( metadata->logical_stride = input.strides(); } - if (tv->hasAllocation()) { - auto allocation_data = - inferAndValidateAllocationSizesAndStrides(input, tv, ee); - metadata->alloc_size_data = std::move(allocation_data.first); - metadata->alloc_size = c10::makeArrayRef(metadata->alloc_size_data); - metadata->alloc_stride_data = std::move(allocation_data.second); - metadata->alloc_stride = c10::makeArrayRef(metadata->alloc_stride_data); - } else { - metadata->alloc_size = metadata->logical_size; - metadata->alloc_stride = metadata->logical_stride; - // TODO: validateAllocationSizesAndStrides - } + auto allocation_data = + inferAndValidateAllocationSizesAndStrides(input, tv, ee); + metadata->alloc_size_data = std::move(allocation_data.first); + metadata->alloc_size = c10::makeArrayRef(metadata->alloc_size_data); + metadata->alloc_stride_data = std::move(allocation_data.second); + metadata->alloc_stride = c10::makeArrayRef(metadata->alloc_stride_data); return {PolymorphicValue(std::move(struct_))}; } diff --git a/csrc/utils.h b/csrc/utils.h index f98d2e357a2..138d05c0128 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -598,4 +598,19 @@ template using MaybeUniqueOwningPtr = dynamic_type:: DynamicType>; +template +void checkAllEqual(std::initializer_list elements) { + for (const auto& element : elements) { + NVF_CHECK( + element == *elements.begin(), + "Expected all elements to be equal, but found ", + element, + " and ", + *elements.begin(), + " in [", + toDelimitedString(elements), + "]"); + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_gpu1.cpp b/tests/cpp/test_gpu1.cpp index 4ed46cddc46..bdd3e4e356d 100644 --- a/tests/cpp/test_gpu1.cpp +++ b/tests/cpp/test_gpu1.cpp @@ -6753,6 +6753,8 @@ TEST_F(NVFuserTest, FusionPersistentSoftmaxLocalShared_CUDA) { TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N) TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N) + sx_softmax->setContiguity(false); + dx_softmax->setContiguity(false); fusion.addOutput(sx_softmax); fusion.addOutput(dx_softmax); diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 6738d99857c..c7ca3f2a5f8 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -147,7 +147,8 @@ TEST_P(ShardingTest, ComputeIndex) { c->setDeviceMesh(mesh); d->setDeviceMesh(mesh); a->axis(2)->parallelize(ParallelType::DIDx); - b->axis(2)->parallelize(ParallelType::DIDx); + TensorDomain::noReductions(b->getLoopDomain())[1]->parallelize( + ParallelType::DIDx); c->axis(2)->parallelize(ParallelType::DIDx); d->axis(0)->parallelize(ParallelType::DIDx);