diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index f4d1c0850c71..be7347c34caf 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2607,14 +2607,57 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [ the vtable group (as specified by Itanium ABI), and `address_point.offset` (address point index) the actual address point within that vtable. + The `name` argument to this operation must be the name of a C++ vtable + object. The return value is the address of the virtual function pointer + array within the vtable (the vptr). This value will be written to the + vptr member of a dynamic class by the constructor of the class. Derived + classes have their own vtable, which is used to obtain the vptr stored + in instances of the derived class. + The return type is always `!cir.vptr`. - Example: + Examples: + + ```C++ + struct Base { + Base(); + virtual void f(); + }; + struct Derived : public Base { + Derived(); + } + ``` + ```mlir - cir.global linkonce_odr @_ZTV1B = ... + !rec_Base = !cir.record) : !cir.vptr + // The vptr is at element zero. + %3 = cir.cast(bitcast, %1 : !cir.ptr), cir.ptr> + cir.store align(8) %2, %3 : !cir.vptr, !cir.ptr + ... + // VTable for Derived + cir.global linkonce_odr @_ZTV7Derived = ... + ... + // Constructor for Derived + cir.func dso_local @_ZN7DerivedC2Ev ... + // Get the address of Base within this Derived instance + %2 = cir.base_class_addr %1 : !cir.ptr nonnull [0] + cir.call @_ZN4BaseC2Ev(%2) + %3 = cir.vtable.address_point(@_ZTV7Derived, + address_point = ) : !cir.vptr + // The vptr is still at the start of the object in this case + %4 = cir.cast(bitcast, %1 : !cir.ptr), !cir.ptr + // This overwrites the vptr that was stored in the Base constructor call + cir.store align(8) %3, %4 : !cir.vptr, !cir.ptr ``` }]; @@ -2649,9 +2692,41 @@ def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> { The return type is always `!cir.ptr`. Example: + ```C++ + struct S { + virtual void f1(); + virtual void f2(); + }; + void f3(S *s) { + s->f2(); + } + ``` + ```mlir - %2 = cir.load %0 : !cir.ptr>, !cir.ptr - %3 = cir.vtable.get_vptr %2 : !cir.ptr -> !cir.ptr + // VTable for S + cir.global external @_ZTV1S = #cir.vtable<{ + #cir.const_array<[ + // Offset to the base object + #cir.ptr : !cir.ptr, + // Type info for S + #cir.global_view<@_ZTI1S> : !cir.ptr, + // Pointer to S::f1 + #cir.global_view<@_ZN1S2f1Ev> : !cir.ptr, + // Pointer to S::f2 + #cir.global_view<@_ZN1S2f2Ev> : !cir.ptr + ]> : !cir.array x 4>}> ... + // f3() + cir.func dso_local @_Z2f3P1S(%s: !cir.ptr) { + // Get the vptr -- This points to offset 2 in the vtable. + %1 = cir.vtable.get_vptr %s : !cir.ptr -> !cir.ptr + %2 = cir.load align(8) %2 : !cir.ptr, !cir.vptr + // Get the address of b->f2() -- may be Base::f2() or Derived::f2() + %3 = cir.vtable.get_virtual_fn_addr %2[1] : !cir.vptr + -> !cir.ptr)>>> + %4 = cir.load align(8) %3 + : !cir.ptr)>>>, + !cir.ptr)>> + cir.call %4(%b) ``` }]; @@ -2688,19 +2763,98 @@ def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [ The return type is a pointer-to-pointer to the function type. - Example: + Example 1: + Suppose we have two classes, Base and Derived, where Derived overrides + virtual functions that were defined in Base. When a pointer to a Base + object is used to call one of these function, we may not know at compile + time whether it points to an instance of Base or an instance of Derived. + The compiler does not need to know. It will load the vptr from the object + and use that to get the address of the correct function to call. The + vptr will have been initialized in the object's constructor to point to + the correct vtable for the object being instantiated. + ```C++ + // In this example, when f3 is called, we don't know at compile-time + // whether + struct Base { + virtual void f1(); + virtual void f2(); + }; + struct Derived : public Base { + void f1() override; + void f2() override; + }; + void f3(Base *b) { + b->f2(); + } + ``` + ```mlir - %2 = cir.load %0 : !cir.ptr>, !cir.ptr - %3 = cir.vtable.get_vptr %2 : !cir.ptr -> !cir.ptr - %4 = cir.load %3 : !cir.ptr, !cir.vptr - %5 = cir.vtable.get_virtual_fn_addr %4[2] : !cir.vptr - -> !cir.ptr) -> !s32i>>> - %6 = cir.load align(8) %5 : !cir.ptr) - -> !s32i>>>, - !cir.ptr) -> !s32i>> - %7 = cir.call %6(%2) : (!cir.ptr) -> !s32i>>, - !cir.ptr) -> !s32i + // VTable for Base + cir.global external @_ZTV4Base = #cir.vtable<{ + #cir.const_array<[ + #cir.ptr : !cir.ptr, + #cir.global_view<@_ZTI4Base> : !cir.ptr, + #cir.global_view<@_ZN4Base2f1Ev> : !cir.ptr, + #cir.global_view<@_ZN4Base2f2Ev> : !cir.ptr + ]> : !cir.array x 4>}> ... + // VTable for Derived + cir.global external @_ZTV7Derived = #cir.vtable<{ + #cir.const_array<[ + #cir.ptr : !cir.ptr, + #cir.global_view<@_ZTI7Derived> : !cir.ptr, + #cir.global_view<@_ZN7Derived2f1Ev> : !cir.ptr, + #cir.global_view<@_ZN7Derived2f2Ev> : !cir.ptr + ]> : !cir.array x 4>}> ... + // f3() + cir.func dso_local @_Z2f3P4Base(%b: !cir.ptr) + // Get the vptr + %1 = cir.vtable.get_vptr %b : !cir.ptr -> !cir.ptr + %2 = cir.load align(8) %2 : !cir.ptr, !cir.vptr + // Get the address of b->f2() -- may be Base::f2() or Derived::f2() + %3 = cir.vtable.get_virtual_fn_addr %2[1] : !cir.vptr + -> !cir.ptr)>>> + %4 = cir.load align(8) %3 + : !cir.ptr)>>>, + !cir.ptr)>> + cir.call %4(%b) ``` + + Example 2: + Consider the case of multiple inheritance, where Base1 and Base2 both + provide virtual functions and a third class, Derived, inherits from both + bases. When a pointer to a Derived is used to call a virtual function in + Base2, we must retrieve a pointer to the Base2 portion of the Derived object + and use that pointer to get the vptr for Base2 as a base class. + ```C++ + struct Base1 { + virtual void f1(); + }; + struct Base2 { + virtual void f2(); + }; + struct Derived : public Base1, Base2 { }; + void f3(Derived *d) { + d->f2(); + } + ``` + + ```mlir + !rec_Base1 = !cir.record) + %2 = cir.base_class_addr %d : !cir.ptr nonnull [8] + -> !cir.ptr + %3 = cir.vtable.get_vptr %2 : !cir.ptr -> !cir.ptr + %4 = cir.load align(8) %3 : !cir.ptr, !cir.vptr + %5 = cir.vtable.get_virtual_fn_addr %4[0] : !cir.vptr + -> !cir.ptr)>>> + %6 = cir.load align(8) %5 + : !cir.ptr)>>>, + !cir.ptr)>> + cir.call %6(%2) : (!cir.ptr)>>, + !cir.ptr) -> () + ``` }]; let arguments = (ins