Skip to content

Commit

Permalink
Refactor lower handlers to use overloads. (#4120)
Browse files Browse the repository at this point in the history
Renames Lower::Handle* to Lower::LowerFunctionInst. This allows writing
a templated handler for instructions, moving code out of the macro
expansion and removing some of the redundancy in things like
`HandleAddrOf(..., SemIR::AddrOf inst)`
  • Loading branch information
jonmeow authored Jul 11, 2024
1 parent bb27a4f commit 6d2d1cf
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 124 deletions.
58 changes: 30 additions & 28 deletions toolchain/lower/function_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,33 @@ auto FunctionContext::LowerBlock(SemIR::InstBlockId block_id) -> void {
}
}

// Handles typed instructions for LowerInst. Many instructions lower using
// HandleInst, but others are unsupported or have trivial lowering.
//
// This only calls HandleInst for versions that should have implementations. A
// different approach would be to have the logic below implemented as HandleInst
// overloads. However, forward declarations of HandleInst exist for all `InstT`
// types, which would make getting the right overload resolution complex.
template <typename InstT>
static auto FatalErrorIfEncountered(InstT inst) -> void {
CARBON_FATAL()
<< "Encountered an instruction that isn't expected to lower. It's "
"possible that logic needs to be changed in order to stop "
"showing this instruction in lowered contexts. Instruction: "
<< inst;
}

// For instructions that are always of type `type`, produce the trivial runtime
// representation of type `type`.
static auto SetTrivialType(FunctionContext& context, SemIR::InstId inst_id)
-> void {
context.SetLocal(inst_id, context.GetTypeAsValue());
static auto LowerInstHelper(FunctionContext& context, SemIR::InstId inst_id,
InstT inst) {
if constexpr (!InstT::Kind.is_lowered()) {
CARBON_FATAL()
<< "Encountered an instruction that isn't expected to lower. It's "
"possible that logic needs to be changed in order to stop "
"showing this instruction in lowered contexts. Instruction: "
<< inst;
} else if constexpr (InstT::Kind.constant_kind() ==
SemIR::InstConstantKind::Always) {
CARBON_FATAL() << "Missing constant value for constant instruction "
<< inst;
} else if constexpr (InstT::Kind.is_type() == SemIR::InstIsType::Always) {
// For instructions that are always of type `type`, produce the trivial
// runtime representation of type `type`.
context.SetLocal(inst_id, context.GetTypeAsValue());
} else {
HandleInst(context, inst_id, inst);
}
}

// TODO: Consider renaming Handle##Name, instead relying on typed_inst overload
Expand All @@ -82,21 +95,10 @@ auto FunctionContext::LowerInst(SemIR::InstId inst_id) -> void {
builder_.getInserter().SetCurrentInstId(inst_id);

CARBON_KIND_SWITCH(inst) {
#define CARBON_SEM_IR_INST_KIND(Name) \
case CARBON_KIND(SemIR::Name typed_inst): { \
if constexpr (!SemIR::Name::Kind.is_lowered()) { \
FatalErrorIfEncountered(typed_inst); \
} else if constexpr (SemIR::Name::Kind.constant_kind() == \
SemIR::InstConstantKind::Always) { \
CARBON_FATAL() << "Missing constant value for constant instruction " \
<< inst; \
} else if constexpr (SemIR::Name::Kind.is_type() == \
SemIR::InstIsType::Always) { \
SetTrivialType(*this, inst_id); \
} else { \
Handle##Name(*this, inst_id, typed_inst); \
} \
break; \
#define CARBON_SEM_IR_INST_KIND(Name) \
case CARBON_KIND(SemIR::Name typed_inst): { \
LowerInstHelper(*this, inst_id, typed_inst); \
break; \
}
#include "toolchain/sem_ir/inst_kind.def"
}
Expand Down
12 changes: 7 additions & 5 deletions toolchain/lower/function_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,13 @@ class FunctionContext {
Map<SemIR::InstId, llvm::Value*> locals_;
};

// Declare handlers for each SemIR::File instruction. Note that these aren't all
// defined.
#define CARBON_SEM_IR_INST_KIND(Name) \
auto Handle##Name(FunctionContext& context, SemIR::InstId inst_id, \
SemIR::Name inst) -> void;
// Provides handlers for instructions that occur in a FunctionContext. Although
// this is declared for all instructions, it should only be defined for
// instructions which are non-constant and not always typed. See
// `FunctionContext::LowerInst` for how this is used.
#define CARBON_SEM_IR_INST_KIND(Name) \
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id, \
SemIR::Name inst) -> void;
#include "toolchain/sem_ir/inst_kind.def"

} // namespace Carbon::Lower
Expand Down
104 changes: 52 additions & 52 deletions toolchain/lower/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@

