Skip to content

[CIR][NFC] Add more examples for vtable-related ops #1782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 170 additions & 16 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2607,14 +2607,57 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
the vtable group (as specified by Itanium ABI), and `address_point.offset`
(address point index) the actual address point within that vtable.

The `name` argument to this operation must be the name of a C++ vtable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having a little trouble understanding this from the examples.
In the example: cir.vtable.address_point(@_ZTV1Base, ...
Is @_ZTV1Base the name of the C++ vtable? Or is it a reference to cir.global @_ZTV1Base? Or is @_ZTV1Base considered to be the 'CIR name' for the vtable (just having an @ prepended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of all of those. I believe it is technically an mlir::FlatSymbolRefAttr. We use it as a string to find the global. This is one of the most inefficient bits of MLIR I've come across. In LLVM IR, this would be a direct pointer to the global, but that doesn't seem to be happening in MLIR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me. It is weird that MLIR requires symbols in these cases. Thanks for the clarification.

object. The return value is the address of the virtual function pointer
array within the vtable (the vptr). This value will be written to the
vptr member of a dynamic class by the constructor of the class. Derived
classes have their own vtable, which is used to obtain the vptr stored
in instances of the derived class.

The return type is always `!cir.vptr`.

Example:
Examples:

```C++
struct Base {
Base();
virtual void f();
};
struct Derived : public Base {
Derived();
}
```

```mlir
cir.global linkonce_odr @_ZTV1B = ...
!rec_Base = !cir.record<struct "Base" {!cir.vptr}
!rec_Derived = !cir.record<struct "Derived" {!rec_Base}
...
// VTable for Base
cir.global linkonce_odr @_ZTV1Base = ...
...
%3 = cir.vtable.address_point(@_ZTV1B,
// Constructor for Base
cir.func dso_local @_ZN4BaseC2Ev ...
...
%2 = cir.vtable.address_point(@_ZTV1Base,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here saying // Get the address where the vptr points, maybe?
Otherwise you have to sit with it for a bit of time to understand.

address_point = <index = 0, offset = 2>) : !cir.vptr
// The vptr is at element zero.
%3 = cir.cast(bitcast, %1 : !cir.ptr<!rec_Base>), cir.ptr<!cir.vptr>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this isn't a cir.vtable.get_vptr?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's because of how this gets handled by clang, maybe we can raise it during canonicalization in a later PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm currently working on a PR to introduce cir.vtable.get_vptr here. I have a comment in the code already saying it should be that, but I wasn't sure how that needed to interact with virtual and non-virtual base offsets. I think I have it worked out now though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wicked, sounds good!

cir.store align(8) %2, %3 : !cir.vptr, !cir.ptr<!cir.vptr>
...
// VTable for Derived
cir.global linkonce_odr @_ZTV7Derived = ...
...
// Constructor for Derived
cir.func dso_local @_ZN7DerivedC2Ev ...
// Get the address of Base within this Derived instance
%2 = cir.base_class_addr %1 : !cir.ptr<!rec_Derived> nonnull [0]
cir.call @_ZN4BaseC2Ev(%2)
%3 = cir.vtable.address_point(@_ZTV7Derived,
address_point = <index = 0, offset = 2>) : !cir.vptr
// The vptr is still at the start of the object in this case
%4 = cir.cast(bitcast, %1 : !cir.ptr<!rec_Derived>), !cir.ptr<!cir.vptr>
// This overwrites the vptr that was stored in the Base constructor call
cir.store align(8) %3, %4 : !cir.vptr, !cir.ptr<!cir.vptr>
```
}];

Expand Down Expand Up @@ -2649,9 +2692,41 @@ def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
The return type is always `!cir.ptr<!cir.vptr>`.

Example:
```C++
struct S {
virtual void f1();
virtual void f2();
};
void f3(S *s) {
s->f2();
}
```

```mlir
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
// VTable for S
cir.global external @_ZTV1S = #cir.vtable<{
#cir.const_array<[
// Offset to the base object
#cir.ptr<null> : !cir.ptr<!u8i>,
// Type info for S
#cir.global_view<@_ZTI1S> : !cir.ptr<!u8i>,
// Pointer to S::f1
#cir.global_view<@_ZN1S2f1Ev> : !cir.ptr<!u8i>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe append a comment to the left/right of this line saying vptr --> , or on the comment line above, make a note "vptr points here".
This will visually show where the vptr points.

// Pointer to S::f2
#cir.global_view<@_ZN1S2f2Ev> : !cir.ptr<!u8i>
]> : !cir.array<!cir.ptr<!u8i> x 4>}> ...
// f3()
cir.func dso_local @_Z2f3P1S(%s: !cir.ptr<!rec_S>) {
// Get the vptr -- This points to offset 2 in the vtable.
%1 = cir.vtable.get_vptr %s : !cir.ptr<!rec_S> -> !cir.ptr<!cir.vptr>
%2 = cir.load align(8) %2 : !cir.ptr<!cir.vptr>, !cir.vptr
// Get the address of b->f2() -- may be Base::f2() or Derived::f2()
%3 = cir.vtable.get_virtual_fn_addr %2[1] : !cir.vptr
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>
%4 = cir.load align(8) %3
: !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>,
!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>
cir.call %4(%b)
```
}];

