Skip to content

Commit

Permalink
upgrade GEPs of PHI when possible (google#1302)
Browse files Browse the repository at this point in the history
* upgrade GEPs of PHI when possible

As the last update of llvm introduced the canonicalization of GEPs to
i8, we have a case where the user of a phi nodes is now a i8. This
lead the inferering of the phi type to be lowered to i8 while it could
have a higher type.

To avoid that, track possible upgrade of GEPs of PHI and upgrade them
if all the user of the PHI can use the new higher type.

This is fixing the performance regression for the ChromeOS workloads.

Ref google#1292

* simplify test

* remove 1292-phi_users.cl

---------

Co-authored-by: Alan Baker <[email protected]>
  • Loading branch information
rjodinchr and alan-baker authored Feb 8, 2024
1 parent 00e3001 commit 1907bf6
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
64 changes: 64 additions & 0 deletions lib/SimplifyPointerBitcastPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,8 @@ bool clspv::SimplifyPointerBitcastPass::runOnUpgradeableConstantCasts(
Value *ptr;
};
SmallVector<UpgradeInfo, 8> Worklist;
SmallVector<UpgradeInfo, 8> GEPsDefiningPHIsWorklist;
DenseSet<GetElementPtrInst *> GEPsDefiningPHISeen;
for (auto &F : M) {
for (auto &BB : F) {
for (auto &I : BB) {
Expand Down Expand Up @@ -703,6 +705,45 @@ bool clspv::SimplifyPointerBitcastPass::runOnUpgradeableConstantCasts(

Worklist.push_back({&I, cstVal, smallerBitWidths, dest_ty,
gep->getPointerOperand()});
} else if (auto *phi = dyn_cast<PHINode>(source)) {
auto &context = M.getContext();
auto get_geps_defining_phis_type = [&DL, phi, source_ty, dest_ty,
&context, &type_cache,
&GEPsDefiningPHISeen]() {
SmallVector<UpgradeInfo> geps;
for (auto user : phi->users()) {
auto user_ty = clspv::InferType(user, context, &type_cache);
if (user_ty != source_ty) {
continue;
}
auto gep = dyn_cast<GetElementPtrInst>(user);
if (gep == nullptr || !gep->hasAllConstantIndices()) {
geps.clear();
return geps;
}
// should not be used as all indices are constant
IRBuilder<> Builder{gep};

uint64_t cstVal;
Value *dynVal;
size_t smallerBitWidths;
ExtractOffsetFromGEP(DL, Builder, gep, cstVal, dynVal,
smallerBitWidths);
assert(dynVal == nullptr);
if (((cstVal * smallerBitWidths) % SizeInBits(DL, dest_ty)) !=
0) {
geps.clear();
return geps;
}
if (GEPsDefiningPHISeen.count(gep) == 0) {
GEPsDefiningPHISeen.insert(gep);
geps.push_back({gep, cstVal, smallerBitWidths, dest_ty,
gep->getPointerOperand()});
}
}
return geps;
};
GEPsDefiningPHIsWorklist.append(get_geps_defining_phis_type());
}
}
}
Expand Down Expand Up @@ -732,6 +773,29 @@ bool clspv::SimplifyPointerBitcastPass::runOnUpgradeableConstantCasts(
changed = true;
}

for (auto GEPInfo : GEPsDefiningPHIsWorklist) {
auto gep = dyn_cast<GetElementPtrInst>(GEPInfo.inst);
uint64_t cst = GEPInfo.cst;
size_t smallerBitWidths = GEPInfo.smallerBitWidth;
Type *dest_ty = GEPInfo.dest_ty;
Value *ptr = GEPInfo.ptr;
IRBuilder Builder{gep};

auto NewGEPIdxs =
GetIdxsForTyFromOffset(M.getDataLayout(), Builder, dest_ty, dest_ty,
cst, nullptr, smallerBitWidths, ptr);

auto new_gep = GetElementPtrInst::Create(dest_ty, ptr, NewGEPIdxs, "", gep);
LLVM_DEBUG(dbgs() << "\n##runOnUpgradeableConstantCasts:\nreplace gep "
"defining phi type: ";
gep->dump(); dbgs() << "by: "; new_gep->dump());

gep->replaceAllUsesWith(new_gep);
gep->eraseFromParent();

changed = true;
}

return changed;
}

Expand Down
48 changes: 48 additions & 0 deletions test/PointerCasts/1292-phi_users.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
; RUN: clspv-opt %s -o %t.ll --passes=simplify-pointer-bitcast
; RUN: FileCheck %s < %t.ll

; CHECK: [[ptr:%[^ ]+]] = phi ptr addrspace(1) [ {{.*}} ], [ [[add_ptr:%[^ ]+]],
; CHECK: [[add_ptr]] = getelementptr <4 x half>, ptr addrspace(1) [[ptr]], i32 32

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir-unknown-unknown"

%0 = type { <3 x i32>, <3 x i32>, <3 x i32>, %1 }
%1 = type { i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32 }

@__push_constants = local_unnamed_addr addrspace(9) global %0 zeroinitializer, !push_constants !0

define spir_kernel void @main_function(ptr addrspace(1) nocapture readnone align 8 %dst_tensor_buffer, ptr addrspace(1) nocapture readonly align 8 %src_tensor_buffer, ptr addrspace(1) nocapture readonly align 8 %weights_buffer, target("spirv.Image", void, 1, 0, 0, 0, 0, 0, 0) %biases_image2d) !clspv.pod_args_impl !17 !kernel_arg_map !18 {
entry:
%0 = load i32, ptr addrspace(9) getelementptr inbounds (%0, ptr addrspace(9) @__push_constants, i32 0, i32 3, i32 1), align 4
%cmp.i.not = icmp slt i32 %0, 10
br i1 %cmp.i.not, label %if.end.i, label %main_function.inner.exit

if.end.i: ; preds = %entry
%add.ptr.i = getelementptr inbounds <4 x half>, ptr addrspace(1) %weights_buffer, i32 0
br label %do.body.i

do.body.i: ; preds = %do.body.i, %if.end.i
%filters_loc.0.i = phi ptr addrspace(1) [ %add.ptr.i, %if.end.i ], [ %add.ptr32.i, %do.body.i ]
%s.0.i = phi i32 [ 0, %if.end.i ], [ %add19.i, %do.body.i ]
%arrayidx.i = getelementptr inbounds <4 x half>, ptr addrspace(1) %filters_loc.0.i, i32 0
%ld = load <4 x half>, ptr addrspace(1) %arrayidx.i, align 8
%add.ptr32.i = getelementptr inbounds i8, ptr addrspace(1) %filters_loc.0.i, i32 256
%add19.i = add i32 %s.0.i, 1
%cmp33.i = icmp slt i32 %add19.i, 10
br i1 %cmp33.i, label %do.body.i, label %main_function.inner.exit

main_function.inner.exit: ; preds = %do.body.i, %entry
ret void
}

!0 = !{i32 1, i32 4, i32 6, i32 7}
!17 = !{i32 3}
!18 = !{!19, !20, !21, !22, !23, !24, !25}
!19 = !{!"dst_tensor_buffer", i32 0, i32 0, i32 0, i32 0, !"buffer"}
!20 = !{!"src_tensor_buffer", i32 1, i32 1, i32 0, i32 0, !"buffer"}
!21 = !{!"weights_buffer", i32 2, i32 2, i32 0, i32 0, !"buffer"}
!22 = !{!"biases_image2d", i32 3, i32 3, i32 0, i32 0, !"ro_image"}
!23 = !{!"shared_int4_0", i32 4, i32 -1, i32 48, i32 16, !"pod_pushconstant"}
!24 = !{!"shared_int4_1", i32 5, i32 -1, i32 64, i32 16, !"pod_pushconstant"}
!25 = !{!"shared_int4_2", i32 6, i32 -1, i32 80, i32 16, !"pod_pushconstant"}

0 comments on commit 1907bf6

Please sign in to comment.