namespace Carbon::Lower {

auto HandleAddrOf(FunctionContext& context, SemIR::InstId inst_id,
SemIR::AddrOf inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::AddrOf inst) -> void {
context.SetLocal(inst_id, context.GetValue(inst.lvalue_id));
}

auto HandleAddrPattern(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::AddrPattern /*inst*/) -> void {
auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::AddrPattern /*inst*/) -> void {
CARBON_FATAL() << "`addr` should be lowered by `BuildFunctionDefinition`";
}

auto HandleArrayIndex(FunctionContext& context, SemIR::InstId inst_id,
SemIR::ArrayIndex inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::ArrayIndex inst) -> void {
auto* array_value = context.GetValue(inst.array_id);
auto* llvm_type =
context.GetType(context.sem_ir().insts().Get(inst.array_id).type_id());
Expand All @@ -41,25 +41,25 @@ auto HandleArrayIndex(FunctionContext& context, SemIR::InstId inst_id,
indexes, "array.index"));
}

auto HandleArrayInit(FunctionContext& context, SemIR::InstId inst_id,
SemIR::ArrayInit inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::ArrayInit inst) -> void {
// The result of initialization is the return slot of the initializer.
context.SetLocal(inst_id, context.GetValue(inst.dest_id));
}

auto HandleAsCompatible(FunctionContext& context, SemIR::InstId inst_id,
SemIR::AsCompatible inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::AsCompatible inst) -> void {
context.SetLocal(inst_id, context.GetValue(inst.source_id));
}

auto HandleAssign(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::Assign inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::Assign inst) -> void {
auto storage_type_id = context.sem_ir().insts().Get(inst.lhs_id).type_id();
context.FinishInit(storage_type_id, inst.lhs_id, inst.rhs_id);
}

