From 9e06e8d075bf4163e6df228cb5705c0384484490 Mon Sep 17 00:00:00 2001 From: Andy Kaylor Date: Fri, 8 Aug 2025 13:18:30 -0700 Subject: [PATCH] [CIR] Use cir.get_vptr to initialize vptr members Previously, following classic codegen, we were using a bitcast to get the address of the vptr member when initializing the vptr in a constructor. This change updates the handling to use cir.get_vptr instead. This is functionally equivalent because the vptr is always at offset zero for the CXXABIs we currently support, but using cir.get_vptr makes what we are doing more explicit. --- clang/lib/CIR/CodeGen/CIRGenClass.cpp | 15 +++----- clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp | 3 +- clang/test/CIR/CodeGen/multi-vtable.cpp | 6 +-- clang/test/CIR/CodeGen/vtable-rtti.cpp | 4 +- clang/test/CIR/CodeGen/vtt.cpp | 37 +++++++++++-------- 5 files changed, 34 insertions(+), 31 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp index 73b9ebbf2346..11a256e7de10 100644 --- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp @@ -795,10 +795,10 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc, } // Apply the offsets. - Address VTableField = LoadCXXThisAddress(); + Address ClassAddr = LoadCXXThisAddress(); if (!NonVirtualOffset.isZero() || VirtualOffset) { - VTableField = ApplyNonVirtualAndVirtualOffset( - loc, *this, VTableField, NonVirtualOffset, VirtualOffset, + ClassAddr = ApplyNonVirtualAndVirtualOffset( + loc, *this, ClassAddr, NonVirtualOffset, VirtualOffset, Vptr.VTableClass, Vptr.NearestVBase, BaseValueTy); } @@ -806,13 +806,10 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc, // // vtable field is derived from `this` pointer, therefore they should be in // the same addr space. - // TODO(cir): We should be using cir.get_vptr rather than a bitcast to get - // the vptr field, but the call to ApplyNonVirtualAndVirtualOffset - // will also need to be adjusted. That should probably be using - // cir.base_class_addr. assert(!cir::MissingFeatures::addressSpace()); - VTableField = builder.createElementBitCast(loc, VTableField, - VTableAddressPoint.getType()); + auto VTablePtr = builder.create( + loc, builder.getPtrToVPtrType(), ClassAddr.getPointer()); + Address VTableField = Address(VTablePtr, ClassAddr.getAlignment()); auto storeOp = builder.createStore(loc, VTableAddressPoint, VTableField); TBAAAccessInfo TBAAInfo = CGM.getTBAAVTablePtrAccessInfo(VTableAddressPoint.getType()); diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index e595d23497e3..a60afa1f13bb 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -992,7 +992,8 @@ mlir::Value CIRGenItaniumCXXABI::getVTableAddressPointInStructorWithVTT( VTTPtr = CGF.getBuilder().createVTTAddrPoint(Loc, VTTPtr.getType(), VTTPtr, VirtualPointerIndex); // And load the address point from the VTT. - return CGF.getBuilder().createAlignedLoad(Loc, CGF.VoidPtrTy, VTTPtr, + auto VPtrType = cir::VPtrType::get(CGF.getBuilder().getContext()); + return CGF.getBuilder().createAlignedLoad(Loc, VPtrType, VTTPtr, CGF.getPointerAlign()); } diff --git a/clang/test/CIR/CodeGen/multi-vtable.cpp b/clang/test/CIR/CodeGen/multi-vtable.cpp index 0545075b7f34..7f3e9a4626f4 100644 --- a/clang/test/CIR/CodeGen/multi-vtable.cpp +++ b/clang/test/CIR/CodeGen/multi-vtable.cpp @@ -40,7 +40,7 @@ int main() { // CIR: cir.func linkonce_odr @_ZN6MotherC2Ev(%arg0: !cir.ptr // CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV6Mother, address_point = ) : !cir.vptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr // CIR: cir.store{{.*}} %2, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: cir.return // CIR: } @@ -53,11 +53,11 @@ int main() { // CIR: cir.func linkonce_odr @_ZN5ChildC2Ev(%arg0: !cir.ptr // CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV5Child, address_point = ) : !cir.vptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %1 : !cir.ptr -> !cir.ptr // CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV5Child, address_point = ) : !cir.vptr // CIR: %7 = cir.base_class_addr %1 : !cir.ptr nonnull [8] -> !cir.ptr -// CIR: %8 = cir.cast(bitcast, %7 : !cir.ptr), !cir.ptr +// CIR: %8 = cir.vtable.get_vptr %7 : !cir.ptr -> !cir.ptr // CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: cir.return // CIR: } diff --git a/clang/test/CIR/CodeGen/vtable-rtti.cpp b/clang/test/CIR/CodeGen/vtable-rtti.cpp index 711312a11911..684608f3d52a 100644 --- a/clang/test/CIR/CodeGen/vtable-rtti.cpp +++ b/clang/test/CIR/CodeGen/vtable-rtti.cpp @@ -46,7 +46,7 @@ class B : public A // CHECK: %2 = cir.base_class_addr %1 : !cir.ptr nonnull [0] -> !cir.ptr // CHECK: cir.call @_ZN1AC2Ev(%2) : (!cir.ptr) -> () // CHECK: %3 = cir.vtable.address_point(@_ZTV1B, address_point = ) : !cir.vptr -// CHECK: %4 = cir.cast(bitcast, %1 : !cir.ptr), !cir.ptr +// CHECK: %4 = cir.vtable.get_vptr %1 : !cir.ptr -> !cir.ptr // CHECK: cir.store{{.*}} %3, %4 : !cir.vptr, !cir.ptr // CHECK: cir.return // CHECK: } @@ -74,7 +74,7 @@ class B : public A // CHECK: cir.store{{.*}} %arg0, %0 : !cir.ptr, !cir.ptr> // CHECK: %1 = cir.load %0 : !cir.ptr>, !cir.ptr // CHECK: %2 = cir.vtable.address_point(@_ZTV1A, address_point = ) : !cir.vptr -// CHECK: %3 = cir.cast(bitcast, %1 : !cir.ptr), !cir.ptr +// CHECK: %3 = cir.vtable.get_vptr %1 : !cir.ptr -> !cir.ptr // CHECK: cir.store{{.*}} %2, %3 : !cir.vptr, !cir.ptr // CHECK: cir.return // CHECK: } diff --git a/clang/test/CIR/CodeGen/vtt.cpp b/clang/test/CIR/CodeGen/vtt.cpp index 4df9d9b95002..4db8face56a4 100644 --- a/clang/test/CIR/CodeGen/vtt.cpp +++ b/clang/test/CIR/CodeGen/vtt.cpp @@ -39,7 +39,7 @@ int f() { // Class A constructor // CIR: cir.func linkonce_odr @_ZN1AC2Ev(%arg0: !cir.ptr // CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV1A, address_point = ) : !cir.vptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{.*}} : !cir.ptr -> !cir.ptr // CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: } @@ -51,12 +51,14 @@ int f() { // Class B constructor // CIR: cir.func linkonce_odr @_ZN1BC2Ev(%arg0: !cir.ptr // CIR: %{{[0-9]+}} = cir.vtt.address_point %{{[0-9]+}} : !cir.ptr>, offset = 0 -> !cir.ptr> -// CIR: %{{[0-9]+}} = cir.load align(8) %{{[0-9]+}} : !cir.ptr>, !cir.ptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr> -// CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr, !cir.ptr> +// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr>), !cir.ptr +// CIR: %{{[0-9]+}} = cir.load align(8) %{{[0-9]+}} : !cir.ptr, !cir.vptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr +// CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: %{{[0-9]+}} = cir.vtt.address_point %{{[0-9]+}} : !cir.ptr>, offset = 1 -> !cir.ptr> -// CIR: %{{[0-9]+}} = cir.load align(8) %{{[0-9]+}} : !cir.ptr>, !cir.ptr +// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr>), !cir.ptr +// CIR: %{{[0-9]+}} = cir.load align(8) %{{[0-9]+}} : !cir.ptr, !cir.vptr // CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{.*}} : !cir.ptr -> !cir.ptr // CIR: %{{[0-9]+}} = cir.load{{.*}} %{{[0-9]+}} : !cir.ptr, !cir.vptr // CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.vptr), !cir.ptr @@ -65,8 +67,8 @@ int f() { // CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr // CIR: %{{[0-9]+}} = cir.load{{.*}} %{{[0-9]+}} : !cir.ptr, !s64i // CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr, %{{[0-9]+}} : !s64i), !cir.ptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr> -// CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr, !cir.ptr> +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr +// CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: } // LLVM-LABEL: @_ZN1BC2Ev @@ -83,12 +85,14 @@ int f() { // Class C constructor // CIR: cir.func linkonce_odr @_ZN1CC2Ev(%arg0: !cir.ptr // CIR: %{{[0-9]+}} = cir.vtt.address_point %{{[0-9]+}} : !cir.ptr>, offset = 0 -> !cir.ptr> -// CIR: %{{[0-9]+}} = cir.load align(8) %{{[0-9]+}} : !cir.ptr>, !cir.ptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr> -// CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr, !cir.ptr> +// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr>), !cir.ptr +// CIR: %{{[0-9]+}} = cir.load align(8) %{{[0-9]+}} : !cir.ptr, !cir.vptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr +// CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: %{{[0-9]+}} = cir.vtt.address_point %{{[0-9]+}} : !cir.ptr>, offset = 1 -> !cir.ptr> -// CIR: %{{[0-9]+}} = cir.load align(8) %{{[0-9]+}} : !cir.ptr>, !cir.ptr +// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr>), !cir.ptr +// CIR: %{{[0-9]+}} = cir.load align(8) %{{[0-9]+}} : !cir.ptr, !cir.vptr // CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr // CIR: %{{[0-9]+}} = cir.load{{.*}} %{{[0-9]+}} : !cir.ptr, !cir.vptr // CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.vptr), !cir.ptr @@ -97,8 +101,9 @@ int f() { // CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr // CIR: %{{[0-9]+}} = cir.load{{.*}} %{{[0-9]+}} : !cir.ptr, !s64i // CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr, %{{[0-9]+}} : !s64i), !cir.ptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr> -// CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr, !cir.ptr> +// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr +// CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: } // Class D constructor @@ -118,17 +123,17 @@ int f() { // CIR: cir.call @_ZN1CC2Ev(%[[C_PTR]], %[[VTT_D_TO_C]]) : (!cir.ptr, !cir.ptr>) -> () // CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV1D, address_point = ) : !cir.vptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr // CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV1D, address_point = ) : !cir.vptr // CIR: %{{[0-9]+}} = cir.base_class_addr %{{[0-9]+}} : !cir.ptr nonnull [40] -> !cir.ptr -// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr // CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV1D, address_point = ) : !cir.vptr // CIR: cir.base_class_addr %{{[0-9]+}} : !cir.ptr nonnull [16] -> !cir.ptr -// CIR: cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr), !cir.ptr +// CIR: %{{[0-9]+}} = cir.vtable.get_vptr %{{[0-9]+}} : !cir.ptr -> !cir.ptr // CIR: cir.store{{.*}} %{{[0-9]+}}, %{{[0-9]+}} : !cir.vptr, !cir.ptr // CIR: cir.return // CIR: }