Expand Down Expand Up @@ -2688,19 +2763,98 @@ def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [

The return type is a pointer-to-pointer to the function type.

Example:
Example 1:
Suppose we have two classes, Base and Derived, where Derived overrides
virtual functions that were defined in Base. When a pointer to a Base
object is used to call one of these function, we may not know at compile
time whether it points to an instance of Base or an instance of Derived.
The compiler does not need to know. It will load the vptr from the object
and use that to get the address of the correct function to call. The
vptr will have been initialized in the object's constructor to point to
the correct vtable for the object being instantiated.
```C++
// In this example, when f3 is called, we don't know at compile-time
// whether
struct Base {
virtual void f1();
virtual void f2();
};
struct Derived : public Base {
void f1() override;
void f2() override;
};
void f3(Base *b) {
b->f2();
}
```

```mlir
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
%4 = cir.load %3 : !cir.ptr<!cir.vptr>, !cir.vptr
%5 = cir.vtable.get_virtual_fn_addr %4[2] : !cir.vptr
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>>
%6 = cir.load align(8) %5 : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>)
-> !s32i>>>,
!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>
%7 = cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>,
!cir.ptr<!rec_C>) -> !s32i
// VTable for Base
cir.global external @_ZTV4Base = #cir.vtable<{
#cir.const_array<[
#cir.ptr<null> : !cir.ptr<!u8i>,
#cir.global_view<@_ZTI4Base> : !cir.ptr<!u8i>,
#cir.global_view<@_ZN4Base2f1Ev> : !cir.ptr<!u8i>,
#cir.global_view<@_ZN4Base2f2Ev> : !cir.ptr<!u8i>
]> : !cir.array<!cir.ptr<!u8i> x 4>}> ...
// VTable for Derived
cir.global external @_ZTV7Derived = #cir.vtable<{
#cir.const_array<[
#cir.ptr<null> : !cir.ptr<!u8i>,
#cir.global_view<@_ZTI7Derived> : !cir.ptr<!u8i>,
#cir.global_view<@_ZN7Derived2f1Ev> : !cir.ptr<!u8i>,
#cir.global_view<@_ZN7Derived2f2Ev> : !cir.ptr<!u8i>
]> : !cir.array<!cir.ptr<!u8i> x 4>}> ...
// f3()
cir.func dso_local @_Z2f3P4Base(%b: !cir.ptr<!rec_Base>)
// Get the vptr
%1 = cir.vtable.get_vptr %b : !cir.ptr<!rec_Base> -> !cir.ptr<!cir.vptr>
%2 = cir.load align(8) %2 : !cir.ptr<!cir.vptr>, !cir.vptr
// Get the address of b->f2() -- may be Base::f2() or Derived::f2()
%3 = cir.vtable.get_virtual_fn_addr %2[1] : !cir.vptr
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>
%4 = cir.load align(8) %3
: !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>>,
!cir.ptr<!cir.func<(!cir.ptr<!rec_Base>)>>
cir.call %4(%b)
```

Example 2:
Consider the case of multiple inheritance, where Base1 and Base2 both
provide virtual functions and a third class, Derived, inherits from both
bases. When a pointer to a Derived is used to call a virtual function in
Base2, we must retrieve a pointer to the Base2 portion of the Derived object
and use that pointer to get the vptr for Base2 as a base class.
```C++
struct Base1 {
virtual void f1();
};
struct Base2 {
virtual void f2();
};
struct Derived : public Base1, Base2 { };
void f3(Derived *d) {
d->f2();
}
```

```mlir
!rec_Base1 = !cir.record<struct "Base1" {!cir.vptr}
!rec_Base2 = !cir.record<struct "Base2" {!cir.vptr}
!rec_Derived = !cir.record<struct "Derived" {!rec_Base1, !rec_Base2}
cir.func dso_local @_Z2f3P7Derived(%d: !cir.ptr<!rec_Derived>)
%2 = cir.base_class_addr %d : !cir.ptr<!rec_Derived> nonnull [8]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate note, it's weird that the offset here is [8] instead of [1]. Can these byte offsets be variable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're getting the offset from clang::ASTRecordLayout::getVBaseClassOffset or clang::ASTRecordLayout::getBaseClassOffset but I think it's ultimately CXXABI-dependent, so not necessarily a simple matter of counting base classes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I assumed it was something coming from clang. I don't think this is an actual issue, was just a little odd.

-> !cir.ptr<!rec_Base2>
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_Base2> -> !cir.ptr<!cir.vptr>
%4 = cir.load align(8) %3 : !cir.ptr<!cir.vptr>, !cir.vptr
%5 = cir.vtable.get_virtual_fn_addr %4[0] : !cir.vptr
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base2>)>>>
%6 = cir.load align(8) %5
: !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_Base2>)>>>,
!cir.ptr<!cir.func<(!cir.ptr<!rec_Base2>)>>
cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_Base2>)>>,
!cir.ptr<!rec_Base2>) -> ()
```
}];

let arguments = (ins
Expand Down
Loading