auto HandleBindAlias(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BindAlias inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BindAlias inst) -> void {
auto type_inst_id = context.sem_ir().types().GetInstId(inst.type_id);
if (type_inst_id == SemIR::InstId::BuiltinNamespaceType) {
return;
Expand All @@ -68,8 +68,8 @@ auto HandleBindAlias(FunctionContext& context, SemIR::InstId inst_id,
context.SetLocal(inst_id, context.GetValue(inst.value_id));
}

auto HandleExportDecl(FunctionContext& context, SemIR::InstId inst_id,
SemIR::ExportDecl inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::ExportDecl inst) -> void {
auto type_inst_id = context.sem_ir().types().GetInstId(inst.type_id);
if (type_inst_id == SemIR::InstId::BuiltinNamespaceType) {
return;
Expand All @@ -78,30 +78,30 @@ auto HandleExportDecl(FunctionContext& context, SemIR::InstId inst_id,
context.SetLocal(inst_id, context.GetValue(inst.value_id));
}

auto HandleBindName(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BindName inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BindName inst) -> void {
context.SetLocal(inst_id, context.GetValue(inst.value_id));
}

auto HandleBindSymbolicName(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BindSymbolicName inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BindSymbolicName inst) -> void {
context.SetLocal(inst_id, context.GetValue(inst.value_id));
}

auto HandleBlockArg(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BlockArg inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BlockArg inst) -> void {
context.SetLocal(inst_id, context.GetBlockArg(inst.block_id, inst.type_id));
}

auto HandleBoundMethod(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BoundMethod inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::BoundMethod inst) -> void {
// Propagate just the function; the object is separately provided to the
// enclosing call as an implicit argument.
context.SetLocal(inst_id, context.GetValue(inst.function_id));
}

auto HandleBranch(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::Branch inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::Branch inst) -> void {
// Opportunistically avoid creating a BasicBlock that contains just a branch.
// TODO: Don't do this if it would remove a loop preheader block.
llvm::BasicBlock* block = context.builder().GetInsertBlock();
Expand All @@ -114,17 +114,17 @@ auto HandleBranch(FunctionContext& context, SemIR::InstId /*inst_id*/,
context.builder().ClearInsertionPoint();
}

auto HandleBranchIf(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::BranchIf inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::BranchIf inst) -> void {
llvm::Value* cond = context.GetValue(inst.cond_id);
llvm::BasicBlock* then_block = context.GetBlock(inst.target_id);
llvm::BasicBlock* else_block = context.MakeSyntheticBlock();
context.builder().CreateCondBr(cond, then_block, else_block);
context.builder().SetInsertPoint(else_block);
}

auto HandleBranchWithArg(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::BranchWithArg inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::BranchWithArg inst) -> void {
llvm::Value* arg = context.GetValue(inst.arg_id);
SemIR::TypeId arg_type_id =
context.sem_ir().insts().Get(inst.arg_id).type_id();
Expand All @@ -150,29 +150,29 @@ auto HandleBranchWithArg(FunctionContext& context, SemIR::InstId /*inst_id*/,
context.builder().ClearInsertionPoint();
}

auto HandleConverted(FunctionContext& context, SemIR::InstId inst_id,
SemIR::Converted inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::Converted inst) -> void {
context.SetLocal(inst_id, context.GetValue(inst.result_id));
}

auto HandleDeref(FunctionContext& context, SemIR::InstId inst_id,
SemIR::Deref inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::Deref inst) -> void {
context.SetLocal(inst_id, context.GetValue(inst.pointer_id));
}

auto HandleFacetTypeAccess(FunctionContext& context, SemIR::InstId inst_id,
SemIR::FacetTypeAccess /*inst*/) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::FacetTypeAccess /*inst*/) -> void {
context.SetLocal(inst_id, context.GetTypeAsValue());
}

auto HandleInitializeFrom(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::InitializeFrom inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::InitializeFrom inst) -> void {
auto storage_type_id = context.sem_ir().insts().Get(inst.dest_id).type_id();
context.FinishInit(storage_type_id, inst.dest_id, inst.src_id);
}

auto HandleNameRef(FunctionContext& context, SemIR::InstId inst_id,
SemIR::NameRef inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::NameRef inst) -> void {
auto type_inst_id = context.sem_ir().types().GetInstId(inst.type_id);
if (type_inst_id == SemIR::InstId::BuiltinNamespaceType) {
return;
Expand All @@ -181,18 +181,18 @@ auto HandleNameRef(FunctionContext& context, SemIR::InstId inst_id,
context.SetLocal(inst_id, context.GetValue(inst.value_id));
}

auto HandleParam(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::Param /*inst*/) -> void {
auto HandleInst(FunctionContext& /*context*/, SemIR::InstId /*inst_id*/,
SemIR::Param /*inst*/) -> void {
CARBON_FATAL() << "Parameters should be lowered by `BuildFunctionDefinition`";
}

auto HandleReturn(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::Return /*inst*/) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::Return /*inst*/) -> void {
context.builder().CreateRetVoid();
}

auto HandleReturnExpr(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::ReturnExpr inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId /*inst_id*/,
SemIR::ReturnExpr inst) -> void {
auto result_type_id = context.sem_ir().insts().Get(inst.expr_id).type_id();
switch (SemIR::GetInitRepr(context.sem_ir(), result_type_id).kind) {
case SemIR::InitRepr::None:
Expand All @@ -210,20 +210,20 @@ auto HandleReturnExpr(FunctionContext& context, SemIR::InstId /*inst_id*/,
}
}

auto HandleSpliceBlock(FunctionContext& context, SemIR::InstId inst_id,
SemIR::SpliceBlock inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::SpliceBlock inst) -> void {
context.LowerBlock(inst.block_id);
context.SetLocal(inst_id, context.GetValue(inst.result_id));
}

auto HandleUnaryOperatorNot(FunctionContext& context, SemIR::InstId inst_id,
SemIR::UnaryOperatorNot inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::UnaryOperatorNot inst) -> void {
context.SetLocal(
inst_id, context.builder().CreateNot(context.GetValue(inst.operand_id)));
}

auto HandleVarStorage(FunctionContext& context, SemIR::InstId inst_id,
SemIR::VarStorage inst) -> void {
auto HandleInst(FunctionContext& context, SemIR::InstId inst_id,
SemIR::VarStorage inst) -> void {
context.SetLocal(inst_id,
context.builder().CreateAlloca(context.GetType(inst.type_id),
/*ArraySize=*/nullptr));
Expand Down
Loading

0 comments on commit 6d2d1cf

Please sign in to comment.