diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index 4a2b534b948d6..ba09451104479 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -44,6 +44,7 @@ add_llvm_target(SPIRVCodeGen SPIRVRegularizer.cpp SPIRVSubtarget.cpp SPIRVTargetMachine.cpp + SPIRVTargetTransformInfo.cpp SPIRVUtils.cpp SPIRVEmitNonSemanticDI.cpp diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.cpp new file mode 100644 index 0000000000000..95093d2b3c263 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.cpp @@ -0,0 +1,40 @@ +//===- SPIRVTargetTransformInfo.cpp - SPIR-V specific TTI -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "SPIRVTargetTransformInfo.h" +#include "llvm/IR/IntrinsicsSPIRV.h" + +using namespace llvm; + +bool llvm::SPIRVTTIImpl::collectFlatAddressOperands( + SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const { + switch (IID) { + case Intrinsic::spv_generic_cast_to_ptr_explicit: + OpIndexes.push_back(0); + return true; + default: + return false; + } +} + +Value *llvm::SPIRVTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, + Value *OldV, + Value *NewV) const { + auto IntrID = II->getIntrinsicID(); + switch (IntrID) { + case Intrinsic::spv_generic_cast_to_ptr_explicit: { + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + unsigned DstAS = II->getType()->getPointerAddressSpace(); + return NewAS == DstAS ? NewV + : ConstantPointerNull::get( + PointerType::get(NewV->getContext(), DstAS)); + } + default: + return nullptr; + } +} diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h index 40e561ba38881..43bf6e9dd2a6e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h +++ b/llvm/lib/Target/SPIRV/SPIRVTargetTransformInfo.h @@ -50,14 +50,15 @@ class SPIRVTTIImpl final : public BasicTTIImplBase { } unsigned getFlatAddressSpace() const override { - if (ST->isShader()) - return 0; - // FIXME: Clang has 2 distinct address space maps. One where + // Clang has 2 distinct address space maps. One where // default=4=Generic, and one with default=0=Function. This depends on the - // environment. For OpenCL, we don't need to run the InferAddrSpace pass, so - // we can return -1, but we might want to fix this. - return -1; + // environment. + return ST->isShader() ? 0 : 4; } + bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, + Intrinsic::ID IID) const override; + Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, + Value *NewV) const override; }; } // namespace llvm diff --git a/llvm/test/Transforms/InferAddressSpaces/SPIRV/generic-cast-explicit.ll b/llvm/test/Transforms/InferAddressSpaces/SPIRV/generic-cast-explicit.ll new file mode 100644 index 0000000000000..aa39797d74a10 --- /dev/null +++ b/llvm/test/Transforms/InferAddressSpaces/SPIRV/generic-cast-explicit.ll @@ -0,0 +1,102 @@ +; This test checks that the address space casts for SPIR-V generic pointer casts +; are lowered correctly by the infer-address-spaces pass. +; RUN: opt < %s -passes=infer-address-spaces -S --mtriple=spirv64-unknown-unknown | FileCheck %s + +; Casting a global pointer to a global pointer. +; The uses of c2 will be replaced with %global. +; CHECK: @kernel1(ptr addrspace(1) %global) +define i1 @kernel1(ptr addrspace(1) %global) { + %c1 = addrspacecast ptr addrspace(1) %global to ptr addrspace(4) + %c2 = call ptr addrspace(1) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr addrspace(1) %global, null + %b1 = icmp eq ptr addrspace(1) %c2, null + ret i1 %b1 +} + +; Casting a global pointer to a local pointer. +; The uses of c2 will be replaced with null. +; CHECK: @kernel2(ptr addrspace(1) %global) +define i1 @kernel2(ptr addrspace(1) %global) { + %c1 = addrspacecast ptr addrspace(1) %global to ptr addrspace(4) + %c2 = call ptr addrspace(3) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr addrspace(3) null, null + %b1 = icmp eq ptr addrspace(3) %c2, null + ret i1 %b1 +} + +; Casting a global pointer to a private pointer. +; The uses of c2 will be replaced with null. +; CHECK: @kernel3(ptr addrspace(1) %global) +define i1 @kernel3(ptr addrspace(1) %global) { + %c1 = addrspacecast ptr addrspace(1) %global to ptr addrspace(4) + %c2 = call ptr @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr null, null + %b1 = icmp eq ptr %c2, null + ret i1 %b1 +} + +; Casting a local pointer to a local pointer. +; The uses of c2 will be replaced with %local. +; CHECK: @kernel4(ptr addrspace(3) %local) +define i1 @kernel4(ptr addrspace(3) %local) { + %c1 = addrspacecast ptr addrspace(3) %local to ptr addrspace(4) + %c2 = call ptr addrspace(3) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr addrspace(3) %local, null + %b1 = icmp eq ptr addrspace(3) %c2, null + ret i1 %b1 +} + +; Casting a local pointer to a global pointer. +; The uses of c2 will be replaced with null. +; CHECK: @kernel5(ptr addrspace(3) %local) +define i1 @kernel5(ptr addrspace(3) %local) { + %c1 = addrspacecast ptr addrspace(3) %local to ptr addrspace(4) + %c2 = call ptr addrspace(1) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr addrspace(1) null, null + %b1 = icmp eq ptr addrspace(1) %c2, null + ret i1 %b1 +} + +; Casting a local pointer to a private pointer. +; The uses of c2 will be replaced with null. +; CHECK: @kernel6(ptr addrspace(3) %local) +define i1 @kernel6(ptr addrspace(3) %local) { + %c1 = addrspacecast ptr addrspace(3) %local to ptr addrspace(4) + %c2 = call ptr @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr null, null + %b1 = icmp eq ptr %c2, null + ret i1 %b1 +} + +; Casting a private pointer to a private pointer. +; The uses of c2 will be replaced with %private. +; CHECK: @kernel7(ptr %private) +define i1 @kernel7(ptr %private) { + %c1 = addrspacecast ptr %private to ptr addrspace(4) + %c2 = call ptr @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr %private, null + %b1 = icmp eq ptr %c2, null + ret i1 %b1 +} + +; Casting a private pointer to a global pointer. +; The uses of c2 will be replaced with null. +; CHECK: @kernel8(ptr %private) +define i1 @kernel8(ptr %private) { + %c1 = addrspacecast ptr %private to ptr addrspace(4) + %c2 = call ptr addrspace(1) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr addrspace(1) null, null + %b1 = icmp eq ptr addrspace(1) %c2, null + ret i1 %b1 +} + +; Casting a private pointer to a local pointer. +; The uses of c2 will be replaced with null. +; CHECK: @kernel9(ptr %private) +define i1 @kernel9(ptr %private) { + %c1 = addrspacecast ptr %private to ptr addrspace(4) + %c2 = call ptr addrspace(3) @llvm.spv.generic.cast.to.ptr.explicit(ptr addrspace(4) %c1) + ; CHECK: %b1 = icmp eq ptr addrspace(3) null, null + %b1 = icmp eq ptr addrspace(3) %c2, null + ret i1 %b1 +}