Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 8, 2024
1 parent e072c92 commit d10691c
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 103 deletions.
17 changes: 0 additions & 17 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,23 +425,6 @@ TensorView* matmul(TensorView* tv_a, TensorView* tv_b) {
return out;
}

namespace {
template <typename T>
void checkAllEqual(std::initializer_list<T> elements) {
for (const auto& element : elements) {
NVF_CHECK(
element == *elements.begin(),
"Expected all elements to be equal, but found ",
element,
" and ",
*elements.begin(),
" in [",
toDelimitedString(elements),
"]");
}
}
} // namespace

SdpfaFwdResult sdpfa_fwd(
TensorView* query,
TensorView* key,
Expand Down
161 changes: 76 additions & 85 deletions csrc/tensor_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,65 +209,66 @@ class BackwardTraverseFromLogicalToAlloc {
};

void validateAllocationSizesAndStrides(
const std::vector<IterDomain*>& alloc_dom_no_reductions,
const std::vector<IterDomain*>& alloc_dom,
const std::vector<std::optional<bool>>& contiguity,
c10::IntArrayRef sizes,
c10::IntArrayRef strides) {
NVF_ERROR(sizes.size() == strides.size());
NVF_ERROR(alloc_dom.size() == contiguity.size());
checkAllEqual(
{TensorDomain::noReductions(alloc_dom).size(),
sizes.size(),
strides.size()});

// Validate contiguity
int64_t contiguous_stride = 1;
auto contiguity_rev = contiguity.crbegin();
for (int64_t i = (int64_t)sizes.size() - 1; i >= 0; i--) {
if (alloc_dom_no_reductions.at(i)->isBroadcast()) {
int64_t expected_stride_if_contiguous = 1;
int64_t dim_index = sizes.size();
// Go backwards because it's easier to compute the expected stride this way.
for (auto domain_index = static_cast<int64_t>(alloc_dom.size()) - 1;
domain_index >= 0;
domain_index--) {
IterDomain* alloc_id = alloc_dom[domain_index];
if (alloc_id->isReduction()) {
continue;
}
while (!contiguity_rev->has_value()) {
contiguity_rev++;

dim_index--;
auto size = sizes.at(dim_index);
auto stride = strides.at(dim_index);

if (alloc_id->isBroadcast()) {
NVF_CHECK(!contiguity[domain_index].has_value());
if (alloc_id->hasExpandedExtent()) {
NVF_CHECK(
stride == 0,
"Expecting an expanded dimension on dimension ",
dim_index,
" but found stride ",
stride);
}
continue;
}
auto size = sizes.at(i);
auto stride = strides.at(i);
NVF_ERROR(!contiguity.empty());
auto last_contiguity = *contiguity_rev;
NVF_ERROR(
last_contiguity.has_value(),
"I don't think this check makes sense, but unfortunately ",
"clang-tidy is not smart enough to infer from the context that this is always true.");
if (*last_contiguity) {

if (alloc_id->isDeviceDim()) {
NVF_CHECK(size == 1);
continue;
}

NVF_CHECK(contiguity[domain_index].has_value());
if (*contiguity[domain_index]) {
NVF_CHECK(
stride == contiguous_stride,
stride == expected_stride_if_contiguous,
"Stride mismatch with contiguity info. ",
" allocation domain: ",
ir_utils::toString(alloc_dom_no_reductions),
" dim: ",
i,
" expected stride: ",
contiguous_stride,
" actual stride: ",
stride);
}
contiguous_stride = stride * size;
contiguity_rev++;
}
NVF_ERROR(
std::none_of(
contiguity_rev,
contiguity.crend(),
[](auto c_flag) { return c_flag.has_value(); }),
"The size of contiguity mismatch with the dimensionality of allocation domain");

// Validate that for expanded broadcast, the stride must be zero.
for (int64_t i : c10::irange((int64_t)strides.size())) {
if (auto alloc_id = alloc_dom_no_reductions.at(i);
alloc_id->hasExpandedExtent()) {
auto stride = strides.at(i);
NVF_CHECK(
stride == 0,
"Expecting an expanded dimension on dimension ",
i,
" but found stride ",
ir_utils::toString(alloc_dom),
"; contiguity: ",
toDelimitedString(contiguity),
"; dim: ",
domain_index,
"; expected stride: ",
expected_stride_if_contiguous,
"; actual stride: ",
stride);
}
expected_stride_if_contiguous = stride * size;
}
}

Expand All @@ -278,43 +279,39 @@ inferAndValidateAllocationSizesAndStrides(
const at::Tensor& tensor,
TensorView* tv,
ExpressionEvaluator ee) {
if (tv == nullptr || !tv->hasAllocation()) {
// When tv is nullptr, or tv does not have allocation, the given sizes and
// strides should already be in the target format. So nothing to do here.
std::vector<int64_t> sizes;
std::vector<int64_t> strides;
for (auto i : c10::irange(tensor.dim())) {
sizes.emplace_back(tensor.size(i));
strides.emplace_back(tensor.stride(i));
}
return {sizes, strides};
}
const auto& alloc =
TensorDomain::noReductions(tv->getMaybeAllocationDomain());
const auto& logical = TensorDomain::noReductions(tv->getLogicalDomain());
const auto& alloc = tv->getMaybeAllocationDomain();
const auto& alloc_no_reductions = TensorDomain::noReductions(alloc);
const auto& logical_no_reductions =
TensorDomain::noReductions(tv->getLogicalDomain());

// active IDs and their shape and stride
std::unordered_map<IterDomain*, std::pair<int64_t, int64_t>> active_ids;
NVF_ERROR((int64_t)logical.size() == tensor.dim());
for (int64_t i : c10::irange((int64_t)logical.size())) {
auto rf_id = logical.at(i);
active_ids[rf_id] = {tensor.size(i), tensor.stride(i)};
NVF_ERROR(static_cast<int64_t>(logical_no_reductions.size()) == tensor.dim());
for (const auto i : c10::irange(tensor.dim())) {
IterDomain* id = logical_no_reductions.at(i);
active_ids[id] = {tensor.size(i), tensor.stride(i)};
}

ForwardTraverseFromLogicalToAlloc(ee, active_ids).run(tv, logical, alloc);
BackwardTraverseFromLogicalToAlloc(ee, active_ids).run(tv, logical, alloc);
ForwardTraverseFromLogicalToAlloc(ee, active_ids)
.run(tv, logical_no_reductions, alloc_no_reductions);
BackwardTraverseFromLogicalToAlloc(ee, active_ids)
.run(tv, logical_no_reductions, alloc_no_reductions);

// Now active_ids should contain the final sizes and strides, unordered. We
// need to put them to the correct order.
std::vector<int64_t> sizes;
std::vector<int64_t> strides;
sizes.reserve(alloc.size());
strides.reserve(alloc.size());
for (auto i : c10::irange(alloc.size())) {
auto id = alloc.at(i);
sizes.emplace_back(active_ids.at(id).first);
strides.emplace_back(active_ids.at(id).second);
sizes.reserve(alloc_no_reductions.size());
strides.reserve(alloc_no_reductions.size());
for (IterDomain* id : alloc_no_reductions) {
if (id->isDeviceDim()) {
sizes.push_back(1);
} else {
sizes.push_back(active_ids.at(id).first);
}
strides.push_back(active_ids.at(id).second);
}

// Only validate final sizes and strides when we have a non-empty tensor.
if (tensor.numel() != 0) {
validateAllocationSizesAndStrides(
Expand Down Expand Up @@ -381,18 +378,12 @@ std::vector<PolymorphicValue> GetMetaData::evaluate(
metadata->logical_stride = input.strides();
}

if (tv->hasAllocation()) {
auto allocation_data =
inferAndValidateAllocationSizesAndStrides(input, tv, ee);
metadata->alloc_size_data = std::move(allocation_data.first);
metadata->alloc_size = c10::makeArrayRef(metadata->alloc_size_data);
metadata->alloc_stride_data = std::move(allocation_data.second);
metadata->alloc_stride = c10::makeArrayRef(metadata->alloc_stride_data);
} else {
metadata->alloc_size = metadata->logical_size;
metadata->alloc_stride = metadata->logical_stride;
// TODO: validateAllocationSizesAndStrides
}
auto allocation_data =
inferAndValidateAllocationSizesAndStrides(input, tv, ee);
metadata->alloc_size_data = std::move(allocation_data.first);
metadata->alloc_size = c10::makeArrayRef(metadata->alloc_size_data);
metadata->alloc_stride_data = std::move(allocation_data.second);
metadata->alloc_stride = c10::makeArrayRef(metadata->alloc_stride_data);
return {PolymorphicValue(std::move(struct_))};
}

Expand Down
15 changes: 15 additions & 0 deletions csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,4 +598,19 @@ template <typename T>
using MaybeUniqueOwningPtr = dynamic_type::
DynamicType<dynamic_type::NoContainers, T*, std::unique_ptr<T>>;

template <typename T>
void checkAllEqual(std::initializer_list<T> elements) {
for (const auto& element : elements) {
NVF_CHECK(
element == *elements.begin(),
"Expected all elements to be equal, but found ",
element,
" and ",
*elements.begin(),
" in [",
toDelimitedString(elements),
"]");
}
}

} // namespace nvfuser
2 changes: 2 additions & 0 deletions tests/cpp/test_gpu1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6753,6 +6753,8 @@ TEST_F(NVFuserTest, FusionPersistentSoftmaxLocalShared_CUDA) {

TensorView* sx_softmax = div(sx_exp, bcast_sum); // (M, N)
TensorView* dx_softmax = div(dx_exp, bcast_sum); // (M, N)
sx_softmax->setContiguity(false);
dx_softmax->setContiguity(false);
fusion.addOutput(sx_softmax);
fusion.addOutput(dx_softmax);

Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/test_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ TEST_P(ShardingTest, ComputeIndex) {
c->setDeviceMesh(mesh);
d->setDeviceMesh(mesh);
a->axis(2)->parallelize(ParallelType::DIDx);
b->axis(2)->parallelize(ParallelType::DIDx);
TensorDomain::noReductions(b->getLoopDomain())[1]->parallelize(
ParallelType::DIDx);
c->axis(2)->parallelize(ParallelType::DIDx);
d->axis(0)->parallelize(ParallelType::DIDx);

Expand Down

0 comments on commit d10691c

Please sign in to comment.