Skip to content

Commit

Permalink
[gccjit] use cached record type and generalize with union support
Browse files Browse the repository at this point in the history
  • Loading branch information
SchrodingerZhu committed Nov 14, 2024
1 parent f20a2a8 commit ec95886
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
32 changes: 32 additions & 0 deletions include/mlir-gccjit/Translation/TranslateToGCCJIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define MLIR_GCCJIT_TRANSLATION_TRANSLATETOGCCJIT_H

#include <llvm/ADT/DenseMap.h>
#include <llvm/ADT/PointerUnion.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/MLIRContext.h>
Expand Down Expand Up @@ -94,6 +95,9 @@ class GCCJITTranslation {
gcc_jit_field *operator[](size_t index) const {
return gcc_jit_struct_get_field(structHandle, index);
}
gcc_jit_type *getTypeHandle() const {
return gcc_jit_struct_as_type(structHandle);
}
};

class UnionEntry {
Expand All @@ -106,6 +110,30 @@ class GCCJITTranslation {
gcc_jit_type *getRawHandle() const { return unionHandle; }
size_t getFieldCount() const { return fields.size(); }
gcc_jit_field *operator[](size_t index) const { return fields[index]; }
gcc_jit_type *getTypeHandle() const { return unionHandle; }
};

class GCCJITRecord : public llvm::PointerUnion<UnionEntry *, StructEntry *> {
public:
using PointerUnion::PointerUnion;
gcc_jit_struct *getAsStruct() const {
return this->get<StructEntry *>()->getRawHandle();
}
gcc_jit_type *getAsType() const {
return this->is<StructEntry *>() ? get<StructEntry *>()->getTypeHandle()
: get<UnionEntry *>()->getRawHandle();
}
gcc_jit_field *operator[](size_t index) const {
if (this->is<StructEntry *>())
return get<StructEntry *>()->operator[](index);
return get<UnionEntry *>()->operator[](index);
}
size_t getFieldCount() const {
return this->is<StructEntry *>() ? get<StructEntry *>()->getFieldCount()
: get<UnionEntry *>()->getFieldCount();
}
bool isStruct() const { return this->is<StructEntry *>(); }
bool isUnion() const { return this->is<UnionEntry *>(); }
};

gcc_jit_context *ctxt;
Expand All @@ -121,8 +149,12 @@ class GCCJITTranslation {
void translateGlobalInitializers();
void translateFunctions();

private:
StructEntry &getOrCreateStructEntry(StructType type);
UnionEntry &getOrCreateUnionEntry(UnionType type);

public:
GCCJITRecord getOrCreateRecordEntry(Type type);
};

llvm::Expected<GCCJITContext> translateModuleToGCCJIT(ModuleOp op);
Expand Down
1 change: 0 additions & 1 deletion src/GCCJITOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
#include <mlir/Support/LLVM.h>
#include <mlir/Support/LogicalResult.h>

#include "mlir-gccjit/IR/GCCJITDialect.h"
#include "mlir-gccjit/IR/GCCJITOps.h"
#include "mlir-gccjit/IR/GCCJITOpsEnums.h"
#include "mlir-gccjit/IR/GCCJITTypes.h"
Expand Down
32 changes: 17 additions & 15 deletions src/Translation/TranslateToGCCJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "mlir-gccjit/Translation/TranslateToGCCJIT.h"

#include <algorithm>
#include <cstddef>
#include <utility>

#include <llvm/ADT/SmallVector.h>
Expand Down Expand Up @@ -607,33 +606,28 @@ gcc_jit_lvalue *RegionVisitor::visitExprWithoutCache(DerefOp op) {
Expr RegionVisitor::visitExprWithoutCache(AccessFieldOp op) {
auto composite = visitExpr(op.getComposite());
auto *loc = getTranslator().getLocation(op.getLoc());
auto *compositeTy = getTranslator().convertType(op.getComposite().getType());
auto compositeTy =
getTranslator().getOrCreateRecordEntry(op.getComposite().getType());
auto index = op.getField().getZExtValue();
// TODO: support union and query from cache instead
auto *structure = gcc_jit_type_is_struct(compositeTy);
if (!structure)
llvm_unreachable("expected struct type");
auto *field = gcc_jit_struct_get_field(structure, index);
auto *field = compositeTy[index];
if (isa<LValueType>(op.getType()))
return gcc_jit_lvalue_access_field(composite, loc, field);
return gcc_jit_rvalue_access_field(composite, loc, field);
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(NewStructOp op) {
auto *rawStructTy = getTranslator().convertType(op.getType());
auto *structTy = gcc_jit_type_is_struct(rawStructTy);
if (!structTy)
auto record = getTranslator().getOrCreateRecordEntry(op.getType());
if (!record.isStruct())
llvm_unreachable("expected struct type");
llvm::SmallVector<gcc_jit_field *> fields;
llvm::SmallVector<gcc_jit_rvalue *> values;
for (auto field : op.getIndices())
fields.push_back(
gcc_jit_struct_get_field(structTy, static_cast<size_t>(field)));
fields.push_back(record[field]);
visitExprAsRValue(op.getElements(), values);
auto *loc = getTranslator().getLocation(op.getLoc());
return gcc_jit_context_new_struct_constructor(getContext(), loc, rawStructTy,
values.size(), fields.data(),
values.data());
return gcc_jit_context_new_struct_constructor(
getContext(), loc, record.getAsType(), values.size(), fields.data(),
values.data());
}

gcc_jit_rvalue *RegionVisitor::visitExprWithoutCache(NewArrayOp op) {
Expand Down Expand Up @@ -1066,4 +1060,12 @@ void GCCJITContextDeleter::operator()(gcc_jit_context *ctxt) const {
gcc_jit_context_release(ctxt);
}

GCCJITTranslation::GCCJITRecord
GCCJITTranslation::getOrCreateRecordEntry(Type type) {
return llvm::TypeSwitch<Type, GCCJITRecord>(type)
.Case([&](StructType t) { return &getOrCreateStructEntry(t); })
.Case([&](UnionType t) { return &getOrCreateUnionEntry(t); })
.Default(
[&](Type) -> GCCJITRecord { llvm_unreachable("unexpected type"); });
}
} // namespace mlir::gccjit

0 comments on commit ec95886

Please sign in to comment.