From ec95886c04aef79993e94c35a430b3a88152b891 Mon Sep 17 00:00:00 2001 From: Schrodinger ZHU Yifan Date: Thu, 14 Nov 2024 10:48:07 -0500 Subject: [PATCH] [gccjit] use cached record type and generalize with union support --- .../Translation/TranslateToGCCJIT.h | 32 +++++++++++++++++++ src/GCCJITOps.cpp | 1 - src/Translation/TranslateToGCCJIT.cpp | 32 ++++++++++--------- 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/include/mlir-gccjit/Translation/TranslateToGCCJIT.h b/include/mlir-gccjit/Translation/TranslateToGCCJIT.h index ad94e11..bcf74fc 100644 --- a/include/mlir-gccjit/Translation/TranslateToGCCJIT.h +++ b/include/mlir-gccjit/Translation/TranslateToGCCJIT.h @@ -16,6 +16,7 @@ #define MLIR_GCCJIT_TRANSLATION_TRANSLATETOGCCJIT_H #include +#include #include #include #include @@ -94,6 +95,9 @@ class GCCJITTranslation { gcc_jit_field *operator[](size_t index) const { return gcc_jit_struct_get_field(structHandle, index); } + gcc_jit_type *getTypeHandle() const { + return gcc_jit_struct_as_type(structHandle); + } }; class UnionEntry { @@ -106,6 +110,30 @@ class GCCJITTranslation { gcc_jit_type *getRawHandle() const { return unionHandle; } size_t getFieldCount() const { return fields.size(); } gcc_jit_field *operator[](size_t index) const { return fields[index]; } + gcc_jit_type *getTypeHandle() const { return unionHandle; } + }; + + class GCCJITRecord : public llvm::PointerUnion { + public: + using PointerUnion::PointerUnion; + gcc_jit_struct *getAsStruct() const { + return this->get()->getRawHandle(); + } + gcc_jit_type *getAsType() const { + return this->is() ? get()->getTypeHandle() + : get()->getRawHandle(); + } + gcc_jit_field *operator[](size_t index) const { + if (this->is()) + return get()->operator[](index); + return get()->operator[](index); + } + size_t getFieldCount() const { + return this->is() ? get()->getFieldCount() + : get()->getFieldCount(); + } + bool isStruct() const { return this->is(); } + bool isUnion() const { return this->is(); } }; gcc_jit_context *ctxt; @@ -121,8 +149,12 @@ class GCCJITTranslation { void translateGlobalInitializers(); void translateFunctions(); +private: StructEntry &getOrCreateStructEntry(StructType type); UnionEntry &getOrCreateUnionEntry(UnionType type); + +public: + GCCJITRecord getOrCreateRecordEntry(Type type); }; llvm::Expected translateModuleToGCCJIT(ModuleOp op); diff --git a/src/GCCJITOps.cpp b/src/GCCJITOps.cpp index 5912254..acd6581 100644 --- a/src/GCCJITOps.cpp +++ b/src/GCCJITOps.cpp @@ -43,7 +43,6 @@ #include #include -#include "mlir-gccjit/IR/GCCJITDialect.h" #include "mlir-gccjit/IR/GCCJITOps.h" #include "mlir-gccjit/IR/GCCJITOpsEnums.h" #include "mlir-gccjit/IR/GCCJITTypes.h" diff --git a/src/Translation/TranslateToGCCJIT.cpp b/src/Translation/TranslateToGCCJIT.cpp index 0a7ec63..c84f931 100644 --- a/src/Translation/TranslateToGCCJIT.cpp +++ b/src/Translation/TranslateToGCCJIT.cpp @@ -15,7 +15,6 @@ #include "mlir-gccjit/Translation/TranslateToGCCJIT.h" #include -#include #include #include @@ -607,33 +606,28 @@ gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(DerefOp op) { Expr RegionVisitor::visitExprWithoutCache(AccessFieldOp op) { auto composite = visitExpr(op.getComposite()); auto *loc = getTranslator().getLocation(op.getLoc()); - auto *compositeTy = getTranslator().convertType(op.getComposite().getType()); + auto compositeTy = + getTranslator().getOrCreateRecordEntry(op.getComposite().getType()); auto index = op.getField().getZExtValue(); - // TODO: support union and query from cache instead - auto *structure = gcc_jit_type_is_struct(compositeTy); - if (!structure) - llvm_unreachable("expected struct type"); - auto *field = gcc_jit_struct_get_field(structure, index); + auto *field = compositeTy[index]; if (isa(op.getType())) return gcc_jit_lvalue_access_field(composite, loc, field); return gcc_jit_rvalue_access_field(composite, loc, field); } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(NewStructOp op) { - auto *rawStructTy = getTranslator().convertType(op.getType()); - auto *structTy = gcc_jit_type_is_struct(rawStructTy); - if (!structTy) + auto record = getTranslator().getOrCreateRecordEntry(op.getType()); + if (!record.isStruct()) llvm_unreachable("expected struct type"); llvm::SmallVector fields; llvm::SmallVector values; for (auto field : op.getIndices()) - fields.push_back( - gcc_jit_struct_get_field(structTy, static_cast(field))); + fields.push_back(record[field]); visitExprAsRValue(op.getElements(), values); auto *loc = getTranslator().getLocation(op.getLoc()); - return gcc_jit_context_new_struct_constructor(getContext(), loc, rawStructTy, - values.size(), fields.data(), - values.data()); + return gcc_jit_context_new_struct_constructor( + getContext(), loc, record.getAsType(), values.size(), fields.data(), + values.data()); } gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(NewArrayOp op) { @@ -1066,4 +1060,12 @@ void GCCJITContextDeleter::operator()(gcc_jit_context *ctxt) const { gcc_jit_context_release(ctxt); } +GCCJITTranslation::GCCJITRecord +GCCJITTranslation::getOrCreateRecordEntry(Type type) { + return llvm::TypeSwitch(type) + .Case([&](StructType t) { return &getOrCreateStructEntry(t); }) + .Case([&](UnionType t) { return &getOrCreateUnionEntry(t); }) + .Default( + [&](Type) -> GCCJITRecord { llvm_unreachable("unexpected type"); }); +} } // namespace mlir::gccjit